1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Analysis/AliasAnalysis.h"
31 #include "llvm/Analysis/MemoryLocation.h"
32 #include "llvm/Analysis/TargetLibraryInfo.h"
33 #include "llvm/Analysis/ValueTracking.h"
34 #include "llvm/Analysis/VectorUtils.h"
35 #include "llvm/CodeGen/ByteProvider.h"
36 #include "llvm/CodeGen/DAGCombine.h"
37 #include "llvm/CodeGen/ISDOpcodes.h"
38 #include "llvm/CodeGen/MachineFunction.h"
39 #include "llvm/CodeGen/MachineMemOperand.h"
40 #include "llvm/CodeGen/MachineValueType.h"
41 #include "llvm/CodeGen/RuntimeLibcalls.h"
42 #include "llvm/CodeGen/SelectionDAG.h"
43 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
44 #include "llvm/CodeGen/SelectionDAGNodes.h"
45 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
46 #include "llvm/CodeGen/TargetLowering.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/TargetSubtargetInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/Constant.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/Metadata.h"
56 #include "llvm/Support/Casting.h"
57 #include "llvm/Support/CodeGen.h"
58 #include "llvm/Support/CommandLine.h"
59 #include "llvm/Support/Compiler.h"
60 #include "llvm/Support/Debug.h"
61 #include "llvm/Support/DebugCounter.h"
62 #include "llvm/Support/ErrorHandling.h"
63 #include "llvm/Support/KnownBits.h"
64 #include "llvm/Support/MathExtras.h"
65 #include "llvm/Support/raw_ostream.h"
66 #include "llvm/Target/TargetMachine.h"
67 #include "llvm/Target/TargetOptions.h"
68 #include <algorithm>
69 #include <cassert>
70 #include <cstdint>
71 #include <functional>
72 #include <iterator>
73 #include <optional>
74 #include <string>
75 #include <tuple>
76 #include <utility>
77 #include <variant>
78 
79 using namespace llvm;
80 
81 #define DEBUG_TYPE "dagcombine"
82 
83 STATISTIC(NodesCombined   , "Number of dag nodes combined");
84 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
85 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
86 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
87 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
88 STATISTIC(SlicedLoads, "Number of load sliced");
89 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
90 
91 DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
92               "Controls whether a DAG combine is performed for a node");
93 
94 static cl::opt<bool>
95 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
96                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
97 
98 static cl::opt<bool>
99 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
100         cl::desc("Enable DAG combiner's use of TBAA"));
101 
102 #ifndef NDEBUG
103 static cl::opt<std::string>
104 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
105                    cl::desc("Only use DAG-combiner alias analysis in this"
106                             " function"));
107 #endif
108 
109 /// Hidden option to stress test load slicing, i.e., when this option
110 /// is enabled, load slicing bypasses most of its profitability guards.
111 static cl::opt<bool>
112 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
113                   cl::desc("Bypass the profitability model of load slicing"),
114                   cl::init(false));
115 
116 static cl::opt<bool>
117   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
118                     cl::desc("DAG combiner may split indexing from loads"));
119 
120 static cl::opt<bool>
121     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
122                        cl::desc("DAG combiner enable merging multiple stores "
123                                 "into a wider store"));
124 
125 static cl::opt<unsigned> TokenFactorInlineLimit(
126     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
127     cl::desc("Limit the number of operands to inline for Token Factors"));
128 
129 static cl::opt<unsigned> StoreMergeDependenceLimit(
130     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
131     cl::desc("Limit the number of times for the same StoreNode and RootNode "
132              "to bail out in store merging dependence check"));
133 
134 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
135     "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
136     cl::desc("DAG combiner enable reducing the width of load/op/store "
137              "sequence"));
138 
139 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
140     "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
141     cl::desc("DAG combiner enable load/<replace bytes>/store with "
142              "a narrower store"));
143 
144 static cl::opt<bool> EnableVectorFCopySignExtendRound(
145     "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
146     cl::desc(
147         "Enable merging extends and rounds into FCOPYSIGN on vector types"));
148 
149 namespace {
150 
151   class DAGCombiner {
152     SelectionDAG &DAG;
153     const TargetLowering &TLI;
154     const SelectionDAGTargetInfo *STI;
155     CombineLevel Level = BeforeLegalizeTypes;
156     CodeGenOptLevel OptLevel;
157     bool LegalDAG = false;
158     bool LegalOperations = false;
159     bool LegalTypes = false;
160     bool ForCodeSize;
161     bool DisableGenericCombines;
162 
163     /// Worklist of all of the nodes that need to be simplified.
164     ///
165     /// This must behave as a stack -- new nodes to process are pushed onto the
166     /// back and when processing we pop off of the back.
167     ///
168     /// The worklist will not contain duplicates but may contain null entries
169     /// due to nodes being deleted from the underlying DAG.
170     SmallVector<SDNode *, 64> Worklist;
171 
172     /// Mapping from an SDNode to its position on the worklist.
173     ///
174     /// This is used to find and remove nodes from the worklist (by nulling
175     /// them) when they are deleted from the underlying DAG. It relies on
176     /// stable indices of nodes within the worklist.
177     DenseMap<SDNode *, unsigned> WorklistMap;
178 
179     /// This records all nodes attempted to be added to the worklist since we
180     /// considered a new worklist entry. As we keep do not add duplicate nodes
181     /// in the worklist, this is different from the tail of the worklist.
182     SmallSetVector<SDNode *, 32> PruningList;
183 
184     /// Set of nodes which have been combined (at least once).
185     ///
186     /// This is used to allow us to reliably add any operands of a DAG node
187     /// which have not yet been combined to the worklist.
188     SmallPtrSet<SDNode *, 32> CombinedNodes;
189 
190     /// Map from candidate StoreNode to the pair of RootNode and count.
191     /// The count is used to track how many times we have seen the StoreNode
192     /// with the same RootNode bail out in dependence check. If we have seen
193     /// the bail out for the same pair many times over a limit, we won't
194     /// consider the StoreNode with the same RootNode as store merging
195     /// candidate again.
196     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
197 
198     // AA - Used for DAG load/store alias analysis.
199     AliasAnalysis *AA;
200 
201     /// When an instruction is simplified, add all users of the instruction to
202     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)203     void AddUsersToWorklist(SDNode *N) {
204       for (SDNode *Node : N->uses())
205         AddToWorklist(Node);
206     }
207 
208     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)209     void AddToWorklistWithUsers(SDNode *N) {
210       AddUsersToWorklist(N);
211       AddToWorklist(N);
212     }
213 
214     // Prune potentially dangling nodes. This is called after
215     // any visit to a node, but should also be called during a visit after any
216     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()217     void clearAddedDanglingWorklistEntries() {
218       // Check any nodes added to the worklist to see if they are prunable.
219       while (!PruningList.empty()) {
220         auto *N = PruningList.pop_back_val();
221         if (N->use_empty())
222           recursivelyDeleteUnusedNodes(N);
223       }
224     }
225 
getNextWorklistEntry()226     SDNode *getNextWorklistEntry() {
227       // Before we do any work, remove nodes that are not in use.
228       clearAddedDanglingWorklistEntries();
229       SDNode *N = nullptr;
230       // The Worklist holds the SDNodes in order, but it may contain null
231       // entries.
232       while (!N && !Worklist.empty()) {
233         N = Worklist.pop_back_val();
234       }
235 
236       if (N) {
237         bool GoodWorklistEntry = WorklistMap.erase(N);
238         (void)GoodWorklistEntry;
239         assert(GoodWorklistEntry &&
240                "Found a worklist entry without a corresponding map entry!");
241       }
242       return N;
243     }
244 
245     /// Call the node-specific routine that folds each particular type of node.
246     SDValue visit(SDNode *N);
247 
248   public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOptLevel OL)249     DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOptLevel OL)
250         : DAG(D), TLI(D.getTargetLoweringInfo()),
251           STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
252       ForCodeSize = DAG.shouldOptForSize();
253       DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
254 
255       MaximumLegalStoreInBits = 0;
256       // We use the minimum store size here, since that's all we can guarantee
257       // for the scalable vector types.
258       for (MVT VT : MVT::all_valuetypes())
259         if (EVT(VT).isSimple() && VT != MVT::Other &&
260             TLI.isTypeLegal(EVT(VT)) &&
261             VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
262           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
263     }
264 
ConsiderForPruning(SDNode * N)265     void ConsiderForPruning(SDNode *N) {
266       // Mark this for potential pruning.
267       PruningList.insert(N);
268     }
269 
270     /// Add to the worklist making sure its instance is at the back (next to be
271     /// processed.)
AddToWorklist(SDNode * N,bool IsCandidateForPruning=true)272     void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true) {
273       assert(N->getOpcode() != ISD::DELETED_NODE &&
274              "Deleted Node added to Worklist");
275 
276       // Skip handle nodes as they can't usefully be combined and confuse the
277       // zero-use deletion strategy.
278       if (N->getOpcode() == ISD::HANDLENODE)
279         return;
280 
281       if (IsCandidateForPruning)
282         ConsiderForPruning(N);
283 
284       if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
285         Worklist.push_back(N);
286     }
287 
288     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)289     void removeFromWorklist(SDNode *N) {
290       CombinedNodes.erase(N);
291       PruningList.remove(N);
292       StoreRootCountMap.erase(N);
293 
294       auto It = WorklistMap.find(N);
295       if (It == WorklistMap.end())
296         return; // Not in the worklist.
297 
298       // Null out the entry rather than erasing it to avoid a linear operation.
299       Worklist[It->second] = nullptr;
300       WorklistMap.erase(It);
301     }
302 
303     void deleteAndRecombine(SDNode *N);
304     bool recursivelyDeleteUnusedNodes(SDNode *N);
305 
306     /// Replaces all uses of the results of one DAG node with new values.
307     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
308                       bool AddTo = true);
309 
310     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)311     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
312       return CombineTo(N, &Res, 1, AddTo);
313     }
314 
315     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)316     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
317                       bool AddTo = true) {
318       SDValue To[] = { Res0, Res1 };
319       return CombineTo(N, To, 2, AddTo);
320     }
321 
322     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
323 
324   private:
325     unsigned MaximumLegalStoreInBits;
326 
327     /// Check the specified integer node value to see if it can be simplified or
328     /// if things it uses can be simplified by bit propagation.
329     /// If so, return true.
SimplifyDemandedBits(SDValue Op)330     bool SimplifyDemandedBits(SDValue Op) {
331       unsigned BitWidth = Op.getScalarValueSizeInBits();
332       APInt DemandedBits = APInt::getAllOnes(BitWidth);
333       return SimplifyDemandedBits(Op, DemandedBits);
334     }
335 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)336     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
337       TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
338       KnownBits Known;
339       if (!TLI.SimplifyDemandedBits(Op, DemandedBits, Known, TLO, 0, false))
340         return false;
341 
342       // Revisit the node.
343       AddToWorklist(Op.getNode());
344 
345       CommitTargetLoweringOpt(TLO);
346       return true;
347     }
348 
349     /// Check the specified vector node value to see if it can be simplified or
350     /// if things it uses can be simplified as it only uses some of the
351     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)352     bool SimplifyDemandedVectorElts(SDValue Op) {
353       // TODO: For now just pretend it cannot be simplified.
354       if (Op.getValueType().isScalableVector())
355         return false;
356 
357       unsigned NumElts = Op.getValueType().getVectorNumElements();
358       APInt DemandedElts = APInt::getAllOnes(NumElts);
359       return SimplifyDemandedVectorElts(Op, DemandedElts);
360     }
361 
362     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
363                               const APInt &DemandedElts,
364                               bool AssumeSingleUse = false);
365     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
366                                     bool AssumeSingleUse = false);
367 
368     bool CombineToPreIndexedLoadStore(SDNode *N);
369     bool CombineToPostIndexedLoadStore(SDNode *N);
370     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
371     bool SliceUpLoad(SDNode *N);
372 
373     // Looks up the chain to find a unique (unaliased) store feeding the passed
374     // load. If no such store is found, returns a nullptr.
375     // Note: This will look past a CALLSEQ_START if the load is chained to it so
376     //       so that it can find stack stores for byval params.
377     StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
378     // Scalars have size 0 to distinguish from singleton vectors.
379     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
380     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
381     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
382 
383     /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
384     ///   load.
385     ///
386     /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
387     /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
388     /// \param EltNo index of the vector element to load.
389     /// \param OriginalLoad load that EVE came from to be replaced.
390     /// \returns EVE on success SDValue() on failure.
391     SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
392                                          SDValue EltNo,
393                                          LoadSDNode *OriginalLoad);
394     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
395     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
396     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
397     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
398     SDValue PromoteIntBinOp(SDValue Op);
399     SDValue PromoteIntShiftOp(SDValue Op);
400     SDValue PromoteExtend(SDValue Op);
401     bool PromoteLoad(SDValue Op);
402 
403     SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
404                                 SDValue RHS, SDValue True, SDValue False,
405                                 ISD::CondCode CC);
406 
407     /// Call the node-specific routine that knows how to fold each
408     /// particular type of node. If that doesn't do anything, try the
409     /// target-specific DAG combines.
410     SDValue combine(SDNode *N);
411 
412     // Visitation implementation - Implement dag node combining for different
413     // node types.  The semantics are as follows:
414     // Return Value:
415     //   SDValue.getNode() == 0 - No change was made
416     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
417     //   otherwise              - N should be replaced by the returned Operand.
418     //
419     SDValue visitTokenFactor(SDNode *N);
420     SDValue visitMERGE_VALUES(SDNode *N);
421     SDValue visitADD(SDNode *N);
422     SDValue visitADDLike(SDNode *N);
423     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
424     SDValue visitSUB(SDNode *N);
425     SDValue visitADDSAT(SDNode *N);
426     SDValue visitSUBSAT(SDNode *N);
427     SDValue visitADDC(SDNode *N);
428     SDValue visitADDO(SDNode *N);
429     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
430     SDValue visitSUBC(SDNode *N);
431     SDValue visitSUBO(SDNode *N);
432     SDValue visitADDE(SDNode *N);
433     SDValue visitUADDO_CARRY(SDNode *N);
434     SDValue visitSADDO_CARRY(SDNode *N);
435     SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
436                                  SDNode *N);
437     SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
438                                  SDNode *N);
439     SDValue visitSUBE(SDNode *N);
440     SDValue visitUSUBO_CARRY(SDNode *N);
441     SDValue visitSSUBO_CARRY(SDNode *N);
442     SDValue visitMUL(SDNode *N);
443     SDValue visitMULFIX(SDNode *N);
444     SDValue useDivRem(SDNode *N);
445     SDValue visitSDIV(SDNode *N);
446     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
447     SDValue visitUDIV(SDNode *N);
448     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
449     SDValue visitREM(SDNode *N);
450     SDValue visitMULHU(SDNode *N);
451     SDValue visitMULHS(SDNode *N);
452     SDValue visitAVG(SDNode *N);
453     SDValue visitABD(SDNode *N);
454     SDValue visitSMUL_LOHI(SDNode *N);
455     SDValue visitUMUL_LOHI(SDNode *N);
456     SDValue visitMULO(SDNode *N);
457     SDValue visitIMINMAX(SDNode *N);
458     SDValue visitAND(SDNode *N);
459     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
460     SDValue visitOR(SDNode *N);
461     SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
462     SDValue visitXOR(SDNode *N);
463     SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
464     SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
465     SDValue visitSHL(SDNode *N);
466     SDValue visitSRA(SDNode *N);
467     SDValue visitSRL(SDNode *N);
468     SDValue visitFunnelShift(SDNode *N);
469     SDValue visitSHLSAT(SDNode *N);
470     SDValue visitRotate(SDNode *N);
471     SDValue visitABS(SDNode *N);
472     SDValue visitBSWAP(SDNode *N);
473     SDValue visitBITREVERSE(SDNode *N);
474     SDValue visitCTLZ(SDNode *N);
475     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
476     SDValue visitCTTZ(SDNode *N);
477     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
478     SDValue visitCTPOP(SDNode *N);
479     SDValue visitSELECT(SDNode *N);
480     SDValue visitVSELECT(SDNode *N);
481     SDValue visitSELECT_CC(SDNode *N);
482     SDValue visitSETCC(SDNode *N);
483     SDValue visitSETCCCARRY(SDNode *N);
484     SDValue visitSIGN_EXTEND(SDNode *N);
485     SDValue visitZERO_EXTEND(SDNode *N);
486     SDValue visitANY_EXTEND(SDNode *N);
487     SDValue visitAssertExt(SDNode *N);
488     SDValue visitAssertAlign(SDNode *N);
489     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
490     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
491     SDValue visitTRUNCATE(SDNode *N);
492     SDValue visitBITCAST(SDNode *N);
493     SDValue visitFREEZE(SDNode *N);
494     SDValue visitBUILD_PAIR(SDNode *N);
495     SDValue visitFADD(SDNode *N);
496     SDValue visitVP_FADD(SDNode *N);
497     SDValue visitVP_FSUB(SDNode *N);
498     SDValue visitSTRICT_FADD(SDNode *N);
499     SDValue visitFSUB(SDNode *N);
500     SDValue visitFMUL(SDNode *N);
501     template <class MatchContextClass> SDValue visitFMA(SDNode *N);
502     SDValue visitFMAD(SDNode *N);
503     SDValue visitFDIV(SDNode *N);
504     SDValue visitFREM(SDNode *N);
505     SDValue visitFSQRT(SDNode *N);
506     SDValue visitFCOPYSIGN(SDNode *N);
507     SDValue visitFPOW(SDNode *N);
508     SDValue visitSINT_TO_FP(SDNode *N);
509     SDValue visitUINT_TO_FP(SDNode *N);
510     SDValue visitFP_TO_SINT(SDNode *N);
511     SDValue visitFP_TO_UINT(SDNode *N);
512     SDValue visitXRINT(SDNode *N);
513     SDValue visitFP_ROUND(SDNode *N);
514     SDValue visitFP_EXTEND(SDNode *N);
515     SDValue visitFNEG(SDNode *N);
516     SDValue visitFABS(SDNode *N);
517     SDValue visitFCEIL(SDNode *N);
518     SDValue visitFTRUNC(SDNode *N);
519     SDValue visitFFREXP(SDNode *N);
520     SDValue visitFFLOOR(SDNode *N);
521     SDValue visitFMinMax(SDNode *N);
522     SDValue visitBRCOND(SDNode *N);
523     SDValue visitBR_CC(SDNode *N);
524     SDValue visitLOAD(SDNode *N);
525 
526     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
527     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
528     SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
529 
530     bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
531 
532     SDValue visitSTORE(SDNode *N);
533     SDValue visitLIFETIME_END(SDNode *N);
534     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
535     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
536     SDValue visitBUILD_VECTOR(SDNode *N);
537     SDValue visitCONCAT_VECTORS(SDNode *N);
538     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
539     SDValue visitVECTOR_SHUFFLE(SDNode *N);
540     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
541     SDValue visitINSERT_SUBVECTOR(SDNode *N);
542     SDValue visitMLOAD(SDNode *N);
543     SDValue visitMSTORE(SDNode *N);
544     SDValue visitMGATHER(SDNode *N);
545     SDValue visitMSCATTER(SDNode *N);
546     SDValue visitVPGATHER(SDNode *N);
547     SDValue visitVPSCATTER(SDNode *N);
548     SDValue visitVP_STRIDED_LOAD(SDNode *N);
549     SDValue visitVP_STRIDED_STORE(SDNode *N);
550     SDValue visitFP_TO_FP16(SDNode *N);
551     SDValue visitFP16_TO_FP(SDNode *N);
552     SDValue visitFP_TO_BF16(SDNode *N);
553     SDValue visitBF16_TO_FP(SDNode *N);
554     SDValue visitVECREDUCE(SDNode *N);
555     SDValue visitVPOp(SDNode *N);
556     SDValue visitGET_FPENV_MEM(SDNode *N);
557     SDValue visitSET_FPENV_MEM(SDNode *N);
558 
559     template <class MatchContextClass>
560     SDValue visitFADDForFMACombine(SDNode *N);
561     template <class MatchContextClass>
562     SDValue visitFSUBForFMACombine(SDNode *N);
563     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
564 
565     SDValue XformToShuffleWithZero(SDNode *N);
566     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
567                                                     const SDLoc &DL,
568                                                     SDNode *N,
569                                                     SDValue N0,
570                                                     SDValue N1);
571     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
572                                       SDValue N1, SDNodeFlags Flags);
573     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
574                            SDValue N1, SDNodeFlags Flags);
575     SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
576                                  EVT VT, SDValue N0, SDValue N1,
577                                  SDNodeFlags Flags = SDNodeFlags());
578 
579     SDValue visitShiftByConstant(SDNode *N);
580 
581     SDValue foldSelectOfConstants(SDNode *N);
582     SDValue foldVSelectOfConstants(SDNode *N);
583     SDValue foldBinOpIntoSelect(SDNode *BO);
584     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
585     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
586     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
587     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
588                              SDValue N2, SDValue N3, ISD::CondCode CC,
589                              bool NotExtCompare = false);
590     SDValue convertSelectOfFPConstantsToLoadOffset(
591         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
592         ISD::CondCode CC);
593     SDValue foldSignChangeInBitcast(SDNode *N);
594     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
595                                    SDValue N2, SDValue N3, ISD::CondCode CC);
596     SDValue foldSelectOfBinops(SDNode *N);
597     SDValue foldSextSetcc(SDNode *N);
598     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
599                               const SDLoc &DL);
600     SDValue foldSubToUSubSat(EVT DstVT, SDNode *N);
601     SDValue foldABSToABD(SDNode *N);
602     SDValue unfoldMaskedMerge(SDNode *N);
603     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
604     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
605                           const SDLoc &DL, bool foldBooleans);
606     SDValue rebuildSetCC(SDValue N);
607 
608     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
609                            SDValue &CC, bool MatchStrict = false) const;
610     bool isOneUseSetCC(SDValue N) const;
611 
612     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
613                                          unsigned HiOp);
614     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
615     SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
616                                  const TargetLowering &TLI);
617 
618     SDValue CombineExtLoad(SDNode *N);
619     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
620     SDValue combineRepeatedFPDivisors(SDNode *N);
621     SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
622     SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
623     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
624     SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
625     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
626     SDValue BuildSDIV(SDNode *N);
627     SDValue BuildSDIVPow2(SDNode *N);
628     SDValue BuildUDIV(SDNode *N);
629     SDValue BuildSREMPow2(SDNode *N);
630     SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
631     SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
632                           bool KnownNeverZero = false,
633                           bool InexpensiveOnly = false,
634                           std::optional<EVT> OutVT = std::nullopt);
635     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
636     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
637     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
638     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
639     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
640                                 SDNodeFlags Flags, bool Reciprocal);
641     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
642                                 SDNodeFlags Flags, bool Reciprocal);
643     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
644                                bool DemandHighBits = true);
645     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
646     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
647                               SDValue InnerPos, SDValue InnerNeg, bool HasPos,
648                               unsigned PosOpcode, unsigned NegOpcode,
649                               const SDLoc &DL);
650     SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
651                               SDValue InnerPos, SDValue InnerNeg, bool HasPos,
652                               unsigned PosOpcode, unsigned NegOpcode,
653                               const SDLoc &DL);
654     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
655     SDValue MatchLoadCombine(SDNode *N);
656     SDValue mergeTruncStores(StoreSDNode *N);
657     SDValue reduceLoadWidth(SDNode *N);
658     SDValue ReduceLoadOpStoreWidth(SDNode *N);
659     SDValue splitMergedValStore(StoreSDNode *ST);
660     SDValue TransformFPLoadStorePair(SDNode *N);
661     SDValue convertBuildVecZextToZext(SDNode *N);
662     SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
663     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
664     SDValue reduceBuildVecTruncToBitCast(SDNode *N);
665     SDValue reduceBuildVecToShuffle(SDNode *N);
666     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
667                                   ArrayRef<int> VectorMask, SDValue VecIn1,
668                                   SDValue VecIn2, unsigned LeftIdx,
669                                   bool DidSplitVec);
670     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
671 
672     /// Walk up chain skipping non-aliasing memory nodes,
673     /// looking for aliasing nodes and adding them to the Aliases vector.
674     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
675                           SmallVectorImpl<SDValue> &Aliases);
676 
677     /// Return true if there is any possibility that the two addresses overlap.
678     bool mayAlias(SDNode *Op0, SDNode *Op1) const;
679 
680     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
681     /// chain (aliasing node.)
682     SDValue FindBetterChain(SDNode *N, SDValue Chain);
683 
684     /// Try to replace a store and any possibly adjacent stores on
685     /// consecutive chains with better chains. Return true only if St is
686     /// replaced.
687     ///
688     /// Notice that other chains may still be replaced even if the function
689     /// returns false.
690     bool findBetterNeighborChains(StoreSDNode *St);
691 
692     // Helper for findBetterNeighborChains. Walk up store chain add additional
693     // chained stores that do not overlap and can be parallelized.
694     bool parallelizeChainedStores(StoreSDNode *St);
695 
696     /// Holds a pointer to an LSBaseSDNode as well as information on where it
697     /// is located in a sequence of memory operations connected by a chain.
698     struct MemOpLink {
699       // Ptr to the mem node.
700       LSBaseSDNode *MemNode;
701 
702       // Offset from the base ptr.
703       int64_t OffsetFromBase;
704 
MemOpLink__anon8fac8bd80111::DAGCombiner::MemOpLink705       MemOpLink(LSBaseSDNode *N, int64_t Offset)
706           : MemNode(N), OffsetFromBase(Offset) {}
707     };
708 
709     // Classify the origin of a stored value.
710     enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)711     StoreSource getStoreSource(SDValue StoreVal) {
712       switch (StoreVal.getOpcode()) {
713       case ISD::Constant:
714       case ISD::ConstantFP:
715         return StoreSource::Constant;
716       case ISD::BUILD_VECTOR:
717         if (ISD::isBuildVectorOfConstantSDNodes(StoreVal.getNode()) ||
718             ISD::isBuildVectorOfConstantFPSDNodes(StoreVal.getNode()))
719           return StoreSource::Constant;
720         return StoreSource::Unknown;
721       case ISD::EXTRACT_VECTOR_ELT:
722       case ISD::EXTRACT_SUBVECTOR:
723         return StoreSource::Extract;
724       case ISD::LOAD:
725         return StoreSource::Load;
726       default:
727         return StoreSource::Unknown;
728       }
729     }
730 
731     /// This is a helper function for visitMUL to check the profitability
732     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
733     /// MulNode is the original multiply, AddNode is (add x, c1),
734     /// and ConstNode is c2.
735     bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
736                                      SDValue ConstNode);
737 
738     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
739     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
740     /// the type of the loaded value to be extended.
741     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
742                           EVT LoadResultTy, EVT &ExtVT);
743 
744     /// Helper function to calculate whether the given Load/Store can have its
745     /// width reduced to ExtVT.
746     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
747                            EVT &MemVT, unsigned ShAmt = 0);
748 
749     /// Used by BackwardsPropagateMask to find suitable loads.
750     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
751                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
752                            ConstantSDNode *Mask, SDNode *&NodeToMask);
753     /// Attempt to propagate a given AND node back to load leaves so that they
754     /// can be combined into narrow loads.
755     bool BackwardsPropagateMask(SDNode *N);
756 
757     /// Helper function for mergeConsecutiveStores which merges the component
758     /// store chains.
759     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
760                                 unsigned NumStores);
761 
762     /// Helper function for mergeConsecutiveStores which checks if all the store
763     /// nodes have the same underlying object. We can still reuse the first
764     /// store's pointer info if all the stores are from the same object.
765     bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
766 
767     /// This is a helper function for mergeConsecutiveStores. When the source
768     /// elements of the consecutive stores are all constants or all extracted
769     /// vector elements, try to merge them into one larger store introducing
770     /// bitcasts if necessary.  \return True if a merged store was created.
771     bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
772                                          EVT MemVT, unsigned NumStores,
773                                          bool IsConstantSrc, bool UseVector,
774                                          bool UseTrunc);
775 
776     /// This is a helper function for mergeConsecutiveStores. Stores that
777     /// potentially may be merged with St are placed in StoreNodes. RootNode is
778     /// a chain predecessor to all store candidates.
779     void getStoreMergeCandidates(StoreSDNode *St,
780                                  SmallVectorImpl<MemOpLink> &StoreNodes,
781                                  SDNode *&Root);
782 
783     /// Helper function for mergeConsecutiveStores. Checks if candidate stores
784     /// have indirect dependency through their operands. RootNode is the
785     /// predecessor to all stores calculated by getStoreMergeCandidates and is
786     /// used to prune the dependency check. \return True if safe to merge.
787     bool checkMergeStoreCandidatesForDependencies(
788         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
789         SDNode *RootNode);
790 
791     /// This is a helper function for mergeConsecutiveStores. Given a list of
792     /// store candidates, find the first N that are consecutive in memory.
793     /// Returns 0 if there are not at least 2 consecutive stores to try merging.
794     unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
795                                   int64_t ElementSizeBytes) const;
796 
797     /// This is a helper function for mergeConsecutiveStores. It is used for
798     /// store chains that are composed entirely of constant values.
799     bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
800                                   unsigned NumConsecutiveStores,
801                                   EVT MemVT, SDNode *Root, bool AllowVectors);
802 
803     /// This is a helper function for mergeConsecutiveStores. It is used for
804     /// store chains that are composed entirely of extracted vector elements.
805     /// When extracting multiple vector elements, try to store them in one
806     /// vector store rather than a sequence of scalar stores.
807     bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
808                                  unsigned NumConsecutiveStores, EVT MemVT,
809                                  SDNode *Root);
810 
811     /// This is a helper function for mergeConsecutiveStores. It is used for
812     /// store chains that are composed entirely of loaded values.
813     bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
814                               unsigned NumConsecutiveStores, EVT MemVT,
815                               SDNode *Root, bool AllowVectors,
816                               bool IsNonTemporalStore, bool IsNonTemporalLoad);
817 
818     /// Merge consecutive store operations into a wide store.
819     /// This optimization uses wide integers or vectors when possible.
820     /// \return true if stores were merged.
821     bool mergeConsecutiveStores(StoreSDNode *St);
822 
823     /// Try to transform a truncation where C is a constant:
824     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
825     ///
826     /// \p N needs to be a truncation and its first operand an AND. Other
827     /// requirements are checked by the function (e.g. that trunc is
828     /// single-use) and if missed an empty SDValue is returned.
829     SDValue distributeTruncateThroughAnd(SDNode *N);
830 
831     /// Helper function to determine whether the target supports operation
832     /// given by \p Opcode for type \p VT, that is, whether the operation
833     /// is legal or custom before legalizing operations, and whether is
834     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)835     bool hasOperation(unsigned Opcode, EVT VT) {
836       return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
837     }
838 
839   public:
840     /// Runs the dag combiner on all nodes in the work list
841     void Run(CombineLevel AtLevel);
842 
getDAG() const843     SelectionDAG &getDAG() const { return DAG; }
844 
845     /// Returns a type large enough to hold any valid shift amount - before type
846     /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)847     EVT getShiftAmountTy(EVT LHSTy) {
848       assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
849       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
850     }
851 
852     /// This method returns true if we are running before type legalization or
853     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)854     bool isTypeLegal(const EVT &VT) {
855       if (!LegalTypes) return true;
856       return TLI.isTypeLegal(VT);
857     }
858 
859     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const860     EVT getSetCCResultType(EVT VT) const {
861       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
862     }
863 
864     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
865                          SDValue OrigLoad, SDValue ExtLoad,
866                          ISD::NodeType ExtType);
867   };
868 
869 /// This class is a DAGUpdateListener that removes any deleted
870 /// nodes from the worklist.
871 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
872   DAGCombiner &DC;
873 
874 public:
WorklistRemover(DAGCombiner & dc)875   explicit WorklistRemover(DAGCombiner &dc)
876     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
877 
NodeDeleted(SDNode * N,SDNode * E)878   void NodeDeleted(SDNode *N, SDNode *E) override {
879     DC.removeFromWorklist(N);
880   }
881 };
882 
883 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
884   DAGCombiner &DC;
885 
886 public:
WorklistInserter(DAGCombiner & dc)887   explicit WorklistInserter(DAGCombiner &dc)
888       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
889 
890   // FIXME: Ideally we could add N to the worklist, but this causes exponential
891   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)892   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
893 };
894 
895 class EmptyMatchContext {
896   SelectionDAG &DAG;
897   const TargetLowering &TLI;
898 
899 public:
EmptyMatchContext(SelectionDAG & DAG,const TargetLowering & TLI,SDNode * Root)900   EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
901       : DAG(DAG), TLI(TLI) {}
902 
match(SDValue OpN,unsigned Opcode) const903   bool match(SDValue OpN, unsigned Opcode) const {
904     return Opcode == OpN->getOpcode();
905   }
906 
907   // Same as SelectionDAG::getNode().
getNode(ArgT &&...Args)908   template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
909     return DAG.getNode(std::forward<ArgT>(Args)...);
910   }
911 
isOperationLegalOrCustom(unsigned Op,EVT VT,bool LegalOnly=false) const912   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
913                                 bool LegalOnly = false) const {
914     return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
915   }
916 };
917 
918 class VPMatchContext {
919   SelectionDAG &DAG;
920   const TargetLowering &TLI;
921   SDValue RootMaskOp;
922   SDValue RootVectorLenOp;
923 
924 public:
VPMatchContext(SelectionDAG & DAG,const TargetLowering & TLI,SDNode * Root)925   VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
926       : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
927     assert(Root->isVPOpcode());
928     if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
929       RootMaskOp = Root->getOperand(*RootMaskPos);
930 
931     if (auto RootVLenPos =
932             ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
933       RootVectorLenOp = Root->getOperand(*RootVLenPos);
934   }
935 
936   /// whether \p OpVal is a node that is functionally compatible with the
937   /// NodeType \p Opc
match(SDValue OpVal,unsigned Opc) const938   bool match(SDValue OpVal, unsigned Opc) const {
939     if (!OpVal->isVPOpcode())
940       return OpVal->getOpcode() == Opc;
941 
942     auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(),
943                                            !OpVal->getFlags().hasNoFPExcept());
944     if (BaseOpc != Opc)
945       return false;
946 
947     // Make sure the mask of OpVal is true mask or is same as Root's.
948     unsigned VPOpcode = OpVal->getOpcode();
949     if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
950       SDValue MaskOp = OpVal.getOperand(*MaskPos);
951       if (RootMaskOp != MaskOp &&
952           !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
953         return false;
954     }
955 
956     // Make sure the EVL of OpVal is same as Root's.
957     if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
958       if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
959         return false;
960     return true;
961   }
962 
963   // Specialize based on number of operands.
964   // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
965   // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
966   // DAG.getNode(Opcode, DL, VT); }
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue Operand)967   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
968     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
969     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
970            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
971     return DAG.getNode(VPOpcode, DL, VT,
972                        {Operand, RootMaskOp, RootVectorLenOp});
973   }
974 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2)975   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
976                   SDValue N2) {
977     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
978     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
979            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
980     return DAG.getNode(VPOpcode, DL, VT,
981                        {N1, N2, RootMaskOp, RootVectorLenOp});
982   }
983 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2,SDValue N3)984   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
985                   SDValue N2, SDValue N3) {
986     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
987     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
988            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
989     return DAG.getNode(VPOpcode, DL, VT,
990                        {N1, N2, N3, RootMaskOp, RootVectorLenOp});
991   }
992 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue Operand,SDNodeFlags Flags)993   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
994                   SDNodeFlags Flags) {
995     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
996     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
997            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
998     return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
999                        Flags);
1000   }
1001 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2,SDNodeFlags Flags)1002   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
1003                   SDValue N2, SDNodeFlags Flags) {
1004     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
1005     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
1006            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
1007     return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
1008                        Flags);
1009   }
1010 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2,SDValue N3,SDNodeFlags Flags)1011   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
1012                   SDValue N2, SDValue N3, SDNodeFlags Flags) {
1013     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
1014     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
1015            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
1016     return DAG.getNode(VPOpcode, DL, VT,
1017                        {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
1018   }
1019 
isOperationLegalOrCustom(unsigned Op,EVT VT,bool LegalOnly=false) const1020   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
1021                                 bool LegalOnly = false) const {
1022     unsigned VPOp = ISD::getVPForBaseOpcode(Op);
1023     return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
1024   }
1025 };
1026 
1027 } // end anonymous namespace
1028 
1029 //===----------------------------------------------------------------------===//
1030 //  TargetLowering::DAGCombinerInfo implementation
1031 //===----------------------------------------------------------------------===//
1032 
AddToWorklist(SDNode * N)1033 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
1034   ((DAGCombiner*)DC)->AddToWorklist(N);
1035 }
1036 
1037 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)1038 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
1039   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
1040 }
1041 
1042 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)1043 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
1044   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
1045 }
1046 
1047 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)1048 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
1049   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
1050 }
1051 
1052 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)1053 recursivelyDeleteUnusedNodes(SDNode *N) {
1054   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
1055 }
1056 
1057 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1058 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1059   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
1060 }
1061 
1062 //===----------------------------------------------------------------------===//
1063 // Helper Functions
1064 //===----------------------------------------------------------------------===//
1065 
deleteAndRecombine(SDNode * N)1066 void DAGCombiner::deleteAndRecombine(SDNode *N) {
1067   removeFromWorklist(N);
1068 
1069   // If the operands of this node are only used by the node, they will now be
1070   // dead. Make sure to re-visit them and recursively delete dead nodes.
1071   for (const SDValue &Op : N->ops())
1072     // For an operand generating multiple values, one of the values may
1073     // become dead allowing further simplification (e.g. split index
1074     // arithmetic from an indexed load).
1075     if (Op->hasOneUse() || Op->getNumValues() > 1)
1076       AddToWorklist(Op.getNode());
1077 
1078   DAG.DeleteNode(N);
1079 }
1080 
1081 // APInts must be the same size for most operations, this helper
1082 // function zero extends the shorter of the pair so that they match.
1083 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)1084 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
1085   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
1086   LHS = LHS.zext(Bits);
1087   RHS = RHS.zext(Bits);
1088 }
1089 
1090 // Return true if this node is a setcc, or is a select_cc
1091 // that selects between the target values used for true and false, making it
1092 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
1093 // the appropriate nodes based on the type of node we are checking. This
1094 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const1095 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
1096                                     SDValue &CC, bool MatchStrict) const {
1097   if (N.getOpcode() == ISD::SETCC) {
1098     LHS = N.getOperand(0);
1099     RHS = N.getOperand(1);
1100     CC  = N.getOperand(2);
1101     return true;
1102   }
1103 
1104   if (MatchStrict &&
1105       (N.getOpcode() == ISD::STRICT_FSETCC ||
1106        N.getOpcode() == ISD::STRICT_FSETCCS)) {
1107     LHS = N.getOperand(1);
1108     RHS = N.getOperand(2);
1109     CC  = N.getOperand(3);
1110     return true;
1111   }
1112 
1113   if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
1114       !TLI.isConstFalseVal(N.getOperand(3)))
1115     return false;
1116 
1117   if (TLI.getBooleanContents(N.getValueType()) ==
1118       TargetLowering::UndefinedBooleanContent)
1119     return false;
1120 
1121   LHS = N.getOperand(0);
1122   RHS = N.getOperand(1);
1123   CC  = N.getOperand(4);
1124   return true;
1125 }
1126 
1127 /// Return true if this is a SetCC-equivalent operation with only one use.
1128 /// If this is true, it allows the users to invert the operation for free when
1129 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const1130 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1131   SDValue N0, N1, N2;
1132   if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
1133     return true;
1134   return false;
1135 }
1136 
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)1137 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1138   if (!ScalarTy.isSimple())
1139     return false;
1140 
1141   uint64_t MaskForTy = 0ULL;
1142   switch (ScalarTy.getSimpleVT().SimpleTy) {
1143   case MVT::i8:
1144     MaskForTy = 0xFFULL;
1145     break;
1146   case MVT::i16:
1147     MaskForTy = 0xFFFFULL;
1148     break;
1149   case MVT::i32:
1150     MaskForTy = 0xFFFFFFFFULL;
1151     break;
1152   default:
1153     return false;
1154     break;
1155   }
1156 
1157   APInt Val;
1158   if (ISD::isConstantSplatVector(N, Val))
1159     return Val.getLimitedValue() == MaskForTy;
1160 
1161   return false;
1162 }
1163 
1164 // Determines if it is a constant integer or a splat/build vector of constant
1165 // integers (and undefs).
1166 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)1167 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
1168   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
1169     return !(Const->isOpaque() && NoOpaques);
1170   if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1171     return false;
1172   unsigned BitWidth = N.getScalarValueSizeInBits();
1173   for (const SDValue &Op : N->op_values()) {
1174     if (Op.isUndef())
1175       continue;
1176     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
1177     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
1178         (Const->isOpaque() && NoOpaques))
1179       return false;
1180   }
1181   return true;
1182 }
1183 
1184 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1185 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)1186 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1187   if (V.getOpcode() != ISD::BUILD_VECTOR)
1188     return false;
1189   return isConstantOrConstantVector(V, NoOpaques) ||
1190          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
1191 }
1192 
1193 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)1194 static bool canSplitIdx(LoadSDNode *LD) {
1195   return MaySplitLoadIndex &&
1196          (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
1197           !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
1198 }
1199 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDNode * N,SDValue N0,SDValue N1)1200 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1201                                                              const SDLoc &DL,
1202                                                              SDNode *N,
1203                                                              SDValue N0,
1204                                                              SDValue N1) {
1205   // Currently this only tries to ensure we don't undo the GEP splits done by
1206   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1207   // we check if the following transformation would be problematic:
1208   // (load/store (add, (add, x, offset1), offset2)) ->
1209   // (load/store (add, x, offset1+offset2)).
1210 
1211   // (load/store (add, (add, x, y), offset2)) ->
1212   // (load/store (add, (add, x, offset2), y)).
1213 
1214   if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1215     return false;
1216 
1217   auto *C2 = dyn_cast<ConstantSDNode>(N1);
1218   if (!C2)
1219     return false;
1220 
1221   const APInt &C2APIntVal = C2->getAPIntValue();
1222   if (C2APIntVal.getSignificantBits() > 64)
1223     return false;
1224 
1225   if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1226     if (N0.hasOneUse())
1227       return false;
1228 
1229     const APInt &C1APIntVal = C1->getAPIntValue();
1230     const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1231     if (CombinedValueIntVal.getSignificantBits() > 64)
1232       return false;
1233     const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1234 
1235     for (SDNode *Node : N->uses()) {
1236       if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1237         // Is x[offset2] already not a legal addressing mode? If so then
1238         // reassociating the constants breaks nothing (we test offset2 because
1239         // that's the one we hope to fold into the load or store).
1240         TargetLoweringBase::AddrMode AM;
1241         AM.HasBaseReg = true;
1242         AM.BaseOffs = C2APIntVal.getSExtValue();
1243         EVT VT = LoadStore->getMemoryVT();
1244         unsigned AS = LoadStore->getAddressSpace();
1245         Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1246         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1247           continue;
1248 
1249         // Would x[offset1+offset2] still be a legal addressing mode?
1250         AM.BaseOffs = CombinedValue;
1251         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1252           return true;
1253       }
1254     }
1255   } else {
1256     if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1257       if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1258         return false;
1259 
1260     for (SDNode *Node : N->uses()) {
1261       auto *LoadStore = dyn_cast<MemSDNode>(Node);
1262       if (!LoadStore)
1263         return false;
1264 
1265       // Is x[offset2] a legal addressing mode? If so then
1266       // reassociating the constants breaks address pattern
1267       TargetLoweringBase::AddrMode AM;
1268       AM.HasBaseReg = true;
1269       AM.BaseOffs = C2APIntVal.getSExtValue();
1270       EVT VT = LoadStore->getMemoryVT();
1271       unsigned AS = LoadStore->getAddressSpace();
1272       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1273       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1274         return false;
1275     }
1276     return true;
1277   }
1278 
1279   return false;
1280 }
1281 
1282 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
1283 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1284 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1285                                                SDValue N0, SDValue N1,
1286                                                SDNodeFlags Flags) {
1287   EVT VT = N0.getValueType();
1288 
1289   if (N0.getOpcode() != Opc)
1290     return SDValue();
1291 
1292   SDValue N00 = N0.getOperand(0);
1293   SDValue N01 = N0.getOperand(1);
1294 
1295   if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N01))) {
1296     if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N1))) {
1297       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1298       if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1}))
1299         return DAG.getNode(Opc, DL, VT, N00, OpNode);
1300       return SDValue();
1301     }
1302     if (TLI.isReassocProfitable(DAG, N0, N1)) {
1303       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1304       //              iff (op x, c1) has one use
1305       SDNodeFlags NewFlags;
1306       if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1307           Flags.hasNoUnsignedWrap())
1308         NewFlags.setNoUnsignedWrap(true);
1309       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, NewFlags);
1310       return DAG.getNode(Opc, DL, VT, OpNode, N01, NewFlags);
1311     }
1312   }
1313 
1314   // Check for repeated operand logic simplifications.
1315   if (Opc == ISD::AND || Opc == ISD::OR) {
1316     // (N00 & N01) & N00 --> N00 & N01
1317     // (N00 & N01) & N01 --> N00 & N01
1318     // (N00 | N01) | N00 --> N00 | N01
1319     // (N00 | N01) | N01 --> N00 | N01
1320     if (N1 == N00 || N1 == N01)
1321       return N0;
1322   }
1323   if (Opc == ISD::XOR) {
1324     // (N00 ^ N01) ^ N00 --> N01
1325     if (N1 == N00)
1326       return N01;
1327     // (N00 ^ N01) ^ N01 --> N00
1328     if (N1 == N01)
1329       return N00;
1330   }
1331 
1332   if (TLI.isReassocProfitable(DAG, N0, N1)) {
1333     if (N1 != N01) {
1334       // Reassociate if (op N00, N1) already exist
1335       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1336         // if Op (Op N00, N1), N01 already exist
1337         // we need to stop reassciate to avoid dead loop
1338         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1339           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1340       }
1341     }
1342 
1343     if (N1 != N00) {
1344       // Reassociate if (op N01, N1) already exist
1345       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1346         // if Op (Op N01, N1), N00 already exist
1347         // we need to stop reassciate to avoid dead loop
1348         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1349           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1350       }
1351     }
1352 
1353     // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1354     // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1355     // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1356     // comparisons with the same predicate. This enables optimizations as the
1357     // following one:
1358     // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1359     // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1360     if (Opc == ISD::AND || Opc == ISD::OR) {
1361       if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1362           N01->getOpcode() == ISD::SETCC) {
1363         ISD::CondCode CC1 = cast<CondCodeSDNode>(N1.getOperand(2))->get();
1364         ISD::CondCode CC00 = cast<CondCodeSDNode>(N00.getOperand(2))->get();
1365         ISD::CondCode CC01 = cast<CondCodeSDNode>(N01.getOperand(2))->get();
1366         if (CC1 == CC00 && CC1 != CC01) {
1367           SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, Flags);
1368           return DAG.getNode(Opc, DL, VT, OpNode, N01, Flags);
1369         }
1370         if (CC1 == CC01 && CC1 != CC00) {
1371           SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N01, N1, Flags);
1372           return DAG.getNode(Opc, DL, VT, OpNode, N00, Flags);
1373         }
1374       }
1375     }
1376   }
1377 
1378   return SDValue();
1379 }
1380 
1381 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1382 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1383                                     SDValue N1, SDNodeFlags Flags) {
1384   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1385 
1386   // Floating-point reassociation is not allowed without loose FP math.
1387   if (N0.getValueType().isFloatingPoint() ||
1388       N1.getValueType().isFloatingPoint())
1389     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1390       return SDValue();
1391 
1392   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1393     return Combined;
1394   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0, Flags))
1395     return Combined;
1396   return SDValue();
1397 }
1398 
1399 // Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1400 // Note that we only expect Flags to be passed from FP operations. For integer
1401 // operations they need to be dropped.
reassociateReduction(unsigned RedOpc,unsigned Opc,const SDLoc & DL,EVT VT,SDValue N0,SDValue N1,SDNodeFlags Flags)1402 SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1403                                           const SDLoc &DL, EVT VT, SDValue N0,
1404                                           SDValue N1, SDNodeFlags Flags) {
1405   if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1406       N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
1407       N0->hasOneUse() && N1->hasOneUse() &&
1408       TLI.isOperationLegalOrCustom(Opc, N0.getOperand(0).getValueType()) &&
1409       TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) {
1410     SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1411     return DAG.getNode(RedOpc, DL, VT,
1412                        DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
1413                                    N0.getOperand(0), N1.getOperand(0)));
1414   }
1415   return SDValue();
1416 }
1417 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1418 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1419                                bool AddTo) {
1420   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1421   ++NodesCombined;
1422   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1423              To[0].dump(&DAG);
1424              dbgs() << " and " << NumTo - 1 << " other values\n");
1425   for (unsigned i = 0, e = NumTo; i != e; ++i)
1426     assert((!To[i].getNode() ||
1427             N->getValueType(i) == To[i].getValueType()) &&
1428            "Cannot combine value to value of different type!");
1429 
1430   WorklistRemover DeadNodes(*this);
1431   DAG.ReplaceAllUsesWith(N, To);
1432   if (AddTo) {
1433     // Push the new nodes and any users onto the worklist
1434     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1435       if (To[i].getNode())
1436         AddToWorklistWithUsers(To[i].getNode());
1437     }
1438   }
1439 
1440   // Finally, if the node is now dead, remove it from the graph.  The node
1441   // may not be dead if the replacement process recursively simplified to
1442   // something else needing this node.
1443   if (N->use_empty())
1444     deleteAndRecombine(N);
1445   return SDValue(N, 0);
1446 }
1447 
1448 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1449 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1450   // Replace the old value with the new one.
1451   ++NodesCombined;
1452   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1453              dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1454 
1455   // Replace all uses.
1456   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1457 
1458   // Push the new node and any (possibly new) users onto the worklist.
1459   AddToWorklistWithUsers(TLO.New.getNode());
1460 
1461   // Finally, if the node is now dead, remove it from the graph.
1462   recursivelyDeleteUnusedNodes(TLO.Old.getNode());
1463 }
1464 
1465 /// Check the specified integer node value to see if it can be simplified or if
1466 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1467 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1468                                        const APInt &DemandedElts,
1469                                        bool AssumeSingleUse) {
1470   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1471   KnownBits Known;
1472   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1473                                 AssumeSingleUse))
1474     return false;
1475 
1476   // Revisit the node.
1477   AddToWorklist(Op.getNode());
1478 
1479   CommitTargetLoweringOpt(TLO);
1480   return true;
1481 }
1482 
1483 /// Check the specified vector node value to see if it can be simplified or
1484 /// if things it uses can be simplified as it only uses some of the elements.
1485 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1486 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1487                                              const APInt &DemandedElts,
1488                                              bool AssumeSingleUse) {
1489   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1490   APInt KnownUndef, KnownZero;
1491   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1492                                       TLO, 0, AssumeSingleUse))
1493     return false;
1494 
1495   // Revisit the node.
1496   AddToWorklist(Op.getNode());
1497 
1498   CommitTargetLoweringOpt(TLO);
1499   return true;
1500 }
1501 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1502 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1503   SDLoc DL(Load);
1504   EVT VT = Load->getValueType(0);
1505   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1506 
1507   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1508              Trunc.dump(&DAG); dbgs() << '\n');
1509 
1510   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1511   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1512 
1513   AddToWorklist(Trunc.getNode());
1514   recursivelyDeleteUnusedNodes(Load);
1515 }
1516 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1517 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1518   Replace = false;
1519   SDLoc DL(Op);
1520   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1521     LoadSDNode *LD = cast<LoadSDNode>(Op);
1522     EVT MemVT = LD->getMemoryVT();
1523     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1524                                                       : LD->getExtensionType();
1525     Replace = true;
1526     return DAG.getExtLoad(ExtType, DL, PVT,
1527                           LD->getChain(), LD->getBasePtr(),
1528                           MemVT, LD->getMemOperand());
1529   }
1530 
1531   unsigned Opc = Op.getOpcode();
1532   switch (Opc) {
1533   default: break;
1534   case ISD::AssertSext:
1535     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1536       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1537     break;
1538   case ISD::AssertZext:
1539     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1540       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1541     break;
1542   case ISD::Constant: {
1543     unsigned ExtOpc =
1544       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1545     return DAG.getNode(ExtOpc, DL, PVT, Op);
1546   }
1547   }
1548 
1549   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1550     return SDValue();
1551   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1552 }
1553 
SExtPromoteOperand(SDValue Op,EVT PVT)1554 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1555   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1556     return SDValue();
1557   EVT OldVT = Op.getValueType();
1558   SDLoc DL(Op);
1559   bool Replace = false;
1560   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1561   if (!NewOp.getNode())
1562     return SDValue();
1563   AddToWorklist(NewOp.getNode());
1564 
1565   if (Replace)
1566     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1567   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1568                      DAG.getValueType(OldVT));
1569 }
1570 
ZExtPromoteOperand(SDValue Op,EVT PVT)1571 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1572   EVT OldVT = Op.getValueType();
1573   SDLoc DL(Op);
1574   bool Replace = false;
1575   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1576   if (!NewOp.getNode())
1577     return SDValue();
1578   AddToWorklist(NewOp.getNode());
1579 
1580   if (Replace)
1581     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1582   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1583 }
1584 
1585 /// Promote the specified integer binary operation if the target indicates it is
1586 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1587 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1588 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1589   if (!LegalOperations)
1590     return SDValue();
1591 
1592   EVT VT = Op.getValueType();
1593   if (VT.isVector() || !VT.isInteger())
1594     return SDValue();
1595 
1596   // If operation type is 'undesirable', e.g. i16 on x86, consider
1597   // promoting it.
1598   unsigned Opc = Op.getOpcode();
1599   if (TLI.isTypeDesirableForOp(Opc, VT))
1600     return SDValue();
1601 
1602   EVT PVT = VT;
1603   // Consult target whether it is a good idea to promote this operation and
1604   // what's the right type to promote it to.
1605   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1606     assert(PVT != VT && "Don't know what type to promote to!");
1607 
1608     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1609 
1610     bool Replace0 = false;
1611     SDValue N0 = Op.getOperand(0);
1612     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1613 
1614     bool Replace1 = false;
1615     SDValue N1 = Op.getOperand(1);
1616     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1617     SDLoc DL(Op);
1618 
1619     SDValue RV =
1620         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1621 
1622     // We are always replacing N0/N1's use in N and only need additional
1623     // replacements if there are additional uses.
1624     // Note: We are checking uses of the *nodes* (SDNode) rather than values
1625     //       (SDValue) here because the node may reference multiple values
1626     //       (for example, the chain value of a load node).
1627     Replace0 &= !N0->hasOneUse();
1628     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1629 
1630     // Combine Op here so it is preserved past replacements.
1631     CombineTo(Op.getNode(), RV);
1632 
1633     // If operands have a use ordering, make sure we deal with
1634     // predecessor first.
1635     if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1636       std::swap(N0, N1);
1637       std::swap(NN0, NN1);
1638     }
1639 
1640     if (Replace0) {
1641       AddToWorklist(NN0.getNode());
1642       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1643     }
1644     if (Replace1) {
1645       AddToWorklist(NN1.getNode());
1646       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1647     }
1648     return Op;
1649   }
1650   return SDValue();
1651 }
1652 
1653 /// Promote the specified integer shift operation if the target indicates it is
1654 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1655 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1656 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1657   if (!LegalOperations)
1658     return SDValue();
1659 
1660   EVT VT = Op.getValueType();
1661   if (VT.isVector() || !VT.isInteger())
1662     return SDValue();
1663 
1664   // If operation type is 'undesirable', e.g. i16 on x86, consider
1665   // promoting it.
1666   unsigned Opc = Op.getOpcode();
1667   if (TLI.isTypeDesirableForOp(Opc, VT))
1668     return SDValue();
1669 
1670   EVT PVT = VT;
1671   // Consult target whether it is a good idea to promote this operation and
1672   // what's the right type to promote it to.
1673   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1674     assert(PVT != VT && "Don't know what type to promote to!");
1675 
1676     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1677 
1678     bool Replace = false;
1679     SDValue N0 = Op.getOperand(0);
1680     if (Opc == ISD::SRA)
1681       N0 = SExtPromoteOperand(N0, PVT);
1682     else if (Opc == ISD::SRL)
1683       N0 = ZExtPromoteOperand(N0, PVT);
1684     else
1685       N0 = PromoteOperand(N0, PVT, Replace);
1686 
1687     if (!N0.getNode())
1688       return SDValue();
1689 
1690     SDLoc DL(Op);
1691     SDValue N1 = Op.getOperand(1);
1692     SDValue RV =
1693         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1694 
1695     if (Replace)
1696       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1697 
1698     // Deal with Op being deleted.
1699     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1700       return RV;
1701   }
1702   return SDValue();
1703 }
1704 
PromoteExtend(SDValue Op)1705 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1706   if (!LegalOperations)
1707     return SDValue();
1708 
1709   EVT VT = Op.getValueType();
1710   if (VT.isVector() || !VT.isInteger())
1711     return SDValue();
1712 
1713   // If operation type is 'undesirable', e.g. i16 on x86, consider
1714   // promoting it.
1715   unsigned Opc = Op.getOpcode();
1716   if (TLI.isTypeDesirableForOp(Opc, VT))
1717     return SDValue();
1718 
1719   EVT PVT = VT;
1720   // Consult target whether it is a good idea to promote this operation and
1721   // what's the right type to promote it to.
1722   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1723     assert(PVT != VT && "Don't know what type to promote to!");
1724     // fold (aext (aext x)) -> (aext x)
1725     // fold (aext (zext x)) -> (zext x)
1726     // fold (aext (sext x)) -> (sext x)
1727     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1728     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1729   }
1730   return SDValue();
1731 }
1732 
PromoteLoad(SDValue Op)1733 bool DAGCombiner::PromoteLoad(SDValue Op) {
1734   if (!LegalOperations)
1735     return false;
1736 
1737   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1738     return false;
1739 
1740   EVT VT = Op.getValueType();
1741   if (VT.isVector() || !VT.isInteger())
1742     return false;
1743 
1744   // If operation type is 'undesirable', e.g. i16 on x86, consider
1745   // promoting it.
1746   unsigned Opc = Op.getOpcode();
1747   if (TLI.isTypeDesirableForOp(Opc, VT))
1748     return false;
1749 
1750   EVT PVT = VT;
1751   // Consult target whether it is a good idea to promote this operation and
1752   // what's the right type to promote it to.
1753   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1754     assert(PVT != VT && "Don't know what type to promote to!");
1755 
1756     SDLoc DL(Op);
1757     SDNode *N = Op.getNode();
1758     LoadSDNode *LD = cast<LoadSDNode>(N);
1759     EVT MemVT = LD->getMemoryVT();
1760     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1761                                                       : LD->getExtensionType();
1762     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1763                                    LD->getChain(), LD->getBasePtr(),
1764                                    MemVT, LD->getMemOperand());
1765     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1766 
1767     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1768                Result.dump(&DAG); dbgs() << '\n');
1769 
1770     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1771     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1772 
1773     AddToWorklist(Result.getNode());
1774     recursivelyDeleteUnusedNodes(N);
1775     return true;
1776   }
1777 
1778   return false;
1779 }
1780 
1781 /// Recursively delete a node which has no uses and any operands for
1782 /// which it is the only use.
1783 ///
1784 /// Note that this both deletes the nodes and removes them from the worklist.
1785 /// It also adds any nodes who have had a user deleted to the worklist as they
1786 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1787 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1788   if (!N->use_empty())
1789     return false;
1790 
1791   SmallSetVector<SDNode *, 16> Nodes;
1792   Nodes.insert(N);
1793   do {
1794     N = Nodes.pop_back_val();
1795     if (!N)
1796       continue;
1797 
1798     if (N->use_empty()) {
1799       for (const SDValue &ChildN : N->op_values())
1800         Nodes.insert(ChildN.getNode());
1801 
1802       removeFromWorklist(N);
1803       DAG.DeleteNode(N);
1804     } else {
1805       AddToWorklist(N);
1806     }
1807   } while (!Nodes.empty());
1808   return true;
1809 }
1810 
1811 //===----------------------------------------------------------------------===//
1812 //  Main DAG Combiner implementation
1813 //===----------------------------------------------------------------------===//
1814 
Run(CombineLevel AtLevel)1815 void DAGCombiner::Run(CombineLevel AtLevel) {
1816   // set the instance variables, so that the various visit routines may use it.
1817   Level = AtLevel;
1818   LegalDAG = Level >= AfterLegalizeDAG;
1819   LegalOperations = Level >= AfterLegalizeVectorOps;
1820   LegalTypes = Level >= AfterLegalizeTypes;
1821 
1822   WorklistInserter AddNodes(*this);
1823 
1824   // Add all the dag nodes to the worklist.
1825   //
1826   // Note: All nodes are not added to PruningList here, this is because the only
1827   // nodes which can be deleted are those which have no uses and all other nodes
1828   // which would otherwise be added to the worklist by the first call to
1829   // getNextWorklistEntry are already present in it.
1830   for (SDNode &Node : DAG.allnodes())
1831     AddToWorklist(&Node, /* IsCandidateForPruning */ Node.use_empty());
1832 
1833   // Create a dummy node (which is not added to allnodes), that adds a reference
1834   // to the root node, preventing it from being deleted, and tracking any
1835   // changes of the root.
1836   HandleSDNode Dummy(DAG.getRoot());
1837 
1838   // While we have a valid worklist entry node, try to combine it.
1839   while (SDNode *N = getNextWorklistEntry()) {
1840     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1841     // N is deleted from the DAG, since they too may now be dead or may have a
1842     // reduced number of uses, allowing other xforms.
1843     if (recursivelyDeleteUnusedNodes(N))
1844       continue;
1845 
1846     WorklistRemover DeadNodes(*this);
1847 
1848     // If this combine is running after legalizing the DAG, re-legalize any
1849     // nodes pulled off the worklist.
1850     if (LegalDAG) {
1851       SmallSetVector<SDNode *, 16> UpdatedNodes;
1852       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1853 
1854       for (SDNode *LN : UpdatedNodes)
1855         AddToWorklistWithUsers(LN);
1856 
1857       if (!NIsValid)
1858         continue;
1859     }
1860 
1861     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1862 
1863     // Add any operands of the new node which have not yet been combined to the
1864     // worklist as well. Because the worklist uniques things already, this
1865     // won't repeatedly process the same operand.
1866     for (const SDValue &ChildN : N->op_values())
1867       if (!CombinedNodes.count(ChildN.getNode()))
1868         AddToWorklist(ChildN.getNode());
1869 
1870     CombinedNodes.insert(N);
1871     SDValue RV = combine(N);
1872 
1873     if (!RV.getNode())
1874       continue;
1875 
1876     ++NodesCombined;
1877 
1878     // If we get back the same node we passed in, rather than a new node or
1879     // zero, we know that the node must have defined multiple values and
1880     // CombineTo was used.  Since CombineTo takes care of the worklist
1881     // mechanics for us, we have no work to do in this case.
1882     if (RV.getNode() == N)
1883       continue;
1884 
1885     assert(N->getOpcode() != ISD::DELETED_NODE &&
1886            RV.getOpcode() != ISD::DELETED_NODE &&
1887            "Node was deleted but visit returned new node!");
1888 
1889     LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1890 
1891     if (N->getNumValues() == RV->getNumValues())
1892       DAG.ReplaceAllUsesWith(N, RV.getNode());
1893     else {
1894       assert(N->getValueType(0) == RV.getValueType() &&
1895              N->getNumValues() == 1 && "Type mismatch");
1896       DAG.ReplaceAllUsesWith(N, &RV);
1897     }
1898 
1899     // Push the new node and any users onto the worklist.  Omit this if the
1900     // new node is the EntryToken (e.g. if a store managed to get optimized
1901     // out), because re-visiting the EntryToken and its users will not uncover
1902     // any additional opportunities, but there may be a large number of such
1903     // users, potentially causing compile time explosion.
1904     if (RV.getOpcode() != ISD::EntryToken)
1905       AddToWorklistWithUsers(RV.getNode());
1906 
1907     // Finally, if the node is now dead, remove it from the graph.  The node
1908     // may not be dead if the replacement process recursively simplified to
1909     // something else needing this node. This will also take care of adding any
1910     // operands which have lost a user to the worklist.
1911     recursivelyDeleteUnusedNodes(N);
1912   }
1913 
1914   // If the root changed (e.g. it was a dead load, update the root).
1915   DAG.setRoot(Dummy.getValue());
1916   DAG.RemoveDeadNodes();
1917 }
1918 
visit(SDNode * N)1919 SDValue DAGCombiner::visit(SDNode *N) {
1920   // clang-format off
1921   switch (N->getOpcode()) {
1922   default: break;
1923   case ISD::TokenFactor:        return visitTokenFactor(N);
1924   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1925   case ISD::ADD:                return visitADD(N);
1926   case ISD::SUB:                return visitSUB(N);
1927   case ISD::SADDSAT:
1928   case ISD::UADDSAT:            return visitADDSAT(N);
1929   case ISD::SSUBSAT:
1930   case ISD::USUBSAT:            return visitSUBSAT(N);
1931   case ISD::ADDC:               return visitADDC(N);
1932   case ISD::SADDO:
1933   case ISD::UADDO:              return visitADDO(N);
1934   case ISD::SUBC:               return visitSUBC(N);
1935   case ISD::SSUBO:
1936   case ISD::USUBO:              return visitSUBO(N);
1937   case ISD::ADDE:               return visitADDE(N);
1938   case ISD::UADDO_CARRY:        return visitUADDO_CARRY(N);
1939   case ISD::SADDO_CARRY:        return visitSADDO_CARRY(N);
1940   case ISD::SUBE:               return visitSUBE(N);
1941   case ISD::USUBO_CARRY:        return visitUSUBO_CARRY(N);
1942   case ISD::SSUBO_CARRY:        return visitSSUBO_CARRY(N);
1943   case ISD::SMULFIX:
1944   case ISD::SMULFIXSAT:
1945   case ISD::UMULFIX:
1946   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1947   case ISD::MUL:                return visitMUL(N);
1948   case ISD::SDIV:               return visitSDIV(N);
1949   case ISD::UDIV:               return visitUDIV(N);
1950   case ISD::SREM:
1951   case ISD::UREM:               return visitREM(N);
1952   case ISD::MULHU:              return visitMULHU(N);
1953   case ISD::MULHS:              return visitMULHS(N);
1954   case ISD::AVGFLOORS:
1955   case ISD::AVGFLOORU:
1956   case ISD::AVGCEILS:
1957   case ISD::AVGCEILU:           return visitAVG(N);
1958   case ISD::ABDS:
1959   case ISD::ABDU:               return visitABD(N);
1960   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1961   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1962   case ISD::SMULO:
1963   case ISD::UMULO:              return visitMULO(N);
1964   case ISD::SMIN:
1965   case ISD::SMAX:
1966   case ISD::UMIN:
1967   case ISD::UMAX:               return visitIMINMAX(N);
1968   case ISD::AND:                return visitAND(N);
1969   case ISD::OR:                 return visitOR(N);
1970   case ISD::XOR:                return visitXOR(N);
1971   case ISD::SHL:                return visitSHL(N);
1972   case ISD::SRA:                return visitSRA(N);
1973   case ISD::SRL:                return visitSRL(N);
1974   case ISD::ROTR:
1975   case ISD::ROTL:               return visitRotate(N);
1976   case ISD::FSHL:
1977   case ISD::FSHR:               return visitFunnelShift(N);
1978   case ISD::SSHLSAT:
1979   case ISD::USHLSAT:            return visitSHLSAT(N);
1980   case ISD::ABS:                return visitABS(N);
1981   case ISD::BSWAP:              return visitBSWAP(N);
1982   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1983   case ISD::CTLZ:               return visitCTLZ(N);
1984   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1985   case ISD::CTTZ:               return visitCTTZ(N);
1986   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1987   case ISD::CTPOP:              return visitCTPOP(N);
1988   case ISD::SELECT:             return visitSELECT(N);
1989   case ISD::VSELECT:            return visitVSELECT(N);
1990   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1991   case ISD::SETCC:              return visitSETCC(N);
1992   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1993   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1994   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1995   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1996   case ISD::AssertSext:
1997   case ISD::AssertZext:         return visitAssertExt(N);
1998   case ISD::AssertAlign:        return visitAssertAlign(N);
1999   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
2000   case ISD::SIGN_EXTEND_VECTOR_INREG:
2001   case ISD::ZERO_EXTEND_VECTOR_INREG:
2002   case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
2003   case ISD::TRUNCATE:           return visitTRUNCATE(N);
2004   case ISD::BITCAST:            return visitBITCAST(N);
2005   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
2006   case ISD::FADD:               return visitFADD(N);
2007   case ISD::STRICT_FADD:        return visitSTRICT_FADD(N);
2008   case ISD::FSUB:               return visitFSUB(N);
2009   case ISD::FMUL:               return visitFMUL(N);
2010   case ISD::FMA:                return visitFMA<EmptyMatchContext>(N);
2011   case ISD::FMAD:               return visitFMAD(N);
2012   case ISD::FDIV:               return visitFDIV(N);
2013   case ISD::FREM:               return visitFREM(N);
2014   case ISD::FSQRT:              return visitFSQRT(N);
2015   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
2016   case ISD::FPOW:               return visitFPOW(N);
2017   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
2018   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
2019   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
2020   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
2021   case ISD::LRINT:
2022   case ISD::LLRINT:             return visitXRINT(N);
2023   case ISD::FP_ROUND:           return visitFP_ROUND(N);
2024   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
2025   case ISD::FNEG:               return visitFNEG(N);
2026   case ISD::FABS:               return visitFABS(N);
2027   case ISD::FFLOOR:             return visitFFLOOR(N);
2028   case ISD::FMINNUM:
2029   case ISD::FMAXNUM:
2030   case ISD::FMINIMUM:
2031   case ISD::FMAXIMUM:           return visitFMinMax(N);
2032   case ISD::FCEIL:              return visitFCEIL(N);
2033   case ISD::FTRUNC:             return visitFTRUNC(N);
2034   case ISD::FFREXP:             return visitFFREXP(N);
2035   case ISD::BRCOND:             return visitBRCOND(N);
2036   case ISD::BR_CC:              return visitBR_CC(N);
2037   case ISD::LOAD:               return visitLOAD(N);
2038   case ISD::STORE:              return visitSTORE(N);
2039   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
2040   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
2041   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
2042   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
2043   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
2044   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
2045   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
2046   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
2047   case ISD::MGATHER:            return visitMGATHER(N);
2048   case ISD::MLOAD:              return visitMLOAD(N);
2049   case ISD::MSCATTER:           return visitMSCATTER(N);
2050   case ISD::MSTORE:             return visitMSTORE(N);
2051   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
2052   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
2053   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
2054   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
2055   case ISD::BF16_TO_FP:         return visitBF16_TO_FP(N);
2056   case ISD::FREEZE:             return visitFREEZE(N);
2057   case ISD::GET_FPENV_MEM:      return visitGET_FPENV_MEM(N);
2058   case ISD::SET_FPENV_MEM:      return visitSET_FPENV_MEM(N);
2059   case ISD::VECREDUCE_FADD:
2060   case ISD::VECREDUCE_FMUL:
2061   case ISD::VECREDUCE_ADD:
2062   case ISD::VECREDUCE_MUL:
2063   case ISD::VECREDUCE_AND:
2064   case ISD::VECREDUCE_OR:
2065   case ISD::VECREDUCE_XOR:
2066   case ISD::VECREDUCE_SMAX:
2067   case ISD::VECREDUCE_SMIN:
2068   case ISD::VECREDUCE_UMAX:
2069   case ISD::VECREDUCE_UMIN:
2070   case ISD::VECREDUCE_FMAX:
2071   case ISD::VECREDUCE_FMIN:
2072   case ISD::VECREDUCE_FMAXIMUM:
2073   case ISD::VECREDUCE_FMINIMUM:     return visitVECREDUCE(N);
2074 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
2075 #include "llvm/IR/VPIntrinsics.def"
2076     return visitVPOp(N);
2077   }
2078   // clang-format on
2079   return SDValue();
2080 }
2081 
combine(SDNode * N)2082 SDValue DAGCombiner::combine(SDNode *N) {
2083   if (!DebugCounter::shouldExecute(DAGCombineCounter))
2084     return SDValue();
2085 
2086   SDValue RV;
2087   if (!DisableGenericCombines)
2088     RV = visit(N);
2089 
2090   // If nothing happened, try a target-specific DAG combine.
2091   if (!RV.getNode()) {
2092     assert(N->getOpcode() != ISD::DELETED_NODE &&
2093            "Node was deleted but visit returned NULL!");
2094 
2095     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2096         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
2097 
2098       // Expose the DAG combiner to the target combiner impls.
2099       TargetLowering::DAGCombinerInfo
2100         DagCombineInfo(DAG, Level, false, this);
2101 
2102       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
2103     }
2104   }
2105 
2106   // If nothing happened still, try promoting the operation.
2107   if (!RV.getNode()) {
2108     switch (N->getOpcode()) {
2109     default: break;
2110     case ISD::ADD:
2111     case ISD::SUB:
2112     case ISD::MUL:
2113     case ISD::AND:
2114     case ISD::OR:
2115     case ISD::XOR:
2116       RV = PromoteIntBinOp(SDValue(N, 0));
2117       break;
2118     case ISD::SHL:
2119     case ISD::SRA:
2120     case ISD::SRL:
2121       RV = PromoteIntShiftOp(SDValue(N, 0));
2122       break;
2123     case ISD::SIGN_EXTEND:
2124     case ISD::ZERO_EXTEND:
2125     case ISD::ANY_EXTEND:
2126       RV = PromoteExtend(SDValue(N, 0));
2127       break;
2128     case ISD::LOAD:
2129       if (PromoteLoad(SDValue(N, 0)))
2130         RV = SDValue(N, 0);
2131       break;
2132     }
2133   }
2134 
2135   // If N is a commutative binary node, try to eliminate it if the commuted
2136   // version is already present in the DAG.
2137   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
2138     SDValue N0 = N->getOperand(0);
2139     SDValue N1 = N->getOperand(1);
2140 
2141     // Constant operands are canonicalized to RHS.
2142     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
2143       SDValue Ops[] = {N1, N0};
2144       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
2145                                             N->getFlags());
2146       if (CSENode)
2147         return SDValue(CSENode, 0);
2148     }
2149   }
2150 
2151   return RV;
2152 }
2153 
2154 /// Given a node, return its input chain if it has one, otherwise return a null
2155 /// sd operand.
getInputChainForNode(SDNode * N)2156 static SDValue getInputChainForNode(SDNode *N) {
2157   if (unsigned NumOps = N->getNumOperands()) {
2158     if (N->getOperand(0).getValueType() == MVT::Other)
2159       return N->getOperand(0);
2160     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
2161       return N->getOperand(NumOps-1);
2162     for (unsigned i = 1; i < NumOps-1; ++i)
2163       if (N->getOperand(i).getValueType() == MVT::Other)
2164         return N->getOperand(i);
2165   }
2166   return SDValue();
2167 }
2168 
visitTokenFactor(SDNode * N)2169 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2170   // If N has two operands, where one has an input chain equal to the other,
2171   // the 'other' chain is redundant.
2172   if (N->getNumOperands() == 2) {
2173     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
2174       return N->getOperand(0);
2175     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
2176       return N->getOperand(1);
2177   }
2178 
2179   // Don't simplify token factors if optnone.
2180   if (OptLevel == CodeGenOptLevel::None)
2181     return SDValue();
2182 
2183   // Don't simplify the token factor if the node itself has too many operands.
2184   if (N->getNumOperands() > TokenFactorInlineLimit)
2185     return SDValue();
2186 
2187   // If the sole user is a token factor, we should make sure we have a
2188   // chance to merge them together. This prevents TF chains from inhibiting
2189   // optimizations.
2190   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
2191     AddToWorklist(*(N->use_begin()));
2192 
2193   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
2194   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
2195   SmallPtrSet<SDNode*, 16> SeenOps;
2196   bool Changed = false;             // If we should replace this token factor.
2197 
2198   // Start out with this token factor.
2199   TFs.push_back(N);
2200 
2201   // Iterate through token factors.  The TFs grows when new token factors are
2202   // encountered.
2203   for (unsigned i = 0; i < TFs.size(); ++i) {
2204     // Limit number of nodes to inline, to avoid quadratic compile times.
2205     // We have to add the outstanding Token Factors to Ops, otherwise we might
2206     // drop Ops from the resulting Token Factors.
2207     if (Ops.size() > TokenFactorInlineLimit) {
2208       for (unsigned j = i; j < TFs.size(); j++)
2209         Ops.emplace_back(TFs[j], 0);
2210       // Drop unprocessed Token Factors from TFs, so we do not add them to the
2211       // combiner worklist later.
2212       TFs.resize(i);
2213       break;
2214     }
2215 
2216     SDNode *TF = TFs[i];
2217     // Check each of the operands.
2218     for (const SDValue &Op : TF->op_values()) {
2219       switch (Op.getOpcode()) {
2220       case ISD::EntryToken:
2221         // Entry tokens don't need to be added to the list. They are
2222         // redundant.
2223         Changed = true;
2224         break;
2225 
2226       case ISD::TokenFactor:
2227         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
2228           // Queue up for processing.
2229           TFs.push_back(Op.getNode());
2230           Changed = true;
2231           break;
2232         }
2233         [[fallthrough]];
2234 
2235       default:
2236         // Only add if it isn't already in the list.
2237         if (SeenOps.insert(Op.getNode()).second)
2238           Ops.push_back(Op);
2239         else
2240           Changed = true;
2241         break;
2242       }
2243     }
2244   }
2245 
2246   // Re-visit inlined Token Factors, to clean them up in case they have been
2247   // removed. Skip the first Token Factor, as this is the current node.
2248   for (unsigned i = 1, e = TFs.size(); i < e; i++)
2249     AddToWorklist(TFs[i]);
2250 
2251   // Remove Nodes that are chained to another node in the list. Do so
2252   // by walking up chains breath-first stopping when we've seen
2253   // another operand. In general we must climb to the EntryNode, but we can exit
2254   // early if we find all remaining work is associated with just one operand as
2255   // no further pruning is possible.
2256 
2257   // List of nodes to search through and original Ops from which they originate.
2258   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2259   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2260   SmallPtrSet<SDNode *, 16> SeenChains;
2261   bool DidPruneOps = false;
2262 
2263   unsigned NumLeftToConsider = 0;
2264   for (const SDValue &Op : Ops) {
2265     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2266     OpWorkCount.push_back(1);
2267   }
2268 
2269   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2270     // If this is an Op, we can remove the op from the list. Remark any
2271     // search associated with it as from the current OpNumber.
2272     if (SeenOps.contains(Op)) {
2273       Changed = true;
2274       DidPruneOps = true;
2275       unsigned OrigOpNumber = 0;
2276       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2277         OrigOpNumber++;
2278       assert((OrigOpNumber != Ops.size()) &&
2279              "expected to find TokenFactor Operand");
2280       // Re-mark worklist from OrigOpNumber to OpNumber
2281       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2282         if (Worklist[i].second == OrigOpNumber) {
2283           Worklist[i].second = OpNumber;
2284         }
2285       }
2286       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2287       OpWorkCount[OrigOpNumber] = 0;
2288       NumLeftToConsider--;
2289     }
2290     // Add if it's a new chain
2291     if (SeenChains.insert(Op).second) {
2292       OpWorkCount[OpNumber]++;
2293       Worklist.push_back(std::make_pair(Op, OpNumber));
2294     }
2295   };
2296 
2297   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2298     // We need at least be consider at least 2 Ops to prune.
2299     if (NumLeftToConsider <= 1)
2300       break;
2301     auto CurNode = Worklist[i].first;
2302     auto CurOpNumber = Worklist[i].second;
2303     assert((OpWorkCount[CurOpNumber] > 0) &&
2304            "Node should not appear in worklist");
2305     switch (CurNode->getOpcode()) {
2306     case ISD::EntryToken:
2307       // Hitting EntryToken is the only way for the search to terminate without
2308       // hitting
2309       // another operand's search. Prevent us from marking this operand
2310       // considered.
2311       NumLeftToConsider++;
2312       break;
2313     case ISD::TokenFactor:
2314       for (const SDValue &Op : CurNode->op_values())
2315         AddToWorklist(i, Op.getNode(), CurOpNumber);
2316       break;
2317     case ISD::LIFETIME_START:
2318     case ISD::LIFETIME_END:
2319     case ISD::CopyFromReg:
2320     case ISD::CopyToReg:
2321       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2322       break;
2323     default:
2324       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2325         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2326       break;
2327     }
2328     OpWorkCount[CurOpNumber]--;
2329     if (OpWorkCount[CurOpNumber] == 0)
2330       NumLeftToConsider--;
2331   }
2332 
2333   // If we've changed things around then replace token factor.
2334   if (Changed) {
2335     SDValue Result;
2336     if (Ops.empty()) {
2337       // The entry token is the only possible outcome.
2338       Result = DAG.getEntryNode();
2339     } else {
2340       if (DidPruneOps) {
2341         SmallVector<SDValue, 8> PrunedOps;
2342         //
2343         for (const SDValue &Op : Ops) {
2344           if (SeenChains.count(Op.getNode()) == 0)
2345             PrunedOps.push_back(Op);
2346         }
2347         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2348       } else {
2349         Result = DAG.getTokenFactor(SDLoc(N), Ops);
2350       }
2351     }
2352     return Result;
2353   }
2354   return SDValue();
2355 }
2356 
2357 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2358 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2359   WorklistRemover DeadNodes(*this);
2360   // Replacing results may cause a different MERGE_VALUES to suddenly
2361   // be CSE'd with N, and carry its uses with it. Iterate until no
2362   // uses remain, to ensure that the node can be safely deleted.
2363   // First add the users of this node to the work list so that they
2364   // can be tried again once they have new operands.
2365   AddUsersToWorklist(N);
2366   do {
2367     // Do as a single replacement to avoid rewalking use lists.
2368     SmallVector<SDValue, 8> Ops;
2369     for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2370       Ops.push_back(N->getOperand(i));
2371     DAG.ReplaceAllUsesWith(N, Ops.data());
2372   } while (!N->use_empty());
2373   deleteAndRecombine(N);
2374   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
2375 }
2376 
2377 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2378 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2379 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2380   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2381   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2382 }
2383 
2384 // isTruncateOf - If N is a truncate of some other value, return true, record
2385 // the value being truncated in Op and which of Op's bits are zero/one in Known.
2386 // This function computes KnownBits to avoid a duplicated call to
2387 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)2388 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2389                          KnownBits &Known) {
2390   if (N->getOpcode() == ISD::TRUNCATE) {
2391     Op = N->getOperand(0);
2392     Known = DAG.computeKnownBits(Op);
2393     return true;
2394   }
2395 
2396   if (N.getOpcode() != ISD::SETCC ||
2397       N.getValueType().getScalarType() != MVT::i1 ||
2398       cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
2399     return false;
2400 
2401   SDValue Op0 = N->getOperand(0);
2402   SDValue Op1 = N->getOperand(1);
2403   assert(Op0.getValueType() == Op1.getValueType());
2404 
2405   if (isNullOrNullSplat(Op0))
2406     Op = Op1;
2407   else if (isNullOrNullSplat(Op1))
2408     Op = Op0;
2409   else
2410     return false;
2411 
2412   Known = DAG.computeKnownBits(Op);
2413 
2414   return (Known.Zero | 1).isAllOnes();
2415 }
2416 
2417 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2418 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2419 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2420                                     const TargetLowering &TLI) {
2421   EVT VT;
2422   unsigned AS;
2423 
2424   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2425     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2426       return false;
2427     VT = LD->getMemoryVT();
2428     AS = LD->getAddressSpace();
2429   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2430     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2431       return false;
2432     VT = ST->getMemoryVT();
2433     AS = ST->getAddressSpace();
2434   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2435     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2436       return false;
2437     VT = LD->getMemoryVT();
2438     AS = LD->getAddressSpace();
2439   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2440     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2441       return false;
2442     VT = ST->getMemoryVT();
2443     AS = ST->getAddressSpace();
2444   } else {
2445     return false;
2446   }
2447 
2448   TargetLowering::AddrMode AM;
2449   if (N->getOpcode() == ISD::ADD) {
2450     AM.HasBaseReg = true;
2451     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2452     if (Offset)
2453       // [reg +/- imm]
2454       AM.BaseOffs = Offset->getSExtValue();
2455     else
2456       // [reg +/- reg]
2457       AM.Scale = 1;
2458   } else if (N->getOpcode() == ISD::SUB) {
2459     AM.HasBaseReg = true;
2460     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2461     if (Offset)
2462       // [reg +/- imm]
2463       AM.BaseOffs = -Offset->getSExtValue();
2464     else
2465       // [reg +/- reg]
2466       AM.Scale = 1;
2467   } else {
2468     return false;
2469   }
2470 
2471   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2472                                    VT.getTypeForEVT(*DAG.getContext()), AS);
2473 }
2474 
2475 /// This inverts a canonicalization in IR that replaces a variable select arm
2476 /// with an identity constant. Codegen improves if we re-use the variable
2477 /// operand rather than load a constant. This can also be converted into a
2478 /// masked vector operation if the target supports it.
foldSelectWithIdentityConstant(SDNode * N,SelectionDAG & DAG,bool ShouldCommuteOperands)2479 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2480                                               bool ShouldCommuteOperands) {
2481   // Match a select as operand 1. The identity constant that we are looking for
2482   // is only valid as operand 1 of a non-commutative binop.
2483   SDValue N0 = N->getOperand(0);
2484   SDValue N1 = N->getOperand(1);
2485   if (ShouldCommuteOperands)
2486     std::swap(N0, N1);
2487 
2488   // TODO: Should this apply to scalar select too?
2489   if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2490     return SDValue();
2491 
2492   // We can't hoist all instructions because of immediate UB (not speculatable).
2493   // For example div/rem by zero.
2494   if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2495     return SDValue();
2496 
2497   unsigned Opcode = N->getOpcode();
2498   EVT VT = N->getValueType(0);
2499   SDValue Cond = N1.getOperand(0);
2500   SDValue TVal = N1.getOperand(1);
2501   SDValue FVal = N1.getOperand(2);
2502 
2503   // This transform increases uses of N0, so freeze it to be safe.
2504   // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2505   unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2506   if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
2507     SDValue F0 = DAG.getFreeze(N0);
2508     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2509     return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2510   }
2511   // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2512   if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
2513     SDValue F0 = DAG.getFreeze(N0);
2514     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2515     return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2516   }
2517 
2518   return SDValue();
2519 }
2520 
foldBinOpIntoSelect(SDNode * BO)2521 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2522   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2523          "Unexpected binary operator");
2524 
2525   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2526   auto BinOpcode = BO->getOpcode();
2527   EVT VT = BO->getValueType(0);
2528   if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2529     if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2530       return Sel;
2531 
2532     if (TLI.isCommutativeBinOp(BO->getOpcode()))
2533       if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2534         return Sel;
2535   }
2536 
2537   // Don't do this unless the old select is going away. We want to eliminate the
2538   // binary operator, not replace a binop with a select.
2539   // TODO: Handle ISD::SELECT_CC.
2540   unsigned SelOpNo = 0;
2541   SDValue Sel = BO->getOperand(0);
2542   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2543     SelOpNo = 1;
2544     Sel = BO->getOperand(1);
2545 
2546     // Peek through trunc to shift amount type.
2547     if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2548          BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2549       // This is valid when the truncated bits of x are already zero.
2550       SDValue Op;
2551       KnownBits Known;
2552       if (isTruncateOf(DAG, Sel, Op, Known) &&
2553           Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2554         Sel = Op;
2555     }
2556   }
2557 
2558   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2559     return SDValue();
2560 
2561   SDValue CT = Sel.getOperand(1);
2562   if (!isConstantOrConstantVector(CT, true) &&
2563       !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2564     return SDValue();
2565 
2566   SDValue CF = Sel.getOperand(2);
2567   if (!isConstantOrConstantVector(CF, true) &&
2568       !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2569     return SDValue();
2570 
2571   // Bail out if any constants are opaque because we can't constant fold those.
2572   // The exception is "and" and "or" with either 0 or -1 in which case we can
2573   // propagate non constant operands into select. I.e.:
2574   // and (select Cond, 0, -1), X --> select Cond, 0, X
2575   // or X, (select Cond, -1, 0) --> select Cond, -1, X
2576   bool CanFoldNonConst =
2577       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2578       ((isNullOrNullSplat(CT) && isAllOnesOrAllOnesSplat(CF)) ||
2579        (isNullOrNullSplat(CF) && isAllOnesOrAllOnesSplat(CT)));
2580 
2581   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2582   if (!CanFoldNonConst &&
2583       !isConstantOrConstantVector(CBO, true) &&
2584       !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2585     return SDValue();
2586 
2587   SDLoc DL(Sel);
2588   SDValue NewCT, NewCF;
2589 
2590   if (CanFoldNonConst) {
2591     // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2592     if ((BinOpcode == ISD::AND && isNullOrNullSplat(CT)) ||
2593         (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CT)))
2594       NewCT = CT;
2595     else
2596       NewCT = CBO;
2597 
2598     if ((BinOpcode == ISD::AND && isNullOrNullSplat(CF)) ||
2599         (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CF)))
2600       NewCF = CF;
2601     else
2602       NewCF = CBO;
2603   } else {
2604     // We have a select-of-constants followed by a binary operator with a
2605     // constant. Eliminate the binop by pulling the constant math into the
2606     // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2607     // CBO, CF + CBO
2608     NewCT = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CT})
2609                     : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CT, CBO});
2610     if (!NewCT)
2611       return SDValue();
2612 
2613     NewCF = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CF})
2614                     : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CF, CBO});
2615     if (!NewCF)
2616       return SDValue();
2617   }
2618 
2619   SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2620   SelectOp->setFlags(BO->getFlags());
2621   return SelectOp;
2622 }
2623 
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2624 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2625   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2626          "Expecting add or sub");
2627 
2628   // Match a constant operand and a zext operand for the math instruction:
2629   // add Z, C
2630   // sub C, Z
2631   bool IsAdd = N->getOpcode() == ISD::ADD;
2632   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2633   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2634   auto *CN = dyn_cast<ConstantSDNode>(C);
2635   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2636     return SDValue();
2637 
2638   // Match the zext operand as a setcc of a boolean.
2639   if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2640       Z.getOperand(0).getValueType() != MVT::i1)
2641     return SDValue();
2642 
2643   // Match the compare as: setcc (X & 1), 0, eq.
2644   SDValue SetCC = Z.getOperand(0);
2645   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2646   if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2647       SetCC.getOperand(0).getOpcode() != ISD::AND ||
2648       !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2649     return SDValue();
2650 
2651   // We are adding/subtracting a constant and an inverted low bit. Turn that
2652   // into a subtract/add of the low bit with incremented/decremented constant:
2653   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2654   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2655   EVT VT = C.getValueType();
2656   SDLoc DL(N);
2657   SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2658   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2659                        DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2660   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2661 }
2662 
2663 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2664 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2665 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2666   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2667          "Expecting add or sub");
2668 
2669   // We need a constant operand for the add/sub, and the other operand is a
2670   // logical shift right: add (srl), C or sub C, (srl).
2671   bool IsAdd = N->getOpcode() == ISD::ADD;
2672   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2673   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2674   if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2675       ShiftOp.getOpcode() != ISD::SRL)
2676     return SDValue();
2677 
2678   // The shift must be of a 'not' value.
2679   SDValue Not = ShiftOp.getOperand(0);
2680   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2681     return SDValue();
2682 
2683   // The shift must be moving the sign bit to the least-significant-bit.
2684   EVT VT = ShiftOp.getValueType();
2685   SDValue ShAmt = ShiftOp.getOperand(1);
2686   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2687   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2688     return SDValue();
2689 
2690   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2691   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2692   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2693   SDLoc DL(N);
2694   if (SDValue NewC = DAG.FoldConstantArithmetic(
2695           IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2696           {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2697     SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2698                                    Not.getOperand(0), ShAmt);
2699     return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2700   }
2701 
2702   return SDValue();
2703 }
2704 
2705 static bool
areBitwiseNotOfEachother(SDValue Op0,SDValue Op1)2706 areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2707   return (isBitwiseNot(Op0) && Op0.getOperand(0) == Op1) ||
2708          (isBitwiseNot(Op1) && Op1.getOperand(0) == Op0);
2709 }
2710 
2711 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2712 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2713 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2714 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2715   SDValue N0 = N->getOperand(0);
2716   SDValue N1 = N->getOperand(1);
2717   EVT VT = N0.getValueType();
2718   SDLoc DL(N);
2719 
2720   // fold (add x, undef) -> undef
2721   if (N0.isUndef())
2722     return N0;
2723   if (N1.isUndef())
2724     return N1;
2725 
2726   // fold (add c1, c2) -> c1+c2
2727   if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2728     return C;
2729 
2730   // canonicalize constant to RHS
2731   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2732       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2733     return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2734 
2735   if (areBitwiseNotOfEachother(N0, N1))
2736     return DAG.getConstant(APInt::getAllOnes(VT.getScalarSizeInBits()),
2737                            SDLoc(N), VT);
2738 
2739   // fold vector ops
2740   if (VT.isVector()) {
2741     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2742       return FoldedVOp;
2743 
2744     // fold (add x, 0) -> x, vector edition
2745     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2746       return N0;
2747   }
2748 
2749   // fold (add x, 0) -> x
2750   if (isNullConstant(N1))
2751     return N0;
2752 
2753   if (N0.getOpcode() == ISD::SUB) {
2754     SDValue N00 = N0.getOperand(0);
2755     SDValue N01 = N0.getOperand(1);
2756 
2757     // fold ((A-c1)+c2) -> (A+(c2-c1))
2758     if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2759       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2760 
2761     // fold ((c1-A)+c2) -> (c1+c2)-A
2762     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2763       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2764   }
2765 
2766   // add (sext i1 X), 1 -> zext (not i1 X)
2767   // We don't transform this pattern:
2768   //   add (zext i1 X), -1 -> sext (not i1 X)
2769   // because most (?) targets generate better code for the zext form.
2770   if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2771       isOneOrOneSplat(N1)) {
2772     SDValue X = N0.getOperand(0);
2773     if ((!LegalOperations ||
2774          (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2775           TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2776         X.getScalarValueSizeInBits() == 1) {
2777       SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2778       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2779     }
2780   }
2781 
2782   // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2783   // iff (or x, c0) is equivalent to (add x, c0).
2784   // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2785   // iff (xor x, c0) is equivalent to (add x, c0).
2786   if (DAG.isADDLike(N0)) {
2787     SDValue N01 = N0.getOperand(1);
2788     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2789       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2790   }
2791 
2792   if (SDValue NewSel = foldBinOpIntoSelect(N))
2793     return NewSel;
2794 
2795   // reassociate add
2796   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2797     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2798       return RADD;
2799 
2800     // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2801     // equivalent to (add x, c).
2802     // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2803     // equivalent to (add x, c).
2804     // Do this optimization only when adding c does not introduce instructions
2805     // for adding carries.
2806     auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2807       if (DAG.isADDLike(N0) && N0.hasOneUse() &&
2808           isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2809         // If N0's type does not split or is a sign mask, it does not introduce
2810         // add carry.
2811         auto TyActn = TLI.getTypeAction(*DAG.getContext(), N0.getValueType());
2812         bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2813                           TyActn == TargetLoweringBase::TypePromoteInteger ||
2814                           isMinSignedConstant(N0.getOperand(1));
2815         if (NoAddCarry)
2816           return DAG.getNode(
2817               ISD::ADD, DL, VT,
2818               DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2819               N0.getOperand(1));
2820       }
2821       return SDValue();
2822     };
2823     if (SDValue Add = ReassociateAddOr(N0, N1))
2824       return Add;
2825     if (SDValue Add = ReassociateAddOr(N1, N0))
2826       return Add;
2827 
2828     // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2829     if (SDValue SD =
2830             reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
2831       return SD;
2832   }
2833   // fold ((0-A) + B) -> B-A
2834   if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2835     return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2836 
2837   // fold (A + (0-B)) -> A-B
2838   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2839     return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2840 
2841   // fold (A+(B-A)) -> B
2842   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2843     return N1.getOperand(0);
2844 
2845   // fold ((B-A)+A) -> B
2846   if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2847     return N0.getOperand(0);
2848 
2849   // fold ((A-B)+(C-A)) -> (C-B)
2850   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2851       N0.getOperand(0) == N1.getOperand(1))
2852     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2853                        N0.getOperand(1));
2854 
2855   // fold ((A-B)+(B-C)) -> (A-C)
2856   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2857       N0.getOperand(1) == N1.getOperand(0))
2858     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2859                        N1.getOperand(1));
2860 
2861   // fold (A+(B-(A+C))) to (B-C)
2862   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2863       N0 == N1.getOperand(1).getOperand(0))
2864     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2865                        N1.getOperand(1).getOperand(1));
2866 
2867   // fold (A+(B-(C+A))) to (B-C)
2868   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2869       N0 == N1.getOperand(1).getOperand(1))
2870     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2871                        N1.getOperand(1).getOperand(0));
2872 
2873   // fold (A+((B-A)+or-C)) to (B+or-C)
2874   if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2875       N1.getOperand(0).getOpcode() == ISD::SUB &&
2876       N0 == N1.getOperand(0).getOperand(1))
2877     return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2878                        N1.getOperand(1));
2879 
2880   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2881   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2882       N0->hasOneUse() && N1->hasOneUse()) {
2883     SDValue N00 = N0.getOperand(0);
2884     SDValue N01 = N0.getOperand(1);
2885     SDValue N10 = N1.getOperand(0);
2886     SDValue N11 = N1.getOperand(1);
2887 
2888     if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2889       return DAG.getNode(ISD::SUB, DL, VT,
2890                          DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2891                          DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2892   }
2893 
2894   // fold (add (umax X, C), -C) --> (usubsat X, C)
2895   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2896     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2897       return (!Max && !Op) ||
2898              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2899     };
2900     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2901                                   /*AllowUndefs*/ true))
2902       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2903                          N0.getOperand(1));
2904   }
2905 
2906   if (SimplifyDemandedBits(SDValue(N, 0)))
2907     return SDValue(N, 0);
2908 
2909   if (isOneOrOneSplat(N1)) {
2910     // fold (add (xor a, -1), 1) -> (sub 0, a)
2911     if (isBitwiseNot(N0))
2912       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2913                          N0.getOperand(0));
2914 
2915     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2916     if (N0.getOpcode() == ISD::ADD) {
2917       SDValue A, Xor;
2918 
2919       if (isBitwiseNot(N0.getOperand(0))) {
2920         A = N0.getOperand(1);
2921         Xor = N0.getOperand(0);
2922       } else if (isBitwiseNot(N0.getOperand(1))) {
2923         A = N0.getOperand(0);
2924         Xor = N0.getOperand(1);
2925       }
2926 
2927       if (Xor)
2928         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2929     }
2930 
2931     // Look for:
2932     //   add (add x, y), 1
2933     // And if the target does not like this form then turn into:
2934     //   sub y, (xor x, -1)
2935     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2936         N0.hasOneUse() &&
2937         // Limit this to after legalization if the add has wrap flags
2938         (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
2939                                        !N->getFlags().hasNoSignedWrap()))) {
2940       SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2941                                 DAG.getAllOnesConstant(DL, VT));
2942       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2943     }
2944   }
2945 
2946   // (x - y) + -1  ->  add (xor y, -1), x
2947   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2948       isAllOnesOrAllOnesSplat(N1)) {
2949     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2950     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2951   }
2952 
2953   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2954     return Combined;
2955 
2956   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2957     return Combined;
2958 
2959   return SDValue();
2960 }
2961 
visitADD(SDNode * N)2962 SDValue DAGCombiner::visitADD(SDNode *N) {
2963   SDValue N0 = N->getOperand(0);
2964   SDValue N1 = N->getOperand(1);
2965   EVT VT = N0.getValueType();
2966   SDLoc DL(N);
2967 
2968   if (SDValue Combined = visitADDLike(N))
2969     return Combined;
2970 
2971   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2972     return V;
2973 
2974   if (SDValue V = foldAddSubOfSignBit(N, DAG))
2975     return V;
2976 
2977   // fold (a+b) -> (a|b) iff a and b share no bits.
2978   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2979       DAG.haveNoCommonBitsSet(N0, N1))
2980     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2981 
2982   // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2983   if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2984     const APInt &C0 = N0->getConstantOperandAPInt(0);
2985     const APInt &C1 = N1->getConstantOperandAPInt(0);
2986     return DAG.getVScale(DL, VT, C0 + C1);
2987   }
2988 
2989   // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2990   if (N0.getOpcode() == ISD::ADD &&
2991       N0.getOperand(1).getOpcode() == ISD::VSCALE &&
2992       N1.getOpcode() == ISD::VSCALE) {
2993     const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2994     const APInt &VS1 = N1->getConstantOperandAPInt(0);
2995     SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
2996     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
2997   }
2998 
2999   // Fold (add step_vector(c1), step_vector(c2)  to step_vector(c1+c2))
3000   if (N0.getOpcode() == ISD::STEP_VECTOR &&
3001       N1.getOpcode() == ISD::STEP_VECTOR) {
3002     const APInt &C0 = N0->getConstantOperandAPInt(0);
3003     const APInt &C1 = N1->getConstantOperandAPInt(0);
3004     APInt NewStep = C0 + C1;
3005     return DAG.getStepVector(DL, VT, NewStep);
3006   }
3007 
3008   // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3009   if (N0.getOpcode() == ISD::ADD &&
3010       N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR &&
3011       N1.getOpcode() == ISD::STEP_VECTOR) {
3012     const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
3013     const APInt &SV1 = N1->getConstantOperandAPInt(0);
3014     APInt NewStep = SV0 + SV1;
3015     SDValue SV = DAG.getStepVector(DL, VT, NewStep);
3016     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
3017   }
3018 
3019   return SDValue();
3020 }
3021 
visitADDSAT(SDNode * N)3022 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3023   unsigned Opcode = N->getOpcode();
3024   SDValue N0 = N->getOperand(0);
3025   SDValue N1 = N->getOperand(1);
3026   EVT VT = N0.getValueType();
3027   bool IsSigned = Opcode == ISD::SADDSAT;
3028   SDLoc DL(N);
3029 
3030   // fold (add_sat x, undef) -> -1
3031   if (N0.isUndef() || N1.isUndef())
3032     return DAG.getAllOnesConstant(DL, VT);
3033 
3034   // fold (add_sat c1, c2) -> c3
3035   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
3036     return C;
3037 
3038   // canonicalize constant to RHS
3039   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3040       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3041     return DAG.getNode(Opcode, DL, VT, N1, N0);
3042 
3043   // fold vector ops
3044   if (VT.isVector()) {
3045     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3046       return FoldedVOp;
3047 
3048     // fold (add_sat x, 0) -> x, vector edition
3049     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3050       return N0;
3051   }
3052 
3053   // fold (add_sat x, 0) -> x
3054   if (isNullConstant(N1))
3055     return N0;
3056 
3057   // If it cannot overflow, transform into an add.
3058   if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3059     return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
3060 
3061   return SDValue();
3062 }
3063 
getAsCarry(const TargetLowering & TLI,SDValue V,bool ForceCarryReconstruction=false)3064 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
3065                           bool ForceCarryReconstruction = false) {
3066   bool Masked = false;
3067 
3068   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3069   while (true) {
3070     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3071       V = V.getOperand(0);
3072       continue;
3073     }
3074 
3075     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
3076       if (ForceCarryReconstruction)
3077         return V;
3078 
3079       Masked = true;
3080       V = V.getOperand(0);
3081       continue;
3082     }
3083 
3084     if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3085       return V;
3086 
3087     break;
3088   }
3089 
3090   // If this is not a carry, return.
3091   if (V.getResNo() != 1)
3092     return SDValue();
3093 
3094   if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3095       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3096     return SDValue();
3097 
3098   EVT VT = V->getValueType(0);
3099   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
3100     return SDValue();
3101 
3102   // If the result is masked, then no matter what kind of bool it is we can
3103   // return. If it isn't, then we need to make sure the bool type is either 0 or
3104   // 1 and not other values.
3105   if (Masked ||
3106       TLI.getBooleanContents(V.getValueType()) ==
3107           TargetLoweringBase::ZeroOrOneBooleanContent)
3108     return V;
3109 
3110   return SDValue();
3111 }
3112 
3113 /// Given the operands of an add/sub operation, see if the 2nd operand is a
3114 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3115 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)3116 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3117                                  SelectionDAG &DAG, const SDLoc &DL) {
3118   if (N1.getOpcode() == ISD::ZERO_EXTEND)
3119     N1 = N1.getOperand(0);
3120 
3121   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
3122     return SDValue();
3123 
3124   EVT VT = N0.getValueType();
3125   SDValue N10 = N1.getOperand(0);
3126   if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3127     N10 = N10.getOperand(0);
3128 
3129   if (N10.getValueType() != VT)
3130     return SDValue();
3131 
3132   if (DAG.ComputeNumSignBits(N10) != VT.getScalarSizeInBits())
3133     return SDValue();
3134 
3135   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3136   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3137   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N10);
3138 }
3139 
3140 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)3141 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3142                                           SDNode *LocReference) {
3143   EVT VT = N0.getValueType();
3144   SDLoc DL(LocReference);
3145 
3146   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3147   if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
3148       isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
3149     return DAG.getNode(ISD::SUB, DL, VT, N0,
3150                        DAG.getNode(ISD::SHL, DL, VT,
3151                                    N1.getOperand(0).getOperand(1),
3152                                    N1.getOperand(1)));
3153 
3154   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
3155     return V;
3156 
3157   // Look for:
3158   //   add (add x, 1), y
3159   // And if the target does not like this form then turn into:
3160   //   sub y, (xor x, -1)
3161   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3162       N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1)) &&
3163       // Limit this to after legalization if the add has wrap flags
3164       (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3165                                      !N0->getFlags().hasNoSignedWrap()))) {
3166     SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
3167                               DAG.getAllOnesConstant(DL, VT));
3168     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
3169   }
3170 
3171   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3172     // Hoist one-use subtraction by non-opaque constant:
3173     //   (x - C) + y  ->  (x + y) - C
3174     // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3175     if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3176       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
3177       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
3178     }
3179     // Hoist one-use subtraction from non-opaque constant:
3180     //   (C - x) + y  ->  (y - x) + C
3181     if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3182       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
3183       return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
3184     }
3185   }
3186 
3187   // add (mul x, C), x -> mul x, C+1
3188   if (N0.getOpcode() == ISD::MUL && N0.getOperand(0) == N1 &&
3189       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true) &&
3190       N0.hasOneUse()) {
3191     SDValue NewC = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
3192                                DAG.getConstant(1, DL, VT));
3193     return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), NewC);
3194   }
3195 
3196   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3197   // rather than 'add 0/-1' (the zext should get folded).
3198   // add (sext i1 Y), X --> sub X, (zext i1 Y)
3199   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3200       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
3201       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
3202     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
3203     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
3204   }
3205 
3206   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3207   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3208     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3209     if (TN->getVT() == MVT::i1) {
3210       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3211                                  DAG.getConstant(1, DL, VT));
3212       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
3213     }
3214   }
3215 
3216   // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3217   if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1)) &&
3218       N1.getResNo() == 0)
3219     return DAG.getNode(ISD::UADDO_CARRY, DL, N1->getVTList(),
3220                        N0, N1.getOperand(0), N1.getOperand(2));
3221 
3222   // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3223   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT))
3224     if (SDValue Carry = getAsCarry(TLI, N1))
3225       return DAG.getNode(ISD::UADDO_CARRY, DL,
3226                          DAG.getVTList(VT, Carry.getValueType()), N0,
3227                          DAG.getConstant(0, DL, VT), Carry);
3228 
3229   return SDValue();
3230 }
3231 
visitADDC(SDNode * N)3232 SDValue DAGCombiner::visitADDC(SDNode *N) {
3233   SDValue N0 = N->getOperand(0);
3234   SDValue N1 = N->getOperand(1);
3235   EVT VT = N0.getValueType();
3236   SDLoc DL(N);
3237 
3238   // If the flag result is dead, turn this into an ADD.
3239   if (!N->hasAnyUseOfValue(1))
3240     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3241                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3242 
3243   // canonicalize constant to RHS.
3244   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3245   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3246   if (N0C && !N1C)
3247     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
3248 
3249   // fold (addc x, 0) -> x + no carry out
3250   if (isNullConstant(N1))
3251     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
3252                                         DL, MVT::Glue));
3253 
3254   // If it cannot overflow, transform into an add.
3255   if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3256     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3257                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3258 
3259   return SDValue();
3260 }
3261 
3262 /**
3263  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3264  * then the flip also occurs if computing the inverse is the same cost.
3265  * This function returns an empty SDValue in case it cannot flip the boolean
3266  * without increasing the cost of the computation. If you want to flip a boolean
3267  * no matter what, use DAG.getLogicalNOT.
3268  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)3269 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3270                                   const TargetLowering &TLI,
3271                                   bool Force) {
3272   if (Force && isa<ConstantSDNode>(V))
3273     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3274 
3275   if (V.getOpcode() != ISD::XOR)
3276     return SDValue();
3277 
3278   ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
3279   if (!Const)
3280     return SDValue();
3281 
3282   EVT VT = V.getValueType();
3283 
3284   bool IsFlip = false;
3285   switch(TLI.getBooleanContents(VT)) {
3286     case TargetLowering::ZeroOrOneBooleanContent:
3287       IsFlip = Const->isOne();
3288       break;
3289     case TargetLowering::ZeroOrNegativeOneBooleanContent:
3290       IsFlip = Const->isAllOnes();
3291       break;
3292     case TargetLowering::UndefinedBooleanContent:
3293       IsFlip = (Const->getAPIntValue() & 0x01) == 1;
3294       break;
3295   }
3296 
3297   if (IsFlip)
3298     return V.getOperand(0);
3299   if (Force)
3300     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3301   return SDValue();
3302 }
3303 
visitADDO(SDNode * N)3304 SDValue DAGCombiner::visitADDO(SDNode *N) {
3305   SDValue N0 = N->getOperand(0);
3306   SDValue N1 = N->getOperand(1);
3307   EVT VT = N0.getValueType();
3308   bool IsSigned = (ISD::SADDO == N->getOpcode());
3309 
3310   EVT CarryVT = N->getValueType(1);
3311   SDLoc DL(N);
3312 
3313   // If the flag result is dead, turn this into an ADD.
3314   if (!N->hasAnyUseOfValue(1))
3315     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3316                      DAG.getUNDEF(CarryVT));
3317 
3318   // canonicalize constant to RHS.
3319   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3320       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3321     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
3322 
3323   // fold (addo x, 0) -> x + no carry out
3324   if (isNullOrNullSplat(N1))
3325     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3326 
3327   // If it cannot overflow, transform into an add.
3328   if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3329     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3330                      DAG.getConstant(0, DL, CarryVT));
3331 
3332   if (IsSigned) {
3333     // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3334     if (isBitwiseNot(N0) && isOneOrOneSplat(N1))
3335       return DAG.getNode(ISD::SSUBO, DL, N->getVTList(),
3336                          DAG.getConstant(0, DL, VT), N0.getOperand(0));
3337   } else {
3338     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3339     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
3340       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
3341                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3342       return CombineTo(
3343           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3344     }
3345 
3346     if (SDValue Combined = visitUADDOLike(N0, N1, N))
3347       return Combined;
3348 
3349     if (SDValue Combined = visitUADDOLike(N1, N0, N))
3350       return Combined;
3351   }
3352 
3353   return SDValue();
3354 }
3355 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)3356 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3357   EVT VT = N0.getValueType();
3358   if (VT.isVector())
3359     return SDValue();
3360 
3361   // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3362   // If Y + 1 cannot overflow.
3363   if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1))) {
3364     SDValue Y = N1.getOperand(0);
3365     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3366     if (DAG.computeOverflowForUnsignedAdd(Y, One) == SelectionDAG::OFK_Never)
3367       return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0, Y,
3368                          N1.getOperand(2));
3369   }
3370 
3371   // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3372   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT))
3373     if (SDValue Carry = getAsCarry(TLI, N1))
3374       return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0,
3375                          DAG.getConstant(0, SDLoc(N), VT), Carry);
3376 
3377   return SDValue();
3378 }
3379 
visitADDE(SDNode * N)3380 SDValue DAGCombiner::visitADDE(SDNode *N) {
3381   SDValue N0 = N->getOperand(0);
3382   SDValue N1 = N->getOperand(1);
3383   SDValue CarryIn = N->getOperand(2);
3384 
3385   // canonicalize constant to RHS
3386   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3387   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3388   if (N0C && !N1C)
3389     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3390                        N1, N0, CarryIn);
3391 
3392   // fold (adde x, y, false) -> (addc x, y)
3393   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3394     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3395 
3396   return SDValue();
3397 }
3398 
visitUADDO_CARRY(SDNode * N)3399 SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3400   SDValue N0 = N->getOperand(0);
3401   SDValue N1 = N->getOperand(1);
3402   SDValue CarryIn = N->getOperand(2);
3403   SDLoc DL(N);
3404 
3405   // canonicalize constant to RHS
3406   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3407   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3408   if (N0C && !N1C)
3409     return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3410 
3411   // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3412   if (isNullConstant(CarryIn)) {
3413     if (!LegalOperations ||
3414         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3415       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3416   }
3417 
3418   // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3419   if (isNullConstant(N0) && isNullConstant(N1)) {
3420     EVT VT = N0.getValueType();
3421     EVT CarryVT = CarryIn.getValueType();
3422     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3423     AddToWorklist(CarryExt.getNode());
3424     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3425                                     DAG.getConstant(1, DL, VT)),
3426                      DAG.getConstant(0, DL, CarryVT));
3427   }
3428 
3429   if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3430     return Combined;
3431 
3432   if (SDValue Combined = visitUADDO_CARRYLike(N1, N0, CarryIn, N))
3433     return Combined;
3434 
3435   // We want to avoid useless duplication.
3436   // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3437   // not a binary operation, this is not really possible to leverage this
3438   // existing mechanism for it. However, if more operations require the same
3439   // deduplication logic, then it may be worth generalize.
3440   SDValue Ops[] = {N1, N0, CarryIn};
3441   SDNode *CSENode =
3442       DAG.getNodeIfExists(ISD::UADDO_CARRY, N->getVTList(), Ops, N->getFlags());
3443   if (CSENode)
3444     return SDValue(CSENode, 0);
3445 
3446   return SDValue();
3447 }
3448 
3449 /**
3450  * If we are facing some sort of diamond carry propapagtion pattern try to
3451  * break it up to generate something like:
3452  *   (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3453  *
3454  * The end result is usually an increase in operation required, but because the
3455  * carry is now linearized, other transforms can kick in and optimize the DAG.
3456  *
3457  * Patterns typically look something like
3458  *                (uaddo A, B)
3459  *                /          \
3460  *             Carry         Sum
3461  *               |             \
3462  *               | (uaddo_carry *, 0, Z)
3463  *               |       /
3464  *                \   Carry
3465  *                 |   /
3466  * (uaddo_carry X, *, *)
3467  *
3468  * But numerous variation exist. Our goal is to identify A, B, X and Z and
3469  * produce a combine with a single path for carry propagation.
3470  */
combineUADDO_CARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)3471 static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3472                                          SelectionDAG &DAG, SDValue X,
3473                                          SDValue Carry0, SDValue Carry1,
3474                                          SDNode *N) {
3475   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3476     return SDValue();
3477   if (Carry1.getOpcode() != ISD::UADDO)
3478     return SDValue();
3479 
3480   SDValue Z;
3481 
3482   /**
3483    * First look for a suitable Z. It will present itself in the form of
3484    * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3485    */
3486   if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3487       isNullConstant(Carry0.getOperand(1))) {
3488     Z = Carry0.getOperand(2);
3489   } else if (Carry0.getOpcode() == ISD::UADDO &&
3490              isOneConstant(Carry0.getOperand(1))) {
3491     EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
3492     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3493   } else {
3494     // We couldn't find a suitable Z.
3495     return SDValue();
3496   }
3497 
3498 
3499   auto cancelDiamond = [&](SDValue A,SDValue B) {
3500     SDLoc DL(N);
3501     SDValue NewY =
3502         DAG.getNode(ISD::UADDO_CARRY, DL, Carry0->getVTList(), A, B, Z);
3503     Combiner.AddToWorklist(NewY.getNode());
3504     return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), X,
3505                        DAG.getConstant(0, DL, X.getValueType()),
3506                        NewY.getValue(1));
3507   };
3508 
3509   /**
3510    *         (uaddo A, B)
3511    *              |
3512    *             Sum
3513    *              |
3514    * (uaddo_carry *, 0, Z)
3515    */
3516   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3517     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3518   }
3519 
3520   /**
3521    * (uaddo_carry A, 0, Z)
3522    *         |
3523    *        Sum
3524    *         |
3525    *  (uaddo *, B)
3526    */
3527   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3528     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3529   }
3530 
3531   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3532     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3533   }
3534 
3535   return SDValue();
3536 }
3537 
3538 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3539 // match patterns like:
3540 //
3541 //          (uaddo A, B)            CarryIn
3542 //            |  \                     |
3543 //            |   \                    |
3544 //    PartialSum   PartialCarryOutX   /
3545 //            |        |             /
3546 //            |    ____|____________/
3547 //            |   /    |
3548 //     (uaddo *, *)    \________
3549 //       |  \                   \
3550 //       |   \                   |
3551 //       |    PartialCarryOutY   |
3552 //       |        \              |
3553 //       |         \            /
3554 //   AddCarrySum    |    ______/
3555 //                  |   /
3556 //   CarryOut = (or *, *)
3557 //
3558 // And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3559 //
3560 //    {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3561 //
3562 // Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3563 // with a single path for carry/borrow out propagation.
combineCarryDiamond(SelectionDAG & DAG,const TargetLowering & TLI,SDValue N0,SDValue N1,SDNode * N)3564 static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3565                                    SDValue N0, SDValue N1, SDNode *N) {
3566   SDValue Carry0 = getAsCarry(TLI, N0);
3567   if (!Carry0)
3568     return SDValue();
3569   SDValue Carry1 = getAsCarry(TLI, N1);
3570   if (!Carry1)
3571     return SDValue();
3572 
3573   unsigned Opcode = Carry0.getOpcode();
3574   if (Opcode != Carry1.getOpcode())
3575     return SDValue();
3576   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3577     return SDValue();
3578   // Guarantee identical type of CarryOut
3579   EVT CarryOutType = N->getValueType(0);
3580   if (CarryOutType != Carry0.getValue(1).getValueType() ||
3581       CarryOutType != Carry1.getValue(1).getValueType())
3582     return SDValue();
3583 
3584   // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3585   // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3586   if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3587     std::swap(Carry0, Carry1);
3588 
3589   // Check if nodes are connected in expected way.
3590   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3591       Carry1.getOperand(1) != Carry0.getValue(0))
3592     return SDValue();
3593 
3594   // The carry in value must be on the righthand side for subtraction.
3595   unsigned CarryInOperandNum =
3596       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3597   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3598     return SDValue();
3599   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3600 
3601   unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3602   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3603     return SDValue();
3604 
3605   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3606   CarryIn = getAsCarry(TLI, CarryIn, true);
3607   if (!CarryIn)
3608     return SDValue();
3609 
3610   SDLoc DL(N);
3611   SDValue Merged =
3612       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3613                   Carry0.getOperand(1), CarryIn);
3614 
3615   // Please note that because we have proven that the result of the UADDO/USUBO
3616   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3617   // therefore prove that if the first UADDO/USUBO overflows, the second
3618   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3619   // maximum value.
3620   //
3621   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3622   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3623   //
3624   // This is important because it means that OR and XOR can be used to merge
3625   // carry flags; and that AND can return a constant zero.
3626   //
3627   // TODO: match other operations that can merge flags (ADD, etc)
3628   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3629   if (N->getOpcode() == ISD::AND)
3630     return DAG.getConstant(0, DL, CarryOutType);
3631   return Merged.getValue(1);
3632 }
3633 
visitUADDO_CARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3634 SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3635                                           SDValue CarryIn, SDNode *N) {
3636   // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3637   // carry.
3638   if (isBitwiseNot(N0))
3639     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3640       SDLoc DL(N);
3641       SDValue Sub = DAG.getNode(ISD::USUBO_CARRY, DL, N->getVTList(), N1,
3642                                 N0.getOperand(0), NotC);
3643       return CombineTo(
3644           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3645     }
3646 
3647   // Iff the flag result is dead:
3648   // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3649   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3650   // or the dependency between the instructions.
3651   if ((N0.getOpcode() == ISD::ADD ||
3652        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3653         N0.getValue(1) != CarryIn)) &&
3654       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3655     return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(),
3656                        N0.getOperand(0), N0.getOperand(1), CarryIn);
3657 
3658   /**
3659    * When one of the uaddo_carry argument is itself a carry, we may be facing
3660    * a diamond carry propagation. In which case we try to transform the DAG
3661    * to ensure linear carry propagation if that is possible.
3662    */
3663   if (auto Y = getAsCarry(TLI, N1)) {
3664     // Because both are carries, Y and Z can be swapped.
3665     if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3666       return R;
3667     if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3668       return R;
3669   }
3670 
3671   return SDValue();
3672 }
3673 
visitSADDO_CARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3674 SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3675                                           SDValue CarryIn, SDNode *N) {
3676   // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3677   if (isBitwiseNot(N0)) {
3678     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true))
3679       return DAG.getNode(ISD::SSUBO_CARRY, SDLoc(N), N->getVTList(), N1,
3680                          N0.getOperand(0), NotC);
3681   }
3682 
3683   return SDValue();
3684 }
3685 
visitSADDO_CARRY(SDNode * N)3686 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3687   SDValue N0 = N->getOperand(0);
3688   SDValue N1 = N->getOperand(1);
3689   SDValue CarryIn = N->getOperand(2);
3690   SDLoc DL(N);
3691 
3692   // canonicalize constant to RHS
3693   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3694   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3695   if (N0C && !N1C)
3696     return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3697 
3698   // fold (saddo_carry x, y, false) -> (saddo x, y)
3699   if (isNullConstant(CarryIn)) {
3700     if (!LegalOperations ||
3701         TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3702       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3703   }
3704 
3705   if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3706     return Combined;
3707 
3708   if (SDValue Combined = visitSADDO_CARRYLike(N1, N0, CarryIn, N))
3709     return Combined;
3710 
3711   return SDValue();
3712 }
3713 
3714 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3715 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3716 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3717                                    SDValue RHS, SelectionDAG &DAG,
3718                                    const SDLoc &DL) {
3719   assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3720          "Illegal truncation");
3721 
3722   if (DstVT == SrcVT)
3723     return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3724 
3725   // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3726   // clamping RHS.
3727   APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3728                                           DstVT.getScalarSizeInBits());
3729   if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3730     return SDValue();
3731 
3732   SDValue SatLimit =
3733       DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3734                                            DstVT.getScalarSizeInBits()),
3735                       DL, SrcVT);
3736   RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3737   RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3738   LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3739   return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3740 }
3741 
3742 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3743 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N)3744 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
3745   if (N->getOpcode() != ISD::SUB ||
3746       !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3747     return SDValue();
3748 
3749   EVT SubVT = N->getValueType(0);
3750   SDValue Op0 = N->getOperand(0);
3751   SDValue Op1 = N->getOperand(1);
3752 
3753   // Try to find umax(a,b) - b or a - umin(a,b) patterns
3754   // they may be converted to usubsat(a,b).
3755   if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3756     SDValue MaxLHS = Op0.getOperand(0);
3757     SDValue MaxRHS = Op0.getOperand(1);
3758     if (MaxLHS == Op1)
3759       return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, SDLoc(N));
3760     if (MaxRHS == Op1)
3761       return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, SDLoc(N));
3762   }
3763 
3764   if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3765     SDValue MinLHS = Op1.getOperand(0);
3766     SDValue MinRHS = Op1.getOperand(1);
3767     if (MinLHS == Op0)
3768       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, SDLoc(N));
3769     if (MinRHS == Op0)
3770       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, SDLoc(N));
3771   }
3772 
3773   // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3774   if (Op1.getOpcode() == ISD::TRUNCATE &&
3775       Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3776       Op1.getOperand(0).hasOneUse()) {
3777     SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3778     SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3779     if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3780       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3781                                  DAG, SDLoc(N));
3782     if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3783       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3784                                  DAG, SDLoc(N));
3785   }
3786 
3787   return SDValue();
3788 }
3789 
3790 // Since it may not be valid to emit a fold to zero for vector initializers
3791 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)3792 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3793                              SelectionDAG &DAG, bool LegalOperations) {
3794   if (!VT.isVector())
3795     return DAG.getConstant(0, DL, VT);
3796   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
3797     return DAG.getConstant(0, DL, VT);
3798   return SDValue();
3799 }
3800 
visitSUB(SDNode * N)3801 SDValue DAGCombiner::visitSUB(SDNode *N) {
3802   SDValue N0 = N->getOperand(0);
3803   SDValue N1 = N->getOperand(1);
3804   EVT VT = N0.getValueType();
3805   SDLoc DL(N);
3806 
3807   auto PeekThroughFreeze = [](SDValue N) {
3808     if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3809       return N->getOperand(0);
3810     return N;
3811   };
3812 
3813   // fold (sub x, x) -> 0
3814   // FIXME: Refactor this and xor and other similar operations together.
3815   if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3816     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3817 
3818   // fold (sub c1, c2) -> c3
3819   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
3820     return C;
3821 
3822   // fold vector ops
3823   if (VT.isVector()) {
3824     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3825       return FoldedVOp;
3826 
3827     // fold (sub x, 0) -> x, vector edition
3828     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3829       return N0;
3830   }
3831 
3832   if (SDValue NewSel = foldBinOpIntoSelect(N))
3833     return NewSel;
3834 
3835   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3836 
3837   // fold (sub x, c) -> (add x, -c)
3838   if (N1C) {
3839     return DAG.getNode(ISD::ADD, DL, VT, N0,
3840                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3841   }
3842 
3843   if (isNullOrNullSplat(N0)) {
3844     unsigned BitWidth = VT.getScalarSizeInBits();
3845     // Right-shifting everything out but the sign bit followed by negation is
3846     // the same as flipping arithmetic/logical shift type without the negation:
3847     // -(X >>u 31) -> (X >>s 31)
3848     // -(X >>s 31) -> (X >>u 31)
3849     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3850       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3851       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3852         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3853         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3854           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3855       }
3856     }
3857 
3858     // 0 - X --> 0 if the sub is NUW.
3859     if (N->getFlags().hasNoUnsignedWrap())
3860       return N0;
3861 
3862     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3863       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3864       // N1 must be 0 because negating the minimum signed value is undefined.
3865       if (N->getFlags().hasNoSignedWrap())
3866         return N0;
3867 
3868       // 0 - X --> X if X is 0 or the minimum signed value.
3869       return N1;
3870     }
3871 
3872     // Convert 0 - abs(x).
3873     if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3874         !TLI.isOperationLegalOrCustom(ISD::ABS, VT))
3875       if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
3876         return Result;
3877 
3878     // Fold neg(splat(neg(x)) -> splat(x)
3879     if (VT.isVector()) {
3880       SDValue N1S = DAG.getSplatValue(N1, true);
3881       if (N1S && N1S.getOpcode() == ISD::SUB &&
3882           isNullConstant(N1S.getOperand(0)))
3883         return DAG.getSplat(VT, DL, N1S.getOperand(1));
3884     }
3885   }
3886 
3887   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3888   if (isAllOnesOrAllOnesSplat(N0))
3889     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3890 
3891   // fold (A - (0-B)) -> A+B
3892   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3893     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3894 
3895   // fold A-(A-B) -> B
3896   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3897     return N1.getOperand(1);
3898 
3899   // fold (A+B)-A -> B
3900   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3901     return N0.getOperand(1);
3902 
3903   // fold (A+B)-B -> A
3904   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3905     return N0.getOperand(0);
3906 
3907   // fold (A+C1)-C2 -> A+(C1-C2)
3908   if (N0.getOpcode() == ISD::ADD) {
3909     SDValue N01 = N0.getOperand(1);
3910     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
3911       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3912   }
3913 
3914   // fold C2-(A+C1) -> (C2-C1)-A
3915   if (N1.getOpcode() == ISD::ADD) {
3916     SDValue N11 = N1.getOperand(1);
3917     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
3918       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3919   }
3920 
3921   // fold (A-C1)-C2 -> A-(C1+C2)
3922   if (N0.getOpcode() == ISD::SUB) {
3923     SDValue N01 = N0.getOperand(1);
3924     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
3925       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3926   }
3927 
3928   // fold (c1-A)-c2 -> (c1-c2)-A
3929   if (N0.getOpcode() == ISD::SUB) {
3930     SDValue N00 = N0.getOperand(0);
3931     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
3932       return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3933   }
3934 
3935   // fold ((A+(B+or-C))-B) -> A+or-C
3936   if (N0.getOpcode() == ISD::ADD &&
3937       (N0.getOperand(1).getOpcode() == ISD::SUB ||
3938        N0.getOperand(1).getOpcode() == ISD::ADD) &&
3939       N0.getOperand(1).getOperand(0) == N1)
3940     return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3941                        N0.getOperand(1).getOperand(1));
3942 
3943   // fold ((A+(C+B))-B) -> A+C
3944   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3945       N0.getOperand(1).getOperand(1) == N1)
3946     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3947                        N0.getOperand(1).getOperand(0));
3948 
3949   // fold ((A-(B-C))-C) -> A-B
3950   if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3951       N0.getOperand(1).getOperand(1) == N1)
3952     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3953                        N0.getOperand(1).getOperand(0));
3954 
3955   // fold (A-(B-C)) -> A+(C-B)
3956   if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3957     return DAG.getNode(ISD::ADD, DL, VT, N0,
3958                        DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3959                                    N1.getOperand(0)));
3960 
3961   // A - (A & B)  ->  A & (~B)
3962   if (N1.getOpcode() == ISD::AND) {
3963     SDValue A = N1.getOperand(0);
3964     SDValue B = N1.getOperand(1);
3965     if (A != N0)
3966       std::swap(A, B);
3967     if (A == N0 &&
3968         (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3969       SDValue InvB =
3970           DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3971       return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3972     }
3973   }
3974 
3975   // fold (X - (-Y * Z)) -> (X + (Y * Z))
3976   if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3977     if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3978         isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3979       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3980                                 N1.getOperand(0).getOperand(1),
3981                                 N1.getOperand(1));
3982       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3983     }
3984     if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3985         isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3986       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3987                                 N1.getOperand(0),
3988                                 N1.getOperand(1).getOperand(1));
3989       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3990     }
3991   }
3992 
3993   // If either operand of a sub is undef, the result is undef
3994   if (N0.isUndef())
3995     return N0;
3996   if (N1.isUndef())
3997     return N1;
3998 
3999   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
4000     return V;
4001 
4002   if (SDValue V = foldAddSubOfSignBit(N, DAG))
4003     return V;
4004 
4005   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
4006     return V;
4007 
4008   if (SDValue V = foldSubToUSubSat(VT, N))
4009     return V;
4010 
4011   // (x - y) - 1  ->  add (xor y, -1), x
4012   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() && isOneOrOneSplat(N1)) {
4013     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
4014                               DAG.getAllOnesConstant(DL, VT));
4015     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
4016   }
4017 
4018   // Look for:
4019   //   sub y, (xor x, -1)
4020   // And if the target does not like this form then turn into:
4021   //   add (add x, y), 1
4022   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
4023     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
4024     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
4025   }
4026 
4027   // Hoist one-use addition by non-opaque constant:
4028   //   (x + C) - y  ->  (x - y) + C
4029   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
4030       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4031     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4032     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
4033   }
4034   // y - (x + C)  ->  (y - x) - C
4035   if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4036       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
4037     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
4038     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
4039   }
4040   // (x - C) - y  ->  (x - y) - C
4041   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4042   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4043       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4044     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4045     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
4046   }
4047   // (C - x) - y  ->  C - (x + y)
4048   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4049       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
4050     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
4051     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
4052   }
4053 
4054   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4055   // rather than 'sub 0/1' (the sext should get folded).
4056   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4057   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4058       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
4059       TLI.getBooleanContents(VT) ==
4060           TargetLowering::ZeroOrNegativeOneBooleanContent) {
4061     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
4062     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
4063   }
4064 
4065   // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
4066   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
4067     if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
4068       SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
4069       SDValue S0 = N1.getOperand(0);
4070       if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
4071         if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
4072           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
4073             return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
4074     }
4075   }
4076 
4077   // If the relocation model supports it, consider symbol offsets.
4078   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
4079     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4080       // fold (sub Sym+c1, Sym+c2) -> c1-c2
4081       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
4082         if (GA->getGlobal() == GB->getGlobal())
4083           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
4084                                  DL, VT);
4085     }
4086 
4087   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4088   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4089     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
4090     if (TN->getVT() == MVT::i1) {
4091       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
4092                                  DAG.getConstant(1, DL, VT));
4093       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
4094     }
4095   }
4096 
4097   // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4098   if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4099     const APInt &IntVal = N1.getConstantOperandAPInt(0);
4100     return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
4101   }
4102 
4103   // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4104   if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4105     APInt NewStep = -N1.getConstantOperandAPInt(0);
4106     return DAG.getNode(ISD::ADD, DL, VT, N0,
4107                        DAG.getStepVector(DL, VT, NewStep));
4108   }
4109 
4110   // Prefer an add for more folding potential and possibly better codegen:
4111   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4112   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4113     SDValue ShAmt = N1.getOperand(1);
4114     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
4115     if (ShAmtC &&
4116         ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
4117       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
4118       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
4119     }
4120   }
4121 
4122   // As with the previous fold, prefer add for more folding potential.
4123   // Subtracting SMIN/0 is the same as adding SMIN/0:
4124   // N0 - (X << BW-1) --> N0 + (X << BW-1)
4125   if (N1.getOpcode() == ISD::SHL) {
4126     ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
4127     if (ShlC && ShlC->getAPIntValue() == VT.getScalarSizeInBits() - 1)
4128       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
4129   }
4130 
4131   // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4132   if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(N0.getOperand(1)) &&
4133       N0.getResNo() == 0 && N0.hasOneUse())
4134     return DAG.getNode(ISD::USUBO_CARRY, DL, N0->getVTList(),
4135                        N0.getOperand(0), N1, N0.getOperand(2));
4136 
4137   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT)) {
4138     // (sub Carry, X)  ->  (uaddo_carry (sub 0, X), 0, Carry)
4139     if (SDValue Carry = getAsCarry(TLI, N0)) {
4140       SDValue X = N1;
4141       SDValue Zero = DAG.getConstant(0, DL, VT);
4142       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
4143       return DAG.getNode(ISD::UADDO_CARRY, DL,
4144                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
4145                          Carry);
4146     }
4147   }
4148 
4149   // If there's no chance of borrowing from adjacent bits, then sub is xor:
4150   // sub C0, X --> xor X, C0
4151   if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
4152     if (!C0->isOpaque()) {
4153       const APInt &C0Val = C0->getAPIntValue();
4154       const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
4155       if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4156         return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
4157     }
4158   }
4159 
4160   // max(a,b) - min(a,b) --> abd(a,b)
4161   auto MatchSubMaxMin = [&](unsigned Max, unsigned Min, unsigned Abd) {
4162     if (N0.getOpcode() != Max || N1.getOpcode() != Min)
4163       return SDValue();
4164     if ((N0.getOperand(0) != N1.getOperand(0) ||
4165          N0.getOperand(1) != N1.getOperand(1)) &&
4166         (N0.getOperand(0) != N1.getOperand(1) ||
4167          N0.getOperand(1) != N1.getOperand(0)))
4168       return SDValue();
4169     if (!hasOperation(Abd, VT))
4170       return SDValue();
4171     return DAG.getNode(Abd, DL, VT, N0.getOperand(0), N0.getOperand(1));
4172   };
4173   if (SDValue R = MatchSubMaxMin(ISD::SMAX, ISD::SMIN, ISD::ABDS))
4174     return R;
4175   if (SDValue R = MatchSubMaxMin(ISD::UMAX, ISD::UMIN, ISD::ABDU))
4176     return R;
4177 
4178   return SDValue();
4179 }
4180 
visitSUBSAT(SDNode * N)4181 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4182   unsigned Opcode = N->getOpcode();
4183   SDValue N0 = N->getOperand(0);
4184   SDValue N1 = N->getOperand(1);
4185   EVT VT = N0.getValueType();
4186   bool IsSigned = Opcode == ISD::SSUBSAT;
4187   SDLoc DL(N);
4188 
4189   // fold (sub_sat x, undef) -> 0
4190   if (N0.isUndef() || N1.isUndef())
4191     return DAG.getConstant(0, DL, VT);
4192 
4193   // fold (sub_sat x, x) -> 0
4194   if (N0 == N1)
4195     return DAG.getConstant(0, DL, VT);
4196 
4197   // fold (sub_sat c1, c2) -> c3
4198   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4199     return C;
4200 
4201   // fold vector ops
4202   if (VT.isVector()) {
4203     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4204       return FoldedVOp;
4205 
4206     // fold (sub_sat x, 0) -> x, vector edition
4207     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4208       return N0;
4209   }
4210 
4211   // fold (sub_sat x, 0) -> x
4212   if (isNullConstant(N1))
4213     return N0;
4214 
4215   // If it cannot overflow, transform into an sub.
4216   if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4217     return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
4218 
4219   return SDValue();
4220 }
4221 
visitSUBC(SDNode * N)4222 SDValue DAGCombiner::visitSUBC(SDNode *N) {
4223   SDValue N0 = N->getOperand(0);
4224   SDValue N1 = N->getOperand(1);
4225   EVT VT = N0.getValueType();
4226   SDLoc DL(N);
4227 
4228   // If the flag result is dead, turn this into an SUB.
4229   if (!N->hasAnyUseOfValue(1))
4230     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4231                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4232 
4233   // fold (subc x, x) -> 0 + no borrow
4234   if (N0 == N1)
4235     return CombineTo(N, DAG.getConstant(0, DL, VT),
4236                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4237 
4238   // fold (subc x, 0) -> x + no borrow
4239   if (isNullConstant(N1))
4240     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4241 
4242   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4243   if (isAllOnesConstant(N0))
4244     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4245                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4246 
4247   return SDValue();
4248 }
4249 
visitSUBO(SDNode * N)4250 SDValue DAGCombiner::visitSUBO(SDNode *N) {
4251   SDValue N0 = N->getOperand(0);
4252   SDValue N1 = N->getOperand(1);
4253   EVT VT = N0.getValueType();
4254   bool IsSigned = (ISD::SSUBO == N->getOpcode());
4255 
4256   EVT CarryVT = N->getValueType(1);
4257   SDLoc DL(N);
4258 
4259   // If the flag result is dead, turn this into an SUB.
4260   if (!N->hasAnyUseOfValue(1))
4261     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4262                      DAG.getUNDEF(CarryVT));
4263 
4264   // fold (subo x, x) -> 0 + no borrow
4265   if (N0 == N1)
4266     return CombineTo(N, DAG.getConstant(0, DL, VT),
4267                      DAG.getConstant(0, DL, CarryVT));
4268 
4269   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
4270 
4271   // fold (subox, c) -> (addo x, -c)
4272   if (IsSigned && N1C && !N1C->isMinSignedValue()) {
4273     return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
4274                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
4275   }
4276 
4277   // fold (subo x, 0) -> x + no borrow
4278   if (isNullOrNullSplat(N1))
4279     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
4280 
4281   // If it cannot overflow, transform into an sub.
4282   if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4283     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4284                      DAG.getConstant(0, DL, CarryVT));
4285 
4286   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4287   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
4288     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4289                      DAG.getConstant(0, DL, CarryVT));
4290 
4291   return SDValue();
4292 }
4293 
visitSUBE(SDNode * N)4294 SDValue DAGCombiner::visitSUBE(SDNode *N) {
4295   SDValue N0 = N->getOperand(0);
4296   SDValue N1 = N->getOperand(1);
4297   SDValue CarryIn = N->getOperand(2);
4298 
4299   // fold (sube x, y, false) -> (subc x, y)
4300   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4301     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
4302 
4303   return SDValue();
4304 }
4305 
visitUSUBO_CARRY(SDNode * N)4306 SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4307   SDValue N0 = N->getOperand(0);
4308   SDValue N1 = N->getOperand(1);
4309   SDValue CarryIn = N->getOperand(2);
4310 
4311   // fold (usubo_carry x, y, false) -> (usubo x, y)
4312   if (isNullConstant(CarryIn)) {
4313     if (!LegalOperations ||
4314         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
4315       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
4316   }
4317 
4318   return SDValue();
4319 }
4320 
visitSSUBO_CARRY(SDNode * N)4321 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4322   SDValue N0 = N->getOperand(0);
4323   SDValue N1 = N->getOperand(1);
4324   SDValue CarryIn = N->getOperand(2);
4325 
4326   // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4327   if (isNullConstant(CarryIn)) {
4328     if (!LegalOperations ||
4329         TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
4330       return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
4331   }
4332 
4333   return SDValue();
4334 }
4335 
4336 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4337 // UMULFIXSAT here.
visitMULFIX(SDNode * N)4338 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4339   SDValue N0 = N->getOperand(0);
4340   SDValue N1 = N->getOperand(1);
4341   SDValue Scale = N->getOperand(2);
4342   EVT VT = N0.getValueType();
4343 
4344   // fold (mulfix x, undef, scale) -> 0
4345   if (N0.isUndef() || N1.isUndef())
4346     return DAG.getConstant(0, SDLoc(N), VT);
4347 
4348   // Canonicalize constant to RHS (vector doesn't have to splat)
4349   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4350      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4351     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
4352 
4353   // fold (mulfix x, 0, scale) -> 0
4354   if (isNullConstant(N1))
4355     return DAG.getConstant(0, SDLoc(N), VT);
4356 
4357   return SDValue();
4358 }
4359 
visitMUL(SDNode * N)4360 SDValue DAGCombiner::visitMUL(SDNode *N) {
4361   SDValue N0 = N->getOperand(0);
4362   SDValue N1 = N->getOperand(1);
4363   EVT VT = N0.getValueType();
4364   SDLoc DL(N);
4365 
4366   // fold (mul x, undef) -> 0
4367   if (N0.isUndef() || N1.isUndef())
4368     return DAG.getConstant(0, DL, VT);
4369 
4370   // fold (mul c1, c2) -> c1*c2
4371   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
4372     return C;
4373 
4374   // canonicalize constant to RHS (vector doesn't have to splat)
4375   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4376       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4377     return DAG.getNode(ISD::MUL, DL, VT, N1, N0);
4378 
4379   bool N1IsConst = false;
4380   bool N1IsOpaqueConst = false;
4381   APInt ConstValue1;
4382 
4383   // fold vector ops
4384   if (VT.isVector()) {
4385     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4386       return FoldedVOp;
4387 
4388     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4389     assert((!N1IsConst ||
4390             ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
4391            "Splat APInt should be element width");
4392   } else {
4393     N1IsConst = isa<ConstantSDNode>(N1);
4394     if (N1IsConst) {
4395       ConstValue1 = N1->getAsAPIntVal();
4396       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
4397     }
4398   }
4399 
4400   // fold (mul x, 0) -> 0
4401   if (N1IsConst && ConstValue1.isZero())
4402     return N1;
4403 
4404   // fold (mul x, 1) -> x
4405   if (N1IsConst && ConstValue1.isOne())
4406     return N0;
4407 
4408   if (SDValue NewSel = foldBinOpIntoSelect(N))
4409     return NewSel;
4410 
4411   // fold (mul x, -1) -> 0-x
4412   if (N1IsConst && ConstValue1.isAllOnes())
4413     return DAG.getNegative(N0, DL, VT);
4414 
4415   // fold (mul x, (1 << c)) -> x << c
4416   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4417       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4418     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4419       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4420       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4421       return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
4422     }
4423   }
4424 
4425   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4426   if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4427     unsigned Log2Val = (-ConstValue1).logBase2();
4428     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4429 
4430     // FIXME: If the input is something that is easily negated (e.g. a
4431     // single-use add), we should put the negate there.
4432     return DAG.getNode(ISD::SUB, DL, VT,
4433                        DAG.getConstant(0, DL, VT),
4434                        DAG.getNode(ISD::SHL, DL, VT, N0,
4435                             DAG.getConstant(Log2Val, DL, ShiftVT)));
4436   }
4437 
4438   // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4439   // hi result is in use in case we hit this mid-legalization.
4440   for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4441     if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4442       SDVTList LoHiVT = DAG.getVTList(VT, VT);
4443       // TODO: Can we match commutable operands with getNodeIfExists?
4444       if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4445         if (LoHi->hasAnyUseOfValue(1))
4446           return SDValue(LoHi, 0);
4447       if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4448         if (LoHi->hasAnyUseOfValue(1))
4449           return SDValue(LoHi, 0);
4450     }
4451   }
4452 
4453   // Try to transform:
4454   // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4455   // mul x, (2^N + 1) --> add (shl x, N), x
4456   // mul x, (2^N - 1) --> sub (shl x, N), x
4457   // Examples: x * 33 --> (x << 5) + x
4458   //           x * 15 --> (x << 4) - x
4459   //           x * -33 --> -((x << 5) + x)
4460   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4461   // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4462   // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4463   // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4464   // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4465   //           x * 0xf800 --> (x << 16) - (x << 11)
4466   //           x * -0x8800 --> -((x << 15) + (x << 11))
4467   //           x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4468   if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4469     // TODO: We could handle more general decomposition of any constant by
4470     //       having the target set a limit on number of ops and making a
4471     //       callback to determine that sequence (similar to sqrt expansion).
4472     unsigned MathOp = ISD::DELETED_NODE;
4473     APInt MulC = ConstValue1.abs();
4474     // The constant `2` should be treated as (2^0 + 1).
4475     unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4476     MulC.lshrInPlace(TZeros);
4477     if ((MulC - 1).isPowerOf2())
4478       MathOp = ISD::ADD;
4479     else if ((MulC + 1).isPowerOf2())
4480       MathOp = ISD::SUB;
4481 
4482     if (MathOp != ISD::DELETED_NODE) {
4483       unsigned ShAmt =
4484           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4485       ShAmt += TZeros;
4486       assert(ShAmt < VT.getScalarSizeInBits() &&
4487              "multiply-by-constant generated out of bounds shift");
4488       SDValue Shl =
4489           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4490       SDValue R =
4491           TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4492                                DAG.getNode(ISD::SHL, DL, VT, N0,
4493                                            DAG.getConstant(TZeros, DL, VT)))
4494                  : DAG.getNode(MathOp, DL, VT, Shl, N0);
4495       if (ConstValue1.isNegative())
4496         R = DAG.getNegative(R, DL, VT);
4497       return R;
4498     }
4499   }
4500 
4501   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4502   if (N0.getOpcode() == ISD::SHL) {
4503     SDValue N01 = N0.getOperand(1);
4504     if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4505       return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4506   }
4507 
4508   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4509   // use.
4510   {
4511     SDValue Sh, Y;
4512 
4513     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
4514     if (N0.getOpcode() == ISD::SHL &&
4515         isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
4516       Sh = N0; Y = N1;
4517     } else if (N1.getOpcode() == ISD::SHL &&
4518                isConstantOrConstantVector(N1.getOperand(1)) &&
4519                N1->hasOneUse()) {
4520       Sh = N1; Y = N0;
4521     }
4522 
4523     if (Sh.getNode()) {
4524       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4525       return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4526     }
4527   }
4528 
4529   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4530   if (N0.getOpcode() == ISD::ADD &&
4531       DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4532       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4533       isMulAddWithConstProfitable(N, N0, N1))
4534     return DAG.getNode(
4535         ISD::ADD, DL, VT,
4536         DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4537         DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4538 
4539   // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4540   ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4541   if (N0.getOpcode() == ISD::VSCALE && NC1) {
4542     const APInt &C0 = N0.getConstantOperandAPInt(0);
4543     const APInt &C1 = NC1->getAPIntValue();
4544     return DAG.getVScale(DL, VT, C0 * C1);
4545   }
4546 
4547   // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4548   APInt MulVal;
4549   if (N0.getOpcode() == ISD::STEP_VECTOR &&
4550       ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4551     const APInt &C0 = N0.getConstantOperandAPInt(0);
4552     APInt NewStep = C0 * MulVal;
4553     return DAG.getStepVector(DL, VT, NewStep);
4554   }
4555 
4556   // Fold ((mul x, 0/undef) -> 0,
4557   //       (mul x, 1) -> x) -> x)
4558   // -> and(x, mask)
4559   // We can replace vectors with '0' and '1' factors with a clearing mask.
4560   if (VT.isFixedLengthVector()) {
4561     unsigned NumElts = VT.getVectorNumElements();
4562     SmallBitVector ClearMask;
4563     ClearMask.reserve(NumElts);
4564     auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4565       if (!V || V->isZero()) {
4566         ClearMask.push_back(true);
4567         return true;
4568       }
4569       ClearMask.push_back(false);
4570       return V->isOne();
4571     };
4572     if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4573         ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4574       assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4575       EVT LegalSVT = N1.getOperand(0).getValueType();
4576       SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4577       SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4578       SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4579       for (unsigned I = 0; I != NumElts; ++I)
4580         if (ClearMask[I])
4581           Mask[I] = Zero;
4582       return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
4583     }
4584   }
4585 
4586   // reassociate mul
4587   if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4588     return RMUL;
4589 
4590   // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4591   if (SDValue SD =
4592           reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4593     return SD;
4594 
4595   // Simplify the operands using demanded-bits information.
4596   if (SimplifyDemandedBits(SDValue(N, 0)))
4597     return SDValue(N, 0);
4598 
4599   return SDValue();
4600 }
4601 
4602 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)4603 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4604                                      const TargetLowering &TLI) {
4605   RTLIB::Libcall LC;
4606   EVT NodeType = Node->getValueType(0);
4607   if (!NodeType.isSimple())
4608     return false;
4609   switch (NodeType.getSimpleVT().SimpleTy) {
4610   default: return false; // No libcall for vector types.
4611   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
4612   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4613   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4614   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4615   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4616   }
4617 
4618   return TLI.getLibcallName(LC) != nullptr;
4619 }
4620 
4621 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4622 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4623   if (Node->use_empty())
4624     return SDValue(); // This is a dead node, leave it alone.
4625 
4626   unsigned Opcode = Node->getOpcode();
4627   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4628   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4629 
4630   // DivMod lib calls can still work on non-legal types if using lib-calls.
4631   EVT VT = Node->getValueType(0);
4632   if (VT.isVector() || !VT.isInteger())
4633     return SDValue();
4634 
4635   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4636     return SDValue();
4637 
4638   // If DIVREM is going to get expanded into a libcall,
4639   // but there is no libcall available, then don't combine.
4640   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4641       !isDivRemLibcallAvailable(Node, isSigned, TLI))
4642     return SDValue();
4643 
4644   // If div is legal, it's better to do the normal expansion
4645   unsigned OtherOpcode = 0;
4646   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4647     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4648     if (TLI.isOperationLegalOrCustom(Opcode, VT))
4649       return SDValue();
4650   } else {
4651     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4652     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4653       return SDValue();
4654   }
4655 
4656   SDValue Op0 = Node->getOperand(0);
4657   SDValue Op1 = Node->getOperand(1);
4658   SDValue combined;
4659   for (SDNode *User : Op0->uses()) {
4660     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4661         User->use_empty())
4662       continue;
4663     // Convert the other matching node(s), too;
4664     // otherwise, the DIVREM may get target-legalized into something
4665     // target-specific that we won't be able to recognize.
4666     unsigned UserOpc = User->getOpcode();
4667     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4668         User->getOperand(0) == Op0 &&
4669         User->getOperand(1) == Op1) {
4670       if (!combined) {
4671         if (UserOpc == OtherOpcode) {
4672           SDVTList VTs = DAG.getVTList(VT, VT);
4673           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4674         } else if (UserOpc == DivRemOpc) {
4675           combined = SDValue(User, 0);
4676         } else {
4677           assert(UserOpc == Opcode);
4678           continue;
4679         }
4680       }
4681       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4682         CombineTo(User, combined);
4683       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4684         CombineTo(User, combined.getValue(1));
4685     }
4686   }
4687   return combined;
4688 }
4689 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4690 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4691   SDValue N0 = N->getOperand(0);
4692   SDValue N1 = N->getOperand(1);
4693   EVT VT = N->getValueType(0);
4694   SDLoc DL(N);
4695 
4696   unsigned Opc = N->getOpcode();
4697   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4698   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4699 
4700   // X / undef -> undef
4701   // X % undef -> undef
4702   // X / 0 -> undef
4703   // X % 0 -> undef
4704   // NOTE: This includes vectors where any divisor element is zero/undef.
4705   if (DAG.isUndef(Opc, {N0, N1}))
4706     return DAG.getUNDEF(VT);
4707 
4708   // undef / X -> 0
4709   // undef % X -> 0
4710   if (N0.isUndef())
4711     return DAG.getConstant(0, DL, VT);
4712 
4713   // 0 / X -> 0
4714   // 0 % X -> 0
4715   ConstantSDNode *N0C = isConstOrConstSplat(N0);
4716   if (N0C && N0C->isZero())
4717     return N0;
4718 
4719   // X / X -> 1
4720   // X % X -> 0
4721   if (N0 == N1)
4722     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
4723 
4724   // X / 1 -> X
4725   // X % 1 -> 0
4726   // If this is a boolean op (single-bit element type), we can't have
4727   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4728   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4729   // it's a 1.
4730   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4731     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
4732 
4733   return SDValue();
4734 }
4735 
visitSDIV(SDNode * N)4736 SDValue DAGCombiner::visitSDIV(SDNode *N) {
4737   SDValue N0 = N->getOperand(0);
4738   SDValue N1 = N->getOperand(1);
4739   EVT VT = N->getValueType(0);
4740   EVT CCVT = getSetCCResultType(VT);
4741   SDLoc DL(N);
4742 
4743   // fold (sdiv c1, c2) -> c1/c2
4744   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
4745     return C;
4746 
4747   // fold vector ops
4748   if (VT.isVector())
4749     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4750       return FoldedVOp;
4751 
4752   // fold (sdiv X, -1) -> 0-X
4753   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4754   if (N1C && N1C->isAllOnes())
4755     return DAG.getNegative(N0, DL, VT);
4756 
4757   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4758   if (N1C && N1C->isMinSignedValue())
4759     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4760                          DAG.getConstant(1, DL, VT),
4761                          DAG.getConstant(0, DL, VT));
4762 
4763   if (SDValue V = simplifyDivRem(N, DAG))
4764     return V;
4765 
4766   if (SDValue NewSel = foldBinOpIntoSelect(N))
4767     return NewSel;
4768 
4769   // If we know the sign bits of both operands are zero, strength reduce to a
4770   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
4771   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4772     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
4773 
4774   if (SDValue V = visitSDIVLike(N0, N1, N)) {
4775     // If the corresponding remainder node exists, update its users with
4776     // (Dividend - (Quotient * Divisor).
4777     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
4778                                               { N0, N1 })) {
4779       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4780       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4781       AddToWorklist(Mul.getNode());
4782       AddToWorklist(Sub.getNode());
4783       CombineTo(RemNode, Sub);
4784     }
4785     return V;
4786   }
4787 
4788   // sdiv, srem -> sdivrem
4789   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4790   // true.  Otherwise, we break the simplification logic in visitREM().
4791   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4792   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4793     if (SDValue DivRem = useDivRem(N))
4794         return DivRem;
4795 
4796   return SDValue();
4797 }
4798 
isDivisorPowerOfTwo(SDValue Divisor)4799 static bool isDivisorPowerOfTwo(SDValue Divisor) {
4800   // Helper for determining whether a value is a power-2 constant scalar or a
4801   // vector of such elements.
4802   auto IsPowerOfTwo = [](ConstantSDNode *C) {
4803     if (C->isZero() || C->isOpaque())
4804       return false;
4805     if (C->getAPIntValue().isPowerOf2())
4806       return true;
4807     if (C->getAPIntValue().isNegatedPowerOf2())
4808       return true;
4809     return false;
4810   };
4811 
4812   return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo);
4813 }
4814 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)4815 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4816   SDLoc DL(N);
4817   EVT VT = N->getValueType(0);
4818   EVT CCVT = getSetCCResultType(VT);
4819   unsigned BitWidth = VT.getScalarSizeInBits();
4820 
4821   // fold (sdiv X, pow2) -> simple ops after legalize
4822   // FIXME: We check for the exact bit here because the generic lowering gives
4823   // better results in that case. The target-specific lowering should learn how
4824   // to handle exact sdivs efficiently.
4825   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) {
4826     // Target-specific implementation of sdiv x, pow2.
4827     if (SDValue Res = BuildSDIVPow2(N))
4828       return Res;
4829 
4830     // Create constants that are functions of the shift amount value.
4831     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
4832     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
4833     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
4834     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
4835     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
4836     if (!isConstantOrConstantVector(Inexact))
4837       return SDValue();
4838 
4839     // Splat the sign bit into the register
4840     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
4841                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
4842     AddToWorklist(Sign.getNode());
4843 
4844     // Add (N0 < 0) ? abs2 - 1 : 0;
4845     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
4846     AddToWorklist(Srl.getNode());
4847     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
4848     AddToWorklist(Add.getNode());
4849     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
4850     AddToWorklist(Sra.getNode());
4851 
4852     // Special case: (sdiv X, 1) -> X
4853     // Special Case: (sdiv X, -1) -> 0-X
4854     SDValue One = DAG.getConstant(1, DL, VT);
4855     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4856     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
4857     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
4858     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
4859     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
4860 
4861     // If dividing by a positive value, we're done. Otherwise, the result must
4862     // be negated.
4863     SDValue Zero = DAG.getConstant(0, DL, VT);
4864     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
4865 
4866     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4867     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
4868     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
4869     return Res;
4870   }
4871 
4872   // If integer divide is expensive and we satisfy the requirements, emit an
4873   // alternate sequence.  Targets may check function attributes for size/speed
4874   // trade-offs.
4875   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4876   if (isConstantOrConstantVector(N1) &&
4877       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4878     if (SDValue Op = BuildSDIV(N))
4879       return Op;
4880 
4881   return SDValue();
4882 }
4883 
visitUDIV(SDNode * N)4884 SDValue DAGCombiner::visitUDIV(SDNode *N) {
4885   SDValue N0 = N->getOperand(0);
4886   SDValue N1 = N->getOperand(1);
4887   EVT VT = N->getValueType(0);
4888   EVT CCVT = getSetCCResultType(VT);
4889   SDLoc DL(N);
4890 
4891   // fold (udiv c1, c2) -> c1/c2
4892   if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
4893     return C;
4894 
4895   // fold vector ops
4896   if (VT.isVector())
4897     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4898       return FoldedVOp;
4899 
4900   // fold (udiv X, -1) -> select(X == -1, 1, 0)
4901   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4902   if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
4903     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4904                          DAG.getConstant(1, DL, VT),
4905                          DAG.getConstant(0, DL, VT));
4906   }
4907 
4908   if (SDValue V = simplifyDivRem(N, DAG))
4909     return V;
4910 
4911   if (SDValue NewSel = foldBinOpIntoSelect(N))
4912     return NewSel;
4913 
4914   if (SDValue V = visitUDIVLike(N0, N1, N)) {
4915     // If the corresponding remainder node exists, update its users with
4916     // (Dividend - (Quotient * Divisor).
4917     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
4918                                               { N0, N1 })) {
4919       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4920       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4921       AddToWorklist(Mul.getNode());
4922       AddToWorklist(Sub.getNode());
4923       CombineTo(RemNode, Sub);
4924     }
4925     return V;
4926   }
4927 
4928   // sdiv, srem -> sdivrem
4929   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4930   // true.  Otherwise, we break the simplification logic in visitREM().
4931   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4932   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4933     if (SDValue DivRem = useDivRem(N))
4934         return DivRem;
4935 
4936   return SDValue();
4937 }
4938 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)4939 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4940   SDLoc DL(N);
4941   EVT VT = N->getValueType(0);
4942 
4943   // fold (udiv x, (1 << c)) -> x >>u c
4944   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true)) {
4945     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4946       AddToWorklist(LogBase2.getNode());
4947 
4948       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4949       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4950       AddToWorklist(Trunc.getNode());
4951       return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4952     }
4953   }
4954 
4955   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4956   if (N1.getOpcode() == ISD::SHL) {
4957     SDValue N10 = N1.getOperand(0);
4958     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true)) {
4959       if (SDValue LogBase2 = BuildLogBase2(N10, DL)) {
4960         AddToWorklist(LogBase2.getNode());
4961 
4962         EVT ADDVT = N1.getOperand(1).getValueType();
4963         SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
4964         AddToWorklist(Trunc.getNode());
4965         SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
4966         AddToWorklist(Add.getNode());
4967         return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
4968       }
4969     }
4970   }
4971 
4972   // fold (udiv x, c) -> alternate
4973   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4974   if (isConstantOrConstantVector(N1) &&
4975       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4976     if (SDValue Op = BuildUDIV(N))
4977       return Op;
4978 
4979   return SDValue();
4980 }
4981 
buildOptimizedSREM(SDValue N0,SDValue N1,SDNode * N)4982 SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4983   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
4984       !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
4985     // Target-specific implementation of srem x, pow2.
4986     if (SDValue Res = BuildSREMPow2(N))
4987       return Res;
4988   }
4989   return SDValue();
4990 }
4991 
4992 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)4993 SDValue DAGCombiner::visitREM(SDNode *N) {
4994   unsigned Opcode = N->getOpcode();
4995   SDValue N0 = N->getOperand(0);
4996   SDValue N1 = N->getOperand(1);
4997   EVT VT = N->getValueType(0);
4998   EVT CCVT = getSetCCResultType(VT);
4999 
5000   bool isSigned = (Opcode == ISD::SREM);
5001   SDLoc DL(N);
5002 
5003   // fold (rem c1, c2) -> c1%c2
5004   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5005     return C;
5006 
5007   // fold (urem X, -1) -> select(FX == -1, 0, FX)
5008   // Freeze the numerator to avoid a miscompile with an undefined value.
5009   if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false) &&
5010       CCVT.isVector() == VT.isVector()) {
5011     SDValue F0 = DAG.getFreeze(N0);
5012     SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
5013     return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
5014   }
5015 
5016   if (SDValue V = simplifyDivRem(N, DAG))
5017     return V;
5018 
5019   if (SDValue NewSel = foldBinOpIntoSelect(N))
5020     return NewSel;
5021 
5022   if (isSigned) {
5023     // If we know the sign bits of both operands are zero, strength reduce to a
5024     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
5025     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
5026       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
5027   } else {
5028     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
5029       // fold (urem x, pow2) -> (and x, pow2-1)
5030       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5031       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5032       AddToWorklist(Add.getNode());
5033       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5034     }
5035     // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
5036     // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
5037     // TODO: We should sink the following into isKnownToBePowerOfTwo
5038     // using a OrZero parameter analogous to our handling in ValueTracking.
5039     if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
5040         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
5041       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5042       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5043       AddToWorklist(Add.getNode());
5044       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5045     }
5046   }
5047 
5048   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5049 
5050   // If X/C can be simplified by the division-by-constant logic, lower
5051   // X%C to the equivalent of X-X/C*C.
5052   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5053   // speculative DIV must not cause a DIVREM conversion.  We guard against this
5054   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
5055   // combine will not return a DIVREM.  Regardless, checking cheapness here
5056   // makes sense since the simplification results in fatter code.
5057   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
5058     if (isSigned) {
5059       // check if we can build faster implementation for srem
5060       if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5061         return OptimizedRem;
5062     }
5063 
5064     SDValue OptimizedDiv =
5065         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5066     if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5067       // If the equivalent Div node also exists, update its users.
5068       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5069       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
5070                                                 { N0, N1 }))
5071         CombineTo(DivNode, OptimizedDiv);
5072       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
5073       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5074       AddToWorklist(OptimizedDiv.getNode());
5075       AddToWorklist(Mul.getNode());
5076       return Sub;
5077     }
5078   }
5079 
5080   // sdiv, srem -> sdivrem
5081   if (SDValue DivRem = useDivRem(N))
5082     return DivRem.getValue(1);
5083 
5084   return SDValue();
5085 }
5086 
visitMULHS(SDNode * N)5087 SDValue DAGCombiner::visitMULHS(SDNode *N) {
5088   SDValue N0 = N->getOperand(0);
5089   SDValue N1 = N->getOperand(1);
5090   EVT VT = N->getValueType(0);
5091   SDLoc DL(N);
5092 
5093   // fold (mulhs c1, c2)
5094   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
5095     return C;
5096 
5097   // canonicalize constant to RHS.
5098   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5099       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5100     return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
5101 
5102   if (VT.isVector()) {
5103     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5104       return FoldedVOp;
5105 
5106     // fold (mulhs x, 0) -> 0
5107     // do not return N1, because undef node may exist.
5108     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5109       return DAG.getConstant(0, DL, VT);
5110   }
5111 
5112   // fold (mulhs x, 0) -> 0
5113   if (isNullConstant(N1))
5114     return N1;
5115 
5116   // fold (mulhs x, 1) -> (sra x, size(x)-1)
5117   if (isOneConstant(N1))
5118     return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
5119                        DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
5120                                        getShiftAmountTy(N0.getValueType())));
5121 
5122   // fold (mulhs x, undef) -> 0
5123   if (N0.isUndef() || N1.isUndef())
5124     return DAG.getConstant(0, DL, VT);
5125 
5126   // If the type twice as wide is legal, transform the mulhs to a wider multiply
5127   // plus a shift.
5128   if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
5129       !VT.isVector()) {
5130     MVT Simple = VT.getSimpleVT();
5131     unsigned SimpleSize = Simple.getSizeInBits();
5132     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5133     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5134       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5135       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5136       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5137       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5138             DAG.getConstant(SimpleSize, DL,
5139                             getShiftAmountTy(N1.getValueType())));
5140       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5141     }
5142   }
5143 
5144   return SDValue();
5145 }
5146 
visitMULHU(SDNode * N)5147 SDValue DAGCombiner::visitMULHU(SDNode *N) {
5148   SDValue N0 = N->getOperand(0);
5149   SDValue N1 = N->getOperand(1);
5150   EVT VT = N->getValueType(0);
5151   SDLoc DL(N);
5152 
5153   // fold (mulhu c1, c2)
5154   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
5155     return C;
5156 
5157   // canonicalize constant to RHS.
5158   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5159       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5160     return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
5161 
5162   if (VT.isVector()) {
5163     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5164       return FoldedVOp;
5165 
5166     // fold (mulhu x, 0) -> 0
5167     // do not return N1, because undef node may exist.
5168     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5169       return DAG.getConstant(0, DL, VT);
5170   }
5171 
5172   // fold (mulhu x, 0) -> 0
5173   if (isNullConstant(N1))
5174     return N1;
5175 
5176   // fold (mulhu x, 1) -> 0
5177   if (isOneConstant(N1))
5178     return DAG.getConstant(0, DL, N0.getValueType());
5179 
5180   // fold (mulhu x, undef) -> 0
5181   if (N0.isUndef() || N1.isUndef())
5182     return DAG.getConstant(0, DL, VT);
5183 
5184   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5185   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
5186       hasOperation(ISD::SRL, VT)) {
5187     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
5188       unsigned NumEltBits = VT.getScalarSizeInBits();
5189       SDValue SRLAmt = DAG.getNode(
5190           ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
5191       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
5192       SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
5193       return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
5194     }
5195   }
5196 
5197   // If the type twice as wide is legal, transform the mulhu to a wider multiply
5198   // plus a shift.
5199   if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
5200       !VT.isVector()) {
5201     MVT Simple = VT.getSimpleVT();
5202     unsigned SimpleSize = Simple.getSizeInBits();
5203     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5204     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5205       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5206       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5207       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5208       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5209             DAG.getConstant(SimpleSize, DL,
5210                             getShiftAmountTy(N1.getValueType())));
5211       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5212     }
5213   }
5214 
5215   // Simplify the operands using demanded-bits information.
5216   // We don't have demanded bits support for MULHU so this just enables constant
5217   // folding based on known bits.
5218   if (SimplifyDemandedBits(SDValue(N, 0)))
5219     return SDValue(N, 0);
5220 
5221   return SDValue();
5222 }
5223 
visitAVG(SDNode * N)5224 SDValue DAGCombiner::visitAVG(SDNode *N) {
5225   unsigned Opcode = N->getOpcode();
5226   SDValue N0 = N->getOperand(0);
5227   SDValue N1 = N->getOperand(1);
5228   EVT VT = N->getValueType(0);
5229   SDLoc DL(N);
5230 
5231   // fold (avg c1, c2)
5232   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5233     return C;
5234 
5235   // canonicalize constant to RHS.
5236   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5237       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5238     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5239 
5240   if (VT.isVector()) {
5241     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5242       return FoldedVOp;
5243 
5244     // fold (avgfloor x, 0) -> x >> 1
5245     if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
5246       if (Opcode == ISD::AVGFLOORS)
5247         return DAG.getNode(ISD::SRA, DL, VT, N0, DAG.getConstant(1, DL, VT));
5248       if (Opcode == ISD::AVGFLOORU)
5249         return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(1, DL, VT));
5250     }
5251   }
5252 
5253   // fold (avg x, undef) -> x
5254   if (N0.isUndef())
5255     return N1;
5256   if (N1.isUndef())
5257     return N0;
5258 
5259   // Fold (avg x, x) --> x
5260   if (N0 == N1 && Level >= AfterLegalizeTypes)
5261     return N0;
5262 
5263   // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
5264 
5265   return SDValue();
5266 }
5267 
visitABD(SDNode * N)5268 SDValue DAGCombiner::visitABD(SDNode *N) {
5269   unsigned Opcode = N->getOpcode();
5270   SDValue N0 = N->getOperand(0);
5271   SDValue N1 = N->getOperand(1);
5272   EVT VT = N->getValueType(0);
5273   SDLoc DL(N);
5274 
5275   // fold (abd c1, c2)
5276   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5277     return C;
5278 
5279   // canonicalize constant to RHS.
5280   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5281       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5282     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5283 
5284   if (VT.isVector()) {
5285     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5286       return FoldedVOp;
5287 
5288     // fold (abds x, 0) -> abs x
5289     // fold (abdu x, 0) -> x
5290     if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
5291       if (Opcode == ISD::ABDS)
5292         return DAG.getNode(ISD::ABS, DL, VT, N0);
5293       if (Opcode == ISD::ABDU)
5294         return N0;
5295     }
5296   }
5297 
5298   // fold (abd x, undef) -> 0
5299   if (N0.isUndef() || N1.isUndef())
5300     return DAG.getConstant(0, DL, VT);
5301 
5302   // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5303   if (Opcode == ISD::ABDS && hasOperation(ISD::ABDU, VT) &&
5304       DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1))
5305     return DAG.getNode(ISD::ABDU, DL, VT, N1, N0);
5306 
5307   return SDValue();
5308 }
5309 
5310 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5311 /// give the opcodes for the two computations that are being performed. Return
5312 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)5313 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5314                                                 unsigned HiOp) {
5315   // If the high half is not needed, just compute the low half.
5316   bool HiExists = N->hasAnyUseOfValue(1);
5317   if (!HiExists && (!LegalOperations ||
5318                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
5319     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5320     return CombineTo(N, Res, Res);
5321   }
5322 
5323   // If the low half is not needed, just compute the high half.
5324   bool LoExists = N->hasAnyUseOfValue(0);
5325   if (!LoExists && (!LegalOperations ||
5326                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
5327     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5328     return CombineTo(N, Res, Res);
5329   }
5330 
5331   // If both halves are used, return as it is.
5332   if (LoExists && HiExists)
5333     return SDValue();
5334 
5335   // If the two computed results can be simplified separately, separate them.
5336   if (LoExists) {
5337     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5338     AddToWorklist(Lo.getNode());
5339     SDValue LoOpt = combine(Lo.getNode());
5340     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5341         (!LegalOperations ||
5342          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
5343       return CombineTo(N, LoOpt, LoOpt);
5344   }
5345 
5346   if (HiExists) {
5347     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5348     AddToWorklist(Hi.getNode());
5349     SDValue HiOpt = combine(Hi.getNode());
5350     if (HiOpt.getNode() && HiOpt != Hi &&
5351         (!LegalOperations ||
5352          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
5353       return CombineTo(N, HiOpt, HiOpt);
5354   }
5355 
5356   return SDValue();
5357 }
5358 
visitSMUL_LOHI(SDNode * N)5359 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5360   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
5361     return Res;
5362 
5363   SDValue N0 = N->getOperand(0);
5364   SDValue N1 = N->getOperand(1);
5365   EVT VT = N->getValueType(0);
5366   SDLoc DL(N);
5367 
5368   // Constant fold.
5369   if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
5370     return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N0, N1);
5371 
5372   // canonicalize constant to RHS (vector doesn't have to splat)
5373   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5374       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5375     return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
5376 
5377   // If the type is twice as wide is legal, transform the mulhu to a wider
5378   // multiply plus a shift.
5379   if (VT.isSimple() && !VT.isVector()) {
5380     MVT Simple = VT.getSimpleVT();
5381     unsigned SimpleSize = Simple.getSizeInBits();
5382     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5383     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5384       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5385       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5386       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5387       // Compute the high part as N1.
5388       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5389             DAG.getConstant(SimpleSize, DL,
5390                             getShiftAmountTy(Lo.getValueType())));
5391       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5392       // Compute the low part as N0.
5393       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5394       return CombineTo(N, Lo, Hi);
5395     }
5396   }
5397 
5398   return SDValue();
5399 }
5400 
visitUMUL_LOHI(SDNode * N)5401 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5402   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
5403     return Res;
5404 
5405   SDValue N0 = N->getOperand(0);
5406   SDValue N1 = N->getOperand(1);
5407   EVT VT = N->getValueType(0);
5408   SDLoc DL(N);
5409 
5410   // Constant fold.
5411   if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
5412     return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N0, N1);
5413 
5414   // canonicalize constant to RHS (vector doesn't have to splat)
5415   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5416       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5417     return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
5418 
5419   // (umul_lohi N0, 0) -> (0, 0)
5420   if (isNullConstant(N1)) {
5421     SDValue Zero = DAG.getConstant(0, DL, VT);
5422     return CombineTo(N, Zero, Zero);
5423   }
5424 
5425   // (umul_lohi N0, 1) -> (N0, 0)
5426   if (isOneConstant(N1)) {
5427     SDValue Zero = DAG.getConstant(0, DL, VT);
5428     return CombineTo(N, N0, Zero);
5429   }
5430 
5431   // If the type is twice as wide is legal, transform the mulhu to a wider
5432   // multiply plus a shift.
5433   if (VT.isSimple() && !VT.isVector()) {
5434     MVT Simple = VT.getSimpleVT();
5435     unsigned SimpleSize = Simple.getSizeInBits();
5436     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5437     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5438       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5439       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5440       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5441       // Compute the high part as N1.
5442       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5443             DAG.getConstant(SimpleSize, DL,
5444                             getShiftAmountTy(Lo.getValueType())));
5445       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5446       // Compute the low part as N0.
5447       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5448       return CombineTo(N, Lo, Hi);
5449     }
5450   }
5451 
5452   return SDValue();
5453 }
5454 
visitMULO(SDNode * N)5455 SDValue DAGCombiner::visitMULO(SDNode *N) {
5456   SDValue N0 = N->getOperand(0);
5457   SDValue N1 = N->getOperand(1);
5458   EVT VT = N0.getValueType();
5459   bool IsSigned = (ISD::SMULO == N->getOpcode());
5460 
5461   EVT CarryVT = N->getValueType(1);
5462   SDLoc DL(N);
5463 
5464   ConstantSDNode *N0C = isConstOrConstSplat(N0);
5465   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5466 
5467   // fold operation with constant operands.
5468   // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5469   // multiple results.
5470   if (N0C && N1C) {
5471     bool Overflow;
5472     APInt Result =
5473         IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
5474                  : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
5475     return CombineTo(N, DAG.getConstant(Result, DL, VT),
5476                      DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
5477   }
5478 
5479   // canonicalize constant to RHS.
5480   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5481       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5482     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
5483 
5484   // fold (mulo x, 0) -> 0 + no carry out
5485   if (isNullOrNullSplat(N1))
5486     return CombineTo(N, DAG.getConstant(0, DL, VT),
5487                      DAG.getConstant(0, DL, CarryVT));
5488 
5489   // (mulo x, 2) -> (addo x, x)
5490   // FIXME: This needs a freeze.
5491   if (N1C && N1C->getAPIntValue() == 2 &&
5492       (!IsSigned || VT.getScalarSizeInBits() > 2))
5493     return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5494                        N->getVTList(), N0, N0);
5495 
5496   // A 1 bit SMULO overflows if both inputs are 1.
5497   if (IsSigned && VT.getScalarSizeInBits() == 1) {
5498     SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
5499     SDValue Cmp = DAG.getSetCC(DL, CarryVT, And,
5500                                DAG.getConstant(0, DL, VT), ISD::SETNE);
5501     return CombineTo(N, And, Cmp);
5502   }
5503 
5504   // If it cannot overflow, transform into a mul.
5505   if (DAG.willNotOverflowMul(IsSigned, N0, N1))
5506     return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5507                      DAG.getConstant(0, DL, CarryVT));
5508   return SDValue();
5509 }
5510 
5511 // Function to calculate whether the Min/Max pair of SDNodes (potentially
5512 // swapped around) make a signed saturate pattern, clamping to between a signed
5513 // saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5514 // Returns the node being clamped and the bitwidth of the clamp in BW. Should
5515 // work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5516 // same as SimplifySelectCC. N0<N1 ? N2 : N3.
isSaturatingMinMax(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,unsigned & BW,bool & Unsigned,SelectionDAG & DAG)5517 static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5518                                   SDValue N3, ISD::CondCode CC, unsigned &BW,
5519                                   bool &Unsigned, SelectionDAG &DAG) {
5520   auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5521                             ISD::CondCode CC) {
5522     // The compare and select operand should be the same or the select operands
5523     // should be truncated versions of the comparison.
5524     if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
5525       return 0;
5526     // The constants need to be the same or a truncated version of each other.
5527     ConstantSDNode *N1C = isConstOrConstSplat(peekThroughTruncates(N1));
5528     ConstantSDNode *N3C = isConstOrConstSplat(peekThroughTruncates(N3));
5529     if (!N1C || !N3C)
5530       return 0;
5531     const APInt &C1 = N1C->getAPIntValue().trunc(N1.getScalarValueSizeInBits());
5532     const APInt &C2 = N3C->getAPIntValue().trunc(N3.getScalarValueSizeInBits());
5533     if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
5534       return 0;
5535     return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5536   };
5537 
5538   // Check the initial value is a SMIN/SMAX equivalent.
5539   unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5540   if (!Opcode0)
5541     return SDValue();
5542 
5543   // We could only need one range check, if the fptosi could never produce
5544   // the upper value.
5545   if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
5546     if (isNullOrNullSplat(N3)) {
5547       EVT IntVT = N0.getValueType().getScalarType();
5548       EVT FPVT = N0.getOperand(0).getValueType().getScalarType();
5549       if (FPVT.isSimple()) {
5550         Type *InputTy = FPVT.getTypeForEVT(*DAG.getContext());
5551         const fltSemantics &Semantics = InputTy->getFltSemantics();
5552         uint32_t MinBitWidth =
5553           APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
5554         if (IntVT.getSizeInBits() >= MinBitWidth) {
5555           Unsigned = true;
5556           BW = PowerOf2Ceil(MinBitWidth);
5557           return N0;
5558         }
5559       }
5560     }
5561   }
5562 
5563   SDValue N00, N01, N02, N03;
5564   ISD::CondCode N0CC;
5565   switch (N0.getOpcode()) {
5566   case ISD::SMIN:
5567   case ISD::SMAX:
5568     N00 = N02 = N0.getOperand(0);
5569     N01 = N03 = N0.getOperand(1);
5570     N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5571     break;
5572   case ISD::SELECT_CC:
5573     N00 = N0.getOperand(0);
5574     N01 = N0.getOperand(1);
5575     N02 = N0.getOperand(2);
5576     N03 = N0.getOperand(3);
5577     N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
5578     break;
5579   case ISD::SELECT:
5580   case ISD::VSELECT:
5581     if (N0.getOperand(0).getOpcode() != ISD::SETCC)
5582       return SDValue();
5583     N00 = N0.getOperand(0).getOperand(0);
5584     N01 = N0.getOperand(0).getOperand(1);
5585     N02 = N0.getOperand(1);
5586     N03 = N0.getOperand(2);
5587     N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
5588     break;
5589   default:
5590     return SDValue();
5591   }
5592 
5593   unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5594   if (!Opcode1 || Opcode0 == Opcode1)
5595     return SDValue();
5596 
5597   ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
5598   ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
5599   if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
5600     return SDValue();
5601 
5602   const APInt &MinC = MinCOp->getAPIntValue();
5603   const APInt &MaxC = MaxCOp->getAPIntValue();
5604   APInt MinCPlus1 = MinC + 1;
5605   if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5606     BW = MinCPlus1.exactLogBase2() + 1;
5607     Unsigned = false;
5608     return N02;
5609   }
5610 
5611   if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5612     BW = MinCPlus1.exactLogBase2();
5613     Unsigned = true;
5614     return N02;
5615   }
5616 
5617   return SDValue();
5618 }
5619 
PerformMinMaxFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5620 static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5621                                            SDValue N3, ISD::CondCode CC,
5622                                            SelectionDAG &DAG) {
5623   unsigned BW;
5624   bool Unsigned;
5625   SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
5626   if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5627     return SDValue();
5628   EVT FPVT = Fp.getOperand(0).getValueType();
5629   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5630   if (FPVT.isVector())
5631     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5632                              FPVT.getVectorElementCount());
5633   unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5634   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
5635     return SDValue();
5636   SDLoc DL(Fp);
5637   SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
5638                             DAG.getValueType(NewVT.getScalarType()));
5639   return DAG.getExtOrTrunc(!Unsigned, Sat, DL, N2->getValueType(0));
5640 }
5641 
PerformUMinFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5642 static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5643                                          SDValue N3, ISD::CondCode CC,
5644                                          SelectionDAG &DAG) {
5645   // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5646   // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5647   // be truncated versions of the setcc (N0/N1).
5648   if ((N0 != N2 &&
5649        (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
5650       N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5651     return SDValue();
5652   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5653   ConstantSDNode *N3C = isConstOrConstSplat(N3);
5654   if (!N1C || !N3C)
5655     return SDValue();
5656   const APInt &C1 = N1C->getAPIntValue();
5657   const APInt &C3 = N3C->getAPIntValue();
5658   if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5659       C1 != C3.zext(C1.getBitWidth()))
5660     return SDValue();
5661 
5662   unsigned BW = (C1 + 1).exactLogBase2();
5663   EVT FPVT = N0.getOperand(0).getValueType();
5664   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5665   if (FPVT.isVector())
5666     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5667                              FPVT.getVectorElementCount());
5668   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
5669                                                         FPVT, NewVT))
5670     return SDValue();
5671 
5672   SDValue Sat =
5673       DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
5674                   DAG.getValueType(NewVT.getScalarType()));
5675   return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
5676 }
5677 
visitIMINMAX(SDNode * N)5678 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5679   SDValue N0 = N->getOperand(0);
5680   SDValue N1 = N->getOperand(1);
5681   EVT VT = N0.getValueType();
5682   unsigned Opcode = N->getOpcode();
5683   SDLoc DL(N);
5684 
5685   // fold operation with constant operands.
5686   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5687     return C;
5688 
5689   // If the operands are the same, this is a no-op.
5690   if (N0 == N1)
5691     return N0;
5692 
5693   // canonicalize constant to RHS
5694   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5695       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5696     return DAG.getNode(Opcode, DL, VT, N1, N0);
5697 
5698   // fold vector ops
5699   if (VT.isVector())
5700     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5701       return FoldedVOp;
5702 
5703   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5704   // Only do this if the current op isn't legal and the flipped is.
5705   if (!TLI.isOperationLegal(Opcode, VT) &&
5706       (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
5707       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
5708     unsigned AltOpcode;
5709     switch (Opcode) {
5710     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5711     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5712     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5713     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5714     default: llvm_unreachable("Unknown MINMAX opcode");
5715     }
5716     if (TLI.isOperationLegal(AltOpcode, VT))
5717       return DAG.getNode(AltOpcode, DL, VT, N0, N1);
5718   }
5719 
5720   if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5721     if (SDValue S = PerformMinMaxFpToSatCombine(
5722             N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5723       return S;
5724   if (Opcode == ISD::UMIN)
5725     if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
5726       return S;
5727 
5728   // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
5729   auto ReductionOpcode = [](unsigned Opcode) {
5730     switch (Opcode) {
5731     case ISD::SMIN:
5732       return ISD::VECREDUCE_SMIN;
5733     case ISD::SMAX:
5734       return ISD::VECREDUCE_SMAX;
5735     case ISD::UMIN:
5736       return ISD::VECREDUCE_UMIN;
5737     case ISD::UMAX:
5738       return ISD::VECREDUCE_UMAX;
5739     default:
5740       llvm_unreachable("Unexpected opcode");
5741     }
5742   };
5743   if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode,
5744                                         SDLoc(N), VT, N0, N1))
5745     return SD;
5746 
5747   // Simplify the operands using demanded-bits information.
5748   if (SimplifyDemandedBits(SDValue(N, 0)))
5749     return SDValue(N, 0);
5750 
5751   return SDValue();
5752 }
5753 
5754 /// If this is a bitwise logic instruction and both operands have the same
5755 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)5756 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5757   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
5758   EVT VT = N0.getValueType();
5759   unsigned LogicOpcode = N->getOpcode();
5760   unsigned HandOpcode = N0.getOpcode();
5761   assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
5762   assert(HandOpcode == N1.getOpcode() && "Bad input!");
5763 
5764   // Bail early if none of these transforms apply.
5765   if (N0.getNumOperands() == 0)
5766     return SDValue();
5767 
5768   // FIXME: We should check number of uses of the operands to not increase
5769   //        the instruction count for all transforms.
5770 
5771   // Handle size-changing casts (or sign_extend_inreg).
5772   SDValue X = N0.getOperand(0);
5773   SDValue Y = N1.getOperand(0);
5774   EVT XVT = X.getValueType();
5775   SDLoc DL(N);
5776   if (ISD::isExtOpcode(HandOpcode) || ISD::isExtVecInRegOpcode(HandOpcode) ||
5777       (HandOpcode == ISD::SIGN_EXTEND_INREG &&
5778        N0.getOperand(1) == N1.getOperand(1))) {
5779     // If both operands have other uses, this transform would create extra
5780     // instructions without eliminating anything.
5781     if (!N0.hasOneUse() && !N1.hasOneUse())
5782       return SDValue();
5783     // We need matching integer source types.
5784     if (XVT != Y.getValueType())
5785       return SDValue();
5786     // Don't create an illegal op during or after legalization. Don't ever
5787     // create an unsupported vector op.
5788     if ((VT.isVector() || LegalOperations) &&
5789         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
5790       return SDValue();
5791     // Avoid infinite looping with PromoteIntBinOp.
5792     // TODO: Should we apply desirable/legal constraints to all opcodes?
5793     if ((HandOpcode == ISD::ANY_EXTEND ||
5794          HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
5795         LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
5796       return SDValue();
5797     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5798     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5799     if (HandOpcode == ISD::SIGN_EXTEND_INREG)
5800       return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
5801     return DAG.getNode(HandOpcode, DL, VT, Logic);
5802   }
5803 
5804   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5805   if (HandOpcode == ISD::TRUNCATE) {
5806     // If both operands have other uses, this transform would create extra
5807     // instructions without eliminating anything.
5808     if (!N0.hasOneUse() && !N1.hasOneUse())
5809       return SDValue();
5810     // We need matching source types.
5811     if (XVT != Y.getValueType())
5812       return SDValue();
5813     // Don't create an illegal op during or after legalization.
5814     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
5815       return SDValue();
5816     // Be extra careful sinking truncate. If it's free, there's no benefit in
5817     // widening a binop. Also, don't create a logic op on an illegal type.
5818     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
5819       return SDValue();
5820     if (!TLI.isTypeLegal(XVT))
5821       return SDValue();
5822     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5823     return DAG.getNode(HandOpcode, DL, VT, Logic);
5824   }
5825 
5826   // For binops SHL/SRL/SRA/AND:
5827   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5828   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5829        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5830       N0.getOperand(1) == N1.getOperand(1)) {
5831     // If either operand has other uses, this transform is not an improvement.
5832     if (!N0.hasOneUse() || !N1.hasOneUse())
5833       return SDValue();
5834     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5835     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
5836   }
5837 
5838   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5839   if (HandOpcode == ISD::BSWAP) {
5840     // If either operand has other uses, this transform is not an improvement.
5841     if (!N0.hasOneUse() || !N1.hasOneUse())
5842       return SDValue();
5843     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5844     return DAG.getNode(HandOpcode, DL, VT, Logic);
5845   }
5846 
5847   // For funnel shifts FSHL/FSHR:
5848   // logic_op (OP x, x1, s), (OP y, y1, s) -->
5849   // --> OP (logic_op x, y), (logic_op, x1, y1), s
5850   if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
5851       N0.getOperand(2) == N1.getOperand(2)) {
5852     if (!N0.hasOneUse() || !N1.hasOneUse())
5853       return SDValue();
5854     SDValue X1 = N0.getOperand(1);
5855     SDValue Y1 = N1.getOperand(1);
5856     SDValue S = N0.getOperand(2);
5857     SDValue Logic0 = DAG.getNode(LogicOpcode, DL, VT, X, Y);
5858     SDValue Logic1 = DAG.getNode(LogicOpcode, DL, VT, X1, Y1);
5859     return DAG.getNode(HandOpcode, DL, VT, Logic0, Logic1, S);
5860   }
5861 
5862   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5863   // Only perform this optimization up until type legalization, before
5864   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5865   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5866   // we don't want to undo this promotion.
5867   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5868   // on scalars.
5869   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5870        Level <= AfterLegalizeTypes) {
5871     // Input types must be integer and the same.
5872     if (XVT.isInteger() && XVT == Y.getValueType() &&
5873         !(VT.isVector() && TLI.isTypeLegal(VT) &&
5874           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
5875       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5876       return DAG.getNode(HandOpcode, DL, VT, Logic);
5877     }
5878   }
5879 
5880   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5881   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5882   // If both shuffles use the same mask, and both shuffle within a single
5883   // vector, then it is worthwhile to move the swizzle after the operation.
5884   // The type-legalizer generates this pattern when loading illegal
5885   // vector types from memory. In many cases this allows additional shuffle
5886   // optimizations.
5887   // There are other cases where moving the shuffle after the xor/and/or
5888   // is profitable even if shuffles don't perform a swizzle.
5889   // If both shuffles use the same mask, and both shuffles have the same first
5890   // or second operand, then it might still be profitable to move the shuffle
5891   // after the xor/and/or operation.
5892   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5893     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
5894     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
5895     assert(X.getValueType() == Y.getValueType() &&
5896            "Inputs to shuffles are not the same type");
5897 
5898     // Check that both shuffles use the same mask. The masks are known to be of
5899     // the same length because the result vector type is the same.
5900     // Check also that shuffles have only one use to avoid introducing extra
5901     // instructions.
5902     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5903         !SVN0->getMask().equals(SVN1->getMask()))
5904       return SDValue();
5905 
5906     // Don't try to fold this node if it requires introducing a
5907     // build vector of all zeros that might be illegal at this stage.
5908     SDValue ShOp = N0.getOperand(1);
5909     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5910       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5911 
5912     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5913     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
5914       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
5915                                   N0.getOperand(0), N1.getOperand(0));
5916       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
5917     }
5918 
5919     // Don't try to fold this node if it requires introducing a
5920     // build vector of all zeros that might be illegal at this stage.
5921     ShOp = N0.getOperand(0);
5922     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5923       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5924 
5925     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5926     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
5927       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
5928                                   N1.getOperand(1));
5929       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
5930     }
5931   }
5932 
5933   return SDValue();
5934 }
5935 
5936 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)5937 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5938                                        const SDLoc &DL) {
5939   SDValue LL, LR, RL, RR, N0CC, N1CC;
5940   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
5941       !isSetCCEquivalent(N1, RL, RR, N1CC))
5942     return SDValue();
5943 
5944   assert(N0.getValueType() == N1.getValueType() &&
5945          "Unexpected operand types for bitwise logic op");
5946   assert(LL.getValueType() == LR.getValueType() &&
5947          RL.getValueType() == RR.getValueType() &&
5948          "Unexpected operand types for setcc");
5949 
5950   // If we're here post-legalization or the logic op type is not i1, the logic
5951   // op type must match a setcc result type. Also, all folds require new
5952   // operations on the left and right operands, so those types must match.
5953   EVT VT = N0.getValueType();
5954   EVT OpVT = LL.getValueType();
5955   if (LegalOperations || VT.getScalarType() != MVT::i1)
5956     if (VT != getSetCCResultType(OpVT))
5957       return SDValue();
5958   if (OpVT != RL.getValueType())
5959     return SDValue();
5960 
5961   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
5962   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
5963   bool IsInteger = OpVT.isInteger();
5964   if (LR == RR && CC0 == CC1 && IsInteger) {
5965     bool IsZero = isNullOrNullSplat(LR);
5966     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
5967 
5968     // All bits clear?
5969     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5970     // All sign bits clear?
5971     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5972     // Any bits set?
5973     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5974     // Any sign bits set?
5975     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5976 
5977     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
5978     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5979     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
5980     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
5981     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5982       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
5983       AddToWorklist(Or.getNode());
5984       return DAG.getSetCC(DL, VT, Or, LR, CC1);
5985     }
5986 
5987     // All bits set?
5988     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5989     // All sign bits set?
5990     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5991     // Any bits clear?
5992     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5993     // Any sign bits clear?
5994     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5995 
5996     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
5997     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
5998     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
5999     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
6000     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
6001       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
6002       AddToWorklist(And.getNode());
6003       return DAG.getSetCC(DL, VT, And, LR, CC1);
6004     }
6005   }
6006 
6007   // TODO: What is the 'or' equivalent of this fold?
6008   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6009   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6010       IsInteger && CC0 == ISD::SETNE &&
6011       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
6012        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
6013     SDValue One = DAG.getConstant(1, DL, OpVT);
6014     SDValue Two = DAG.getConstant(2, DL, OpVT);
6015     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
6016     AddToWorklist(Add.getNode());
6017     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
6018   }
6019 
6020   // Try more general transforms if the predicates match and the only user of
6021   // the compares is the 'and' or 'or'.
6022   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
6023       N0.hasOneUse() && N1.hasOneUse()) {
6024     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6025     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6026     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6027       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
6028       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
6029       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
6030       SDValue Zero = DAG.getConstant(0, DL, OpVT);
6031       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
6032     }
6033 
6034     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6035     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6036       // Match a shared variable operand and 2 non-opaque constant operands.
6037       auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6038         // The difference of the constants must be a single bit.
6039         const APInt &CMax =
6040             APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
6041         const APInt &CMin =
6042             APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
6043         return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6044       };
6045       if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
6046         // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6047         // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6048         SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
6049         SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
6050         SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
6051         SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
6052         SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
6053         SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
6054         SDValue Zero = DAG.getConstant(0, DL, OpVT);
6055         return DAG.getSetCC(DL, VT, And, Zero, CC0);
6056       }
6057     }
6058   }
6059 
6060   // Canonicalize equivalent operands to LL == RL.
6061   if (LL == RR && LR == RL) {
6062     CC1 = ISD::getSetCCSwappedOperands(CC1);
6063     std::swap(RL, RR);
6064   }
6065 
6066   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6067   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6068   if (LL == RL && LR == RR) {
6069     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
6070                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
6071     if (NewCC != ISD::SETCC_INVALID &&
6072         (!LegalOperations ||
6073          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
6074           TLI.isOperationLegal(ISD::SETCC, OpVT))))
6075       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
6076   }
6077 
6078   return SDValue();
6079 }
6080 
arebothOperandsNotSNan(SDValue Operand1,SDValue Operand2,SelectionDAG & DAG)6081 static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6082                                    SelectionDAG &DAG) {
6083   return DAG.isKnownNeverSNaN(Operand2) && DAG.isKnownNeverSNaN(Operand1);
6084 }
6085 
arebothOperandsNotNan(SDValue Operand1,SDValue Operand2,SelectionDAG & DAG)6086 static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6087                                   SelectionDAG &DAG) {
6088   return DAG.isKnownNeverNaN(Operand2) && DAG.isKnownNeverNaN(Operand1);
6089 }
6090 
getMinMaxOpcodeForFP(SDValue Operand1,SDValue Operand2,ISD::CondCode CC,unsigned OrAndOpcode,SelectionDAG & DAG,bool isFMAXNUMFMINNUM_IEEE,bool isFMAXNUMFMINNUM)6091 static unsigned getMinMaxOpcodeForFP(SDValue Operand1, SDValue Operand2,
6092                                      ISD::CondCode CC, unsigned OrAndOpcode,
6093                                      SelectionDAG &DAG,
6094                                      bool isFMAXNUMFMINNUM_IEEE,
6095                                      bool isFMAXNUMFMINNUM) {
6096   // The optimization cannot be applied for all the predicates because
6097   // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6098   // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6099   // applied at all if one of the operands is a signaling NaN.
6100 
6101   // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6102   // are non NaN values.
6103   if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6104       ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND)))
6105     return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6106                    isFMAXNUMFMINNUM_IEEE
6107                ? ISD::FMINNUM_IEEE
6108                : ISD::DELETED_NODE;
6109   else if (((CC == ISD::SETGT || CC == ISD::SETGE) &&
6110             (OrAndOpcode == ISD::OR)) ||
6111            ((CC == ISD::SETLT || CC == ISD::SETLE) &&
6112             (OrAndOpcode == ISD::AND)))
6113     return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6114                    isFMAXNUMFMINNUM_IEEE
6115                ? ISD::FMAXNUM_IEEE
6116                : ISD::DELETED_NODE;
6117   // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6118   // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6119   // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6120   // that there are not any sNaNs, then the optimization is not valid
6121   // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6122   // the optimization using FMINNUM/FMAXNUM for the following cases. If
6123   // we can prove that we do not have any sNaNs, then we can do the
6124   // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6125   // cases.
6126   else if (((CC == ISD::SETOLT || CC == ISD::SETOLE) &&
6127             (OrAndOpcode == ISD::OR)) ||
6128            ((CC == ISD::SETUGT || CC == ISD::SETUGE) &&
6129             (OrAndOpcode == ISD::AND)))
6130     return isFMAXNUMFMINNUM ? ISD::FMINNUM
6131                             : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6132                                       isFMAXNUMFMINNUM_IEEE
6133                                   ? ISD::FMINNUM_IEEE
6134                                   : ISD::DELETED_NODE;
6135   else if (((CC == ISD::SETOGT || CC == ISD::SETOGE) &&
6136             (OrAndOpcode == ISD::OR)) ||
6137            ((CC == ISD::SETULT || CC == ISD::SETULE) &&
6138             (OrAndOpcode == ISD::AND)))
6139     return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6140                             : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6141                                       isFMAXNUMFMINNUM_IEEE
6142                                   ? ISD::FMAXNUM_IEEE
6143                                   : ISD::DELETED_NODE;
6144   return ISD::DELETED_NODE;
6145 }
6146 
foldAndOrOfSETCC(SDNode * LogicOp,SelectionDAG & DAG)6147 static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6148   using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6149   assert(
6150       (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6151       "Invalid Op to combine SETCC with");
6152 
6153   // TODO: Search past casts/truncates.
6154   SDValue LHS = LogicOp->getOperand(0);
6155   SDValue RHS = LogicOp->getOperand(1);
6156   if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6157       !LHS->hasOneUse() || !RHS->hasOneUse())
6158     return SDValue();
6159 
6160   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6161   AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6162       LogicOp, LHS.getNode(), RHS.getNode());
6163 
6164   SDValue LHS0 = LHS->getOperand(0);
6165   SDValue RHS0 = RHS->getOperand(0);
6166   SDValue LHS1 = LHS->getOperand(1);
6167   SDValue RHS1 = RHS->getOperand(1);
6168   // TODO: We don't actually need a splat here, for vectors we just need the
6169   // invariants to hold for each element.
6170   auto *LHS1C = isConstOrConstSplat(LHS1);
6171   auto *RHS1C = isConstOrConstSplat(RHS1);
6172   ISD::CondCode CCL = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
6173   ISD::CondCode CCR = cast<CondCodeSDNode>(RHS.getOperand(2))->get();
6174   EVT VT = LogicOp->getValueType(0);
6175   EVT OpVT = LHS0.getValueType();
6176   SDLoc DL(LogicOp);
6177 
6178   // Check if the operands of an and/or operation are comparisons and if they
6179   // compare against the same value. Replace the and/or-cmp-cmp sequence with
6180   // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6181   // sequence will be replaced with min-cmp sequence:
6182   // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6183   // and and-cmp-cmp will be replaced with max-cmp sequence:
6184   // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6185   // The optimization does not work for `==` or `!=` .
6186   // The two comparisons should have either the same predicate or the
6187   // predicate of one of the comparisons is the opposite of the other one.
6188   bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(ISD::FMAXNUM_IEEE, OpVT) &&
6189                                TLI.isOperationLegal(ISD::FMINNUM_IEEE, OpVT);
6190   bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(ISD::FMAXNUM, OpVT) &&
6191                           TLI.isOperationLegalOrCustom(ISD::FMINNUM, OpVT);
6192   if (((OpVT.isInteger() && TLI.isOperationLegal(ISD::UMAX, OpVT) &&
6193         TLI.isOperationLegal(ISD::SMAX, OpVT) &&
6194         TLI.isOperationLegal(ISD::UMIN, OpVT) &&
6195         TLI.isOperationLegal(ISD::SMIN, OpVT)) ||
6196        (OpVT.isFloatingPoint() &&
6197         (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6198       !ISD::isIntEqualitySetCC(CCL) && !ISD::isFPEqualitySetCC(CCL) &&
6199       CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6200       CCL != ISD::SETTRUE &&
6201       (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(CCR))) {
6202 
6203     SDValue CommonValue, Operand1, Operand2;
6204     ISD::CondCode CC = ISD::SETCC_INVALID;
6205     if (CCL == CCR) {
6206       if (LHS0 == RHS0) {
6207         CommonValue = LHS0;
6208         Operand1 = LHS1;
6209         Operand2 = RHS1;
6210         CC = ISD::getSetCCSwappedOperands(CCL);
6211       } else if (LHS1 == RHS1) {
6212         CommonValue = LHS1;
6213         Operand1 = LHS0;
6214         Operand2 = RHS0;
6215         CC = CCL;
6216       }
6217     } else {
6218       assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6219       if (LHS0 == RHS1) {
6220         CommonValue = LHS0;
6221         Operand1 = LHS1;
6222         Operand2 = RHS0;
6223         CC = CCR;
6224       } else if (RHS0 == LHS1) {
6225         CommonValue = LHS1;
6226         Operand1 = LHS0;
6227         Operand2 = RHS1;
6228         CC = CCL;
6229       }
6230     }
6231 
6232     // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6233     // handle it using OR/AND.
6234     if (CC == ISD::SETLT && isNullOrNullSplat(CommonValue))
6235       CC = ISD::SETCC_INVALID;
6236     else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CommonValue))
6237       CC = ISD::SETCC_INVALID;
6238 
6239     if (CC != ISD::SETCC_INVALID) {
6240       unsigned NewOpcode = ISD::DELETED_NODE;
6241       bool IsSigned = isSignedIntSetCC(CC);
6242       if (OpVT.isInteger()) {
6243         bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6244                        CC == ISD::SETLT || CC == ISD::SETULT);
6245         bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6246         if (IsLess == IsOr)
6247           NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6248         else
6249           NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6250       } else if (OpVT.isFloatingPoint())
6251         NewOpcode =
6252             getMinMaxOpcodeForFP(Operand1, Operand2, CC, LogicOp->getOpcode(),
6253                                  DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6254 
6255       if (NewOpcode != ISD::DELETED_NODE) {
6256         SDValue MinMaxValue =
6257             DAG.getNode(NewOpcode, DL, OpVT, Operand1, Operand2);
6258         return DAG.getSetCC(DL, VT, MinMaxValue, CommonValue, CC);
6259       }
6260     }
6261   }
6262 
6263   if (TargetPreference == AndOrSETCCFoldKind::None)
6264     return SDValue();
6265 
6266   if (CCL == CCR &&
6267       CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6268       LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6269     const APInt &APLhs = LHS1C->getAPIntValue();
6270     const APInt &APRhs = RHS1C->getAPIntValue();
6271 
6272     // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6273     // case this is just a compare).
6274     if (APLhs == (-APRhs) &&
6275         ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6276          DAG.doesNodeExist(ISD::ABS, DAG.getVTList(OpVT), {LHS0}))) {
6277       const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6278       // (icmp eq A, C) | (icmp eq A, -C)
6279       //    -> (icmp eq Abs(A), C)
6280       // (icmp ne A, C) & (icmp ne A, -C)
6281       //    -> (icmp ne Abs(A), C)
6282       SDValue AbsOp = DAG.getNode(ISD::ABS, DL, OpVT, LHS0);
6283       return DAG.getNode(ISD::SETCC, DL, VT, AbsOp,
6284                          DAG.getConstant(C, DL, OpVT), LHS.getOperand(2));
6285     } else if (TargetPreference &
6286                (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6287 
6288       // AndOrSETCCFoldKind::AddAnd:
6289       // A == C0 | A == C1
6290       //  IF IsPow2(smax(C0, C1)-smin(C0, C1))
6291       //    -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6292       // A != C0 & A != C1
6293       //  IF IsPow2(smax(C0, C1)-smin(C0, C1))
6294       //    -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6295 
6296       // AndOrSETCCFoldKind::NotAnd:
6297       // A == C0 | A == C1
6298       //  IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6299       //    -> ~A & smin(C0, C1) == 0
6300       // A != C0 & A != C1
6301       //  IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6302       //    -> ~A & smin(C0, C1) != 0
6303 
6304       const APInt &MaxC = APIntOps::smax(APRhs, APLhs);
6305       const APInt &MinC = APIntOps::smin(APRhs, APLhs);
6306       APInt Dif = MaxC - MinC;
6307       if (!Dif.isZero() && Dif.isPowerOf2()) {
6308         if (MaxC.isAllOnes() &&
6309             (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6310           SDValue NotOp = DAG.getNOT(DL, LHS0, OpVT);
6311           SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, NotOp,
6312                                       DAG.getConstant(MinC, DL, OpVT));
6313           return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6314                              DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6315         } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6316 
6317           SDValue AddOp = DAG.getNode(ISD::ADD, DL, OpVT, LHS0,
6318                                       DAG.getConstant(-MinC, DL, OpVT));
6319           SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, AddOp,
6320                                       DAG.getConstant(~Dif, DL, OpVT));
6321           return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6322                              DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6323         }
6324       }
6325     }
6326   }
6327 
6328   return SDValue();
6329 }
6330 
6331 // Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6332 // We canonicalize to the `select` form in the middle end, but the `and` form
6333 // gets better codegen and all tested targets (arm, x86, riscv)
combineSelectAsExtAnd(SDValue Cond,SDValue T,SDValue F,const SDLoc & DL,SelectionDAG & DAG)6334 static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6335                                      const SDLoc &DL, SelectionDAG &DAG) {
6336   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6337   if (!isNullConstant(F))
6338     return SDValue();
6339 
6340   EVT CondVT = Cond.getValueType();
6341   if (TLI.getBooleanContents(CondVT) !=
6342       TargetLoweringBase::ZeroOrOneBooleanContent)
6343     return SDValue();
6344 
6345   if (T.getOpcode() != ISD::AND)
6346     return SDValue();
6347 
6348   if (!isOneConstant(T.getOperand(1)))
6349     return SDValue();
6350 
6351   EVT OpVT = T.getValueType();
6352 
6353   SDValue CondMask =
6354       OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Cond, DL, OpVT, CondVT);
6355   return DAG.getNode(ISD::AND, DL, OpVT, CondMask, T.getOperand(0));
6356 }
6357 
6358 /// This contains all DAGCombine rules which reduce two values combined by
6359 /// an And operation to a single value. This makes them reusable in the context
6360 /// of visitSELECT(). Rules involving constants are not included as
6361 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)6362 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6363   EVT VT = N1.getValueType();
6364   SDLoc DL(N);
6365 
6366   // fold (and x, undef) -> 0
6367   if (N0.isUndef() || N1.isUndef())
6368     return DAG.getConstant(0, DL, VT);
6369 
6370   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
6371     return V;
6372 
6373   // Canonicalize:
6374   //   and(x, add) -> and(add, x)
6375   if (N1.getOpcode() == ISD::ADD)
6376     std::swap(N0, N1);
6377 
6378   // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
6379   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
6380       VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
6381     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
6382       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
6383         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
6384         // immediate for an add, but it is legal if its top c2 bits are set,
6385         // transform the ADD so the immediate doesn't need to be materialized
6386         // in a register.
6387         APInt ADDC = ADDI->getAPIntValue();
6388         APInt SRLC = SRLI->getAPIntValue();
6389         if (ADDC.getSignificantBits() <= 64 && SRLC.ult(VT.getSizeInBits()) &&
6390             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6391           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
6392                                              SRLC.getZExtValue());
6393           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
6394             ADDC |= Mask;
6395             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6396               SDLoc DL0(N0);
6397               SDValue NewAdd =
6398                 DAG.getNode(ISD::ADD, DL0, VT,
6399                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
6400               CombineTo(N0.getNode(), NewAdd);
6401               // Return N so it doesn't get rechecked!
6402               return SDValue(N, 0);
6403             }
6404           }
6405         }
6406       }
6407     }
6408   }
6409 
6410   return SDValue();
6411 }
6412 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)6413 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
6414                                    EVT LoadResultTy, EVT &ExtVT) {
6415   if (!AndC->getAPIntValue().isMask())
6416     return false;
6417 
6418   unsigned ActiveBits = AndC->getAPIntValue().countr_one();
6419 
6420   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
6421   EVT LoadedVT = LoadN->getMemoryVT();
6422 
6423   if (ExtVT == LoadedVT &&
6424       (!LegalOperations ||
6425        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
6426     // ZEXTLOAD will match without needing to change the size of the value being
6427     // loaded.
6428     return true;
6429   }
6430 
6431   // Do not change the width of a volatile or atomic loads.
6432   if (!LoadN->isSimple())
6433     return false;
6434 
6435   // Do not generate loads of non-round integer types since these can
6436   // be expensive (and would be wrong if the type is not byte sized).
6437   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
6438     return false;
6439 
6440   if (LegalOperations &&
6441       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
6442     return false;
6443 
6444   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
6445     return false;
6446 
6447   return true;
6448 }
6449 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)6450 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
6451                                     ISD::LoadExtType ExtType, EVT &MemVT,
6452                                     unsigned ShAmt) {
6453   if (!LDST)
6454     return false;
6455   // Only allow byte offsets.
6456   if (ShAmt % 8)
6457     return false;
6458 
6459   // Do not generate loads of non-round integer types since these can
6460   // be expensive (and would be wrong if the type is not byte sized).
6461   if (!MemVT.isRound())
6462     return false;
6463 
6464   // Don't change the width of a volatile or atomic loads.
6465   if (!LDST->isSimple())
6466     return false;
6467 
6468   EVT LdStMemVT = LDST->getMemoryVT();
6469 
6470   // Bail out when changing the scalable property, since we can't be sure that
6471   // we're actually narrowing here.
6472   if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
6473     return false;
6474 
6475   // Verify that we are actually reducing a load width here.
6476   if (LdStMemVT.bitsLT(MemVT))
6477     return false;
6478 
6479   // Ensure that this isn't going to produce an unsupported memory access.
6480   if (ShAmt) {
6481     assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
6482     const unsigned ByteShAmt = ShAmt / 8;
6483     const Align LDSTAlign = LDST->getAlign();
6484     const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
6485     if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
6486                                 LDST->getAddressSpace(), NarrowAlign,
6487                                 LDST->getMemOperand()->getFlags()))
6488       return false;
6489   }
6490 
6491   // It's not possible to generate a constant of extended or untyped type.
6492   EVT PtrType = LDST->getBasePtr().getValueType();
6493   if (PtrType == MVT::Untyped || PtrType.isExtended())
6494     return false;
6495 
6496   if (isa<LoadSDNode>(LDST)) {
6497     LoadSDNode *Load = cast<LoadSDNode>(LDST);
6498     // Don't transform one with multiple uses, this would require adding a new
6499     // load.
6500     if (!SDValue(Load, 0).hasOneUse())
6501       return false;
6502 
6503     if (LegalOperations &&
6504         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
6505       return false;
6506 
6507     // For the transform to be legal, the load must produce only two values
6508     // (the value loaded and the chain).  Don't transform a pre-increment
6509     // load, for example, which produces an extra value.  Otherwise the
6510     // transformation is not equivalent, and the downstream logic to replace
6511     // uses gets things wrong.
6512     if (Load->getNumValues() > 2)
6513       return false;
6514 
6515     // If the load that we're shrinking is an extload and we're not just
6516     // discarding the extension we can't simply shrink the load. Bail.
6517     // TODO: It would be possible to merge the extensions in some cases.
6518     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
6519         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6520       return false;
6521 
6522     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
6523       return false;
6524   } else {
6525     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
6526     StoreSDNode *Store = cast<StoreSDNode>(LDST);
6527     // Can't write outside the original store
6528     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6529       return false;
6530 
6531     if (LegalOperations &&
6532         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
6533       return false;
6534   }
6535   return true;
6536 }
6537 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)6538 bool DAGCombiner::SearchForAndLoads(SDNode *N,
6539                                     SmallVectorImpl<LoadSDNode*> &Loads,
6540                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
6541                                     ConstantSDNode *Mask,
6542                                     SDNode *&NodeToMask) {
6543   // Recursively search for the operands, looking for loads which can be
6544   // narrowed.
6545   for (SDValue Op : N->op_values()) {
6546     if (Op.getValueType().isVector())
6547       return false;
6548 
6549     // Some constants may need fixing up later if they are too large.
6550     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
6551       if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
6552           (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
6553         NodesWithConsts.insert(N);
6554       continue;
6555     }
6556 
6557     if (!Op.hasOneUse())
6558       return false;
6559 
6560     switch(Op.getOpcode()) {
6561     case ISD::LOAD: {
6562       auto *Load = cast<LoadSDNode>(Op);
6563       EVT ExtVT;
6564       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
6565           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
6566 
6567         // ZEXTLOAD is already small enough.
6568         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
6569             ExtVT.bitsGE(Load->getMemoryVT()))
6570           continue;
6571 
6572         // Use LE to convert equal sized loads to zext.
6573         if (ExtVT.bitsLE(Load->getMemoryVT()))
6574           Loads.push_back(Load);
6575 
6576         continue;
6577       }
6578       return false;
6579     }
6580     case ISD::ZERO_EXTEND:
6581     case ISD::AssertZext: {
6582       unsigned ActiveBits = Mask->getAPIntValue().countr_one();
6583       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
6584       EVT VT = Op.getOpcode() == ISD::AssertZext ?
6585         cast<VTSDNode>(Op.getOperand(1))->getVT() :
6586         Op.getOperand(0).getValueType();
6587 
6588       // We can accept extending nodes if the mask is wider or an equal
6589       // width to the original type.
6590       if (ExtVT.bitsGE(VT))
6591         continue;
6592       break;
6593     }
6594     case ISD::OR:
6595     case ISD::XOR:
6596     case ISD::AND:
6597       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
6598                              NodeToMask))
6599         return false;
6600       continue;
6601     }
6602 
6603     // Allow one node which will masked along with any loads found.
6604     if (NodeToMask)
6605       return false;
6606 
6607     // Also ensure that the node to be masked only produces one data result.
6608     NodeToMask = Op.getNode();
6609     if (NodeToMask->getNumValues() > 1) {
6610       bool HasValue = false;
6611       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
6612         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
6613         if (VT != MVT::Glue && VT != MVT::Other) {
6614           if (HasValue) {
6615             NodeToMask = nullptr;
6616             return false;
6617           }
6618           HasValue = true;
6619         }
6620       }
6621       assert(HasValue && "Node to be masked has no data result?");
6622     }
6623   }
6624   return true;
6625 }
6626 
BackwardsPropagateMask(SDNode * N)6627 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
6628   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
6629   if (!Mask)
6630     return false;
6631 
6632   if (!Mask->getAPIntValue().isMask())
6633     return false;
6634 
6635   // No need to do anything if the and directly uses a load.
6636   if (isa<LoadSDNode>(N->getOperand(0)))
6637     return false;
6638 
6639   SmallVector<LoadSDNode*, 8> Loads;
6640   SmallPtrSet<SDNode*, 2> NodesWithConsts;
6641   SDNode *FixupNode = nullptr;
6642   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
6643     if (Loads.empty())
6644       return false;
6645 
6646     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
6647     SDValue MaskOp = N->getOperand(1);
6648 
6649     // If it exists, fixup the single node we allow in the tree that needs
6650     // masking.
6651     if (FixupNode) {
6652       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
6653       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
6654                                 FixupNode->getValueType(0),
6655                                 SDValue(FixupNode, 0), MaskOp);
6656       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
6657       if (And.getOpcode() == ISD ::AND)
6658         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
6659     }
6660 
6661     // Narrow any constants that need it.
6662     for (auto *LogicN : NodesWithConsts) {
6663       SDValue Op0 = LogicN->getOperand(0);
6664       SDValue Op1 = LogicN->getOperand(1);
6665 
6666       if (isa<ConstantSDNode>(Op0))
6667         Op0 =
6668             DAG.getNode(ISD::AND, SDLoc(Op0), Op0.getValueType(), Op0, MaskOp);
6669 
6670       if (isa<ConstantSDNode>(Op1))
6671         Op1 =
6672             DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(), Op1, MaskOp);
6673 
6674       if (isa<ConstantSDNode>(Op0) && !isa<ConstantSDNode>(Op1))
6675         std::swap(Op0, Op1);
6676 
6677       DAG.UpdateNodeOperands(LogicN, Op0, Op1);
6678     }
6679 
6680     // Create narrow loads.
6681     for (auto *Load : Loads) {
6682       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
6683       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
6684                                 SDValue(Load, 0), MaskOp);
6685       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
6686       if (And.getOpcode() == ISD ::AND)
6687         And = SDValue(
6688             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
6689       SDValue NewLoad = reduceLoadWidth(And.getNode());
6690       assert(NewLoad &&
6691              "Shouldn't be masking the load if it can't be narrowed");
6692       CombineTo(Load, NewLoad, NewLoad.getValue(1));
6693     }
6694     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
6695     return true;
6696   }
6697   return false;
6698 }
6699 
6700 // Unfold
6701 //    x &  (-1 'logical shift' y)
6702 // To
6703 //    (x 'opposite logical shift' y) 'logical shift' y
6704 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)6705 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
6706   assert(N->getOpcode() == ISD::AND);
6707 
6708   SDValue N0 = N->getOperand(0);
6709   SDValue N1 = N->getOperand(1);
6710 
6711   // Do we actually prefer shifts over mask?
6712   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
6713     return SDValue();
6714 
6715   // Try to match  (-1 '[outer] logical shift' y)
6716   unsigned OuterShift;
6717   unsigned InnerShift; // The opposite direction to the OuterShift.
6718   SDValue Y;           // Shift amount.
6719   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
6720     if (!M.hasOneUse())
6721       return false;
6722     OuterShift = M->getOpcode();
6723     if (OuterShift == ISD::SHL)
6724       InnerShift = ISD::SRL;
6725     else if (OuterShift == ISD::SRL)
6726       InnerShift = ISD::SHL;
6727     else
6728       return false;
6729     if (!isAllOnesConstant(M->getOperand(0)))
6730       return false;
6731     Y = M->getOperand(1);
6732     return true;
6733   };
6734 
6735   SDValue X;
6736   if (matchMask(N1))
6737     X = N0;
6738   else if (matchMask(N0))
6739     X = N1;
6740   else
6741     return SDValue();
6742 
6743   SDLoc DL(N);
6744   EVT VT = N->getValueType(0);
6745 
6746   //     tmp = x   'opposite logical shift' y
6747   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
6748   //     ret = tmp 'logical shift' y
6749   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
6750 
6751   return T1;
6752 }
6753 
6754 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
6755 /// For a target with a bit test, this is expected to become test + set and save
6756 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)6757 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
6758   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
6759 
6760   // Look through an optional extension.
6761   SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
6762   if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6763     And0 = And0.getOperand(0);
6764   if (!isOneConstant(And1) || !And0.hasOneUse())
6765     return SDValue();
6766 
6767   SDValue Src = And0;
6768 
6769   // Attempt to find a 'not' op.
6770   // TODO: Should we favor test+set even without the 'not' op?
6771   bool FoundNot = false;
6772   if (isBitwiseNot(Src)) {
6773     FoundNot = true;
6774     Src = Src.getOperand(0);
6775 
6776     // Look though an optional truncation. The source operand may not be the
6777     // same type as the original 'and', but that is ok because we are masking
6778     // off everything but the low bit.
6779     if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6780       Src = Src.getOperand(0);
6781   }
6782 
6783   // Match a shift-right by constant.
6784   if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6785     return SDValue();
6786 
6787   // This is probably not worthwhile without a supported type.
6788   EVT SrcVT = Src.getValueType();
6789   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6790   if (!TLI.isTypeLegal(SrcVT))
6791     return SDValue();
6792 
6793   // We might have looked through casts that make this transform invalid.
6794   unsigned BitWidth = SrcVT.getScalarSizeInBits();
6795   SDValue ShiftAmt = Src.getOperand(1);
6796   auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
6797   if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(BitWidth))
6798     return SDValue();
6799 
6800   // Set source to shift source.
6801   Src = Src.getOperand(0);
6802 
6803   // Try again to find a 'not' op.
6804   // TODO: Should we favor test+set even with two 'not' ops?
6805   if (!FoundNot) {
6806     if (!isBitwiseNot(Src))
6807       return SDValue();
6808     Src = Src.getOperand(0);
6809   }
6810 
6811   if (!TLI.hasBitTest(Src, ShiftAmt))
6812     return SDValue();
6813 
6814   // Turn this into a bit-test pattern using mask op + setcc:
6815   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6816   // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6817   SDLoc DL(And);
6818   SDValue X = DAG.getZExtOrTrunc(Src, DL, SrcVT);
6819   EVT CCVT =
6820       TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
6821   SDValue Mask = DAG.getConstant(
6822       APInt::getOneBitSet(BitWidth, ShiftAmtC->getZExtValue()), DL, SrcVT);
6823   SDValue NewAnd = DAG.getNode(ISD::AND, DL, SrcVT, X, Mask);
6824   SDValue Zero = DAG.getConstant(0, DL, SrcVT);
6825   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
6826   return DAG.getZExtOrTrunc(Setcc, DL, And->getValueType(0));
6827 }
6828 
6829 /// For targets that support usubsat, match a bit-hack form of that operation
6830 /// that ends in 'and' and convert it.
foldAndToUsubsat(SDNode * N,SelectionDAG & DAG)6831 static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG) {
6832   SDValue N0 = N->getOperand(0);
6833   SDValue N1 = N->getOperand(1);
6834   EVT VT = N1.getValueType();
6835 
6836   // Canonicalize SRA as operand 1.
6837   if (N0.getOpcode() == ISD::SRA)
6838     std::swap(N0, N1);
6839 
6840   // xor/add with SMIN (signmask) are logically equivalent.
6841   if (N0.getOpcode() != ISD::XOR && N0.getOpcode() != ISD::ADD)
6842     return SDValue();
6843 
6844   if (N1.getOpcode() != ISD::SRA || !N0.hasOneUse() || !N1.hasOneUse() ||
6845       N0.getOperand(0) != N1.getOperand(0))
6846     return SDValue();
6847 
6848   unsigned BitWidth = VT.getScalarSizeInBits();
6849   ConstantSDNode *XorC = isConstOrConstSplat(N0.getOperand(1), true);
6850   ConstantSDNode *SraC = isConstOrConstSplat(N1.getOperand(1), true);
6851   if (!XorC || !XorC->getAPIntValue().isSignMask() ||
6852       !SraC || SraC->getAPIntValue() != BitWidth - 1)
6853     return SDValue();
6854 
6855   // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6856   // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6857   SDLoc DL(N);
6858   SDValue SignMask = DAG.getConstant(XorC->getAPIntValue(), DL, VT);
6859   return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0), SignMask);
6860 }
6861 
6862 /// Given a bitwise logic operation N with a matching bitwise logic operand,
6863 /// fold a pattern where 2 of the source operands are identically shifted
6864 /// values. For example:
6865 /// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
foldLogicOfShifts(SDNode * N,SDValue LogicOp,SDValue ShiftOp,SelectionDAG & DAG)6866 static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6867                                  SelectionDAG &DAG) {
6868   unsigned LogicOpcode = N->getOpcode();
6869   assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6870          "Expected bitwise logic operation");
6871 
6872   if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6873     return SDValue();
6874 
6875   // Match another bitwise logic op and a shift.
6876   unsigned ShiftOpcode = ShiftOp.getOpcode();
6877   if (LogicOp.getOpcode() != LogicOpcode ||
6878       !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6879         ShiftOpcode == ISD::SRA))
6880     return SDValue();
6881 
6882   // Match another shift op inside the first logic operand. Handle both commuted
6883   // possibilities.
6884   // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6885   // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6886   SDValue X1 = ShiftOp.getOperand(0);
6887   SDValue Y = ShiftOp.getOperand(1);
6888   SDValue X0, Z;
6889   if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
6890       LogicOp.getOperand(0).getOperand(1) == Y) {
6891     X0 = LogicOp.getOperand(0).getOperand(0);
6892     Z = LogicOp.getOperand(1);
6893   } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
6894              LogicOp.getOperand(1).getOperand(1) == Y) {
6895     X0 = LogicOp.getOperand(1).getOperand(0);
6896     Z = LogicOp.getOperand(0);
6897   } else {
6898     return SDValue();
6899   }
6900 
6901   EVT VT = N->getValueType(0);
6902   SDLoc DL(N);
6903   SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
6904   SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
6905   return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
6906 }
6907 
6908 /// Given a tree of logic operations with shape like
6909 /// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
6910 /// try to match and fold shift operations with the same shift amount.
6911 /// For example:
6912 /// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
6913 /// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
foldLogicTreeOfShifts(SDNode * N,SDValue LeftHand,SDValue RightHand,SelectionDAG & DAG)6914 static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
6915                                      SDValue RightHand, SelectionDAG &DAG) {
6916   unsigned LogicOpcode = N->getOpcode();
6917   assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6918          "Expected bitwise logic operation");
6919   if (LeftHand.getOpcode() != LogicOpcode ||
6920       RightHand.getOpcode() != LogicOpcode)
6921     return SDValue();
6922   if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
6923     return SDValue();
6924 
6925   // Try to match one of following patterns:
6926   // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
6927   // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
6928   // Note that foldLogicOfShifts will handle commuted versions of the left hand
6929   // itself.
6930   SDValue CombinedShifts, W;
6931   SDValue R0 = RightHand.getOperand(0);
6932   SDValue R1 = RightHand.getOperand(1);
6933   if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R0, DAG)))
6934     W = R1;
6935   else if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R1, DAG)))
6936     W = R0;
6937   else
6938     return SDValue();
6939 
6940   EVT VT = N->getValueType(0);
6941   SDLoc DL(N);
6942   return DAG.getNode(LogicOpcode, DL, VT, CombinedShifts, W);
6943 }
6944 
visitAND(SDNode * N)6945 SDValue DAGCombiner::visitAND(SDNode *N) {
6946   SDValue N0 = N->getOperand(0);
6947   SDValue N1 = N->getOperand(1);
6948   EVT VT = N1.getValueType();
6949 
6950   // x & x --> x
6951   if (N0 == N1)
6952     return N0;
6953 
6954   // fold (and c1, c2) -> c1&c2
6955   if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
6956     return C;
6957 
6958   // canonicalize constant to RHS
6959   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6960       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6961     return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
6962 
6963   if (areBitwiseNotOfEachother(N0, N1))
6964     return DAG.getConstant(APInt::getZero(VT.getScalarSizeInBits()), SDLoc(N),
6965                            VT);
6966 
6967   // fold vector ops
6968   if (VT.isVector()) {
6969     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
6970       return FoldedVOp;
6971 
6972     // fold (and x, 0) -> 0, vector edition
6973     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6974       // do not return N1, because undef node may exist in N1
6975       return DAG.getConstant(APInt::getZero(N1.getScalarValueSizeInBits()),
6976                              SDLoc(N), N1.getValueType());
6977 
6978     // fold (and x, -1) -> x, vector edition
6979     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6980       return N0;
6981 
6982     // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6983     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
6984     ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
6985     if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
6986         N1.hasOneUse()) {
6987       EVT LoadVT = MLoad->getMemoryVT();
6988       EVT ExtVT = VT;
6989       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
6990         // For this AND to be a zero extension of the masked load the elements
6991         // of the BuildVec must mask the bottom bits of the extended element
6992         // type
6993         uint64_t ElementSize =
6994             LoadVT.getVectorElementType().getScalarSizeInBits();
6995         if (Splat->getAPIntValue().isMask(ElementSize)) {
6996           auto NewLoad = DAG.getMaskedLoad(
6997               ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
6998               MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
6999               LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
7000               ISD::ZEXTLOAD, MLoad->isExpandingLoad());
7001           bool LoadHasOtherUsers = !N0.hasOneUse();
7002           CombineTo(N, NewLoad);
7003           if (LoadHasOtherUsers)
7004             CombineTo(MLoad, NewLoad.getValue(0), NewLoad.getValue(1));
7005           return SDValue(N, 0);
7006         }
7007       }
7008     }
7009   }
7010 
7011   // fold (and x, -1) -> x
7012   if (isAllOnesConstant(N1))
7013     return N0;
7014 
7015   // if (and x, c) is known to be zero, return 0
7016   unsigned BitWidth = VT.getScalarSizeInBits();
7017   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7018   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
7019     return DAG.getConstant(0, SDLoc(N), VT);
7020 
7021   if (SDValue R = foldAndOrOfSETCC(N, DAG))
7022     return R;
7023 
7024   if (SDValue NewSel = foldBinOpIntoSelect(N))
7025     return NewSel;
7026 
7027   // reassociate and
7028   if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
7029     return RAND;
7030 
7031   // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7032   if (SDValue SD = reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, SDLoc(N),
7033                                         VT, N0, N1))
7034     return SD;
7035 
7036   // fold (and (or x, C), D) -> D if (C & D) == D
7037   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7038     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
7039   };
7040   if (N0.getOpcode() == ISD::OR &&
7041       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
7042     return N1;
7043 
7044   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7045     SDValue N0Op0 = N0.getOperand(0);
7046     EVT SrcVT = N0Op0.getValueType();
7047     unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7048     APInt Mask = ~N1C->getAPIntValue();
7049     Mask = Mask.trunc(SrcBitWidth);
7050 
7051     // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7052     if (DAG.MaskedValueIsZero(N0Op0, Mask))
7053       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0Op0);
7054 
7055     // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7056     if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7057         TLI.isTruncateFree(VT, SrcVT) && TLI.isZExtFree(SrcVT, VT) &&
7058         TLI.isTypeDesirableForOp(ISD::AND, SrcVT) &&
7059         TLI.isNarrowingProfitable(VT, SrcVT)) {
7060       SDLoc DL(N);
7061       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
7062                          DAG.getNode(ISD::AND, DL, SrcVT, N0Op0,
7063                                      DAG.getZExtOrTrunc(N1, DL, SrcVT)));
7064     }
7065   }
7066 
7067   // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7068   if (ISD::isExtOpcode(N0.getOpcode())) {
7069     unsigned ExtOpc = N0.getOpcode();
7070     SDValue N0Op0 = N0.getOperand(0);
7071     if (N0Op0.getOpcode() == ISD::AND &&
7072         (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0Op0, VT)) &&
7073         DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
7074         DAG.isConstantIntBuildVectorOrConstantInt(N0Op0.getOperand(1)) &&
7075         N0->hasOneUse() && N0Op0->hasOneUse()) {
7076       SDLoc DL(N);
7077       SDValue NewMask =
7078           DAG.getNode(ISD::AND, DL, VT, N1,
7079                       DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(1)));
7080       return DAG.getNode(ISD::AND, DL, VT,
7081                          DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(0)),
7082                          NewMask);
7083     }
7084   }
7085 
7086   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7087   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7088   // already be zero by virtue of the width of the base type of the load.
7089   //
7090   // the 'X' node here can either be nothing or an extract_vector_elt to catch
7091   // more cases.
7092   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7093        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
7094        N0.getOperand(0).getOpcode() == ISD::LOAD &&
7095        N0.getOperand(0).getResNo() == 0) ||
7096       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7097     LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
7098                                          N0 : N0.getOperand(0) );
7099 
7100     // Get the constant (if applicable) the zero'th operand is being ANDed with.
7101     // This can be a pure constant or a vector splat, in which case we treat the
7102     // vector as a scalar and use the splat value.
7103     APInt Constant = APInt::getZero(1);
7104     if (const ConstantSDNode *C = isConstOrConstSplat(
7105             N1, /*AllowUndef=*/false, /*AllowTruncation=*/true)) {
7106       Constant = C->getAPIntValue();
7107     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
7108       unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
7109       APInt SplatValue, SplatUndef;
7110       unsigned SplatBitSize;
7111       bool HasAnyUndefs;
7112       // Endianness should not matter here. Code below makes sure that we only
7113       // use the result if the SplatBitSize is a multiple of the vector element
7114       // size. And after that we AND all element sized parts of the splat
7115       // together. So the end result should be the same regardless of in which
7116       // order we do those operations.
7117       const bool IsBigEndian = false;
7118       bool IsSplat =
7119           Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7120                                   HasAnyUndefs, EltBitWidth, IsBigEndian);
7121 
7122       // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7123       // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7124       if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7125         // Undef bits can contribute to a possible optimisation if set, so
7126         // set them.
7127         SplatValue |= SplatUndef;
7128 
7129         // The splat value may be something like "0x00FFFFFF", which means 0 for
7130         // the first vector value and FF for the rest, repeating. We need a mask
7131         // that will apply equally to all members of the vector, so AND all the
7132         // lanes of the constant together.
7133         Constant = APInt::getAllOnes(EltBitWidth);
7134         for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7135           Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
7136       }
7137     }
7138 
7139     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7140     // actually legal and isn't going to get expanded, else this is a false
7141     // optimisation.
7142     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
7143                                                     Load->getValueType(0),
7144                                                     Load->getMemoryVT());
7145 
7146     // Resize the constant to the same size as the original memory access before
7147     // extension. If it is still the AllOnesValue then this AND is completely
7148     // unneeded.
7149     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
7150 
7151     bool B;
7152     switch (Load->getExtensionType()) {
7153     default: B = false; break;
7154     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7155     case ISD::ZEXTLOAD:
7156     case ISD::NON_EXTLOAD: B = true; break;
7157     }
7158 
7159     if (B && Constant.isAllOnes()) {
7160       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7161       // preserve semantics once we get rid of the AND.
7162       SDValue NewLoad(Load, 0);
7163 
7164       // Fold the AND away. NewLoad may get replaced immediately.
7165       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
7166 
7167       if (Load->getExtensionType() == ISD::EXTLOAD) {
7168         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
7169                               Load->getValueType(0), SDLoc(Load),
7170                               Load->getChain(), Load->getBasePtr(),
7171                               Load->getOffset(), Load->getMemoryVT(),
7172                               Load->getMemOperand());
7173         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7174         if (Load->getNumValues() == 3) {
7175           // PRE/POST_INC loads have 3 values.
7176           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
7177                            NewLoad.getValue(2) };
7178           CombineTo(Load, To, 3, true);
7179         } else {
7180           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
7181         }
7182       }
7183 
7184       return SDValue(N, 0); // Return N so it doesn't get rechecked!
7185     }
7186   }
7187 
7188   // Try to convert a constant mask AND into a shuffle clear mask.
7189   if (VT.isVector())
7190     if (SDValue Shuffle = XformToShuffleWithZero(N))
7191       return Shuffle;
7192 
7193   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7194     return Combined;
7195 
7196   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7197       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
7198     SDValue Ext = N0.getOperand(0);
7199     EVT ExtVT = Ext->getValueType(0);
7200     SDValue Extendee = Ext->getOperand(0);
7201 
7202     unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7203     if (N1C->getAPIntValue().isMask(ScalarWidth) &&
7204         (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
7205       //    (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7206       // => (extract_subvector (iN_zeroext v))
7207       SDValue ZeroExtExtendee =
7208           DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), ExtVT, Extendee);
7209 
7210       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, ZeroExtExtendee,
7211                          N0.getOperand(1));
7212     }
7213   }
7214 
7215   // fold (and (masked_gather x)) -> (zext_masked_gather x)
7216   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
7217     EVT MemVT = GN0->getMemoryVT();
7218     EVT ScalarVT = MemVT.getScalarType();
7219 
7220     if (SDValue(GN0, 0).hasOneUse() &&
7221         isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
7222         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
7223       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
7224                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
7225 
7226       SDValue ZExtLoad = DAG.getMaskedGather(
7227           DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops,
7228           GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD);
7229 
7230       CombineTo(N, ZExtLoad);
7231       AddToWorklist(ZExtLoad.getNode());
7232       // Avoid recheck of N.
7233       return SDValue(N, 0);
7234     }
7235   }
7236 
7237   // fold (and (load x), 255) -> (zextload x, i8)
7238   // fold (and (extload x, i16), 255) -> (zextload x, i8)
7239   if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7240     if (SDValue Res = reduceLoadWidth(N))
7241       return Res;
7242 
7243   if (LegalTypes) {
7244     // Attempt to propagate the AND back up to the leaves which, if they're
7245     // loads, can be combined to narrow loads and the AND node can be removed.
7246     // Perform after legalization so that extend nodes will already be
7247     // combined into the loads.
7248     if (BackwardsPropagateMask(N))
7249       return SDValue(N, 0);
7250   }
7251 
7252   if (SDValue Combined = visitANDLike(N0, N1, N))
7253     return Combined;
7254 
7255   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
7256   if (N0.getOpcode() == N1.getOpcode())
7257     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7258       return V;
7259 
7260   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7261     return R;
7262   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
7263     return R;
7264 
7265   // Masking the negated extension of a boolean is just the zero-extended
7266   // boolean:
7267   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7268   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7269   //
7270   // Note: the SimplifyDemandedBits fold below can make an information-losing
7271   // transform, and then we have no way to find this better fold.
7272   if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
7273     if (isNullOrNullSplat(N0.getOperand(0))) {
7274       SDValue SubRHS = N0.getOperand(1);
7275       if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
7276           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
7277         return SubRHS;
7278       if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
7279           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
7280         return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
7281     }
7282   }
7283 
7284   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7285   // fold (and (sra)) -> (and (srl)) when possible.
7286   if (SimplifyDemandedBits(SDValue(N, 0)))
7287     return SDValue(N, 0);
7288 
7289   // fold (zext_inreg (extload x)) -> (zextload x)
7290   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7291   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
7292       (ISD::isEXTLoad(N0.getNode()) ||
7293        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
7294     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
7295     EVT MemVT = LN0->getMemoryVT();
7296     // If we zero all the possible extended bits, then we can turn this into
7297     // a zextload if we are running before legalize or the operation is legal.
7298     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7299     unsigned MemBitSize = MemVT.getScalarSizeInBits();
7300     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
7301     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
7302         ((!LegalOperations && LN0->isSimple()) ||
7303          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
7304       SDValue ExtLoad =
7305           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
7306                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
7307       AddToWorklist(N);
7308       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
7309       return SDValue(N, 0); // Return N so it doesn't get rechecked!
7310     }
7311   }
7312 
7313   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7314   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7315     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
7316                                            N0.getOperand(1), false))
7317       return BSwap;
7318   }
7319 
7320   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7321     return Shifts;
7322 
7323   if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
7324     return V;
7325 
7326   // Recognize the following pattern:
7327   //
7328   // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
7329   //
7330   // where bitmask is a mask that clears the upper bits of AndVT. The
7331   // number of bits in bitmask must be a power of two.
7332   auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
7333     if (LHS->getOpcode() != ISD::SIGN_EXTEND)
7334       return false;
7335 
7336     auto *C = dyn_cast<ConstantSDNode>(RHS);
7337     if (!C)
7338       return false;
7339 
7340     if (!C->getAPIntValue().isMask(
7341             LHS.getOperand(0).getValueType().getFixedSizeInBits()))
7342       return false;
7343 
7344     return true;
7345   };
7346 
7347   // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
7348   if (IsAndZeroExtMask(N0, N1))
7349     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0.getOperand(0));
7350 
7351   if (hasOperation(ISD::USUBSAT, VT))
7352     if (SDValue V = foldAndToUsubsat(N, DAG))
7353       return V;
7354 
7355   // Postpone until legalization completed to avoid interference with bswap
7356   // folding
7357   if (LegalOperations || VT.isVector())
7358     if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
7359       return R;
7360 
7361   return SDValue();
7362 }
7363 
7364 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)7365 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
7366                                         bool DemandHighBits) {
7367   if (!LegalOperations)
7368     return SDValue();
7369 
7370   EVT VT = N->getValueType(0);
7371   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
7372     return SDValue();
7373   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
7374     return SDValue();
7375 
7376   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
7377   bool LookPassAnd0 = false;
7378   bool LookPassAnd1 = false;
7379   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
7380     std::swap(N0, N1);
7381   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
7382     std::swap(N0, N1);
7383   if (N0.getOpcode() == ISD::AND) {
7384     if (!N0->hasOneUse())
7385       return SDValue();
7386     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7387     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
7388     // This is needed for X86.
7389     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
7390                   N01C->getZExtValue() != 0xFFFF))
7391       return SDValue();
7392     N0 = N0.getOperand(0);
7393     LookPassAnd0 = true;
7394   }
7395 
7396   if (N1.getOpcode() == ISD::AND) {
7397     if (!N1->hasOneUse())
7398       return SDValue();
7399     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
7400     if (!N11C || N11C->getZExtValue() != 0xFF)
7401       return SDValue();
7402     N1 = N1.getOperand(0);
7403     LookPassAnd1 = true;
7404   }
7405 
7406   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
7407     std::swap(N0, N1);
7408   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
7409     return SDValue();
7410   if (!N0->hasOneUse() || !N1->hasOneUse())
7411     return SDValue();
7412 
7413   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7414   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
7415   if (!N01C || !N11C)
7416     return SDValue();
7417   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
7418     return SDValue();
7419 
7420   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
7421   SDValue N00 = N0->getOperand(0);
7422   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
7423     if (!N00->hasOneUse())
7424       return SDValue();
7425     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
7426     if (!N001C || N001C->getZExtValue() != 0xFF)
7427       return SDValue();
7428     N00 = N00.getOperand(0);
7429     LookPassAnd0 = true;
7430   }
7431 
7432   SDValue N10 = N1->getOperand(0);
7433   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
7434     if (!N10->hasOneUse())
7435       return SDValue();
7436     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
7437     // Also allow 0xFFFF since the bits will be shifted out. This is needed
7438     // for X86.
7439     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
7440                    N101C->getZExtValue() != 0xFFFF))
7441       return SDValue();
7442     N10 = N10.getOperand(0);
7443     LookPassAnd1 = true;
7444   }
7445 
7446   if (N00 != N10)
7447     return SDValue();
7448 
7449   // Make sure everything beyond the low halfword gets set to zero since the SRL
7450   // 16 will clear the top bits.
7451   unsigned OpSizeInBits = VT.getSizeInBits();
7452   if (OpSizeInBits > 16) {
7453     // If the left-shift isn't masked out then the only way this is a bswap is
7454     // if all bits beyond the low 8 are 0. In that case the entire pattern
7455     // reduces to a left shift anyway: leave it for other parts of the combiner.
7456     if (DemandHighBits && !LookPassAnd0)
7457       return SDValue();
7458 
7459     // However, if the right shift isn't masked out then it might be because
7460     // it's not needed. See if we can spot that too. If the high bits aren't
7461     // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
7462     // upper bits to be zero.
7463     if (!LookPassAnd1) {
7464       unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
7465       if (!DAG.MaskedValueIsZero(N10,
7466                                  APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
7467         return SDValue();
7468     }
7469   }
7470 
7471   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
7472   if (OpSizeInBits > 16) {
7473     SDLoc DL(N);
7474     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
7475                       DAG.getConstant(OpSizeInBits - 16, DL,
7476                                       getShiftAmountTy(VT)));
7477   }
7478   return Res;
7479 }
7480 
7481 /// Return true if the specified node is an element that makes up a 32-bit
7482 /// packed halfword byteswap.
7483 /// ((x & 0x000000ff) << 8) |
7484 /// ((x & 0x0000ff00) >> 8) |
7485 /// ((x & 0x00ff0000) << 8) |
7486 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)7487 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
7488   if (!N->hasOneUse())
7489     return false;
7490 
7491   unsigned Opc = N.getOpcode();
7492   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
7493     return false;
7494 
7495   SDValue N0 = N.getOperand(0);
7496   unsigned Opc0 = N0.getOpcode();
7497   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
7498     return false;
7499 
7500   ConstantSDNode *N1C = nullptr;
7501   // SHL or SRL: look upstream for AND mask operand
7502   if (Opc == ISD::AND)
7503     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
7504   else if (Opc0 == ISD::AND)
7505     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7506   if (!N1C)
7507     return false;
7508 
7509   unsigned MaskByteOffset;
7510   switch (N1C->getZExtValue()) {
7511   default:
7512     return false;
7513   case 0xFF:       MaskByteOffset = 0; break;
7514   case 0xFF00:     MaskByteOffset = 1; break;
7515   case 0xFFFF:
7516     // In case demanded bits didn't clear the bits that will be shifted out.
7517     // This is needed for X86.
7518     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
7519       MaskByteOffset = 1;
7520       break;
7521     }
7522     return false;
7523   case 0xFF0000:   MaskByteOffset = 2; break;
7524   case 0xFF000000: MaskByteOffset = 3; break;
7525   }
7526 
7527   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
7528   if (Opc == ISD::AND) {
7529     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
7530       // (x >> 8) & 0xff
7531       // (x >> 8) & 0xff0000
7532       if (Opc0 != ISD::SRL)
7533         return false;
7534       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7535       if (!C || C->getZExtValue() != 8)
7536         return false;
7537     } else {
7538       // (x << 8) & 0xff00
7539       // (x << 8) & 0xff000000
7540       if (Opc0 != ISD::SHL)
7541         return false;
7542       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7543       if (!C || C->getZExtValue() != 8)
7544         return false;
7545     }
7546   } else if (Opc == ISD::SHL) {
7547     // (x & 0xff) << 8
7548     // (x & 0xff0000) << 8
7549     if (MaskByteOffset != 0 && MaskByteOffset != 2)
7550       return false;
7551     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
7552     if (!C || C->getZExtValue() != 8)
7553       return false;
7554   } else { // Opc == ISD::SRL
7555     // (x & 0xff00) >> 8
7556     // (x & 0xff000000) >> 8
7557     if (MaskByteOffset != 1 && MaskByteOffset != 3)
7558       return false;
7559     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
7560     if (!C || C->getZExtValue() != 8)
7561       return false;
7562   }
7563 
7564   if (Parts[MaskByteOffset])
7565     return false;
7566 
7567   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
7568   return true;
7569 }
7570 
7571 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)7572 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
7573   if (N.getOpcode() == ISD::OR)
7574     return isBSwapHWordElement(N.getOperand(0), Parts) &&
7575            isBSwapHWordElement(N.getOperand(1), Parts);
7576 
7577   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
7578     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
7579     if (!C || C->getAPIntValue() != 16)
7580       return false;
7581     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
7582     return true;
7583   }
7584 
7585   return false;
7586 }
7587 
7588 // Match this pattern:
7589 //   (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
7590 // And rewrite this to:
7591 //   (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT,EVT ShiftAmountTy)7592 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
7593                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
7594                                        SDValue N1, EVT VT, EVT ShiftAmountTy) {
7595   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
7596          "MatchBSwapHWordOrAndAnd: expecting i32");
7597   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
7598     return SDValue();
7599   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
7600     return SDValue();
7601   // TODO: this is too restrictive; lifting this restriction requires more tests
7602   if (!N0->hasOneUse() || !N1->hasOneUse())
7603     return SDValue();
7604   ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
7605   ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
7606   if (!Mask0 || !Mask1)
7607     return SDValue();
7608   if (Mask0->getAPIntValue() != 0xff00ff00 ||
7609       Mask1->getAPIntValue() != 0x00ff00ff)
7610     return SDValue();
7611   SDValue Shift0 = N0.getOperand(0);
7612   SDValue Shift1 = N1.getOperand(0);
7613   if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
7614     return SDValue();
7615   ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
7616   ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
7617   if (!ShiftAmt0 || !ShiftAmt1)
7618     return SDValue();
7619   if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
7620     return SDValue();
7621   if (Shift0.getOperand(0) != Shift1.getOperand(0))
7622     return SDValue();
7623 
7624   SDLoc DL(N);
7625   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
7626   SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
7627   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
7628 }
7629 
7630 /// Match a 32-bit packed halfword bswap. That is
7631 /// ((x & 0x000000ff) << 8) |
7632 /// ((x & 0x0000ff00) >> 8) |
7633 /// ((x & 0x00ff0000) << 8) |
7634 /// ((x & 0xff000000) >> 8)
7635 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)7636 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
7637   if (!LegalOperations)
7638     return SDValue();
7639 
7640   EVT VT = N->getValueType(0);
7641   if (VT != MVT::i32)
7642     return SDValue();
7643   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
7644     return SDValue();
7645 
7646   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
7647                                               getShiftAmountTy(VT)))
7648     return BSwap;
7649 
7650   // Try again with commuted operands.
7651   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
7652                                               getShiftAmountTy(VT)))
7653     return BSwap;
7654 
7655 
7656   // Look for either
7657   // (or (bswaphpair), (bswaphpair))
7658   // (or (or (bswaphpair), (and)), (and))
7659   // (or (or (and), (bswaphpair)), (and))
7660   SDNode *Parts[4] = {};
7661 
7662   if (isBSwapHWordPair(N0, Parts)) {
7663     // (or (or (and), (and)), (or (and), (and)))
7664     if (!isBSwapHWordPair(N1, Parts))
7665       return SDValue();
7666   } else if (N0.getOpcode() == ISD::OR) {
7667     // (or (or (or (and), (and)), (and)), (and))
7668     if (!isBSwapHWordElement(N1, Parts))
7669       return SDValue();
7670     SDValue N00 = N0.getOperand(0);
7671     SDValue N01 = N0.getOperand(1);
7672     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
7673         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
7674       return SDValue();
7675   } else {
7676     return SDValue();
7677   }
7678 
7679   // Make sure the parts are all coming from the same node.
7680   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
7681     return SDValue();
7682 
7683   SDLoc DL(N);
7684   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
7685                               SDValue(Parts[0], 0));
7686 
7687   // Result of the bswap should be rotated by 16. If it's not legal, then
7688   // do  (x << 16) | (x >> 16).
7689   SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
7690   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
7691     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
7692   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
7693     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
7694   return DAG.getNode(ISD::OR, DL, VT,
7695                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
7696                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
7697 }
7698 
7699 /// This contains all DAGCombine rules which reduce two values combined by
7700 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)7701 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
7702   EVT VT = N1.getValueType();
7703   SDLoc DL(N);
7704 
7705   // fold (or x, undef) -> -1
7706   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
7707     return DAG.getAllOnesConstant(DL, VT);
7708 
7709   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
7710     return V;
7711 
7712   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
7713   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
7714       // Don't increase # computations.
7715       (N0->hasOneUse() || N1->hasOneUse())) {
7716     // We can only do this xform if we know that bits from X that are set in C2
7717     // but not in C1 are already zero.  Likewise for Y.
7718     if (const ConstantSDNode *N0O1C =
7719         getAsNonOpaqueConstant(N0.getOperand(1))) {
7720       if (const ConstantSDNode *N1O1C =
7721           getAsNonOpaqueConstant(N1.getOperand(1))) {
7722         // We can only do this xform if we know that bits from X that are set in
7723         // C2 but not in C1 are already zero.  Likewise for Y.
7724         const APInt &LHSMask = N0O1C->getAPIntValue();
7725         const APInt &RHSMask = N1O1C->getAPIntValue();
7726 
7727         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
7728             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
7729           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
7730                                   N0.getOperand(0), N1.getOperand(0));
7731           return DAG.getNode(ISD::AND, DL, VT, X,
7732                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
7733         }
7734       }
7735     }
7736   }
7737 
7738   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
7739   if (N0.getOpcode() == ISD::AND &&
7740       N1.getOpcode() == ISD::AND &&
7741       N0.getOperand(0) == N1.getOperand(0) &&
7742       // Don't increase # computations.
7743       (N0->hasOneUse() || N1->hasOneUse())) {
7744     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
7745                             N0.getOperand(1), N1.getOperand(1));
7746     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
7747   }
7748 
7749   return SDValue();
7750 }
7751 
7752 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)7753 static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
7754                                   SDNode *N) {
7755   EVT VT = N0.getValueType();
7756 
7757   auto peekThroughResize = [](SDValue V) {
7758     if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
7759       return V->getOperand(0);
7760     return V;
7761   };
7762 
7763   SDValue N0Resized = peekThroughResize(N0);
7764   if (N0Resized.getOpcode() == ISD::AND) {
7765     SDValue N1Resized = peekThroughResize(N1);
7766     SDValue N00 = N0Resized.getOperand(0);
7767     SDValue N01 = N0Resized.getOperand(1);
7768 
7769     // fold or (and x, y), x --> x
7770     if (N00 == N1Resized || N01 == N1Resized)
7771       return N1;
7772 
7773     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
7774     // TODO: Set AllowUndefs = true.
7775     if (SDValue NotOperand = getBitwiseNotOperand(N01, N00,
7776                                                   /* AllowUndefs */ false)) {
7777       if (peekThroughResize(NotOperand) == N1Resized)
7778         return DAG.getNode(ISD::OR, SDLoc(N), VT,
7779                            DAG.getZExtOrTrunc(N00, SDLoc(N), VT), N1);
7780     }
7781 
7782     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
7783     if (SDValue NotOperand = getBitwiseNotOperand(N00, N01,
7784                                                   /* AllowUndefs */ false)) {
7785       if (peekThroughResize(NotOperand) == N1Resized)
7786         return DAG.getNode(ISD::OR, SDLoc(N), VT,
7787                            DAG.getZExtOrTrunc(N01, SDLoc(N), VT), N1);
7788     }
7789   }
7790 
7791   if (N0.getOpcode() == ISD::XOR) {
7792     // fold or (xor x, y), x --> or x, y
7793     //      or (xor x, y), (x and/or y) --> or x, y
7794     SDValue N00 = N0.getOperand(0);
7795     SDValue N01 = N0.getOperand(1);
7796     if (N00 == N1)
7797       return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1);
7798     if (N01 == N1)
7799       return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1);
7800 
7801     if (N1.getOpcode() == ISD::AND || N1.getOpcode() == ISD::OR) {
7802       SDValue N10 = N1.getOperand(0);
7803       SDValue N11 = N1.getOperand(1);
7804       if ((N00 == N10 && N01 == N11) || (N00 == N11 && N01 == N10))
7805         return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N01);
7806     }
7807   }
7808 
7809   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7810     return R;
7811 
7812   auto peekThroughZext = [](SDValue V) {
7813     if (V->getOpcode() == ISD::ZERO_EXTEND)
7814       return V->getOperand(0);
7815     return V;
7816   };
7817 
7818   // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
7819   if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
7820       N0.getOperand(0) == N1.getOperand(0) &&
7821       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
7822     return N0;
7823 
7824   // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
7825   if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
7826       N0.getOperand(1) == N1.getOperand(0) &&
7827       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
7828     return N0;
7829 
7830   return SDValue();
7831 }
7832 
visitOR(SDNode * N)7833 SDValue DAGCombiner::visitOR(SDNode *N) {
7834   SDValue N0 = N->getOperand(0);
7835   SDValue N1 = N->getOperand(1);
7836   EVT VT = N1.getValueType();
7837 
7838   // x | x --> x
7839   if (N0 == N1)
7840     return N0;
7841 
7842   // fold (or c1, c2) -> c1|c2
7843   if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
7844     return C;
7845 
7846   // canonicalize constant to RHS
7847   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7848       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7849     return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
7850 
7851   // fold vector ops
7852   if (VT.isVector()) {
7853     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
7854       return FoldedVOp;
7855 
7856     // fold (or x, 0) -> x, vector edition
7857     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
7858       return N0;
7859 
7860     // fold (or x, -1) -> -1, vector edition
7861     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
7862       // do not return N1, because undef node may exist in N1
7863       return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
7864 
7865     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
7866     // Do this only if the resulting type / shuffle is legal.
7867     auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
7868     auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
7869     if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
7870       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
7871       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
7872       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
7873       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
7874       // Ensure both shuffles have a zero input.
7875       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
7876         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
7877         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
7878         bool CanFold = true;
7879         int NumElts = VT.getVectorNumElements();
7880         SmallVector<int, 4> Mask(NumElts, -1);
7881 
7882         for (int i = 0; i != NumElts; ++i) {
7883           int M0 = SV0->getMaskElt(i);
7884           int M1 = SV1->getMaskElt(i);
7885 
7886           // Determine if either index is pointing to a zero vector.
7887           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7888           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7889 
7890           // If one element is zero and the otherside is undef, keep undef.
7891           // This also handles the case that both are undef.
7892           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7893             continue;
7894 
7895           // Make sure only one of the elements is zero.
7896           if (M0Zero == M1Zero) {
7897             CanFold = false;
7898             break;
7899           }
7900 
7901           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7902 
7903           // We have a zero and non-zero element. If the non-zero came from
7904           // SV0 make the index a LHS index. If it came from SV1, make it
7905           // a RHS index. We need to mod by NumElts because we don't care
7906           // which operand it came from in the original shuffles.
7907           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7908         }
7909 
7910         if (CanFold) {
7911           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
7912           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
7913 
7914           SDValue LegalShuffle =
7915               TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
7916                                           Mask, DAG);
7917           if (LegalShuffle)
7918             return LegalShuffle;
7919         }
7920       }
7921     }
7922   }
7923 
7924   // fold (or x, 0) -> x
7925   if (isNullConstant(N1))
7926     return N0;
7927 
7928   // fold (or x, -1) -> -1
7929   if (isAllOnesConstant(N1))
7930     return N1;
7931 
7932   if (SDValue NewSel = foldBinOpIntoSelect(N))
7933     return NewSel;
7934 
7935   // fold (or x, c) -> c iff (x & ~c) == 0
7936   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
7937   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
7938     return N1;
7939 
7940   if (SDValue R = foldAndOrOfSETCC(N, DAG))
7941     return R;
7942 
7943   if (SDValue Combined = visitORLike(N0, N1, N))
7944     return Combined;
7945 
7946   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7947     return Combined;
7948 
7949   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7950   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7951     return BSwap;
7952   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7953     return BSwap;
7954 
7955   // reassociate or
7956   if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
7957     return ROR;
7958 
7959   // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
7960   if (SDValue SD = reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, SDLoc(N),
7961                                         VT, N0, N1))
7962     return SD;
7963 
7964   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7965   // iff (c1 & c2) != 0 or c1/c2 are undef.
7966   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7967     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
7968   };
7969   if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7970       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
7971     if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
7972                                                  {N1, N0.getOperand(1)})) {
7973       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
7974       AddToWorklist(IOR.getNode());
7975       return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
7976     }
7977   }
7978 
7979   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7980     return Combined;
7981   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
7982     return Combined;
7983 
7984   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
7985   if (N0.getOpcode() == N1.getOpcode())
7986     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7987       return V;
7988 
7989   // See if this is some rotate idiom.
7990   if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
7991     return Rot;
7992 
7993   if (SDValue Load = MatchLoadCombine(N))
7994     return Load;
7995 
7996   // Simplify the operands using demanded-bits information.
7997   if (SimplifyDemandedBits(SDValue(N, 0)))
7998     return SDValue(N, 0);
7999 
8000   // If OR can be rewritten into ADD, try combines based on ADD.
8001   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8002       DAG.isADDLike(SDValue(N, 0)))
8003     if (SDValue Combined = visitADDLike(N))
8004       return Combined;
8005 
8006   // Postpone until legalization completed to avoid interference with bswap
8007   // folding
8008   if (LegalOperations || VT.isVector())
8009     if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
8010       return R;
8011 
8012   return SDValue();
8013 }
8014 
stripConstantMask(const SelectionDAG & DAG,SDValue Op,SDValue & Mask)8015 static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
8016                                  SDValue &Mask) {
8017   if (Op.getOpcode() == ISD::AND &&
8018       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
8019     Mask = Op.getOperand(1);
8020     return Op.getOperand(0);
8021   }
8022   return Op;
8023 }
8024 
8025 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(const SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)8026 static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8027                             SDValue &Mask) {
8028   Op = stripConstantMask(DAG, Op, Mask);
8029   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8030     Shift = Op;
8031     return true;
8032   }
8033   return false;
8034 }
8035 
8036 /// Helper function for visitOR to extract the needed side of a rotate idiom
8037 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
8038 /// InstCombine merged some outside op with one of the shifts from
8039 /// the rotate pattern.
8040 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8041 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
8042 /// patterns:
8043 ///
8044 ///   (or (add v v) (shrl v bitwidth-1)):
8045 ///     expands (add v v) -> (shl v 1)
8046 ///
8047 ///   (or (mul v c0) (shrl (mul v c1) c2)):
8048 ///     expands (mul v c0) -> (shl (mul v c1) c3)
8049 ///
8050 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
8051 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
8052 ///
8053 ///   (or (shl v c0) (shrl (shl v c1) c2)):
8054 ///     expands (shl v c0) -> (shl (shl v c1) c3)
8055 ///
8056 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
8057 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
8058 ///
8059 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)8060 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
8061                                      SDValue ExtractFrom, SDValue &Mask,
8062                                      const SDLoc &DL) {
8063   assert(OppShift && ExtractFrom && "Empty SDValue");
8064   if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8065     return SDValue();
8066 
8067   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
8068 
8069   // Value and Type of the shift.
8070   SDValue OppShiftLHS = OppShift.getOperand(0);
8071   EVT ShiftedVT = OppShiftLHS.getValueType();
8072 
8073   // Amount of the existing shift.
8074   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
8075 
8076   // (add v v) -> (shl v 1)
8077   // TODO: Should this be a general DAG canonicalization?
8078   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8079       ExtractFrom.getOpcode() == ISD::ADD &&
8080       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
8081       ExtractFrom.getOperand(0) == OppShiftLHS &&
8082       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8083     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
8084                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
8085 
8086   // Preconditions:
8087   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8088   //
8089   // Find opcode of the needed shift to be extracted from (op0 v c0).
8090   unsigned Opcode = ISD::DELETED_NODE;
8091   bool IsMulOrDiv = false;
8092   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8093   // opcode or its arithmetic (mul or udiv) variant.
8094   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8095     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8096     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8097       return false;
8098     Opcode = NeededShift;
8099     return true;
8100   };
8101   // op0 must be either the needed shift opcode or the mul/udiv equivalent
8102   // that the needed shift can be extracted from.
8103   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8104       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8105     return SDValue();
8106 
8107   // op0 must be the same opcode on both sides, have the same LHS argument,
8108   // and produce the same value type.
8109   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8110       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
8111       ShiftedVT != ExtractFrom.getValueType())
8112     return SDValue();
8113 
8114   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8115   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
8116   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8117   ConstantSDNode *ExtractFromCst =
8118       isConstOrConstSplat(ExtractFrom.getOperand(1));
8119   // TODO: We should be able to handle non-uniform constant vectors for these values
8120   // Check that we have constant values.
8121   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8122       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8123       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8124     return SDValue();
8125 
8126   // Compute the shift amount we need to extract to complete the rotate.
8127   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8128   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
8129     return SDValue();
8130   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8131   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8132   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8133   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8134   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
8135 
8136   // Now try extract the needed shift from the ExtractFrom op and see if the
8137   // result matches up with the existing shift's LHS op.
8138   if (IsMulOrDiv) {
8139     // Op to extract from is a mul or udiv by a constant.
8140     // Check:
8141     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8142     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8143     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
8144                                                  NeededShiftAmt.getZExtValue());
8145     APInt ResultAmt;
8146     APInt Rem;
8147     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
8148     if (Rem != 0 || ResultAmt != OppLHSAmt)
8149       return SDValue();
8150   } else {
8151     // Op to extract from is a shift by a constant.
8152     // Check:
8153     //      c2 - (bitwidth(op0 v c0) - c1) == c0
8154     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8155                                           ExtractFromAmt.getBitWidth()))
8156       return SDValue();
8157   }
8158 
8159   // Return the expanded shift op that should allow a rotate to be formed.
8160   EVT ShiftVT = OppShift.getOperand(1).getValueType();
8161   EVT ResVT = ExtractFrom.getValueType();
8162   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
8163   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
8164 }
8165 
8166 // Return true if we can prove that, whenever Neg and Pos are both in the
8167 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
8168 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8169 //
8170 //     (or (shift1 X, Neg), (shift2 X, Pos))
8171 //
8172 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8173 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
8174 // to consider shift amounts with defined behavior.
8175 //
8176 // The IsRotate flag should be set when the LHS of both shifts is the same.
8177 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate)8178 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8179                            SelectionDAG &DAG, bool IsRotate) {
8180   const auto &TLI = DAG.getTargetLoweringInfo();
8181   // If EltSize is a power of 2 then:
8182   //
8183   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8184   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8185   //
8186   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8187   // for the stronger condition:
8188   //
8189   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
8190   //
8191   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8192   // we can just replace Neg with Neg' for the rest of the function.
8193   //
8194   // In other cases we check for the even stronger condition:
8195   //
8196   //     Neg == EltSize - Pos                                    [B]
8197   //
8198   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
8199   // behavior if Pos == 0 (and consequently Neg == EltSize).
8200   //
8201   // We could actually use [A] whenever EltSize is a power of 2, but the
8202   // only extra cases that it would match are those uninteresting ones
8203   // where Neg and Pos are never in range at the same time.  E.g. for
8204   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8205   // as well as (sub 32, Pos), but:
8206   //
8207   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8208   //
8209   // always invokes undefined behavior for 32-bit X.
8210   //
8211   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8212   // This allows us to peek through any operations that only affect Mask's
8213   // un-demanded bits.
8214   //
8215   // NOTE: We can only do this when matching operations which won't modify the
8216   // least Log2(EltSize) significant bits and not a general funnel shift.
8217   unsigned MaskLoBits = 0;
8218   if (IsRotate && isPowerOf2_64(EltSize)) {
8219     unsigned Bits = Log2_64(EltSize);
8220     unsigned NegBits = Neg.getScalarValueSizeInBits();
8221     if (NegBits >= Bits) {
8222       APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
8223       if (SDValue Inner =
8224               TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
8225         Neg = Inner;
8226         MaskLoBits = Bits;
8227       }
8228     }
8229   }
8230 
8231   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8232   if (Neg.getOpcode() != ISD::SUB)
8233     return false;
8234   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
8235   if (!NegC)
8236     return false;
8237   SDValue NegOp1 = Neg.getOperand(1);
8238 
8239   // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8240   // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8241   // are redundant for the purpose of the equality.
8242   if (MaskLoBits) {
8243     unsigned PosBits = Pos.getScalarValueSizeInBits();
8244     if (PosBits >= MaskLoBits) {
8245       APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
8246       if (SDValue Inner =
8247               TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
8248         Pos = Inner;
8249       }
8250     }
8251   }
8252 
8253   // The condition we need is now:
8254   //
8255   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8256   //
8257   // If NegOp1 == Pos then we need:
8258   //
8259   //              EltSize & Mask == NegC & Mask
8260   //
8261   // (because "x & Mask" is a truncation and distributes through subtraction).
8262   //
8263   // We also need to account for a potential truncation of NegOp1 if the amount
8264   // has already been legalized to a shift amount type.
8265   APInt Width;
8266   if ((Pos == NegOp1) ||
8267       (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
8268     Width = NegC->getAPIntValue();
8269 
8270   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8271   // Then the condition we want to prove becomes:
8272   //
8273   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8274   //
8275   // which, again because "x & Mask" is a truncation, becomes:
8276   //
8277   //                NegC & Mask == (EltSize - PosC) & Mask
8278   //             EltSize & Mask == (NegC + PosC) & Mask
8279   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
8280     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
8281       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8282     else
8283       return false;
8284   } else
8285     return false;
8286 
8287   // Now we just need to check that EltSize & Mask == Width & Mask.
8288   if (MaskLoBits)
8289     // EltSize & Mask is 0 since Mask is EltSize - 1.
8290     return Width.getLoBits(MaskLoBits) == 0;
8291   return Width == EltSize;
8292 }
8293 
8294 // A subroutine of MatchRotate used once we have found an OR of two opposite
8295 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
8296 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8297 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
8298 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)8299 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8300                                        SDValue Neg, SDValue InnerPos,
8301                                        SDValue InnerNeg, bool HasPos,
8302                                        unsigned PosOpcode, unsigned NegOpcode,
8303                                        const SDLoc &DL) {
8304   // fold (or (shl x, (*ext y)),
8305   //          (srl x, (*ext (sub 32, y)))) ->
8306   //   (rotl x, y) or (rotr x, (sub 32, y))
8307   //
8308   // fold (or (shl x, (*ext (sub 32, y))),
8309   //          (srl x, (*ext y))) ->
8310   //   (rotr x, y) or (rotl x, (sub 32, y))
8311   EVT VT = Shifted.getValueType();
8312   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
8313                      /*IsRotate*/ true)) {
8314     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
8315                        HasPos ? Pos : Neg);
8316   }
8317 
8318   return SDValue();
8319 }
8320 
8321 // A subroutine of MatchRotate used once we have found an OR of two opposite
8322 // shifts of N0 + N1.  If Neg == <operand size> - Pos then the OR reduces
8323 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
8324 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
8325 // Neg with outer conversions stripped away.
8326 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)8327 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
8328                                        SDValue Neg, SDValue InnerPos,
8329                                        SDValue InnerNeg, bool HasPos,
8330                                        unsigned PosOpcode, unsigned NegOpcode,
8331                                        const SDLoc &DL) {
8332   EVT VT = N0.getValueType();
8333   unsigned EltBits = VT.getScalarSizeInBits();
8334 
8335   // fold (or (shl x0, (*ext y)),
8336   //          (srl x1, (*ext (sub 32, y)))) ->
8337   //   (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
8338   //
8339   // fold (or (shl x0, (*ext (sub 32, y))),
8340   //          (srl x1, (*ext y))) ->
8341   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8342   if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
8343     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
8344                        HasPos ? Pos : Neg);
8345   }
8346 
8347   // Matching the shift+xor cases, we can't easily use the xor'd shift amount
8348   // so for now just use the PosOpcode case if its legal.
8349   // TODO: When can we use the NegOpcode case?
8350   if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
8351     auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
8352       if (Op.getOpcode() != BinOpc)
8353         return false;
8354       ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1));
8355       return Cst && (Cst->getAPIntValue() == Imm);
8356     };
8357 
8358     // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
8359     //   -> (fshl x0, x1, y)
8360     if (IsBinOpImm(N1, ISD::SRL, 1) &&
8361         IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
8362         InnerPos == InnerNeg.getOperand(0) &&
8363         TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
8364       return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos);
8365     }
8366 
8367     // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
8368     //   -> (fshr x0, x1, y)
8369     if (IsBinOpImm(N0, ISD::SHL, 1) &&
8370         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8371         InnerNeg == InnerPos.getOperand(0) &&
8372         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
8373       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
8374     }
8375 
8376     // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
8377     //   -> (fshr x0, x1, y)
8378     // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
8379     if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) &&
8380         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8381         InnerNeg == InnerPos.getOperand(0) &&
8382         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
8383       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
8384     }
8385   }
8386 
8387   return SDValue();
8388 }
8389 
8390 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
8391 // idioms for rotate, and if the target supports rotation instructions, generate
8392 // a rot[lr]. This also matches funnel shift patterns, similar to rotation but
8393 // with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)8394 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
8395   EVT VT = LHS.getValueType();
8396 
8397   // The target must have at least one rotate/funnel flavor.
8398   // We still try to match rotate by constant pre-legalization.
8399   // TODO: Support pre-legalization funnel-shift by constant.
8400   bool HasROTL = hasOperation(ISD::ROTL, VT);
8401   bool HasROTR = hasOperation(ISD::ROTR, VT);
8402   bool HasFSHL = hasOperation(ISD::FSHL, VT);
8403   bool HasFSHR = hasOperation(ISD::FSHR, VT);
8404 
8405   // If the type is going to be promoted and the target has enabled custom
8406   // lowering for rotate, allow matching rotate by non-constants. Only allow
8407   // this for scalar types.
8408   if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
8409                                   TargetLowering::TypePromoteInteger) {
8410     HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom;
8411     HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom;
8412   }
8413 
8414   if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8415     return SDValue();
8416 
8417   // Check for truncated rotate.
8418   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
8419       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
8420     assert(LHS.getValueType() == RHS.getValueType());
8421     if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
8422       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
8423     }
8424   }
8425 
8426   // Match "(X shl/srl V1) & V2" where V2 may not be present.
8427   SDValue LHSShift;   // The shift.
8428   SDValue LHSMask;    // AND value if any.
8429   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
8430 
8431   SDValue RHSShift;   // The shift.
8432   SDValue RHSMask;    // AND value if any.
8433   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
8434 
8435   // If neither side matched a rotate half, bail
8436   if (!LHSShift && !RHSShift)
8437     return SDValue();
8438 
8439   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
8440   // side of the rotate, so try to handle that here. In all cases we need to
8441   // pass the matched shift from the opposite side to compute the opcode and
8442   // needed shift amount to extract.  We still want to do this if both sides
8443   // matched a rotate half because one half may be a potential overshift that
8444   // can be broken down (ie if InstCombine merged two shl or srl ops into a
8445   // single one).
8446 
8447   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
8448   if (LHSShift)
8449     if (SDValue NewRHSShift =
8450             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
8451       RHSShift = NewRHSShift;
8452   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
8453   if (RHSShift)
8454     if (SDValue NewLHSShift =
8455             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
8456       LHSShift = NewLHSShift;
8457 
8458   // If a side is still missing, nothing else we can do.
8459   if (!RHSShift || !LHSShift)
8460     return SDValue();
8461 
8462   // At this point we've matched or extracted a shift op on each side.
8463 
8464   if (LHSShift.getOpcode() == RHSShift.getOpcode())
8465     return SDValue(); // Shifts must disagree.
8466 
8467   // Canonicalize shl to left side in a shl/srl pair.
8468   if (RHSShift.getOpcode() == ISD::SHL) {
8469     std::swap(LHS, RHS);
8470     std::swap(LHSShift, RHSShift);
8471     std::swap(LHSMask, RHSMask);
8472   }
8473 
8474   // Something has gone wrong - we've lost the shl/srl pair - bail.
8475   if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
8476     return SDValue();
8477 
8478   unsigned EltSizeInBits = VT.getScalarSizeInBits();
8479   SDValue LHSShiftArg = LHSShift.getOperand(0);
8480   SDValue LHSShiftAmt = LHSShift.getOperand(1);
8481   SDValue RHSShiftArg = RHSShift.getOperand(0);
8482   SDValue RHSShiftAmt = RHSShift.getOperand(1);
8483 
8484   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
8485                                         ConstantSDNode *RHS) {
8486     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
8487   };
8488 
8489   auto ApplyMasks = [&](SDValue Res) {
8490     // If there is an AND of either shifted operand, apply it to the result.
8491     if (LHSMask.getNode() || RHSMask.getNode()) {
8492       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
8493       SDValue Mask = AllOnes;
8494 
8495       if (LHSMask.getNode()) {
8496         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
8497         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
8498                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
8499       }
8500       if (RHSMask.getNode()) {
8501         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
8502         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
8503                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
8504       }
8505 
8506       Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
8507     }
8508 
8509     return Res;
8510   };
8511 
8512   // TODO: Support pre-legalization funnel-shift by constant.
8513   bool IsRotate = LHSShiftArg == RHSShiftArg;
8514   if (!IsRotate && !(HasFSHL || HasFSHR)) {
8515     if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
8516         ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
8517       // Look for a disguised rotate by constant.
8518       // The common shifted operand X may be hidden inside another 'or'.
8519       SDValue X, Y;
8520       auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
8521         if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
8522           return false;
8523         if (CommonOp == Or.getOperand(0)) {
8524           X = CommonOp;
8525           Y = Or.getOperand(1);
8526           return true;
8527         }
8528         if (CommonOp == Or.getOperand(1)) {
8529           X = CommonOp;
8530           Y = Or.getOperand(0);
8531           return true;
8532         }
8533         return false;
8534       };
8535 
8536       SDValue Res;
8537       if (matchOr(LHSShiftArg, RHSShiftArg)) {
8538         // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
8539         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
8540         SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
8541         Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
8542       } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
8543         // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
8544         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
8545         SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
8546         Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
8547       } else {
8548         return SDValue();
8549       }
8550 
8551       return ApplyMasks(Res);
8552     }
8553 
8554     return SDValue(); // Requires funnel shift support.
8555   }
8556 
8557   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
8558   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
8559   // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
8560   // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
8561   // iff C1+C2 == EltSizeInBits
8562   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
8563     SDValue Res;
8564     if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
8565       bool UseROTL = !LegalOperations || HasROTL;
8566       Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
8567                         UseROTL ? LHSShiftAmt : RHSShiftAmt);
8568     } else {
8569       bool UseFSHL = !LegalOperations || HasFSHL;
8570       Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
8571                         RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
8572     }
8573 
8574     return ApplyMasks(Res);
8575   }
8576 
8577   // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
8578   // shift.
8579   if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8580     return SDValue();
8581 
8582   // If there is a mask here, and we have a variable shift, we can't be sure
8583   // that we're masking out the right stuff.
8584   if (LHSMask.getNode() || RHSMask.getNode())
8585     return SDValue();
8586 
8587   // If the shift amount is sign/zext/any-extended just peel it off.
8588   SDValue LExtOp0 = LHSShiftAmt;
8589   SDValue RExtOp0 = RHSShiftAmt;
8590   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8591        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8592        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8593        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
8594       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8595        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8596        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8597        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
8598     LExtOp0 = LHSShiftAmt.getOperand(0);
8599     RExtOp0 = RHSShiftAmt.getOperand(0);
8600   }
8601 
8602   if (IsRotate && (HasROTL || HasROTR)) {
8603     SDValue TryL =
8604         MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
8605                           RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
8606     if (TryL)
8607       return TryL;
8608 
8609     SDValue TryR =
8610         MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
8611                           LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
8612     if (TryR)
8613       return TryR;
8614   }
8615 
8616   SDValue TryL =
8617       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
8618                         LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
8619   if (TryL)
8620     return TryL;
8621 
8622   SDValue TryR =
8623       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
8624                         RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
8625   if (TryR)
8626     return TryR;
8627 
8628   return SDValue();
8629 }
8630 
8631 /// Recursively traverses the expression calculating the origin of the requested
8632 /// byte of the given value. Returns std::nullopt if the provider can't be
8633 /// calculated.
8634 ///
8635 /// For all the values except the root of the expression, we verify that the
8636 /// value has exactly one use and if not then return std::nullopt. This way if
8637 /// the origin of the byte is returned it's guaranteed that the values which
8638 /// contribute to the byte are not used outside of this expression.
8639 
8640 /// However, there is a special case when dealing with vector loads -- we allow
8641 /// more than one use if the load is a vector type.  Since the values that
8642 /// contribute to the byte ultimately come from the ExtractVectorElements of the
8643 /// Load, we don't care if the Load has uses other than ExtractVectorElements,
8644 /// because those operations are independent from the pattern to be combined.
8645 /// For vector loads, we simply care that the ByteProviders are adjacent
8646 /// positions of the same vector, and their index matches the byte that is being
8647 /// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
8648 /// is the index used in an ExtractVectorElement, and \p StartingIndex is the
8649 /// byte position we are trying to provide for the LoadCombine. If these do
8650 /// not match, then we can not combine the vector loads. \p Index uses the
8651 /// byte position we are trying to provide for and is matched against the
8652 /// shl and load size. The \p Index algorithm ensures the requested byte is
8653 /// provided for by the pattern, and the pattern does not over provide bytes.
8654 ///
8655 ///
8656 /// The supported LoadCombine pattern for vector loads is as follows
8657 ///                              or
8658 ///                          /        \
8659 ///                         or        shl
8660 ///                       /     \      |
8661 ///                     or      shl   zext
8662 ///                   /    \     |     |
8663 ///                 shl   zext  zext  EVE*
8664 ///                  |     |     |     |
8665 ///                 zext  EVE*  EVE*  LOAD
8666 ///                  |     |     |
8667 ///                 EVE*  LOAD  LOAD
8668 ///                  |
8669 ///                 LOAD
8670 ///
8671 /// *ExtractVectorElement
8672 using SDByteProvider = ByteProvider<SDNode *>;
8673 
8674 static std::optional<SDByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,std::optional<uint64_t> VectorIndex,unsigned StartingIndex=0)8675 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
8676                       std::optional<uint64_t> VectorIndex,
8677                       unsigned StartingIndex = 0) {
8678 
8679   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
8680   if (Depth == 10)
8681     return std::nullopt;
8682 
8683   // Only allow multiple uses if the instruction is a vector load (in which
8684   // case we will use the load for every ExtractVectorElement)
8685   if (Depth && !Op.hasOneUse() &&
8686       (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
8687     return std::nullopt;
8688 
8689   // Fail to combine if we have encountered anything but a LOAD after handling
8690   // an ExtractVectorElement.
8691   if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
8692     return std::nullopt;
8693 
8694   unsigned BitWidth = Op.getValueSizeInBits();
8695   if (BitWidth % 8 != 0)
8696     return std::nullopt;
8697   unsigned ByteWidth = BitWidth / 8;
8698   assert(Index < ByteWidth && "invalid index requested");
8699   (void) ByteWidth;
8700 
8701   switch (Op.getOpcode()) {
8702   case ISD::OR: {
8703     auto LHS =
8704         calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
8705     if (!LHS)
8706       return std::nullopt;
8707     auto RHS =
8708         calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
8709     if (!RHS)
8710       return std::nullopt;
8711 
8712     if (LHS->isConstantZero())
8713       return RHS;
8714     if (RHS->isConstantZero())
8715       return LHS;
8716     return std::nullopt;
8717   }
8718   case ISD::SHL: {
8719     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
8720     if (!ShiftOp)
8721       return std::nullopt;
8722 
8723     uint64_t BitShift = ShiftOp->getZExtValue();
8724 
8725     if (BitShift % 8 != 0)
8726       return std::nullopt;
8727     uint64_t ByteShift = BitShift / 8;
8728 
8729     // If we are shifting by an amount greater than the index we are trying to
8730     // provide, then do not provide anything. Otherwise, subtract the index by
8731     // the amount we shifted by.
8732     return Index < ByteShift
8733                ? SDByteProvider::getConstantZero()
8734                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
8735                                        Depth + 1, VectorIndex, Index);
8736   }
8737   case ISD::ANY_EXTEND:
8738   case ISD::SIGN_EXTEND:
8739   case ISD::ZERO_EXTEND: {
8740     SDValue NarrowOp = Op->getOperand(0);
8741     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8742     if (NarrowBitWidth % 8 != 0)
8743       return std::nullopt;
8744     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8745 
8746     if (Index >= NarrowByteWidth)
8747       return Op.getOpcode() == ISD::ZERO_EXTEND
8748                  ? std::optional<SDByteProvider>(
8749                        SDByteProvider::getConstantZero())
8750                  : std::nullopt;
8751     return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
8752                                  StartingIndex);
8753   }
8754   case ISD::BSWAP:
8755     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
8756                                  Depth + 1, VectorIndex, StartingIndex);
8757   case ISD::EXTRACT_VECTOR_ELT: {
8758     auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
8759     if (!OffsetOp)
8760       return std::nullopt;
8761 
8762     VectorIndex = OffsetOp->getZExtValue();
8763 
8764     SDValue NarrowOp = Op->getOperand(0);
8765     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8766     if (NarrowBitWidth % 8 != 0)
8767       return std::nullopt;
8768     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8769 
8770     // Check to see if the position of the element in the vector corresponds
8771     // with the byte we are trying to provide for. In the case of a vector of
8772     // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
8773     // the element will provide a range of bytes. For example, if we have a
8774     // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
8775     // 3).
8776     if (*VectorIndex * NarrowByteWidth > StartingIndex)
8777       return std::nullopt;
8778     if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
8779       return std::nullopt;
8780 
8781     return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
8782                                  VectorIndex, StartingIndex);
8783   }
8784   case ISD::LOAD: {
8785     auto L = cast<LoadSDNode>(Op.getNode());
8786     if (!L->isSimple() || L->isIndexed())
8787       return std::nullopt;
8788 
8789     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
8790     if (NarrowBitWidth % 8 != 0)
8791       return std::nullopt;
8792     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8793 
8794     // If the width of the load does not reach byte we are trying to provide for
8795     // and it is not a ZEXTLOAD, then the load does not provide for the byte in
8796     // question
8797     if (Index >= NarrowByteWidth)
8798       return L->getExtensionType() == ISD::ZEXTLOAD
8799                  ? std::optional<SDByteProvider>(
8800                        SDByteProvider::getConstantZero())
8801                  : std::nullopt;
8802 
8803     unsigned BPVectorIndex = VectorIndex.value_or(0U);
8804     return SDByteProvider::getSrc(L, Index, BPVectorIndex);
8805   }
8806   }
8807 
8808   return std::nullopt;
8809 }
8810 
littleEndianByteAt(unsigned BW,unsigned i)8811 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
8812   return i;
8813 }
8814 
bigEndianByteAt(unsigned BW,unsigned i)8815 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
8816   return BW - i - 1;
8817 }
8818 
8819 // Check if the bytes offsets we are looking at match with either big or
8820 // little endian value loaded. Return true for big endian, false for little
8821 // endian, and std::nullopt if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)8822 static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
8823                                        int64_t FirstOffset) {
8824   // The endian can be decided only when it is 2 bytes at least.
8825   unsigned Width = ByteOffsets.size();
8826   if (Width < 2)
8827     return std::nullopt;
8828 
8829   bool BigEndian = true, LittleEndian = true;
8830   for (unsigned i = 0; i < Width; i++) {
8831     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
8832     LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
8833     BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
8834     if (!BigEndian && !LittleEndian)
8835       return std::nullopt;
8836   }
8837 
8838   assert((BigEndian != LittleEndian) && "It should be either big endian or"
8839                                         "little endian");
8840   return BigEndian;
8841 }
8842 
stripTruncAndExt(SDValue Value)8843 static SDValue stripTruncAndExt(SDValue Value) {
8844   switch (Value.getOpcode()) {
8845   case ISD::TRUNCATE:
8846   case ISD::ZERO_EXTEND:
8847   case ISD::SIGN_EXTEND:
8848   case ISD::ANY_EXTEND:
8849     return stripTruncAndExt(Value.getOperand(0));
8850   }
8851   return Value;
8852 }
8853 
8854 /// Match a pattern where a wide type scalar value is stored by several narrow
8855 /// stores. Fold it into a single store or a BSWAP and a store if the targets
8856 /// supports it.
8857 ///
8858 /// Assuming little endian target:
8859 ///  i8 *p = ...
8860 ///  i32 val = ...
8861 ///  p[0] = (val >> 0) & 0xFF;
8862 ///  p[1] = (val >> 8) & 0xFF;
8863 ///  p[2] = (val >> 16) & 0xFF;
8864 ///  p[3] = (val >> 24) & 0xFF;
8865 /// =>
8866 ///  *((i32)p) = val;
8867 ///
8868 ///  i8 *p = ...
8869 ///  i32 val = ...
8870 ///  p[0] = (val >> 24) & 0xFF;
8871 ///  p[1] = (val >> 16) & 0xFF;
8872 ///  p[2] = (val >> 8) & 0xFF;
8873 ///  p[3] = (val >> 0) & 0xFF;
8874 /// =>
8875 ///  *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)8876 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
8877   // The matching looks for "store (trunc x)" patterns that appear early but are
8878   // likely to be replaced by truncating store nodes during combining.
8879   // TODO: If there is evidence that running this later would help, this
8880   //       limitation could be removed. Legality checks may need to be added
8881   //       for the created store and optional bswap/rotate.
8882   if (LegalOperations || OptLevel == CodeGenOptLevel::None)
8883     return SDValue();
8884 
8885   // We only handle merging simple stores of 1-4 bytes.
8886   // TODO: Allow unordered atomics when wider type is legal (see D66309)
8887   EVT MemVT = N->getMemoryVT();
8888   if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
8889       !N->isSimple() || N->isIndexed())
8890     return SDValue();
8891 
8892   // Collect all of the stores in the chain, upto the maximum store width (i64).
8893   SDValue Chain = N->getChain();
8894   SmallVector<StoreSDNode *, 8> Stores = {N};
8895   unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
8896   unsigned MaxWideNumBits = 64;
8897   unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
8898   while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
8899     // All stores must be the same size to ensure that we are writing all of the
8900     // bytes in the wide value.
8901     // This store should have exactly one use as a chain operand for another
8902     // store in the merging set. If there are other chain uses, then the
8903     // transform may not be safe because order of loads/stores outside of this
8904     // set may not be preserved.
8905     // TODO: We could allow multiple sizes by tracking each stored byte.
8906     if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
8907         Store->isIndexed() || !Store->hasOneUse())
8908       return SDValue();
8909     Stores.push_back(Store);
8910     Chain = Store->getChain();
8911     if (MaxStores < Stores.size())
8912       return SDValue();
8913   }
8914   // There is no reason to continue if we do not have at least a pair of stores.
8915   if (Stores.size() < 2)
8916     return SDValue();
8917 
8918   // Handle simple types only.
8919   LLVMContext &Context = *DAG.getContext();
8920   unsigned NumStores = Stores.size();
8921   unsigned WideNumBits = NumStores * NarrowNumBits;
8922   EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
8923   if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
8924     return SDValue();
8925 
8926   // Check if all bytes of the source value that we are looking at are stored
8927   // to the same base address. Collect offsets from Base address into OffsetMap.
8928   SDValue SourceValue;
8929   SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
8930   int64_t FirstOffset = INT64_MAX;
8931   StoreSDNode *FirstStore = nullptr;
8932   std::optional<BaseIndexOffset> Base;
8933   for (auto *Store : Stores) {
8934     // All the stores store different parts of the CombinedValue. A truncate is
8935     // required to get the partial value.
8936     SDValue Trunc = Store->getValue();
8937     if (Trunc.getOpcode() != ISD::TRUNCATE)
8938       return SDValue();
8939     // Other than the first/last part, a shift operation is required to get the
8940     // offset.
8941     int64_t Offset = 0;
8942     SDValue WideVal = Trunc.getOperand(0);
8943     if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
8944         isa<ConstantSDNode>(WideVal.getOperand(1))) {
8945       // The shift amount must be a constant multiple of the narrow type.
8946       // It is translated to the offset address in the wide source value "y".
8947       //
8948       // x = srl y, ShiftAmtC
8949       // i8 z = trunc x
8950       // store z, ...
8951       uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
8952       if (ShiftAmtC % NarrowNumBits != 0)
8953         return SDValue();
8954 
8955       Offset = ShiftAmtC / NarrowNumBits;
8956       WideVal = WideVal.getOperand(0);
8957     }
8958 
8959     // Stores must share the same source value with different offsets.
8960     // Truncate and extends should be stripped to get the single source value.
8961     if (!SourceValue)
8962       SourceValue = WideVal;
8963     else if (stripTruncAndExt(SourceValue) != stripTruncAndExt(WideVal))
8964       return SDValue();
8965     else if (SourceValue.getValueType() != WideVT) {
8966       if (WideVal.getValueType() == WideVT ||
8967           WideVal.getScalarValueSizeInBits() >
8968               SourceValue.getScalarValueSizeInBits())
8969         SourceValue = WideVal;
8970       // Give up if the source value type is smaller than the store size.
8971       if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8972         return SDValue();
8973     }
8974 
8975     // Stores must share the same base address.
8976     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
8977     int64_t ByteOffsetFromBase = 0;
8978     if (!Base)
8979       Base = Ptr;
8980     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8981       return SDValue();
8982 
8983     // Remember the first store.
8984     if (ByteOffsetFromBase < FirstOffset) {
8985       FirstStore = Store;
8986       FirstOffset = ByteOffsetFromBase;
8987     }
8988     // Map the offset in the store and the offset in the combined value, and
8989     // early return if it has been set before.
8990     if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
8991       return SDValue();
8992     OffsetMap[Offset] = ByteOffsetFromBase;
8993   }
8994 
8995   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8996   assert(FirstStore && "First store must be set");
8997 
8998   // Check that a store of the wide type is both allowed and fast on the target
8999   const DataLayout &Layout = DAG.getDataLayout();
9000   unsigned Fast = 0;
9001   bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
9002                                         *FirstStore->getMemOperand(), &Fast);
9003   if (!Allowed || !Fast)
9004     return SDValue();
9005 
9006   // Check if the pieces of the value are going to the expected places in memory
9007   // to merge the stores.
9008   auto checkOffsets = [&](bool MatchLittleEndian) {
9009     if (MatchLittleEndian) {
9010       for (unsigned i = 0; i != NumStores; ++i)
9011         if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9012           return false;
9013     } else { // MatchBigEndian by reversing loop counter.
9014       for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9015         if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9016           return false;
9017     }
9018     return true;
9019   };
9020 
9021   // Check if the offsets line up for the native data layout of this target.
9022   bool NeedBswap = false;
9023   bool NeedRotate = false;
9024   if (!checkOffsets(Layout.isLittleEndian())) {
9025     // Special-case: check if byte offsets line up for the opposite endian.
9026     if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9027       NeedBswap = true;
9028     else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9029       NeedRotate = true;
9030     else
9031       return SDValue();
9032   }
9033 
9034   SDLoc DL(N);
9035   if (WideVT != SourceValue.getValueType()) {
9036     assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9037            "Unexpected store value to merge");
9038     SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
9039   }
9040 
9041   // Before legalize we can introduce illegal bswaps/rotates which will be later
9042   // converted to an explicit bswap sequence. This way we end up with a single
9043   // store and byte shuffling instead of several stores and byte shuffling.
9044   if (NeedBswap) {
9045     SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
9046   } else if (NeedRotate) {
9047     assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9048     SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
9049     SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
9050   }
9051 
9052   SDValue NewStore =
9053       DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
9054                    FirstStore->getPointerInfo(), FirstStore->getAlign());
9055 
9056   // Rely on other DAG combine rules to remove the other individual stores.
9057   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
9058   return NewStore;
9059 }
9060 
9061 /// Match a pattern where a wide type scalar value is loaded by several narrow
9062 /// loads and combined by shifts and ors. Fold it into a single load or a load
9063 /// and a BSWAP if the targets supports it.
9064 ///
9065 /// Assuming little endian target:
9066 ///  i8 *a = ...
9067 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9068 /// =>
9069 ///  i32 val = *((i32)a)
9070 ///
9071 ///  i8 *a = ...
9072 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9073 /// =>
9074 ///  i32 val = BSWAP(*((i32)a))
9075 ///
9076 /// TODO: This rule matches complex patterns with OR node roots and doesn't
9077 /// interact well with the worklist mechanism. When a part of the pattern is
9078 /// updated (e.g. one of the loads) its direct users are put into the worklist,
9079 /// but the root node of the pattern which triggers the load combine is not
9080 /// necessarily a direct user of the changed node. For example, once the address
9081 /// of t28 load is reassociated load combine won't be triggered:
9082 ///             t25: i32 = add t4, Constant:i32<2>
9083 ///           t26: i64 = sign_extend t25
9084 ///        t27: i64 = add t2, t26
9085 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9086 ///     t29: i32 = zero_extend t28
9087 ///   t32: i32 = shl t29, Constant:i8<8>
9088 /// t33: i32 = or t23, t32
9089 /// As a possible fix visitLoad can check if the load can be a part of a load
9090 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)9091 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9092   assert(N->getOpcode() == ISD::OR &&
9093          "Can only match load combining against OR nodes");
9094 
9095   // Handles simple types only
9096   EVT VT = N->getValueType(0);
9097   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9098     return SDValue();
9099   unsigned ByteWidth = VT.getSizeInBits() / 8;
9100 
9101   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9102   auto MemoryByteOffset = [&](SDByteProvider P) {
9103     assert(P.hasSrc() && "Must be a memory byte provider");
9104     auto *Load = cast<LoadSDNode>(P.Src.value());
9105 
9106     unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9107 
9108     assert(LoadBitWidth % 8 == 0 &&
9109            "can only analyze providers for individual bytes not bit");
9110     unsigned LoadByteWidth = LoadBitWidth / 8;
9111     return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset)
9112                              : littleEndianByteAt(LoadByteWidth, P.DestOffset);
9113   };
9114 
9115   std::optional<BaseIndexOffset> Base;
9116   SDValue Chain;
9117 
9118   SmallPtrSet<LoadSDNode *, 8> Loads;
9119   std::optional<SDByteProvider> FirstByteProvider;
9120   int64_t FirstOffset = INT64_MAX;
9121 
9122   // Check if all the bytes of the OR we are looking at are loaded from the same
9123   // base address. Collect bytes offsets from Base address in ByteOffsets.
9124   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9125   unsigned ZeroExtendedBytes = 0;
9126   for (int i = ByteWidth - 1; i >= 0; --i) {
9127     auto P =
9128         calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
9129                               /*StartingIndex*/ i);
9130     if (!P)
9131       return SDValue();
9132 
9133     if (P->isConstantZero()) {
9134       // It's OK for the N most significant bytes to be 0, we can just
9135       // zero-extend the load.
9136       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9137         return SDValue();
9138       continue;
9139     }
9140     assert(P->hasSrc() && "provenance should either be memory or zero");
9141     auto *L = cast<LoadSDNode>(P->Src.value());
9142 
9143     // All loads must share the same chain
9144     SDValue LChain = L->getChain();
9145     if (!Chain)
9146       Chain = LChain;
9147     else if (Chain != LChain)
9148       return SDValue();
9149 
9150     // Loads must share the same base address
9151     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
9152     int64_t ByteOffsetFromBase = 0;
9153 
9154     // For vector loads, the expected load combine pattern will have an
9155     // ExtractElement for each index in the vector. While each of these
9156     // ExtractElements will be accessing the same base address as determined
9157     // by the load instruction, the actual bytes they interact with will differ
9158     // due to different ExtractElement indices. To accurately determine the
9159     // byte position of an ExtractElement, we offset the base load ptr with
9160     // the index multiplied by the byte size of each element in the vector.
9161     if (L->getMemoryVT().isVector()) {
9162       unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9163       if (LoadWidthInBit % 8 != 0)
9164         return SDValue();
9165       unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9166       Ptr.addToOffset(ByteOffsetFromVector);
9167     }
9168 
9169     if (!Base)
9170       Base = Ptr;
9171 
9172     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9173       return SDValue();
9174 
9175     // Calculate the offset of the current byte from the base address
9176     ByteOffsetFromBase += MemoryByteOffset(*P);
9177     ByteOffsets[i] = ByteOffsetFromBase;
9178 
9179     // Remember the first byte load
9180     if (ByteOffsetFromBase < FirstOffset) {
9181       FirstByteProvider = P;
9182       FirstOffset = ByteOffsetFromBase;
9183     }
9184 
9185     Loads.insert(L);
9186   }
9187 
9188   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9189          "memory, so there must be at least one load which produces the value");
9190   assert(Base && "Base address of the accessed memory location must be set");
9191   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9192 
9193   bool NeedsZext = ZeroExtendedBytes > 0;
9194 
9195   EVT MemVT =
9196       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
9197 
9198   if (!MemVT.isSimple())
9199     return SDValue();
9200 
9201   // Before legalize we can introduce too wide illegal loads which will be later
9202   // split into legal sized loads. This enables us to combine i64 load by i8
9203   // patterns to a couple of i32 loads on 32 bit targets.
9204   if (LegalOperations &&
9205       !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
9206                             MemVT))
9207     return SDValue();
9208 
9209   // Check if the bytes of the OR we are looking at match with either big or
9210   // little endian value load
9211   std::optional<bool> IsBigEndian = isBigEndian(
9212       ArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
9213   if (!IsBigEndian)
9214     return SDValue();
9215 
9216   assert(FirstByteProvider && "must be set");
9217 
9218   // Ensure that the first byte is loaded from zero offset of the first load.
9219   // So the combined value can be loaded from the first load address.
9220   if (MemoryByteOffset(*FirstByteProvider) != 0)
9221     return SDValue();
9222   auto *FirstLoad = cast<LoadSDNode>(FirstByteProvider->Src.value());
9223 
9224   // The node we are looking at matches with the pattern, check if we can
9225   // replace it with a single (possibly zero-extended) load and bswap + shift if
9226   // needed.
9227 
9228   // If the load needs byte swap check if the target supports it
9229   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9230 
9231   // Before legalize we can introduce illegal bswaps which will be later
9232   // converted to an explicit bswap sequence. This way we end up with a single
9233   // load and byte shuffling instead of several loads and byte shuffling.
9234   // We do not introduce illegal bswaps when zero-extending as this tends to
9235   // introduce too many arithmetic instructions.
9236   if (NeedsBswap && (LegalOperations || NeedsZext) &&
9237       !TLI.isOperationLegal(ISD::BSWAP, VT))
9238     return SDValue();
9239 
9240   // If we need to bswap and zero extend, we have to insert a shift. Check that
9241   // it is legal.
9242   if (NeedsBswap && NeedsZext && LegalOperations &&
9243       !TLI.isOperationLegal(ISD::SHL, VT))
9244     return SDValue();
9245 
9246   // Check that a load of the wide type is both allowed and fast on the target
9247   unsigned Fast = 0;
9248   bool Allowed =
9249       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
9250                              *FirstLoad->getMemOperand(), &Fast);
9251   if (!Allowed || !Fast)
9252     return SDValue();
9253 
9254   SDValue NewLoad =
9255       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
9256                      Chain, FirstLoad->getBasePtr(),
9257                      FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
9258 
9259   // Transfer chain users from old loads to the new load.
9260   for (LoadSDNode *L : Loads)
9261     DAG.makeEquivalentMemoryOrdering(L, NewLoad);
9262 
9263   if (!NeedsBswap)
9264     return NewLoad;
9265 
9266   SDValue ShiftedLoad =
9267       NeedsZext
9268           ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
9269                         DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
9270                                                    SDLoc(N), LegalOperations))
9271           : NewLoad;
9272   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
9273 }
9274 
9275 // If the target has andn, bsl, or a similar bit-select instruction,
9276 // we want to unfold masked merge, with canonical pattern of:
9277 //   |        A  |  |B|
9278 //   ((x ^ y) & m) ^ y
9279 //    |  D  |
9280 // Into:
9281 //   (x & m) | (y & ~m)
9282 // If y is a constant, m is not a 'not', and the 'andn' does not work with
9283 // immediates, we unfold into a different pattern:
9284 //   ~(~x & m) & (m | y)
9285 // If x is a constant, m is a 'not', and the 'andn' does not work with
9286 // immediates, we unfold into a different pattern:
9287 //   (x | ~m) & ~(~m & ~y)
9288 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9289 //       the very least that breaks andnpd / andnps patterns, and because those
9290 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)9291 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9292   assert(N->getOpcode() == ISD::XOR);
9293 
9294   // Don't touch 'not' (i.e. where y = -1).
9295   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
9296     return SDValue();
9297 
9298   EVT VT = N->getValueType(0);
9299 
9300   // There are 3 commutable operators in the pattern,
9301   // so we have to deal with 8 possible variants of the basic pattern.
9302   SDValue X, Y, M;
9303   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9304     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
9305       return false;
9306     SDValue Xor = And.getOperand(XorIdx);
9307     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9308       return false;
9309     SDValue Xor0 = Xor.getOperand(0);
9310     SDValue Xor1 = Xor.getOperand(1);
9311     // Don't touch 'not' (i.e. where y = -1).
9312     if (isAllOnesOrAllOnesSplat(Xor1))
9313       return false;
9314     if (Other == Xor0)
9315       std::swap(Xor0, Xor1);
9316     if (Other != Xor1)
9317       return false;
9318     X = Xor0;
9319     Y = Xor1;
9320     M = And.getOperand(XorIdx ? 0 : 1);
9321     return true;
9322   };
9323 
9324   SDValue N0 = N->getOperand(0);
9325   SDValue N1 = N->getOperand(1);
9326   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
9327       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
9328     return SDValue();
9329 
9330   // Don't do anything if the mask is constant. This should not be reachable.
9331   // InstCombine should have already unfolded this pattern, and DAGCombiner
9332   // probably shouldn't produce it, too.
9333   if (isa<ConstantSDNode>(M.getNode()))
9334     return SDValue();
9335 
9336   // We can transform if the target has AndNot
9337   if (!TLI.hasAndNot(M))
9338     return SDValue();
9339 
9340   SDLoc DL(N);
9341 
9342   // If Y is a constant, check that 'andn' works with immediates. Unless M is
9343   // a bitwise not that would already allow ANDN to be used.
9344   if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
9345     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
9346     // If not, we need to do a bit more work to make sure andn is still used.
9347     SDValue NotX = DAG.getNOT(DL, X, VT);
9348     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
9349     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
9350     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
9351     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
9352   }
9353 
9354   // If X is a constant and M is a bitwise not, check that 'andn' works with
9355   // immediates.
9356   if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
9357     assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
9358     // If not, we need to do a bit more work to make sure andn is still used.
9359     SDValue NotM = M.getOperand(0);
9360     SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
9361     SDValue NotY = DAG.getNOT(DL, Y, VT);
9362     SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
9363     SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
9364     return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
9365   }
9366 
9367   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
9368   SDValue NotM = DAG.getNOT(DL, M, VT);
9369   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
9370 
9371   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
9372 }
9373 
visitXOR(SDNode * N)9374 SDValue DAGCombiner::visitXOR(SDNode *N) {
9375   SDValue N0 = N->getOperand(0);
9376   SDValue N1 = N->getOperand(1);
9377   EVT VT = N0.getValueType();
9378   SDLoc DL(N);
9379 
9380   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
9381   if (N0.isUndef() && N1.isUndef())
9382     return DAG.getConstant(0, DL, VT);
9383 
9384   // fold (xor x, undef) -> undef
9385   if (N0.isUndef())
9386     return N0;
9387   if (N1.isUndef())
9388     return N1;
9389 
9390   // fold (xor c1, c2) -> c1^c2
9391   if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
9392     return C;
9393 
9394   // canonicalize constant to RHS
9395   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
9396       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
9397     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
9398 
9399   // fold vector ops
9400   if (VT.isVector()) {
9401     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9402       return FoldedVOp;
9403 
9404     // fold (xor x, 0) -> x, vector edition
9405     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
9406       return N0;
9407   }
9408 
9409   // fold (xor x, 0) -> x
9410   if (isNullConstant(N1))
9411     return N0;
9412 
9413   if (SDValue NewSel = foldBinOpIntoSelect(N))
9414     return NewSel;
9415 
9416   // reassociate xor
9417   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
9418     return RXOR;
9419 
9420   // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
9421   if (SDValue SD =
9422           reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1))
9423     return SD;
9424 
9425   // fold (a^b) -> (a|b) iff a and b share no bits.
9426   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
9427       DAG.haveNoCommonBitsSet(N0, N1))
9428     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
9429 
9430   // look for 'add-like' folds:
9431   // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
9432   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
9433       isMinSignedConstant(N1))
9434     if (SDValue Combined = visitADDLike(N))
9435       return Combined;
9436 
9437   // fold !(x cc y) -> (x !cc y)
9438   unsigned N0Opcode = N0.getOpcode();
9439   SDValue LHS, RHS, CC;
9440   if (TLI.isConstTrueVal(N1) &&
9441       isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
9442     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
9443                                                LHS.getValueType());
9444     if (!LegalOperations ||
9445         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
9446       switch (N0Opcode) {
9447       default:
9448         llvm_unreachable("Unhandled SetCC Equivalent!");
9449       case ISD::SETCC:
9450         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
9451       case ISD::SELECT_CC:
9452         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
9453                                N0.getOperand(3), NotCC);
9454       case ISD::STRICT_FSETCC:
9455       case ISD::STRICT_FSETCCS: {
9456         if (N0.hasOneUse()) {
9457           // FIXME Can we handle multiple uses? Could we token factor the chain
9458           // results from the new/old setcc?
9459           SDValue SetCC =
9460               DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
9461                            N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
9462           CombineTo(N, SetCC);
9463           DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
9464           recursivelyDeleteUnusedNodes(N0.getNode());
9465           return SDValue(N, 0); // Return N so it doesn't get rechecked!
9466         }
9467         break;
9468       }
9469       }
9470     }
9471   }
9472 
9473   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
9474   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9475       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
9476     SDValue V = N0.getOperand(0);
9477     SDLoc DL0(N0);
9478     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
9479                     DAG.getConstant(1, DL0, V.getValueType()));
9480     AddToWorklist(V.getNode());
9481     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
9482   }
9483 
9484   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
9485   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
9486       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9487     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
9488     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
9489       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9490       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
9491       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
9492       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
9493       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
9494     }
9495   }
9496   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
9497   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
9498       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9499     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
9500     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
9501       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9502       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
9503       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
9504       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
9505       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
9506     }
9507   }
9508 
9509   // fold (not (neg x)) -> (add X, -1)
9510   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
9511   // Y is a constant or the subtract has a single use.
9512   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
9513       isNullConstant(N0.getOperand(0))) {
9514     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
9515                        DAG.getAllOnesConstant(DL, VT));
9516   }
9517 
9518   // fold (not (add X, -1)) -> (neg X)
9519   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
9520       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
9521     return DAG.getNegative(N0.getOperand(0), DL, VT);
9522   }
9523 
9524   // fold (xor (and x, y), y) -> (and (not x), y)
9525   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
9526     SDValue X = N0.getOperand(0);
9527     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
9528     AddToWorklist(NotX.getNode());
9529     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
9530   }
9531 
9532   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
9533   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
9534     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
9535     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
9536     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
9537       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
9538       SDValue S0 = S.getOperand(0);
9539       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
9540         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
9541           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
9542             return DAG.getNode(ISD::ABS, DL, VT, S0);
9543     }
9544   }
9545 
9546   // fold (xor x, x) -> 0
9547   if (N0 == N1)
9548     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
9549 
9550   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
9551   // Here is a concrete example of this equivalence:
9552   // i16   x ==  14
9553   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
9554   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
9555   //
9556   // =>
9557   //
9558   // i16     ~1      == 0b1111111111111110
9559   // i16 rol(~1, 14) == 0b1011111111111111
9560   //
9561   // Some additional tips to help conceptualize this transform:
9562   // - Try to see the operation as placing a single zero in a value of all ones.
9563   // - There exists no value for x which would allow the result to contain zero.
9564   // - Values of x larger than the bitwidth are undefined and do not require a
9565   //   consistent result.
9566   // - Pushing the zero left requires shifting one bits in from the right.
9567   // A rotate left of ~1 is a nice way of achieving the desired result.
9568   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
9569       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
9570     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
9571                        N0.getOperand(1));
9572   }
9573 
9574   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
9575   if (N0Opcode == N1.getOpcode())
9576     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
9577       return V;
9578 
9579   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
9580     return R;
9581   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
9582     return R;
9583   if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
9584     return R;
9585 
9586   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
9587   if (SDValue MM = unfoldMaskedMerge(N))
9588     return MM;
9589 
9590   // Simplify the expression using non-local knowledge.
9591   if (SimplifyDemandedBits(SDValue(N, 0)))
9592     return SDValue(N, 0);
9593 
9594   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
9595     return Combined;
9596 
9597   return SDValue();
9598 }
9599 
9600 /// If we have a shift-by-constant of a bitwise logic op that itself has a
9601 /// shift-by-constant operand with identical opcode, we may be able to convert
9602 /// that into 2 independent shifts followed by the logic op. This is a
9603 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)9604 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
9605   // Match a one-use bitwise logic op.
9606   SDValue LogicOp = Shift->getOperand(0);
9607   if (!LogicOp.hasOneUse())
9608     return SDValue();
9609 
9610   unsigned LogicOpcode = LogicOp.getOpcode();
9611   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
9612       LogicOpcode != ISD::XOR)
9613     return SDValue();
9614 
9615   // Find a matching one-use shift by constant.
9616   unsigned ShiftOpcode = Shift->getOpcode();
9617   SDValue C1 = Shift->getOperand(1);
9618   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
9619   assert(C1Node && "Expected a shift with constant operand");
9620   const APInt &C1Val = C1Node->getAPIntValue();
9621   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
9622                              const APInt *&ShiftAmtVal) {
9623     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
9624       return false;
9625 
9626     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
9627     if (!ShiftCNode)
9628       return false;
9629 
9630     // Capture the shifted operand and shift amount value.
9631     ShiftOp = V.getOperand(0);
9632     ShiftAmtVal = &ShiftCNode->getAPIntValue();
9633 
9634     // Shift amount types do not have to match their operand type, so check that
9635     // the constants are the same width.
9636     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
9637       return false;
9638 
9639     // The fold is not valid if the sum of the shift values doesn't fit in the
9640     // given shift amount type.
9641     bool Overflow = false;
9642     APInt NewShiftAmt = C1Val.uadd_ov(*ShiftAmtVal, Overflow);
9643     if (Overflow)
9644       return false;
9645 
9646     // The fold is not valid if the sum of the shift values exceeds bitwidth.
9647     if (NewShiftAmt.uge(V.getScalarValueSizeInBits()))
9648       return false;
9649 
9650     return true;
9651   };
9652 
9653   // Logic ops are commutative, so check each operand for a match.
9654   SDValue X, Y;
9655   const APInt *C0Val;
9656   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
9657     Y = LogicOp.getOperand(1);
9658   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
9659     Y = LogicOp.getOperand(0);
9660   else
9661     return SDValue();
9662 
9663   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
9664   SDLoc DL(Shift);
9665   EVT VT = Shift->getValueType(0);
9666   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
9667   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
9668   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
9669   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
9670   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
9671 }
9672 
9673 /// Handle transforms common to the three shifts, when the shift amount is a
9674 /// constant.
9675 /// We are looking for: (shift being one of shl/sra/srl)
9676 ///   shift (binop X, C0), C1
9677 /// And want to transform into:
9678 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)9679 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
9680   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
9681 
9682   // Do not turn a 'not' into a regular xor.
9683   if (isBitwiseNot(N->getOperand(0)))
9684     return SDValue();
9685 
9686   // The inner binop must be one-use, since we want to replace it.
9687   SDValue LHS = N->getOperand(0);
9688   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
9689     return SDValue();
9690 
9691   // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
9692   if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
9693     return R;
9694 
9695   // We want to pull some binops through shifts, so that we have (and (shift))
9696   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
9697   // thing happens with address calculations, so it's important to canonicalize
9698   // it.
9699   switch (LHS.getOpcode()) {
9700   default:
9701     return SDValue();
9702   case ISD::OR:
9703   case ISD::XOR:
9704   case ISD::AND:
9705     break;
9706   case ISD::ADD:
9707     if (N->getOpcode() != ISD::SHL)
9708       return SDValue(); // only shl(add) not sr[al](add).
9709     break;
9710   }
9711 
9712   // FIXME: disable this unless the input to the binop is a shift by a constant
9713   // or is copy/select. Enable this in other cases when figure out it's exactly
9714   // profitable.
9715   SDValue BinOpLHSVal = LHS.getOperand(0);
9716   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
9717                             BinOpLHSVal.getOpcode() == ISD::SRA ||
9718                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
9719                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
9720   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
9721                         BinOpLHSVal.getOpcode() == ISD::SELECT;
9722 
9723   if (!IsShiftByConstant && !IsCopyOrSelect)
9724     return SDValue();
9725 
9726   if (IsCopyOrSelect && N->hasOneUse())
9727     return SDValue();
9728 
9729   // Attempt to fold the constants, shifting the binop RHS by the shift amount.
9730   SDLoc DL(N);
9731   EVT VT = N->getValueType(0);
9732   if (SDValue NewRHS = DAG.FoldConstantArithmetic(
9733           N->getOpcode(), DL, VT, {LHS.getOperand(1), N->getOperand(1)})) {
9734     SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
9735                                    N->getOperand(1));
9736     return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
9737   }
9738 
9739   return SDValue();
9740 }
9741 
distributeTruncateThroughAnd(SDNode * N)9742 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
9743   assert(N->getOpcode() == ISD::TRUNCATE);
9744   assert(N->getOperand(0).getOpcode() == ISD::AND);
9745 
9746   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
9747   EVT TruncVT = N->getValueType(0);
9748   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
9749       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
9750     SDValue N01 = N->getOperand(0).getOperand(1);
9751     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
9752       SDLoc DL(N);
9753       SDValue N00 = N->getOperand(0).getOperand(0);
9754       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
9755       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
9756       AddToWorklist(Trunc00.getNode());
9757       AddToWorklist(Trunc01.getNode());
9758       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
9759     }
9760   }
9761 
9762   return SDValue();
9763 }
9764 
visitRotate(SDNode * N)9765 SDValue DAGCombiner::visitRotate(SDNode *N) {
9766   SDLoc dl(N);
9767   SDValue N0 = N->getOperand(0);
9768   SDValue N1 = N->getOperand(1);
9769   EVT VT = N->getValueType(0);
9770   unsigned Bitsize = VT.getScalarSizeInBits();
9771 
9772   // fold (rot x, 0) -> x
9773   if (isNullOrNullSplat(N1))
9774     return N0;
9775 
9776   // fold (rot x, c) -> x iff (c % BitSize) == 0
9777   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
9778     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
9779     if (DAG.MaskedValueIsZero(N1, ModuloMask))
9780       return N0;
9781   }
9782 
9783   // fold (rot x, c) -> (rot x, c % BitSize)
9784   bool OutOfRange = false;
9785   auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
9786     OutOfRange |= C->getAPIntValue().uge(Bitsize);
9787     return true;
9788   };
9789   if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
9790     EVT AmtVT = N1.getValueType();
9791     SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
9792     if (SDValue Amt =
9793             DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
9794       return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
9795   }
9796 
9797   // rot i16 X, 8 --> bswap X
9798   auto *RotAmtC = isConstOrConstSplat(N1);
9799   if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
9800       VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
9801     return DAG.getNode(ISD::BSWAP, dl, VT, N0);
9802 
9803   // Simplify the operands using demanded-bits information.
9804   if (SimplifyDemandedBits(SDValue(N, 0)))
9805     return SDValue(N, 0);
9806 
9807   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
9808   if (N1.getOpcode() == ISD::TRUNCATE &&
9809       N1.getOperand(0).getOpcode() == ISD::AND) {
9810     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9811       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
9812   }
9813 
9814   unsigned NextOp = N0.getOpcode();
9815 
9816   // fold (rot* (rot* x, c2), c1)
9817   //   -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
9818   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
9819     SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
9820     SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
9821     if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
9822       EVT ShiftVT = C1->getValueType(0);
9823       bool SameSide = (N->getOpcode() == NextOp);
9824       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
9825       SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
9826       SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
9827                                                  {N1, BitsizeC});
9828       SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
9829                                                  {N0.getOperand(1), BitsizeC});
9830       if (Norm1 && Norm2)
9831         if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
9832                 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
9833           CombinedShift = DAG.FoldConstantArithmetic(ISD::ADD, dl, ShiftVT,
9834                                                      {CombinedShift, BitsizeC});
9835           SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
9836               ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
9837           return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
9838                              CombinedShiftNorm);
9839         }
9840     }
9841   }
9842   return SDValue();
9843 }
9844 
visitSHL(SDNode * N)9845 SDValue DAGCombiner::visitSHL(SDNode *N) {
9846   SDValue N0 = N->getOperand(0);
9847   SDValue N1 = N->getOperand(1);
9848   if (SDValue V = DAG.simplifyShift(N0, N1))
9849     return V;
9850 
9851   EVT VT = N0.getValueType();
9852   EVT ShiftVT = N1.getValueType();
9853   unsigned OpSizeInBits = VT.getScalarSizeInBits();
9854 
9855   // fold (shl c1, c2) -> c1<<c2
9856   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
9857     return C;
9858 
9859   // fold vector ops
9860   if (VT.isVector()) {
9861     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9862       return FoldedVOp;
9863 
9864     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
9865     // If setcc produces all-one true value then:
9866     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
9867     if (N1CV && N1CV->isConstant()) {
9868       if (N0.getOpcode() == ISD::AND) {
9869         SDValue N00 = N0->getOperand(0);
9870         SDValue N01 = N0->getOperand(1);
9871         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
9872 
9873         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
9874             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
9875                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
9876           if (SDValue C =
9877                   DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N01, N1}))
9878             return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
9879         }
9880       }
9881     }
9882   }
9883 
9884   if (SDValue NewSel = foldBinOpIntoSelect(N))
9885     return NewSel;
9886 
9887   // if (shl x, c) is known to be zero, return 0
9888   if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
9889     return DAG.getConstant(0, SDLoc(N), VT);
9890 
9891   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
9892   if (N1.getOpcode() == ISD::TRUNCATE &&
9893       N1.getOperand(0).getOpcode() == ISD::AND) {
9894     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9895       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
9896   }
9897 
9898   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
9899   if (N0.getOpcode() == ISD::SHL) {
9900     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9901                                           ConstantSDNode *RHS) {
9902       APInt c1 = LHS->getAPIntValue();
9903       APInt c2 = RHS->getAPIntValue();
9904       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9905       return (c1 + c2).uge(OpSizeInBits);
9906     };
9907     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
9908       return DAG.getConstant(0, SDLoc(N), VT);
9909 
9910     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9911                                        ConstantSDNode *RHS) {
9912       APInt c1 = LHS->getAPIntValue();
9913       APInt c2 = RHS->getAPIntValue();
9914       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9915       return (c1 + c2).ult(OpSizeInBits);
9916     };
9917     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
9918       SDLoc DL(N);
9919       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
9920       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
9921     }
9922   }
9923 
9924   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
9925   // For this to be valid, the second form must not preserve any of the bits
9926   // that are shifted out by the inner shift in the first form.  This means
9927   // the outer shift size must be >= the number of bits added by the ext.
9928   // As a corollary, we don't care what kind of ext it is.
9929   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
9930        N0.getOpcode() == ISD::ANY_EXTEND ||
9931        N0.getOpcode() == ISD::SIGN_EXTEND) &&
9932       N0.getOperand(0).getOpcode() == ISD::SHL) {
9933     SDValue N0Op0 = N0.getOperand(0);
9934     SDValue InnerShiftAmt = N0Op0.getOperand(1);
9935     EVT InnerVT = N0Op0.getValueType();
9936     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
9937 
9938     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9939                                                          ConstantSDNode *RHS) {
9940       APInt c1 = LHS->getAPIntValue();
9941       APInt c2 = RHS->getAPIntValue();
9942       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9943       return c2.uge(OpSizeInBits - InnerBitwidth) &&
9944              (c1 + c2).uge(OpSizeInBits);
9945     };
9946     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
9947                                   /*AllowUndefs*/ false,
9948                                   /*AllowTypeMismatch*/ true))
9949       return DAG.getConstant(0, SDLoc(N), VT);
9950 
9951     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9952                                                       ConstantSDNode *RHS) {
9953       APInt c1 = LHS->getAPIntValue();
9954       APInt c2 = RHS->getAPIntValue();
9955       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9956       return c2.uge(OpSizeInBits - InnerBitwidth) &&
9957              (c1 + c2).ult(OpSizeInBits);
9958     };
9959     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
9960                                   /*AllowUndefs*/ false,
9961                                   /*AllowTypeMismatch*/ true)) {
9962       SDLoc DL(N);
9963       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
9964       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
9965       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
9966       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
9967     }
9968   }
9969 
9970   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
9971   // Only fold this if the inner zext has no other uses to avoid increasing
9972   // the total number of instructions.
9973   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9974       N0.getOperand(0).getOpcode() == ISD::SRL) {
9975     SDValue N0Op0 = N0.getOperand(0);
9976     SDValue InnerShiftAmt = N0Op0.getOperand(1);
9977 
9978     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9979       APInt c1 = LHS->getAPIntValue();
9980       APInt c2 = RHS->getAPIntValue();
9981       zeroExtendToMatch(c1, c2);
9982       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
9983     };
9984     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
9985                                   /*AllowUndefs*/ false,
9986                                   /*AllowTypeMismatch*/ true)) {
9987       SDLoc DL(N);
9988       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
9989       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
9990       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
9991       AddToWorklist(NewSHL.getNode());
9992       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
9993     }
9994   }
9995 
9996   if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
9997     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9998                                            ConstantSDNode *RHS) {
9999       const APInt &LHSC = LHS->getAPIntValue();
10000       const APInt &RHSC = RHS->getAPIntValue();
10001       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
10002              LHSC.getZExtValue() <= RHSC.getZExtValue();
10003     };
10004 
10005     SDLoc DL(N);
10006 
10007     // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
10008     // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
10009     if (N0->getFlags().hasExact()) {
10010       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10011                                     /*AllowUndefs*/ false,
10012                                     /*AllowTypeMismatch*/ true)) {
10013         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10014         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10015         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10016       }
10017       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10018                                     /*AllowUndefs*/ false,
10019                                     /*AllowTypeMismatch*/ true)) {
10020         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10021         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10022         return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
10023       }
10024     }
10025 
10026     // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10027     //                               (and (srl x, (sub c1, c2), MASK)
10028     // Only fold this if the inner shift has no other uses -- if it does,
10029     // folding this will increase the total number of instructions.
10030     if (N0.getOpcode() == ISD::SRL &&
10031         (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
10032         TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10033       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10034                                     /*AllowUndefs*/ false,
10035                                     /*AllowTypeMismatch*/ true)) {
10036         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10037         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10038         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10039         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
10040         Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
10041         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
10042         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10043       }
10044       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10045                                     /*AllowUndefs*/ false,
10046                                     /*AllowTypeMismatch*/ true)) {
10047         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10048         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10049         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10050         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
10051         SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10052         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10053       }
10054     }
10055   }
10056 
10057   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10058   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
10059       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
10060     SDLoc DL(N);
10061     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10062     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
10063     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
10064   }
10065 
10066   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10067   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10068   // Variant of version done on multiply, except mul by a power of 2 is turned
10069   // into a shift.
10070   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10071       N0->hasOneUse() && TLI.isDesirableToCommuteWithShift(N, Level)) {
10072     SDValue N01 = N0.getOperand(1);
10073     if (SDValue Shl1 =
10074             DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) {
10075       SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
10076       AddToWorklist(Shl0.getNode());
10077       SDNodeFlags Flags;
10078       // Preserve the disjoint flag for Or.
10079       if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10080         Flags.setDisjoint(true);
10081       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1, Flags);
10082     }
10083   }
10084 
10085   // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10086   // TODO: Add zext/add_nuw variant with suitable test coverage
10087   // TODO: Should we limit this with isLegalAddImmediate?
10088   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10089       N0.getOperand(0).getOpcode() == ISD::ADD &&
10090       N0.getOperand(0)->getFlags().hasNoSignedWrap() && N0->hasOneUse() &&
10091       N0.getOperand(0)->hasOneUse() &&
10092       TLI.isDesirableToCommuteWithShift(N, Level)) {
10093     SDValue Add = N0.getOperand(0);
10094     SDLoc DL(N0);
10095     if (SDValue ExtC = DAG.FoldConstantArithmetic(N0.getOpcode(), DL, VT,
10096                                                   {Add.getOperand(1)})) {
10097       if (SDValue ShlC =
10098               DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {ExtC, N1})) {
10099         SDValue ExtX = DAG.getNode(N0.getOpcode(), DL, VT, Add.getOperand(0));
10100         SDValue ShlX = DAG.getNode(ISD::SHL, DL, VT, ExtX, N1);
10101         return DAG.getNode(ISD::ADD, DL, VT, ShlX, ShlC);
10102       }
10103     }
10104   }
10105 
10106   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10107   if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10108     SDValue N01 = N0.getOperand(1);
10109     if (SDValue Shl =
10110             DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
10111       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
10112   }
10113 
10114   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10115   if (N1C && !N1C->isOpaque())
10116     if (SDValue NewSHL = visitShiftByConstant(N))
10117       return NewSHL;
10118 
10119   if (SimplifyDemandedBits(SDValue(N, 0)))
10120     return SDValue(N, 0);
10121 
10122   // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10123   if (N0.getOpcode() == ISD::VSCALE && N1C) {
10124     const APInt &C0 = N0.getConstantOperandAPInt(0);
10125     const APInt &C1 = N1C->getAPIntValue();
10126     return DAG.getVScale(SDLoc(N), VT, C0 << C1);
10127   }
10128 
10129   // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10130   APInt ShlVal;
10131   if (N0.getOpcode() == ISD::STEP_VECTOR &&
10132       ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
10133     const APInt &C0 = N0.getConstantOperandAPInt(0);
10134     if (ShlVal.ult(C0.getBitWidth())) {
10135       APInt NewStep = C0 << ShlVal;
10136       return DAG.getStepVector(SDLoc(N), VT, NewStep);
10137     }
10138   }
10139 
10140   return SDValue();
10141 }
10142 
10143 // Transform a right shift of a multiply into a multiply-high.
10144 // Examples:
10145 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10146 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)10147 static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
10148                                   const TargetLowering &TLI) {
10149   assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10150          "SRL or SRA node is required here!");
10151 
10152   // Check the shift amount. Proceed with the transformation if the shift
10153   // amount is constant.
10154   ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
10155   if (!ShiftAmtSrc)
10156     return SDValue();
10157 
10158   SDLoc DL(N);
10159 
10160   // The operation feeding into the shift must be a multiply.
10161   SDValue ShiftOperand = N->getOperand(0);
10162   if (ShiftOperand.getOpcode() != ISD::MUL)
10163     return SDValue();
10164 
10165   // Both operands must be equivalent extend nodes.
10166   SDValue LeftOp = ShiftOperand.getOperand(0);
10167   SDValue RightOp = ShiftOperand.getOperand(1);
10168 
10169   bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10170   bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10171 
10172   if (!IsSignExt && !IsZeroExt)
10173     return SDValue();
10174 
10175   EVT NarrowVT = LeftOp.getOperand(0).getValueType();
10176   unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10177 
10178   // return true if U may use the lower bits of its operands
10179   auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10180     if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10181       return true;
10182     }
10183     ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(U->getOperand(1));
10184     if (!UShiftAmtSrc) {
10185       return true;
10186     }
10187     unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10188     return UShiftAmt < NarrowVTSize;
10189   };
10190 
10191   // If the lower part of the MUL is also used and MUL_LOHI is supported
10192   // do not introduce the MULH in favor of MUL_LOHI
10193   unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10194   if (!ShiftOperand.hasOneUse() &&
10195       TLI.isOperationLegalOrCustom(MulLoHiOp, NarrowVT) &&
10196       llvm::any_of(ShiftOperand->uses(), UserOfLowerBits)) {
10197     return SDValue();
10198   }
10199 
10200   SDValue MulhRightOp;
10201   if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
10202     unsigned ActiveBits = IsSignExt
10203                               ? Constant->getAPIntValue().getSignificantBits()
10204                               : Constant->getAPIntValue().getActiveBits();
10205     if (ActiveBits > NarrowVTSize)
10206       return SDValue();
10207     MulhRightOp = DAG.getConstant(
10208         Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
10209         NarrowVT);
10210   } else {
10211     if (LeftOp.getOpcode() != RightOp.getOpcode())
10212       return SDValue();
10213     // Check that the two extend nodes are the same type.
10214     if (NarrowVT != RightOp.getOperand(0).getValueType())
10215       return SDValue();
10216     MulhRightOp = RightOp.getOperand(0);
10217   }
10218 
10219   EVT WideVT = LeftOp.getValueType();
10220   // Proceed with the transformation if the wide types match.
10221   assert((WideVT == RightOp.getValueType()) &&
10222          "Cannot have a multiply node with two different operand types.");
10223 
10224   // Proceed with the transformation if the wide type is twice as large
10225   // as the narrow type.
10226   if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
10227     return SDValue();
10228 
10229   // Check the shift amount with the narrow type size.
10230   // Proceed with the transformation if the shift amount is the width
10231   // of the narrow type.
10232   unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
10233   if (ShiftAmt != NarrowVTSize)
10234     return SDValue();
10235 
10236   // If the operation feeding into the MUL is a sign extend (sext),
10237   // we use mulhs. Othewise, zero extends (zext) use mulhu.
10238   unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
10239 
10240   // Combine to mulh if mulh is legal/custom for the narrow type on the target
10241   // or if it is a vector type then we could transform to an acceptable type and
10242   // rely on legalization to split/combine the result.
10243   if (NarrowVT.isVector()) {
10244     EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), NarrowVT);
10245     if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() ||
10246         !TLI.isOperationLegalOrCustom(MulhOpcode, TransformVT))
10247       return SDValue();
10248   } else {
10249     if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
10250       return SDValue();
10251   }
10252 
10253   SDValue Result =
10254       DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
10255   bool IsSigned = N->getOpcode() == ISD::SRA;
10256   return DAG.getExtOrTrunc(IsSigned, Result, DL, WideVT);
10257 }
10258 
10259 // fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
10260 // This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
foldBitOrderCrossLogicOp(SDNode * N,SelectionDAG & DAG)10261 static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
10262   unsigned Opcode = N->getOpcode();
10263   if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
10264     return SDValue();
10265 
10266   SDValue N0 = N->getOperand(0);
10267   EVT VT = N->getValueType(0);
10268   SDLoc DL(N);
10269   if (ISD::isBitwiseLogicOp(N0.getOpcode()) && N0.hasOneUse()) {
10270     SDValue OldLHS = N0.getOperand(0);
10271     SDValue OldRHS = N0.getOperand(1);
10272 
10273     // If both operands are bswap/bitreverse, ignore the multiuse
10274     // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
10275     if (OldLHS.getOpcode() == Opcode && OldRHS.getOpcode() == Opcode) {
10276       return DAG.getNode(N0.getOpcode(), DL, VT, OldLHS.getOperand(0),
10277                          OldRHS.getOperand(0));
10278     }
10279 
10280     if (OldLHS.getOpcode() == Opcode && OldLHS.hasOneUse()) {
10281       SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, OldRHS);
10282       return DAG.getNode(N0.getOpcode(), DL, VT, OldLHS.getOperand(0),
10283                          NewBitReorder);
10284     }
10285 
10286     if (OldRHS.getOpcode() == Opcode && OldRHS.hasOneUse()) {
10287       SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, OldLHS);
10288       return DAG.getNode(N0.getOpcode(), DL, VT, NewBitReorder,
10289                          OldRHS.getOperand(0));
10290     }
10291   }
10292   return SDValue();
10293 }
10294 
visitSRA(SDNode * N)10295 SDValue DAGCombiner::visitSRA(SDNode *N) {
10296   SDValue N0 = N->getOperand(0);
10297   SDValue N1 = N->getOperand(1);
10298   if (SDValue V = DAG.simplifyShift(N0, N1))
10299     return V;
10300 
10301   EVT VT = N0.getValueType();
10302   unsigned OpSizeInBits = VT.getScalarSizeInBits();
10303 
10304   // fold (sra c1, c2) -> (sra c1, c2)
10305   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
10306     return C;
10307 
10308   // Arithmetic shifting an all-sign-bit value is a no-op.
10309   // fold (sra 0, x) -> 0
10310   // fold (sra -1, x) -> -1
10311   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
10312     return N0;
10313 
10314   // fold vector ops
10315   if (VT.isVector())
10316     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
10317       return FoldedVOp;
10318 
10319   if (SDValue NewSel = foldBinOpIntoSelect(N))
10320     return NewSel;
10321 
10322   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10323 
10324   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
10325   // clamp (add c1, c2) to max shift.
10326   if (N0.getOpcode() == ISD::SRA) {
10327     SDLoc DL(N);
10328     EVT ShiftVT = N1.getValueType();
10329     EVT ShiftSVT = ShiftVT.getScalarType();
10330     SmallVector<SDValue, 16> ShiftValues;
10331 
10332     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10333       APInt c1 = LHS->getAPIntValue();
10334       APInt c2 = RHS->getAPIntValue();
10335       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10336       APInt Sum = c1 + c2;
10337       unsigned ShiftSum =
10338           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
10339       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
10340       return true;
10341     };
10342     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
10343       SDValue ShiftValue;
10344       if (N1.getOpcode() == ISD::BUILD_VECTOR)
10345         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
10346       else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
10347         assert(ShiftValues.size() == 1 &&
10348                "Expected matchBinaryPredicate to return one element for "
10349                "SPLAT_VECTORs");
10350         ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
10351       } else
10352         ShiftValue = ShiftValues[0];
10353       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
10354     }
10355   }
10356 
10357   // fold (sra (shl X, m), (sub result_size, n))
10358   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
10359   // result_size - n != m.
10360   // If truncate is free for the target sext(shl) is likely to result in better
10361   // code.
10362   if (N0.getOpcode() == ISD::SHL && N1C) {
10363     // Get the two constants of the shifts, CN0 = m, CN = n.
10364     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
10365     if (N01C) {
10366       LLVMContext &Ctx = *DAG.getContext();
10367       // Determine what the truncate's result bitsize and type would be.
10368       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
10369 
10370       if (VT.isVector())
10371         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
10372 
10373       // Determine the residual right-shift amount.
10374       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
10375 
10376       // If the shift is not a no-op (in which case this should be just a sign
10377       // extend already), the truncated to type is legal, sign_extend is legal
10378       // on that type, and the truncate to that type is both legal and free,
10379       // perform the transform.
10380       if ((ShiftAmt > 0) &&
10381           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
10382           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
10383           TLI.isTruncateFree(VT, TruncVT)) {
10384         SDLoc DL(N);
10385         SDValue Amt = DAG.getConstant(ShiftAmt, DL,
10386             getShiftAmountTy(N0.getOperand(0).getValueType()));
10387         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
10388                                     N0.getOperand(0), Amt);
10389         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
10390                                     Shift);
10391         return DAG.getNode(ISD::SIGN_EXTEND, DL,
10392                            N->getValueType(0), Trunc);
10393       }
10394     }
10395   }
10396 
10397   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
10398   //   sra (add (shl X, N1C), AddC), N1C -->
10399   //   sext (add (trunc X to (width - N1C)), AddC')
10400   //   sra (sub AddC, (shl X, N1C)), N1C -->
10401   //   sext (sub AddC1',(trunc X to (width - N1C)))
10402   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
10403       N0.hasOneUse()) {
10404     bool IsAdd = N0.getOpcode() == ISD::ADD;
10405     SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
10406     if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
10407         Shl.hasOneUse()) {
10408       // TODO: AddC does not need to be a splat.
10409       if (ConstantSDNode *AddC =
10410               isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
10411         // Determine what the truncate's type would be and ask the target if
10412         // that is a free operation.
10413         LLVMContext &Ctx = *DAG.getContext();
10414         unsigned ShiftAmt = N1C->getZExtValue();
10415         EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
10416         if (VT.isVector())
10417           TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
10418 
10419         // TODO: The simple type check probably belongs in the default hook
10420         //       implementation and/or target-specific overrides (because
10421         //       non-simple types likely require masking when legalized), but
10422         //       that restriction may conflict with other transforms.
10423         if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
10424             TLI.isTruncateFree(VT, TruncVT)) {
10425           SDLoc DL(N);
10426           SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
10427           SDValue ShiftC =
10428               DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
10429                                   TruncVT.getScalarSizeInBits()),
10430                               DL, TruncVT);
10431           SDValue Add;
10432           if (IsAdd)
10433             Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
10434           else
10435             Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
10436           return DAG.getSExtOrTrunc(Add, DL, VT);
10437         }
10438       }
10439     }
10440   }
10441 
10442   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
10443   if (N1.getOpcode() == ISD::TRUNCATE &&
10444       N1.getOperand(0).getOpcode() == ISD::AND) {
10445     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10446       return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
10447   }
10448 
10449   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
10450   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
10451   //      if c1 is equal to the number of bits the trunc removes
10452   // TODO - support non-uniform vector shift amounts.
10453   if (N0.getOpcode() == ISD::TRUNCATE &&
10454       (N0.getOperand(0).getOpcode() == ISD::SRL ||
10455        N0.getOperand(0).getOpcode() == ISD::SRA) &&
10456       N0.getOperand(0).hasOneUse() &&
10457       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
10458     SDValue N0Op0 = N0.getOperand(0);
10459     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
10460       EVT LargeVT = N0Op0.getValueType();
10461       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
10462       if (LargeShift->getAPIntValue() == TruncBits) {
10463         SDLoc DL(N);
10464         EVT LargeShiftVT = getShiftAmountTy(LargeVT);
10465         SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
10466         Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
10467                           DAG.getConstant(TruncBits, DL, LargeShiftVT));
10468         SDValue SRA =
10469             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
10470         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
10471       }
10472     }
10473   }
10474 
10475   // Simplify, based on bits shifted out of the LHS.
10476   if (SimplifyDemandedBits(SDValue(N, 0)))
10477     return SDValue(N, 0);
10478 
10479   // If the sign bit is known to be zero, switch this to a SRL.
10480   if (DAG.SignBitIsZero(N0))
10481     return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
10482 
10483   if (N1C && !N1C->isOpaque())
10484     if (SDValue NewSRA = visitShiftByConstant(N))
10485       return NewSRA;
10486 
10487   // Try to transform this shift into a multiply-high if
10488   // it matches the appropriate pattern detected in combineShiftToMULH.
10489   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
10490     return MULH;
10491 
10492   // Attempt to convert a sra of a load into a narrower sign-extending load.
10493   if (SDValue NarrowLoad = reduceLoadWidth(N))
10494     return NarrowLoad;
10495 
10496   return SDValue();
10497 }
10498 
visitSRL(SDNode * N)10499 SDValue DAGCombiner::visitSRL(SDNode *N) {
10500   SDValue N0 = N->getOperand(0);
10501   SDValue N1 = N->getOperand(1);
10502   if (SDValue V = DAG.simplifyShift(N0, N1))
10503     return V;
10504 
10505   EVT VT = N0.getValueType();
10506   EVT ShiftVT = N1.getValueType();
10507   unsigned OpSizeInBits = VT.getScalarSizeInBits();
10508 
10509   // fold (srl c1, c2) -> c1 >>u c2
10510   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
10511     return C;
10512 
10513   // fold vector ops
10514   if (VT.isVector())
10515     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
10516       return FoldedVOp;
10517 
10518   if (SDValue NewSel = foldBinOpIntoSelect(N))
10519     return NewSel;
10520 
10521   // if (srl x, c) is known to be zero, return 0
10522   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10523   if (N1C &&
10524       DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
10525     return DAG.getConstant(0, SDLoc(N), VT);
10526 
10527   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
10528   if (N0.getOpcode() == ISD::SRL) {
10529     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10530                                           ConstantSDNode *RHS) {
10531       APInt c1 = LHS->getAPIntValue();
10532       APInt c2 = RHS->getAPIntValue();
10533       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10534       return (c1 + c2).uge(OpSizeInBits);
10535     };
10536     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
10537       return DAG.getConstant(0, SDLoc(N), VT);
10538 
10539     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10540                                        ConstantSDNode *RHS) {
10541       APInt c1 = LHS->getAPIntValue();
10542       APInt c2 = RHS->getAPIntValue();
10543       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10544       return (c1 + c2).ult(OpSizeInBits);
10545     };
10546     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
10547       SDLoc DL(N);
10548       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
10549       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
10550     }
10551   }
10552 
10553   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
10554       N0.getOperand(0).getOpcode() == ISD::SRL) {
10555     SDValue InnerShift = N0.getOperand(0);
10556     // TODO - support non-uniform vector shift amounts.
10557     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
10558       uint64_t c1 = N001C->getZExtValue();
10559       uint64_t c2 = N1C->getZExtValue();
10560       EVT InnerShiftVT = InnerShift.getValueType();
10561       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
10562       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
10563       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
10564       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
10565       if (c1 + OpSizeInBits == InnerShiftSize) {
10566         SDLoc DL(N);
10567         if (c1 + c2 >= InnerShiftSize)
10568           return DAG.getConstant(0, DL, VT);
10569         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
10570         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
10571                                        InnerShift.getOperand(0), NewShiftAmt);
10572         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
10573       }
10574       // In the more general case, we can clear the high bits after the shift:
10575       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
10576       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
10577           c1 + c2 < InnerShiftSize) {
10578         SDLoc DL(N);
10579         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
10580         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
10581                                        InnerShift.getOperand(0), NewShiftAmt);
10582         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
10583                                                             OpSizeInBits - c2),
10584                                        DL, InnerShiftVT);
10585         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
10586         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
10587       }
10588     }
10589   }
10590 
10591   // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
10592   //                               (and (srl x, (sub c2, c1), MASK)
10593   if (N0.getOpcode() == ISD::SHL &&
10594       (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
10595       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10596     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10597                                            ConstantSDNode *RHS) {
10598       const APInt &LHSC = LHS->getAPIntValue();
10599       const APInt &RHSC = RHS->getAPIntValue();
10600       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
10601              LHSC.getZExtValue() <= RHSC.getZExtValue();
10602     };
10603     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10604                                   /*AllowUndefs*/ false,
10605                                   /*AllowTypeMismatch*/ true)) {
10606       SDLoc DL(N);
10607       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10608       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10609       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10610       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
10611       Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
10612       SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10613       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10614     }
10615     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10616                                   /*AllowUndefs*/ false,
10617                                   /*AllowTypeMismatch*/ true)) {
10618       SDLoc DL(N);
10619       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10620       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10621       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10622       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
10623       SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
10624       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10625     }
10626   }
10627 
10628   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
10629   // TODO - support non-uniform vector shift amounts.
10630   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
10631     // Shifting in all undef bits?
10632     EVT SmallVT = N0.getOperand(0).getValueType();
10633     unsigned BitSize = SmallVT.getScalarSizeInBits();
10634     if (N1C->getAPIntValue().uge(BitSize))
10635       return DAG.getUNDEF(VT);
10636 
10637     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
10638       uint64_t ShiftAmt = N1C->getZExtValue();
10639       SDLoc DL0(N0);
10640       SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
10641                                        N0.getOperand(0),
10642                           DAG.getConstant(ShiftAmt, DL0,
10643                                           getShiftAmountTy(SmallVT)));
10644       AddToWorklist(SmallShift.getNode());
10645       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
10646       SDLoc DL(N);
10647       return DAG.getNode(ISD::AND, DL, VT,
10648                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
10649                          DAG.getConstant(Mask, DL, VT));
10650     }
10651   }
10652 
10653   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
10654   // bit, which is unmodified by sra.
10655   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
10656     if (N0.getOpcode() == ISD::SRA)
10657       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
10658   }
10659 
10660   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit), and x has a power
10661   // of two bitwidth. The "5" represents (log2 (bitwidth x)).
10662   if (N1C && N0.getOpcode() == ISD::CTLZ &&
10663       isPowerOf2_32(OpSizeInBits) &&
10664       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
10665     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
10666 
10667     // If any of the input bits are KnownOne, then the input couldn't be all
10668     // zeros, thus the result of the srl will always be zero.
10669     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
10670 
10671     // If all of the bits input the to ctlz node are known to be zero, then
10672     // the result of the ctlz is "32" and the result of the shift is one.
10673     APInt UnknownBits = ~Known.Zero;
10674     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
10675 
10676     // Otherwise, check to see if there is exactly one bit input to the ctlz.
10677     if (UnknownBits.isPowerOf2()) {
10678       // Okay, we know that only that the single bit specified by UnknownBits
10679       // could be set on input to the CTLZ node. If this bit is set, the SRL
10680       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
10681       // to an SRL/XOR pair, which is likely to simplify more.
10682       unsigned ShAmt = UnknownBits.countr_zero();
10683       SDValue Op = N0.getOperand(0);
10684 
10685       if (ShAmt) {
10686         SDLoc DL(N0);
10687         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
10688                   DAG.getConstant(ShAmt, DL,
10689                                   getShiftAmountTy(Op.getValueType())));
10690         AddToWorklist(Op.getNode());
10691       }
10692 
10693       SDLoc DL(N);
10694       return DAG.getNode(ISD::XOR, DL, VT,
10695                          Op, DAG.getConstant(1, DL, VT));
10696     }
10697   }
10698 
10699   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
10700   if (N1.getOpcode() == ISD::TRUNCATE &&
10701       N1.getOperand(0).getOpcode() == ISD::AND) {
10702     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10703       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
10704   }
10705 
10706   // fold operands of srl based on knowledge that the low bits are not
10707   // demanded.
10708   if (SimplifyDemandedBits(SDValue(N, 0)))
10709     return SDValue(N, 0);
10710 
10711   if (N1C && !N1C->isOpaque())
10712     if (SDValue NewSRL = visitShiftByConstant(N))
10713       return NewSRL;
10714 
10715   // Attempt to convert a srl of a load into a narrower zero-extending load.
10716   if (SDValue NarrowLoad = reduceLoadWidth(N))
10717     return NarrowLoad;
10718 
10719   // Here is a common situation. We want to optimize:
10720   //
10721   //   %a = ...
10722   //   %b = and i32 %a, 2
10723   //   %c = srl i32 %b, 1
10724   //   brcond i32 %c ...
10725   //
10726   // into
10727   //
10728   //   %a = ...
10729   //   %b = and %a, 2
10730   //   %c = setcc eq %b, 0
10731   //   brcond %c ...
10732   //
10733   // However when after the source operand of SRL is optimized into AND, the SRL
10734   // itself may not be optimized further. Look for it and add the BRCOND into
10735   // the worklist.
10736   //
10737   // The also tends to happen for binary operations when SimplifyDemandedBits
10738   // is involved.
10739   //
10740   // FIXME: This is unecessary if we process the DAG in topological order,
10741   // which we plan to do. This workaround can be removed once the DAG is
10742   // processed in topological order.
10743   if (N->hasOneUse()) {
10744     SDNode *Use = *N->use_begin();
10745 
10746     // Look pass the truncate.
10747     if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse())
10748       Use = *Use->use_begin();
10749 
10750     if (Use->getOpcode() == ISD::BRCOND || Use->getOpcode() == ISD::AND ||
10751         Use->getOpcode() == ISD::OR || Use->getOpcode() == ISD::XOR)
10752       AddToWorklist(Use);
10753   }
10754 
10755   // Try to transform this shift into a multiply-high if
10756   // it matches the appropriate pattern detected in combineShiftToMULH.
10757   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
10758     return MULH;
10759 
10760   return SDValue();
10761 }
10762 
visitFunnelShift(SDNode * N)10763 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
10764   EVT VT = N->getValueType(0);
10765   SDValue N0 = N->getOperand(0);
10766   SDValue N1 = N->getOperand(1);
10767   SDValue N2 = N->getOperand(2);
10768   bool IsFSHL = N->getOpcode() == ISD::FSHL;
10769   unsigned BitWidth = VT.getScalarSizeInBits();
10770 
10771   // fold (fshl N0, N1, 0) -> N0
10772   // fold (fshr N0, N1, 0) -> N1
10773   if (isPowerOf2_32(BitWidth))
10774     if (DAG.MaskedValueIsZero(
10775             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
10776       return IsFSHL ? N0 : N1;
10777 
10778   auto IsUndefOrZero = [](SDValue V) {
10779     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
10780   };
10781 
10782   // TODO - support non-uniform vector shift amounts.
10783   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
10784     EVT ShAmtTy = N2.getValueType();
10785 
10786     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
10787     if (Cst->getAPIntValue().uge(BitWidth)) {
10788       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
10789       return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
10790                          DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
10791     }
10792 
10793     unsigned ShAmt = Cst->getZExtValue();
10794     if (ShAmt == 0)
10795       return IsFSHL ? N0 : N1;
10796 
10797     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
10798     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
10799     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
10800     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
10801     if (IsUndefOrZero(N0))
10802       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
10803                          DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
10804                                          SDLoc(N), ShAmtTy));
10805     if (IsUndefOrZero(N1))
10806       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
10807                          DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
10808                                          SDLoc(N), ShAmtTy));
10809 
10810     // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10811     // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10812     // TODO - bigendian support once we have test coverage.
10813     // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
10814     // TODO - permit LHS EXTLOAD if extensions are shifted out.
10815     if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
10816         !DAG.getDataLayout().isBigEndian()) {
10817       auto *LHS = dyn_cast<LoadSDNode>(N0);
10818       auto *RHS = dyn_cast<LoadSDNode>(N1);
10819       if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
10820           LHS->getAddressSpace() == RHS->getAddressSpace() &&
10821           (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) &&
10822           ISD::isNON_EXTLoad(LHS)) {
10823         if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
10824           SDLoc DL(RHS);
10825           uint64_t PtrOff =
10826               IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
10827           Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
10828           unsigned Fast = 0;
10829           if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
10830                                      RHS->getAddressSpace(), NewAlign,
10831                                      RHS->getMemOperand()->getFlags(), &Fast) &&
10832               Fast) {
10833             SDValue NewPtr = DAG.getMemBasePlusOffset(
10834                 RHS->getBasePtr(), TypeSize::getFixed(PtrOff), DL);
10835             AddToWorklist(NewPtr.getNode());
10836             SDValue Load = DAG.getLoad(
10837                 VT, DL, RHS->getChain(), NewPtr,
10838                 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
10839                 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
10840             // Replace the old load's chain with the new load's chain.
10841             WorklistRemover DeadNodes(*this);
10842             DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1));
10843             return Load;
10844           }
10845         }
10846       }
10847     }
10848   }
10849 
10850   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
10851   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
10852   // iff We know the shift amount is in range.
10853   // TODO: when is it worth doing SUB(BW, N2) as well?
10854   if (isPowerOf2_32(BitWidth)) {
10855     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
10856     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
10857       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
10858     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
10859       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
10860   }
10861 
10862   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
10863   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
10864   // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
10865   // is legal as well we might be better off avoiding non-constant (BW - N2).
10866   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
10867   if (N0 == N1 && hasOperation(RotOpc, VT))
10868     return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
10869 
10870   // Simplify, based on bits shifted out of N0/N1.
10871   if (SimplifyDemandedBits(SDValue(N, 0)))
10872     return SDValue(N, 0);
10873 
10874   return SDValue();
10875 }
10876 
visitSHLSAT(SDNode * N)10877 SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
10878   SDValue N0 = N->getOperand(0);
10879   SDValue N1 = N->getOperand(1);
10880   if (SDValue V = DAG.simplifyShift(N0, N1))
10881     return V;
10882 
10883   EVT VT = N0.getValueType();
10884 
10885   // fold (*shlsat c1, c2) -> c1<<c2
10886   if (SDValue C =
10887           DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0, N1}))
10888     return C;
10889 
10890   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10891 
10892   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
10893     // fold (sshlsat x, c) -> (shl x, c)
10894     if (N->getOpcode() == ISD::SSHLSAT && N1C &&
10895         N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
10896       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
10897 
10898     // fold (ushlsat x, c) -> (shl x, c)
10899     if (N->getOpcode() == ISD::USHLSAT && N1C &&
10900         N1C->getAPIntValue().ule(
10901             DAG.computeKnownBits(N0).countMinLeadingZeros()))
10902       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
10903   }
10904 
10905   return SDValue();
10906 }
10907 
10908 // Given a ABS node, detect the following patterns:
10909 // (ABS (SUB (EXTEND a), (EXTEND b))).
10910 // (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
10911 // Generates UABD/SABD instruction.
foldABSToABD(SDNode * N)10912 SDValue DAGCombiner::foldABSToABD(SDNode *N) {
10913   EVT SrcVT = N->getValueType(0);
10914 
10915   if (N->getOpcode() == ISD::TRUNCATE)
10916     N = N->getOperand(0).getNode();
10917 
10918   if (N->getOpcode() != ISD::ABS)
10919     return SDValue();
10920 
10921   EVT VT = N->getValueType(0);
10922   SDValue AbsOp1 = N->getOperand(0);
10923   SDValue Op0, Op1;
10924   SDLoc DL(N);
10925 
10926   if (AbsOp1.getOpcode() != ISD::SUB)
10927     return SDValue();
10928 
10929   Op0 = AbsOp1.getOperand(0);
10930   Op1 = AbsOp1.getOperand(1);
10931 
10932   unsigned Opc0 = Op0.getOpcode();
10933 
10934   // Check if the operands of the sub are (zero|sign)-extended.
10935   // TODO: Should we use ValueTracking instead?
10936   if (Opc0 != Op1.getOpcode() ||
10937       (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
10938        Opc0 != ISD::SIGN_EXTEND_INREG)) {
10939     // fold (abs (sub nsw x, y)) -> abds(x, y)
10940     if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(ISD::ABDS, VT) &&
10941         TLI.preferABDSToABSWithNSW(VT)) {
10942       SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
10943       return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
10944     }
10945     return SDValue();
10946   }
10947 
10948   EVT VT0, VT1;
10949   if (Opc0 == ISD::SIGN_EXTEND_INREG) {
10950     VT0 = cast<VTSDNode>(Op0.getOperand(1))->getVT();
10951     VT1 = cast<VTSDNode>(Op1.getOperand(1))->getVT();
10952   } else {
10953     VT0 = Op0.getOperand(0).getValueType();
10954     VT1 = Op1.getOperand(0).getValueType();
10955   }
10956   unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
10957 
10958   // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
10959   // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
10960   EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
10961   if ((VT0 == MaxVT || Op0->hasOneUse()) &&
10962       (VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(ABDOpcode, MaxVT)) {
10963     SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
10964                               DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
10965                               DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
10966     ABD = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ABD);
10967     return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
10968   }
10969 
10970   // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
10971   // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10972   if (hasOperation(ABDOpcode, VT)) {
10973     SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
10974     return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
10975   }
10976 
10977   return SDValue();
10978 }
10979 
visitABS(SDNode * N)10980 SDValue DAGCombiner::visitABS(SDNode *N) {
10981   SDValue N0 = N->getOperand(0);
10982   EVT VT = N->getValueType(0);
10983 
10984   // fold (abs c1) -> c2
10985   if (SDValue C = DAG.FoldConstantArithmetic(ISD::ABS, SDLoc(N), VT, {N0}))
10986     return C;
10987   // fold (abs (abs x)) -> (abs x)
10988   if (N0.getOpcode() == ISD::ABS)
10989     return N0;
10990   // fold (abs x) -> x iff not-negative
10991   if (DAG.SignBitIsZero(N0))
10992     return N0;
10993 
10994   if (SDValue ABD = foldABSToABD(N))
10995     return ABD;
10996 
10997   // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
10998   // iff zero_extend/truncate are free.
10999   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
11000     EVT ExtVT = cast<VTSDNode>(N0.getOperand(1))->getVT();
11001     if (TLI.isTruncateFree(VT, ExtVT) && TLI.isZExtFree(ExtVT, VT) &&
11002         TLI.isTypeDesirableForOp(ISD::ABS, ExtVT) &&
11003         hasOperation(ISD::ABS, ExtVT)) {
11004       SDLoc DL(N);
11005       return DAG.getNode(
11006           ISD::ZERO_EXTEND, DL, VT,
11007           DAG.getNode(ISD::ABS, DL, ExtVT,
11008                       DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N0.getOperand(0))));
11009     }
11010   }
11011 
11012   return SDValue();
11013 }
11014 
visitBSWAP(SDNode * N)11015 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11016   SDValue N0 = N->getOperand(0);
11017   EVT VT = N->getValueType(0);
11018   SDLoc DL(N);
11019 
11020   // fold (bswap c1) -> c2
11021   if (SDValue C = DAG.FoldConstantArithmetic(ISD::BSWAP, DL, VT, {N0}))
11022     return C;
11023   // fold (bswap (bswap x)) -> x
11024   if (N0.getOpcode() == ISD::BSWAP)
11025     return N0.getOperand(0);
11026 
11027   // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11028   // isn't supported, it will be expanded to bswap followed by a manual reversal
11029   // of bits in each byte. By placing bswaps before bitreverse, we can remove
11030   // the two bswaps if the bitreverse gets expanded.
11031   if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11032     SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11033     return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
11034   }
11035 
11036   // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
11037   // iff x >= bw/2 (i.e. lower half is known zero)
11038   unsigned BW = VT.getScalarSizeInBits();
11039   if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
11040     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11041     EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
11042     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
11043         ShAmt->getZExtValue() >= (BW / 2) &&
11044         (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
11045         TLI.isTruncateFree(VT, HalfVT) &&
11046         (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
11047       SDValue Res = N0.getOperand(0);
11048       if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
11049         Res = DAG.getNode(ISD::SHL, DL, VT, Res,
11050                           DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
11051       Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
11052       Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
11053       return DAG.getZExtOrTrunc(Res, DL, VT);
11054     }
11055   }
11056 
11057   // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
11058   // inverse-shift-of-bswap:
11059   // bswap (X u<< C) --> (bswap X) u>> C
11060   // bswap (X u>> C) --> (bswap X) u<< C
11061   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11062       N0.hasOneUse()) {
11063     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11064     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
11065         ShAmt->getZExtValue() % 8 == 0) {
11066       SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11067       unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
11068       return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
11069     }
11070   }
11071 
11072   if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
11073     return V;
11074 
11075   return SDValue();
11076 }
11077 
visitBITREVERSE(SDNode * N)11078 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
11079   SDValue N0 = N->getOperand(0);
11080   EVT VT = N->getValueType(0);
11081   SDLoc DL(N);
11082 
11083   // fold (bitreverse c1) -> c2
11084   if (SDValue C = DAG.FoldConstantArithmetic(ISD::BITREVERSE, DL, VT, {N0}))
11085     return C;
11086   // fold (bitreverse (bitreverse x)) -> x
11087   if (N0.getOpcode() == ISD::BITREVERSE)
11088     return N0.getOperand(0);
11089   return SDValue();
11090 }
11091 
visitCTLZ(SDNode * N)11092 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
11093   SDValue N0 = N->getOperand(0);
11094   EVT VT = N->getValueType(0);
11095   SDLoc DL(N);
11096 
11097   // fold (ctlz c1) -> c2
11098   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTLZ, DL, VT, {N0}))
11099     return C;
11100 
11101   // If the value is known never to be zero, switch to the undef version.
11102   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT))
11103     if (DAG.isKnownNeverZero(N0))
11104       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, N0);
11105 
11106   return SDValue();
11107 }
11108 
visitCTLZ_ZERO_UNDEF(SDNode * N)11109 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
11110   SDValue N0 = N->getOperand(0);
11111   EVT VT = N->getValueType(0);
11112   SDLoc DL(N);
11113 
11114   // fold (ctlz_zero_undef c1) -> c2
11115   if (SDValue C =
11116           DAG.FoldConstantArithmetic(ISD::CTLZ_ZERO_UNDEF, DL, VT, {N0}))
11117     return C;
11118   return SDValue();
11119 }
11120 
visitCTTZ(SDNode * N)11121 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
11122   SDValue N0 = N->getOperand(0);
11123   EVT VT = N->getValueType(0);
11124   SDLoc DL(N);
11125 
11126   // fold (cttz c1) -> c2
11127   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTTZ, DL, VT, {N0}))
11128     return C;
11129 
11130   // If the value is known never to be zero, switch to the undef version.
11131   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT))
11132     if (DAG.isKnownNeverZero(N0))
11133       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, DL, VT, N0);
11134 
11135   return SDValue();
11136 }
11137 
visitCTTZ_ZERO_UNDEF(SDNode * N)11138 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
11139   SDValue N0 = N->getOperand(0);
11140   EVT VT = N->getValueType(0);
11141   SDLoc DL(N);
11142 
11143   // fold (cttz_zero_undef c1) -> c2
11144   if (SDValue C =
11145           DAG.FoldConstantArithmetic(ISD::CTTZ_ZERO_UNDEF, DL, VT, {N0}))
11146     return C;
11147   return SDValue();
11148 }
11149 
visitCTPOP(SDNode * N)11150 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
11151   SDValue N0 = N->getOperand(0);
11152   EVT VT = N->getValueType(0);
11153   SDLoc DL(N);
11154 
11155   // fold (ctpop c1) -> c2
11156   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTPOP, DL, VT, {N0}))
11157     return C;
11158   return SDValue();
11159 }
11160 
11161 // FIXME: This should be checking for no signed zeros on individual operands, as
11162 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)11163 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
11164                                          SDValue RHS,
11165                                          const TargetLowering &TLI) {
11166   const TargetOptions &Options = DAG.getTarget().Options;
11167   EVT VT = LHS.getValueType();
11168 
11169   return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
11170          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
11171          DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
11172 }
11173 
combineMinNumMaxNumImpl(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)11174 static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
11175                                        SDValue RHS, SDValue True, SDValue False,
11176                                        ISD::CondCode CC,
11177                                        const TargetLowering &TLI,
11178                                        SelectionDAG &DAG) {
11179   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
11180   switch (CC) {
11181   case ISD::SETOLT:
11182   case ISD::SETOLE:
11183   case ISD::SETLT:
11184   case ISD::SETLE:
11185   case ISD::SETULT:
11186   case ISD::SETULE: {
11187     // Since it's known never nan to get here already, either fminnum or
11188     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
11189     // expanded in terms of it.
11190     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
11191     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
11192       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
11193 
11194     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
11195     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
11196       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
11197     return SDValue();
11198   }
11199   case ISD::SETOGT:
11200   case ISD::SETOGE:
11201   case ISD::SETGT:
11202   case ISD::SETGE:
11203   case ISD::SETUGT:
11204   case ISD::SETUGE: {
11205     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
11206     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
11207       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
11208 
11209     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
11210     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
11211       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
11212     return SDValue();
11213   }
11214   default:
11215     return SDValue();
11216   }
11217 }
11218 
11219 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC)11220 SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
11221                                          SDValue RHS, SDValue True,
11222                                          SDValue False, ISD::CondCode CC) {
11223   if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
11224     return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
11225 
11226   // If we can't directly match this, try to see if we can pull an fneg out of
11227   // the select.
11228   SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
11229       True, DAG, LegalOperations, ForCodeSize);
11230   if (!NegTrue)
11231     return SDValue();
11232 
11233   HandleSDNode NegTrueHandle(NegTrue);
11234 
11235   // Try to unfold an fneg from the select if we are comparing the negated
11236   // constant.
11237   //
11238   // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
11239   //
11240   // TODO: Handle fabs
11241   if (LHS == NegTrue) {
11242     // If we can't directly match this, try to see if we can pull an fneg out of
11243     // the select.
11244     SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
11245         RHS, DAG, LegalOperations, ForCodeSize);
11246     if (NegRHS) {
11247       HandleSDNode NegRHSHandle(NegRHS);
11248       if (NegRHS == False) {
11249         SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
11250                                                    False, CC, TLI, DAG);
11251         if (Combined)
11252           return DAG.getNode(ISD::FNEG, DL, VT, Combined);
11253       }
11254     }
11255   }
11256 
11257   return SDValue();
11258 }
11259 
11260 /// If a (v)select has a condition value that is a sign-bit test, try to smear
11261 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)11262 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
11263   SDValue Cond = N->getOperand(0);
11264   SDValue C1 = N->getOperand(1);
11265   SDValue C2 = N->getOperand(2);
11266   if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
11267     return SDValue();
11268 
11269   EVT VT = N->getValueType(0);
11270   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
11271       VT != Cond.getOperand(0).getValueType())
11272     return SDValue();
11273 
11274   // The inverted-condition + commuted-select variants of these patterns are
11275   // canonicalized to these forms in IR.
11276   SDValue X = Cond.getOperand(0);
11277   SDValue CondC = Cond.getOperand(1);
11278   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
11279   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
11280       isAllOnesOrAllOnesSplat(C2)) {
11281     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
11282     SDLoc DL(N);
11283     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
11284     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
11285     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
11286   }
11287   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
11288     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
11289     SDLoc DL(N);
11290     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
11291     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
11292     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
11293   }
11294   return SDValue();
11295 }
11296 
shouldConvertSelectOfConstantsToMath(const SDValue & Cond,EVT VT,const TargetLowering & TLI)11297 static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
11298                                                  const TargetLowering &TLI) {
11299   if (!TLI.convertSelectOfConstantsToMath(VT))
11300     return false;
11301 
11302   if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
11303     return true;
11304   if (!TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))
11305     return true;
11306 
11307   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
11308   if (CC == ISD::SETLT && isNullOrNullSplat(Cond.getOperand(1)))
11309     return true;
11310   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond.getOperand(1)))
11311     return true;
11312 
11313   return false;
11314 }
11315 
foldSelectOfConstants(SDNode * N)11316 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
11317   SDValue Cond = N->getOperand(0);
11318   SDValue N1 = N->getOperand(1);
11319   SDValue N2 = N->getOperand(2);
11320   EVT VT = N->getValueType(0);
11321   EVT CondVT = Cond.getValueType();
11322   SDLoc DL(N);
11323 
11324   if (!VT.isInteger())
11325     return SDValue();
11326 
11327   auto *C1 = dyn_cast<ConstantSDNode>(N1);
11328   auto *C2 = dyn_cast<ConstantSDNode>(N2);
11329   if (!C1 || !C2)
11330     return SDValue();
11331 
11332   if (CondVT != MVT::i1 || LegalOperations) {
11333     // fold (select Cond, 0, 1) -> (xor Cond, 1)
11334     // We can't do this reliably if integer based booleans have different contents
11335     // to floating point based booleans. This is because we can't tell whether we
11336     // have an integer-based boolean or a floating-point-based boolean unless we
11337     // can find the SETCC that produced it and inspect its operands. This is
11338     // fairly easy if C is the SETCC node, but it can potentially be
11339     // undiscoverable (or not reasonably discoverable). For example, it could be
11340     // in another basic block or it could require searching a complicated
11341     // expression.
11342     if (CondVT.isInteger() &&
11343         TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
11344             TargetLowering::ZeroOrOneBooleanContent &&
11345         TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
11346             TargetLowering::ZeroOrOneBooleanContent &&
11347         C1->isZero() && C2->isOne()) {
11348       SDValue NotCond =
11349           DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
11350       if (VT.bitsEq(CondVT))
11351         return NotCond;
11352       return DAG.getZExtOrTrunc(NotCond, DL, VT);
11353     }
11354 
11355     return SDValue();
11356   }
11357 
11358   // Only do this before legalization to avoid conflicting with target-specific
11359   // transforms in the other direction (create a select from a zext/sext). There
11360   // is also a target-independent combine here in DAGCombiner in the other
11361   // direction for (select Cond, -1, 0) when the condition is not i1.
11362   assert(CondVT == MVT::i1 && !LegalOperations);
11363 
11364   // select Cond, 1, 0 --> zext (Cond)
11365   if (C1->isOne() && C2->isZero())
11366     return DAG.getZExtOrTrunc(Cond, DL, VT);
11367 
11368   // select Cond, -1, 0 --> sext (Cond)
11369   if (C1->isAllOnes() && C2->isZero())
11370     return DAG.getSExtOrTrunc(Cond, DL, VT);
11371 
11372   // select Cond, 0, 1 --> zext (!Cond)
11373   if (C1->isZero() && C2->isOne()) {
11374     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11375     NotCond = DAG.getZExtOrTrunc(NotCond, DL, VT);
11376     return NotCond;
11377   }
11378 
11379   // select Cond, 0, -1 --> sext (!Cond)
11380   if (C1->isZero() && C2->isAllOnes()) {
11381     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11382     NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
11383     return NotCond;
11384   }
11385 
11386   // Use a target hook because some targets may prefer to transform in the
11387   // other direction.
11388   if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
11389     return SDValue();
11390 
11391   // For any constants that differ by 1, we can transform the select into
11392   // an extend and add.
11393   const APInt &C1Val = C1->getAPIntValue();
11394   const APInt &C2Val = C2->getAPIntValue();
11395 
11396   // select Cond, C1, C1-1 --> add (zext Cond), C1-1
11397   if (C1Val - 1 == C2Val) {
11398     Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
11399     return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
11400   }
11401 
11402   // select Cond, C1, C1+1 --> add (sext Cond), C1+1
11403   if (C1Val + 1 == C2Val) {
11404     Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
11405     return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
11406   }
11407 
11408   // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
11409   if (C1Val.isPowerOf2() && C2Val.isZero()) {
11410     Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
11411     SDValue ShAmtC =
11412         DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
11413     return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
11414   }
11415 
11416   // select Cond, -1, C --> or (sext Cond), C
11417   if (C1->isAllOnes()) {
11418     Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
11419     return DAG.getNode(ISD::OR, DL, VT, Cond, N2);
11420   }
11421 
11422   // select Cond, C, -1 --> or (sext (not Cond)), C
11423   if (C2->isAllOnes()) {
11424     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11425     NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
11426     return DAG.getNode(ISD::OR, DL, VT, NotCond, N1);
11427   }
11428 
11429   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
11430     return V;
11431 
11432   return SDValue();
11433 }
11434 
foldBoolSelectToLogic(SDNode * N,SelectionDAG & DAG)11435 static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
11436   assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
11437          "Expected a (v)select");
11438   SDValue Cond = N->getOperand(0);
11439   SDValue T = N->getOperand(1), F = N->getOperand(2);
11440   EVT VT = N->getValueType(0);
11441   if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
11442     return SDValue();
11443 
11444   // select Cond, Cond, F --> or Cond, F
11445   // select Cond, 1, F    --> or Cond, F
11446   if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
11447     return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
11448 
11449   // select Cond, T, Cond --> and Cond, T
11450   // select Cond, T, 0    --> and Cond, T
11451   if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
11452     return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
11453 
11454   // select Cond, T, 1 --> or (not Cond), T
11455   if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
11456     SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
11457     return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
11458   }
11459 
11460   // select Cond, 0, F --> and (not Cond), F
11461   if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
11462     SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
11463     return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
11464   }
11465 
11466   return SDValue();
11467 }
11468 
foldVSelectToSignBitSplatMask(SDNode * N,SelectionDAG & DAG)11469 static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
11470   SDValue N0 = N->getOperand(0);
11471   SDValue N1 = N->getOperand(1);
11472   SDValue N2 = N->getOperand(2);
11473   EVT VT = N->getValueType(0);
11474   if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
11475     return SDValue();
11476 
11477   SDValue Cond0 = N0.getOperand(0);
11478   SDValue Cond1 = N0.getOperand(1);
11479   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
11480   if (VT != Cond0.getValueType())
11481     return SDValue();
11482 
11483   // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
11484   // compare is inverted from that pattern ("Cond0 s> -1").
11485   if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
11486     ; // This is the pattern we are looking for.
11487   else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
11488     std::swap(N1, N2);
11489   else
11490     return SDValue();
11491 
11492   // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
11493   if (isNullOrNullSplat(N2)) {
11494     SDLoc DL(N);
11495     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
11496     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
11497     return DAG.getNode(ISD::AND, DL, VT, Sra, N1);
11498   }
11499 
11500   // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | N2
11501   if (isAllOnesOrAllOnesSplat(N1)) {
11502     SDLoc DL(N);
11503     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
11504     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
11505     return DAG.getNode(ISD::OR, DL, VT, Sra, N2);
11506   }
11507 
11508   // If we have to invert the sign bit mask, only do that transform if the
11509   // target has a bitwise 'and not' instruction (the invert is free).
11510   // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & N2
11511   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11512   if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
11513     SDLoc DL(N);
11514     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
11515     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
11516     SDValue Not = DAG.getNOT(DL, Sra, VT);
11517     return DAG.getNode(ISD::AND, DL, VT, Not, N2);
11518   }
11519 
11520   // TODO: There's another pattern in this family, but it may require
11521   //       implementing hasOrNot() to check for profitability:
11522   //       (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | N2
11523 
11524   return SDValue();
11525 }
11526 
visitSELECT(SDNode * N)11527 SDValue DAGCombiner::visitSELECT(SDNode *N) {
11528   SDValue N0 = N->getOperand(0);
11529   SDValue N1 = N->getOperand(1);
11530   SDValue N2 = N->getOperand(2);
11531   EVT VT = N->getValueType(0);
11532   EVT VT0 = N0.getValueType();
11533   SDLoc DL(N);
11534   SDNodeFlags Flags = N->getFlags();
11535 
11536   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
11537     return V;
11538 
11539   if (SDValue V = foldBoolSelectToLogic(N, DAG))
11540     return V;
11541 
11542   // select (not Cond), N1, N2 -> select Cond, N2, N1
11543   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
11544     SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
11545     SelectOp->setFlags(Flags);
11546     return SelectOp;
11547   }
11548 
11549   if (SDValue V = foldSelectOfConstants(N))
11550     return V;
11551 
11552   // If we can fold this based on the true/false value, do so.
11553   if (SimplifySelectOps(N, N1, N2))
11554     return SDValue(N, 0); // Don't revisit N.
11555 
11556   if (VT0 == MVT::i1) {
11557     // The code in this block deals with the following 2 equivalences:
11558     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
11559     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
11560     // The target can specify its preferred form with the
11561     // shouldNormalizeToSelectSequence() callback. However we always transform
11562     // to the right anyway if we find the inner select exists in the DAG anyway
11563     // and we always transform to the left side if we know that we can further
11564     // optimize the combination of the conditions.
11565     bool normalizeToSequence =
11566         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
11567     // select (and Cond0, Cond1), X, Y
11568     //   -> select Cond0, (select Cond1, X, Y), Y
11569     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
11570       SDValue Cond0 = N0->getOperand(0);
11571       SDValue Cond1 = N0->getOperand(1);
11572       SDValue InnerSelect =
11573           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
11574       if (normalizeToSequence || !InnerSelect.use_empty())
11575         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
11576                            InnerSelect, N2, Flags);
11577       // Cleanup on failure.
11578       if (InnerSelect.use_empty())
11579         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
11580     }
11581     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
11582     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
11583       SDValue Cond0 = N0->getOperand(0);
11584       SDValue Cond1 = N0->getOperand(1);
11585       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
11586                                         Cond1, N1, N2, Flags);
11587       if (normalizeToSequence || !InnerSelect.use_empty())
11588         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
11589                            InnerSelect, Flags);
11590       // Cleanup on failure.
11591       if (InnerSelect.use_empty())
11592         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
11593     }
11594 
11595     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
11596     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
11597       SDValue N1_0 = N1->getOperand(0);
11598       SDValue N1_1 = N1->getOperand(1);
11599       SDValue N1_2 = N1->getOperand(2);
11600       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
11601         // Create the actual and node if we can generate good code for it.
11602         if (!normalizeToSequence) {
11603           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
11604           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
11605                              N2, Flags);
11606         }
11607         // Otherwise see if we can optimize the "and" to a better pattern.
11608         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
11609           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
11610                              N2, Flags);
11611         }
11612       }
11613     }
11614     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
11615     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
11616       SDValue N2_0 = N2->getOperand(0);
11617       SDValue N2_1 = N2->getOperand(1);
11618       SDValue N2_2 = N2->getOperand(2);
11619       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
11620         // Create the actual or node if we can generate good code for it.
11621         if (!normalizeToSequence) {
11622           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
11623           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
11624                              N2_2, Flags);
11625         }
11626         // Otherwise see if we can optimize to a better pattern.
11627         if (SDValue Combined = visitORLike(N0, N2_0, N))
11628           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
11629                              N2_2, Flags);
11630       }
11631     }
11632   }
11633 
11634   // Fold selects based on a setcc into other things, such as min/max/abs.
11635   if (N0.getOpcode() == ISD::SETCC) {
11636     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
11637     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
11638 
11639     // select (fcmp lt x, y), x, y -> fminnum x, y
11640     // select (fcmp gt x, y), x, y -> fmaxnum x, y
11641     //
11642     // This is OK if we don't care what happens if either operand is a NaN.
11643     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
11644       if (SDValue FMinMax =
11645               combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC))
11646         return FMinMax;
11647 
11648     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
11649     // This is conservatively limited to pre-legal-operations to give targets
11650     // a chance to reverse the transform if they want to do that. Also, it is
11651     // unlikely that the pattern would be formed late, so it's probably not
11652     // worth going through the other checks.
11653     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
11654         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
11655         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
11656       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
11657       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
11658       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
11659         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
11660         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
11661         //
11662         // The IR equivalent of this transform would have this form:
11663         //   %a = add %x, C
11664         //   %c = icmp ugt %x, ~C
11665         //   %r = select %c, -1, %a
11666         //   =>
11667         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
11668         //   %u0 = extractvalue %u, 0
11669         //   %u1 = extractvalue %u, 1
11670         //   %r = select %u1, -1, %u0
11671         SDVTList VTs = DAG.getVTList(VT, VT0);
11672         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
11673         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
11674       }
11675     }
11676 
11677     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
11678         (!LegalOperations &&
11679          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
11680       // Any flags available in a select/setcc fold will be on the setcc as they
11681       // migrated from fcmp
11682       Flags = N0->getFlags();
11683       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
11684                                        N2, N0.getOperand(2));
11685       SelectNode->setFlags(Flags);
11686       return SelectNode;
11687     }
11688 
11689     if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
11690       return NewSel;
11691   }
11692 
11693   if (!VT.isVector())
11694     if (SDValue BinOp = foldSelectOfBinops(N))
11695       return BinOp;
11696 
11697   if (SDValue R = combineSelectAsExtAnd(N0, N1, N2, DL, DAG))
11698     return R;
11699 
11700   return SDValue();
11701 }
11702 
11703 // This function assumes all the vselect's arguments are CONCAT_VECTOR
11704 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)11705 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
11706   SDLoc DL(N);
11707   SDValue Cond = N->getOperand(0);
11708   SDValue LHS = N->getOperand(1);
11709   SDValue RHS = N->getOperand(2);
11710   EVT VT = N->getValueType(0);
11711   int NumElems = VT.getVectorNumElements();
11712   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
11713          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
11714          Cond.getOpcode() == ISD::BUILD_VECTOR);
11715 
11716   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
11717   // binary ones here.
11718   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
11719     return SDValue();
11720 
11721   // We're sure we have an even number of elements due to the
11722   // concat_vectors we have as arguments to vselect.
11723   // Skip BV elements until we find one that's not an UNDEF
11724   // After we find an UNDEF element, keep looping until we get to half the
11725   // length of the BV and see if all the non-undef nodes are the same.
11726   ConstantSDNode *BottomHalf = nullptr;
11727   for (int i = 0; i < NumElems / 2; ++i) {
11728     if (Cond->getOperand(i)->isUndef())
11729       continue;
11730 
11731     if (BottomHalf == nullptr)
11732       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
11733     else if (Cond->getOperand(i).getNode() != BottomHalf)
11734       return SDValue();
11735   }
11736 
11737   // Do the same for the second half of the BuildVector
11738   ConstantSDNode *TopHalf = nullptr;
11739   for (int i = NumElems / 2; i < NumElems; ++i) {
11740     if (Cond->getOperand(i)->isUndef())
11741       continue;
11742 
11743     if (TopHalf == nullptr)
11744       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
11745     else if (Cond->getOperand(i).getNode() != TopHalf)
11746       return SDValue();
11747   }
11748 
11749   assert(TopHalf && BottomHalf &&
11750          "One half of the selector was all UNDEFs and the other was all the "
11751          "same value. This should have been addressed before this function.");
11752   return DAG.getNode(
11753       ISD::CONCAT_VECTORS, DL, VT,
11754       BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
11755       TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
11756 }
11757 
refineUniformBase(SDValue & BasePtr,SDValue & Index,bool IndexIsScaled,SelectionDAG & DAG,const SDLoc & DL)11758 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
11759                        SelectionDAG &DAG, const SDLoc &DL) {
11760 
11761   // Only perform the transformation when existing operands can be reused.
11762   if (IndexIsScaled)
11763     return false;
11764 
11765   if (!isNullConstant(BasePtr) && !Index.hasOneUse())
11766     return false;
11767 
11768   EVT VT = BasePtr.getValueType();
11769 
11770   if (SDValue SplatVal = DAG.getSplatValue(Index);
11771       SplatVal && !isNullConstant(SplatVal) &&
11772       SplatVal.getValueType() == VT) {
11773     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
11774     Index = DAG.getSplat(Index.getValueType(), DL, DAG.getConstant(0, DL, VT));
11775     return true;
11776   }
11777 
11778   if (Index.getOpcode() != ISD::ADD)
11779     return false;
11780 
11781   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
11782       SplatVal && SplatVal.getValueType() == VT) {
11783     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
11784     Index = Index.getOperand(1);
11785     return true;
11786   }
11787   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
11788       SplatVal && SplatVal.getValueType() == VT) {
11789     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
11790     Index = Index.getOperand(0);
11791     return true;
11792   }
11793   return false;
11794 }
11795 
11796 // Fold sext/zext of index into index type.
refineIndexType(SDValue & Index,ISD::MemIndexType & IndexType,EVT DataVT,SelectionDAG & DAG)11797 bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
11798                      SelectionDAG &DAG) {
11799   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11800 
11801   // It's always safe to look through zero extends.
11802   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
11803     if (TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
11804       IndexType = ISD::UNSIGNED_SCALED;
11805       Index = Index.getOperand(0);
11806       return true;
11807     }
11808     if (ISD::isIndexTypeSigned(IndexType)) {
11809       IndexType = ISD::UNSIGNED_SCALED;
11810       return true;
11811     }
11812   }
11813 
11814   // It's only safe to look through sign extends when Index is signed.
11815   if (Index.getOpcode() == ISD::SIGN_EXTEND &&
11816       ISD::isIndexTypeSigned(IndexType) &&
11817       TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
11818     Index = Index.getOperand(0);
11819     return true;
11820   }
11821 
11822   return false;
11823 }
11824 
visitVPSCATTER(SDNode * N)11825 SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
11826   VPScatterSDNode *MSC = cast<VPScatterSDNode>(N);
11827   SDValue Mask = MSC->getMask();
11828   SDValue Chain = MSC->getChain();
11829   SDValue Index = MSC->getIndex();
11830   SDValue Scale = MSC->getScale();
11831   SDValue StoreVal = MSC->getValue();
11832   SDValue BasePtr = MSC->getBasePtr();
11833   SDValue VL = MSC->getVectorLength();
11834   ISD::MemIndexType IndexType = MSC->getIndexType();
11835   SDLoc DL(N);
11836 
11837   // Zap scatters with a zero mask.
11838   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11839     return Chain;
11840 
11841   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
11842     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11843     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11844                             DL, Ops, MSC->getMemOperand(), IndexType);
11845   }
11846 
11847   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
11848     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11849     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11850                             DL, Ops, MSC->getMemOperand(), IndexType);
11851   }
11852 
11853   return SDValue();
11854 }
11855 
visitMSCATTER(SDNode * N)11856 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
11857   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
11858   SDValue Mask = MSC->getMask();
11859   SDValue Chain = MSC->getChain();
11860   SDValue Index = MSC->getIndex();
11861   SDValue Scale = MSC->getScale();
11862   SDValue StoreVal = MSC->getValue();
11863   SDValue BasePtr = MSC->getBasePtr();
11864   ISD::MemIndexType IndexType = MSC->getIndexType();
11865   SDLoc DL(N);
11866 
11867   // Zap scatters with a zero mask.
11868   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11869     return Chain;
11870 
11871   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
11872     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11873     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11874                                 DL, Ops, MSC->getMemOperand(), IndexType,
11875                                 MSC->isTruncatingStore());
11876   }
11877 
11878   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
11879     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11880     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11881                                 DL, Ops, MSC->getMemOperand(), IndexType,
11882                                 MSC->isTruncatingStore());
11883   }
11884 
11885   return SDValue();
11886 }
11887 
visitMSTORE(SDNode * N)11888 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
11889   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
11890   SDValue Mask = MST->getMask();
11891   SDValue Chain = MST->getChain();
11892   SDValue Value = MST->getValue();
11893   SDValue Ptr = MST->getBasePtr();
11894   SDLoc DL(N);
11895 
11896   // Zap masked stores with a zero mask.
11897   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11898     return Chain;
11899 
11900   // Remove a masked store if base pointers and masks are equal.
11901   if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Chain)) {
11902     if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
11903         MST1->isSimple() && MST1->getBasePtr() == Ptr &&
11904         !MST->getBasePtr().isUndef() &&
11905         ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
11906                                          MST1->getMemoryVT().getStoreSize()) ||
11907          ISD::isConstantSplatVectorAllOnes(Mask.getNode())) &&
11908         TypeSize::isKnownLE(MST1->getMemoryVT().getStoreSize(),
11909                             MST->getMemoryVT().getStoreSize())) {
11910       CombineTo(MST1, MST1->getChain());
11911       if (N->getOpcode() != ISD::DELETED_NODE)
11912         AddToWorklist(N);
11913       return SDValue(N, 0);
11914     }
11915   }
11916 
11917   // If this is a masked load with an all ones mask, we can use a unmasked load.
11918   // FIXME: Can we do this for indexed, compressing, or truncating stores?
11919   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
11920       !MST->isCompressingStore() && !MST->isTruncatingStore())
11921     return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
11922                         MST->getBasePtr(), MST->getPointerInfo(),
11923                         MST->getOriginalAlign(), MachineMemOperand::MOStore,
11924                         MST->getAAInfo());
11925 
11926   // Try transforming N to an indexed store.
11927   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11928     return SDValue(N, 0);
11929 
11930   if (MST->isTruncatingStore() && MST->isUnindexed() &&
11931       Value.getValueType().isInteger() &&
11932       (!isa<ConstantSDNode>(Value) ||
11933        !cast<ConstantSDNode>(Value)->isOpaque())) {
11934     APInt TruncDemandedBits =
11935         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
11936                              MST->getMemoryVT().getScalarSizeInBits());
11937 
11938     // See if we can simplify the operation with
11939     // SimplifyDemandedBits, which only works if the value has a single use.
11940     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
11941       // Re-visit the store if anything changed and the store hasn't been merged
11942       // with another node (N is deleted) SimplifyDemandedBits will add Value's
11943       // node back to the worklist if necessary, but we also need to re-visit
11944       // the Store node itself.
11945       if (N->getOpcode() != ISD::DELETED_NODE)
11946         AddToWorklist(N);
11947       return SDValue(N, 0);
11948     }
11949   }
11950 
11951   // If this is a TRUNC followed by a masked store, fold this into a masked
11952   // truncating store.  We can do this even if this is already a masked
11953   // truncstore.
11954   // TODO: Try combine to masked compress store if possiable.
11955   if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
11956       MST->isUnindexed() && !MST->isCompressingStore() &&
11957       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
11958                                MST->getMemoryVT(), LegalOperations)) {
11959     auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
11960                                          Value.getOperand(0).getValueType());
11961     return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
11962                               MST->getOffset(), Mask, MST->getMemoryVT(),
11963                               MST->getMemOperand(), MST->getAddressingMode(),
11964                               /*IsTruncating=*/true);
11965   }
11966 
11967   return SDValue();
11968 }
11969 
visitVP_STRIDED_STORE(SDNode * N)11970 SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
11971   auto *SST = cast<VPStridedStoreSDNode>(N);
11972   EVT EltVT = SST->getValue().getValueType().getVectorElementType();
11973   // Combine strided stores with unit-stride to a regular VP store.
11974   if (auto *CStride = dyn_cast<ConstantSDNode>(SST->getStride());
11975       CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
11976     return DAG.getStoreVP(SST->getChain(), SDLoc(N), SST->getValue(),
11977                           SST->getBasePtr(), SST->getOffset(), SST->getMask(),
11978                           SST->getVectorLength(), SST->getMemoryVT(),
11979                           SST->getMemOperand(), SST->getAddressingMode(),
11980                           SST->isTruncatingStore(), SST->isCompressingStore());
11981   }
11982   return SDValue();
11983 }
11984 
visitVPGATHER(SDNode * N)11985 SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
11986   VPGatherSDNode *MGT = cast<VPGatherSDNode>(N);
11987   SDValue Mask = MGT->getMask();
11988   SDValue Chain = MGT->getChain();
11989   SDValue Index = MGT->getIndex();
11990   SDValue Scale = MGT->getScale();
11991   SDValue BasePtr = MGT->getBasePtr();
11992   SDValue VL = MGT->getVectorLength();
11993   ISD::MemIndexType IndexType = MGT->getIndexType();
11994   SDLoc DL(N);
11995 
11996   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
11997     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
11998     return DAG.getGatherVP(
11999         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12000         Ops, MGT->getMemOperand(), IndexType);
12001   }
12002 
12003   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
12004     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12005     return DAG.getGatherVP(
12006         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12007         Ops, MGT->getMemOperand(), IndexType);
12008   }
12009 
12010   return SDValue();
12011 }
12012 
visitMGATHER(SDNode * N)12013 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
12014   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
12015   SDValue Mask = MGT->getMask();
12016   SDValue Chain = MGT->getChain();
12017   SDValue Index = MGT->getIndex();
12018   SDValue Scale = MGT->getScale();
12019   SDValue PassThru = MGT->getPassThru();
12020   SDValue BasePtr = MGT->getBasePtr();
12021   ISD::MemIndexType IndexType = MGT->getIndexType();
12022   SDLoc DL(N);
12023 
12024   // Zap gathers with a zero mask.
12025   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12026     return CombineTo(N, PassThru, MGT->getChain());
12027 
12028   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
12029     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12030     return DAG.getMaskedGather(
12031         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12032         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12033   }
12034 
12035   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
12036     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12037     return DAG.getMaskedGather(
12038         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12039         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12040   }
12041 
12042   return SDValue();
12043 }
12044 
visitMLOAD(SDNode * N)12045 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
12046   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
12047   SDValue Mask = MLD->getMask();
12048   SDLoc DL(N);
12049 
12050   // Zap masked loads with a zero mask.
12051   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12052     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
12053 
12054   // If this is a masked load with an all ones mask, we can use a unmasked load.
12055   // FIXME: Can we do this for indexed, expanding, or extending loads?
12056   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
12057       !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
12058     SDValue NewLd = DAG.getLoad(
12059         N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
12060         MLD->getPointerInfo(), MLD->getOriginalAlign(),
12061         MachineMemOperand::MOLoad, MLD->getAAInfo(), MLD->getRanges());
12062     return CombineTo(N, NewLd, NewLd.getValue(1));
12063   }
12064 
12065   // Try transforming N to an indexed load.
12066   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12067     return SDValue(N, 0);
12068 
12069   return SDValue();
12070 }
12071 
visitVP_STRIDED_LOAD(SDNode * N)12072 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
12073   auto *SLD = cast<VPStridedLoadSDNode>(N);
12074   EVT EltVT = SLD->getValueType(0).getVectorElementType();
12075   // Combine strided loads with unit-stride to a regular VP load.
12076   if (auto *CStride = dyn_cast<ConstantSDNode>(SLD->getStride());
12077       CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12078     SDValue NewLd = DAG.getLoadVP(
12079         SLD->getAddressingMode(), SLD->getExtensionType(), SLD->getValueType(0),
12080         SDLoc(N), SLD->getChain(), SLD->getBasePtr(), SLD->getOffset(),
12081         SLD->getMask(), SLD->getVectorLength(), SLD->getMemoryVT(),
12082         SLD->getMemOperand(), SLD->isExpandingLoad());
12083     return CombineTo(N, NewLd, NewLd.getValue(1));
12084   }
12085   return SDValue();
12086 }
12087 
12088 /// A vector select of 2 constant vectors can be simplified to math/logic to
12089 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)12090 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
12091   SDValue Cond = N->getOperand(0);
12092   SDValue N1 = N->getOperand(1);
12093   SDValue N2 = N->getOperand(2);
12094   EVT VT = N->getValueType(0);
12095   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
12096       !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
12097       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
12098       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
12099     return SDValue();
12100 
12101   // Check if we can use the condition value to increment/decrement a single
12102   // constant value. This simplifies a select to an add and removes a constant
12103   // load/materialization from the general case.
12104   bool AllAddOne = true;
12105   bool AllSubOne = true;
12106   unsigned Elts = VT.getVectorNumElements();
12107   for (unsigned i = 0; i != Elts; ++i) {
12108     SDValue N1Elt = N1.getOperand(i);
12109     SDValue N2Elt = N2.getOperand(i);
12110     if (N1Elt.isUndef() || N2Elt.isUndef())
12111       continue;
12112     if (N1Elt.getValueType() != N2Elt.getValueType())
12113       continue;
12114 
12115     const APInt &C1 = N1Elt->getAsAPIntVal();
12116     const APInt &C2 = N2Elt->getAsAPIntVal();
12117     if (C1 != C2 + 1)
12118       AllAddOne = false;
12119     if (C1 != C2 - 1)
12120       AllSubOne = false;
12121   }
12122 
12123   // Further simplifications for the extra-special cases where the constants are
12124   // all 0 or all -1 should be implemented as folds of these patterns.
12125   SDLoc DL(N);
12126   if (AllAddOne || AllSubOne) {
12127     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
12128     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
12129     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
12130     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
12131     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
12132   }
12133 
12134   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
12135   APInt Pow2C;
12136   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
12137       isNullOrNullSplat(N2)) {
12138     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
12139     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
12140     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
12141   }
12142 
12143   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
12144     return V;
12145 
12146   // The general case for select-of-constants:
12147   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
12148   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
12149   // leave that to a machine-specific pass.
12150   return SDValue();
12151 }
12152 
visitVSELECT(SDNode * N)12153 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
12154   SDValue N0 = N->getOperand(0);
12155   SDValue N1 = N->getOperand(1);
12156   SDValue N2 = N->getOperand(2);
12157   EVT VT = N->getValueType(0);
12158   SDLoc DL(N);
12159 
12160   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
12161     return V;
12162 
12163   if (SDValue V = foldBoolSelectToLogic(N, DAG))
12164     return V;
12165 
12166   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
12167   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
12168     return DAG.getSelect(DL, VT, F, N2, N1);
12169 
12170   // Canonicalize integer abs.
12171   // vselect (setg[te] X,  0),  X, -X ->
12172   // vselect (setgt    X, -1),  X, -X ->
12173   // vselect (setl[te] X,  0), -X,  X ->
12174   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
12175   if (N0.getOpcode() == ISD::SETCC) {
12176     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
12177     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
12178     bool isAbs = false;
12179     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
12180 
12181     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
12182          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
12183         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
12184       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
12185     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
12186              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
12187       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
12188 
12189     if (isAbs) {
12190       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
12191         return DAG.getNode(ISD::ABS, DL, VT, LHS);
12192 
12193       SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
12194                                   DAG.getConstant(VT.getScalarSizeInBits() - 1,
12195                                                   DL, getShiftAmountTy(VT)));
12196       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
12197       AddToWorklist(Shift.getNode());
12198       AddToWorklist(Add.getNode());
12199       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
12200     }
12201 
12202     // vselect x, y (fcmp lt x, y) -> fminnum x, y
12203     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
12204     //
12205     // This is OK if we don't care about what happens if either operand is a
12206     // NaN.
12207     //
12208     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
12209       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC))
12210         return FMinMax;
12211     }
12212 
12213     if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
12214       return S;
12215     if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
12216       return S;
12217 
12218     // If this select has a condition (setcc) with narrower operands than the
12219     // select, try to widen the compare to match the select width.
12220     // TODO: This should be extended to handle any constant.
12221     // TODO: This could be extended to handle non-loading patterns, but that
12222     //       requires thorough testing to avoid regressions.
12223     if (isNullOrNullSplat(RHS)) {
12224       EVT NarrowVT = LHS.getValueType();
12225       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
12226       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
12227       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
12228       unsigned WideWidth = WideVT.getScalarSizeInBits();
12229       bool IsSigned = isSignedIntSetCC(CC);
12230       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12231       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
12232           SetCCWidth != 1 && SetCCWidth < WideWidth &&
12233           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
12234           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
12235         // Both compare operands can be widened for free. The LHS can use an
12236         // extended load, and the RHS is a constant:
12237         //   vselect (ext (setcc load(X), C)), N1, N2 -->
12238         //   vselect (setcc extload(X), C'), N1, N2
12239         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
12240         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
12241         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
12242         EVT WideSetCCVT = getSetCCResultType(WideVT);
12243         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
12244         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
12245       }
12246     }
12247 
12248     // Match VSELECTs with absolute difference patterns.
12249     // (vselect (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12250     // (vselect (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12251     // (vselect (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12252     // (vselect (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
12253     if (N1.getOpcode() == ISD::SUB && N2.getOpcode() == ISD::SUB &&
12254         N1.getOperand(0) == N2.getOperand(1) &&
12255         N1.getOperand(1) == N2.getOperand(0)) {
12256       bool IsSigned = isSignedIntSetCC(CC);
12257       unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12258       if (hasOperation(ABDOpc, VT)) {
12259         switch (CC) {
12260         case ISD::SETGT:
12261         case ISD::SETGE:
12262         case ISD::SETUGT:
12263         case ISD::SETUGE:
12264           if (LHS == N1.getOperand(0) && RHS == N1.getOperand(1))
12265             return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12266           break;
12267         case ISD::SETLT:
12268         case ISD::SETLE:
12269         case ISD::SETULT:
12270         case ISD::SETULE:
12271           if (RHS == N1.getOperand(0) && LHS == N1.getOperand(1) )
12272             return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12273           break;
12274         default:
12275           break;
12276         }
12277       }
12278     }
12279 
12280     // Match VSELECTs into add with unsigned saturation.
12281     if (hasOperation(ISD::UADDSAT, VT)) {
12282       // Check if one of the arms of the VSELECT is vector with all bits set.
12283       // If it's on the left side invert the predicate to simplify logic below.
12284       SDValue Other;
12285       ISD::CondCode SatCC = CC;
12286       if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
12287         Other = N2;
12288         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
12289       } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
12290         Other = N1;
12291       }
12292 
12293       if (Other && Other.getOpcode() == ISD::ADD) {
12294         SDValue CondLHS = LHS, CondRHS = RHS;
12295         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
12296 
12297         // Canonicalize condition operands.
12298         if (SatCC == ISD::SETUGE) {
12299           std::swap(CondLHS, CondRHS);
12300           SatCC = ISD::SETULE;
12301         }
12302 
12303         // We can test against either of the addition operands.
12304         // x <= x+y ? x+y : ~0 --> uaddsat x, y
12305         // x+y >= x ? x+y : ~0 --> uaddsat x, y
12306         if (SatCC == ISD::SETULE && Other == CondRHS &&
12307             (OpLHS == CondLHS || OpRHS == CondLHS))
12308           return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
12309 
12310         if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
12311             (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12312              OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
12313             CondLHS == OpLHS) {
12314           // If the RHS is a constant we have to reverse the const
12315           // canonicalization.
12316           // x >= ~C ? x+C : ~0 --> uaddsat x, C
12317           auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12318             return Cond->getAPIntValue() == ~Op->getAPIntValue();
12319           };
12320           if (SatCC == ISD::SETULE &&
12321               ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
12322             return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
12323         }
12324       }
12325     }
12326 
12327     // Match VSELECTs into sub with unsigned saturation.
12328     if (hasOperation(ISD::USUBSAT, VT)) {
12329       // Check if one of the arms of the VSELECT is a zero vector. If it's on
12330       // the left side invert the predicate to simplify logic below.
12331       SDValue Other;
12332       ISD::CondCode SatCC = CC;
12333       if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
12334         Other = N2;
12335         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
12336       } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
12337         Other = N1;
12338       }
12339 
12340       // zext(x) >= y ? trunc(zext(x) - y) : 0
12341       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12342       // zext(x) >  y ? trunc(zext(x) - y) : 0
12343       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12344       if (Other && Other.getOpcode() == ISD::TRUNCATE &&
12345           Other.getOperand(0).getOpcode() == ISD::SUB &&
12346           (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
12347         SDValue OpLHS = Other.getOperand(0).getOperand(0);
12348         SDValue OpRHS = Other.getOperand(0).getOperand(1);
12349         if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
12350           if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
12351                                               DAG, DL))
12352             return R;
12353       }
12354 
12355       if (Other && Other.getNumOperands() == 2) {
12356         SDValue CondRHS = RHS;
12357         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
12358 
12359         if (OpLHS == LHS) {
12360           // Look for a general sub with unsigned saturation first.
12361           // x >= y ? x-y : 0 --> usubsat x, y
12362           // x >  y ? x-y : 0 --> usubsat x, y
12363           if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
12364               Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
12365             return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
12366 
12367           if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12368               OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12369             if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
12370                 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12371               // If the RHS is a constant we have to reverse the const
12372               // canonicalization.
12373               // x > C-1 ? x+-C : 0 --> usubsat x, C
12374               auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12375                 return (!Op && !Cond) ||
12376                        (Op && Cond &&
12377                         Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
12378               };
12379               if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
12380                   ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
12381                                             /*AllowUndefs*/ true)) {
12382                 OpRHS = DAG.getNegative(OpRHS, DL, VT);
12383                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
12384               }
12385 
12386               // Another special case: If C was a sign bit, the sub has been
12387               // canonicalized into a xor.
12388               // FIXME: Would it be better to use computeKnownBits to
12389               // determine whether it's safe to decanonicalize the xor?
12390               // x s< 0 ? x^C : 0 --> usubsat x, C
12391               APInt SplatValue;
12392               if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
12393                   ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
12394                   ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
12395                   SplatValue.isSignMask()) {
12396                 // Note that we have to rebuild the RHS constant here to
12397                 // ensure we don't rely on particular values of undef lanes.
12398                 OpRHS = DAG.getConstant(SplatValue, DL, VT);
12399                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
12400               }
12401             }
12402           }
12403         }
12404       }
12405     }
12406   }
12407 
12408   if (SimplifySelectOps(N, N1, N2))
12409     return SDValue(N, 0);  // Don't revisit N.
12410 
12411   // Fold (vselect all_ones, N1, N2) -> N1
12412   if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
12413     return N1;
12414   // Fold (vselect all_zeros, N1, N2) -> N2
12415   if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
12416     return N2;
12417 
12418   // The ConvertSelectToConcatVector function is assuming both the above
12419   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
12420   // and addressed.
12421   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
12422       N2.getOpcode() == ISD::CONCAT_VECTORS &&
12423       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
12424     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
12425       return CV;
12426   }
12427 
12428   if (SDValue V = foldVSelectOfConstants(N))
12429     return V;
12430 
12431   if (hasOperation(ISD::SRA, VT))
12432     if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
12433       return V;
12434 
12435   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
12436     return SDValue(N, 0);
12437 
12438   return SDValue();
12439 }
12440 
visitSELECT_CC(SDNode * N)12441 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
12442   SDValue N0 = N->getOperand(0);
12443   SDValue N1 = N->getOperand(1);
12444   SDValue N2 = N->getOperand(2);
12445   SDValue N3 = N->getOperand(3);
12446   SDValue N4 = N->getOperand(4);
12447   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
12448 
12449   // fold select_cc lhs, rhs, x, x, cc -> x
12450   if (N2 == N3)
12451     return N2;
12452 
12453   // select_cc bool, 0, x, y, seteq -> select bool, y, x
12454   if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
12455       isNullConstant(N1))
12456     return DAG.getSelect(SDLoc(N), N2.getValueType(), N0, N3, N2);
12457 
12458   // Determine if the condition we're dealing with is constant
12459   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
12460                                   CC, SDLoc(N), false)) {
12461     AddToWorklist(SCC.getNode());
12462 
12463     // cond always true -> true val
12464     // cond always false -> false val
12465     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
12466       return SCCC->isZero() ? N3 : N2;
12467 
12468     // When the condition is UNDEF, just return the first operand. This is
12469     // coherent the DAG creation, no setcc node is created in this case
12470     if (SCC->isUndef())
12471       return N2;
12472 
12473     // Fold to a simpler select_cc
12474     if (SCC.getOpcode() == ISD::SETCC) {
12475       SDValue SelectOp = DAG.getNode(
12476           ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
12477           SCC.getOperand(1), N2, N3, SCC.getOperand(2));
12478       SelectOp->setFlags(SCC->getFlags());
12479       return SelectOp;
12480     }
12481   }
12482 
12483   // If we can fold this based on the true/false value, do so.
12484   if (SimplifySelectOps(N, N2, N3))
12485     return SDValue(N, 0);  // Don't revisit N.
12486 
12487   // fold select_cc into other things, such as min/max/abs
12488   return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
12489 }
12490 
visitSETCC(SDNode * N)12491 SDValue DAGCombiner::visitSETCC(SDNode *N) {
12492   // setcc is very commonly used as an argument to brcond. This pattern
12493   // also lend itself to numerous combines and, as a result, it is desired
12494   // we keep the argument to a brcond as a setcc as much as possible.
12495   bool PreferSetCC =
12496       N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
12497 
12498   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
12499   EVT VT = N->getValueType(0);
12500   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
12501 
12502   SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, SDLoc(N), !PreferSetCC);
12503 
12504   if (Combined) {
12505     // If we prefer to have a setcc, and we don't, we'll try our best to
12506     // recreate one using rebuildSetCC.
12507     if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12508       SDValue NewSetCC = rebuildSetCC(Combined);
12509 
12510       // We don't have anything interesting to combine to.
12511       if (NewSetCC.getNode() == N)
12512         return SDValue();
12513 
12514       if (NewSetCC)
12515         return NewSetCC;
12516     }
12517     return Combined;
12518   }
12519 
12520   // Optimize
12521   //    1) (icmp eq/ne (and X, C0), (shift X, C1))
12522   // or
12523   //    2) (icmp eq/ne X, (rotate X, C1))
12524   // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
12525   // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
12526   // Then:
12527   // If C1 is a power of 2, then the rotate and shift+and versions are
12528   // equivilent, so we can interchange them depending on target preference.
12529   // Otherwise, if we have the shift+and version we can interchange srl/shl
12530   // which inturn affects the constant C0. We can use this to get better
12531   // constants again determined by target preference.
12532   if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
12533     auto IsAndWithShift = [](SDValue A, SDValue B) {
12534       return A.getOpcode() == ISD::AND &&
12535              (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
12536              A.getOperand(0) == B.getOperand(0);
12537     };
12538     auto IsRotateWithOp = [](SDValue A, SDValue B) {
12539       return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
12540              B.getOperand(0) == A;
12541     };
12542     SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
12543     bool IsRotate = false;
12544 
12545     // Find either shift+and or rotate pattern.
12546     if (IsAndWithShift(N0, N1)) {
12547       AndOrOp = N0;
12548       ShiftOrRotate = N1;
12549     } else if (IsAndWithShift(N1, N0)) {
12550       AndOrOp = N1;
12551       ShiftOrRotate = N0;
12552     } else if (IsRotateWithOp(N0, N1)) {
12553       IsRotate = true;
12554       AndOrOp = N0;
12555       ShiftOrRotate = N1;
12556     } else if (IsRotateWithOp(N1, N0)) {
12557       IsRotate = true;
12558       AndOrOp = N1;
12559       ShiftOrRotate = N0;
12560     }
12561 
12562     if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
12563         (IsRotate || AndOrOp.hasOneUse())) {
12564       EVT OpVT = N0.getValueType();
12565       // Get constant shift/rotate amount and possibly mask (if its shift+and
12566       // variant).
12567       auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
12568         ConstantSDNode *CNode = isConstOrConstSplat(Op, /*AllowUndefs*/ false,
12569                                                     /*AllowTrunc*/ false);
12570         if (CNode == nullptr)
12571           return std::nullopt;
12572         return CNode->getAPIntValue();
12573       };
12574       std::optional<APInt> AndCMask =
12575           IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(1));
12576       std::optional<APInt> ShiftCAmt =
12577           GetAPIntValue(ShiftOrRotate.getOperand(1));
12578       unsigned NumBits = OpVT.getScalarSizeInBits();
12579 
12580       // We found constants.
12581       if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(NumBits)) {
12582         unsigned ShiftOpc = ShiftOrRotate.getOpcode();
12583         // Check that the constants meet the constraints.
12584         bool CanTransform = IsRotate;
12585         if (!CanTransform) {
12586           // Check that mask and shift compliment eachother
12587           CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
12588           // Check that we are comparing all bits
12589           CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
12590           // Check that the and mask is correct for the shift
12591           CanTransform &=
12592               ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
12593         }
12594 
12595         // See if target prefers another shift/rotate opcode.
12596         unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
12597             OpVT, ShiftOpc, ShiftCAmt->isPowerOf2(), *ShiftCAmt, AndCMask);
12598         // Transform is valid and we have a new preference.
12599         if (CanTransform && NewShiftOpc != ShiftOpc) {
12600           SDLoc DL(N);
12601           SDValue NewShiftOrRotate =
12602               DAG.getNode(NewShiftOpc, DL, OpVT, ShiftOrRotate.getOperand(0),
12603                           ShiftOrRotate.getOperand(1));
12604           SDValue NewAndOrOp = SDValue();
12605 
12606           if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
12607             APInt NewMask =
12608                 NewShiftOpc == ISD::SHL
12609                     ? APInt::getHighBitsSet(NumBits,
12610                                             NumBits - ShiftCAmt->getZExtValue())
12611                     : APInt::getLowBitsSet(NumBits,
12612                                            NumBits - ShiftCAmt->getZExtValue());
12613             NewAndOrOp =
12614                 DAG.getNode(ISD::AND, DL, OpVT, ShiftOrRotate.getOperand(0),
12615                             DAG.getConstant(NewMask, DL, OpVT));
12616           } else {
12617             NewAndOrOp = ShiftOrRotate.getOperand(0);
12618           }
12619 
12620           return DAG.getSetCC(DL, VT, NewAndOrOp, NewShiftOrRotate, Cond);
12621         }
12622       }
12623     }
12624   }
12625   return SDValue();
12626 }
12627 
visitSETCCCARRY(SDNode * N)12628 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
12629   SDValue LHS = N->getOperand(0);
12630   SDValue RHS = N->getOperand(1);
12631   SDValue Carry = N->getOperand(2);
12632   SDValue Cond = N->getOperand(3);
12633 
12634   // If Carry is false, fold to a regular SETCC.
12635   if (isNullConstant(Carry))
12636     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
12637 
12638   return SDValue();
12639 }
12640 
12641 /// Check if N satisfies:
12642 ///   N is used once.
12643 ///   N is a Load.
12644 ///   The load is compatible with ExtOpcode. It means
12645 ///     If load has explicit zero/sign extension, ExpOpcode must have the same
12646 ///     extension.
12647 ///     Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)12648 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
12649   if (!N.hasOneUse())
12650     return false;
12651 
12652   if (!isa<LoadSDNode>(N))
12653     return false;
12654 
12655   LoadSDNode *Load = cast<LoadSDNode>(N);
12656   ISD::LoadExtType LoadExt = Load->getExtensionType();
12657   if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
12658     return true;
12659 
12660   // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
12661   // extension.
12662   if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
12663       (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
12664     return false;
12665 
12666   return true;
12667 }
12668 
12669 /// Fold
12670 ///   (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
12671 ///   (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
12672 ///   (aext (select c, load x, load y)) -> (select c, extload x, extload y)
12673 /// This function is called by the DAGCombiner when visiting sext/zext/aext
12674 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,CombineLevel Level)12675 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
12676                                          SelectionDAG &DAG,
12677                                          CombineLevel Level) {
12678   unsigned Opcode = N->getOpcode();
12679   SDValue N0 = N->getOperand(0);
12680   EVT VT = N->getValueType(0);
12681   SDLoc DL(N);
12682 
12683   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
12684           Opcode == ISD::ANY_EXTEND) &&
12685          "Expected EXTEND dag node in input!");
12686 
12687   if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
12688       !N0.hasOneUse())
12689     return SDValue();
12690 
12691   SDValue Op1 = N0->getOperand(1);
12692   SDValue Op2 = N0->getOperand(2);
12693   if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
12694     return SDValue();
12695 
12696   auto ExtLoadOpcode = ISD::EXTLOAD;
12697   if (Opcode == ISD::SIGN_EXTEND)
12698     ExtLoadOpcode = ISD::SEXTLOAD;
12699   else if (Opcode == ISD::ZERO_EXTEND)
12700     ExtLoadOpcode = ISD::ZEXTLOAD;
12701 
12702   // Illegal VSELECT may ISel fail if happen after legalization (DAG
12703   // Combine2), so we should conservatively check the OperationAction.
12704   LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
12705   LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
12706   if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
12707       !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) ||
12708       (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
12709        TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal))
12710     return SDValue();
12711 
12712   SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
12713   SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
12714   return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
12715 }
12716 
12717 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
12718 /// a build_vector of constants.
12719 /// This function is called by the DAGCombiner when visiting sext/zext/aext
12720 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
12721 /// Vector extends are not folded if operations are legal; this is to
12722 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)12723 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
12724                                          SelectionDAG &DAG, bool LegalTypes) {
12725   unsigned Opcode = N->getOpcode();
12726   SDValue N0 = N->getOperand(0);
12727   EVT VT = N->getValueType(0);
12728   SDLoc DL(N);
12729 
12730   assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
12731          "Expected EXTEND dag node in input!");
12732 
12733   // fold (sext c1) -> c1
12734   // fold (zext c1) -> c1
12735   // fold (aext c1) -> c1
12736   if (isa<ConstantSDNode>(N0))
12737     return DAG.getNode(Opcode, DL, VT, N0);
12738 
12739   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12740   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
12741   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12742   if (N0->getOpcode() == ISD::SELECT) {
12743     SDValue Op1 = N0->getOperand(1);
12744     SDValue Op2 = N0->getOperand(2);
12745     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
12746         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
12747       // For any_extend, choose sign extension of the constants to allow a
12748       // possible further transform to sign_extend_inreg.i.e.
12749       //
12750       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
12751       // t2: i64 = any_extend t1
12752       // -->
12753       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
12754       // -->
12755       // t4: i64 = sign_extend_inreg t3
12756       unsigned FoldOpc = Opcode;
12757       if (FoldOpc == ISD::ANY_EXTEND)
12758         FoldOpc = ISD::SIGN_EXTEND;
12759       return DAG.getSelect(DL, VT, N0->getOperand(0),
12760                            DAG.getNode(FoldOpc, DL, VT, Op1),
12761                            DAG.getNode(FoldOpc, DL, VT, Op2));
12762     }
12763   }
12764 
12765   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
12766   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
12767   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
12768   EVT SVT = VT.getScalarType();
12769   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
12770       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
12771     return SDValue();
12772 
12773   // We can fold this node into a build_vector.
12774   unsigned VTBits = SVT.getSizeInBits();
12775   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
12776   SmallVector<SDValue, 8> Elts;
12777   unsigned NumElts = VT.getVectorNumElements();
12778 
12779   for (unsigned i = 0; i != NumElts; ++i) {
12780     SDValue Op = N0.getOperand(i);
12781     if (Op.isUndef()) {
12782       if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
12783         Elts.push_back(DAG.getUNDEF(SVT));
12784       else
12785         Elts.push_back(DAG.getConstant(0, DL, SVT));
12786       continue;
12787     }
12788 
12789     SDLoc DL(Op);
12790     // Get the constant value and if needed trunc it to the size of the type.
12791     // Nodes like build_vector might have constants wider than the scalar type.
12792     APInt C = Op->getAsAPIntVal().zextOrTrunc(EVTBits);
12793     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
12794       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
12795     else
12796       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
12797   }
12798 
12799   return DAG.getBuildVector(VT, DL, Elts);
12800 }
12801 
12802 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
12803 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
12804 // transformation. Returns true if extension are possible and the above
12805 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)12806 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
12807                                     unsigned ExtOpc,
12808                                     SmallVectorImpl<SDNode *> &ExtendNodes,
12809                                     const TargetLowering &TLI) {
12810   bool HasCopyToRegUses = false;
12811   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
12812   for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
12813        ++UI) {
12814     SDNode *User = *UI;
12815     if (User == N)
12816       continue;
12817     if (UI.getUse().getResNo() != N0.getResNo())
12818       continue;
12819     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
12820     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
12821       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
12822       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
12823         // Sign bits will be lost after a zext.
12824         return false;
12825       bool Add = false;
12826       for (unsigned i = 0; i != 2; ++i) {
12827         SDValue UseOp = User->getOperand(i);
12828         if (UseOp == N0)
12829           continue;
12830         if (!isa<ConstantSDNode>(UseOp))
12831           return false;
12832         Add = true;
12833       }
12834       if (Add)
12835         ExtendNodes.push_back(User);
12836       continue;
12837     }
12838     // If truncates aren't free and there are users we can't
12839     // extend, it isn't worthwhile.
12840     if (!isTruncFree)
12841       return false;
12842     // Remember if this value is live-out.
12843     if (User->getOpcode() == ISD::CopyToReg)
12844       HasCopyToRegUses = true;
12845   }
12846 
12847   if (HasCopyToRegUses) {
12848     bool BothLiveOut = false;
12849     for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
12850          UI != UE; ++UI) {
12851       SDUse &Use = UI.getUse();
12852       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
12853         BothLiveOut = true;
12854         break;
12855       }
12856     }
12857     if (BothLiveOut)
12858       // Both unextended and extended values are live out. There had better be
12859       // a good reason for the transformation.
12860       return !ExtendNodes.empty();
12861   }
12862   return true;
12863 }
12864 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)12865 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
12866                                   SDValue OrigLoad, SDValue ExtLoad,
12867                                   ISD::NodeType ExtType) {
12868   // Extend SetCC uses if necessary.
12869   SDLoc DL(ExtLoad);
12870   for (SDNode *SetCC : SetCCs) {
12871     SmallVector<SDValue, 4> Ops;
12872 
12873     for (unsigned j = 0; j != 2; ++j) {
12874       SDValue SOp = SetCC->getOperand(j);
12875       if (SOp == OrigLoad)
12876         Ops.push_back(ExtLoad);
12877       else
12878         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
12879     }
12880 
12881     Ops.push_back(SetCC->getOperand(2));
12882     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
12883   }
12884 }
12885 
12886 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)12887 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
12888   SDValue N0 = N->getOperand(0);
12889   EVT DstVT = N->getValueType(0);
12890   EVT SrcVT = N0.getValueType();
12891 
12892   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
12893           N->getOpcode() == ISD::ZERO_EXTEND) &&
12894          "Unexpected node type (not an extend)!");
12895 
12896   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
12897   // For example, on a target with legal v4i32, but illegal v8i32, turn:
12898   //   (v8i32 (sext (v8i16 (load x))))
12899   // into:
12900   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
12901   //                          (v4i32 (sextload (x + 16)))))
12902   // Where uses of the original load, i.e.:
12903   //   (v8i16 (load x))
12904   // are replaced with:
12905   //   (v8i16 (truncate
12906   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
12907   //                            (v4i32 (sextload (x + 16)))))))
12908   //
12909   // This combine is only applicable to illegal, but splittable, vectors.
12910   // All legal types, and illegal non-vector types, are handled elsewhere.
12911   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
12912   //
12913   if (N0->getOpcode() != ISD::LOAD)
12914     return SDValue();
12915 
12916   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12917 
12918   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
12919       !N0.hasOneUse() || !LN0->isSimple() ||
12920       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
12921       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
12922     return SDValue();
12923 
12924   SmallVector<SDNode *, 4> SetCCs;
12925   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
12926     return SDValue();
12927 
12928   ISD::LoadExtType ExtType =
12929       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12930 
12931   // Try to split the vector types to get down to legal types.
12932   EVT SplitSrcVT = SrcVT;
12933   EVT SplitDstVT = DstVT;
12934   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
12935          SplitSrcVT.getVectorNumElements() > 1) {
12936     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
12937     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
12938   }
12939 
12940   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
12941     return SDValue();
12942 
12943   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
12944 
12945   SDLoc DL(N);
12946   const unsigned NumSplits =
12947       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
12948   const unsigned Stride = SplitSrcVT.getStoreSize();
12949   SmallVector<SDValue, 4> Loads;
12950   SmallVector<SDValue, 4> Chains;
12951 
12952   SDValue BasePtr = LN0->getBasePtr();
12953   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
12954     const unsigned Offset = Idx * Stride;
12955     const Align Align = commonAlignment(LN0->getAlign(), Offset);
12956 
12957     SDValue SplitLoad = DAG.getExtLoad(
12958         ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
12959         LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
12960         LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
12961 
12962     BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::getFixed(Stride), DL);
12963 
12964     Loads.push_back(SplitLoad.getValue(0));
12965     Chains.push_back(SplitLoad.getValue(1));
12966   }
12967 
12968   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
12969   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
12970 
12971   // Simplify TF.
12972   AddToWorklist(NewChain.getNode());
12973 
12974   CombineTo(N, NewValue);
12975 
12976   // Replace uses of the original load (before extension)
12977   // with a truncate of the concatenated sextloaded vectors.
12978   SDValue Trunc =
12979       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
12980   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
12981   CombineTo(N0.getNode(), Trunc, NewChain);
12982   return SDValue(N, 0); // Return N so it doesn't get rechecked!
12983 }
12984 
12985 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
12986 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)12987 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
12988   assert(N->getOpcode() == ISD::ZERO_EXTEND);
12989   EVT VT = N->getValueType(0);
12990   EVT OrigVT = N->getOperand(0).getValueType();
12991   if (TLI.isZExtFree(OrigVT, VT))
12992     return SDValue();
12993 
12994   // and/or/xor
12995   SDValue N0 = N->getOperand(0);
12996   if (!ISD::isBitwiseLogicOp(N0.getOpcode()) ||
12997       N0.getOperand(1).getOpcode() != ISD::Constant ||
12998       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
12999     return SDValue();
13000 
13001   // shl/shr
13002   SDValue N1 = N0->getOperand(0);
13003   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
13004       N1.getOperand(1).getOpcode() != ISD::Constant ||
13005       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
13006     return SDValue();
13007 
13008   // load
13009   if (!isa<LoadSDNode>(N1.getOperand(0)))
13010     return SDValue();
13011   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
13012   EVT MemVT = Load->getMemoryVT();
13013   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
13014       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
13015     return SDValue();
13016 
13017 
13018   // If the shift op is SHL, the logic op must be AND, otherwise the result
13019   // will be wrong.
13020   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
13021     return SDValue();
13022 
13023   if (!N0.hasOneUse() || !N1.hasOneUse())
13024     return SDValue();
13025 
13026   SmallVector<SDNode*, 4> SetCCs;
13027   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
13028                                ISD::ZERO_EXTEND, SetCCs, TLI))
13029     return SDValue();
13030 
13031   // Actually do the transformation.
13032   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
13033                                    Load->getChain(), Load->getBasePtr(),
13034                                    Load->getMemoryVT(), Load->getMemOperand());
13035 
13036   SDLoc DL1(N1);
13037   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
13038                               N1.getOperand(1));
13039 
13040   APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
13041   SDLoc DL0(N0);
13042   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
13043                             DAG.getConstant(Mask, DL0, VT));
13044 
13045   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
13046   CombineTo(N, And);
13047   if (SDValue(Load, 0).hasOneUse()) {
13048     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
13049   } else {
13050     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
13051                                 Load->getValueType(0), ExtLoad);
13052     CombineTo(Load, Trunc, ExtLoad.getValue(1));
13053   }
13054 
13055   // N0 is dead at this point.
13056   recursivelyDeleteUnusedNodes(N0.getNode());
13057 
13058   return SDValue(N,0); // Return N so it doesn't get rechecked!
13059 }
13060 
13061 /// If we're narrowing or widening the result of a vector select and the final
13062 /// size is the same size as a setcc (compare) feeding the select, then try to
13063 /// apply the cast operation to the select's operands because matching vector
13064 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)13065 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
13066   unsigned CastOpcode = Cast->getOpcode();
13067   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
13068           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13069           CastOpcode == ISD::FP_ROUND) &&
13070          "Unexpected opcode for vector select narrowing/widening");
13071 
13072   // We only do this transform before legal ops because the pattern may be
13073   // obfuscated by target-specific operations after legalization. Do not create
13074   // an illegal select op, however, because that may be difficult to lower.
13075   EVT VT = Cast->getValueType(0);
13076   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
13077     return SDValue();
13078 
13079   SDValue VSel = Cast->getOperand(0);
13080   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
13081       VSel.getOperand(0).getOpcode() != ISD::SETCC)
13082     return SDValue();
13083 
13084   // Does the setcc have the same vector size as the casted select?
13085   SDValue SetCC = VSel.getOperand(0);
13086   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
13087   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
13088     return SDValue();
13089 
13090   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
13091   SDValue A = VSel.getOperand(1);
13092   SDValue B = VSel.getOperand(2);
13093   SDValue CastA, CastB;
13094   SDLoc DL(Cast);
13095   if (CastOpcode == ISD::FP_ROUND) {
13096     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
13097     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
13098     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
13099   } else {
13100     CastA = DAG.getNode(CastOpcode, DL, VT, A);
13101     CastB = DAG.getNode(CastOpcode, DL, VT, B);
13102   }
13103   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
13104 }
13105 
13106 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13107 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)13108 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
13109                                      const TargetLowering &TLI, EVT VT,
13110                                      bool LegalOperations, SDNode *N,
13111                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
13112   SDNode *N0Node = N0.getNode();
13113   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
13114                                                    : ISD::isZEXTLoad(N0Node);
13115   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
13116       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
13117     return SDValue();
13118 
13119   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13120   EVT MemVT = LN0->getMemoryVT();
13121   if ((LegalOperations || !LN0->isSimple() ||
13122        VT.isVector()) &&
13123       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
13124     return SDValue();
13125 
13126   SDValue ExtLoad =
13127       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
13128                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
13129   Combiner.CombineTo(N, ExtLoad);
13130   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
13131   if (LN0->use_empty())
13132     Combiner.recursivelyDeleteUnusedNodes(LN0);
13133   return SDValue(N, 0); // Return N so it doesn't get rechecked!
13134 }
13135 
13136 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13137 // Only generate vector extloads when 1) they're legal, and 2) they are
13138 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)13139 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
13140                                   const TargetLowering &TLI, EVT VT,
13141                                   bool LegalOperations, SDNode *N, SDValue N0,
13142                                   ISD::LoadExtType ExtLoadType,
13143                                   ISD::NodeType ExtOpc) {
13144   // TODO: isFixedLengthVector() should be removed and any negative effects on
13145   // code generation being the result of that target's implementation of
13146   // isVectorLoadExtDesirable().
13147   if (!ISD::isNON_EXTLoad(N0.getNode()) ||
13148       !ISD::isUNINDEXEDLoad(N0.getNode()) ||
13149       ((LegalOperations || VT.isFixedLengthVector() ||
13150         !cast<LoadSDNode>(N0)->isSimple()) &&
13151        !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
13152     return {};
13153 
13154   bool DoXform = true;
13155   SmallVector<SDNode *, 4> SetCCs;
13156   if (!N0.hasOneUse())
13157     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
13158   if (VT.isVector())
13159     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
13160   if (!DoXform)
13161     return {};
13162 
13163   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13164   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
13165                                    LN0->getBasePtr(), N0.getValueType(),
13166                                    LN0->getMemOperand());
13167   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
13168   // If the load value is used only by N, replace it via CombineTo N.
13169   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
13170   Combiner.CombineTo(N, ExtLoad);
13171   if (NoReplaceTrunc) {
13172     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
13173     Combiner.recursivelyDeleteUnusedNodes(LN0);
13174   } else {
13175     SDValue Trunc =
13176         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
13177     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
13178   }
13179   return SDValue(N, 0); // Return N so it doesn't get rechecked!
13180 }
13181 
13182 static SDValue
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)13183 tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
13184                          bool LegalOperations, SDNode *N, SDValue N0,
13185                          ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
13186   if (!N0.hasOneUse())
13187     return SDValue();
13188 
13189   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
13190   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
13191     return SDValue();
13192 
13193   if ((LegalOperations || !cast<MaskedLoadSDNode>(N0)->isSimple()) &&
13194       !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
13195     return SDValue();
13196 
13197   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
13198     return SDValue();
13199 
13200   SDLoc dl(Ld);
13201   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
13202   SDValue NewLoad = DAG.getMaskedLoad(
13203       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
13204       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
13205       ExtLoadType, Ld->isExpandingLoad());
13206   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
13207   return NewLoad;
13208 }
13209 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)13210 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
13211                                        bool LegalOperations) {
13212   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13213           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
13214 
13215   SDValue SetCC = N->getOperand(0);
13216   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
13217       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
13218     return SDValue();
13219 
13220   SDValue X = SetCC.getOperand(0);
13221   SDValue Ones = SetCC.getOperand(1);
13222   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
13223   EVT VT = N->getValueType(0);
13224   EVT XVT = X.getValueType();
13225   // setge X, C is canonicalized to setgt, so we do not need to match that
13226   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
13227   // not require the 'not' op.
13228   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
13229     // Invert and smear/shift the sign bit:
13230     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
13231     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
13232     SDLoc DL(N);
13233     unsigned ShCt = VT.getSizeInBits() - 1;
13234     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13235     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
13236       SDValue NotX = DAG.getNOT(DL, X, VT);
13237       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
13238       auto ShiftOpcode =
13239         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
13240       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
13241     }
13242   }
13243   return SDValue();
13244 }
13245 
foldSextSetcc(SDNode * N)13246 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
13247   SDValue N0 = N->getOperand(0);
13248   if (N0.getOpcode() != ISD::SETCC)
13249     return SDValue();
13250 
13251   SDValue N00 = N0.getOperand(0);
13252   SDValue N01 = N0.getOperand(1);
13253   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
13254   EVT VT = N->getValueType(0);
13255   EVT N00VT = N00.getValueType();
13256   SDLoc DL(N);
13257 
13258   // Propagate fast-math-flags.
13259   SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13260 
13261   // On some architectures (such as SSE/NEON/etc) the SETCC result type is
13262   // the same size as the compared operands. Try to optimize sext(setcc())
13263   // if this is the case.
13264   if (VT.isVector() && !LegalOperations &&
13265       TLI.getBooleanContents(N00VT) ==
13266           TargetLowering::ZeroOrNegativeOneBooleanContent) {
13267     EVT SVT = getSetCCResultType(N00VT);
13268 
13269     // If we already have the desired type, don't change it.
13270     if (SVT != N0.getValueType()) {
13271       // We know that the # elements of the results is the same as the
13272       // # elements of the compare (and the # elements of the compare result
13273       // for that matter).  Check to see that they are the same size.  If so,
13274       // we know that the element size of the sext'd result matches the
13275       // element size of the compare operands.
13276       if (VT.getSizeInBits() == SVT.getSizeInBits())
13277         return DAG.getSetCC(DL, VT, N00, N01, CC);
13278 
13279       // If the desired elements are smaller or larger than the source
13280       // elements, we can use a matching integer vector type and then
13281       // truncate/sign extend.
13282       EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
13283       if (SVT == MatchingVecType) {
13284         SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
13285         return DAG.getSExtOrTrunc(VsetCC, DL, VT);
13286       }
13287     }
13288 
13289     // Try to eliminate the sext of a setcc by zexting the compare operands.
13290     if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
13291         !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
13292       bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
13293       unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13294       unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13295 
13296       // We have an unsupported narrow vector compare op that would be legal
13297       // if extended to the destination type. See if the compare operands
13298       // can be freely extended to the destination type.
13299       auto IsFreeToExtend = [&](SDValue V) {
13300         if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
13301           return true;
13302         // Match a simple, non-extended load that can be converted to a
13303         // legal {z/s}ext-load.
13304         // TODO: Allow widening of an existing {z/s}ext-load?
13305         if (!(ISD::isNON_EXTLoad(V.getNode()) &&
13306               ISD::isUNINDEXEDLoad(V.getNode()) &&
13307               cast<LoadSDNode>(V)->isSimple() &&
13308               TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
13309           return false;
13310 
13311         // Non-chain users of this value must either be the setcc in this
13312         // sequence or extends that can be folded into the new {z/s}ext-load.
13313         for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
13314              UI != UE; ++UI) {
13315           // Skip uses of the chain and the setcc.
13316           SDNode *User = *UI;
13317           if (UI.getUse().getResNo() != 0 || User == N0.getNode())
13318             continue;
13319           // Extra users must have exactly the same cast we are about to create.
13320           // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
13321           //       is enhanced similarly.
13322           if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
13323             return false;
13324         }
13325         return true;
13326       };
13327 
13328       if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
13329         SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
13330         SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
13331         return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
13332       }
13333     }
13334   }
13335 
13336   // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
13337   // Here, T can be 1 or -1, depending on the type of the setcc and
13338   // getBooleanContents().
13339   unsigned SetCCWidth = N0.getScalarValueSizeInBits();
13340 
13341   // To determine the "true" side of the select, we need to know the high bit
13342   // of the value returned by the setcc if it evaluates to true.
13343   // If the type of the setcc is i1, then the true case of the select is just
13344   // sext(i1 1), that is, -1.
13345   // If the type of the setcc is larger (say, i8) then the value of the high
13346   // bit depends on getBooleanContents(), so ask TLI for a real "true" value
13347   // of the appropriate width.
13348   SDValue ExtTrueVal = (SetCCWidth == 1)
13349                            ? DAG.getAllOnesConstant(DL, VT)
13350                            : DAG.getBoolConstant(true, DL, VT, N00VT);
13351   SDValue Zero = DAG.getConstant(0, DL, VT);
13352   if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
13353     return SCC;
13354 
13355   if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(N0, VT, TLI)) {
13356     EVT SetCCVT = getSetCCResultType(N00VT);
13357     // Don't do this transform for i1 because there's a select transform
13358     // that would reverse it.
13359     // TODO: We should not do this transform at all without a target hook
13360     // because a sext is likely cheaper than a select?
13361     if (SetCCVT.getScalarSizeInBits() != 1 &&
13362         (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
13363       SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
13364       return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
13365     }
13366   }
13367 
13368   return SDValue();
13369 }
13370 
visitSIGN_EXTEND(SDNode * N)13371 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
13372   SDValue N0 = N->getOperand(0);
13373   EVT VT = N->getValueType(0);
13374   SDLoc DL(N);
13375 
13376   if (VT.isVector())
13377     if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13378       return FoldedVOp;
13379 
13380   // sext(undef) = 0 because the top bit will all be the same.
13381   if (N0.isUndef())
13382     return DAG.getConstant(0, DL, VT);
13383 
13384   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13385     return Res;
13386 
13387   // fold (sext (sext x)) -> (sext x)
13388   // fold (sext (aext x)) -> (sext x)
13389   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
13390     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
13391 
13392   // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13393   // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13394   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13395       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
13396     return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
13397                        N0.getOperand(0));
13398 
13399   // fold (sext (sext_inreg x)) -> (sext (trunc x))
13400   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
13401     SDValue N00 = N0.getOperand(0);
13402     EVT ExtVT = cast<VTSDNode>(N0->getOperand(1))->getVT();
13403     if ((N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(N00, ExtVT)) &&
13404         (!LegalTypes || TLI.isTypeLegal(ExtVT))) {
13405       SDValue T = DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N00);
13406       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, T);
13407     }
13408   }
13409 
13410   if (N0.getOpcode() == ISD::TRUNCATE) {
13411     // fold (sext (truncate (load x))) -> (sext (smaller load x))
13412     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
13413     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
13414       SDNode *oye = N0.getOperand(0).getNode();
13415       if (NarrowLoad.getNode() != N0.getNode()) {
13416         CombineTo(N0.getNode(), NarrowLoad);
13417         // CombineTo deleted the truncate, if needed, but not what's under it.
13418         AddToWorklist(oye);
13419       }
13420       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13421     }
13422 
13423     // See if the value being truncated is already sign extended.  If so, just
13424     // eliminate the trunc/sext pair.
13425     SDValue Op = N0.getOperand(0);
13426     unsigned OpBits   = Op.getScalarValueSizeInBits();
13427     unsigned MidBits  = N0.getScalarValueSizeInBits();
13428     unsigned DestBits = VT.getScalarSizeInBits();
13429     unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
13430 
13431     if (OpBits == DestBits) {
13432       // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
13433       // bits, it is already ready.
13434       if (NumSignBits > DestBits-MidBits)
13435         return Op;
13436     } else if (OpBits < DestBits) {
13437       // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
13438       // bits, just sext from i32.
13439       if (NumSignBits > OpBits-MidBits)
13440         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
13441     } else {
13442       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
13443       // bits, just truncate to i32.
13444       if (NumSignBits > OpBits-MidBits)
13445         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
13446     }
13447 
13448     // fold (sext (truncate x)) -> (sextinreg x).
13449     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
13450                                                  N0.getValueType())) {
13451       if (OpBits < DestBits)
13452         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
13453       else if (OpBits > DestBits)
13454         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
13455       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
13456                          DAG.getValueType(N0.getValueType()));
13457     }
13458   }
13459 
13460   // Try to simplify (sext (load x)).
13461   if (SDValue foldedExt =
13462           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
13463                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
13464     return foldedExt;
13465 
13466   if (SDValue foldedExt =
13467           tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13468                                    ISD::SEXTLOAD, ISD::SIGN_EXTEND))
13469     return foldedExt;
13470 
13471   // fold (sext (load x)) to multiple smaller sextloads.
13472   // Only on illegal but splittable vectors.
13473   if (SDValue ExtLoad = CombineExtLoad(N))
13474     return ExtLoad;
13475 
13476   // Try to simplify (sext (sextload x)).
13477   if (SDValue foldedExt = tryToFoldExtOfExtload(
13478           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
13479     return foldedExt;
13480 
13481   // fold (sext (and/or/xor (load x), cst)) ->
13482   //      (and/or/xor (sextload x), (sext cst))
13483   if (ISD::isBitwiseLogicOp(N0.getOpcode()) &&
13484       isa<LoadSDNode>(N0.getOperand(0)) &&
13485       N0.getOperand(1).getOpcode() == ISD::Constant &&
13486       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
13487     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
13488     EVT MemVT = LN00->getMemoryVT();
13489     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
13490       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
13491       SmallVector<SDNode*, 4> SetCCs;
13492       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
13493                                              ISD::SIGN_EXTEND, SetCCs, TLI);
13494       if (DoXform) {
13495         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
13496                                          LN00->getChain(), LN00->getBasePtr(),
13497                                          LN00->getMemoryVT(),
13498                                          LN00->getMemOperand());
13499         APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
13500         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
13501                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
13502         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
13503         bool NoReplaceTruncAnd = !N0.hasOneUse();
13504         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
13505         CombineTo(N, And);
13506         // If N0 has multiple uses, change other uses as well.
13507         if (NoReplaceTruncAnd) {
13508           SDValue TruncAnd =
13509               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
13510           CombineTo(N0.getNode(), TruncAnd);
13511         }
13512         if (NoReplaceTrunc) {
13513           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
13514         } else {
13515           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
13516                                       LN00->getValueType(0), ExtLoad);
13517           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
13518         }
13519         return SDValue(N,0); // Return N so it doesn't get rechecked!
13520       }
13521     }
13522   }
13523 
13524   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
13525     return V;
13526 
13527   if (SDValue V = foldSextSetcc(N))
13528     return V;
13529 
13530   // fold (sext x) -> (zext x) if the sign bit is known zero.
13531   if (!TLI.isSExtCheaperThanZExt(N0.getValueType(), VT) &&
13532       (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
13533       DAG.SignBitIsZero(N0)) {
13534     SDNodeFlags Flags;
13535     Flags.setNonNeg(true);
13536     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, Flags);
13537   }
13538 
13539   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13540     return NewVSel;
13541 
13542   // Eliminate this sign extend by doing a negation in the destination type:
13543   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
13544   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
13545       isNullOrNullSplat(N0.getOperand(0)) &&
13546       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
13547       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
13548     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
13549     return DAG.getNegative(Zext, DL, VT);
13550   }
13551   // Eliminate this sign extend by doing a decrement in the destination type:
13552   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
13553   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
13554       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
13555       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
13556       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
13557     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
13558     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
13559   }
13560 
13561   // fold sext (not i1 X) -> add (zext i1 X), -1
13562   // TODO: This could be extended to handle bool vectors.
13563   if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
13564       (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
13565                             TLI.isOperationLegal(ISD::ADD, VT)))) {
13566     // If we can eliminate the 'not', the sext form should be better
13567     if (SDValue NewXor = visitXOR(N0.getNode())) {
13568       // Returning N0 is a form of in-visit replacement that may have
13569       // invalidated N0.
13570       if (NewXor.getNode() == N0.getNode()) {
13571         // Return SDValue here as the xor should have already been replaced in
13572         // this sext.
13573         return SDValue();
13574       }
13575 
13576       // Return a new sext with the new xor.
13577       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
13578     }
13579 
13580     SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
13581     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
13582   }
13583 
13584   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
13585     return Res;
13586 
13587   return SDValue();
13588 }
13589 
13590 /// Given an extending node with a pop-count operand, if the target does not
13591 /// support a pop-count in the narrow source type but does support it in the
13592 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)13593 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
13594   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
13595           Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
13596 
13597   SDValue CtPop = Extend->getOperand(0);
13598   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
13599     return SDValue();
13600 
13601   EVT VT = Extend->getValueType(0);
13602   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13603   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
13604       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
13605     return SDValue();
13606 
13607   // zext (ctpop X) --> ctpop (zext X)
13608   SDLoc DL(Extend);
13609   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
13610   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
13611 }
13612 
13613 // If we have (zext (abs X)) where X is a type that will be promoted by type
13614 // legalization, convert to (abs (sext X)). But don't extend past a legal type.
widenAbs(SDNode * Extend,SelectionDAG & DAG)13615 static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
13616   assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
13617 
13618   EVT VT = Extend->getValueType(0);
13619   if (VT.isVector())
13620     return SDValue();
13621 
13622   SDValue Abs = Extend->getOperand(0);
13623   if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
13624     return SDValue();
13625 
13626   EVT AbsVT = Abs.getValueType();
13627   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13628   if (TLI.getTypeAction(*DAG.getContext(), AbsVT) !=
13629       TargetLowering::TypePromoteInteger)
13630     return SDValue();
13631 
13632   EVT LegalVT = TLI.getTypeToTransformTo(*DAG.getContext(), AbsVT);
13633 
13634   SDValue SExt =
13635       DAG.getNode(ISD::SIGN_EXTEND, SDLoc(Abs), LegalVT, Abs.getOperand(0));
13636   SDValue NewAbs = DAG.getNode(ISD::ABS, SDLoc(Abs), LegalVT, SExt);
13637   return DAG.getZExtOrTrunc(NewAbs, SDLoc(Extend), VT);
13638 }
13639 
visitZERO_EXTEND(SDNode * N)13640 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
13641   SDValue N0 = N->getOperand(0);
13642   EVT VT = N->getValueType(0);
13643   SDLoc DL(N);
13644 
13645   if (VT.isVector())
13646     if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13647       return FoldedVOp;
13648 
13649   // zext(undef) = 0
13650   if (N0.isUndef())
13651     return DAG.getConstant(0, DL, VT);
13652 
13653   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13654     return Res;
13655 
13656   // fold (zext (zext x)) -> (zext x)
13657   // fold (zext (aext x)) -> (zext x)
13658   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
13659     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
13660 
13661   // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13662   // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13663   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13664       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
13665     return DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(N), VT,
13666                        N0.getOperand(0));
13667 
13668   // fold (zext (truncate x)) -> (zext x) or
13669   //      (zext (truncate x)) -> (truncate x)
13670   // This is valid when the truncated bits of x are already zero.
13671   SDValue Op;
13672   KnownBits Known;
13673   if (isTruncateOf(DAG, N0, Op, Known)) {
13674     APInt TruncatedBits =
13675       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
13676       APInt(Op.getScalarValueSizeInBits(), 0) :
13677       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
13678                         N0.getScalarValueSizeInBits(),
13679                         std::min(Op.getScalarValueSizeInBits(),
13680                                  VT.getScalarSizeInBits()));
13681     if (TruncatedBits.isSubsetOf(Known.Zero)) {
13682       SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13683       DAG.salvageDebugInfo(*N0.getNode());
13684 
13685       return ZExtOrTrunc;
13686     }
13687   }
13688 
13689   // fold (zext (truncate x)) -> (and x, mask)
13690   if (N0.getOpcode() == ISD::TRUNCATE) {
13691     // fold (zext (truncate (load x))) -> (zext (smaller load x))
13692     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
13693     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
13694       SDNode *oye = N0.getOperand(0).getNode();
13695       if (NarrowLoad.getNode() != N0.getNode()) {
13696         CombineTo(N0.getNode(), NarrowLoad);
13697         // CombineTo deleted the truncate, if needed, but not what's under it.
13698         AddToWorklist(oye);
13699       }
13700       return SDValue(N, 0); // Return N so it doesn't get rechecked!
13701     }
13702 
13703     EVT SrcVT = N0.getOperand(0).getValueType();
13704     EVT MinVT = N0.getValueType();
13705 
13706     // Try to mask before the extension to avoid having to generate a larger mask,
13707     // possibly over several sub-vectors.
13708     if (SrcVT.bitsLT(VT) && VT.isVector()) {
13709       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
13710                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
13711         SDValue Op = N0.getOperand(0);
13712         Op = DAG.getZeroExtendInReg(Op, DL, MinVT);
13713         AddToWorklist(Op.getNode());
13714         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13715         // Transfer the debug info; the new node is equivalent to N0.
13716         DAG.transferDbgValues(N0, ZExtOrTrunc);
13717         return ZExtOrTrunc;
13718       }
13719     }
13720 
13721     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
13722       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
13723       AddToWorklist(Op.getNode());
13724       SDValue And = DAG.getZeroExtendInReg(Op, DL, MinVT);
13725       // We may safely transfer the debug info describing the truncate node over
13726       // to the equivalent and operation.
13727       DAG.transferDbgValues(N0, And);
13728       return And;
13729     }
13730   }
13731 
13732   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
13733   // if either of the casts is not free.
13734   if (N0.getOpcode() == ISD::AND &&
13735       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
13736       N0.getOperand(1).getOpcode() == ISD::Constant &&
13737       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType()) ||
13738        !TLI.isZExtFree(N0.getValueType(), VT))) {
13739     SDValue X = N0.getOperand(0).getOperand(0);
13740     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
13741     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
13742     return DAG.getNode(ISD::AND, DL, VT,
13743                        X, DAG.getConstant(Mask, DL, VT));
13744   }
13745 
13746   // Try to simplify (zext (load x)).
13747   if (SDValue foldedExt =
13748           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
13749                              ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
13750     return foldedExt;
13751 
13752   if (SDValue foldedExt =
13753           tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13754                                    ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
13755     return foldedExt;
13756 
13757   // fold (zext (load x)) to multiple smaller zextloads.
13758   // Only on illegal but splittable vectors.
13759   if (SDValue ExtLoad = CombineExtLoad(N))
13760     return ExtLoad;
13761 
13762   // fold (zext (and/or/xor (load x), cst)) ->
13763   //      (and/or/xor (zextload x), (zext cst))
13764   // Unless (and (load x) cst) will match as a zextload already and has
13765   // additional users, or the zext is already free.
13766   if (ISD::isBitwiseLogicOp(N0.getOpcode()) && !TLI.isZExtFree(N0, VT) &&
13767       isa<LoadSDNode>(N0.getOperand(0)) &&
13768       N0.getOperand(1).getOpcode() == ISD::Constant &&
13769       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
13770     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
13771     EVT MemVT = LN00->getMemoryVT();
13772     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
13773         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
13774       bool DoXform = true;
13775       SmallVector<SDNode*, 4> SetCCs;
13776       if (!N0.hasOneUse()) {
13777         if (N0.getOpcode() == ISD::AND) {
13778           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
13779           EVT LoadResultTy = AndC->getValueType(0);
13780           EVT ExtVT;
13781           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
13782             DoXform = false;
13783         }
13784       }
13785       if (DoXform)
13786         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
13787                                           ISD::ZERO_EXTEND, SetCCs, TLI);
13788       if (DoXform) {
13789         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
13790                                          LN00->getChain(), LN00->getBasePtr(),
13791                                          LN00->getMemoryVT(),
13792                                          LN00->getMemOperand());
13793         APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
13794         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
13795                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
13796         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
13797         bool NoReplaceTruncAnd = !N0.hasOneUse();
13798         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
13799         CombineTo(N, And);
13800         // If N0 has multiple uses, change other uses as well.
13801         if (NoReplaceTruncAnd) {
13802           SDValue TruncAnd =
13803               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
13804           CombineTo(N0.getNode(), TruncAnd);
13805         }
13806         if (NoReplaceTrunc) {
13807           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
13808         } else {
13809           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
13810                                       LN00->getValueType(0), ExtLoad);
13811           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
13812         }
13813         return SDValue(N,0); // Return N so it doesn't get rechecked!
13814       }
13815     }
13816   }
13817 
13818   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13819   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
13820   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
13821     return ZExtLoad;
13822 
13823   // Try to simplify (zext (zextload x)).
13824   if (SDValue foldedExt = tryToFoldExtOfExtload(
13825           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
13826     return foldedExt;
13827 
13828   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
13829     return V;
13830 
13831   if (N0.getOpcode() == ISD::SETCC) {
13832     // Propagate fast-math-flags.
13833     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13834 
13835     // Only do this before legalize for now.
13836     if (!LegalOperations && VT.isVector() &&
13837         N0.getValueType().getVectorElementType() == MVT::i1) {
13838       EVT N00VT = N0.getOperand(0).getValueType();
13839       if (getSetCCResultType(N00VT) == N0.getValueType())
13840         return SDValue();
13841 
13842       // We know that the # elements of the results is the same as the #
13843       // elements of the compare (and the # elements of the compare result for
13844       // that matter). Check to see that they are the same size. If so, we know
13845       // that the element size of the sext'd result matches the element size of
13846       // the compare operands.
13847       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
13848         // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
13849         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
13850                                      N0.getOperand(1), N0.getOperand(2));
13851         return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
13852       }
13853 
13854       // If the desired elements are smaller or larger than the source
13855       // elements we can use a matching integer vector type and then
13856       // truncate/any extend followed by zext_in_reg.
13857       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
13858       SDValue VsetCC =
13859           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
13860                       N0.getOperand(1), N0.getOperand(2));
13861       return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
13862                                     N0.getValueType());
13863     }
13864 
13865     // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
13866     EVT N0VT = N0.getValueType();
13867     EVT N00VT = N0.getOperand(0).getValueType();
13868     if (SDValue SCC = SimplifySelectCC(
13869             DL, N0.getOperand(0), N0.getOperand(1),
13870             DAG.getBoolConstant(true, DL, N0VT, N00VT),
13871             DAG.getBoolConstant(false, DL, N0VT, N00VT),
13872             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
13873       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
13874   }
13875 
13876   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
13877   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
13878       !TLI.isZExtFree(N0, VT)) {
13879     SDValue ShVal = N0.getOperand(0);
13880     SDValue ShAmt = N0.getOperand(1);
13881     if (auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt)) {
13882       if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
13883         if (N0.getOpcode() == ISD::SHL) {
13884           // If the original shl may be shifting out bits, do not perform this
13885           // transformation.
13886           // TODO: Add MaskedValueIsZero check.
13887           unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
13888                                    ShVal.getOperand(0).getValueSizeInBits();
13889           if (ShAmtC->getAPIntValue().ugt(KnownZeroBits))
13890             return SDValue();
13891         }
13892 
13893         // Ensure that the shift amount is wide enough for the shifted value.
13894         if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
13895           ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
13896 
13897         return DAG.getNode(N0.getOpcode(), DL, VT,
13898                            DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ShVal), ShAmt);
13899       }
13900     }
13901   }
13902 
13903   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13904     return NewVSel;
13905 
13906   if (SDValue NewCtPop = widenCtPop(N, DAG))
13907     return NewCtPop;
13908 
13909   if (SDValue V = widenAbs(N, DAG))
13910     return V;
13911 
13912   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
13913     return Res;
13914 
13915   return SDValue();
13916 }
13917 
visitANY_EXTEND(SDNode * N)13918 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
13919   SDValue N0 = N->getOperand(0);
13920   EVT VT = N->getValueType(0);
13921 
13922   // aext(undef) = undef
13923   if (N0.isUndef())
13924     return DAG.getUNDEF(VT);
13925 
13926   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13927     return Res;
13928 
13929   // fold (aext (aext x)) -> (aext x)
13930   // fold (aext (zext x)) -> (zext x)
13931   // fold (aext (sext x)) -> (sext x)
13932   if (N0.getOpcode() == ISD::ANY_EXTEND  ||
13933       N0.getOpcode() == ISD::ZERO_EXTEND ||
13934       N0.getOpcode() == ISD::SIGN_EXTEND)
13935     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
13936 
13937   // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
13938   // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13939   // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13940   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13941       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
13942       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
13943     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
13944 
13945   // fold (aext (truncate (load x))) -> (aext (smaller load x))
13946   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
13947   if (N0.getOpcode() == ISD::TRUNCATE) {
13948     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
13949       SDNode *oye = N0.getOperand(0).getNode();
13950       if (NarrowLoad.getNode() != N0.getNode()) {
13951         CombineTo(N0.getNode(), NarrowLoad);
13952         // CombineTo deleted the truncate, if needed, but not what's under it.
13953         AddToWorklist(oye);
13954       }
13955       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13956     }
13957   }
13958 
13959   // fold (aext (truncate x))
13960   if (N0.getOpcode() == ISD::TRUNCATE)
13961     return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
13962 
13963   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
13964   // if the trunc is not free.
13965   if (N0.getOpcode() == ISD::AND &&
13966       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
13967       N0.getOperand(1).getOpcode() == ISD::Constant &&
13968       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType())) {
13969     SDLoc DL(N);
13970     SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
13971     SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
13972     assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
13973     return DAG.getNode(ISD::AND, DL, VT, X, Y);
13974   }
13975 
13976   // fold (aext (load x)) -> (aext (truncate (extload x)))
13977   // None of the supported targets knows how to perform load and any_ext
13978   // on vectors in one instruction, so attempt to fold to zext instead.
13979   if (VT.isVector()) {
13980     // Try to simplify (zext (load x)).
13981     if (SDValue foldedExt =
13982             tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
13983                                ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
13984       return foldedExt;
13985   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
13986              ISD::isUNINDEXEDLoad(N0.getNode()) &&
13987              TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
13988     bool DoXform = true;
13989     SmallVector<SDNode *, 4> SetCCs;
13990     if (!N0.hasOneUse())
13991       DoXform =
13992           ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
13993     if (DoXform) {
13994       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13995       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
13996                                        LN0->getChain(), LN0->getBasePtr(),
13997                                        N0.getValueType(), LN0->getMemOperand());
13998       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
13999       // If the load value is used only by N, replace it via CombineTo N.
14000       bool NoReplaceTrunc = N0.hasOneUse();
14001       CombineTo(N, ExtLoad);
14002       if (NoReplaceTrunc) {
14003         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
14004         recursivelyDeleteUnusedNodes(LN0);
14005       } else {
14006         SDValue Trunc =
14007             DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
14008         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
14009       }
14010       return SDValue(N, 0); // Return N so it doesn't get rechecked!
14011     }
14012   }
14013 
14014   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
14015   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
14016   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
14017   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
14018       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
14019     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14020     ISD::LoadExtType ExtType = LN0->getExtensionType();
14021     EVT MemVT = LN0->getMemoryVT();
14022     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
14023       SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
14024                                        VT, LN0->getChain(), LN0->getBasePtr(),
14025                                        MemVT, LN0->getMemOperand());
14026       CombineTo(N, ExtLoad);
14027       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
14028       recursivelyDeleteUnusedNodes(LN0);
14029       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14030     }
14031   }
14032 
14033   if (N0.getOpcode() == ISD::SETCC) {
14034     // Propagate fast-math-flags.
14035     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14036 
14037     // For vectors:
14038     // aext(setcc) -> vsetcc
14039     // aext(setcc) -> truncate(vsetcc)
14040     // aext(setcc) -> aext(vsetcc)
14041     // Only do this before legalize for now.
14042     if (VT.isVector() && !LegalOperations) {
14043       EVT N00VT = N0.getOperand(0).getValueType();
14044       if (getSetCCResultType(N00VT) == N0.getValueType())
14045         return SDValue();
14046 
14047       // We know that the # elements of the results is the same as the
14048       // # elements of the compare (and the # elements of the compare result
14049       // for that matter).  Check to see that they are the same size.  If so,
14050       // we know that the element size of the sext'd result matches the
14051       // element size of the compare operands.
14052       if (VT.getSizeInBits() == N00VT.getSizeInBits())
14053         return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
14054                              N0.getOperand(1),
14055                              cast<CondCodeSDNode>(N0.getOperand(2))->get());
14056 
14057       // If the desired elements are smaller or larger than the source
14058       // elements we can use a matching integer vector type and then
14059       // truncate/any extend
14060       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14061       SDValue VsetCC =
14062         DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
14063                       N0.getOperand(1),
14064                       cast<CondCodeSDNode>(N0.getOperand(2))->get());
14065       return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
14066     }
14067 
14068     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
14069     SDLoc DL(N);
14070     if (SDValue SCC = SimplifySelectCC(
14071             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
14072             DAG.getConstant(0, DL, VT),
14073             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
14074       return SCC;
14075   }
14076 
14077   if (SDValue NewCtPop = widenCtPop(N, DAG))
14078     return NewCtPop;
14079 
14080   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
14081     return Res;
14082 
14083   return SDValue();
14084 }
14085 
visitAssertExt(SDNode * N)14086 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
14087   unsigned Opcode = N->getOpcode();
14088   SDValue N0 = N->getOperand(0);
14089   SDValue N1 = N->getOperand(1);
14090   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
14091 
14092   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
14093   if (N0.getOpcode() == Opcode &&
14094       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
14095     return N0;
14096 
14097   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14098       N0.getOperand(0).getOpcode() == Opcode) {
14099     // We have an assert, truncate, assert sandwich. Make one stronger assert
14100     // by asserting on the smallest asserted type to the larger source type.
14101     // This eliminates the later assert:
14102     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
14103     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
14104     SDLoc DL(N);
14105     SDValue BigA = N0.getOperand(0);
14106     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
14107     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
14108     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
14109     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
14110                                     BigA.getOperand(0), MinAssertVTVal);
14111     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
14112   }
14113 
14114   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
14115   // than X. Just move the AssertZext in front of the truncate and drop the
14116   // AssertSExt.
14117   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14118       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
14119       Opcode == ISD::AssertZext) {
14120     SDValue BigA = N0.getOperand(0);
14121     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
14122     if (AssertVT.bitsLT(BigA_AssertVT)) {
14123       SDLoc DL(N);
14124       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
14125                                       BigA.getOperand(0), N1);
14126       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
14127     }
14128   }
14129 
14130   return SDValue();
14131 }
14132 
visitAssertAlign(SDNode * N)14133 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
14134   SDLoc DL(N);
14135 
14136   Align AL = cast<AssertAlignSDNode>(N)->getAlign();
14137   SDValue N0 = N->getOperand(0);
14138 
14139   // Fold (assertalign (assertalign x, AL0), AL1) ->
14140   // (assertalign x, max(AL0, AL1))
14141   if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
14142     return DAG.getAssertAlign(DL, N0.getOperand(0),
14143                               std::max(AL, AAN->getAlign()));
14144 
14145   // In rare cases, there are trivial arithmetic ops in source operands. Sink
14146   // this assert down to source operands so that those arithmetic ops could be
14147   // exposed to the DAG combining.
14148   switch (N0.getOpcode()) {
14149   default:
14150     break;
14151   case ISD::ADD:
14152   case ISD::SUB: {
14153     unsigned AlignShift = Log2(AL);
14154     SDValue LHS = N0.getOperand(0);
14155     SDValue RHS = N0.getOperand(1);
14156     unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
14157     unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
14158     if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
14159       if (LHSAlignShift < AlignShift)
14160         LHS = DAG.getAssertAlign(DL, LHS, AL);
14161       if (RHSAlignShift < AlignShift)
14162         RHS = DAG.getAssertAlign(DL, RHS, AL);
14163       return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
14164     }
14165     break;
14166   }
14167   }
14168 
14169   return SDValue();
14170 }
14171 
14172 /// If the result of a load is shifted/masked/truncated to an effectively
14173 /// narrower type, try to transform the load to a narrower type and/or
14174 /// use an extending load.
reduceLoadWidth(SDNode * N)14175 SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
14176   unsigned Opc = N->getOpcode();
14177 
14178   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
14179   SDValue N0 = N->getOperand(0);
14180   EVT VT = N->getValueType(0);
14181   EVT ExtVT = VT;
14182 
14183   // This transformation isn't valid for vector loads.
14184   if (VT.isVector())
14185     return SDValue();
14186 
14187   // The ShAmt variable is used to indicate that we've consumed a right
14188   // shift. I.e. we want to narrow the width of the load by skipping to load the
14189   // ShAmt least significant bits.
14190   unsigned ShAmt = 0;
14191   // A special case is when the least significant bits from the load are masked
14192   // away, but using an AND rather than a right shift. HasShiftedOffset is used
14193   // to indicate that the narrowed load should be left-shifted ShAmt bits to get
14194   // the result.
14195   bool HasShiftedOffset = false;
14196   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
14197   // extended to VT.
14198   if (Opc == ISD::SIGN_EXTEND_INREG) {
14199     ExtType = ISD::SEXTLOAD;
14200     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
14201   } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
14202     // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
14203     // value, or it may be shifting a higher subword, half or byte into the
14204     // lowest bits.
14205 
14206     // Only handle shift with constant shift amount, and the shiftee must be a
14207     // load.
14208     auto *LN = dyn_cast<LoadSDNode>(N0);
14209     auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
14210     if (!N1C || !LN)
14211       return SDValue();
14212     // If the shift amount is larger than the memory type then we're not
14213     // accessing any of the loaded bytes.
14214     ShAmt = N1C->getZExtValue();
14215     uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
14216     if (MemoryWidth <= ShAmt)
14217       return SDValue();
14218     // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
14219     ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
14220     ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
14221     // If original load is a SEXTLOAD then we can't simply replace it by a
14222     // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
14223     // followed by a ZEXT, but that is not handled at the moment). Similarly if
14224     // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
14225     if ((LN->getExtensionType() == ISD::SEXTLOAD ||
14226          LN->getExtensionType() == ISD::ZEXTLOAD) &&
14227         LN->getExtensionType() != ExtType)
14228       return SDValue();
14229   } else if (Opc == ISD::AND) {
14230     // An AND with a constant mask is the same as a truncate + zero-extend.
14231     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
14232     if (!AndC)
14233       return SDValue();
14234 
14235     const APInt &Mask = AndC->getAPIntValue();
14236     unsigned ActiveBits = 0;
14237     if (Mask.isMask()) {
14238       ActiveBits = Mask.countr_one();
14239     } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
14240       HasShiftedOffset = true;
14241     } else {
14242       return SDValue();
14243     }
14244 
14245     ExtType = ISD::ZEXTLOAD;
14246     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
14247   }
14248 
14249   // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
14250   // a right shift. Here we redo some of those checks, to possibly adjust the
14251   // ExtVT even further based on "a masking AND". We could also end up here for
14252   // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
14253   // need to be done here as well.
14254   if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
14255     SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
14256     // Bail out when the SRL has more than one use. This is done for historical
14257     // (undocumented) reasons. Maybe intent was to guard the AND-masking below
14258     // check below? And maybe it could be non-profitable to do the transform in
14259     // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
14260     // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
14261     if (!SRL.hasOneUse())
14262       return SDValue();
14263 
14264     // Only handle shift with constant shift amount, and the shiftee must be a
14265     // load.
14266     auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
14267     auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
14268     if (!SRL1C || !LN)
14269       return SDValue();
14270 
14271     // If the shift amount is larger than the input type then we're not
14272     // accessing any of the loaded bytes.  If the load was a zextload/extload
14273     // then the result of the shift+trunc is zero/undef (handled elsewhere).
14274     ShAmt = SRL1C->getZExtValue();
14275     uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
14276     if (ShAmt >= MemoryWidth)
14277       return SDValue();
14278 
14279     // Because a SRL must be assumed to *need* to zero-extend the high bits
14280     // (as opposed to anyext the high bits), we can't combine the zextload
14281     // lowering of SRL and an sextload.
14282     if (LN->getExtensionType() == ISD::SEXTLOAD)
14283       return SDValue();
14284 
14285     // Avoid reading outside the memory accessed by the original load (could
14286     // happened if we only adjust the load base pointer by ShAmt). Instead we
14287     // try to narrow the load even further. The typical scenario here is:
14288     //   (i64 (truncate (i96 (srl (load x), 64)))) ->
14289     //     (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
14290     if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
14291       // Don't replace sextload by zextload.
14292       if (ExtType == ISD::SEXTLOAD)
14293         return SDValue();
14294       // Narrow the load.
14295       ExtType = ISD::ZEXTLOAD;
14296       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
14297     }
14298 
14299     // If the SRL is only used by a masking AND, we may be able to adjust
14300     // the ExtVT to make the AND redundant.
14301     SDNode *Mask = *(SRL->use_begin());
14302     if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
14303         isa<ConstantSDNode>(Mask->getOperand(1))) {
14304       const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
14305       if (ShiftMask.isMask()) {
14306         EVT MaskedVT =
14307             EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countr_one());
14308         // If the mask is smaller, recompute the type.
14309         if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
14310             TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
14311           ExtVT = MaskedVT;
14312       }
14313     }
14314 
14315     N0 = SRL.getOperand(0);
14316   }
14317 
14318   // If the load is shifted left (and the result isn't shifted back right), we
14319   // can fold a truncate through the shift. The typical scenario is that N
14320   // points at a TRUNCATE here so the attempted fold is:
14321   //   (truncate (shl (load x), c))) -> (shl (narrow load x), c)
14322   // ShLeftAmt will indicate how much a narrowed load should be shifted left.
14323   unsigned ShLeftAmt = 0;
14324   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14325       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
14326     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
14327       ShLeftAmt = N01->getZExtValue();
14328       N0 = N0.getOperand(0);
14329     }
14330   }
14331 
14332   // If we haven't found a load, we can't narrow it.
14333   if (!isa<LoadSDNode>(N0))
14334     return SDValue();
14335 
14336   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14337   // Reducing the width of a volatile load is illegal.  For atomics, we may be
14338   // able to reduce the width provided we never widen again. (see D66309)
14339   if (!LN0->isSimple() ||
14340       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
14341     return SDValue();
14342 
14343   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
14344     unsigned LVTStoreBits =
14345         LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
14346     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
14347     return LVTStoreBits - EVTStoreBits - ShAmt;
14348   };
14349 
14350   // We need to adjust the pointer to the load by ShAmt bits in order to load
14351   // the correct bytes.
14352   unsigned PtrAdjustmentInBits =
14353       DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
14354 
14355   uint64_t PtrOff = PtrAdjustmentInBits / 8;
14356   Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
14357   SDLoc DL(LN0);
14358   // The original load itself didn't wrap, so an offset within it doesn't.
14359   SDNodeFlags Flags;
14360   Flags.setNoUnsignedWrap(true);
14361   SDValue NewPtr = DAG.getMemBasePlusOffset(
14362       LN0->getBasePtr(), TypeSize::getFixed(PtrOff), DL, Flags);
14363   AddToWorklist(NewPtr.getNode());
14364 
14365   SDValue Load;
14366   if (ExtType == ISD::NON_EXTLOAD)
14367     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
14368                        LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
14369                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
14370   else
14371     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
14372                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
14373                           NewAlign, LN0->getMemOperand()->getFlags(),
14374                           LN0->getAAInfo());
14375 
14376   // Replace the old load's chain with the new load's chain.
14377   WorklistRemover DeadNodes(*this);
14378   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
14379 
14380   // Shift the result left, if we've swallowed a left shift.
14381   SDValue Result = Load;
14382   if (ShLeftAmt != 0) {
14383     EVT ShImmTy = getShiftAmountTy(Result.getValueType());
14384     if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
14385       ShImmTy = VT;
14386     // If the shift amount is as large as the result size (but, presumably,
14387     // no larger than the source) then the useful bits of the result are
14388     // zero; we can't simply return the shortened shift, because the result
14389     // of that operation is undefined.
14390     if (ShLeftAmt >= VT.getScalarSizeInBits())
14391       Result = DAG.getConstant(0, DL, VT);
14392     else
14393       Result = DAG.getNode(ISD::SHL, DL, VT,
14394                           Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
14395   }
14396 
14397   if (HasShiftedOffset) {
14398     // We're using a shifted mask, so the load now has an offset. This means
14399     // that data has been loaded into the lower bytes than it would have been
14400     // before, so we need to shl the loaded data into the correct position in the
14401     // register.
14402     SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
14403     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
14404     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
14405   }
14406 
14407   // Return the new loaded value.
14408   return Result;
14409 }
14410 
visitSIGN_EXTEND_INREG(SDNode * N)14411 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
14412   SDValue N0 = N->getOperand(0);
14413   SDValue N1 = N->getOperand(1);
14414   EVT VT = N->getValueType(0);
14415   EVT ExtVT = cast<VTSDNode>(N1)->getVT();
14416   unsigned VTBits = VT.getScalarSizeInBits();
14417   unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
14418 
14419   // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
14420   if (N0.isUndef())
14421     return DAG.getConstant(0, SDLoc(N), VT);
14422 
14423   // fold (sext_in_reg c1) -> c1
14424   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
14425     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
14426 
14427   // If the input is already sign extended, just drop the extension.
14428   if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
14429     return N0;
14430 
14431   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
14432   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14433       ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
14434     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0),
14435                        N1);
14436 
14437   // fold (sext_in_reg (sext x)) -> (sext x)
14438   // fold (sext_in_reg (aext x)) -> (sext x)
14439   // if x is small enough or if we know that x has more than 1 sign bit and the
14440   // sign_extend_inreg is extending from one of them.
14441   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
14442     SDValue N00 = N0.getOperand(0);
14443     unsigned N00Bits = N00.getScalarValueSizeInBits();
14444     if ((N00Bits <= ExtVTBits ||
14445          DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
14446         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
14447       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
14448   }
14449 
14450   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
14451   // if x is small enough or if we know that x has more than 1 sign bit and the
14452   // sign_extend_inreg is extending from one of them.
14453   if (ISD::isExtVecInRegOpcode(N0.getOpcode())) {
14454     SDValue N00 = N0.getOperand(0);
14455     unsigned N00Bits = N00.getScalarValueSizeInBits();
14456     unsigned DstElts = N0.getValueType().getVectorMinNumElements();
14457     unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
14458     bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
14459     APInt DemandedSrcElts = APInt::getLowBitsSet(SrcElts, DstElts);
14460     if ((N00Bits == ExtVTBits ||
14461          (!IsZext && (N00Bits < ExtVTBits ||
14462                       DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
14463         (!LegalOperations ||
14464          TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
14465       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, N00);
14466   }
14467 
14468   // fold (sext_in_reg (zext x)) -> (sext x)
14469   // iff we are extending the source sign bit.
14470   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
14471     SDValue N00 = N0.getOperand(0);
14472     if (N00.getScalarValueSizeInBits() == ExtVTBits &&
14473         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
14474       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
14475   }
14476 
14477   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
14478   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
14479     return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT);
14480 
14481   // fold operands of sext_in_reg based on knowledge that the top bits are not
14482   // demanded.
14483   if (SimplifyDemandedBits(SDValue(N, 0)))
14484     return SDValue(N, 0);
14485 
14486   // fold (sext_in_reg (load x)) -> (smaller sextload x)
14487   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
14488   if (SDValue NarrowLoad = reduceLoadWidth(N))
14489     return NarrowLoad;
14490 
14491   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
14492   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
14493   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
14494   if (N0.getOpcode() == ISD::SRL) {
14495     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
14496       if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
14497         // We can turn this into an SRA iff the input to the SRL is already sign
14498         // extended enough.
14499         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
14500         if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
14501           return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
14502                              N0.getOperand(1));
14503       }
14504   }
14505 
14506   // fold (sext_inreg (extload x)) -> (sextload x)
14507   // If sextload is not supported by target, we can only do the combine when
14508   // load has one use. Doing otherwise can block folding the extload with other
14509   // extends that the target does support.
14510   if (ISD::isEXTLoad(N0.getNode()) &&
14511       ISD::isUNINDEXEDLoad(N0.getNode()) &&
14512       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
14513       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
14514         N0.hasOneUse()) ||
14515        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
14516     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14517     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
14518                                      LN0->getChain(),
14519                                      LN0->getBasePtr(), ExtVT,
14520                                      LN0->getMemOperand());
14521     CombineTo(N, ExtLoad);
14522     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
14523     AddToWorklist(ExtLoad.getNode());
14524     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14525   }
14526 
14527   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
14528   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
14529       N0.hasOneUse() &&
14530       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
14531       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
14532        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
14533     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14534     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
14535                                      LN0->getChain(),
14536                                      LN0->getBasePtr(), ExtVT,
14537                                      LN0->getMemOperand());
14538     CombineTo(N, ExtLoad);
14539     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
14540     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14541   }
14542 
14543   // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
14544   // ignore it if the masked load is already sign extended
14545   if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
14546     if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
14547         Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
14548         TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
14549       SDValue ExtMaskedLoad = DAG.getMaskedLoad(
14550           VT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
14551           Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
14552           Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
14553       CombineTo(N, ExtMaskedLoad);
14554       CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
14555       return SDValue(N, 0); // Return N so it doesn't get rechecked!
14556     }
14557   }
14558 
14559   // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
14560   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
14561     if (SDValue(GN0, 0).hasOneUse() &&
14562         ExtVT == GN0->getMemoryVT() &&
14563         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
14564       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
14565                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
14566 
14567       SDValue ExtLoad = DAG.getMaskedGather(
14568           DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
14569           GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
14570 
14571       CombineTo(N, ExtLoad);
14572       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
14573       AddToWorklist(ExtLoad.getNode());
14574       return SDValue(N, 0); // Return N so it doesn't get rechecked!
14575     }
14576   }
14577 
14578   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
14579   if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
14580     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
14581                                            N0.getOperand(1), false))
14582       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1);
14583   }
14584 
14585   // Fold (iM_signext_inreg
14586   //        (extract_subvector (zext|anyext|sext iN_v to _) _)
14587   //        from iN)
14588   //      -> (extract_subvector (signext iN_v to iM))
14589   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
14590       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
14591     SDValue InnerExt = N0.getOperand(0);
14592     EVT InnerExtVT = InnerExt->getValueType(0);
14593     SDValue Extendee = InnerExt->getOperand(0);
14594 
14595     if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
14596         (!LegalOperations ||
14597          TLI.isOperationLegal(ISD::SIGN_EXTEND, InnerExtVT))) {
14598       SDValue SignExtExtendee =
14599           DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), InnerExtVT, Extendee);
14600       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, SignExtExtendee,
14601                          N0.getOperand(1));
14602     }
14603   }
14604 
14605   return SDValue();
14606 }
14607 
14608 static SDValue
foldExtendVectorInregToExtendOfSubvector(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalOperations)14609 foldExtendVectorInregToExtendOfSubvector(SDNode *N, const TargetLowering &TLI,
14610                                          SelectionDAG &DAG,
14611                                          bool LegalOperations) {
14612   unsigned InregOpcode = N->getOpcode();
14613   unsigned Opcode = DAG.getOpcode_EXTEND(InregOpcode);
14614 
14615   SDValue Src = N->getOperand(0);
14616   EVT VT = N->getValueType(0);
14617   EVT SrcVT = EVT::getVectorVT(*DAG.getContext(),
14618                                Src.getValueType().getVectorElementType(),
14619                                VT.getVectorElementCount());
14620 
14621   assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
14622          "Expected EXTEND_VECTOR_INREG dag node in input!");
14623 
14624   // Profitability check: our operand must be an one-use CONCAT_VECTORS.
14625   // FIXME: one-use check may be overly restrictive
14626   if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
14627     return SDValue();
14628 
14629   // Profitability check: we must be extending exactly one of it's operands.
14630   // FIXME: this is probably overly restrictive.
14631   Src = Src.getOperand(0);
14632   if (Src.getValueType() != SrcVT)
14633     return SDValue();
14634 
14635   if (LegalOperations && !TLI.isOperationLegal(Opcode, VT))
14636     return SDValue();
14637 
14638   return DAG.getNode(Opcode, SDLoc(N), VT, Src);
14639 }
14640 
visitEXTEND_VECTOR_INREG(SDNode * N)14641 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
14642   SDValue N0 = N->getOperand(0);
14643   EVT VT = N->getValueType(0);
14644 
14645   if (N0.isUndef()) {
14646     // aext_vector_inreg(undef) = undef because the top bits are undefined.
14647     // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
14648     return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
14649                ? DAG.getUNDEF(VT)
14650                : DAG.getConstant(0, SDLoc(N), VT);
14651   }
14652 
14653   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
14654     return Res;
14655 
14656   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
14657     return SDValue(N, 0);
14658 
14659   if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, TLI, DAG,
14660                                                            LegalOperations))
14661     return R;
14662 
14663   return SDValue();
14664 }
14665 
visitTRUNCATE(SDNode * N)14666 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14667   SDValue N0 = N->getOperand(0);
14668   EVT VT = N->getValueType(0);
14669   EVT SrcVT = N0.getValueType();
14670   bool isLE = DAG.getDataLayout().isLittleEndian();
14671 
14672   // trunc(undef) = undef
14673   if (N0.isUndef())
14674     return DAG.getUNDEF(VT);
14675 
14676   // fold (truncate (truncate x)) -> (truncate x)
14677   if (N0.getOpcode() == ISD::TRUNCATE)
14678     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
14679 
14680   // fold (truncate c1) -> c1
14681   if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, SDLoc(N), VT, {N0}))
14682     return C;
14683 
14684   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
14685   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
14686       N0.getOpcode() == ISD::SIGN_EXTEND ||
14687       N0.getOpcode() == ISD::ANY_EXTEND) {
14688     // if the source is smaller than the dest, we still need an extend.
14689     if (N0.getOperand(0).getValueType().bitsLT(VT))
14690       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
14691     // if the source is larger than the dest, than we just need the truncate.
14692     if (N0.getOperand(0).getValueType().bitsGT(VT))
14693       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
14694     // if the source and dest are the same type, we can drop both the extend
14695     // and the truncate.
14696     return N0.getOperand(0);
14697   }
14698 
14699   // Try to narrow a truncate-of-sext_in_reg to the destination type:
14700   // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
14701   if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14702       N0.hasOneUse()) {
14703     SDValue X = N0.getOperand(0);
14704     SDValue ExtVal = N0.getOperand(1);
14705     EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
14706     if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(VT, SrcVT, ExtVT)) {
14707       SDValue TrX = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X);
14708       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, TrX, ExtVal);
14709     }
14710   }
14711 
14712   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
14713   if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
14714     return SDValue();
14715 
14716   // Fold extract-and-trunc into a narrow extract. For example:
14717   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
14718   //   i32 y = TRUNCATE(i64 x)
14719   //        -- becomes --
14720   //   v16i8 b = BITCAST (v2i64 val)
14721   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
14722   //
14723   // Note: We only run this optimization after type legalization (which often
14724   // creates this pattern) and before operation legalization after which
14725   // we need to be more careful about the vector instructions that we generate.
14726   if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
14727       LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
14728     EVT VecTy = N0.getOperand(0).getValueType();
14729     EVT ExTy = N0.getValueType();
14730     EVT TrTy = N->getValueType(0);
14731 
14732     auto EltCnt = VecTy.getVectorElementCount();
14733     unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
14734     auto NewEltCnt = EltCnt * SizeRatio;
14735 
14736     EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
14737     assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
14738 
14739     SDValue EltNo = N0->getOperand(1);
14740     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
14741       int Elt = EltNo->getAsZExtVal();
14742       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
14743 
14744       SDLoc DL(N);
14745       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
14746                          DAG.getBitcast(NVT, N0.getOperand(0)),
14747                          DAG.getVectorIdxConstant(Index, DL));
14748     }
14749   }
14750 
14751   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
14752   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
14753     if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
14754         TLI.isTruncateFree(SrcVT, VT)) {
14755       SDLoc SL(N0);
14756       SDValue Cond = N0.getOperand(0);
14757       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
14758       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
14759       return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
14760     }
14761   }
14762 
14763   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
14764   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14765       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
14766       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
14767     SDValue Amt = N0.getOperand(1);
14768     KnownBits Known = DAG.computeKnownBits(Amt);
14769     unsigned Size = VT.getScalarSizeInBits();
14770     if (Known.countMaxActiveBits() <= Log2_32(Size)) {
14771       SDLoc SL(N);
14772       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
14773 
14774       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
14775       if (AmtVT != Amt.getValueType()) {
14776         Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
14777         AddToWorklist(Amt.getNode());
14778       }
14779       return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
14780     }
14781   }
14782 
14783   if (SDValue V = foldSubToUSubSat(VT, N0.getNode()))
14784     return V;
14785 
14786   if (SDValue ABD = foldABSToABD(N))
14787     return ABD;
14788 
14789   // Attempt to pre-truncate BUILD_VECTOR sources.
14790   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
14791       N0.hasOneUse() &&
14792       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
14793       // Avoid creating illegal types if running after type legalizer.
14794       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
14795     SDLoc DL(N);
14796     EVT SVT = VT.getScalarType();
14797     SmallVector<SDValue, 8> TruncOps;
14798     for (const SDValue &Op : N0->op_values()) {
14799       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
14800       TruncOps.push_back(TruncOp);
14801     }
14802     return DAG.getBuildVector(VT, DL, TruncOps);
14803   }
14804 
14805   // trunc (splat_vector x) -> splat_vector (trunc x)
14806   if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
14807       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType())) &&
14808       (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, VT))) {
14809     SDLoc DL(N);
14810     EVT SVT = VT.getScalarType();
14811     return DAG.getSplatVector(
14812         VT, DL, DAG.getNode(ISD::TRUNCATE, DL, SVT, N0->getOperand(0)));
14813   }
14814 
14815   // Fold a series of buildvector, bitcast, and truncate if possible.
14816   // For example fold
14817   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
14818   //   (2xi32 (buildvector x, y)).
14819   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
14820       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
14821       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
14822       N0.getOperand(0).hasOneUse()) {
14823     SDValue BuildVect = N0.getOperand(0);
14824     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
14825     EVT TruncVecEltTy = VT.getVectorElementType();
14826 
14827     // Check that the element types match.
14828     if (BuildVectEltTy == TruncVecEltTy) {
14829       // Now we only need to compute the offset of the truncated elements.
14830       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
14831       unsigned TruncVecNumElts = VT.getVectorNumElements();
14832       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
14833 
14834       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
14835              "Invalid number of elements");
14836 
14837       SmallVector<SDValue, 8> Opnds;
14838       for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
14839         Opnds.push_back(BuildVect.getOperand(i));
14840 
14841       return DAG.getBuildVector(VT, SDLoc(N), Opnds);
14842     }
14843   }
14844 
14845   // fold (truncate (load x)) -> (smaller load x)
14846   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
14847   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
14848     if (SDValue Reduced = reduceLoadWidth(N))
14849       return Reduced;
14850 
14851     // Handle the case where the truncated result is at least as wide as the
14852     // loaded type.
14853     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
14854       auto *LN0 = cast<LoadSDNode>(N0);
14855       if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
14856         SDValue NewLoad = DAG.getExtLoad(
14857             LN0->getExtensionType(), SDLoc(LN0), VT, LN0->getChain(),
14858             LN0->getBasePtr(), LN0->getMemoryVT(), LN0->getMemOperand());
14859         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
14860         return NewLoad;
14861       }
14862     }
14863   }
14864 
14865   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
14866   // where ... are all 'undef'.
14867   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
14868     SmallVector<EVT, 8> VTs;
14869     SDValue V;
14870     unsigned Idx = 0;
14871     unsigned NumDefs = 0;
14872 
14873     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
14874       SDValue X = N0.getOperand(i);
14875       if (!X.isUndef()) {
14876         V = X;
14877         Idx = i;
14878         NumDefs++;
14879       }
14880       // Stop if more than one members are non-undef.
14881       if (NumDefs > 1)
14882         break;
14883 
14884       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
14885                                      VT.getVectorElementType(),
14886                                      X.getValueType().getVectorElementCount()));
14887     }
14888 
14889     if (NumDefs == 0)
14890       return DAG.getUNDEF(VT);
14891 
14892     if (NumDefs == 1) {
14893       assert(V.getNode() && "The single defined operand is empty!");
14894       SmallVector<SDValue, 8> Opnds;
14895       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
14896         if (i != Idx) {
14897           Opnds.push_back(DAG.getUNDEF(VTs[i]));
14898           continue;
14899         }
14900         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
14901         AddToWorklist(NV.getNode());
14902         Opnds.push_back(NV);
14903       }
14904       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
14905     }
14906   }
14907 
14908   // Fold truncate of a bitcast of a vector to an extract of the low vector
14909   // element.
14910   //
14911   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
14912   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
14913     SDValue VecSrc = N0.getOperand(0);
14914     EVT VecSrcVT = VecSrc.getValueType();
14915     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
14916         (!LegalOperations ||
14917          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
14918       SDLoc SL(N);
14919 
14920       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
14921       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
14922                          DAG.getVectorIdxConstant(Idx, SL));
14923     }
14924   }
14925 
14926   // Simplify the operands using demanded-bits information.
14927   if (SimplifyDemandedBits(SDValue(N, 0)))
14928     return SDValue(N, 0);
14929 
14930   // fold (truncate (extract_subvector(ext x))) ->
14931   //      (extract_subvector x)
14932   // TODO: This can be generalized to cover cases where the truncate and extract
14933   // do not fully cancel each other out.
14934   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
14935     SDValue N00 = N0.getOperand(0);
14936     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
14937         N00.getOpcode() == ISD::ZERO_EXTEND ||
14938         N00.getOpcode() == ISD::ANY_EXTEND) {
14939       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
14940           VT.getVectorElementType())
14941         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
14942                            N00.getOperand(0), N0.getOperand(1));
14943     }
14944   }
14945 
14946   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14947     return NewVSel;
14948 
14949   // Narrow a suitable binary operation with a non-opaque constant operand by
14950   // moving it ahead of the truncate. This is limited to pre-legalization
14951   // because targets may prefer a wider type during later combines and invert
14952   // this transform.
14953   switch (N0.getOpcode()) {
14954   case ISD::ADD:
14955   case ISD::SUB:
14956   case ISD::MUL:
14957   case ISD::AND:
14958   case ISD::OR:
14959   case ISD::XOR:
14960     if (!LegalOperations && N0.hasOneUse() &&
14961         (isConstantOrConstantVector(N0.getOperand(0), true) ||
14962          isConstantOrConstantVector(N0.getOperand(1), true))) {
14963       // TODO: We already restricted this to pre-legalization, but for vectors
14964       // we are extra cautious to not create an unsupported operation.
14965       // Target-specific changes are likely needed to avoid regressions here.
14966       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
14967         SDLoc DL(N);
14968         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14969         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
14970         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
14971       }
14972     }
14973     break;
14974   case ISD::ADDE:
14975   case ISD::UADDO_CARRY:
14976     // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
14977     // (trunc uaddo_carry(X, Y, Carry)) ->
14978     //     (uaddo_carry trunc(X), trunc(Y), Carry)
14979     // When the adde's carry is not used.
14980     // We only do for uaddo_carry before legalize operation
14981     if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
14982          TLI.isOperationLegal(N0.getOpcode(), VT)) &&
14983         N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
14984       SDLoc DL(N);
14985       SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14986       SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
14987       SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
14988       return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
14989     }
14990     break;
14991   case ISD::USUBSAT:
14992     // Truncate the USUBSAT only if LHS is a known zero-extension, its not
14993     // enough to know that the upper bits are zero we must ensure that we don't
14994     // introduce an extra truncate.
14995     if (!LegalOperations && N0.hasOneUse() &&
14996         N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
14997         N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
14998             VT.getScalarSizeInBits() &&
14999         hasOperation(N0.getOpcode(), VT)) {
15000       return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
15001                                  DAG, SDLoc(N));
15002     }
15003     break;
15004   }
15005 
15006   return SDValue();
15007 }
15008 
getBuildPairElt(SDNode * N,unsigned i)15009 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
15010   SDValue Elt = N->getOperand(i);
15011   if (Elt.getOpcode() != ISD::MERGE_VALUES)
15012     return Elt.getNode();
15013   return Elt.getOperand(Elt.getResNo()).getNode();
15014 }
15015 
15016 /// build_pair (load, load) -> load
15017 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)15018 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
15019   assert(N->getOpcode() == ISD::BUILD_PAIR);
15020 
15021   auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
15022   auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
15023 
15024   // A BUILD_PAIR is always having the least significant part in elt 0 and the
15025   // most significant part in elt 1. So when combining into one large load, we
15026   // need to consider the endianness.
15027   if (DAG.getDataLayout().isBigEndian())
15028     std::swap(LD1, LD2);
15029 
15030   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
15031       !LD1->hasOneUse() || !LD2->hasOneUse() ||
15032       LD1->getAddressSpace() != LD2->getAddressSpace())
15033     return SDValue();
15034 
15035   unsigned LD1Fast = 0;
15036   EVT LD1VT = LD1->getValueType(0);
15037   unsigned LD1Bytes = LD1VT.getStoreSize();
15038   if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
15039       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
15040       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
15041                              *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
15042     return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
15043                        LD1->getPointerInfo(), LD1->getAlign());
15044 
15045   return SDValue();
15046 }
15047 
getPPCf128HiElementSelector(const SelectionDAG & DAG)15048 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
15049   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
15050   // and Lo parts; on big-endian machines it doesn't.
15051   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
15052 }
15053 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)15054 SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
15055                                           const TargetLowering &TLI) {
15056   // If this is not a bitcast to an FP type or if the target doesn't have
15057   // IEEE754-compliant FP logic, we're done.
15058   EVT VT = N->getValueType(0);
15059   SDValue N0 = N->getOperand(0);
15060   EVT SourceVT = N0.getValueType();
15061 
15062   if (!VT.isFloatingPoint())
15063     return SDValue();
15064 
15065   // TODO: Handle cases where the integer constant is a different scalar
15066   // bitwidth to the FP.
15067   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
15068     return SDValue();
15069 
15070   unsigned FPOpcode;
15071   APInt SignMask;
15072   switch (N0.getOpcode()) {
15073   case ISD::AND:
15074     FPOpcode = ISD::FABS;
15075     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
15076     break;
15077   case ISD::XOR:
15078     FPOpcode = ISD::FNEG;
15079     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
15080     break;
15081   case ISD::OR:
15082     FPOpcode = ISD::FABS;
15083     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
15084     break;
15085   default:
15086     return SDValue();
15087   }
15088 
15089   if (LegalOperations && !TLI.isOperationLegal(FPOpcode, VT))
15090     return SDValue();
15091 
15092   // This needs to be the inverse of logic in foldSignChangeInBitcast.
15093   // FIXME: I don't think looking for bitcast intrinsically makes sense, but
15094   // removing this would require more changes.
15095   auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
15096     if (Op.getOpcode() == ISD::BITCAST && Op.getOperand(0).getValueType() == VT)
15097       return true;
15098 
15099     return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
15100   };
15101 
15102   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
15103   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
15104   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
15105   //   fneg (fabs X)
15106   SDValue LogicOp0 = N0.getOperand(0);
15107   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
15108   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
15109       IsBitCastOrFree(LogicOp0, VT)) {
15110     SDValue CastOp0 = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, LogicOp0);
15111     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, CastOp0);
15112     NumFPLogicOpsConv++;
15113     if (N0.getOpcode() == ISD::OR)
15114       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
15115     return FPOp;
15116   }
15117 
15118   return SDValue();
15119 }
15120 
visitBITCAST(SDNode * N)15121 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
15122   SDValue N0 = N->getOperand(0);
15123   EVT VT = N->getValueType(0);
15124 
15125   if (N0.isUndef())
15126     return DAG.getUNDEF(VT);
15127 
15128   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
15129   // Only do this before legalize types, unless both types are integer and the
15130   // scalar type is legal. Only do this before legalize ops, since the target
15131   // maybe depending on the bitcast.
15132   // First check to see if this is all constant.
15133   // TODO: Support FP bitcasts after legalize types.
15134   if (VT.isVector() &&
15135       (!LegalTypes ||
15136        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
15137         TLI.isTypeLegal(VT.getVectorElementType()))) &&
15138       N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
15139       cast<BuildVectorSDNode>(N0)->isConstant())
15140     return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
15141                                              VT.getVectorElementType());
15142 
15143   // If the input is a constant, let getNode fold it.
15144   if (isIntOrFPConstant(N0)) {
15145     // If we can't allow illegal operations, we need to check that this is just
15146     // a fp -> int or int -> conversion and that the resulting operation will
15147     // be legal.
15148     if (!LegalOperations ||
15149         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
15150          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
15151         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
15152          TLI.isOperationLegal(ISD::Constant, VT))) {
15153       SDValue C = DAG.getBitcast(VT, N0);
15154       if (C.getNode() != N)
15155         return C;
15156     }
15157   }
15158 
15159   // (conv (conv x, t1), t2) -> (conv x, t2)
15160   if (N0.getOpcode() == ISD::BITCAST)
15161     return DAG.getBitcast(VT, N0.getOperand(0));
15162 
15163   // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
15164   // iff the current bitwise logicop type isn't legal
15165   if (ISD::isBitwiseLogicOp(N0.getOpcode()) && VT.isInteger() &&
15166       !TLI.isTypeLegal(N0.getOperand(0).getValueType())) {
15167     auto IsFreeBitcast = [VT](SDValue V) {
15168       return (V.getOpcode() == ISD::BITCAST &&
15169               V.getOperand(0).getValueType() == VT) ||
15170              (ISD::isBuildVectorOfConstantSDNodes(V.getNode()) &&
15171               V->hasOneUse());
15172     };
15173     if (IsFreeBitcast(N0.getOperand(0)) && IsFreeBitcast(N0.getOperand(1)))
15174       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT,
15175                          DAG.getBitcast(VT, N0.getOperand(0)),
15176                          DAG.getBitcast(VT, N0.getOperand(1)));
15177   }
15178 
15179   // fold (conv (load x)) -> (load (conv*)x)
15180   // If the resultant load doesn't need a higher alignment than the original!
15181   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
15182       // Do not remove the cast if the types differ in endian layout.
15183       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
15184           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
15185       // If the load is volatile, we only want to change the load type if the
15186       // resulting load is legal. Otherwise we might increase the number of
15187       // memory accesses. We don't care if the original type was legal or not
15188       // as we assume software couldn't rely on the number of accesses of an
15189       // illegal type.
15190       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
15191        TLI.isOperationLegal(ISD::LOAD, VT))) {
15192     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15193 
15194     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
15195                                     *LN0->getMemOperand())) {
15196       SDValue Load =
15197           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
15198                       LN0->getMemOperand());
15199       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
15200       return Load;
15201     }
15202   }
15203 
15204   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
15205     return V;
15206 
15207   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
15208   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
15209   //
15210   // For ppc_fp128:
15211   // fold (bitcast (fneg x)) ->
15212   //     flipbit = signbit
15213   //     (xor (bitcast x) (build_pair flipbit, flipbit))
15214   //
15215   // fold (bitcast (fabs x)) ->
15216   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
15217   //     (xor (bitcast x) (build_pair flipbit, flipbit))
15218   // This often reduces constant pool loads.
15219   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
15220        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
15221       N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
15222       !N0.getValueType().isVector()) {
15223     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
15224     AddToWorklist(NewConv.getNode());
15225 
15226     SDLoc DL(N);
15227     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15228       assert(VT.getSizeInBits() == 128);
15229       SDValue SignBit = DAG.getConstant(
15230           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
15231       SDValue FlipBit;
15232       if (N0.getOpcode() == ISD::FNEG) {
15233         FlipBit = SignBit;
15234         AddToWorklist(FlipBit.getNode());
15235       } else {
15236         assert(N0.getOpcode() == ISD::FABS);
15237         SDValue Hi =
15238             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
15239                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
15240                                               SDLoc(NewConv)));
15241         AddToWorklist(Hi.getNode());
15242         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
15243         AddToWorklist(FlipBit.getNode());
15244       }
15245       SDValue FlipBits =
15246           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
15247       AddToWorklist(FlipBits.getNode());
15248       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
15249     }
15250     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
15251     if (N0.getOpcode() == ISD::FNEG)
15252       return DAG.getNode(ISD::XOR, DL, VT,
15253                          NewConv, DAG.getConstant(SignBit, DL, VT));
15254     assert(N0.getOpcode() == ISD::FABS);
15255     return DAG.getNode(ISD::AND, DL, VT,
15256                        NewConv, DAG.getConstant(~SignBit, DL, VT));
15257   }
15258 
15259   // fold (bitconvert (fcopysign cst, x)) ->
15260   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
15261   // Note that we don't handle (copysign x, cst) because this can always be
15262   // folded to an fneg or fabs.
15263   //
15264   // For ppc_fp128:
15265   // fold (bitcast (fcopysign cst, x)) ->
15266   //     flipbit = (and (extract_element
15267   //                     (xor (bitcast cst), (bitcast x)), 0),
15268   //                    signbit)
15269   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
15270   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
15271       isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() &&
15272       !VT.isVector()) {
15273     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
15274     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
15275     if (isTypeLegal(IntXVT)) {
15276       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
15277       AddToWorklist(X.getNode());
15278 
15279       // If X has a different width than the result/lhs, sext it or truncate it.
15280       unsigned VTWidth = VT.getSizeInBits();
15281       if (OrigXWidth < VTWidth) {
15282         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
15283         AddToWorklist(X.getNode());
15284       } else if (OrigXWidth > VTWidth) {
15285         // To get the sign bit in the right place, we have to shift it right
15286         // before truncating.
15287         SDLoc DL(X);
15288         X = DAG.getNode(ISD::SRL, DL,
15289                         X.getValueType(), X,
15290                         DAG.getConstant(OrigXWidth-VTWidth, DL,
15291                                         X.getValueType()));
15292         AddToWorklist(X.getNode());
15293         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
15294         AddToWorklist(X.getNode());
15295       }
15296 
15297       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15298         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
15299         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
15300         AddToWorklist(Cst.getNode());
15301         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
15302         AddToWorklist(X.getNode());
15303         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
15304         AddToWorklist(XorResult.getNode());
15305         SDValue XorResult64 = DAG.getNode(
15306             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
15307             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
15308                                   SDLoc(XorResult)));
15309         AddToWorklist(XorResult64.getNode());
15310         SDValue FlipBit =
15311             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
15312                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
15313         AddToWorklist(FlipBit.getNode());
15314         SDValue FlipBits =
15315             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
15316         AddToWorklist(FlipBits.getNode());
15317         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
15318       }
15319       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
15320       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
15321                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
15322       AddToWorklist(X.getNode());
15323 
15324       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
15325       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
15326                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
15327       AddToWorklist(Cst.getNode());
15328 
15329       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
15330     }
15331   }
15332 
15333   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
15334   if (N0.getOpcode() == ISD::BUILD_PAIR)
15335     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
15336       return CombineLD;
15337 
15338   // Remove double bitcasts from shuffles - this is often a legacy of
15339   // XformToShuffleWithZero being used to combine bitmaskings (of
15340   // float vectors bitcast to integer vectors) into shuffles.
15341   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
15342   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
15343       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
15344       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
15345       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
15346     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
15347 
15348     // If operands are a bitcast, peek through if it casts the original VT.
15349     // If operands are a constant, just bitcast back to original VT.
15350     auto PeekThroughBitcast = [&](SDValue Op) {
15351       if (Op.getOpcode() == ISD::BITCAST &&
15352           Op.getOperand(0).getValueType() == VT)
15353         return SDValue(Op.getOperand(0));
15354       if (Op.isUndef() || isAnyConstantBuildVector(Op))
15355         return DAG.getBitcast(VT, Op);
15356       return SDValue();
15357     };
15358 
15359     // FIXME: If either input vector is bitcast, try to convert the shuffle to
15360     // the result type of this bitcast. This would eliminate at least one
15361     // bitcast. See the transform in InstCombine.
15362     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
15363     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
15364     if (!(SV0 && SV1))
15365       return SDValue();
15366 
15367     int MaskScale =
15368         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
15369     SmallVector<int, 8> NewMask;
15370     for (int M : SVN->getMask())
15371       for (int i = 0; i != MaskScale; ++i)
15372         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
15373 
15374     SDValue LegalShuffle =
15375         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
15376     if (LegalShuffle)
15377       return LegalShuffle;
15378   }
15379 
15380   return SDValue();
15381 }
15382 
visitBUILD_PAIR(SDNode * N)15383 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
15384   EVT VT = N->getValueType(0);
15385   return CombineConsecutiveLoads(N, VT);
15386 }
15387 
visitFREEZE(SDNode * N)15388 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
15389   SDValue N0 = N->getOperand(0);
15390 
15391   if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
15392     return N0;
15393 
15394   // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
15395   // Try to push freeze through instructions that propagate but don't produce
15396   // poison as far as possible. If an operand of freeze follows three
15397   // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
15398   // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
15399   // the freeze through to the operands that are not guaranteed non-poison.
15400   // NOTE: we will strip poison-generating flags, so ignore them here.
15401   if (DAG.canCreateUndefOrPoison(N0, /*PoisonOnly*/ false,
15402                                  /*ConsiderFlags*/ false) ||
15403       N0->getNumValues() != 1 || !N0->hasOneUse())
15404     return SDValue();
15405 
15406   bool AllowMultipleMaybePoisonOperands = N0.getOpcode() == ISD::BUILD_VECTOR ||
15407                                           N0.getOpcode() == ISD::BUILD_PAIR ||
15408                                           N0.getOpcode() == ISD::CONCAT_VECTORS;
15409 
15410   SmallSetVector<SDValue, 8> MaybePoisonOperands;
15411   for (SDValue Op : N0->ops()) {
15412     if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
15413                                              /*Depth*/ 1))
15414       continue;
15415     bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
15416     bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(Op);
15417     if (!HadMaybePoisonOperands)
15418       continue;
15419     if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
15420       // Multiple maybe-poison ops when not allowed - bail out.
15421       return SDValue();
15422     }
15423   }
15424   // NOTE: the whole op may be not guaranteed to not be undef or poison because
15425   // it could create undef or poison due to it's poison-generating flags.
15426   // So not finding any maybe-poison operands is fine.
15427 
15428   for (SDValue MaybePoisonOperand : MaybePoisonOperands) {
15429     // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
15430     if (MaybePoisonOperand.getOpcode() == ISD::UNDEF)
15431       continue;
15432     // First, freeze each offending operand.
15433     SDValue FrozenMaybePoisonOperand = DAG.getFreeze(MaybePoisonOperand);
15434     // Then, change all other uses of unfrozen operand to use frozen operand.
15435     DAG.ReplaceAllUsesOfValueWith(MaybePoisonOperand, FrozenMaybePoisonOperand);
15436     if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
15437         FrozenMaybePoisonOperand.getOperand(0) == FrozenMaybePoisonOperand) {
15438       // But, that also updated the use in the freeze we just created, thus
15439       // creating a cycle in a DAG. Let's undo that by mutating the freeze.
15440       DAG.UpdateNodeOperands(FrozenMaybePoisonOperand.getNode(),
15441                              MaybePoisonOperand);
15442     }
15443   }
15444 
15445   // This node has been merged with another.
15446   if (N->getOpcode() == ISD::DELETED_NODE)
15447     return SDValue(N, 0);
15448 
15449   // The whole node may have been updated, so the value we were holding
15450   // may no longer be valid. Re-fetch the operand we're `freeze`ing.
15451   N0 = N->getOperand(0);
15452 
15453   // Finally, recreate the node, it's operands were updated to use
15454   // frozen operands, so we just need to use it's "original" operands.
15455   SmallVector<SDValue> Ops(N0->op_begin(), N0->op_end());
15456   // Special-handle ISD::UNDEF, each single one of them can be it's own thing.
15457   for (SDValue &Op : Ops) {
15458     if (Op.getOpcode() == ISD::UNDEF)
15459       Op = DAG.getFreeze(Op);
15460   }
15461   // NOTE: this strips poison generating flags.
15462   SDValue R = DAG.getNode(N0.getOpcode(), SDLoc(N0), N0->getVTList(), Ops);
15463   assert(DAG.isGuaranteedNotToBeUndefOrPoison(R, /*PoisonOnly*/ false) &&
15464          "Can't create node that may be undef/poison!");
15465   return R;
15466 }
15467 
15468 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
15469 /// operands. DstEltVT indicates the destination element value type.
15470 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)15471 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
15472   EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
15473 
15474   // If this is already the right type, we're done.
15475   if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
15476 
15477   unsigned SrcBitSize = SrcEltVT.getSizeInBits();
15478   unsigned DstBitSize = DstEltVT.getSizeInBits();
15479 
15480   // If this is a conversion of N elements of one type to N elements of another
15481   // type, convert each element.  This handles FP<->INT cases.
15482   if (SrcBitSize == DstBitSize) {
15483     SmallVector<SDValue, 8> Ops;
15484     for (SDValue Op : BV->op_values()) {
15485       // If the vector element type is not legal, the BUILD_VECTOR operands
15486       // are promoted and implicitly truncated.  Make that explicit here.
15487       if (Op.getValueType() != SrcEltVT)
15488         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
15489       Ops.push_back(DAG.getBitcast(DstEltVT, Op));
15490       AddToWorklist(Ops.back().getNode());
15491     }
15492     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
15493                               BV->getValueType(0).getVectorNumElements());
15494     return DAG.getBuildVector(VT, SDLoc(BV), Ops);
15495   }
15496 
15497   // Otherwise, we're growing or shrinking the elements.  To avoid having to
15498   // handle annoying details of growing/shrinking FP values, we convert them to
15499   // int first.
15500   if (SrcEltVT.isFloatingPoint()) {
15501     // Convert the input float vector to a int vector where the elements are the
15502     // same sizes.
15503     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
15504     BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
15505     SrcEltVT = IntVT;
15506   }
15507 
15508   // Now we know the input is an integer vector.  If the output is a FP type,
15509   // convert to integer first, then to FP of the right size.
15510   if (DstEltVT.isFloatingPoint()) {
15511     EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
15512     SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
15513 
15514     // Next, convert to FP elements of the same size.
15515     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
15516   }
15517 
15518   // Okay, we know the src/dst types are both integers of differing types.
15519   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
15520 
15521   // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
15522   // BuildVectorSDNode?
15523   auto *BVN = cast<BuildVectorSDNode>(BV);
15524 
15525   // Extract the constant raw bit data.
15526   BitVector UndefElements;
15527   SmallVector<APInt> RawBits;
15528   bool IsLE = DAG.getDataLayout().isLittleEndian();
15529   if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
15530     return SDValue();
15531 
15532   SDLoc DL(BV);
15533   SmallVector<SDValue, 8> Ops;
15534   for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
15535     if (UndefElements[I])
15536       Ops.push_back(DAG.getUNDEF(DstEltVT));
15537     else
15538       Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
15539   }
15540 
15541   EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
15542   return DAG.getBuildVector(VT, DL, Ops);
15543 }
15544 
15545 // Returns true if floating point contraction is allowed on the FMUL-SDValue
15546 // `N`
isContractableFMUL(const TargetOptions & Options,SDValue N)15547 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
15548   assert(N.getOpcode() == ISD::FMUL);
15549 
15550   return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
15551          N->getFlags().hasAllowContract();
15552 }
15553 
15554 // Returns true if `N` can assume no infinities involved in its computation.
hasNoInfs(const TargetOptions & Options,SDValue N)15555 static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
15556   return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
15557 }
15558 
15559 /// Try to perform FMA combining on a given FADD node.
15560 template <class MatchContextClass>
visitFADDForFMACombine(SDNode * N)15561 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
15562   SDValue N0 = N->getOperand(0);
15563   SDValue N1 = N->getOperand(1);
15564   EVT VT = N->getValueType(0);
15565   SDLoc SL(N);
15566   MatchContextClass matcher(DAG, TLI, N);
15567   const TargetOptions &Options = DAG.getTarget().Options;
15568 
15569   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
15570 
15571   // Floating-point multiply-add with intermediate rounding.
15572   // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
15573   // FIXME: Add VP_FMAD opcode.
15574   bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
15575 
15576   // Floating-point multiply-add without intermediate rounding.
15577   bool HasFMA =
15578       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
15579       (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
15580 
15581   // No valid opcode, do not combine.
15582   if (!HasFMAD && !HasFMA)
15583     return SDValue();
15584 
15585   bool CanReassociate =
15586       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15587   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15588                               Options.UnsafeFPMath || HasFMAD);
15589   // If the addition is not contractable, do not combine.
15590   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
15591     return SDValue();
15592 
15593   // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
15594   // beneficial. It does not reduce latency. It increases register pressure. It
15595   // replaces an fadd with an fma which is a more complex instruction, so is
15596   // likely to have a larger encoding, use more functional units, etc.
15597   if (N0 == N1)
15598     return SDValue();
15599 
15600   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
15601     return SDValue();
15602 
15603   // Always prefer FMAD to FMA for precision.
15604   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15605   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15606 
15607   auto isFusedOp = [&](SDValue N) {
15608     return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
15609   };
15610 
15611   // Is the node an FMUL and contractable either due to global flags or
15612   // SDNodeFlags.
15613   auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
15614     if (!matcher.match(N, ISD::FMUL))
15615       return false;
15616     return AllowFusionGlobally || N->getFlags().hasAllowContract();
15617   };
15618   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
15619   // prefer to fold the multiply with fewer uses.
15620   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
15621     if (N0->use_size() > N1->use_size())
15622       std::swap(N0, N1);
15623   }
15624 
15625   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
15626   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
15627     return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
15628                            N0.getOperand(1), N1);
15629   }
15630 
15631   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
15632   // Note: Commutes FADD operands.
15633   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
15634     return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
15635                            N1.getOperand(1), N0);
15636   }
15637 
15638   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
15639   // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
15640   // This also works with nested fma instructions:
15641   // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
15642   // fma A, B, (fma C, D, fma (E, F, G))
15643   // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
15644   // fma A, B, (fma C, D, fma (E, F, G)).
15645   // This requires reassociation because it changes the order of operations.
15646   if (CanReassociate) {
15647     SDValue FMA, E;
15648     if (isFusedOp(N0) && N0.hasOneUse()) {
15649       FMA = N0;
15650       E = N1;
15651     } else if (isFusedOp(N1) && N1.hasOneUse()) {
15652       FMA = N1;
15653       E = N0;
15654     }
15655 
15656     SDValue TmpFMA = FMA;
15657     while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
15658       SDValue FMul = TmpFMA->getOperand(2);
15659       if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
15660         SDValue C = FMul.getOperand(0);
15661         SDValue D = FMul.getOperand(1);
15662         SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
15663         DAG.ReplaceAllUsesOfValueWith(FMul, CDE);
15664         // Replacing the inner FMul could cause the outer FMA to be simplified
15665         // away.
15666         return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
15667       }
15668 
15669       TmpFMA = TmpFMA->getOperand(2);
15670     }
15671   }
15672 
15673   // Look through FP_EXTEND nodes to do more combining.
15674 
15675   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
15676   if (matcher.match(N0, ISD::FP_EXTEND)) {
15677     SDValue N00 = N0.getOperand(0);
15678     if (isContractableFMUL(N00) &&
15679         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15680                             N00.getValueType())) {
15681       return matcher.getNode(
15682           PreferredFusedOpcode, SL, VT,
15683           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
15684           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1);
15685     }
15686   }
15687 
15688   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
15689   // Note: Commutes FADD operands.
15690   if (matcher.match(N1, ISD::FP_EXTEND)) {
15691     SDValue N10 = N1.getOperand(0);
15692     if (isContractableFMUL(N10) &&
15693         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15694                             N10.getValueType())) {
15695       return matcher.getNode(
15696           PreferredFusedOpcode, SL, VT,
15697           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
15698           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
15699     }
15700   }
15701 
15702   // More folding opportunities when target permits.
15703   if (Aggressive) {
15704     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
15705     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
15706     auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
15707                                     SDValue Z) {
15708       return matcher.getNode(
15709           PreferredFusedOpcode, SL, VT, X, Y,
15710           matcher.getNode(PreferredFusedOpcode, SL, VT,
15711                           matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
15712                           matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
15713     };
15714     if (isFusedOp(N0)) {
15715       SDValue N02 = N0.getOperand(2);
15716       if (matcher.match(N02, ISD::FP_EXTEND)) {
15717         SDValue N020 = N02.getOperand(0);
15718         if (isContractableFMUL(N020) &&
15719             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15720                                 N020.getValueType())) {
15721           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
15722                                       N020.getOperand(0), N020.getOperand(1),
15723                                       N1);
15724         }
15725       }
15726     }
15727 
15728     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
15729     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
15730     // FIXME: This turns two single-precision and one double-precision
15731     // operation into two double-precision operations, which might not be
15732     // interesting for all targets, especially GPUs.
15733     auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
15734                                     SDValue Z) {
15735       return matcher.getNode(
15736           PreferredFusedOpcode, SL, VT,
15737           matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
15738           matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
15739           matcher.getNode(PreferredFusedOpcode, SL, VT,
15740                           matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
15741                           matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
15742     };
15743     if (N0.getOpcode() == ISD::FP_EXTEND) {
15744       SDValue N00 = N0.getOperand(0);
15745       if (isFusedOp(N00)) {
15746         SDValue N002 = N00.getOperand(2);
15747         if (isContractableFMUL(N002) &&
15748             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15749                                 N00.getValueType())) {
15750           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
15751                                       N002.getOperand(0), N002.getOperand(1),
15752                                       N1);
15753         }
15754       }
15755     }
15756 
15757     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
15758     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
15759     if (isFusedOp(N1)) {
15760       SDValue N12 = N1.getOperand(2);
15761       if (N12.getOpcode() == ISD::FP_EXTEND) {
15762         SDValue N120 = N12.getOperand(0);
15763         if (isContractableFMUL(N120) &&
15764             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15765                                 N120.getValueType())) {
15766           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
15767                                       N120.getOperand(0), N120.getOperand(1),
15768                                       N0);
15769         }
15770       }
15771     }
15772 
15773     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
15774     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
15775     // FIXME: This turns two single-precision and one double-precision
15776     // operation into two double-precision operations, which might not be
15777     // interesting for all targets, especially GPUs.
15778     if (N1.getOpcode() == ISD::FP_EXTEND) {
15779       SDValue N10 = N1.getOperand(0);
15780       if (isFusedOp(N10)) {
15781         SDValue N102 = N10.getOperand(2);
15782         if (isContractableFMUL(N102) &&
15783             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15784                                 N10.getValueType())) {
15785           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
15786                                       N102.getOperand(0), N102.getOperand(1),
15787                                       N0);
15788         }
15789       }
15790     }
15791   }
15792 
15793   return SDValue();
15794 }
15795 
15796 /// Try to perform FMA combining on a given FSUB node.
15797 template <class MatchContextClass>
visitFSUBForFMACombine(SDNode * N)15798 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
15799   SDValue N0 = N->getOperand(0);
15800   SDValue N1 = N->getOperand(1);
15801   EVT VT = N->getValueType(0);
15802   SDLoc SL(N);
15803   MatchContextClass matcher(DAG, TLI, N);
15804   const TargetOptions &Options = DAG.getTarget().Options;
15805 
15806   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
15807 
15808   // Floating-point multiply-add with intermediate rounding.
15809   // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
15810   // FIXME: Add VP_FMAD opcode.
15811   bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
15812 
15813   // Floating-point multiply-add without intermediate rounding.
15814   bool HasFMA =
15815       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
15816       (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
15817 
15818   // No valid opcode, do not combine.
15819   if (!HasFMAD && !HasFMA)
15820     return SDValue();
15821 
15822   const SDNodeFlags Flags = N->getFlags();
15823   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15824                               Options.UnsafeFPMath || HasFMAD);
15825 
15826   // If the subtraction is not contractable, do not combine.
15827   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
15828     return SDValue();
15829 
15830   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
15831     return SDValue();
15832 
15833   // Always prefer FMAD to FMA for precision.
15834   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15835   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15836   bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
15837 
15838   // Is the node an FMUL and contractable either due to global flags or
15839   // SDNodeFlags.
15840   auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
15841     if (!matcher.match(N, ISD::FMUL))
15842       return false;
15843     return AllowFusionGlobally || N->getFlags().hasAllowContract();
15844   };
15845 
15846   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
15847   auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
15848     if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
15849       return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
15850                              XY.getOperand(1),
15851                              matcher.getNode(ISD::FNEG, SL, VT, Z));
15852     }
15853     return SDValue();
15854   };
15855 
15856   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
15857   // Note: Commutes FSUB operands.
15858   auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
15859     if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
15860       return matcher.getNode(
15861           PreferredFusedOpcode, SL, VT,
15862           matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
15863           YZ.getOperand(1), X);
15864     }
15865     return SDValue();
15866   };
15867 
15868   // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
15869   // prefer to fold the multiply with fewer uses.
15870   if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
15871       (N0->use_size() > N1->use_size())) {
15872     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
15873     if (SDValue V = tryToFoldXSubYZ(N0, N1))
15874       return V;
15875     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
15876     if (SDValue V = tryToFoldXYSubZ(N0, N1))
15877       return V;
15878   } else {
15879     // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
15880     if (SDValue V = tryToFoldXYSubZ(N0, N1))
15881       return V;
15882     // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
15883     if (SDValue V = tryToFoldXSubYZ(N0, N1))
15884       return V;
15885   }
15886 
15887   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
15888   if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) &&
15889       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
15890     SDValue N00 = N0.getOperand(0).getOperand(0);
15891     SDValue N01 = N0.getOperand(0).getOperand(1);
15892     return matcher.getNode(PreferredFusedOpcode, SL, VT,
15893                            matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
15894                            matcher.getNode(ISD::FNEG, SL, VT, N1));
15895   }
15896 
15897   // Look through FP_EXTEND nodes to do more combining.
15898 
15899   // fold (fsub (fpext (fmul x, y)), z)
15900   //   -> (fma (fpext x), (fpext y), (fneg z))
15901   if (matcher.match(N0, ISD::FP_EXTEND)) {
15902     SDValue N00 = N0.getOperand(0);
15903     if (isContractableFMUL(N00) &&
15904         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15905                             N00.getValueType())) {
15906       return matcher.getNode(
15907           PreferredFusedOpcode, SL, VT,
15908           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
15909           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
15910           matcher.getNode(ISD::FNEG, SL, VT, N1));
15911     }
15912   }
15913 
15914   // fold (fsub x, (fpext (fmul y, z)))
15915   //   -> (fma (fneg (fpext y)), (fpext z), x)
15916   // Note: Commutes FSUB operands.
15917   if (matcher.match(N1, ISD::FP_EXTEND)) {
15918     SDValue N10 = N1.getOperand(0);
15919     if (isContractableFMUL(N10) &&
15920         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15921                             N10.getValueType())) {
15922       return matcher.getNode(
15923           PreferredFusedOpcode, SL, VT,
15924           matcher.getNode(
15925               ISD::FNEG, SL, VT,
15926               matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
15927           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
15928     }
15929   }
15930 
15931   // fold (fsub (fpext (fneg (fmul, x, y))), z)
15932   //   -> (fneg (fma (fpext x), (fpext y), z))
15933   // Note: This could be removed with appropriate canonicalization of the
15934   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
15935   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
15936   // from implementing the canonicalization in visitFSUB.
15937   if (matcher.match(N0, ISD::FP_EXTEND)) {
15938     SDValue N00 = N0.getOperand(0);
15939     if (matcher.match(N00, ISD::FNEG)) {
15940       SDValue N000 = N00.getOperand(0);
15941       if (isContractableFMUL(N000) &&
15942           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15943                               N00.getValueType())) {
15944         return matcher.getNode(
15945             ISD::FNEG, SL, VT,
15946             matcher.getNode(
15947                 PreferredFusedOpcode, SL, VT,
15948                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
15949                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
15950                 N1));
15951       }
15952     }
15953   }
15954 
15955   // fold (fsub (fneg (fpext (fmul, x, y))), z)
15956   //   -> (fneg (fma (fpext x)), (fpext y), z)
15957   // Note: This could be removed with appropriate canonicalization of the
15958   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
15959   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
15960   // from implementing the canonicalization in visitFSUB.
15961   if (matcher.match(N0, ISD::FNEG)) {
15962     SDValue N00 = N0.getOperand(0);
15963     if (matcher.match(N00, ISD::FP_EXTEND)) {
15964       SDValue N000 = N00.getOperand(0);
15965       if (isContractableFMUL(N000) &&
15966           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15967                               N000.getValueType())) {
15968         return matcher.getNode(
15969             ISD::FNEG, SL, VT,
15970             matcher.getNode(
15971                 PreferredFusedOpcode, SL, VT,
15972                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
15973                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
15974                 N1));
15975       }
15976     }
15977   }
15978 
15979   auto isReassociable = [&Options](SDNode *N) {
15980     return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15981   };
15982 
15983   auto isContractableAndReassociableFMUL = [&isContractableFMUL,
15984                                             &isReassociable](SDValue N) {
15985     return isContractableFMUL(N) && isReassociable(N.getNode());
15986   };
15987 
15988   auto isFusedOp = [&](SDValue N) {
15989     return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
15990   };
15991 
15992   // More folding opportunities when target permits.
15993   if (Aggressive && isReassociable(N)) {
15994     bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
15995     // fold (fsub (fma x, y, (fmul u, v)), z)
15996     //   -> (fma x, y (fma u, v, (fneg z)))
15997     if (CanFuse && isFusedOp(N0) &&
15998         isContractableAndReassociableFMUL(N0.getOperand(2)) &&
15999         N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
16000       return matcher.getNode(
16001           PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
16002           matcher.getNode(PreferredFusedOpcode, SL, VT,
16003                           N0.getOperand(2).getOperand(0),
16004                           N0.getOperand(2).getOperand(1),
16005                           matcher.getNode(ISD::FNEG, SL, VT, N1)));
16006     }
16007 
16008     // fold (fsub x, (fma y, z, (fmul u, v)))
16009     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
16010     if (CanFuse && isFusedOp(N1) &&
16011         isContractableAndReassociableFMUL(N1.getOperand(2)) &&
16012         N1->hasOneUse() && NoSignedZero) {
16013       SDValue N20 = N1.getOperand(2).getOperand(0);
16014       SDValue N21 = N1.getOperand(2).getOperand(1);
16015       return matcher.getNode(
16016           PreferredFusedOpcode, SL, VT,
16017           matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
16018           N1.getOperand(1),
16019           matcher.getNode(PreferredFusedOpcode, SL, VT,
16020                           matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
16021     }
16022 
16023     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
16024     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
16025     if (isFusedOp(N0) && N0->hasOneUse()) {
16026       SDValue N02 = N0.getOperand(2);
16027       if (matcher.match(N02, ISD::FP_EXTEND)) {
16028         SDValue N020 = N02.getOperand(0);
16029         if (isContractableAndReassociableFMUL(N020) &&
16030             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16031                                 N020.getValueType())) {
16032           return matcher.getNode(
16033               PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
16034               matcher.getNode(
16035                   PreferredFusedOpcode, SL, VT,
16036                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
16037                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
16038                   matcher.getNode(ISD::FNEG, SL, VT, N1)));
16039         }
16040       }
16041     }
16042 
16043     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
16044     //   -> (fma (fpext x), (fpext y),
16045     //           (fma (fpext u), (fpext v), (fneg z)))
16046     // FIXME: This turns two single-precision and one double-precision
16047     // operation into two double-precision operations, which might not be
16048     // interesting for all targets, especially GPUs.
16049     if (matcher.match(N0, ISD::FP_EXTEND)) {
16050       SDValue N00 = N0.getOperand(0);
16051       if (isFusedOp(N00)) {
16052         SDValue N002 = N00.getOperand(2);
16053         if (isContractableAndReassociableFMUL(N002) &&
16054             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16055                                 N00.getValueType())) {
16056           return matcher.getNode(
16057               PreferredFusedOpcode, SL, VT,
16058               matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
16059               matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
16060               matcher.getNode(
16061                   PreferredFusedOpcode, SL, VT,
16062                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
16063                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
16064                   matcher.getNode(ISD::FNEG, SL, VT, N1)));
16065         }
16066       }
16067     }
16068 
16069     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
16070     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
16071     if (isFusedOp(N1) && matcher.match(N1.getOperand(2), ISD::FP_EXTEND) &&
16072         N1->hasOneUse()) {
16073       SDValue N120 = N1.getOperand(2).getOperand(0);
16074       if (isContractableAndReassociableFMUL(N120) &&
16075           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16076                               N120.getValueType())) {
16077         SDValue N1200 = N120.getOperand(0);
16078         SDValue N1201 = N120.getOperand(1);
16079         return matcher.getNode(
16080             PreferredFusedOpcode, SL, VT,
16081             matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
16082             N1.getOperand(1),
16083             matcher.getNode(
16084                 PreferredFusedOpcode, SL, VT,
16085                 matcher.getNode(ISD::FNEG, SL, VT,
16086                                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
16087                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
16088       }
16089     }
16090 
16091     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
16092     //   -> (fma (fneg (fpext y)), (fpext z),
16093     //           (fma (fneg (fpext u)), (fpext v), x))
16094     // FIXME: This turns two single-precision and one double-precision
16095     // operation into two double-precision operations, which might not be
16096     // interesting for all targets, especially GPUs.
16097     if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) {
16098       SDValue CvtSrc = N1.getOperand(0);
16099       SDValue N100 = CvtSrc.getOperand(0);
16100       SDValue N101 = CvtSrc.getOperand(1);
16101       SDValue N102 = CvtSrc.getOperand(2);
16102       if (isContractableAndReassociableFMUL(N102) &&
16103           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16104                               CvtSrc.getValueType())) {
16105         SDValue N1020 = N102.getOperand(0);
16106         SDValue N1021 = N102.getOperand(1);
16107         return matcher.getNode(
16108             PreferredFusedOpcode, SL, VT,
16109             matcher.getNode(ISD::FNEG, SL, VT,
16110                             matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
16111             matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
16112             matcher.getNode(
16113                 PreferredFusedOpcode, SL, VT,
16114                 matcher.getNode(ISD::FNEG, SL, VT,
16115                                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
16116                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
16117       }
16118     }
16119   }
16120 
16121   return SDValue();
16122 }
16123 
16124 /// Try to perform FMA combining on a given FMUL node based on the distributive
16125 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
16126 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)16127 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
16128   SDValue N0 = N->getOperand(0);
16129   SDValue N1 = N->getOperand(1);
16130   EVT VT = N->getValueType(0);
16131   SDLoc SL(N);
16132 
16133   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
16134 
16135   const TargetOptions &Options = DAG.getTarget().Options;
16136 
16137   // The transforms below are incorrect when x == 0 and y == inf, because the
16138   // intermediate multiplication produces a nan.
16139   SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
16140   if (!hasNoInfs(Options, FAdd))
16141     return SDValue();
16142 
16143   // Floating-point multiply-add without intermediate rounding.
16144   bool HasFMA =
16145       isContractableFMUL(Options, SDValue(N, 0)) &&
16146       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
16147       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
16148 
16149   // Floating-point multiply-add with intermediate rounding. This can result
16150   // in a less precise result due to the changed rounding order.
16151   bool HasFMAD = Options.UnsafeFPMath &&
16152                  (LegalOperations && TLI.isFMADLegal(DAG, N));
16153 
16154   // No valid opcode, do not combine.
16155   if (!HasFMAD && !HasFMA)
16156     return SDValue();
16157 
16158   // Always prefer FMAD to FMA for precision.
16159   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16160   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16161 
16162   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
16163   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
16164   auto FuseFADD = [&](SDValue X, SDValue Y) {
16165     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
16166       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
16167         if (C->isExactlyValue(+1.0))
16168           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
16169                              Y);
16170         if (C->isExactlyValue(-1.0))
16171           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
16172                              DAG.getNode(ISD::FNEG, SL, VT, Y));
16173       }
16174     }
16175     return SDValue();
16176   };
16177 
16178   if (SDValue FMA = FuseFADD(N0, N1))
16179     return FMA;
16180   if (SDValue FMA = FuseFADD(N1, N0))
16181     return FMA;
16182 
16183   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
16184   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
16185   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
16186   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
16187   auto FuseFSUB = [&](SDValue X, SDValue Y) {
16188     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
16189       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
16190         if (C0->isExactlyValue(+1.0))
16191           return DAG.getNode(PreferredFusedOpcode, SL, VT,
16192                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
16193                              Y);
16194         if (C0->isExactlyValue(-1.0))
16195           return DAG.getNode(PreferredFusedOpcode, SL, VT,
16196                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
16197                              DAG.getNode(ISD::FNEG, SL, VT, Y));
16198       }
16199       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
16200         if (C1->isExactlyValue(+1.0))
16201           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
16202                              DAG.getNode(ISD::FNEG, SL, VT, Y));
16203         if (C1->isExactlyValue(-1.0))
16204           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
16205                              Y);
16206       }
16207     }
16208     return SDValue();
16209   };
16210 
16211   if (SDValue FMA = FuseFSUB(N0, N1))
16212     return FMA;
16213   if (SDValue FMA = FuseFSUB(N1, N0))
16214     return FMA;
16215 
16216   return SDValue();
16217 }
16218 
visitVP_FADD(SDNode * N)16219 SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
16220   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16221 
16222   // FADD -> FMA combines:
16223   if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
16224     if (Fused.getOpcode() != ISD::DELETED_NODE)
16225       AddToWorklist(Fused.getNode());
16226     return Fused;
16227   }
16228   return SDValue();
16229 }
16230 
visitFADD(SDNode * N)16231 SDValue DAGCombiner::visitFADD(SDNode *N) {
16232   SDValue N0 = N->getOperand(0);
16233   SDValue N1 = N->getOperand(1);
16234   SDNode *N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
16235   SDNode *N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
16236   EVT VT = N->getValueType(0);
16237   SDLoc DL(N);
16238   const TargetOptions &Options = DAG.getTarget().Options;
16239   SDNodeFlags Flags = N->getFlags();
16240   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16241 
16242   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
16243     return R;
16244 
16245   // fold (fadd c1, c2) -> c1 + c2
16246   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
16247     return C;
16248 
16249   // canonicalize constant to RHS
16250   if (N0CFP && !N1CFP)
16251     return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
16252 
16253   // fold vector ops
16254   if (VT.isVector())
16255     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16256       return FoldedVOp;
16257 
16258   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
16259   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
16260   if (N1C && N1C->isZero())
16261     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
16262       return N0;
16263 
16264   if (SDValue NewSel = foldBinOpIntoSelect(N))
16265     return NewSel;
16266 
16267   // fold (fadd A, (fneg B)) -> (fsub A, B)
16268   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
16269     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16270             N1, DAG, LegalOperations, ForCodeSize))
16271       return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
16272 
16273   // fold (fadd (fneg A), B) -> (fsub B, A)
16274   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
16275     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16276             N0, DAG, LegalOperations, ForCodeSize))
16277       return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
16278 
16279   auto isFMulNegTwo = [](SDValue FMul) {
16280     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
16281       return false;
16282     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
16283     return C && C->isExactlyValue(-2.0);
16284   };
16285 
16286   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
16287   if (isFMulNegTwo(N0)) {
16288     SDValue B = N0.getOperand(0);
16289     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
16290     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
16291   }
16292   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
16293   if (isFMulNegTwo(N1)) {
16294     SDValue B = N1.getOperand(0);
16295     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
16296     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
16297   }
16298 
16299   // No FP constant should be created after legalization as Instruction
16300   // Selection pass has a hard time dealing with FP constants.
16301   bool AllowNewConst = (Level < AfterLegalizeDAG);
16302 
16303   // If nnan is enabled, fold lots of things.
16304   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
16305     // If allowed, fold (fadd (fneg x), x) -> 0.0
16306     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
16307       return DAG.getConstantFP(0.0, DL, VT);
16308 
16309     // If allowed, fold (fadd x, (fneg x)) -> 0.0
16310     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
16311       return DAG.getConstantFP(0.0, DL, VT);
16312   }
16313 
16314   // If 'unsafe math' or reassoc and nsz, fold lots of things.
16315   // TODO: break out portions of the transformations below for which Unsafe is
16316   //       considered and which do not require both nsz and reassoc
16317   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16318        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16319       AllowNewConst) {
16320     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
16321     if (N1CFP && N0.getOpcode() == ISD::FADD &&
16322         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
16323       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
16324       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
16325     }
16326 
16327     // We can fold chains of FADD's of the same value into multiplications.
16328     // This transform is not safe in general because we are reducing the number
16329     // of rounding steps.
16330     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
16331       if (N0.getOpcode() == ISD::FMUL) {
16332         SDNode *CFP00 =
16333             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
16334         SDNode *CFP01 =
16335             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
16336 
16337         // (fadd (fmul x, c), x) -> (fmul x, c+1)
16338         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
16339           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
16340                                        DAG.getConstantFP(1.0, DL, VT));
16341           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
16342         }
16343 
16344         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
16345         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
16346             N1.getOperand(0) == N1.getOperand(1) &&
16347             N0.getOperand(0) == N1.getOperand(0)) {
16348           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
16349                                        DAG.getConstantFP(2.0, DL, VT));
16350           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
16351         }
16352       }
16353 
16354       if (N1.getOpcode() == ISD::FMUL) {
16355         SDNode *CFP10 =
16356             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
16357         SDNode *CFP11 =
16358             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
16359 
16360         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
16361         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
16362           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
16363                                        DAG.getConstantFP(1.0, DL, VT));
16364           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
16365         }
16366 
16367         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
16368         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
16369             N0.getOperand(0) == N0.getOperand(1) &&
16370             N1.getOperand(0) == N0.getOperand(0)) {
16371           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
16372                                        DAG.getConstantFP(2.0, DL, VT));
16373           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
16374         }
16375       }
16376 
16377       if (N0.getOpcode() == ISD::FADD) {
16378         SDNode *CFP00 =
16379             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
16380         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
16381         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
16382             (N0.getOperand(0) == N1)) {
16383           return DAG.getNode(ISD::FMUL, DL, VT, N1,
16384                              DAG.getConstantFP(3.0, DL, VT));
16385         }
16386       }
16387 
16388       if (N1.getOpcode() == ISD::FADD) {
16389         SDNode *CFP10 =
16390             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
16391         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
16392         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
16393             N1.getOperand(0) == N0) {
16394           return DAG.getNode(ISD::FMUL, DL, VT, N0,
16395                              DAG.getConstantFP(3.0, DL, VT));
16396         }
16397       }
16398 
16399       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
16400       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
16401           N0.getOperand(0) == N0.getOperand(1) &&
16402           N1.getOperand(0) == N1.getOperand(1) &&
16403           N0.getOperand(0) == N1.getOperand(0)) {
16404         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
16405                            DAG.getConstantFP(4.0, DL, VT));
16406       }
16407     }
16408 
16409     // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
16410     if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
16411                                           VT, N0, N1, Flags))
16412       return SD;
16413   } // enable-unsafe-fp-math
16414 
16415   // FADD -> FMA combines:
16416   if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
16417     if (Fused.getOpcode() != ISD::DELETED_NODE)
16418       AddToWorklist(Fused.getNode());
16419     return Fused;
16420   }
16421   return SDValue();
16422 }
16423 
visitSTRICT_FADD(SDNode * N)16424 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
16425   SDValue Chain = N->getOperand(0);
16426   SDValue N0 = N->getOperand(1);
16427   SDValue N1 = N->getOperand(2);
16428   EVT VT = N->getValueType(0);
16429   EVT ChainVT = N->getValueType(1);
16430   SDLoc DL(N);
16431   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16432 
16433   // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
16434   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
16435     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16436             N1, DAG, LegalOperations, ForCodeSize)) {
16437       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
16438                          {Chain, N0, NegN1});
16439     }
16440 
16441   // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
16442   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
16443     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16444             N0, DAG, LegalOperations, ForCodeSize)) {
16445       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
16446                          {Chain, N1, NegN0});
16447     }
16448   return SDValue();
16449 }
16450 
visitFSUB(SDNode * N)16451 SDValue DAGCombiner::visitFSUB(SDNode *N) {
16452   SDValue N0 = N->getOperand(0);
16453   SDValue N1 = N->getOperand(1);
16454   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
16455   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
16456   EVT VT = N->getValueType(0);
16457   SDLoc DL(N);
16458   const TargetOptions &Options = DAG.getTarget().Options;
16459   const SDNodeFlags Flags = N->getFlags();
16460   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16461 
16462   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
16463     return R;
16464 
16465   // fold (fsub c1, c2) -> c1-c2
16466   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
16467     return C;
16468 
16469   // fold vector ops
16470   if (VT.isVector())
16471     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16472       return FoldedVOp;
16473 
16474   if (SDValue NewSel = foldBinOpIntoSelect(N))
16475     return NewSel;
16476 
16477   // (fsub A, 0) -> A
16478   if (N1CFP && N1CFP->isZero()) {
16479     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
16480         Flags.hasNoSignedZeros()) {
16481       return N0;
16482     }
16483   }
16484 
16485   if (N0 == N1) {
16486     // (fsub x, x) -> 0.0
16487     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
16488       return DAG.getConstantFP(0.0f, DL, VT);
16489   }
16490 
16491   // (fsub -0.0, N1) -> -N1
16492   if (N0CFP && N0CFP->isZero()) {
16493     if (N0CFP->isNegative() ||
16494         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
16495       // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
16496       // flushed to zero, unless all users treat denorms as zero (DAZ).
16497       // FIXME: This transform will change the sign of a NaN and the behavior
16498       // of a signaling NaN. It is only valid when a NoNaN flag is present.
16499       DenormalMode DenormMode = DAG.getDenormalMode(VT);
16500       if (DenormMode == DenormalMode::getIEEE()) {
16501         if (SDValue NegN1 =
16502                 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
16503           return NegN1;
16504         if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
16505           return DAG.getNode(ISD::FNEG, DL, VT, N1);
16506       }
16507     }
16508   }
16509 
16510   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16511        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16512       N1.getOpcode() == ISD::FADD) {
16513     // X - (X + Y) -> -Y
16514     if (N0 == N1->getOperand(0))
16515       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
16516     // X - (Y + X) -> -Y
16517     if (N0 == N1->getOperand(1))
16518       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
16519   }
16520 
16521   // fold (fsub A, (fneg B)) -> (fadd A, B)
16522   if (SDValue NegN1 =
16523           TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
16524     return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
16525 
16526   // FSUB -> FMA combines:
16527   if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
16528     AddToWorklist(Fused.getNode());
16529     return Fused;
16530   }
16531 
16532   return SDValue();
16533 }
16534 
16535 // Transform IEEE Floats:
16536 //      (fmul C, (uitofp Pow2))
16537 //          -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
16538 //      (fdiv C, (uitofp Pow2))
16539 //          -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
16540 //
16541 // The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
16542 // there is no need for more than an add/sub.
16543 //
16544 // This is valid under the following circumstances:
16545 // 1) We are dealing with IEEE floats
16546 // 2) C is normal
16547 // 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
16548 // TODO: Much of this could also be used for generating `ldexp` on targets the
16549 // prefer it.
combineFMulOrFDivWithIntPow2(SDNode * N)16550 SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
16551   EVT VT = N->getValueType(0);
16552   SDValue ConstOp, Pow2Op;
16553 
16554   std::optional<int> Mantissa;
16555   auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
16556     if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
16557       return false;
16558 
16559     ConstOp = peekThroughBitcasts(N->getOperand(ConstOpIdx));
16560     Pow2Op = N->getOperand(1 - ConstOpIdx);
16561     if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
16562         (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
16563          !DAG.computeKnownBits(Pow2Op).isNonNegative()))
16564       return false;
16565 
16566     Pow2Op = Pow2Op.getOperand(0);
16567 
16568     // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
16569     // TODO: We could use knownbits to make this bound more precise.
16570     int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
16571 
16572     auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
16573       if (CFP == nullptr)
16574         return false;
16575 
16576       const APFloat &APF = CFP->getValueAPF();
16577 
16578       // Make sure we have normal/ieee constant.
16579       if (!APF.isNormal() || !APF.isIEEE())
16580         return false;
16581 
16582       // Make sure the floats exponent is within the bounds that this transform
16583       // produces bitwise equals value.
16584       int CurExp = ilogb(APF);
16585       // FMul by pow2 will only increase exponent.
16586       int MinExp =
16587           N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
16588       // FDiv by pow2 will only decrease exponent.
16589       int MaxExp =
16590           N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
16591       if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
16592           MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
16593         return false;
16594 
16595       // Finally make sure we actually know the mantissa for the float type.
16596       int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
16597       if (!Mantissa)
16598         Mantissa = ThisMantissa;
16599 
16600       return *Mantissa == ThisMantissa && ThisMantissa > 0;
16601     };
16602 
16603     // TODO: We may be able to include undefs.
16604     return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid);
16605   };
16606 
16607   if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
16608     return SDValue();
16609 
16610   if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, ConstOp, Pow2Op))
16611     return SDValue();
16612 
16613   // Get log2 after all other checks have taken place. This is because
16614   // BuildLogBase2 may create a new node.
16615   SDLoc DL(N);
16616   // Get Log2 type with same bitwidth as the float type (VT).
16617   EVT NewIntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits());
16618   if (VT.isVector())
16619     NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewIntVT,
16620                                 VT.getVectorElementCount());
16621 
16622   SDValue Log2 = BuildLogBase2(Pow2Op, DL, DAG.isKnownNeverZero(Pow2Op),
16623                                /*InexpensiveOnly*/ true, NewIntVT);
16624   if (!Log2)
16625     return SDValue();
16626 
16627   // Perform actual transform.
16628   SDValue MantissaShiftCnt =
16629       DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT));
16630   // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
16631   // `(X << C1) + (C << C1)`, but that isn't always the case because of the
16632   // cast. We could implement that by handle here to handle the casts.
16633   SDValue Shift = DAG.getNode(ISD::SHL, DL, NewIntVT, Log2, MantissaShiftCnt);
16634   SDValue ResAsInt =
16635       DAG.getNode(N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
16636                   NewIntVT, DAG.getBitcast(NewIntVT, ConstOp), Shift);
16637   SDValue ResAsFP = DAG.getBitcast(VT, ResAsInt);
16638   return ResAsFP;
16639 }
16640 
visitFMUL(SDNode * N)16641 SDValue DAGCombiner::visitFMUL(SDNode *N) {
16642   SDValue N0 = N->getOperand(0);
16643   SDValue N1 = N->getOperand(1);
16644   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
16645   EVT VT = N->getValueType(0);
16646   SDLoc DL(N);
16647   const TargetOptions &Options = DAG.getTarget().Options;
16648   const SDNodeFlags Flags = N->getFlags();
16649   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16650 
16651   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
16652     return R;
16653 
16654   // fold (fmul c1, c2) -> c1*c2
16655   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
16656     return C;
16657 
16658   // canonicalize constant to RHS
16659   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
16660      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
16661     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
16662 
16663   // fold vector ops
16664   if (VT.isVector())
16665     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16666       return FoldedVOp;
16667 
16668   if (SDValue NewSel = foldBinOpIntoSelect(N))
16669     return NewSel;
16670 
16671   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
16672     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
16673     if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
16674         N0.getOpcode() == ISD::FMUL) {
16675       SDValue N00 = N0.getOperand(0);
16676       SDValue N01 = N0.getOperand(1);
16677       // Avoid an infinite loop by making sure that N00 is not a constant
16678       // (the inner multiply has not been constant folded yet).
16679       if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
16680           !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
16681         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
16682         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
16683       }
16684     }
16685 
16686     // Match a special-case: we convert X * 2.0 into fadd.
16687     // fmul (fadd X, X), C -> fmul X, 2.0 * C
16688     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
16689         N0.getOperand(0) == N0.getOperand(1)) {
16690       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
16691       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
16692       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
16693     }
16694 
16695     // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
16696     if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL,
16697                                           VT, N0, N1, Flags))
16698       return SD;
16699   }
16700 
16701   // fold (fmul X, 2.0) -> (fadd X, X)
16702   if (N1CFP && N1CFP->isExactlyValue(+2.0))
16703     return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
16704 
16705   // fold (fmul X, -1.0) -> (fsub -0.0, X)
16706   if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
16707     if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
16708       return DAG.getNode(ISD::FSUB, DL, VT,
16709                          DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
16710     }
16711   }
16712 
16713   // -N0 * -N1 --> N0 * N1
16714   TargetLowering::NegatibleCost CostN0 =
16715       TargetLowering::NegatibleCost::Expensive;
16716   TargetLowering::NegatibleCost CostN1 =
16717       TargetLowering::NegatibleCost::Expensive;
16718   SDValue NegN0 =
16719       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
16720   if (NegN0) {
16721     HandleSDNode NegN0Handle(NegN0);
16722     SDValue NegN1 =
16723         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
16724     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
16725                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
16726       return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
16727   }
16728 
16729   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
16730   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
16731   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
16732       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
16733       TLI.isOperationLegal(ISD::FABS, VT)) {
16734     SDValue Select = N0, X = N1;
16735     if (Select.getOpcode() != ISD::SELECT)
16736       std::swap(Select, X);
16737 
16738     SDValue Cond = Select.getOperand(0);
16739     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
16740     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
16741 
16742     if (TrueOpnd && FalseOpnd &&
16743         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
16744         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
16745         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
16746       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
16747       switch (CC) {
16748       default: break;
16749       case ISD::SETOLT:
16750       case ISD::SETULT:
16751       case ISD::SETOLE:
16752       case ISD::SETULE:
16753       case ISD::SETLT:
16754       case ISD::SETLE:
16755         std::swap(TrueOpnd, FalseOpnd);
16756         [[fallthrough]];
16757       case ISD::SETOGT:
16758       case ISD::SETUGT:
16759       case ISD::SETOGE:
16760       case ISD::SETUGE:
16761       case ISD::SETGT:
16762       case ISD::SETGE:
16763         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
16764             TLI.isOperationLegal(ISD::FNEG, VT))
16765           return DAG.getNode(ISD::FNEG, DL, VT,
16766                    DAG.getNode(ISD::FABS, DL, VT, X));
16767         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
16768           return DAG.getNode(ISD::FABS, DL, VT, X);
16769 
16770         break;
16771       }
16772     }
16773   }
16774 
16775   // FMUL -> FMA combines:
16776   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
16777     AddToWorklist(Fused.getNode());
16778     return Fused;
16779   }
16780 
16781   // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
16782   // able to run.
16783   if (SDValue R = combineFMulOrFDivWithIntPow2(N))
16784     return R;
16785 
16786   return SDValue();
16787 }
16788 
visitFMA(SDNode * N)16789 template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
16790   SDValue N0 = N->getOperand(0);
16791   SDValue N1 = N->getOperand(1);
16792   SDValue N2 = N->getOperand(2);
16793   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
16794   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
16795   EVT VT = N->getValueType(0);
16796   SDLoc DL(N);
16797   const TargetOptions &Options = DAG.getTarget().Options;
16798   // FMA nodes have flags that propagate to the created nodes.
16799   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16800   MatchContextClass matcher(DAG, TLI, N);
16801 
16802   bool CanReassociate =
16803       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16804 
16805   // Constant fold FMA.
16806   if (isa<ConstantFPSDNode>(N0) &&
16807       isa<ConstantFPSDNode>(N1) &&
16808       isa<ConstantFPSDNode>(N2)) {
16809     return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2);
16810   }
16811 
16812   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
16813   TargetLowering::NegatibleCost CostN0 =
16814       TargetLowering::NegatibleCost::Expensive;
16815   TargetLowering::NegatibleCost CostN1 =
16816       TargetLowering::NegatibleCost::Expensive;
16817   SDValue NegN0 =
16818       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
16819   if (NegN0) {
16820     HandleSDNode NegN0Handle(NegN0);
16821     SDValue NegN1 =
16822         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
16823     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
16824                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
16825       return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
16826   }
16827 
16828   // FIXME: use fast math flags instead of Options.UnsafeFPMath
16829   if (Options.UnsafeFPMath) {
16830     if (N0CFP && N0CFP->isZero())
16831       return N2;
16832     if (N1CFP && N1CFP->isZero())
16833       return N2;
16834   }
16835 
16836   // FIXME: Support splat of constant.
16837   if (N0CFP && N0CFP->isExactlyValue(1.0))
16838     return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
16839   if (N1CFP && N1CFP->isExactlyValue(1.0))
16840     return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
16841 
16842   // Canonicalize (fma c, x, y) -> (fma x, c, y)
16843   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
16844      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
16845     return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
16846 
16847   if (CanReassociate) {
16848     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
16849     if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) &&
16850         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
16851         DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
16852       return matcher.getNode(
16853           ISD::FMUL, DL, VT, N0,
16854           matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
16855     }
16856 
16857     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
16858     if (matcher.match(N0, ISD::FMUL) &&
16859         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
16860         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
16861       return matcher.getNode(
16862           ISD::FMA, DL, VT, N0.getOperand(0),
16863           matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2);
16864     }
16865   }
16866 
16867   // (fma x, -1, y) -> (fadd (fneg x), y)
16868   // FIXME: Support splat of constant.
16869   if (N1CFP) {
16870     if (N1CFP->isExactlyValue(1.0))
16871       return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
16872 
16873     if (N1CFP->isExactlyValue(-1.0) &&
16874         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
16875       SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
16876       AddToWorklist(RHSNeg.getNode());
16877       return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
16878     }
16879 
16880     // fma (fneg x), K, y -> fma x -K, y
16881     if (matcher.match(N0, ISD::FNEG) &&
16882         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
16883          (N1.hasOneUse() &&
16884           !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
16885       return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
16886                              matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
16887     }
16888   }
16889 
16890   // FIXME: Support splat of constant.
16891   if (CanReassociate) {
16892     // (fma x, c, x) -> (fmul x, (c+1))
16893     if (N1CFP && N0 == N2) {
16894       return matcher.getNode(ISD::FMUL, DL, VT, N0,
16895                              matcher.getNode(ISD::FADD, DL, VT, N1,
16896                                              DAG.getConstantFP(1.0, DL, VT)));
16897     }
16898 
16899     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
16900     if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) {
16901       return matcher.getNode(ISD::FMUL, DL, VT, N0,
16902                              matcher.getNode(ISD::FADD, DL, VT, N1,
16903                                              DAG.getConstantFP(-1.0, DL, VT)));
16904     }
16905   }
16906 
16907   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
16908   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
16909   if (!TLI.isFNegFree(VT))
16910     if (SDValue Neg = TLI.getCheaperNegatedExpression(
16911             SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
16912       return matcher.getNode(ISD::FNEG, DL, VT, Neg);
16913   return SDValue();
16914 }
16915 
visitFMAD(SDNode * N)16916 SDValue DAGCombiner::visitFMAD(SDNode *N) {
16917   SDValue N0 = N->getOperand(0);
16918   SDValue N1 = N->getOperand(1);
16919   SDValue N2 = N->getOperand(2);
16920   EVT VT = N->getValueType(0);
16921   SDLoc DL(N);
16922 
16923   // Constant fold FMAD.
16924   if (isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1) &&
16925       isa<ConstantFPSDNode>(N2))
16926     return DAG.getNode(ISD::FMAD, DL, VT, N0, N1, N2);
16927 
16928   return SDValue();
16929 }
16930 
16931 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
16932 // reciprocal.
16933 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
16934 // Notice that this is not always beneficial. One reason is different targets
16935 // may have different costs for FDIV and FMUL, so sometimes the cost of two
16936 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
16937 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)16938 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
16939   // TODO: Limit this transform based on optsize/minsize - it always creates at
16940   //       least 1 extra instruction. But the perf win may be substantial enough
16941   //       that only minsize should restrict this.
16942   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
16943   const SDNodeFlags Flags = N->getFlags();
16944   if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
16945     return SDValue();
16946 
16947   // Skip if current node is a reciprocal/fneg-reciprocal.
16948   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
16949   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
16950   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
16951     return SDValue();
16952 
16953   // Exit early if the target does not want this transform or if there can't
16954   // possibly be enough uses of the divisor to make the transform worthwhile.
16955   unsigned MinUses = TLI.combineRepeatedFPDivisors();
16956 
16957   // For splat vectors, scale the number of uses by the splat factor. If we can
16958   // convert the division into a scalar op, that will likely be much faster.
16959   unsigned NumElts = 1;
16960   EVT VT = N->getValueType(0);
16961   if (VT.isVector() && DAG.isSplatValue(N1))
16962     NumElts = VT.getVectorMinNumElements();
16963 
16964   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
16965     return SDValue();
16966 
16967   // Find all FDIV users of the same divisor.
16968   // Use a set because duplicates may be present in the user list.
16969   SetVector<SDNode *> Users;
16970   for (auto *U : N1->uses()) {
16971     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
16972       // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
16973       if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
16974           U->getOperand(0) == U->getOperand(1).getOperand(0) &&
16975           U->getFlags().hasAllowReassociation() &&
16976           U->getFlags().hasNoSignedZeros())
16977         continue;
16978 
16979       // This division is eligible for optimization only if global unsafe math
16980       // is enabled or if this division allows reciprocal formation.
16981       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
16982         Users.insert(U);
16983     }
16984   }
16985 
16986   // Now that we have the actual number of divisor uses, make sure it meets
16987   // the minimum threshold specified by the target.
16988   if ((Users.size() * NumElts) < MinUses)
16989     return SDValue();
16990 
16991   SDLoc DL(N);
16992   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
16993   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
16994 
16995   // Dividend / Divisor -> Dividend * Reciprocal
16996   for (auto *U : Users) {
16997     SDValue Dividend = U->getOperand(0);
16998     if (Dividend != FPOne) {
16999       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
17000                                     Reciprocal, Flags);
17001       CombineTo(U, NewNode);
17002     } else if (U != Reciprocal.getNode()) {
17003       // In the absence of fast-math-flags, this user node is always the
17004       // same node as Reciprocal, but with FMF they may be different nodes.
17005       CombineTo(U, Reciprocal);
17006     }
17007   }
17008   return SDValue(N, 0);  // N was replaced.
17009 }
17010 
visitFDIV(SDNode * N)17011 SDValue DAGCombiner::visitFDIV(SDNode *N) {
17012   SDValue N0 = N->getOperand(0);
17013   SDValue N1 = N->getOperand(1);
17014   EVT VT = N->getValueType(0);
17015   SDLoc DL(N);
17016   const TargetOptions &Options = DAG.getTarget().Options;
17017   SDNodeFlags Flags = N->getFlags();
17018   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17019 
17020   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17021     return R;
17022 
17023   // fold (fdiv c1, c2) -> c1/c2
17024   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
17025     return C;
17026 
17027   // fold vector ops
17028   if (VT.isVector())
17029     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17030       return FoldedVOp;
17031 
17032   if (SDValue NewSel = foldBinOpIntoSelect(N))
17033     return NewSel;
17034 
17035   if (SDValue V = combineRepeatedFPDivisors(N))
17036     return V;
17037 
17038   if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
17039     // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
17040     if (auto *N1CFP = dyn_cast<ConstantFPSDNode>(N1)) {
17041       // Compute the reciprocal 1.0 / c2.
17042       const APFloat &N1APF = N1CFP->getValueAPF();
17043       APFloat Recip(N1APF.getSemantics(), 1); // 1.0
17044       APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
17045       // Only do the transform if the reciprocal is a legal fp immediate that
17046       // isn't too nasty (eg NaN, denormal, ...).
17047       if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
17048           (!LegalOperations ||
17049            // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
17050            // backend)... we should handle this gracefully after Legalize.
17051            // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
17052            TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17053            TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
17054         return DAG.getNode(ISD::FMUL, DL, VT, N0,
17055                            DAG.getConstantFP(Recip, DL, VT));
17056     }
17057 
17058     // If this FDIV is part of a reciprocal square root, it may be folded
17059     // into a target-specific square root estimate instruction.
17060     if (N1.getOpcode() == ISD::FSQRT) {
17061       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
17062         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
17063     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
17064                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
17065       if (SDValue RV =
17066               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
17067         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
17068         AddToWorklist(RV.getNode());
17069         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
17070       }
17071     } else if (N1.getOpcode() == ISD::FP_ROUND &&
17072                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
17073       if (SDValue RV =
17074               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
17075         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
17076         AddToWorklist(RV.getNode());
17077         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
17078       }
17079     } else if (N1.getOpcode() == ISD::FMUL) {
17080       // Look through an FMUL. Even though this won't remove the FDIV directly,
17081       // it's still worthwhile to get rid of the FSQRT if possible.
17082       SDValue Sqrt, Y;
17083       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
17084         Sqrt = N1.getOperand(0);
17085         Y = N1.getOperand(1);
17086       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
17087         Sqrt = N1.getOperand(1);
17088         Y = N1.getOperand(0);
17089       }
17090       if (Sqrt.getNode()) {
17091         // If the other multiply operand is known positive, pull it into the
17092         // sqrt. That will eliminate the division if we convert to an estimate.
17093         if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
17094             N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
17095           SDValue A;
17096           if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
17097             A = Y.getOperand(0);
17098           else if (Y == Sqrt.getOperand(0))
17099             A = Y;
17100           if (A) {
17101             // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
17102             // X / (A * sqrt(A))       --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
17103             SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
17104             SDValue AAZ =
17105                 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
17106             if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
17107               return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
17108 
17109             // Estimate creation failed. Clean up speculatively created nodes.
17110             recursivelyDeleteUnusedNodes(AAZ.getNode());
17111           }
17112         }
17113 
17114         // We found a FSQRT, so try to make this fold:
17115         // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
17116         if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
17117           SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
17118           AddToWorklist(Div.getNode());
17119           return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
17120         }
17121       }
17122     }
17123 
17124     // Fold into a reciprocal estimate and multiply instead of a real divide.
17125     if (Options.NoInfsFPMath || Flags.hasNoInfs())
17126       if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
17127         return RV;
17128   }
17129 
17130   // Fold X/Sqrt(X) -> Sqrt(X)
17131   if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
17132       (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
17133     if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
17134       return N1;
17135 
17136   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
17137   TargetLowering::NegatibleCost CostN0 =
17138       TargetLowering::NegatibleCost::Expensive;
17139   TargetLowering::NegatibleCost CostN1 =
17140       TargetLowering::NegatibleCost::Expensive;
17141   SDValue NegN0 =
17142       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
17143   if (NegN0) {
17144     HandleSDNode NegN0Handle(NegN0);
17145     SDValue NegN1 =
17146         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
17147     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17148                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
17149       return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
17150   }
17151 
17152   if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17153     return R;
17154 
17155   return SDValue();
17156 }
17157 
visitFREM(SDNode * N)17158 SDValue DAGCombiner::visitFREM(SDNode *N) {
17159   SDValue N0 = N->getOperand(0);
17160   SDValue N1 = N->getOperand(1);
17161   EVT VT = N->getValueType(0);
17162   SDNodeFlags Flags = N->getFlags();
17163   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17164 
17165   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17166     return R;
17167 
17168   // fold (frem c1, c2) -> fmod(c1,c2)
17169   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, SDLoc(N), VT, {N0, N1}))
17170     return C;
17171 
17172   if (SDValue NewSel = foldBinOpIntoSelect(N))
17173     return NewSel;
17174 
17175   return SDValue();
17176 }
17177 
visitFSQRT(SDNode * N)17178 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
17179   SDNodeFlags Flags = N->getFlags();
17180   const TargetOptions &Options = DAG.getTarget().Options;
17181 
17182   // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
17183   // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
17184   if (!Flags.hasApproximateFuncs() ||
17185       (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
17186     return SDValue();
17187 
17188   SDValue N0 = N->getOperand(0);
17189   if (TLI.isFsqrtCheap(N0, DAG))
17190     return SDValue();
17191 
17192   // FSQRT nodes have flags that propagate to the created nodes.
17193   // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
17194   //       transform the fdiv, we may produce a sub-optimal estimate sequence
17195   //       because the reciprocal calculation may not have to filter out a
17196   //       0.0 input.
17197   return buildSqrtEstimate(N0, Flags);
17198 }
17199 
17200 /// copysign(x, fp_extend(y)) -> copysign(x, y)
17201 /// copysign(x, fp_round(y)) -> copysign(x, y)
17202 /// Operands to the functions are the type of X and Y respectively.
CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy,EVT YTy)17203 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
17204   // Always fold no-op FP casts.
17205   if (XTy == YTy)
17206     return true;
17207 
17208   // Do not optimize out type conversion of f128 type yet.
17209   // For some targets like x86_64, configuration is changed to keep one f128
17210   // value in one SSE register, but instruction selection cannot handle
17211   // FCOPYSIGN on SSE registers yet.
17212   if (YTy == MVT::f128)
17213     return false;
17214 
17215   return !YTy.isVector() || EnableVectorFCopySignExtendRound;
17216 }
17217 
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)17218 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
17219   SDValue N1 = N->getOperand(1);
17220   if (N1.getOpcode() != ISD::FP_EXTEND &&
17221       N1.getOpcode() != ISD::FP_ROUND)
17222     return false;
17223   EVT N1VT = N1->getValueType(0);
17224   EVT N1Op0VT = N1->getOperand(0).getValueType();
17225   return CanCombineFCOPYSIGN_EXTEND_ROUND(N1VT, N1Op0VT);
17226 }
17227 
visitFCOPYSIGN(SDNode * N)17228 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
17229   SDValue N0 = N->getOperand(0);
17230   SDValue N1 = N->getOperand(1);
17231   EVT VT = N->getValueType(0);
17232 
17233   // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
17234   if (SDValue C =
17235           DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, SDLoc(N), VT, {N0, N1}))
17236     return C;
17237 
17238   if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
17239     const APFloat &V = N1C->getValueAPF();
17240     // copysign(x, c1) -> fabs(x)       iff ispos(c1)
17241     // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
17242     if (!V.isNegative()) {
17243       if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
17244         return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
17245     } else {
17246       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
17247         return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
17248                            DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
17249     }
17250   }
17251 
17252   // copysign(fabs(x), y) -> copysign(x, y)
17253   // copysign(fneg(x), y) -> copysign(x, y)
17254   // copysign(copysign(x,z), y) -> copysign(x, y)
17255   if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
17256       N0.getOpcode() == ISD::FCOPYSIGN)
17257     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
17258 
17259   // copysign(x, abs(y)) -> abs(x)
17260   if (N1.getOpcode() == ISD::FABS)
17261     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
17262 
17263   // copysign(x, copysign(y,z)) -> copysign(x, z)
17264   if (N1.getOpcode() == ISD::FCOPYSIGN)
17265     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
17266 
17267   // copysign(x, fp_extend(y)) -> copysign(x, y)
17268   // copysign(x, fp_round(y)) -> copysign(x, y)
17269   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
17270     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
17271 
17272   return SDValue();
17273 }
17274 
visitFPOW(SDNode * N)17275 SDValue DAGCombiner::visitFPOW(SDNode *N) {
17276   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
17277   if (!ExponentC)
17278     return SDValue();
17279   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17280 
17281   // Try to convert x ** (1/3) into cube root.
17282   // TODO: Handle the various flavors of long double.
17283   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
17284   //       Some range near 1/3 should be fine.
17285   EVT VT = N->getValueType(0);
17286   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
17287       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
17288     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
17289     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
17290     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
17291     // For regular numbers, rounding may cause the results to differ.
17292     // Therefore, we require { nsz ninf nnan afn } for this transform.
17293     // TODO: We could select out the special cases if we don't have nsz/ninf.
17294     SDNodeFlags Flags = N->getFlags();
17295     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
17296         !Flags.hasApproximateFuncs())
17297       return SDValue();
17298 
17299     // Do not create a cbrt() libcall if the target does not have it, and do not
17300     // turn a pow that has lowering support into a cbrt() libcall.
17301     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
17302         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
17303          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
17304       return SDValue();
17305 
17306     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
17307   }
17308 
17309   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
17310   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
17311   // TODO: This could be extended (using a target hook) to handle smaller
17312   // power-of-2 fractional exponents.
17313   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
17314   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
17315   if (ExponentIs025 || ExponentIs075) {
17316     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
17317     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
17318     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
17319     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
17320     // For regular numbers, rounding may cause the results to differ.
17321     // Therefore, we require { nsz ninf afn } for this transform.
17322     // TODO: We could select out the special cases if we don't have nsz/ninf.
17323     SDNodeFlags Flags = N->getFlags();
17324 
17325     // We only need no signed zeros for the 0.25 case.
17326     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
17327         !Flags.hasApproximateFuncs())
17328       return SDValue();
17329 
17330     // Don't double the number of libcalls. We are trying to inline fast code.
17331     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
17332       return SDValue();
17333 
17334     // Assume that libcalls are the smallest code.
17335     // TODO: This restriction should probably be lifted for vectors.
17336     if (ForCodeSize)
17337       return SDValue();
17338 
17339     // pow(X, 0.25) --> sqrt(sqrt(X))
17340     SDLoc DL(N);
17341     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
17342     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
17343     if (ExponentIs025)
17344       return SqrtSqrt;
17345     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
17346     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
17347   }
17348 
17349   return SDValue();
17350 }
17351 
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)17352 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
17353                                const TargetLowering &TLI) {
17354   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
17355   // replacing casts with a libcall. We also must be allowed to ignore -0.0
17356   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
17357   // conversions would return +0.0.
17358   // FIXME: We should be able to use node-level FMF here.
17359   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
17360   EVT VT = N->getValueType(0);
17361   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
17362       !DAG.getTarget().Options.NoSignedZerosFPMath)
17363     return SDValue();
17364 
17365   // fptosi/fptoui round towards zero, so converting from FP to integer and
17366   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
17367   SDValue N0 = N->getOperand(0);
17368   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
17369       N0.getOperand(0).getValueType() == VT)
17370     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
17371 
17372   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
17373       N0.getOperand(0).getValueType() == VT)
17374     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
17375 
17376   return SDValue();
17377 }
17378 
visitSINT_TO_FP(SDNode * N)17379 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
17380   SDValue N0 = N->getOperand(0);
17381   EVT VT = N->getValueType(0);
17382   EVT OpVT = N0.getValueType();
17383 
17384   // [us]itofp(undef) = 0, because the result value is bounded.
17385   if (N0.isUndef())
17386     return DAG.getConstantFP(0.0, SDLoc(N), VT);
17387 
17388   // fold (sint_to_fp c1) -> c1fp
17389   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
17390       // ...but only if the target supports immediate floating-point values
17391       (!LegalOperations ||
17392        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
17393     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
17394 
17395   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
17396   // but UINT_TO_FP is legal on this target, try to convert.
17397   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
17398       hasOperation(ISD::UINT_TO_FP, OpVT)) {
17399     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
17400     if (DAG.SignBitIsZero(N0))
17401       return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
17402   }
17403 
17404   // The next optimizations are desirable only if SELECT_CC can be lowered.
17405   // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
17406   if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
17407       !VT.isVector() &&
17408       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
17409     SDLoc DL(N);
17410     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
17411                          DAG.getConstantFP(0.0, DL, VT));
17412   }
17413 
17414   // fold (sint_to_fp (zext (setcc x, y, cc))) ->
17415   //      (select (setcc x, y, cc), 1.0, 0.0)
17416   if (N0.getOpcode() == ISD::ZERO_EXTEND &&
17417       N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
17418       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
17419     SDLoc DL(N);
17420     return DAG.getSelect(DL, VT, N0.getOperand(0),
17421                          DAG.getConstantFP(1.0, DL, VT),
17422                          DAG.getConstantFP(0.0, DL, VT));
17423   }
17424 
17425   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17426     return FTrunc;
17427 
17428   return SDValue();
17429 }
17430 
visitUINT_TO_FP(SDNode * N)17431 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
17432   SDValue N0 = N->getOperand(0);
17433   EVT VT = N->getValueType(0);
17434   EVT OpVT = N0.getValueType();
17435 
17436   // [us]itofp(undef) = 0, because the result value is bounded.
17437   if (N0.isUndef())
17438     return DAG.getConstantFP(0.0, SDLoc(N), VT);
17439 
17440   // fold (uint_to_fp c1) -> c1fp
17441   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
17442       // ...but only if the target supports immediate floating-point values
17443       (!LegalOperations ||
17444        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
17445     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
17446 
17447   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
17448   // but SINT_TO_FP is legal on this target, try to convert.
17449   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
17450       hasOperation(ISD::SINT_TO_FP, OpVT)) {
17451     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
17452     if (DAG.SignBitIsZero(N0))
17453       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
17454   }
17455 
17456   // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
17457   if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
17458       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
17459     SDLoc DL(N);
17460     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
17461                          DAG.getConstantFP(0.0, DL, VT));
17462   }
17463 
17464   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17465     return FTrunc;
17466 
17467   return SDValue();
17468 }
17469 
17470 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)17471 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
17472   SDValue N0 = N->getOperand(0);
17473   EVT VT = N->getValueType(0);
17474 
17475   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
17476     return SDValue();
17477 
17478   SDValue Src = N0.getOperand(0);
17479   EVT SrcVT = Src.getValueType();
17480   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
17481   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
17482 
17483   // We can safely assume the conversion won't overflow the output range,
17484   // because (for example) (uint8_t)18293.f is undefined behavior.
17485 
17486   // Since we can assume the conversion won't overflow, our decision as to
17487   // whether the input will fit in the float should depend on the minimum
17488   // of the input range and output range.
17489 
17490   // This means this is also safe for a signed input and unsigned output, since
17491   // a negative input would lead to undefined behavior.
17492   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
17493   unsigned OutputSize = (int)VT.getScalarSizeInBits();
17494   unsigned ActualSize = std::min(InputSize, OutputSize);
17495   const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
17496 
17497   // We can only fold away the float conversion if the input range can be
17498   // represented exactly in the float range.
17499   if (APFloat::semanticsPrecision(sem) >= ActualSize) {
17500     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
17501       unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
17502                                                        : ISD::ZERO_EXTEND;
17503       return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
17504     }
17505     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
17506       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
17507     return DAG.getBitcast(VT, Src);
17508   }
17509   return SDValue();
17510 }
17511 
visitFP_TO_SINT(SDNode * N)17512 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
17513   SDValue N0 = N->getOperand(0);
17514   EVT VT = N->getValueType(0);
17515 
17516   // fold (fp_to_sint undef) -> undef
17517   if (N0.isUndef())
17518     return DAG.getUNDEF(VT);
17519 
17520   // fold (fp_to_sint c1fp) -> c1
17521   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17522     return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
17523 
17524   return FoldIntToFPToInt(N, DAG);
17525 }
17526 
visitFP_TO_UINT(SDNode * N)17527 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
17528   SDValue N0 = N->getOperand(0);
17529   EVT VT = N->getValueType(0);
17530 
17531   // fold (fp_to_uint undef) -> undef
17532   if (N0.isUndef())
17533     return DAG.getUNDEF(VT);
17534 
17535   // fold (fp_to_uint c1fp) -> c1
17536   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17537     return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
17538 
17539   return FoldIntToFPToInt(N, DAG);
17540 }
17541 
visitXRINT(SDNode * N)17542 SDValue DAGCombiner::visitXRINT(SDNode *N) {
17543   SDValue N0 = N->getOperand(0);
17544   EVT VT = N->getValueType(0);
17545 
17546   // fold (lrint|llrint undef) -> undef
17547   if (N0.isUndef())
17548     return DAG.getUNDEF(VT);
17549 
17550   // fold (lrint|llrint c1fp) -> c1
17551   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17552     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);
17553 
17554   return SDValue();
17555 }
17556 
visitFP_ROUND(SDNode * N)17557 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
17558   SDValue N0 = N->getOperand(0);
17559   SDValue N1 = N->getOperand(1);
17560   EVT VT = N->getValueType(0);
17561 
17562   // fold (fp_round c1fp) -> c1fp
17563   if (SDValue C =
17564           DAG.FoldConstantArithmetic(ISD::FP_ROUND, SDLoc(N), VT, {N0, N1}))
17565     return C;
17566 
17567   // fold (fp_round (fp_extend x)) -> x
17568   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
17569     return N0.getOperand(0);
17570 
17571   // fold (fp_round (fp_round x)) -> (fp_round x)
17572   if (N0.getOpcode() == ISD::FP_ROUND) {
17573     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
17574     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
17575 
17576     // Avoid folding legal fp_rounds into non-legal ones.
17577     if (!hasOperation(ISD::FP_ROUND, VT))
17578       return SDValue();
17579 
17580     // Skip this folding if it results in an fp_round from f80 to f16.
17581     //
17582     // f80 to f16 always generates an expensive (and as yet, unimplemented)
17583     // libcall to __truncxfhf2 instead of selecting native f16 conversion
17584     // instructions from f32 or f64.  Moreover, the first (value-preserving)
17585     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
17586     // x86.
17587     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
17588       return SDValue();
17589 
17590     // If the first fp_round isn't a value preserving truncation, it might
17591     // introduce a tie in the second fp_round, that wouldn't occur in the
17592     // single-step fp_round we want to fold to.
17593     // In other words, double rounding isn't the same as rounding.
17594     // Also, this is a value preserving truncation iff both fp_round's are.
17595     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
17596       SDLoc DL(N);
17597       return DAG.getNode(
17598           ISD::FP_ROUND, DL, VT, N0.getOperand(0),
17599           DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
17600     }
17601   }
17602 
17603   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
17604   // Note: From a legality perspective, this is a two step transform.  First,
17605   // we duplicate the fp_round to the arguments of the copysign, then we
17606   // eliminate the fp_round on Y.  The second step requires an additional
17607   // predicate to match the implementation above.
17608   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
17609       CanCombineFCOPYSIGN_EXTEND_ROUND(VT,
17610                                        N0.getValueType())) {
17611     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
17612                               N0.getOperand(0), N1);
17613     AddToWorklist(Tmp.getNode());
17614     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
17615                        Tmp, N0.getOperand(1));
17616   }
17617 
17618   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
17619     return NewVSel;
17620 
17621   return SDValue();
17622 }
17623 
visitFP_EXTEND(SDNode * N)17624 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
17625   SDValue N0 = N->getOperand(0);
17626   EVT VT = N->getValueType(0);
17627 
17628   if (VT.isVector())
17629     if (SDValue FoldedVOp = SimplifyVCastOp(N, SDLoc(N)))
17630       return FoldedVOp;
17631 
17632   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
17633   if (N->hasOneUse() &&
17634       N->use_begin()->getOpcode() == ISD::FP_ROUND)
17635     return SDValue();
17636 
17637   // fold (fp_extend c1fp) -> c1fp
17638   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17639     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
17640 
17641   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
17642   if (N0.getOpcode() == ISD::FP16_TO_FP &&
17643       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
17644     return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
17645 
17646   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
17647   // value of X.
17648   if (N0.getOpcode() == ISD::FP_ROUND
17649       && N0.getConstantOperandVal(1) == 1) {
17650     SDValue In = N0.getOperand(0);
17651     if (In.getValueType() == VT) return In;
17652     if (VT.bitsLT(In.getValueType()))
17653       return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
17654                          In, N0.getOperand(1));
17655     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
17656   }
17657 
17658   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
17659   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
17660       TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
17661     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
17662     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
17663                                      LN0->getChain(),
17664                                      LN0->getBasePtr(), N0.getValueType(),
17665                                      LN0->getMemOperand());
17666     CombineTo(N, ExtLoad);
17667     CombineTo(
17668         N0.getNode(),
17669         DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
17670                     DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
17671         ExtLoad.getValue(1));
17672     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
17673   }
17674 
17675   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
17676     return NewVSel;
17677 
17678   return SDValue();
17679 }
17680 
visitFCEIL(SDNode * N)17681 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
17682   SDValue N0 = N->getOperand(0);
17683   EVT VT = N->getValueType(0);
17684 
17685   // fold (fceil c1) -> fceil(c1)
17686   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17687     return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
17688 
17689   return SDValue();
17690 }
17691 
visitFTRUNC(SDNode * N)17692 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
17693   SDValue N0 = N->getOperand(0);
17694   EVT VT = N->getValueType(0);
17695 
17696   // fold (ftrunc c1) -> ftrunc(c1)
17697   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17698     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
17699 
17700   // fold ftrunc (known rounded int x) -> x
17701   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
17702   // likely to be generated to extract integer from a rounded floating value.
17703   switch (N0.getOpcode()) {
17704   default: break;
17705   case ISD::FRINT:
17706   case ISD::FTRUNC:
17707   case ISD::FNEARBYINT:
17708   case ISD::FROUNDEVEN:
17709   case ISD::FFLOOR:
17710   case ISD::FCEIL:
17711     return N0;
17712   }
17713 
17714   return SDValue();
17715 }
17716 
visitFFREXP(SDNode * N)17717 SDValue DAGCombiner::visitFFREXP(SDNode *N) {
17718   SDValue N0 = N->getOperand(0);
17719 
17720   // fold (ffrexp c1) -> ffrexp(c1)
17721   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17722     return DAG.getNode(ISD::FFREXP, SDLoc(N), N->getVTList(), N0);
17723   return SDValue();
17724 }
17725 
visitFFLOOR(SDNode * N)17726 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
17727   SDValue N0 = N->getOperand(0);
17728   EVT VT = N->getValueType(0);
17729 
17730   // fold (ffloor c1) -> ffloor(c1)
17731   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17732     return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
17733 
17734   return SDValue();
17735 }
17736 
visitFNEG(SDNode * N)17737 SDValue DAGCombiner::visitFNEG(SDNode *N) {
17738   SDValue N0 = N->getOperand(0);
17739   EVT VT = N->getValueType(0);
17740   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17741 
17742   // Constant fold FNEG.
17743   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17744     return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
17745 
17746   if (SDValue NegN0 =
17747           TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
17748     return NegN0;
17749 
17750   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
17751   // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
17752   // know it was called from a context with a nsz flag if the input fsub does
17753   // not.
17754   if (N0.getOpcode() == ISD::FSUB &&
17755       (DAG.getTarget().Options.NoSignedZerosFPMath ||
17756        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
17757     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
17758                        N0.getOperand(0));
17759   }
17760 
17761   if (SDValue Cast = foldSignChangeInBitcast(N))
17762     return Cast;
17763 
17764   return SDValue();
17765 }
17766 
visitFMinMax(SDNode * N)17767 SDValue DAGCombiner::visitFMinMax(SDNode *N) {
17768   SDValue N0 = N->getOperand(0);
17769   SDValue N1 = N->getOperand(1);
17770   EVT VT = N->getValueType(0);
17771   const SDNodeFlags Flags = N->getFlags();
17772   unsigned Opc = N->getOpcode();
17773   bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
17774   bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
17775   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17776 
17777   // Constant fold.
17778   if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
17779     return C;
17780 
17781   // Canonicalize to constant on RHS.
17782   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
17783       !DAG.isConstantFPBuildVectorOrConstantFP(N1))
17784     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
17785 
17786   if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
17787     const APFloat &AF = N1CFP->getValueAPF();
17788 
17789     // minnum(X, nan) -> X
17790     // maxnum(X, nan) -> X
17791     // minimum(X, nan) -> nan
17792     // maximum(X, nan) -> nan
17793     if (AF.isNaN())
17794       return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
17795 
17796     // In the following folds, inf can be replaced with the largest finite
17797     // float, if the ninf flag is set.
17798     if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
17799       // minnum(X, -inf) -> -inf
17800       // maxnum(X, +inf) -> +inf
17801       // minimum(X, -inf) -> -inf if nnan
17802       // maximum(X, +inf) -> +inf if nnan
17803       if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
17804         return N->getOperand(1);
17805 
17806       // minnum(X, +inf) -> X if nnan
17807       // maxnum(X, -inf) -> X if nnan
17808       // minimum(X, +inf) -> X
17809       // maximum(X, -inf) -> X
17810       if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
17811         return N->getOperand(0);
17812     }
17813   }
17814 
17815   if (SDValue SD = reassociateReduction(
17816           PropagatesNaN
17817               ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
17818               : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
17819           Opc, SDLoc(N), VT, N0, N1, Flags))
17820     return SD;
17821 
17822   return SDValue();
17823 }
17824 
visitFABS(SDNode * N)17825 SDValue DAGCombiner::visitFABS(SDNode *N) {
17826   SDValue N0 = N->getOperand(0);
17827   EVT VT = N->getValueType(0);
17828 
17829   // fold (fabs c1) -> fabs(c1)
17830   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17831     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
17832 
17833   // fold (fabs (fabs x)) -> (fabs x)
17834   if (N0.getOpcode() == ISD::FABS)
17835     return N->getOperand(0);
17836 
17837   // fold (fabs (fneg x)) -> (fabs x)
17838   // fold (fabs (fcopysign x, y)) -> (fabs x)
17839   if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
17840     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
17841 
17842   if (SDValue Cast = foldSignChangeInBitcast(N))
17843     return Cast;
17844 
17845   return SDValue();
17846 }
17847 
visitBRCOND(SDNode * N)17848 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
17849   SDValue Chain = N->getOperand(0);
17850   SDValue N1 = N->getOperand(1);
17851   SDValue N2 = N->getOperand(2);
17852 
17853   // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
17854   // nondeterministic jumps).
17855   if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
17856     return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
17857                        N1->getOperand(0), N2);
17858   }
17859 
17860   // Variant of the previous fold where there is a SETCC in between:
17861   //   BRCOND(SETCC(FREEZE(X), CONST, Cond))
17862   // =>
17863   //   BRCOND(FREEZE(SETCC(X, CONST, Cond)))
17864   // =>
17865   //   BRCOND(SETCC(X, CONST, Cond))
17866   // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
17867   // isn't equivalent to true or false.
17868   // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
17869   // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
17870   if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
17871     SDValue S0 = N1->getOperand(0), S1 = N1->getOperand(1);
17872     ISD::CondCode Cond = cast<CondCodeSDNode>(N1->getOperand(2))->get();
17873     ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(S0);
17874     ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(S1);
17875     bool Updated = false;
17876 
17877     // Is 'X Cond C' always true or false?
17878     auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
17879       bool False = (Cond == ISD::SETULT && C->isZero()) ||
17880                    (Cond == ISD::SETLT && C->isMinSignedValue()) ||
17881                    (Cond == ISD::SETUGT && C->isAllOnes()) ||
17882                    (Cond == ISD::SETGT && C->isMaxSignedValue());
17883       bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
17884                   (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
17885                   (Cond == ISD::SETUGE && C->isZero()) ||
17886                   (Cond == ISD::SETGE && C->isMinSignedValue());
17887       return True || False;
17888     };
17889 
17890     if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
17891       if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
17892         S0 = S0->getOperand(0);
17893         Updated = true;
17894       }
17895     }
17896     if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
17897       if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond), S0C)) {
17898         S1 = S1->getOperand(0);
17899         Updated = true;
17900       }
17901     }
17902 
17903     if (Updated)
17904       return DAG.getNode(
17905           ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
17906           DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2);
17907   }
17908 
17909   // If N is a constant we could fold this into a fallthrough or unconditional
17910   // branch. However that doesn't happen very often in normal code, because
17911   // Instcombine/SimplifyCFG should have handled the available opportunities.
17912   // If we did this folding here, it would be necessary to update the
17913   // MachineBasicBlock CFG, which is awkward.
17914 
17915   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
17916   // on the target.
17917   if (N1.getOpcode() == ISD::SETCC &&
17918       TLI.isOperationLegalOrCustom(ISD::BR_CC,
17919                                    N1.getOperand(0).getValueType())) {
17920     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
17921                        Chain, N1.getOperand(2),
17922                        N1.getOperand(0), N1.getOperand(1), N2);
17923   }
17924 
17925   if (N1.hasOneUse()) {
17926     // rebuildSetCC calls visitXor which may change the Chain when there is a
17927     // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
17928     HandleSDNode ChainHandle(Chain);
17929     if (SDValue NewN1 = rebuildSetCC(N1))
17930       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
17931                          ChainHandle.getValue(), NewN1, N2);
17932   }
17933 
17934   return SDValue();
17935 }
17936 
rebuildSetCC(SDValue N)17937 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
17938   if (N.getOpcode() == ISD::SRL ||
17939       (N.getOpcode() == ISD::TRUNCATE &&
17940        (N.getOperand(0).hasOneUse() &&
17941         N.getOperand(0).getOpcode() == ISD::SRL))) {
17942     // Look pass the truncate.
17943     if (N.getOpcode() == ISD::TRUNCATE)
17944       N = N.getOperand(0);
17945 
17946     // Match this pattern so that we can generate simpler code:
17947     //
17948     //   %a = ...
17949     //   %b = and i32 %a, 2
17950     //   %c = srl i32 %b, 1
17951     //   brcond i32 %c ...
17952     //
17953     // into
17954     //
17955     //   %a = ...
17956     //   %b = and i32 %a, 2
17957     //   %c = setcc eq %b, 0
17958     //   brcond %c ...
17959     //
17960     // This applies only when the AND constant value has one bit set and the
17961     // SRL constant is equal to the log2 of the AND constant. The back-end is
17962     // smart enough to convert the result into a TEST/JMP sequence.
17963     SDValue Op0 = N.getOperand(0);
17964     SDValue Op1 = N.getOperand(1);
17965 
17966     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
17967       SDValue AndOp1 = Op0.getOperand(1);
17968 
17969       if (AndOp1.getOpcode() == ISD::Constant) {
17970         const APInt &AndConst = AndOp1->getAsAPIntVal();
17971 
17972         if (AndConst.isPowerOf2() &&
17973             Op1->getAsAPIntVal() == AndConst.logBase2()) {
17974           SDLoc DL(N);
17975           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
17976                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
17977                               ISD::SETNE);
17978         }
17979       }
17980     }
17981   }
17982 
17983   // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
17984   // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
17985   if (N.getOpcode() == ISD::XOR) {
17986     // Because we may call this on a speculatively constructed
17987     // SimplifiedSetCC Node, we need to simplify this node first.
17988     // Ideally this should be folded into SimplifySetCC and not
17989     // here. For now, grab a handle to N so we don't lose it from
17990     // replacements interal to the visit.
17991     HandleSDNode XORHandle(N);
17992     while (N.getOpcode() == ISD::XOR) {
17993       SDValue Tmp = visitXOR(N.getNode());
17994       // No simplification done.
17995       if (!Tmp.getNode())
17996         break;
17997       // Returning N is form in-visit replacement that may invalidated
17998       // N. Grab value from Handle.
17999       if (Tmp.getNode() == N.getNode())
18000         N = XORHandle.getValue();
18001       else // Node simplified. Try simplifying again.
18002         N = Tmp;
18003     }
18004 
18005     if (N.getOpcode() != ISD::XOR)
18006       return N;
18007 
18008     SDValue Op0 = N->getOperand(0);
18009     SDValue Op1 = N->getOperand(1);
18010 
18011     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
18012       bool Equal = false;
18013       // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
18014       if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
18015           Op0.getValueType() == MVT::i1) {
18016         N = Op0;
18017         Op0 = N->getOperand(0);
18018         Op1 = N->getOperand(1);
18019         Equal = true;
18020       }
18021 
18022       EVT SetCCVT = N.getValueType();
18023       if (LegalTypes)
18024         SetCCVT = getSetCCResultType(SetCCVT);
18025       // Replace the uses of XOR with SETCC
18026       return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
18027                           Equal ? ISD::SETEQ : ISD::SETNE);
18028     }
18029   }
18030 
18031   return SDValue();
18032 }
18033 
18034 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
18035 //
visitBR_CC(SDNode * N)18036 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
18037   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
18038   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
18039 
18040   // If N is a constant we could fold this into a fallthrough or unconditional
18041   // branch. However that doesn't happen very often in normal code, because
18042   // Instcombine/SimplifyCFG should have handled the available opportunities.
18043   // If we did this folding here, it would be necessary to update the
18044   // MachineBasicBlock CFG, which is awkward.
18045 
18046   // Use SimplifySetCC to simplify SETCC's.
18047   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
18048                                CondLHS, CondRHS, CC->get(), SDLoc(N),
18049                                false);
18050   if (Simp.getNode()) AddToWorklist(Simp.getNode());
18051 
18052   // fold to a simpler setcc
18053   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
18054     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
18055                        N->getOperand(0), Simp.getOperand(2),
18056                        Simp.getOperand(0), Simp.getOperand(1),
18057                        N->getOperand(4));
18058 
18059   return SDValue();
18060 }
18061 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)18062 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
18063                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
18064                                      const TargetLowering &TLI) {
18065   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
18066     if (LD->isIndexed())
18067       return false;
18068     EVT VT = LD->getMemoryVT();
18069     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
18070       return false;
18071     Ptr = LD->getBasePtr();
18072   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
18073     if (ST->isIndexed())
18074       return false;
18075     EVT VT = ST->getMemoryVT();
18076     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
18077       return false;
18078     Ptr = ST->getBasePtr();
18079     IsLoad = false;
18080   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
18081     if (LD->isIndexed())
18082       return false;
18083     EVT VT = LD->getMemoryVT();
18084     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
18085         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
18086       return false;
18087     Ptr = LD->getBasePtr();
18088     IsMasked = true;
18089   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
18090     if (ST->isIndexed())
18091       return false;
18092     EVT VT = ST->getMemoryVT();
18093     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
18094         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
18095       return false;
18096     Ptr = ST->getBasePtr();
18097     IsLoad = false;
18098     IsMasked = true;
18099   } else {
18100     return false;
18101   }
18102   return true;
18103 }
18104 
18105 /// Try turning a load/store into a pre-indexed load/store when the base
18106 /// pointer is an add or subtract and it has other uses besides the load/store.
18107 /// After the transformation, the new indexed load/store has effectively folded
18108 /// the add/subtract in and all of its other uses are redirected to the
18109 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)18110 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
18111   if (Level < AfterLegalizeDAG)
18112     return false;
18113 
18114   bool IsLoad = true;
18115   bool IsMasked = false;
18116   SDValue Ptr;
18117   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
18118                                 Ptr, TLI))
18119     return false;
18120 
18121   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
18122   // out.  There is no reason to make this a preinc/predec.
18123   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
18124       Ptr->hasOneUse())
18125     return false;
18126 
18127   // Ask the target to do addressing mode selection.
18128   SDValue BasePtr;
18129   SDValue Offset;
18130   ISD::MemIndexedMode AM = ISD::UNINDEXED;
18131   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
18132     return false;
18133 
18134   // Backends without true r+i pre-indexed forms may need to pass a
18135   // constant base with a variable offset so that constant coercion
18136   // will work with the patterns in canonical form.
18137   bool Swapped = false;
18138   if (isa<ConstantSDNode>(BasePtr)) {
18139     std::swap(BasePtr, Offset);
18140     Swapped = true;
18141   }
18142 
18143   // Don't create a indexed load / store with zero offset.
18144   if (isNullConstant(Offset))
18145     return false;
18146 
18147   // Try turning it into a pre-indexed load / store except when:
18148   // 1) The new base ptr is a frame index.
18149   // 2) If N is a store and the new base ptr is either the same as or is a
18150   //    predecessor of the value being stored.
18151   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
18152   //    that would create a cycle.
18153   // 4) All uses are load / store ops that use it as old base ptr.
18154 
18155   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
18156   // (plus the implicit offset) to a register to preinc anyway.
18157   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
18158     return false;
18159 
18160   // Check #2.
18161   if (!IsLoad) {
18162     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
18163                            : cast<StoreSDNode>(N)->getValue();
18164 
18165     // Would require a copy.
18166     if (Val == BasePtr)
18167       return false;
18168 
18169     // Would create a cycle.
18170     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
18171       return false;
18172   }
18173 
18174   // Caches for hasPredecessorHelper.
18175   SmallPtrSet<const SDNode *, 32> Visited;
18176   SmallVector<const SDNode *, 16> Worklist;
18177   Worklist.push_back(N);
18178 
18179   // If the offset is a constant, there may be other adds of constants that
18180   // can be folded with this one. We should do this to avoid having to keep
18181   // a copy of the original base pointer.
18182   SmallVector<SDNode *, 16> OtherUses;
18183   constexpr unsigned int MaxSteps = 8192;
18184   if (isa<ConstantSDNode>(Offset))
18185     for (SDNode::use_iterator UI = BasePtr->use_begin(),
18186                               UE = BasePtr->use_end();
18187          UI != UE; ++UI) {
18188       SDUse &Use = UI.getUse();
18189       // Skip the use that is Ptr and uses of other results from BasePtr's
18190       // node (important for nodes that return multiple results).
18191       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
18192         continue;
18193 
18194       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist,
18195                                        MaxSteps))
18196         continue;
18197 
18198       if (Use.getUser()->getOpcode() != ISD::ADD &&
18199           Use.getUser()->getOpcode() != ISD::SUB) {
18200         OtherUses.clear();
18201         break;
18202       }
18203 
18204       SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
18205       if (!isa<ConstantSDNode>(Op1)) {
18206         OtherUses.clear();
18207         break;
18208       }
18209 
18210       // FIXME: In some cases, we can be smarter about this.
18211       if (Op1.getValueType() != Offset.getValueType()) {
18212         OtherUses.clear();
18213         break;
18214       }
18215 
18216       OtherUses.push_back(Use.getUser());
18217     }
18218 
18219   if (Swapped)
18220     std::swap(BasePtr, Offset);
18221 
18222   // Now check for #3 and #4.
18223   bool RealUse = false;
18224 
18225   for (SDNode *Use : Ptr->uses()) {
18226     if (Use == N)
18227       continue;
18228     if (SDNode::hasPredecessorHelper(Use, Visited, Worklist, MaxSteps))
18229       return false;
18230 
18231     // If Ptr may be folded in addressing mode of other use, then it's
18232     // not profitable to do this transformation.
18233     if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
18234       RealUse = true;
18235   }
18236 
18237   if (!RealUse)
18238     return false;
18239 
18240   SDValue Result;
18241   if (!IsMasked) {
18242     if (IsLoad)
18243       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
18244     else
18245       Result =
18246           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
18247   } else {
18248     if (IsLoad)
18249       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
18250                                         Offset, AM);
18251     else
18252       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
18253                                          Offset, AM);
18254   }
18255   ++PreIndexedNodes;
18256   ++NodesCombined;
18257   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
18258              Result.dump(&DAG); dbgs() << '\n');
18259   WorklistRemover DeadNodes(*this);
18260   if (IsLoad) {
18261     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
18262     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
18263   } else {
18264     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
18265   }
18266 
18267   // Finally, since the node is now dead, remove it from the graph.
18268   deleteAndRecombine(N);
18269 
18270   if (Swapped)
18271     std::swap(BasePtr, Offset);
18272 
18273   // Replace other uses of BasePtr that can be updated to use Ptr
18274   for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
18275     unsigned OffsetIdx = 1;
18276     if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
18277       OffsetIdx = 0;
18278     assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
18279            BasePtr.getNode() && "Expected BasePtr operand");
18280 
18281     // We need to replace ptr0 in the following expression:
18282     //   x0 * offset0 + y0 * ptr0 = t0
18283     // knowing that
18284     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
18285     //
18286     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
18287     // indexed load/store and the expression that needs to be re-written.
18288     //
18289     // Therefore, we have:
18290     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
18291 
18292     auto *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
18293     const APInt &Offset0 = CN->getAPIntValue();
18294     const APInt &Offset1 = Offset->getAsAPIntVal();
18295     int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
18296     int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
18297     int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
18298     int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
18299 
18300     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
18301 
18302     APInt CNV = Offset0;
18303     if (X0 < 0) CNV = -CNV;
18304     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
18305     else CNV = CNV - Offset1;
18306 
18307     SDLoc DL(OtherUses[i]);
18308 
18309     // We can now generate the new expression.
18310     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
18311     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
18312 
18313     SDValue NewUse = DAG.getNode(Opcode,
18314                                  DL,
18315                                  OtherUses[i]->getValueType(0), NewOp1, NewOp2);
18316     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
18317     deleteAndRecombine(OtherUses[i]);
18318   }
18319 
18320   // Replace the uses of Ptr with uses of the updated base value.
18321   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
18322   deleteAndRecombine(Ptr.getNode());
18323   AddToWorklist(Result.getNode());
18324 
18325   return true;
18326 }
18327 
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)18328 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
18329                                    SDValue &BasePtr, SDValue &Offset,
18330                                    ISD::MemIndexedMode &AM,
18331                                    SelectionDAG &DAG,
18332                                    const TargetLowering &TLI) {
18333   if (PtrUse == N ||
18334       (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
18335     return false;
18336 
18337   if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
18338     return false;
18339 
18340   // Don't create a indexed load / store with zero offset.
18341   if (isNullConstant(Offset))
18342     return false;
18343 
18344   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
18345     return false;
18346 
18347   SmallPtrSet<const SDNode *, 32> Visited;
18348   for (SDNode *Use : BasePtr->uses()) {
18349     if (Use == Ptr.getNode())
18350       continue;
18351 
18352     // No if there's a later user which could perform the index instead.
18353     if (isa<MemSDNode>(Use)) {
18354       bool IsLoad = true;
18355       bool IsMasked = false;
18356       SDValue OtherPtr;
18357       if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad,
18358                                    IsMasked, OtherPtr, TLI)) {
18359         SmallVector<const SDNode *, 2> Worklist;
18360         Worklist.push_back(Use);
18361         if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
18362           return false;
18363       }
18364     }
18365 
18366     // If all the uses are load / store addresses, then don't do the
18367     // transformation.
18368     if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
18369       for (SDNode *UseUse : Use->uses())
18370         if (canFoldInAddressingMode(Use, UseUse, DAG, TLI))
18371           return false;
18372     }
18373   }
18374   return true;
18375 }
18376 
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)18377 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
18378                                          bool &IsMasked, SDValue &Ptr,
18379                                          SDValue &BasePtr, SDValue &Offset,
18380                                          ISD::MemIndexedMode &AM,
18381                                          SelectionDAG &DAG,
18382                                          const TargetLowering &TLI) {
18383   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
18384                                 IsMasked, Ptr, TLI) ||
18385       Ptr->hasOneUse())
18386     return nullptr;
18387 
18388   // Try turning it into a post-indexed load / store except when
18389   // 1) All uses are load / store ops that use it as base ptr (and
18390   //    it may be folded as addressing mmode).
18391   // 2) Op must be independent of N, i.e. Op is neither a predecessor
18392   //    nor a successor of N. Otherwise, if Op is folded that would
18393   //    create a cycle.
18394   for (SDNode *Op : Ptr->uses()) {
18395     // Check for #1.
18396     if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
18397       continue;
18398 
18399     // Check for #2.
18400     SmallPtrSet<const SDNode *, 32> Visited;
18401     SmallVector<const SDNode *, 8> Worklist;
18402     constexpr unsigned int MaxSteps = 8192;
18403     // Ptr is predecessor to both N and Op.
18404     Visited.insert(Ptr.getNode());
18405     Worklist.push_back(N);
18406     Worklist.push_back(Op);
18407     if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
18408         !SDNode::hasPredecessorHelper(Op, Visited, Worklist, MaxSteps))
18409       return Op;
18410   }
18411   return nullptr;
18412 }
18413 
18414 /// Try to combine a load/store with a add/sub of the base pointer node into a
18415 /// post-indexed load/store. The transformation folded the add/subtract into the
18416 /// new indexed load/store effectively and all of its uses are redirected to the
18417 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)18418 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
18419   if (Level < AfterLegalizeDAG)
18420     return false;
18421 
18422   bool IsLoad = true;
18423   bool IsMasked = false;
18424   SDValue Ptr;
18425   SDValue BasePtr;
18426   SDValue Offset;
18427   ISD::MemIndexedMode AM = ISD::UNINDEXED;
18428   SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
18429                                          Offset, AM, DAG, TLI);
18430   if (!Op)
18431     return false;
18432 
18433   SDValue Result;
18434   if (!IsMasked)
18435     Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
18436                                          Offset, AM)
18437                     : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
18438                                           BasePtr, Offset, AM);
18439   else
18440     Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
18441                                                BasePtr, Offset, AM)
18442                     : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
18443                                                 BasePtr, Offset, AM);
18444   ++PostIndexedNodes;
18445   ++NodesCombined;
18446   LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
18447              Result.dump(&DAG); dbgs() << '\n');
18448   WorklistRemover DeadNodes(*this);
18449   if (IsLoad) {
18450     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
18451     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
18452   } else {
18453     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
18454   }
18455 
18456   // Finally, since the node is now dead, remove it from the graph.
18457   deleteAndRecombine(N);
18458 
18459   // Replace the uses of Use with uses of the updated base value.
18460   DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
18461                                 Result.getValue(IsLoad ? 1 : 0));
18462   deleteAndRecombine(Op);
18463   return true;
18464 }
18465 
18466 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)18467 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
18468   ISD::MemIndexedMode AM = LD->getAddressingMode();
18469   assert(AM != ISD::UNINDEXED);
18470   SDValue BP = LD->getOperand(1);
18471   SDValue Inc = LD->getOperand(2);
18472 
18473   // Some backends use TargetConstants for load offsets, but don't expect
18474   // TargetConstants in general ADD nodes. We can convert these constants into
18475   // regular Constants (if the constant is not opaque).
18476   assert((Inc.getOpcode() != ISD::TargetConstant ||
18477           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
18478          "Cannot split out indexing using opaque target constants");
18479   if (Inc.getOpcode() == ISD::TargetConstant) {
18480     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
18481     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
18482                           ConstInc->getValueType(0));
18483   }
18484 
18485   unsigned Opc =
18486       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
18487   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
18488 }
18489 
numVectorEltsOrZero(EVT T)18490 static inline ElementCount numVectorEltsOrZero(EVT T) {
18491   return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
18492 }
18493 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)18494 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
18495   EVT STType = Val.getValueType();
18496   EVT STMemType = ST->getMemoryVT();
18497   if (STType == STMemType)
18498     return true;
18499   if (isTypeLegal(STMemType))
18500     return false; // fail.
18501   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
18502       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
18503     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
18504     return true;
18505   }
18506   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
18507       STType.isInteger() && STMemType.isInteger()) {
18508     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
18509     return true;
18510   }
18511   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
18512     Val = DAG.getBitcast(STMemType, Val);
18513     return true;
18514   }
18515   return false; // fail.
18516 }
18517 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)18518 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
18519   EVT LDMemType = LD->getMemoryVT();
18520   EVT LDType = LD->getValueType(0);
18521   assert(Val.getValueType() == LDMemType &&
18522          "Attempting to extend value of non-matching type");
18523   if (LDType == LDMemType)
18524     return true;
18525   if (LDMemType.isInteger() && LDType.isInteger()) {
18526     switch (LD->getExtensionType()) {
18527     case ISD::NON_EXTLOAD:
18528       Val = DAG.getBitcast(LDType, Val);
18529       return true;
18530     case ISD::EXTLOAD:
18531       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
18532       return true;
18533     case ISD::SEXTLOAD:
18534       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
18535       return true;
18536     case ISD::ZEXTLOAD:
18537       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
18538       return true;
18539     }
18540   }
18541   return false;
18542 }
18543 
getUniqueStoreFeeding(LoadSDNode * LD,int64_t & Offset)18544 StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
18545                                                 int64_t &Offset) {
18546   SDValue Chain = LD->getOperand(0);
18547 
18548   // Look through CALLSEQ_START.
18549   if (Chain.getOpcode() == ISD::CALLSEQ_START)
18550     Chain = Chain->getOperand(0);
18551 
18552   StoreSDNode *ST = nullptr;
18553   SmallVector<SDValue, 8> Aliases;
18554   if (Chain.getOpcode() == ISD::TokenFactor) {
18555     // Look for unique store within the TokenFactor.
18556     for (SDValue Op : Chain->ops()) {
18557       StoreSDNode *Store = dyn_cast<StoreSDNode>(Op.getNode());
18558       if (!Store)
18559         continue;
18560       BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
18561       BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG);
18562       if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
18563         continue;
18564       // Make sure the store is not aliased with any nodes in TokenFactor.
18565       GatherAllAliases(Store, Chain, Aliases);
18566       if (Aliases.empty() ||
18567           (Aliases.size() == 1 && Aliases.front().getNode() == Store))
18568         ST = Store;
18569       break;
18570     }
18571   } else {
18572     StoreSDNode *Store = dyn_cast<StoreSDNode>(Chain.getNode());
18573     if (Store) {
18574       BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
18575       BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG);
18576       if (BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
18577         ST = Store;
18578     }
18579   }
18580 
18581   return ST;
18582 }
18583 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)18584 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
18585   if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
18586     return SDValue();
18587   SDValue Chain = LD->getOperand(0);
18588   int64_t Offset;
18589 
18590   StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
18591   // TODO: Relax this restriction for unordered atomics (see D66309)
18592   if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
18593     return SDValue();
18594 
18595   EVT LDType = LD->getValueType(0);
18596   EVT LDMemType = LD->getMemoryVT();
18597   EVT STMemType = ST->getMemoryVT();
18598   EVT STType = ST->getValue().getValueType();
18599 
18600   // There are two cases to consider here:
18601   //  1. The store is fixed width and the load is scalable. In this case we
18602   //     don't know at compile time if the store completely envelops the load
18603   //     so we abandon the optimisation.
18604   //  2. The store is scalable and the load is fixed width. We could
18605   //     potentially support a limited number of cases here, but there has been
18606   //     no cost-benefit analysis to prove it's worth it.
18607   bool LdStScalable = LDMemType.isScalableVT();
18608   if (LdStScalable != STMemType.isScalableVT())
18609     return SDValue();
18610 
18611   // If we are dealing with scalable vectors on a big endian platform the
18612   // calculation of offsets below becomes trickier, since we do not know at
18613   // compile time the absolute size of the vector. Until we've done more
18614   // analysis on big-endian platforms it seems better to bail out for now.
18615   if (LdStScalable && DAG.getDataLayout().isBigEndian())
18616     return SDValue();
18617 
18618   // Normalize for Endianness. After this Offset=0 will denote that the least
18619   // significant bit in the loaded value maps to the least significant bit in
18620   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
18621   // n:th least significant byte of the stored value.
18622   int64_t OrigOffset = Offset;
18623   if (DAG.getDataLayout().isBigEndian())
18624     Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
18625               (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
18626                  8 -
18627              Offset;
18628 
18629   // Check that the stored value cover all bits that are loaded.
18630   bool STCoversLD;
18631 
18632   TypeSize LdMemSize = LDMemType.getSizeInBits();
18633   TypeSize StMemSize = STMemType.getSizeInBits();
18634   if (LdStScalable)
18635     STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
18636   else
18637     STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
18638                                    StMemSize.getFixedValue());
18639 
18640   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
18641     if (LD->isIndexed()) {
18642       // Cannot handle opaque target constants and we must respect the user's
18643       // request not to split indexes from loads.
18644       if (!canSplitIdx(LD))
18645         return SDValue();
18646       SDValue Idx = SplitIndexingFromLoad(LD);
18647       SDValue Ops[] = {Val, Idx, Chain};
18648       return CombineTo(LD, Ops, 3);
18649     }
18650     return CombineTo(LD, Val, Chain);
18651   };
18652 
18653   if (!STCoversLD)
18654     return SDValue();
18655 
18656   // Memory as copy space (potentially masked).
18657   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
18658     // Simple case: Direct non-truncating forwarding
18659     if (LDType.getSizeInBits() == LdMemSize)
18660       return ReplaceLd(LD, ST->getValue(), Chain);
18661     // Can we model the truncate and extension with an and mask?
18662     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
18663         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
18664       // Mask to size of LDMemType
18665       auto Mask =
18666           DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
18667                                                StMemSize.getFixedValue()),
18668                           SDLoc(ST), STType);
18669       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
18670       return ReplaceLd(LD, Val, Chain);
18671     }
18672   }
18673 
18674   // Handle some cases for big-endian that would be Offset 0 and handled for
18675   // little-endian.
18676   SDValue Val = ST->getValue();
18677   if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
18678     if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
18679         !LDType.isVector() && isTypeLegal(STType) &&
18680         TLI.isOperationLegal(ISD::SRL, STType)) {
18681       Val = DAG.getNode(ISD::SRL, SDLoc(LD), STType, Val,
18682                         DAG.getConstant(Offset * 8, SDLoc(LD), STType));
18683       Offset = 0;
18684     }
18685   }
18686 
18687   // TODO: Deal with nonzero offset.
18688   if (LD->getBasePtr().isUndef() || Offset != 0)
18689     return SDValue();
18690   // Model necessary truncations / extenstions.
18691   // Truncate Value To Stored Memory Size.
18692   do {
18693     if (!getTruncatedStoreValue(ST, Val))
18694       continue;
18695     if (!isTypeLegal(LDMemType))
18696       continue;
18697     if (STMemType != LDMemType) {
18698       // TODO: Support vectors? This requires extract_subvector/bitcast.
18699       if (!STMemType.isVector() && !LDMemType.isVector() &&
18700           STMemType.isInteger() && LDMemType.isInteger())
18701         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
18702       else
18703         continue;
18704     }
18705     if (!extendLoadedValueToExtension(LD, Val))
18706       continue;
18707     return ReplaceLd(LD, Val, Chain);
18708   } while (false);
18709 
18710   // On failure, cleanup dead nodes we may have created.
18711   if (Val->use_empty())
18712     deleteAndRecombine(Val.getNode());
18713   return SDValue();
18714 }
18715 
visitLOAD(SDNode * N)18716 SDValue DAGCombiner::visitLOAD(SDNode *N) {
18717   LoadSDNode *LD  = cast<LoadSDNode>(N);
18718   SDValue Chain = LD->getChain();
18719   SDValue Ptr   = LD->getBasePtr();
18720 
18721   // If load is not volatile and there are no uses of the loaded value (and
18722   // the updated indexed value in case of indexed loads), change uses of the
18723   // chain value into uses of the chain input (i.e. delete the dead load).
18724   // TODO: Allow this for unordered atomics (see D66309)
18725   if (LD->isSimple()) {
18726     if (N->getValueType(1) == MVT::Other) {
18727       // Unindexed loads.
18728       if (!N->hasAnyUseOfValue(0)) {
18729         // It's not safe to use the two value CombineTo variant here. e.g.
18730         // v1, chain2 = load chain1, loc
18731         // v2, chain3 = load chain2, loc
18732         // v3         = add v2, c
18733         // Now we replace use of chain2 with chain1.  This makes the second load
18734         // isomorphic to the one we are deleting, and thus makes this load live.
18735         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
18736                    dbgs() << "\nWith chain: "; Chain.dump(&DAG);
18737                    dbgs() << "\n");
18738         WorklistRemover DeadNodes(*this);
18739         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
18740         AddUsersToWorklist(Chain.getNode());
18741         if (N->use_empty())
18742           deleteAndRecombine(N);
18743 
18744         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
18745       }
18746     } else {
18747       // Indexed loads.
18748       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
18749 
18750       // If this load has an opaque TargetConstant offset, then we cannot split
18751       // the indexing into an add/sub directly (that TargetConstant may not be
18752       // valid for a different type of node, and we cannot convert an opaque
18753       // target constant into a regular constant).
18754       bool CanSplitIdx = canSplitIdx(LD);
18755 
18756       if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
18757         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
18758         SDValue Index;
18759         if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
18760           Index = SplitIndexingFromLoad(LD);
18761           // Try to fold the base pointer arithmetic into subsequent loads and
18762           // stores.
18763           AddUsersToWorklist(N);
18764         } else
18765           Index = DAG.getUNDEF(N->getValueType(1));
18766         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
18767                    dbgs() << "\nWith: "; Undef.dump(&DAG);
18768                    dbgs() << " and 2 other values\n");
18769         WorklistRemover DeadNodes(*this);
18770         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
18771         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
18772         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
18773         deleteAndRecombine(N);
18774         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
18775       }
18776     }
18777   }
18778 
18779   // If this load is directly stored, replace the load value with the stored
18780   // value.
18781   if (auto V = ForwardStoreValueToDirectLoad(LD))
18782     return V;
18783 
18784   // Try to infer better alignment information than the load already has.
18785   if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
18786       !LD->isAtomic()) {
18787     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
18788       if (*Alignment > LD->getAlign() &&
18789           isAligned(*Alignment, LD->getSrcValueOffset())) {
18790         SDValue NewLoad = DAG.getExtLoad(
18791             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
18792             LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
18793             LD->getMemOperand()->getFlags(), LD->getAAInfo());
18794         // NewLoad will always be N as we are only refining the alignment
18795         assert(NewLoad.getNode() == N);
18796         (void)NewLoad;
18797       }
18798     }
18799   }
18800 
18801   if (LD->isUnindexed()) {
18802     // Walk up chain skipping non-aliasing memory nodes.
18803     SDValue BetterChain = FindBetterChain(LD, Chain);
18804 
18805     // If there is a better chain.
18806     if (Chain != BetterChain) {
18807       SDValue ReplLoad;
18808 
18809       // Replace the chain to void dependency.
18810       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
18811         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
18812                                BetterChain, Ptr, LD->getMemOperand());
18813       } else {
18814         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
18815                                   LD->getValueType(0),
18816                                   BetterChain, Ptr, LD->getMemoryVT(),
18817                                   LD->getMemOperand());
18818       }
18819 
18820       // Create token factor to keep old chain connected.
18821       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
18822                                   MVT::Other, Chain, ReplLoad.getValue(1));
18823 
18824       // Replace uses with load result and token factor
18825       return CombineTo(N, ReplLoad.getValue(0), Token);
18826     }
18827   }
18828 
18829   // Try transforming N to an indexed load.
18830   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
18831     return SDValue(N, 0);
18832 
18833   // Try to slice up N to more direct loads if the slices are mapped to
18834   // different register banks or pairing can take place.
18835   if (SliceUpLoad(N))
18836     return SDValue(N, 0);
18837 
18838   return SDValue();
18839 }
18840 
18841 namespace {
18842 
18843 /// Helper structure used to slice a load in smaller loads.
18844 /// Basically a slice is obtained from the following sequence:
18845 /// Origin = load Ty1, Base
18846 /// Shift = srl Ty1 Origin, CstTy Amount
18847 /// Inst = trunc Shift to Ty2
18848 ///
18849 /// Then, it will be rewritten into:
18850 /// Slice = load SliceTy, Base + SliceOffset
18851 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
18852 ///
18853 /// SliceTy is deduced from the number of bits that are actually used to
18854 /// build Inst.
18855 struct LoadedSlice {
18856   /// Helper structure used to compute the cost of a slice.
18857   struct Cost {
18858     /// Are we optimizing for code size.
18859     bool ForCodeSize = false;
18860 
18861     /// Various cost.
18862     unsigned Loads = 0;
18863     unsigned Truncates = 0;
18864     unsigned CrossRegisterBanksCopies = 0;
18865     unsigned ZExts = 0;
18866     unsigned Shift = 0;
18867 
Cost__anon8fac8bd84411::LoadedSlice::Cost18868     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
18869 
18870     /// Get the cost of one isolated slice.
Cost__anon8fac8bd84411::LoadedSlice::Cost18871     Cost(const LoadedSlice &LS, bool ForCodeSize)
18872         : ForCodeSize(ForCodeSize), Loads(1) {
18873       EVT TruncType = LS.Inst->getValueType(0);
18874       EVT LoadedType = LS.getLoadedType();
18875       if (TruncType != LoadedType &&
18876           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
18877         ZExts = 1;
18878     }
18879 
18880     /// Account for slicing gain in the current cost.
18881     /// Slicing provide a few gains like removing a shift or a
18882     /// truncate. This method allows to grow the cost of the original
18883     /// load with the gain from this slice.
addSliceGain__anon8fac8bd84411::LoadedSlice::Cost18884     void addSliceGain(const LoadedSlice &LS) {
18885       // Each slice saves a truncate.
18886       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
18887       if (!TLI.isTruncateFree(LS.Inst->getOperand(0), LS.Inst->getValueType(0)))
18888         ++Truncates;
18889       // If there is a shift amount, this slice gets rid of it.
18890       if (LS.Shift)
18891         ++Shift;
18892       // If this slice can merge a cross register bank copy, account for it.
18893       if (LS.canMergeExpensiveCrossRegisterBankCopy())
18894         ++CrossRegisterBanksCopies;
18895     }
18896 
operator +=__anon8fac8bd84411::LoadedSlice::Cost18897     Cost &operator+=(const Cost &RHS) {
18898       Loads += RHS.Loads;
18899       Truncates += RHS.Truncates;
18900       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
18901       ZExts += RHS.ZExts;
18902       Shift += RHS.Shift;
18903       return *this;
18904     }
18905 
operator ==__anon8fac8bd84411::LoadedSlice::Cost18906     bool operator==(const Cost &RHS) const {
18907       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
18908              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
18909              ZExts == RHS.ZExts && Shift == RHS.Shift;
18910     }
18911 
operator !=__anon8fac8bd84411::LoadedSlice::Cost18912     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
18913 
operator <__anon8fac8bd84411::LoadedSlice::Cost18914     bool operator<(const Cost &RHS) const {
18915       // Assume cross register banks copies are as expensive as loads.
18916       // FIXME: Do we want some more target hooks?
18917       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
18918       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
18919       // Unless we are optimizing for code size, consider the
18920       // expensive operation first.
18921       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
18922         return ExpensiveOpsLHS < ExpensiveOpsRHS;
18923       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
18924              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
18925     }
18926 
operator >__anon8fac8bd84411::LoadedSlice::Cost18927     bool operator>(const Cost &RHS) const { return RHS < *this; }
18928 
operator <=__anon8fac8bd84411::LoadedSlice::Cost18929     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
18930 
operator >=__anon8fac8bd84411::LoadedSlice::Cost18931     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
18932   };
18933 
18934   // The last instruction that represent the slice. This should be a
18935   // truncate instruction.
18936   SDNode *Inst;
18937 
18938   // The original load instruction.
18939   LoadSDNode *Origin;
18940 
18941   // The right shift amount in bits from the original load.
18942   unsigned Shift;
18943 
18944   // The DAG from which Origin came from.
18945   // This is used to get some contextual information about legal types, etc.
18946   SelectionDAG *DAG;
18947 
LoadedSlice__anon8fac8bd84411::LoadedSlice18948   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
18949               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
18950       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
18951 
18952   /// Get the bits used in a chunk of bits \p BitWidth large.
18953   /// \return Result is \p BitWidth and has used bits set to 1 and
18954   ///         not used bits set to 0.
getUsedBits__anon8fac8bd84411::LoadedSlice18955   APInt getUsedBits() const {
18956     // Reproduce the trunc(lshr) sequence:
18957     // - Start from the truncated value.
18958     // - Zero extend to the desired bit width.
18959     // - Shift left.
18960     assert(Origin && "No original load to compare against.");
18961     unsigned BitWidth = Origin->getValueSizeInBits(0);
18962     assert(Inst && "This slice is not bound to an instruction");
18963     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
18964            "Extracted slice is bigger than the whole type!");
18965     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
18966     UsedBits.setAllBits();
18967     UsedBits = UsedBits.zext(BitWidth);
18968     UsedBits <<= Shift;
18969     return UsedBits;
18970   }
18971 
18972   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon8fac8bd84411::LoadedSlice18973   unsigned getLoadedSize() const {
18974     unsigned SliceSize = getUsedBits().popcount();
18975     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
18976     return SliceSize / 8;
18977   }
18978 
18979   /// Get the type that will be loaded for this slice.
18980   /// Note: This may not be the final type for the slice.
getLoadedType__anon8fac8bd84411::LoadedSlice18981   EVT getLoadedType() const {
18982     assert(DAG && "Missing context");
18983     LLVMContext &Ctxt = *DAG->getContext();
18984     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
18985   }
18986 
18987   /// Get the alignment of the load used for this slice.
getAlign__anon8fac8bd84411::LoadedSlice18988   Align getAlign() const {
18989     Align Alignment = Origin->getAlign();
18990     uint64_t Offset = getOffsetFromBase();
18991     if (Offset != 0)
18992       Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
18993     return Alignment;
18994   }
18995 
18996   /// Check if this slice can be rewritten with legal operations.
isLegal__anon8fac8bd84411::LoadedSlice18997   bool isLegal() const {
18998     // An invalid slice is not legal.
18999     if (!Origin || !Inst || !DAG)
19000       return false;
19001 
19002     // Offsets are for indexed load only, we do not handle that.
19003     if (!Origin->getOffset().isUndef())
19004       return false;
19005 
19006     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19007 
19008     // Check that the type is legal.
19009     EVT SliceType = getLoadedType();
19010     if (!TLI.isTypeLegal(SliceType))
19011       return false;
19012 
19013     // Check that the load is legal for this type.
19014     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
19015       return false;
19016 
19017     // Check that the offset can be computed.
19018     // 1. Check its type.
19019     EVT PtrType = Origin->getBasePtr().getValueType();
19020     if (PtrType == MVT::Untyped || PtrType.isExtended())
19021       return false;
19022 
19023     // 2. Check that it fits in the immediate.
19024     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
19025       return false;
19026 
19027     // 3. Check that the computation is legal.
19028     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
19029       return false;
19030 
19031     // Check that the zext is legal if it needs one.
19032     EVT TruncateType = Inst->getValueType(0);
19033     if (TruncateType != SliceType &&
19034         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
19035       return false;
19036 
19037     return true;
19038   }
19039 
19040   /// Get the offset in bytes of this slice in the original chunk of
19041   /// bits.
19042   /// \pre DAG != nullptr.
getOffsetFromBase__anon8fac8bd84411::LoadedSlice19043   uint64_t getOffsetFromBase() const {
19044     assert(DAG && "Missing context.");
19045     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
19046     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
19047     uint64_t Offset = Shift / 8;
19048     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
19049     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
19050            "The size of the original loaded type is not a multiple of a"
19051            " byte.");
19052     // If Offset is bigger than TySizeInBytes, it means we are loading all
19053     // zeros. This should have been optimized before in the process.
19054     assert(TySizeInBytes > Offset &&
19055            "Invalid shift amount for given loaded size");
19056     if (IsBigEndian)
19057       Offset = TySizeInBytes - Offset - getLoadedSize();
19058     return Offset;
19059   }
19060 
19061   /// Generate the sequence of instructions to load the slice
19062   /// represented by this object and redirect the uses of this slice to
19063   /// this new sequence of instructions.
19064   /// \pre this->Inst && this->Origin are valid Instructions and this
19065   /// object passed the legal check: LoadedSlice::isLegal returned true.
19066   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon8fac8bd84411::LoadedSlice19067   SDValue loadSlice() const {
19068     assert(Inst && Origin && "Unable to replace a non-existing slice.");
19069     const SDValue &OldBaseAddr = Origin->getBasePtr();
19070     SDValue BaseAddr = OldBaseAddr;
19071     // Get the offset in that chunk of bytes w.r.t. the endianness.
19072     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
19073     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
19074     if (Offset) {
19075       // BaseAddr = BaseAddr + Offset.
19076       EVT ArithType = BaseAddr.getValueType();
19077       SDLoc DL(Origin);
19078       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
19079                               DAG->getConstant(Offset, DL, ArithType));
19080     }
19081 
19082     // Create the type of the loaded slice according to its size.
19083     EVT SliceType = getLoadedType();
19084 
19085     // Create the load for the slice.
19086     SDValue LastInst =
19087         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
19088                      Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
19089                      Origin->getMemOperand()->getFlags());
19090     // If the final type is not the same as the loaded type, this means that
19091     // we have to pad with zero. Create a zero extend for that.
19092     EVT FinalType = Inst->getValueType(0);
19093     if (SliceType != FinalType)
19094       LastInst =
19095           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
19096     return LastInst;
19097   }
19098 
19099   /// Check if this slice can be merged with an expensive cross register
19100   /// bank copy. E.g.,
19101   /// i = load i32
19102   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon8fac8bd84411::LoadedSlice19103   bool canMergeExpensiveCrossRegisterBankCopy() const {
19104     if (!Inst || !Inst->hasOneUse())
19105       return false;
19106     SDNode *Use = *Inst->use_begin();
19107     if (Use->getOpcode() != ISD::BITCAST)
19108       return false;
19109     assert(DAG && "Missing context");
19110     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19111     EVT ResVT = Use->getValueType(0);
19112     const TargetRegisterClass *ResRC =
19113         TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
19114     const TargetRegisterClass *ArgRC =
19115         TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
19116                            Use->getOperand(0)->isDivergent());
19117     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
19118       return false;
19119 
19120     // At this point, we know that we perform a cross-register-bank copy.
19121     // Check if it is expensive.
19122     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
19123     // Assume bitcasts are cheap, unless both register classes do not
19124     // explicitly share a common sub class.
19125     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
19126       return false;
19127 
19128     // Check if it will be merged with the load.
19129     // 1. Check the alignment / fast memory access constraint.
19130     unsigned IsFast = 0;
19131     if (!TLI.allowsMemoryAccess(*DAG->getContext(), DAG->getDataLayout(), ResVT,
19132                                 Origin->getAddressSpace(), getAlign(),
19133                                 Origin->getMemOperand()->getFlags(), &IsFast) ||
19134         !IsFast)
19135       return false;
19136 
19137     // 2. Check that the load is a legal operation for that type.
19138     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
19139       return false;
19140 
19141     // 3. Check that we do not have a zext in the way.
19142     if (Inst->getValueType(0) != getLoadedType())
19143       return false;
19144 
19145     return true;
19146   }
19147 };
19148 
19149 } // end anonymous namespace
19150 
19151 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
19152 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)19153 static bool areUsedBitsDense(const APInt &UsedBits) {
19154   // If all the bits are one, this is dense!
19155   if (UsedBits.isAllOnes())
19156     return true;
19157 
19158   // Get rid of the unused bits on the right.
19159   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countr_zero());
19160   // Get rid of the unused bits on the left.
19161   if (NarrowedUsedBits.countl_zero())
19162     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
19163   // Check that the chunk of bits is completely used.
19164   return NarrowedUsedBits.isAllOnes();
19165 }
19166 
19167 /// Check whether or not \p First and \p Second are next to each other
19168 /// in memory. This means that there is no hole between the bits loaded
19169 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)19170 static bool areSlicesNextToEachOther(const LoadedSlice &First,
19171                                      const LoadedSlice &Second) {
19172   assert(First.Origin == Second.Origin && First.Origin &&
19173          "Unable to match different memory origins.");
19174   APInt UsedBits = First.getUsedBits();
19175   assert((UsedBits & Second.getUsedBits()) == 0 &&
19176          "Slices are not supposed to overlap.");
19177   UsedBits |= Second.getUsedBits();
19178   return areUsedBitsDense(UsedBits);
19179 }
19180 
19181 /// Adjust the \p GlobalLSCost according to the target
19182 /// paring capabilities and the layout of the slices.
19183 /// \pre \p GlobalLSCost should account for at least as many loads as
19184 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)19185 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19186                                  LoadedSlice::Cost &GlobalLSCost) {
19187   unsigned NumberOfSlices = LoadedSlices.size();
19188   // If there is less than 2 elements, no pairing is possible.
19189   if (NumberOfSlices < 2)
19190     return;
19191 
19192   // Sort the slices so that elements that are likely to be next to each
19193   // other in memory are next to each other in the list.
19194   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
19195     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
19196     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
19197   });
19198   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
19199   // First (resp. Second) is the first (resp. Second) potentially candidate
19200   // to be placed in a paired load.
19201   const LoadedSlice *First = nullptr;
19202   const LoadedSlice *Second = nullptr;
19203   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
19204                 // Set the beginning of the pair.
19205                                                            First = Second) {
19206     Second = &LoadedSlices[CurrSlice];
19207 
19208     // If First is NULL, it means we start a new pair.
19209     // Get to the next slice.
19210     if (!First)
19211       continue;
19212 
19213     EVT LoadedType = First->getLoadedType();
19214 
19215     // If the types of the slices are different, we cannot pair them.
19216     if (LoadedType != Second->getLoadedType())
19217       continue;
19218 
19219     // Check if the target supplies paired loads for this type.
19220     Align RequiredAlignment;
19221     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
19222       // move to the next pair, this type is hopeless.
19223       Second = nullptr;
19224       continue;
19225     }
19226     // Check if we meet the alignment requirement.
19227     if (First->getAlign() < RequiredAlignment)
19228       continue;
19229 
19230     // Check that both loads are next to each other in memory.
19231     if (!areSlicesNextToEachOther(*First, *Second))
19232       continue;
19233 
19234     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
19235     --GlobalLSCost.Loads;
19236     // Move to the next pair.
19237     Second = nullptr;
19238   }
19239 }
19240 
19241 /// Check the profitability of all involved LoadedSlice.
19242 /// Currently, it is considered profitable if there is exactly two
19243 /// involved slices (1) which are (2) next to each other in memory, and
19244 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
19245 ///
19246 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
19247 /// the elements themselves.
19248 ///
19249 /// FIXME: When the cost model will be mature enough, we can relax
19250 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)19251 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19252                                 const APInt &UsedBits, bool ForCodeSize) {
19253   unsigned NumberOfSlices = LoadedSlices.size();
19254   if (StressLoadSlicing)
19255     return NumberOfSlices > 1;
19256 
19257   // Check (1).
19258   if (NumberOfSlices != 2)
19259     return false;
19260 
19261   // Check (2).
19262   if (!areUsedBitsDense(UsedBits))
19263     return false;
19264 
19265   // Check (3).
19266   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
19267   // The original code has one big load.
19268   OrigCost.Loads = 1;
19269   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
19270     const LoadedSlice &LS = LoadedSlices[CurrSlice];
19271     // Accumulate the cost of all the slices.
19272     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
19273     GlobalSlicingCost += SliceCost;
19274 
19275     // Account as cost in the original configuration the gain obtained
19276     // with the current slices.
19277     OrigCost.addSliceGain(LS);
19278   }
19279 
19280   // If the target supports paired load, adjust the cost accordingly.
19281   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
19282   return OrigCost > GlobalSlicingCost;
19283 }
19284 
19285 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
19286 /// operations, split it in the various pieces being extracted.
19287 ///
19288 /// This sort of thing is introduced by SROA.
19289 /// This slicing takes care not to insert overlapping loads.
19290 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)19291 bool DAGCombiner::SliceUpLoad(SDNode *N) {
19292   if (Level < AfterLegalizeDAG)
19293     return false;
19294 
19295   LoadSDNode *LD = cast<LoadSDNode>(N);
19296   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
19297       !LD->getValueType(0).isInteger())
19298     return false;
19299 
19300   // The algorithm to split up a load of a scalable vector into individual
19301   // elements currently requires knowing the length of the loaded type,
19302   // so will need adjusting to work on scalable vectors.
19303   if (LD->getValueType(0).isScalableVector())
19304     return false;
19305 
19306   // Keep track of already used bits to detect overlapping values.
19307   // In that case, we will just abort the transformation.
19308   APInt UsedBits(LD->getValueSizeInBits(0), 0);
19309 
19310   SmallVector<LoadedSlice, 4> LoadedSlices;
19311 
19312   // Check if this load is used as several smaller chunks of bits.
19313   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
19314   // of computation for each trunc.
19315   for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
19316        UI != UIEnd; ++UI) {
19317     // Skip the uses of the chain.
19318     if (UI.getUse().getResNo() != 0)
19319       continue;
19320 
19321     SDNode *User = *UI;
19322     unsigned Shift = 0;
19323 
19324     // Check if this is a trunc(lshr).
19325     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
19326         isa<ConstantSDNode>(User->getOperand(1))) {
19327       Shift = User->getConstantOperandVal(1);
19328       User = *User->use_begin();
19329     }
19330 
19331     // At this point, User is a Truncate, iff we encountered, trunc or
19332     // trunc(lshr).
19333     if (User->getOpcode() != ISD::TRUNCATE)
19334       return false;
19335 
19336     // The width of the type must be a power of 2 and greater than 8-bits.
19337     // Otherwise the load cannot be represented in LLVM IR.
19338     // Moreover, if we shifted with a non-8-bits multiple, the slice
19339     // will be across several bytes. We do not support that.
19340     unsigned Width = User->getValueSizeInBits(0);
19341     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
19342       return false;
19343 
19344     // Build the slice for this chain of computations.
19345     LoadedSlice LS(User, LD, Shift, &DAG);
19346     APInt CurrentUsedBits = LS.getUsedBits();
19347 
19348     // Check if this slice overlaps with another.
19349     if ((CurrentUsedBits & UsedBits) != 0)
19350       return false;
19351     // Update the bits used globally.
19352     UsedBits |= CurrentUsedBits;
19353 
19354     // Check if the new slice would be legal.
19355     if (!LS.isLegal())
19356       return false;
19357 
19358     // Record the slice.
19359     LoadedSlices.push_back(LS);
19360   }
19361 
19362   // Abort slicing if it does not seem to be profitable.
19363   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
19364     return false;
19365 
19366   ++SlicedLoads;
19367 
19368   // Rewrite each chain to use an independent load.
19369   // By construction, each chain can be represented by a unique load.
19370 
19371   // Prepare the argument for the new token factor for all the slices.
19372   SmallVector<SDValue, 8> ArgChains;
19373   for (const LoadedSlice &LS : LoadedSlices) {
19374     SDValue SliceInst = LS.loadSlice();
19375     CombineTo(LS.Inst, SliceInst, true);
19376     if (SliceInst.getOpcode() != ISD::LOAD)
19377       SliceInst = SliceInst.getOperand(0);
19378     assert(SliceInst->getOpcode() == ISD::LOAD &&
19379            "It takes more than a zext to get to the loaded slice!!");
19380     ArgChains.push_back(SliceInst.getValue(1));
19381   }
19382 
19383   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
19384                               ArgChains);
19385   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
19386   AddToWorklist(Chain.getNode());
19387   return true;
19388 }
19389 
19390 /// Check to see if V is (and load (ptr), imm), where the load is having
19391 /// specific bytes cleared out.  If so, return the byte size being masked out
19392 /// and the shift amount.
19393 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)19394 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
19395   std::pair<unsigned, unsigned> Result(0, 0);
19396 
19397   // Check for the structure we're looking for.
19398   if (V->getOpcode() != ISD::AND ||
19399       !isa<ConstantSDNode>(V->getOperand(1)) ||
19400       !ISD::isNormalLoad(V->getOperand(0).getNode()))
19401     return Result;
19402 
19403   // Check the chain and pointer.
19404   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
19405   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
19406 
19407   // This only handles simple types.
19408   if (V.getValueType() != MVT::i16 &&
19409       V.getValueType() != MVT::i32 &&
19410       V.getValueType() != MVT::i64)
19411     return Result;
19412 
19413   // Check the constant mask.  Invert it so that the bits being masked out are
19414   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
19415   // follow the sign bit for uniformity.
19416   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
19417   unsigned NotMaskLZ = llvm::countl_zero(NotMask);
19418   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
19419   unsigned NotMaskTZ = llvm::countr_zero(NotMask);
19420   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
19421   if (NotMaskLZ == 64) return Result;  // All zero mask.
19422 
19423   // See if we have a continuous run of bits.  If so, we have 0*1+0*
19424   if (llvm::countr_one(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
19425     return Result;
19426 
19427   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
19428   if (V.getValueType() != MVT::i64 && NotMaskLZ)
19429     NotMaskLZ -= 64-V.getValueSizeInBits();
19430 
19431   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
19432   switch (MaskedBytes) {
19433   case 1:
19434   case 2:
19435   case 4: break;
19436   default: return Result; // All one mask, or 5-byte mask.
19437   }
19438 
19439   // Verify that the first bit starts at a multiple of mask so that the access
19440   // is aligned the same as the access width.
19441   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
19442 
19443   // For narrowing to be valid, it must be the case that the load the
19444   // immediately preceding memory operation before the store.
19445   if (LD == Chain.getNode())
19446     ; // ok.
19447   else if (Chain->getOpcode() == ISD::TokenFactor &&
19448            SDValue(LD, 1).hasOneUse()) {
19449     // LD has only 1 chain use so they are no indirect dependencies.
19450     if (!LD->isOperandOf(Chain.getNode()))
19451       return Result;
19452   } else
19453     return Result; // Fail.
19454 
19455   Result.first = MaskedBytes;
19456   Result.second = NotMaskTZ/8;
19457   return Result;
19458 }
19459 
19460 /// Check to see if IVal is something that provides a value as specified by
19461 /// MaskInfo. If so, replace the specified store with a narrower store of
19462 /// truncated IVal.
19463 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)19464 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
19465                                 SDValue IVal, StoreSDNode *St,
19466                                 DAGCombiner *DC) {
19467   unsigned NumBytes = MaskInfo.first;
19468   unsigned ByteShift = MaskInfo.second;
19469   SelectionDAG &DAG = DC->getDAG();
19470 
19471   // Check to see if IVal is all zeros in the part being masked in by the 'or'
19472   // that uses this.  If not, this is not a replacement.
19473   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
19474                                   ByteShift*8, (ByteShift+NumBytes)*8);
19475   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
19476 
19477   // Check that it is legal on the target to do this.  It is legal if the new
19478   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
19479   // legalization. If the source type is legal, but the store type isn't, see
19480   // if we can use a truncating store.
19481   MVT VT = MVT::getIntegerVT(NumBytes * 8);
19482   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19483   bool UseTruncStore;
19484   if (DC->isTypeLegal(VT))
19485     UseTruncStore = false;
19486   else if (TLI.isTypeLegal(IVal.getValueType()) &&
19487            TLI.isTruncStoreLegal(IVal.getValueType(), VT))
19488     UseTruncStore = true;
19489   else
19490     return SDValue();
19491 
19492   // Can't do this for indexed stores.
19493   if (St->isIndexed())
19494     return SDValue();
19495 
19496   // Check that the target doesn't think this is a bad idea.
19497   if (St->getMemOperand() &&
19498       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
19499                               *St->getMemOperand()))
19500     return SDValue();
19501 
19502   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
19503   // shifted by ByteShift and truncated down to NumBytes.
19504   if (ByteShift) {
19505     SDLoc DL(IVal);
19506     IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
19507                        DAG.getConstant(ByteShift*8, DL,
19508                                     DC->getShiftAmountTy(IVal.getValueType())));
19509   }
19510 
19511   // Figure out the offset for the store and the alignment of the access.
19512   unsigned StOffset;
19513   if (DAG.getDataLayout().isLittleEndian())
19514     StOffset = ByteShift;
19515   else
19516     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
19517 
19518   SDValue Ptr = St->getBasePtr();
19519   if (StOffset) {
19520     SDLoc DL(IVal);
19521     Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(StOffset), DL);
19522   }
19523 
19524   ++OpsNarrowed;
19525   if (UseTruncStore)
19526     return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr,
19527                              St->getPointerInfo().getWithOffset(StOffset),
19528                              VT, St->getOriginalAlign());
19529 
19530   // Truncate down to the new size.
19531   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
19532 
19533   return DAG
19534       .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
19535                 St->getPointerInfo().getWithOffset(StOffset),
19536                 St->getOriginalAlign());
19537 }
19538 
19539 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
19540 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
19541 /// narrowing the load and store if it would end up being a win for performance
19542 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)19543 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
19544   StoreSDNode *ST  = cast<StoreSDNode>(N);
19545   if (!ST->isSimple())
19546     return SDValue();
19547 
19548   SDValue Chain = ST->getChain();
19549   SDValue Value = ST->getValue();
19550   SDValue Ptr   = ST->getBasePtr();
19551   EVT VT = Value.getValueType();
19552 
19553   if (ST->isTruncatingStore() || VT.isVector())
19554     return SDValue();
19555 
19556   unsigned Opc = Value.getOpcode();
19557 
19558   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
19559       !Value.hasOneUse())
19560     return SDValue();
19561 
19562   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
19563   // is a byte mask indicating a consecutive number of bytes, check to see if
19564   // Y is known to provide just those bytes.  If so, we try to replace the
19565   // load + replace + store sequence with a single (narrower) store, which makes
19566   // the load dead.
19567   if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
19568     std::pair<unsigned, unsigned> MaskedLoad;
19569     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
19570     if (MaskedLoad.first)
19571       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
19572                                                   Value.getOperand(1), ST,this))
19573         return NewST;
19574 
19575     // Or is commutative, so try swapping X and Y.
19576     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
19577     if (MaskedLoad.first)
19578       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
19579                                                   Value.getOperand(0), ST,this))
19580         return NewST;
19581   }
19582 
19583   if (!EnableReduceLoadOpStoreWidth)
19584     return SDValue();
19585 
19586   if (Value.getOperand(1).getOpcode() != ISD::Constant)
19587     return SDValue();
19588 
19589   SDValue N0 = Value.getOperand(0);
19590   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
19591       Chain == SDValue(N0.getNode(), 1)) {
19592     LoadSDNode *LD = cast<LoadSDNode>(N0);
19593     if (LD->getBasePtr() != Ptr ||
19594         LD->getPointerInfo().getAddrSpace() !=
19595         ST->getPointerInfo().getAddrSpace())
19596       return SDValue();
19597 
19598     // Find the type to narrow it the load / op / store to.
19599     SDValue N1 = Value.getOperand(1);
19600     unsigned BitWidth = N1.getValueSizeInBits();
19601     APInt Imm = N1->getAsAPIntVal();
19602     if (Opc == ISD::AND)
19603       Imm ^= APInt::getAllOnes(BitWidth);
19604     if (Imm == 0 || Imm.isAllOnes())
19605       return SDValue();
19606     unsigned ShAmt = Imm.countr_zero();
19607     unsigned MSB = BitWidth - Imm.countl_zero() - 1;
19608     unsigned NewBW = NextPowerOf2(MSB - ShAmt);
19609     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
19610     // The narrowing should be profitable, the load/store operation should be
19611     // legal (or custom) and the store size should be equal to the NewVT width.
19612     while (NewBW < BitWidth &&
19613            (NewVT.getStoreSizeInBits() != NewBW ||
19614             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
19615             !TLI.isNarrowingProfitable(VT, NewVT))) {
19616       NewBW = NextPowerOf2(NewBW);
19617       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
19618     }
19619     if (NewBW >= BitWidth)
19620       return SDValue();
19621 
19622     // If the lsb changed does not start at the type bitwidth boundary,
19623     // start at the previous one.
19624     if (ShAmt % NewBW)
19625       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
19626     APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
19627                                    std::min(BitWidth, ShAmt + NewBW));
19628     if ((Imm & Mask) == Imm) {
19629       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
19630       if (Opc == ISD::AND)
19631         NewImm ^= APInt::getAllOnes(NewBW);
19632       uint64_t PtrOff = ShAmt / 8;
19633       // For big endian targets, we need to adjust the offset to the pointer to
19634       // load the correct bytes.
19635       if (DAG.getDataLayout().isBigEndian())
19636         PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
19637 
19638       unsigned IsFast = 0;
19639       Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
19640       if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), NewVT,
19641                                   LD->getAddressSpace(), NewAlign,
19642                                   LD->getMemOperand()->getFlags(), &IsFast) ||
19643           !IsFast)
19644         return SDValue();
19645 
19646       SDValue NewPtr =
19647           DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(PtrOff), SDLoc(LD));
19648       SDValue NewLD =
19649           DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
19650                       LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
19651                       LD->getMemOperand()->getFlags(), LD->getAAInfo());
19652       SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
19653                                    DAG.getConstant(NewImm, SDLoc(Value),
19654                                                    NewVT));
19655       SDValue NewST =
19656           DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
19657                        ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
19658 
19659       AddToWorklist(NewPtr.getNode());
19660       AddToWorklist(NewLD.getNode());
19661       AddToWorklist(NewVal.getNode());
19662       WorklistRemover DeadNodes(*this);
19663       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
19664       ++OpsNarrowed;
19665       return NewST;
19666     }
19667   }
19668 
19669   return SDValue();
19670 }
19671 
19672 /// For a given floating point load / store pair, if the load value isn't used
19673 /// by any other operations, then consider transforming the pair to integer
19674 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)19675 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
19676   StoreSDNode *ST  = cast<StoreSDNode>(N);
19677   SDValue Value = ST->getValue();
19678   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
19679       Value.hasOneUse()) {
19680     LoadSDNode *LD = cast<LoadSDNode>(Value);
19681     EVT VT = LD->getMemoryVT();
19682     if (!VT.isFloatingPoint() ||
19683         VT != ST->getMemoryVT() ||
19684         LD->isNonTemporal() ||
19685         ST->isNonTemporal() ||
19686         LD->getPointerInfo().getAddrSpace() != 0 ||
19687         ST->getPointerInfo().getAddrSpace() != 0)
19688       return SDValue();
19689 
19690     TypeSize VTSize = VT.getSizeInBits();
19691 
19692     // We don't know the size of scalable types at compile time so we cannot
19693     // create an integer of the equivalent size.
19694     if (VTSize.isScalable())
19695       return SDValue();
19696 
19697     unsigned FastLD = 0, FastST = 0;
19698     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedValue());
19699     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
19700         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
19701         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
19702         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
19703         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
19704                                 *LD->getMemOperand(), &FastLD) ||
19705         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
19706                                 *ST->getMemOperand(), &FastST) ||
19707         !FastLD || !FastST)
19708       return SDValue();
19709 
19710     SDValue NewLD =
19711         DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
19712                     LD->getPointerInfo(), LD->getAlign());
19713 
19714     SDValue NewST =
19715         DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
19716                      ST->getPointerInfo(), ST->getAlign());
19717 
19718     AddToWorklist(NewLD.getNode());
19719     AddToWorklist(NewST.getNode());
19720     WorklistRemover DeadNodes(*this);
19721     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
19722     ++LdStFP2Int;
19723     return NewST;
19724   }
19725 
19726   return SDValue();
19727 }
19728 
19729 // This is a helper function for visitMUL to check the profitability
19730 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
19731 // MulNode is the original multiply, AddNode is (add x, c1),
19732 // and ConstNode is c2.
19733 //
19734 // If the (add x, c1) has multiple uses, we could increase
19735 // the number of adds if we make this transformation.
19736 // It would only be worth doing this if we can remove a
19737 // multiply in the process. Check for that here.
19738 // To illustrate:
19739 //     (A + c1) * c3
19740 //     (A + c2) * c3
19741 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue AddNode,SDValue ConstNode)19742 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
19743                                               SDValue ConstNode) {
19744   APInt Val;
19745 
19746   // If the add only has one use, and the target thinks the folding is
19747   // profitable or does not lead to worse code, this would be OK to do.
19748   if (AddNode->hasOneUse() &&
19749       TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
19750     return true;
19751 
19752   // Walk all the users of the constant with which we're multiplying.
19753   for (SDNode *Use : ConstNode->uses()) {
19754     if (Use == MulNode) // This use is the one we're on right now. Skip it.
19755       continue;
19756 
19757     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
19758       SDNode *OtherOp;
19759       SDNode *MulVar = AddNode.getOperand(0).getNode();
19760 
19761       // OtherOp is what we're multiplying against the constant.
19762       if (Use->getOperand(0) == ConstNode)
19763         OtherOp = Use->getOperand(1).getNode();
19764       else
19765         OtherOp = Use->getOperand(0).getNode();
19766 
19767       // Check to see if multiply is with the same operand of our "add".
19768       //
19769       //     ConstNode  = CONST
19770       //     Use = ConstNode * A  <-- visiting Use. OtherOp is A.
19771       //     ...
19772       //     AddNode  = (A + c1)  <-- MulVar is A.
19773       //         = AddNode * ConstNode   <-- current visiting instruction.
19774       //
19775       // If we make this transformation, we will have a common
19776       // multiply (ConstNode * A) that we can save.
19777       if (OtherOp == MulVar)
19778         return true;
19779 
19780       // Now check to see if a future expansion will give us a common
19781       // multiply.
19782       //
19783       //     ConstNode  = CONST
19784       //     AddNode    = (A + c1)
19785       //     ...   = AddNode * ConstNode <-- current visiting instruction.
19786       //     ...
19787       //     OtherOp = (A + c2)
19788       //     Use     = OtherOp * ConstNode <-- visiting Use.
19789       //
19790       // If we make this transformation, we will have a common
19791       // multiply (CONST * A) after we also do the same transformation
19792       // to the "t2" instruction.
19793       if (OtherOp->getOpcode() == ISD::ADD &&
19794           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
19795           OtherOp->getOperand(0).getNode() == MulVar)
19796         return true;
19797     }
19798   }
19799 
19800   // Didn't find a case where this would be profitable.
19801   return false;
19802 }
19803 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)19804 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
19805                                          unsigned NumStores) {
19806   SmallVector<SDValue, 8> Chains;
19807   SmallPtrSet<const SDNode *, 8> Visited;
19808   SDLoc StoreDL(StoreNodes[0].MemNode);
19809 
19810   for (unsigned i = 0; i < NumStores; ++i) {
19811     Visited.insert(StoreNodes[i].MemNode);
19812   }
19813 
19814   // don't include nodes that are children or repeated nodes.
19815   for (unsigned i = 0; i < NumStores; ++i) {
19816     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
19817       Chains.push_back(StoreNodes[i].MemNode->getChain());
19818   }
19819 
19820   assert(!Chains.empty() && "Chain should have generated a chain");
19821   return DAG.getTokenFactor(StoreDL, Chains);
19822 }
19823 
hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes)19824 bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
19825   const Value *UnderlyingObj = nullptr;
19826   for (const auto &MemOp : StoreNodes) {
19827     const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
19828     // Pseudo value like stack frame has its own frame index and size, should
19829     // not use the first store's frame index for other frames.
19830     if (MMO->getPseudoValue())
19831       return false;
19832 
19833     if (!MMO->getValue())
19834       return false;
19835 
19836     const Value *Obj = getUnderlyingObject(MMO->getValue());
19837 
19838     if (UnderlyingObj && UnderlyingObj != Obj)
19839       return false;
19840 
19841     if (!UnderlyingObj)
19842       UnderlyingObj = Obj;
19843   }
19844 
19845   return true;
19846 }
19847 
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)19848 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
19849     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
19850     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
19851   // Make sure we have something to merge.
19852   if (NumStores < 2)
19853     return false;
19854 
19855   assert((!UseTrunc || !UseVector) &&
19856          "This optimization cannot emit a vector truncating store");
19857 
19858   // The latest Node in the DAG.
19859   SDLoc DL(StoreNodes[0].MemNode);
19860 
19861   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
19862   unsigned SizeInBits = NumStores * ElementSizeBits;
19863   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
19864 
19865   std::optional<MachineMemOperand::Flags> Flags;
19866   AAMDNodes AAInfo;
19867   for (unsigned I = 0; I != NumStores; ++I) {
19868     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
19869     if (!Flags) {
19870       Flags = St->getMemOperand()->getFlags();
19871       AAInfo = St->getAAInfo();
19872       continue;
19873     }
19874     // Skip merging if there's an inconsistent flag.
19875     if (Flags != St->getMemOperand()->getFlags())
19876       return false;
19877     // Concatenate AA metadata.
19878     AAInfo = AAInfo.concat(St->getAAInfo());
19879   }
19880 
19881   EVT StoreTy;
19882   if (UseVector) {
19883     unsigned Elts = NumStores * NumMemElts;
19884     // Get the type for the merged vector store.
19885     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
19886   } else
19887     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
19888 
19889   SDValue StoredVal;
19890   if (UseVector) {
19891     if (IsConstantSrc) {
19892       SmallVector<SDValue, 8> BuildVector;
19893       for (unsigned I = 0; I != NumStores; ++I) {
19894         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
19895         SDValue Val = St->getValue();
19896         // If constant is of the wrong type, convert it now.  This comes up
19897         // when one of our stores was truncating.
19898         if (MemVT != Val.getValueType()) {
19899           Val = peekThroughBitcasts(Val);
19900           // Deal with constants of wrong size.
19901           if (ElementSizeBits != Val.getValueSizeInBits()) {
19902             auto *C = dyn_cast<ConstantSDNode>(Val);
19903             if (!C)
19904               // Not clear how to truncate FP values.
19905               // TODO: Handle truncation of build_vector constants
19906               return false;
19907 
19908             EVT IntMemVT =
19909                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
19910             Val = DAG.getConstant(C->getAPIntValue()
19911                                       .zextOrTrunc(Val.getValueSizeInBits())
19912                                       .zextOrTrunc(ElementSizeBits),
19913                                   SDLoc(C), IntMemVT);
19914           }
19915           // Make sure correctly size type is the correct type.
19916           Val = DAG.getBitcast(MemVT, Val);
19917         }
19918         BuildVector.push_back(Val);
19919       }
19920       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
19921                                                : ISD::BUILD_VECTOR,
19922                               DL, StoreTy, BuildVector);
19923     } else {
19924       SmallVector<SDValue, 8> Ops;
19925       for (unsigned i = 0; i < NumStores; ++i) {
19926         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
19927         SDValue Val = peekThroughBitcasts(St->getValue());
19928         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
19929         // type MemVT. If the underlying value is not the correct
19930         // type, but it is an extraction of an appropriate vector we
19931         // can recast Val to be of the correct type. This may require
19932         // converting between EXTRACT_VECTOR_ELT and
19933         // EXTRACT_SUBVECTOR.
19934         if ((MemVT != Val.getValueType()) &&
19935             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
19936              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
19937           EVT MemVTScalarTy = MemVT.getScalarType();
19938           // We may need to add a bitcast here to get types to line up.
19939           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
19940             Val = DAG.getBitcast(MemVT, Val);
19941           } else if (MemVT.isVector() &&
19942                      Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
19943             Val = DAG.getNode(ISD::BUILD_VECTOR, DL, MemVT, Val);
19944           } else {
19945             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
19946                                             : ISD::EXTRACT_VECTOR_ELT;
19947             SDValue Vec = Val.getOperand(0);
19948             SDValue Idx = Val.getOperand(1);
19949             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
19950           }
19951         }
19952         Ops.push_back(Val);
19953       }
19954 
19955       // Build the extracted vector elements back into a vector.
19956       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
19957                                                : ISD::BUILD_VECTOR,
19958                               DL, StoreTy, Ops);
19959     }
19960   } else {
19961     // We should always use a vector store when merging extracted vector
19962     // elements, so this path implies a store of constants.
19963     assert(IsConstantSrc && "Merged vector elements should use vector store");
19964 
19965     APInt StoreInt(SizeInBits, 0);
19966 
19967     // Construct a single integer constant which is made of the smaller
19968     // constant inputs.
19969     bool IsLE = DAG.getDataLayout().isLittleEndian();
19970     for (unsigned i = 0; i < NumStores; ++i) {
19971       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
19972       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
19973 
19974       SDValue Val = St->getValue();
19975       Val = peekThroughBitcasts(Val);
19976       StoreInt <<= ElementSizeBits;
19977       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
19978         StoreInt |= C->getAPIntValue()
19979                         .zextOrTrunc(ElementSizeBits)
19980                         .zextOrTrunc(SizeInBits);
19981       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
19982         StoreInt |= C->getValueAPF()
19983                         .bitcastToAPInt()
19984                         .zextOrTrunc(ElementSizeBits)
19985                         .zextOrTrunc(SizeInBits);
19986         // If fp truncation is necessary give up for now.
19987         if (MemVT.getSizeInBits() != ElementSizeBits)
19988           return false;
19989       } else if (ISD::isBuildVectorOfConstantSDNodes(Val.getNode()) ||
19990                  ISD::isBuildVectorOfConstantFPSDNodes(Val.getNode())) {
19991         // Not yet handled
19992         return false;
19993       } else {
19994         llvm_unreachable("Invalid constant element type");
19995       }
19996     }
19997 
19998     // Create the new Load and Store operations.
19999     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
20000   }
20001 
20002   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20003   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
20004   bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
20005 
20006   // make sure we use trunc store if it's necessary to be legal.
20007   // When generate the new widen store, if the first store's pointer info can
20008   // not be reused, discard the pointer info except the address space because
20009   // now the widen store can not be represented by the original pointer info
20010   // which is for the narrow memory object.
20011   SDValue NewStore;
20012   if (!UseTrunc) {
20013     NewStore = DAG.getStore(
20014         NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
20015         CanReusePtrInfo
20016             ? FirstInChain->getPointerInfo()
20017             : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20018         FirstInChain->getAlign(), *Flags, AAInfo);
20019   } else { // Must be realized as a trunc store
20020     EVT LegalizedStoredValTy =
20021         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
20022     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
20023     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
20024     SDValue ExtendedStoreVal =
20025         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
20026                         LegalizedStoredValTy);
20027     NewStore = DAG.getTruncStore(
20028         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
20029         CanReusePtrInfo
20030             ? FirstInChain->getPointerInfo()
20031             : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20032         StoredVal.getValueType() /*TVT*/, FirstInChain->getAlign(), *Flags,
20033         AAInfo);
20034   }
20035 
20036   // Replace all merged stores with the new store.
20037   for (unsigned i = 0; i < NumStores; ++i)
20038     CombineTo(StoreNodes[i].MemNode, NewStore);
20039 
20040   AddToWorklist(NewChain.getNode());
20041   return true;
20042 }
20043 
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)20044 void DAGCombiner::getStoreMergeCandidates(
20045     StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
20046     SDNode *&RootNode) {
20047   // This holds the base pointer, index, and the offset in bytes from the base
20048   // pointer. We must have a base and an offset. Do not handle stores to undef
20049   // base pointers.
20050   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
20051   if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
20052     return;
20053 
20054   SDValue Val = peekThroughBitcasts(St->getValue());
20055   StoreSource StoreSrc = getStoreSource(Val);
20056   assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
20057 
20058   // Match on loadbaseptr if relevant.
20059   EVT MemVT = St->getMemoryVT();
20060   BaseIndexOffset LBasePtr;
20061   EVT LoadVT;
20062   if (StoreSrc == StoreSource::Load) {
20063     auto *Ld = cast<LoadSDNode>(Val);
20064     LBasePtr = BaseIndexOffset::match(Ld, DAG);
20065     LoadVT = Ld->getMemoryVT();
20066     // Load and store should be the same type.
20067     if (MemVT != LoadVT)
20068       return;
20069     // Loads must only have one use.
20070     if (!Ld->hasNUsesOfValue(1, 0))
20071       return;
20072     // The memory operands must not be volatile/indexed/atomic.
20073     // TODO: May be able to relax for unordered atomics (see D66309)
20074     if (!Ld->isSimple() || Ld->isIndexed())
20075       return;
20076   }
20077   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
20078                             int64_t &Offset) -> bool {
20079     // The memory operands must not be volatile/indexed/atomic.
20080     // TODO: May be able to relax for unordered atomics (see D66309)
20081     if (!Other->isSimple() || Other->isIndexed())
20082       return false;
20083     // Don't mix temporal stores with non-temporal stores.
20084     if (St->isNonTemporal() != Other->isNonTemporal())
20085       return false;
20086     if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*St, *Other))
20087       return false;
20088     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
20089     // Allow merging constants of different types as integers.
20090     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
20091                                            : Other->getMemoryVT() != MemVT;
20092     switch (StoreSrc) {
20093     case StoreSource::Load: {
20094       if (NoTypeMatch)
20095         return false;
20096       // The Load's Base Ptr must also match.
20097       auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
20098       if (!OtherLd)
20099         return false;
20100       BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
20101       if (LoadVT != OtherLd->getMemoryVT())
20102         return false;
20103       // Loads must only have one use.
20104       if (!OtherLd->hasNUsesOfValue(1, 0))
20105         return false;
20106       // The memory operands must not be volatile/indexed/atomic.
20107       // TODO: May be able to relax for unordered atomics (see D66309)
20108       if (!OtherLd->isSimple() || OtherLd->isIndexed())
20109         return false;
20110       // Don't mix temporal loads with non-temporal loads.
20111       if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
20112         return false;
20113       if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*cast<LoadSDNode>(Val),
20114                                                    *OtherLd))
20115         return false;
20116       if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
20117         return false;
20118       break;
20119     }
20120     case StoreSource::Constant:
20121       if (NoTypeMatch)
20122         return false;
20123       if (getStoreSource(OtherBC) != StoreSource::Constant)
20124         return false;
20125       break;
20126     case StoreSource::Extract:
20127       // Do not merge truncated stores here.
20128       if (Other->isTruncatingStore())
20129         return false;
20130       if (!MemVT.bitsEq(OtherBC.getValueType()))
20131         return false;
20132       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20133           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20134         return false;
20135       break;
20136     default:
20137       llvm_unreachable("Unhandled store source for merging");
20138     }
20139     Ptr = BaseIndexOffset::match(Other, DAG);
20140     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
20141   };
20142 
20143   // Check if the pair of StoreNode and the RootNode already bail out many
20144   // times which is over the limit in dependence check.
20145   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
20146                                         SDNode *RootNode) -> bool {
20147     auto RootCount = StoreRootCountMap.find(StoreNode);
20148     return RootCount != StoreRootCountMap.end() &&
20149            RootCount->second.first == RootNode &&
20150            RootCount->second.second > StoreMergeDependenceLimit;
20151   };
20152 
20153   auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
20154     // This must be a chain use.
20155     if (UseIter.getOperandNo() != 0)
20156       return;
20157     if (auto *OtherStore = dyn_cast<StoreSDNode>(*UseIter)) {
20158       BaseIndexOffset Ptr;
20159       int64_t PtrDiff;
20160       if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
20161           !OverLimitInDependenceCheck(OtherStore, RootNode))
20162         StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
20163     }
20164   };
20165 
20166   // We looking for a root node which is an ancestor to all mergable
20167   // stores. We search up through a load, to our root and then down
20168   // through all children. For instance we will find Store{1,2,3} if
20169   // St is Store1, Store2. or Store3 where the root is not a load
20170   // which always true for nonvolatile ops. TODO: Expand
20171   // the search to find all valid candidates through multiple layers of loads.
20172   //
20173   // Root
20174   // |-------|-------|
20175   // Load    Load    Store3
20176   // |       |
20177   // Store1   Store2
20178   //
20179   // FIXME: We should be able to climb and
20180   // descend TokenFactors to find candidates as well.
20181 
20182   RootNode = St->getChain().getNode();
20183 
20184   unsigned NumNodesExplored = 0;
20185   const unsigned MaxSearchNodes = 1024;
20186   if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
20187     RootNode = Ldn->getChain().getNode();
20188     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20189          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
20190       if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) { // walk down chain
20191         for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
20192           TryToAddCandidate(I2);
20193       }
20194       // Check stores that depend on the root (e.g. Store 3 in the chart above).
20195       if (I.getOperandNo() == 0 && isa<StoreSDNode>(*I)) {
20196         TryToAddCandidate(I);
20197       }
20198     }
20199   } else {
20200     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20201          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
20202       TryToAddCandidate(I);
20203   }
20204 }
20205 
20206 // We need to check that merging these stores does not cause a loop in the
20207 // DAG. Any store candidate may depend on another candidate indirectly through
20208 // its operands. Check in parallel by searching up from operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)20209 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
20210     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
20211     SDNode *RootNode) {
20212   // FIXME: We should be able to truncate a full search of
20213   // predecessors by doing a BFS and keeping tabs the originating
20214   // stores from which worklist nodes come from in a similar way to
20215   // TokenFactor simplfication.
20216 
20217   SmallPtrSet<const SDNode *, 32> Visited;
20218   SmallVector<const SDNode *, 8> Worklist;
20219 
20220   // RootNode is a predecessor to all candidates so we need not search
20221   // past it. Add RootNode (peeking through TokenFactors). Do not count
20222   // these towards size check.
20223 
20224   Worklist.push_back(RootNode);
20225   while (!Worklist.empty()) {
20226     auto N = Worklist.pop_back_val();
20227     if (!Visited.insert(N).second)
20228       continue; // Already present in Visited.
20229     if (N->getOpcode() == ISD::TokenFactor) {
20230       for (SDValue Op : N->ops())
20231         Worklist.push_back(Op.getNode());
20232     }
20233   }
20234 
20235   // Don't count pruning nodes towards max.
20236   unsigned int Max = 1024 + Visited.size();
20237   // Search Ops of store candidates.
20238   for (unsigned i = 0; i < NumStores; ++i) {
20239     SDNode *N = StoreNodes[i].MemNode;
20240     // Of the 4 Store Operands:
20241     //   * Chain (Op 0) -> We have already considered these
20242     //                     in candidate selection, but only by following the
20243     //                     chain dependencies. We could still have a chain
20244     //                     dependency to a load, that has a non-chain dep to
20245     //                     another load, that depends on a store, etc. So it is
20246     //                     possible to have dependencies that consist of a mix
20247     //                     of chain and non-chain deps, and we need to include
20248     //                     chain operands in the analysis here..
20249     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
20250     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
20251     //                       but aren't necessarily fromt the same base node, so
20252     //                       cycles possible (e.g. via indexed store).
20253     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
20254     //               non-indexed stores). Not constant on all targets (e.g. ARM)
20255     //               and so can participate in a cycle.
20256     for (unsigned j = 0; j < N->getNumOperands(); ++j)
20257       Worklist.push_back(N->getOperand(j).getNode());
20258   }
20259   // Search through DAG. We can stop early if we find a store node.
20260   for (unsigned i = 0; i < NumStores; ++i)
20261     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
20262                                      Max)) {
20263       // If the searching bail out, record the StoreNode and RootNode in the
20264       // StoreRootCountMap. If we have seen the pair many times over a limit,
20265       // we won't add the StoreNode into StoreNodes set again.
20266       if (Visited.size() >= Max) {
20267         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
20268         if (RootCount.first == RootNode)
20269           RootCount.second++;
20270         else
20271           RootCount = {RootNode, 1};
20272       }
20273       return false;
20274     }
20275   return true;
20276 }
20277 
20278 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const20279 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
20280                                   int64_t ElementSizeBytes) const {
20281   while (true) {
20282     // Find a store past the width of the first store.
20283     size_t StartIdx = 0;
20284     while ((StartIdx + 1 < StoreNodes.size()) &&
20285            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
20286               StoreNodes[StartIdx + 1].OffsetFromBase)
20287       ++StartIdx;
20288 
20289     // Bail if we don't have enough candidates to merge.
20290     if (StartIdx + 1 >= StoreNodes.size())
20291       return 0;
20292 
20293     // Trim stores that overlapped with the first store.
20294     if (StartIdx)
20295       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
20296 
20297     // Scan the memory operations on the chain and find the first
20298     // non-consecutive store memory address.
20299     unsigned NumConsecutiveStores = 1;
20300     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
20301     // Check that the addresses are consecutive starting from the second
20302     // element in the list of stores.
20303     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
20304       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
20305       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20306         break;
20307       NumConsecutiveStores = i + 1;
20308     }
20309     if (NumConsecutiveStores > 1)
20310       return NumConsecutiveStores;
20311 
20312     // There are no consecutive stores at the start of the list.
20313     // Remove the first store and try again.
20314     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
20315   }
20316 }
20317 
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)20318 bool DAGCombiner::tryStoreMergeOfConstants(
20319     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20320     EVT MemVT, SDNode *RootNode, bool AllowVectors) {
20321   LLVMContext &Context = *DAG.getContext();
20322   const DataLayout &DL = DAG.getDataLayout();
20323   int64_t ElementSizeBytes = MemVT.getStoreSize();
20324   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20325   bool MadeChange = false;
20326 
20327   // Store the constants into memory as one consecutive store.
20328   while (NumConsecutiveStores >= 2) {
20329     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20330     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20331     Align FirstStoreAlign = FirstInChain->getAlign();
20332     unsigned LastLegalType = 1;
20333     unsigned LastLegalVectorType = 1;
20334     bool LastIntegerTrunc = false;
20335     bool NonZero = false;
20336     unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
20337     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20338       StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
20339       SDValue StoredVal = ST->getValue();
20340       bool IsElementZero = false;
20341       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
20342         IsElementZero = C->isZero();
20343       else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
20344         IsElementZero = C->getConstantFPValue()->isNullValue();
20345       else if (ISD::isBuildVectorAllZeros(StoredVal.getNode()))
20346         IsElementZero = true;
20347       if (IsElementZero) {
20348         if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
20349           FirstZeroAfterNonZero = i;
20350       }
20351       NonZero |= !IsElementZero;
20352 
20353       // Find a legal type for the constant store.
20354       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20355       EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
20356       unsigned IsFast = 0;
20357 
20358       // Break early when size is too large to be legal.
20359       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20360         break;
20361 
20362       if (TLI.isTypeLegal(StoreTy) &&
20363           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
20364                                DAG.getMachineFunction()) &&
20365           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20366                                  *FirstInChain->getMemOperand(), &IsFast) &&
20367           IsFast) {
20368         LastIntegerTrunc = false;
20369         LastLegalType = i + 1;
20370         // Or check whether a truncstore is legal.
20371       } else if (TLI.getTypeAction(Context, StoreTy) ==
20372                  TargetLowering::TypePromoteInteger) {
20373         EVT LegalizedStoredValTy =
20374             TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
20375         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
20376             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
20377                                  DAG.getMachineFunction()) &&
20378             TLI.allowsMemoryAccess(Context, DL, StoreTy,
20379                                    *FirstInChain->getMemOperand(), &IsFast) &&
20380             IsFast) {
20381           LastIntegerTrunc = true;
20382           LastLegalType = i + 1;
20383         }
20384       }
20385 
20386       // We only use vectors if the target allows it and the function is not
20387       // marked with the noimplicitfloat attribute.
20388       if (TLI.storeOfVectorConstantIsCheap(!NonZero, MemVT, i + 1, FirstStoreAS) &&
20389           AllowVectors) {
20390         // Find a legal type for the vector store.
20391         unsigned Elts = (i + 1) * NumMemElts;
20392         EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
20393         if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
20394             TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
20395             TLI.allowsMemoryAccess(Context, DL, Ty,
20396                                    *FirstInChain->getMemOperand(), &IsFast) &&
20397             IsFast)
20398           LastLegalVectorType = i + 1;
20399       }
20400     }
20401 
20402     bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
20403     unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
20404     bool UseTrunc = LastIntegerTrunc && !UseVector;
20405 
20406     // Check if we found a legal integer type that creates a meaningful
20407     // merge.
20408     if (NumElem < 2) {
20409       // We know that candidate stores are in order and of correct
20410       // shape. While there is no mergeable sequence from the
20411       // beginning one may start later in the sequence. The only
20412       // reason a merge of size N could have failed where another of
20413       // the same size would not have, is if the alignment has
20414       // improved or we've dropped a non-zero value. Drop as many
20415       // candidates as we can here.
20416       unsigned NumSkip = 1;
20417       while ((NumSkip < NumConsecutiveStores) &&
20418              (NumSkip < FirstZeroAfterNonZero) &&
20419              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20420         NumSkip++;
20421 
20422       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
20423       NumConsecutiveStores -= NumSkip;
20424       continue;
20425     }
20426 
20427     // Check that we can merge these candidates without causing a cycle.
20428     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
20429                                                   RootNode)) {
20430       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
20431       NumConsecutiveStores -= NumElem;
20432       continue;
20433     }
20434 
20435     MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
20436                                                   /*IsConstantSrc*/ true,
20437                                                   UseVector, UseTrunc);
20438 
20439     // Remove merged stores for next iteration.
20440     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
20441     NumConsecutiveStores -= NumElem;
20442   }
20443   return MadeChange;
20444 }
20445 
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)20446 bool DAGCombiner::tryStoreMergeOfExtracts(
20447     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20448     EVT MemVT, SDNode *RootNode) {
20449   LLVMContext &Context = *DAG.getContext();
20450   const DataLayout &DL = DAG.getDataLayout();
20451   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20452   bool MadeChange = false;
20453 
20454   // Loop on Consecutive Stores on success.
20455   while (NumConsecutiveStores >= 2) {
20456     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20457     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20458     Align FirstStoreAlign = FirstInChain->getAlign();
20459     unsigned NumStoresToMerge = 1;
20460     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20461       // Find a legal type for the vector store.
20462       unsigned Elts = (i + 1) * NumMemElts;
20463       EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
20464       unsigned IsFast = 0;
20465 
20466       // Break early when size is too large to be legal.
20467       if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
20468         break;
20469 
20470       if (TLI.isTypeLegal(Ty) &&
20471           TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
20472           TLI.allowsMemoryAccess(Context, DL, Ty,
20473                                  *FirstInChain->getMemOperand(), &IsFast) &&
20474           IsFast)
20475         NumStoresToMerge = i + 1;
20476     }
20477 
20478     // Check if we found a legal integer type creating a meaningful
20479     // merge.
20480     if (NumStoresToMerge < 2) {
20481       // We know that candidate stores are in order and of correct
20482       // shape. While there is no mergeable sequence from the
20483       // beginning one may start later in the sequence. The only
20484       // reason a merge of size N could have failed where another of
20485       // the same size would not have, is if the alignment has
20486       // improved. Drop as many candidates as we can here.
20487       unsigned NumSkip = 1;
20488       while ((NumSkip < NumConsecutiveStores) &&
20489              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20490         NumSkip++;
20491 
20492       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
20493       NumConsecutiveStores -= NumSkip;
20494       continue;
20495     }
20496 
20497     // Check that we can merge these candidates without causing a cycle.
20498     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
20499                                                   RootNode)) {
20500       StoreNodes.erase(StoreNodes.begin(),
20501                        StoreNodes.begin() + NumStoresToMerge);
20502       NumConsecutiveStores -= NumStoresToMerge;
20503       continue;
20504     }
20505 
20506     MadeChange |= mergeStoresOfConstantsOrVecElts(
20507         StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
20508         /*UseVector*/ true, /*UseTrunc*/ false);
20509 
20510     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
20511     NumConsecutiveStores -= NumStoresToMerge;
20512   }
20513   return MadeChange;
20514 }
20515 
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)20516 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
20517                                        unsigned NumConsecutiveStores, EVT MemVT,
20518                                        SDNode *RootNode, bool AllowVectors,
20519                                        bool IsNonTemporalStore,
20520                                        bool IsNonTemporalLoad) {
20521   LLVMContext &Context = *DAG.getContext();
20522   const DataLayout &DL = DAG.getDataLayout();
20523   int64_t ElementSizeBytes = MemVT.getStoreSize();
20524   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20525   bool MadeChange = false;
20526 
20527   // Look for load nodes which are used by the stored values.
20528   SmallVector<MemOpLink, 8> LoadNodes;
20529 
20530   // Find acceptable loads. Loads need to have the same chain (token factor),
20531   // must not be zext, volatile, indexed, and they must be consecutive.
20532   BaseIndexOffset LdBasePtr;
20533 
20534   for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20535     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
20536     SDValue Val = peekThroughBitcasts(St->getValue());
20537     LoadSDNode *Ld = cast<LoadSDNode>(Val);
20538 
20539     BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
20540     // If this is not the first ptr that we check.
20541     int64_t LdOffset = 0;
20542     if (LdBasePtr.getBase().getNode()) {
20543       // The base ptr must be the same.
20544       if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
20545         break;
20546     } else {
20547       // Check that all other base pointers are the same as this one.
20548       LdBasePtr = LdPtr;
20549     }
20550 
20551     // We found a potential memory operand to merge.
20552     LoadNodes.push_back(MemOpLink(Ld, LdOffset));
20553   }
20554 
20555   while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
20556     Align RequiredAlignment;
20557     bool NeedRotate = false;
20558     if (LoadNodes.size() == 2) {
20559       // If we have load/store pair instructions and we only have two values,
20560       // don't bother merging.
20561       if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
20562           StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
20563         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
20564         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
20565         break;
20566       }
20567       // If the loads are reversed, see if we can rotate the halves into place.
20568       int64_t Offset0 = LoadNodes[0].OffsetFromBase;
20569       int64_t Offset1 = LoadNodes[1].OffsetFromBase;
20570       EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
20571       if (Offset0 - Offset1 == ElementSizeBytes &&
20572           (hasOperation(ISD::ROTL, PairVT) ||
20573            hasOperation(ISD::ROTR, PairVT))) {
20574         std::swap(LoadNodes[0], LoadNodes[1]);
20575         NeedRotate = true;
20576       }
20577     }
20578     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20579     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20580     Align FirstStoreAlign = FirstInChain->getAlign();
20581     LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
20582 
20583     // Scan the memory operations on the chain and find the first
20584     // non-consecutive load memory address. These variables hold the index in
20585     // the store node array.
20586 
20587     unsigned LastConsecutiveLoad = 1;
20588 
20589     // This variable refers to the size and not index in the array.
20590     unsigned LastLegalVectorType = 1;
20591     unsigned LastLegalIntegerType = 1;
20592     bool isDereferenceable = true;
20593     bool DoIntegerTruncate = false;
20594     int64_t StartAddress = LoadNodes[0].OffsetFromBase;
20595     SDValue LoadChain = FirstLoad->getChain();
20596     for (unsigned i = 1; i < LoadNodes.size(); ++i) {
20597       // All loads must share the same chain.
20598       if (LoadNodes[i].MemNode->getChain() != LoadChain)
20599         break;
20600 
20601       int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
20602       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20603         break;
20604       LastConsecutiveLoad = i;
20605 
20606       if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
20607         isDereferenceable = false;
20608 
20609       // Find a legal type for the vector store.
20610       unsigned Elts = (i + 1) * NumMemElts;
20611       EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
20612 
20613       // Break early when size is too large to be legal.
20614       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20615         break;
20616 
20617       unsigned IsFastSt = 0;
20618       unsigned IsFastLd = 0;
20619       // Don't try vector types if we need a rotate. We may still fail the
20620       // legality checks for the integer type, but we can't handle the rotate
20621       // case with vectors.
20622       // FIXME: We could use a shuffle in place of the rotate.
20623       if (!NeedRotate && TLI.isTypeLegal(StoreTy) &&
20624           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
20625                                DAG.getMachineFunction()) &&
20626           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20627                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
20628           IsFastSt &&
20629           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20630                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
20631           IsFastLd) {
20632         LastLegalVectorType = i + 1;
20633       }
20634 
20635       // Find a legal type for the integer store.
20636       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20637       StoreTy = EVT::getIntegerVT(Context, SizeInBits);
20638       if (TLI.isTypeLegal(StoreTy) &&
20639           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
20640                                DAG.getMachineFunction()) &&
20641           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20642                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
20643           IsFastSt &&
20644           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20645                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
20646           IsFastLd) {
20647         LastLegalIntegerType = i + 1;
20648         DoIntegerTruncate = false;
20649         // Or check whether a truncstore and extload is legal.
20650       } else if (TLI.getTypeAction(Context, StoreTy) ==
20651                  TargetLowering::TypePromoteInteger) {
20652         EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
20653         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
20654             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
20655                                  DAG.getMachineFunction()) &&
20656             TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
20657             TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
20658             TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
20659             TLI.allowsMemoryAccess(Context, DL, StoreTy,
20660                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
20661             IsFastSt &&
20662             TLI.allowsMemoryAccess(Context, DL, StoreTy,
20663                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
20664             IsFastLd) {
20665           LastLegalIntegerType = i + 1;
20666           DoIntegerTruncate = true;
20667         }
20668       }
20669     }
20670 
20671     // Only use vector types if the vector type is larger than the integer
20672     // type. If they are the same, use integers.
20673     bool UseVectorTy =
20674         LastLegalVectorType > LastLegalIntegerType && AllowVectors;
20675     unsigned LastLegalType =
20676         std::max(LastLegalVectorType, LastLegalIntegerType);
20677 
20678     // We add +1 here because the LastXXX variables refer to location while
20679     // the NumElem refers to array/index size.
20680     unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
20681     NumElem = std::min(LastLegalType, NumElem);
20682     Align FirstLoadAlign = FirstLoad->getAlign();
20683 
20684     if (NumElem < 2) {
20685       // We know that candidate stores are in order and of correct
20686       // shape. While there is no mergeable sequence from the
20687       // beginning one may start later in the sequence. The only
20688       // reason a merge of size N could have failed where another of
20689       // the same size would not have is if the alignment or either
20690       // the load or store has improved. Drop as many candidates as we
20691       // can here.
20692       unsigned NumSkip = 1;
20693       while ((NumSkip < LoadNodes.size()) &&
20694              (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
20695              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20696         NumSkip++;
20697       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
20698       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
20699       NumConsecutiveStores -= NumSkip;
20700       continue;
20701     }
20702 
20703     // Check that we can merge these candidates without causing a cycle.
20704     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
20705                                                   RootNode)) {
20706       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
20707       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
20708       NumConsecutiveStores -= NumElem;
20709       continue;
20710     }
20711 
20712     // Find if it is better to use vectors or integers to load and store
20713     // to memory.
20714     EVT JointMemOpVT;
20715     if (UseVectorTy) {
20716       // Find a legal type for the vector store.
20717       unsigned Elts = NumElem * NumMemElts;
20718       JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
20719     } else {
20720       unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
20721       JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
20722     }
20723 
20724     SDLoc LoadDL(LoadNodes[0].MemNode);
20725     SDLoc StoreDL(StoreNodes[0].MemNode);
20726 
20727     // The merged loads are required to have the same incoming chain, so
20728     // using the first's chain is acceptable.
20729 
20730     SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
20731     bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
20732     AddToWorklist(NewStoreChain.getNode());
20733 
20734     MachineMemOperand::Flags LdMMOFlags =
20735         isDereferenceable ? MachineMemOperand::MODereferenceable
20736                           : MachineMemOperand::MONone;
20737     if (IsNonTemporalLoad)
20738       LdMMOFlags |= MachineMemOperand::MONonTemporal;
20739 
20740     LdMMOFlags |= TLI.getTargetMMOFlags(*FirstLoad);
20741 
20742     MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
20743                                               ? MachineMemOperand::MONonTemporal
20744                                               : MachineMemOperand::MONone;
20745 
20746     StMMOFlags |= TLI.getTargetMMOFlags(*StoreNodes[0].MemNode);
20747 
20748     SDValue NewLoad, NewStore;
20749     if (UseVectorTy || !DoIntegerTruncate) {
20750       NewLoad = DAG.getLoad(
20751           JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
20752           FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
20753       SDValue StoreOp = NewLoad;
20754       if (NeedRotate) {
20755         unsigned LoadWidth = ElementSizeBytes * 8 * 2;
20756         assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
20757                "Unexpected type for rotate-able load pair");
20758         SDValue RotAmt =
20759             DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
20760         // Target can convert to the identical ROTR if it does not have ROTL.
20761         StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
20762       }
20763       NewStore = DAG.getStore(
20764           NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
20765           CanReusePtrInfo ? FirstInChain->getPointerInfo()
20766                           : MachinePointerInfo(FirstStoreAS),
20767           FirstStoreAlign, StMMOFlags);
20768     } else { // This must be the truncstore/extload case
20769       EVT ExtendedTy =
20770           TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
20771       NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
20772                                FirstLoad->getChain(), FirstLoad->getBasePtr(),
20773                                FirstLoad->getPointerInfo(), JointMemOpVT,
20774                                FirstLoadAlign, LdMMOFlags);
20775       NewStore = DAG.getTruncStore(
20776           NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
20777           CanReusePtrInfo ? FirstInChain->getPointerInfo()
20778                           : MachinePointerInfo(FirstStoreAS),
20779           JointMemOpVT, FirstInChain->getAlign(),
20780           FirstInChain->getMemOperand()->getFlags());
20781     }
20782 
20783     // Transfer chain users from old loads to the new load.
20784     for (unsigned i = 0; i < NumElem; ++i) {
20785       LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
20786       DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
20787                                     SDValue(NewLoad.getNode(), 1));
20788     }
20789 
20790     // Replace all stores with the new store. Recursively remove corresponding
20791     // values if they are no longer used.
20792     for (unsigned i = 0; i < NumElem; ++i) {
20793       SDValue Val = StoreNodes[i].MemNode->getOperand(1);
20794       CombineTo(StoreNodes[i].MemNode, NewStore);
20795       if (Val->use_empty())
20796         recursivelyDeleteUnusedNodes(Val.getNode());
20797     }
20798 
20799     MadeChange = true;
20800     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
20801     LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
20802     NumConsecutiveStores -= NumElem;
20803   }
20804   return MadeChange;
20805 }
20806 
mergeConsecutiveStores(StoreSDNode * St)20807 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
20808   if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
20809     return false;
20810 
20811   // TODO: Extend this function to merge stores of scalable vectors.
20812   // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
20813   // store since we know <vscale x 16 x i8> is exactly twice as large as
20814   // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
20815   EVT MemVT = St->getMemoryVT();
20816   if (MemVT.isScalableVT())
20817     return false;
20818   if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
20819     return false;
20820 
20821   // This function cannot currently deal with non-byte-sized memory sizes.
20822   int64_t ElementSizeBytes = MemVT.getStoreSize();
20823   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
20824     return false;
20825 
20826   // Do not bother looking at stored values that are not constants, loads, or
20827   // extracted vector elements.
20828   SDValue StoredVal = peekThroughBitcasts(St->getValue());
20829   const StoreSource StoreSrc = getStoreSource(StoredVal);
20830   if (StoreSrc == StoreSource::Unknown)
20831     return false;
20832 
20833   SmallVector<MemOpLink, 8> StoreNodes;
20834   SDNode *RootNode;
20835   // Find potential store merge candidates by searching through chain sub-DAG
20836   getStoreMergeCandidates(St, StoreNodes, RootNode);
20837 
20838   // Check if there is anything to merge.
20839   if (StoreNodes.size() < 2)
20840     return false;
20841 
20842   // Sort the memory operands according to their distance from the
20843   // base pointer.
20844   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
20845     return LHS.OffsetFromBase < RHS.OffsetFromBase;
20846   });
20847 
20848   bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
20849       Attribute::NoImplicitFloat);
20850   bool IsNonTemporalStore = St->isNonTemporal();
20851   bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
20852                            cast<LoadSDNode>(StoredVal)->isNonTemporal();
20853 
20854   // Store Merge attempts to merge the lowest stores. This generally
20855   // works out as if successful, as the remaining stores are checked
20856   // after the first collection of stores is merged. However, in the
20857   // case that a non-mergeable store is found first, e.g., {p[-2],
20858   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
20859   // mergeable cases. To prevent this, we prune such stores from the
20860   // front of StoreNodes here.
20861   bool MadeChange = false;
20862   while (StoreNodes.size() > 1) {
20863     unsigned NumConsecutiveStores =
20864         getConsecutiveStores(StoreNodes, ElementSizeBytes);
20865     // There are no more stores in the list to examine.
20866     if (NumConsecutiveStores == 0)
20867       return MadeChange;
20868 
20869     // We have at least 2 consecutive stores. Try to merge them.
20870     assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
20871     switch (StoreSrc) {
20872     case StoreSource::Constant:
20873       MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
20874                                              MemVT, RootNode, AllowVectors);
20875       break;
20876 
20877     case StoreSource::Extract:
20878       MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
20879                                             MemVT, RootNode);
20880       break;
20881 
20882     case StoreSource::Load:
20883       MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
20884                                          MemVT, RootNode, AllowVectors,
20885                                          IsNonTemporalStore, IsNonTemporalLoad);
20886       break;
20887 
20888     default:
20889       llvm_unreachable("Unhandled store source type");
20890     }
20891   }
20892   return MadeChange;
20893 }
20894 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)20895 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
20896   SDLoc SL(ST);
20897   SDValue ReplStore;
20898 
20899   // Replace the chain to avoid dependency.
20900   if (ST->isTruncatingStore()) {
20901     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
20902                                   ST->getBasePtr(), ST->getMemoryVT(),
20903                                   ST->getMemOperand());
20904   } else {
20905     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
20906                              ST->getMemOperand());
20907   }
20908 
20909   // Create token to keep both nodes around.
20910   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
20911                               MVT::Other, ST->getChain(), ReplStore);
20912 
20913   // Make sure the new and old chains are cleaned up.
20914   AddToWorklist(Token.getNode());
20915 
20916   // Don't add users to work list.
20917   return CombineTo(ST, Token, false);
20918 }
20919 
replaceStoreOfFPConstant(StoreSDNode * ST)20920 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
20921   SDValue Value = ST->getValue();
20922   if (Value.getOpcode() == ISD::TargetConstantFP)
20923     return SDValue();
20924 
20925   if (!ISD::isNormalStore(ST))
20926     return SDValue();
20927 
20928   SDLoc DL(ST);
20929 
20930   SDValue Chain = ST->getChain();
20931   SDValue Ptr = ST->getBasePtr();
20932 
20933   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
20934 
20935   // NOTE: If the original store is volatile, this transform must not increase
20936   // the number of stores.  For example, on x86-32 an f64 can be stored in one
20937   // processor operation but an i64 (which is not legal) requires two.  So the
20938   // transform should not be done in this case.
20939 
20940   SDValue Tmp;
20941   switch (CFP->getSimpleValueType(0).SimpleTy) {
20942   default:
20943     llvm_unreachable("Unknown FP type");
20944   case MVT::f16:    // We don't do this for these yet.
20945   case MVT::bf16:
20946   case MVT::f80:
20947   case MVT::f128:
20948   case MVT::ppcf128:
20949     return SDValue();
20950   case MVT::f32:
20951     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
20952         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
20953       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
20954                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
20955                             MVT::i32);
20956       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
20957     }
20958 
20959     return SDValue();
20960   case MVT::f64:
20961     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
20962          ST->isSimple()) ||
20963         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
20964       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
20965                             getZExtValue(), SDLoc(CFP), MVT::i64);
20966       return DAG.getStore(Chain, DL, Tmp,
20967                           Ptr, ST->getMemOperand());
20968     }
20969 
20970     if (ST->isSimple() && TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32) &&
20971         !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
20972       // Many FP stores are not made apparent until after legalize, e.g. for
20973       // argument passing.  Since this is so common, custom legalize the
20974       // 64-bit integer store into two 32-bit stores.
20975       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
20976       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
20977       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
20978       if (DAG.getDataLayout().isBigEndian())
20979         std::swap(Lo, Hi);
20980 
20981       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
20982       AAMDNodes AAInfo = ST->getAAInfo();
20983 
20984       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
20985                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
20986       Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(4), DL);
20987       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
20988                                  ST->getPointerInfo().getWithOffset(4),
20989                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
20990       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
20991                          St0, St1);
20992     }
20993 
20994     return SDValue();
20995   }
20996 }
20997 
20998 // (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
20999 //
21000 // If a store of a load with an element inserted into it has no other
21001 // uses in between the chain, then we can consider the vector store
21002 // dead and replace it with just the single scalar element store.
replaceStoreOfInsertLoad(StoreSDNode * ST)21003 SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
21004   SDLoc DL(ST);
21005   SDValue Value = ST->getValue();
21006   SDValue Ptr = ST->getBasePtr();
21007   SDValue Chain = ST->getChain();
21008   if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
21009     return SDValue();
21010 
21011   SDValue Elt = Value.getOperand(1);
21012   SDValue Idx = Value.getOperand(2);
21013 
21014   // If the element isn't byte sized or is implicitly truncated then we can't
21015   // compute an offset.
21016   EVT EltVT = Elt.getValueType();
21017   if (!EltVT.isByteSized() ||
21018       EltVT != Value.getOperand(0).getValueType().getVectorElementType())
21019     return SDValue();
21020 
21021   auto *Ld = dyn_cast<LoadSDNode>(Value.getOperand(0));
21022   if (!Ld || Ld->getBasePtr() != Ptr ||
21023       ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
21024       !ISD::isNormalStore(ST) ||
21025       Ld->getAddressSpace() != ST->getAddressSpace() ||
21026       !Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1)))
21027     return SDValue();
21028 
21029   unsigned IsFast;
21030   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
21031                               Elt.getValueType(), ST->getAddressSpace(),
21032                               ST->getAlign(), ST->getMemOperand()->getFlags(),
21033                               &IsFast) ||
21034       !IsFast)
21035     return SDValue();
21036 
21037   MachinePointerInfo PointerInfo(ST->getAddressSpace());
21038 
21039   // If the offset is a known constant then try to recover the pointer
21040   // info
21041   SDValue NewPtr;
21042   if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) {
21043     unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
21044     NewPtr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(COffset), DL);
21045     PointerInfo = ST->getPointerInfo().getWithOffset(COffset);
21046   } else {
21047     NewPtr = TLI.getVectorElementPointer(DAG, Ptr, Value.getValueType(), Idx);
21048   }
21049 
21050   return DAG.getStore(Chain, DL, Elt, NewPtr, PointerInfo, ST->getAlign(),
21051                       ST->getMemOperand()->getFlags());
21052 }
21053 
visitSTORE(SDNode * N)21054 SDValue DAGCombiner::visitSTORE(SDNode *N) {
21055   StoreSDNode *ST  = cast<StoreSDNode>(N);
21056   SDValue Chain = ST->getChain();
21057   SDValue Value = ST->getValue();
21058   SDValue Ptr   = ST->getBasePtr();
21059 
21060   // If this is a store of a bit convert, store the input value if the
21061   // resultant store does not need a higher alignment than the original.
21062   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
21063       ST->isUnindexed()) {
21064     EVT SVT = Value.getOperand(0).getValueType();
21065     // If the store is volatile, we only want to change the store type if the
21066     // resulting store is legal. Otherwise we might increase the number of
21067     // memory accesses. We don't care if the original type was legal or not
21068     // as we assume software couldn't rely on the number of accesses of an
21069     // illegal type.
21070     // TODO: May be able to relax for unordered atomics (see D66309)
21071     if (((!LegalOperations && ST->isSimple()) ||
21072          TLI.isOperationLegal(ISD::STORE, SVT)) &&
21073         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
21074                                      DAG, *ST->getMemOperand())) {
21075       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
21076                           ST->getMemOperand());
21077     }
21078   }
21079 
21080   // Turn 'store undef, Ptr' -> nothing.
21081   if (Value.isUndef() && ST->isUnindexed())
21082     return Chain;
21083 
21084   // Try to infer better alignment information than the store already has.
21085   if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
21086       !ST->isAtomic()) {
21087     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
21088       if (*Alignment > ST->getAlign() &&
21089           isAligned(*Alignment, ST->getSrcValueOffset())) {
21090         SDValue NewStore =
21091             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
21092                               ST->getMemoryVT(), *Alignment,
21093                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
21094         // NewStore will always be N as we are only refining the alignment
21095         assert(NewStore.getNode() == N);
21096         (void)NewStore;
21097       }
21098     }
21099   }
21100 
21101   // Try transforming a pair floating point load / store ops to integer
21102   // load / store ops.
21103   if (SDValue NewST = TransformFPLoadStorePair(N))
21104     return NewST;
21105 
21106   // Try transforming several stores into STORE (BSWAP).
21107   if (SDValue Store = mergeTruncStores(ST))
21108     return Store;
21109 
21110   if (ST->isUnindexed()) {
21111     // Walk up chain skipping non-aliasing memory nodes, on this store and any
21112     // adjacent stores.
21113     if (findBetterNeighborChains(ST)) {
21114       // replaceStoreChain uses CombineTo, which handled all of the worklist
21115       // manipulation. Return the original node to not do anything else.
21116       return SDValue(ST, 0);
21117     }
21118     Chain = ST->getChain();
21119   }
21120 
21121   // FIXME: is there such a thing as a truncating indexed store?
21122   if (ST->isTruncatingStore() && ST->isUnindexed() &&
21123       Value.getValueType().isInteger() &&
21124       (!isa<ConstantSDNode>(Value) ||
21125        !cast<ConstantSDNode>(Value)->isOpaque())) {
21126     // Convert a truncating store of a extension into a standard store.
21127     if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
21128          Value.getOpcode() == ISD::SIGN_EXTEND ||
21129          Value.getOpcode() == ISD::ANY_EXTEND) &&
21130         Value.getOperand(0).getValueType() == ST->getMemoryVT() &&
21131         TLI.isOperationLegalOrCustom(ISD::STORE, ST->getMemoryVT()))
21132       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
21133                           ST->getMemOperand());
21134 
21135     APInt TruncDemandedBits =
21136         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
21137                              ST->getMemoryVT().getScalarSizeInBits());
21138 
21139     // See if we can simplify the operation with SimplifyDemandedBits, which
21140     // only works if the value has a single use.
21141     AddToWorklist(Value.getNode());
21142     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
21143       // Re-visit the store if anything changed and the store hasn't been merged
21144       // with another node (N is deleted) SimplifyDemandedBits will add Value's
21145       // node back to the worklist if necessary, but we also need to re-visit
21146       // the Store node itself.
21147       if (N->getOpcode() != ISD::DELETED_NODE)
21148         AddToWorklist(N);
21149       return SDValue(N, 0);
21150     }
21151 
21152     // Otherwise, see if we can simplify the input to this truncstore with
21153     // knowledge that only the low bits are being used.  For example:
21154     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
21155     if (SDValue Shorter =
21156             TLI.SimplifyMultipleUseDemandedBits(Value, TruncDemandedBits, DAG))
21157       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
21158                                ST->getMemOperand());
21159 
21160     // If we're storing a truncated constant, see if we can simplify it.
21161     // TODO: Move this to targetShrinkDemandedConstant?
21162     if (auto *Cst = dyn_cast<ConstantSDNode>(Value))
21163       if (!Cst->isOpaque()) {
21164         const APInt &CValue = Cst->getAPIntValue();
21165         APInt NewVal = CValue & TruncDemandedBits;
21166         if (NewVal != CValue) {
21167           SDValue Shorter =
21168               DAG.getConstant(NewVal, SDLoc(N), Value.getValueType());
21169           return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr,
21170                                    ST->getMemoryVT(), ST->getMemOperand());
21171         }
21172       }
21173   }
21174 
21175   // If this is a load followed by a store to the same location, then the store
21176   // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
21177   // TODO: Add big-endian truncate support with test coverage.
21178   // TODO: Can relax for unordered atomics (see D66309)
21179   SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
21180                          ? peekThroughTruncates(Value)
21181                          : Value;
21182   if (auto *Ld = dyn_cast<LoadSDNode>(TruncVal)) {
21183     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
21184         ST->isUnindexed() && ST->isSimple() &&
21185         Ld->getAddressSpace() == ST->getAddressSpace() &&
21186         // There can't be any side effects between the load and store, such as
21187         // a call or store.
21188         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
21189       // The store is dead, remove it.
21190       return Chain;
21191     }
21192   }
21193 
21194   // Try scalarizing vector stores of loads where we only change one element
21195   if (SDValue NewST = replaceStoreOfInsertLoad(ST))
21196     return NewST;
21197 
21198   // TODO: Can relax for unordered atomics (see D66309)
21199   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
21200     if (ST->isUnindexed() && ST->isSimple() &&
21201         ST1->isUnindexed() && ST1->isSimple()) {
21202       if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
21203           ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
21204           ST->getAddressSpace() == ST1->getAddressSpace()) {
21205         // If this is a store followed by a store with the same value to the
21206         // same location, then the store is dead/noop.
21207         return Chain;
21208       }
21209 
21210       if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
21211           !ST1->getBasePtr().isUndef() &&
21212           ST->getAddressSpace() == ST1->getAddressSpace()) {
21213         // If we consider two stores and one smaller in size is a scalable
21214         // vector type and another one a bigger size store with a fixed type,
21215         // then we could not allow the scalable store removal because we don't
21216         // know its final size in the end.
21217         if (ST->getMemoryVT().isScalableVector() ||
21218             ST1->getMemoryVT().isScalableVector()) {
21219           if (ST1->getBasePtr() == Ptr &&
21220               TypeSize::isKnownLE(ST1->getMemoryVT().getStoreSize(),
21221                                   ST->getMemoryVT().getStoreSize())) {
21222             CombineTo(ST1, ST1->getChain());
21223             return SDValue(N, 0);
21224           }
21225         } else {
21226           const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
21227           const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
21228           // If this is a store who's preceding store to a subset of the current
21229           // location and no one other node is chained to that store we can
21230           // effectively drop the store. Do not remove stores to undef as they
21231           // may be used as data sinks.
21232           if (STBase.contains(DAG, ST->getMemoryVT().getFixedSizeInBits(),
21233                               ChainBase,
21234                               ST1->getMemoryVT().getFixedSizeInBits())) {
21235             CombineTo(ST1, ST1->getChain());
21236             return SDValue(N, 0);
21237           }
21238         }
21239       }
21240     }
21241   }
21242 
21243   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
21244   // truncating store.  We can do this even if this is already a truncstore.
21245   if ((Value.getOpcode() == ISD::FP_ROUND ||
21246        Value.getOpcode() == ISD::TRUNCATE) &&
21247       Value->hasOneUse() && ST->isUnindexed() &&
21248       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
21249                                ST->getMemoryVT(), LegalOperations)) {
21250     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
21251                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
21252   }
21253 
21254   // Always perform this optimization before types are legal. If the target
21255   // prefers, also try this after legalization to catch stores that were created
21256   // by intrinsics or other nodes.
21257   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
21258     while (true) {
21259       // There can be multiple store sequences on the same chain.
21260       // Keep trying to merge store sequences until we are unable to do so
21261       // or until we merge the last store on the chain.
21262       bool Changed = mergeConsecutiveStores(ST);
21263       if (!Changed) break;
21264       // Return N as merge only uses CombineTo and no worklist clean
21265       // up is necessary.
21266       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
21267         return SDValue(N, 0);
21268     }
21269   }
21270 
21271   // Try transforming N to an indexed store.
21272   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
21273     return SDValue(N, 0);
21274 
21275   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
21276   //
21277   // Make sure to do this only after attempting to merge stores in order to
21278   //  avoid changing the types of some subset of stores due to visit order,
21279   //  preventing their merging.
21280   if (isa<ConstantFPSDNode>(ST->getValue())) {
21281     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
21282       return NewSt;
21283   }
21284 
21285   if (SDValue NewSt = splitMergedValStore(ST))
21286     return NewSt;
21287 
21288   return ReduceLoadOpStoreWidth(N);
21289 }
21290 
visitLIFETIME_END(SDNode * N)21291 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
21292   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
21293   if (!LifetimeEnd->hasOffset())
21294     return SDValue();
21295 
21296   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
21297                                         LifetimeEnd->getOffset(), false);
21298 
21299   // We walk up the chains to find stores.
21300   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
21301   while (!Chains.empty()) {
21302     SDValue Chain = Chains.pop_back_val();
21303     if (!Chain.hasOneUse())
21304       continue;
21305     switch (Chain.getOpcode()) {
21306     case ISD::TokenFactor:
21307       for (unsigned Nops = Chain.getNumOperands(); Nops;)
21308         Chains.push_back(Chain.getOperand(--Nops));
21309       break;
21310     case ISD::LIFETIME_START:
21311     case ISD::LIFETIME_END:
21312       // We can forward past any lifetime start/end that can be proven not to
21313       // alias the node.
21314       if (!mayAlias(Chain.getNode(), N))
21315         Chains.push_back(Chain.getOperand(0));
21316       break;
21317     case ISD::STORE: {
21318       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
21319       // TODO: Can relax for unordered atomics (see D66309)
21320       if (!ST->isSimple() || ST->isIndexed())
21321         continue;
21322       const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
21323       // The bounds of a scalable store are not known until runtime, so this
21324       // store cannot be elided.
21325       if (StoreSize.isScalable())
21326         continue;
21327       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
21328       // If we store purely within object bounds just before its lifetime ends,
21329       // we can remove the store.
21330       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
21331                                    StoreSize.getFixedValue() * 8)) {
21332         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
21333                    dbgs() << "\nwithin LIFETIME_END of : ";
21334                    LifetimeEndBase.dump(); dbgs() << "\n");
21335         CombineTo(ST, ST->getChain());
21336         return SDValue(N, 0);
21337       }
21338     }
21339     }
21340   }
21341   return SDValue();
21342 }
21343 
21344 /// For the instruction sequence of store below, F and I values
21345 /// are bundled together as an i64 value before being stored into memory.
21346 /// Sometimes it is more efficent to generate separate stores for F and I,
21347 /// which can remove the bitwise instructions or sink them to colder places.
21348 ///
21349 ///   (store (or (zext (bitcast F to i32) to i64),
21350 ///              (shl (zext I to i64), 32)), addr)  -->
21351 ///   (store F, addr) and (store I, addr+4)
21352 ///
21353 /// Similarly, splitting for other merged store can also be beneficial, like:
21354 /// For pair of {i32, i32}, i64 store --> two i32 stores.
21355 /// For pair of {i32, i16}, i64 store --> two i32 stores.
21356 /// For pair of {i16, i16}, i32 store --> two i16 stores.
21357 /// For pair of {i16, i8},  i32 store --> two i16 stores.
21358 /// For pair of {i8, i8},   i16 store --> two i8 stores.
21359 ///
21360 /// We allow each target to determine specifically which kind of splitting is
21361 /// supported.
21362 ///
21363 /// The store patterns are commonly seen from the simple code snippet below
21364 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
21365 ///   void goo(const std::pair<int, float> &);
21366 ///   hoo() {
21367 ///     ...
21368 ///     goo(std::make_pair(tmp, ftmp));
21369 ///     ...
21370 ///   }
21371 ///
splitMergedValStore(StoreSDNode * ST)21372 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
21373   if (OptLevel == CodeGenOptLevel::None)
21374     return SDValue();
21375 
21376   // Can't change the number of memory accesses for a volatile store or break
21377   // atomicity for an atomic one.
21378   if (!ST->isSimple())
21379     return SDValue();
21380 
21381   SDValue Val = ST->getValue();
21382   SDLoc DL(ST);
21383 
21384   // Match OR operand.
21385   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
21386     return SDValue();
21387 
21388   // Match SHL operand and get Lower and Higher parts of Val.
21389   SDValue Op1 = Val.getOperand(0);
21390   SDValue Op2 = Val.getOperand(1);
21391   SDValue Lo, Hi;
21392   if (Op1.getOpcode() != ISD::SHL) {
21393     std::swap(Op1, Op2);
21394     if (Op1.getOpcode() != ISD::SHL)
21395       return SDValue();
21396   }
21397   Lo = Op2;
21398   Hi = Op1.getOperand(0);
21399   if (!Op1.hasOneUse())
21400     return SDValue();
21401 
21402   // Match shift amount to HalfValBitSize.
21403   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
21404   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
21405   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
21406     return SDValue();
21407 
21408   // Lo and Hi are zero-extended from int with size less equal than 32
21409   // to i64.
21410   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
21411       !Lo.getOperand(0).getValueType().isScalarInteger() ||
21412       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
21413       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
21414       !Hi.getOperand(0).getValueType().isScalarInteger() ||
21415       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
21416     return SDValue();
21417 
21418   // Use the EVT of low and high parts before bitcast as the input
21419   // of target query.
21420   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
21421                   ? Lo.getOperand(0).getValueType()
21422                   : Lo.getValueType();
21423   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
21424                    ? Hi.getOperand(0).getValueType()
21425                    : Hi.getValueType();
21426   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
21427     return SDValue();
21428 
21429   // Start to split store.
21430   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21431   AAMDNodes AAInfo = ST->getAAInfo();
21432 
21433   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
21434   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
21435   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
21436   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
21437 
21438   SDValue Chain = ST->getChain();
21439   SDValue Ptr = ST->getBasePtr();
21440   // Lower value store.
21441   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
21442                              ST->getOriginalAlign(), MMOFlags, AAInfo);
21443   Ptr =
21444       DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(HalfValBitSize / 8), DL);
21445   // Higher value store.
21446   SDValue St1 = DAG.getStore(
21447       St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
21448       ST->getOriginalAlign(), MMOFlags, AAInfo);
21449   return St1;
21450 }
21451 
21452 // Merge an insertion into an existing shuffle:
21453 // (insert_vector_elt (vector_shuffle X, Y, Mask),
21454 //                   .(extract_vector_elt X, N), InsIndex)
21455 //   --> (vector_shuffle X, Y, NewMask)
21456 //  and variations where shuffle operands may be CONCAT_VECTORS.
mergeEltWithShuffle(SDValue & X,SDValue & Y,ArrayRef<int> Mask,SmallVectorImpl<int> & NewMask,SDValue Elt,unsigned InsIndex)21457 static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
21458                                 SmallVectorImpl<int> &NewMask, SDValue Elt,
21459                                 unsigned InsIndex) {
21460   if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21461       !isa<ConstantSDNode>(Elt.getOperand(1)))
21462     return false;
21463 
21464   // Vec's operand 0 is using indices from 0 to N-1 and
21465   // operand 1 from N to 2N - 1, where N is the number of
21466   // elements in the vectors.
21467   SDValue InsertVal0 = Elt.getOperand(0);
21468   int ElementOffset = -1;
21469 
21470   // We explore the inputs of the shuffle in order to see if we find the
21471   // source of the extract_vector_elt. If so, we can use it to modify the
21472   // shuffle rather than perform an insert_vector_elt.
21473   SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
21474   ArgWorkList.emplace_back(Mask.size(), Y);
21475   ArgWorkList.emplace_back(0, X);
21476 
21477   while (!ArgWorkList.empty()) {
21478     int ArgOffset;
21479     SDValue ArgVal;
21480     std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
21481 
21482     if (ArgVal == InsertVal0) {
21483       ElementOffset = ArgOffset;
21484       break;
21485     }
21486 
21487     // Peek through concat_vector.
21488     if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
21489       int CurrentArgOffset =
21490           ArgOffset + ArgVal.getValueType().getVectorNumElements();
21491       int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
21492       for (SDValue Op : reverse(ArgVal->ops())) {
21493         CurrentArgOffset -= Step;
21494         ArgWorkList.emplace_back(CurrentArgOffset, Op);
21495       }
21496 
21497       // Make sure we went through all the elements and did not screw up index
21498       // computation.
21499       assert(CurrentArgOffset == ArgOffset);
21500     }
21501   }
21502 
21503   // If we failed to find a match, see if we can replace an UNDEF shuffle
21504   // operand.
21505   if (ElementOffset == -1) {
21506     if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
21507       return false;
21508     ElementOffset = Mask.size();
21509     Y = InsertVal0;
21510   }
21511 
21512   NewMask.assign(Mask.begin(), Mask.end());
21513   NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(1);
21514   assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
21515          "NewMask[InsIndex] is out of bound");
21516   return true;
21517 }
21518 
21519 // Merge an insertion into an existing shuffle:
21520 // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
21521 // InsIndex)
21522 //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
21523 //   CONCAT_VECTORS.
mergeInsertEltWithShuffle(SDNode * N,unsigned InsIndex)21524 SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
21525   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21526          "Expected extract_vector_elt");
21527   SDValue InsertVal = N->getOperand(1);
21528   SDValue Vec = N->getOperand(0);
21529 
21530   auto *SVN = dyn_cast<ShuffleVectorSDNode>(Vec);
21531   if (!SVN || !Vec.hasOneUse())
21532     return SDValue();
21533 
21534   ArrayRef<int> Mask = SVN->getMask();
21535   SDValue X = Vec.getOperand(0);
21536   SDValue Y = Vec.getOperand(1);
21537 
21538   SmallVector<int, 16> NewMask(Mask);
21539   if (mergeEltWithShuffle(X, Y, Mask, NewMask, InsertVal, InsIndex)) {
21540     SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
21541         Vec.getValueType(), SDLoc(N), X, Y, NewMask, DAG);
21542     if (LegalShuffle)
21543       return LegalShuffle;
21544   }
21545 
21546   return SDValue();
21547 }
21548 
21549 // Convert a disguised subvector insertion into a shuffle:
21550 // insert_vector_elt V, (bitcast X from vector type), IdxC -->
21551 // bitcast(shuffle (bitcast V), (extended X), Mask)
21552 // Note: We do not use an insert_subvector node because that requires a
21553 // legal subvector type.
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)21554 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
21555   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21556          "Expected extract_vector_elt");
21557   SDValue InsertVal = N->getOperand(1);
21558 
21559   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
21560       !InsertVal.getOperand(0).getValueType().isVector())
21561     return SDValue();
21562 
21563   SDValue SubVec = InsertVal.getOperand(0);
21564   SDValue DestVec = N->getOperand(0);
21565   EVT SubVecVT = SubVec.getValueType();
21566   EVT VT = DestVec.getValueType();
21567   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
21568   // If the source only has a single vector element, the cost of creating adding
21569   // it to a vector is likely to exceed the cost of a insert_vector_elt.
21570   if (NumSrcElts == 1)
21571     return SDValue();
21572   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
21573   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
21574 
21575   // Step 1: Create a shuffle mask that implements this insert operation. The
21576   // vector that we are inserting into will be operand 0 of the shuffle, so
21577   // those elements are just 'i'. The inserted subvector is in the first
21578   // positions of operand 1 of the shuffle. Example:
21579   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
21580   SmallVector<int, 16> Mask(NumMaskVals);
21581   for (unsigned i = 0; i != NumMaskVals; ++i) {
21582     if (i / NumSrcElts == InsIndex)
21583       Mask[i] = (i % NumSrcElts) + NumMaskVals;
21584     else
21585       Mask[i] = i;
21586   }
21587 
21588   // Bail out if the target can not handle the shuffle we want to create.
21589   EVT SubVecEltVT = SubVecVT.getVectorElementType();
21590   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
21591   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
21592     return SDValue();
21593 
21594   // Step 2: Create a wide vector from the inserted source vector by appending
21595   // undefined elements. This is the same size as our destination vector.
21596   SDLoc DL(N);
21597   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
21598   ConcatOps[0] = SubVec;
21599   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
21600 
21601   // Step 3: Shuffle in the padded subvector.
21602   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
21603   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
21604   AddToWorklist(PaddedSubV.getNode());
21605   AddToWorklist(DestVecBC.getNode());
21606   AddToWorklist(Shuf.getNode());
21607   return DAG.getBitcast(VT, Shuf);
21608 }
21609 
21610 // Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
21611 // possible and the new load will be quick. We use more loads but less shuffles
21612 // and inserts.
combineInsertEltToLoad(SDNode * N,unsigned InsIndex)21613 SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
21614   EVT VT = N->getValueType(0);
21615 
21616   // InsIndex is expected to be the first of last lane.
21617   if (!VT.isFixedLengthVector() ||
21618       (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
21619     return SDValue();
21620 
21621   // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
21622   // depending on the InsIndex.
21623   auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0));
21624   SDValue Scalar = N->getOperand(1);
21625   if (!Shuffle || !all_of(enumerate(Shuffle->getMask()), [&](auto P) {
21626         return InsIndex == P.index() || P.value() < 0 ||
21627                (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
21628                (InsIndex == VT.getVectorNumElements() - 1 &&
21629                 P.value() == (int)P.index() + 1);
21630       }))
21631     return SDValue();
21632 
21633   // We optionally skip over an extend so long as both loads are extended in the
21634   // same way from the same type.
21635   unsigned Extend = 0;
21636   if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
21637       Scalar.getOpcode() == ISD::SIGN_EXTEND ||
21638       Scalar.getOpcode() == ISD::ANY_EXTEND) {
21639     Extend = Scalar.getOpcode();
21640     Scalar = Scalar.getOperand(0);
21641   }
21642 
21643   auto *ScalarLoad = dyn_cast<LoadSDNode>(Scalar);
21644   if (!ScalarLoad)
21645     return SDValue();
21646 
21647   SDValue Vec = Shuffle->getOperand(0);
21648   if (Extend) {
21649     if (Vec.getOpcode() != Extend)
21650       return SDValue();
21651     Vec = Vec.getOperand(0);
21652   }
21653   auto *VecLoad = dyn_cast<LoadSDNode>(Vec);
21654   if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
21655     return SDValue();
21656 
21657   int EltSize = ScalarLoad->getValueType(0).getScalarSizeInBits();
21658   if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
21659       !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
21660       ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
21661       ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
21662     return SDValue();
21663 
21664   // Check that the offset between the pointers to produce a single continuous
21665   // load.
21666   if (InsIndex == 0) {
21667     if (!DAG.areNonVolatileConsecutiveLoads(ScalarLoad, VecLoad, EltSize / 8,
21668                                             -1))
21669       return SDValue();
21670   } else {
21671     if (!DAG.areNonVolatileConsecutiveLoads(
21672             VecLoad, ScalarLoad, VT.getVectorNumElements() * EltSize / 8, -1))
21673       return SDValue();
21674   }
21675 
21676   // And that the new unaligned load will be fast.
21677   unsigned IsFast = 0;
21678   Align NewAlign = commonAlignment(VecLoad->getAlign(), EltSize / 8);
21679   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
21680                               Vec.getValueType(), VecLoad->getAddressSpace(),
21681                               NewAlign, VecLoad->getMemOperand()->getFlags(),
21682                               &IsFast) ||
21683       !IsFast)
21684     return SDValue();
21685 
21686   // Calculate the new Ptr and create the new load.
21687   SDLoc DL(N);
21688   SDValue Ptr = ScalarLoad->getBasePtr();
21689   if (InsIndex != 0)
21690     Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), VecLoad->getBasePtr(),
21691                       DAG.getConstant(EltSize / 8, DL, Ptr.getValueType()));
21692   MachinePointerInfo PtrInfo =
21693       InsIndex == 0 ? ScalarLoad->getPointerInfo()
21694                     : VecLoad->getPointerInfo().getWithOffset(EltSize / 8);
21695 
21696   SDValue Load = DAG.getLoad(VecLoad->getValueType(0), DL,
21697                              ScalarLoad->getChain(), Ptr, PtrInfo, NewAlign);
21698   DAG.makeEquivalentMemoryOrdering(ScalarLoad, Load.getValue(1));
21699   DAG.makeEquivalentMemoryOrdering(VecLoad, Load.getValue(1));
21700   return Extend ? DAG.getNode(Extend, DL, VT, Load) : Load;
21701 }
21702 
visitINSERT_VECTOR_ELT(SDNode * N)21703 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
21704   SDValue InVec = N->getOperand(0);
21705   SDValue InVal = N->getOperand(1);
21706   SDValue EltNo = N->getOperand(2);
21707   SDLoc DL(N);
21708 
21709   EVT VT = InVec.getValueType();
21710   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
21711 
21712   // Insert into out-of-bounds element is undefined.
21713   if (IndexC && VT.isFixedLengthVector() &&
21714       IndexC->getZExtValue() >= VT.getVectorNumElements())
21715     return DAG.getUNDEF(VT);
21716 
21717   // Remove redundant insertions:
21718   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
21719   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21720       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
21721     return InVec;
21722 
21723   if (!IndexC) {
21724     // If this is variable insert to undef vector, it might be better to splat:
21725     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
21726     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
21727       return DAG.getSplat(VT, DL, InVal);
21728     return SDValue();
21729   }
21730 
21731   if (VT.isScalableVector())
21732     return SDValue();
21733 
21734   unsigned NumElts = VT.getVectorNumElements();
21735 
21736   // We must know which element is being inserted for folds below here.
21737   unsigned Elt = IndexC->getZExtValue();
21738 
21739   // Handle <1 x ???> vector insertion special cases.
21740   if (NumElts == 1) {
21741     // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
21742     if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21743         InVal.getOperand(0).getValueType() == VT &&
21744         isNullConstant(InVal.getOperand(1)))
21745       return InVal.getOperand(0);
21746   }
21747 
21748   // Canonicalize insert_vector_elt dag nodes.
21749   // Example:
21750   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
21751   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
21752   //
21753   // Do this only if the child insert_vector node has one use; also
21754   // do this only if indices are both constants and Idx1 < Idx0.
21755   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
21756       && isa<ConstantSDNode>(InVec.getOperand(2))) {
21757     unsigned OtherElt = InVec.getConstantOperandVal(2);
21758     if (Elt < OtherElt) {
21759       // Swap nodes.
21760       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
21761                                   InVec.getOperand(0), InVal, EltNo);
21762       AddToWorklist(NewOp.getNode());
21763       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
21764                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
21765     }
21766   }
21767 
21768   if (SDValue Shuf = mergeInsertEltWithShuffle(N, Elt))
21769     return Shuf;
21770 
21771   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
21772     return Shuf;
21773 
21774   if (SDValue Shuf = combineInsertEltToLoad(N, Elt))
21775     return Shuf;
21776 
21777   // Attempt to convert an insert_vector_elt chain into a legal build_vector.
21778   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
21779     // vXi1 vector - we don't need to recurse.
21780     if (NumElts == 1)
21781       return DAG.getBuildVector(VT, DL, {InVal});
21782 
21783     // If we haven't already collected the element, insert into the op list.
21784     EVT MaxEltVT = InVal.getValueType();
21785     auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
21786                                 unsigned Idx) {
21787       if (!Ops[Idx]) {
21788         Ops[Idx] = Elt;
21789         if (VT.isInteger()) {
21790           EVT EltVT = Elt.getValueType();
21791           MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
21792         }
21793       }
21794     };
21795 
21796     // Ensure all the operands are the same value type, fill any missing
21797     // operands with UNDEF and create the BUILD_VECTOR.
21798     auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
21799       assert(Ops.size() == NumElts && "Unexpected vector size");
21800       for (SDValue &Op : Ops) {
21801         if (Op)
21802           Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
21803         else
21804           Op = DAG.getUNDEF(MaxEltVT);
21805       }
21806       return DAG.getBuildVector(VT, DL, Ops);
21807     };
21808 
21809     SmallVector<SDValue, 8> Ops(NumElts, SDValue());
21810     Ops[Elt] = InVal;
21811 
21812     // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
21813     for (SDValue CurVec = InVec; CurVec;) {
21814       // UNDEF - build new BUILD_VECTOR from already inserted operands.
21815       if (CurVec.isUndef())
21816         return CanonicalizeBuildVector(Ops);
21817 
21818       // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
21819       if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
21820         for (unsigned I = 0; I != NumElts; ++I)
21821           AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
21822         return CanonicalizeBuildVector(Ops);
21823       }
21824 
21825       // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
21826       if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
21827         AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
21828         return CanonicalizeBuildVector(Ops);
21829       }
21830 
21831       // INSERT_VECTOR_ELT - insert operand and continue up the chain.
21832       if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
21833         if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
21834           if (CurIdx->getAPIntValue().ult(NumElts)) {
21835             unsigned Idx = CurIdx->getZExtValue();
21836             AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
21837 
21838             // Found entire BUILD_VECTOR.
21839             if (all_of(Ops, [](SDValue Op) { return !!Op; }))
21840               return CanonicalizeBuildVector(Ops);
21841 
21842             CurVec = CurVec->getOperand(0);
21843             continue;
21844           }
21845 
21846       // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
21847       // update the shuffle mask (and second operand if we started with unary
21848       // shuffle) and create a new legal shuffle.
21849       if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
21850         auto *SVN = cast<ShuffleVectorSDNode>(CurVec);
21851         SDValue LHS = SVN->getOperand(0);
21852         SDValue RHS = SVN->getOperand(1);
21853         SmallVector<int, 16> Mask(SVN->getMask());
21854         bool Merged = true;
21855         for (auto I : enumerate(Ops)) {
21856           SDValue &Op = I.value();
21857           if (Op) {
21858             SmallVector<int, 16> NewMask;
21859             if (!mergeEltWithShuffle(LHS, RHS, Mask, NewMask, Op, I.index())) {
21860               Merged = false;
21861               break;
21862             }
21863             Mask = std::move(NewMask);
21864           }
21865         }
21866         if (Merged)
21867           if (SDValue NewShuffle =
21868                   TLI.buildLegalVectorShuffle(VT, DL, LHS, RHS, Mask, DAG))
21869             return NewShuffle;
21870       }
21871 
21872       // If all insertions are zero value, try to convert to AND mask.
21873       // TODO: Do this for -1 with OR mask?
21874       if (!LegalOperations && llvm::isNullConstant(InVal) &&
21875           all_of(Ops, [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
21876           count_if(Ops, [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
21877         SDValue Zero = DAG.getConstant(0, DL, MaxEltVT);
21878         SDValue AllOnes = DAG.getAllOnesConstant(DL, MaxEltVT);
21879         SmallVector<SDValue, 8> Mask(NumElts);
21880         for (unsigned I = 0; I != NumElts; ++I)
21881           Mask[I] = Ops[I] ? Zero : AllOnes;
21882         return DAG.getNode(ISD::AND, DL, VT, CurVec,
21883                            DAG.getBuildVector(VT, DL, Mask));
21884       }
21885 
21886       // Failed to find a match in the chain - bail.
21887       break;
21888     }
21889 
21890     // See if we can fill in the missing constant elements as zeros.
21891     // TODO: Should we do this for any constant?
21892     APInt DemandedZeroElts = APInt::getZero(NumElts);
21893     for (unsigned I = 0; I != NumElts; ++I)
21894       if (!Ops[I])
21895         DemandedZeroElts.setBit(I);
21896 
21897     if (DAG.MaskedVectorIsZero(InVec, DemandedZeroElts)) {
21898       SDValue Zero = VT.isInteger() ? DAG.getConstant(0, DL, MaxEltVT)
21899                                     : DAG.getConstantFP(0, DL, MaxEltVT);
21900       for (unsigned I = 0; I != NumElts; ++I)
21901         if (!Ops[I])
21902           Ops[I] = Zero;
21903 
21904       return CanonicalizeBuildVector(Ops);
21905     }
21906   }
21907 
21908   return SDValue();
21909 }
21910 
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)21911 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
21912                                                   SDValue EltNo,
21913                                                   LoadSDNode *OriginalLoad) {
21914   assert(OriginalLoad->isSimple());
21915 
21916   EVT ResultVT = EVE->getValueType(0);
21917   EVT VecEltVT = InVecVT.getVectorElementType();
21918 
21919   // If the vector element type is not a multiple of a byte then we are unable
21920   // to correctly compute an address to load only the extracted element as a
21921   // scalar.
21922   if (!VecEltVT.isByteSized())
21923     return SDValue();
21924 
21925   ISD::LoadExtType ExtTy =
21926       ResultVT.bitsGT(VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
21927   if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
21928       !TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
21929     return SDValue();
21930 
21931   Align Alignment = OriginalLoad->getAlign();
21932   MachinePointerInfo MPI;
21933   SDLoc DL(EVE);
21934   if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
21935     int Elt = ConstEltNo->getZExtValue();
21936     unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
21937     MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
21938     Alignment = commonAlignment(Alignment, PtrOff);
21939   } else {
21940     // Discard the pointer info except the address space because the memory
21941     // operand can't represent this new access since the offset is variable.
21942     MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
21943     Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
21944   }
21945 
21946   unsigned IsFast = 0;
21947   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
21948                               OriginalLoad->getAddressSpace(), Alignment,
21949                               OriginalLoad->getMemOperand()->getFlags(),
21950                               &IsFast) ||
21951       !IsFast)
21952     return SDValue();
21953 
21954   SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
21955                                                InVecVT, EltNo);
21956 
21957   // We are replacing a vector load with a scalar load. The new load must have
21958   // identical memory op ordering to the original.
21959   SDValue Load;
21960   if (ResultVT.bitsGT(VecEltVT)) {
21961     // If the result type of vextract is wider than the load, then issue an
21962     // extending load instead.
21963     ISD::LoadExtType ExtType =
21964         TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
21965                                                               : ISD::EXTLOAD;
21966     Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
21967                           NewPtr, MPI, VecEltVT, Alignment,
21968                           OriginalLoad->getMemOperand()->getFlags(),
21969                           OriginalLoad->getAAInfo());
21970     DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
21971   } else {
21972     // The result type is narrower or the same width as the vector element
21973     Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
21974                        Alignment, OriginalLoad->getMemOperand()->getFlags(),
21975                        OriginalLoad->getAAInfo());
21976     DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
21977     if (ResultVT.bitsLT(VecEltVT))
21978       Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
21979     else
21980       Load = DAG.getBitcast(ResultVT, Load);
21981   }
21982   ++OpsNarrowed;
21983   return Load;
21984 }
21985 
21986 /// Transform a vector binary operation into a scalar binary operation by moving
21987 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)21988 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
21989                                        bool LegalOperations) {
21990   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21991   SDValue Vec = ExtElt->getOperand(0);
21992   SDValue Index = ExtElt->getOperand(1);
21993   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
21994   if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
21995       Vec->getNumValues() != 1)
21996     return SDValue();
21997 
21998   // Targets may want to avoid this to prevent an expensive register transfer.
21999   if (!TLI.shouldScalarizeBinop(Vec))
22000     return SDValue();
22001 
22002   // Extracting an element of a vector constant is constant-folded, so this
22003   // transform is just replacing a vector op with a scalar op while moving the
22004   // extract.
22005   SDValue Op0 = Vec.getOperand(0);
22006   SDValue Op1 = Vec.getOperand(1);
22007   APInt SplatVal;
22008   if (isAnyConstantBuildVector(Op0, true) ||
22009       ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
22010       isAnyConstantBuildVector(Op1, true) ||
22011       ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
22012     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22013     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22014     SDLoc DL(ExtElt);
22015     EVT VT = ExtElt->getValueType(0);
22016     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
22017     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
22018     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
22019   }
22020 
22021   return SDValue();
22022 }
22023 
22024 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
22025 // recursively analyse all of it's users. and try to model themselves as
22026 // bit sequence extractions. If all of them agree on the new, narrower element
22027 // type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
22028 // new element type, do so now.
22029 // This is mainly useful to recover from legalization that scalarized
22030 // the vector as wide elements, but tries to rebuild it with narrower elements.
22031 //
22032 // Some more nodes could be modelled if that helps cover interesting patterns.
refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode * N)22033 bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
22034     SDNode *N) {
22035   // We perform this optimization post type-legalization because
22036   // the type-legalizer often scalarizes integer-promoted vectors.
22037   // Performing this optimization before may cause legalizaton cycles.
22038   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22039     return false;
22040 
22041   // TODO: Add support for big-endian.
22042   if (DAG.getDataLayout().isBigEndian())
22043     return false;
22044 
22045   SDValue VecOp = N->getOperand(0);
22046   EVT VecVT = VecOp.getValueType();
22047   assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
22048 
22049   // We must start with a constant extraction index.
22050   auto *IndexC = dyn_cast<ConstantSDNode>(N->getOperand(1));
22051   if (!IndexC)
22052     return false;
22053 
22054   assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
22055          "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
22056 
22057   // TODO: deal with the case of implicit anyext of the extraction.
22058   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22059   EVT ScalarVT = N->getValueType(0);
22060   if (VecVT.getScalarType() != ScalarVT)
22061     return false;
22062 
22063   // TODO: deal with the cases other than everything being integer-typed.
22064   if (!ScalarVT.isScalarInteger())
22065     return false;
22066 
22067   struct Entry {
22068     SDNode *Producer;
22069 
22070     // Which bits of VecOp does it contain?
22071     unsigned BitPos;
22072     int NumBits;
22073     // NOTE: the actual width of \p Producer may be wider than NumBits!
22074 
22075     Entry(Entry &&) = default;
22076     Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
22077         : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
22078 
22079     Entry() = delete;
22080     Entry(const Entry &) = delete;
22081     Entry &operator=(const Entry &) = delete;
22082     Entry &operator=(Entry &&) = delete;
22083   };
22084   SmallVector<Entry, 32> Worklist;
22085   SmallVector<Entry, 32> Leafs;
22086 
22087   // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
22088   Worklist.emplace_back(N, /*BitPos=*/VecEltBitWidth * IndexC->getZExtValue(),
22089                         /*NumBits=*/VecEltBitWidth);
22090 
22091   while (!Worklist.empty()) {
22092     Entry E = Worklist.pop_back_val();
22093     // Does the node not even use any of the VecOp bits?
22094     if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
22095           E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
22096       return false; // Let's allow the other combines clean this up first.
22097     // Did we fail to model any of the users of the Producer?
22098     bool ProducerIsLeaf = false;
22099     // Look at each user of this Producer.
22100     for (SDNode *User : E.Producer->uses()) {
22101       switch (User->getOpcode()) {
22102       // TODO: support ISD::BITCAST
22103       // TODO: support ISD::ANY_EXTEND
22104       // TODO: support ISD::ZERO_EXTEND
22105       // TODO: support ISD::SIGN_EXTEND
22106       case ISD::TRUNCATE:
22107         // Truncation simply means we keep position, but extract less bits.
22108         Worklist.emplace_back(User, E.BitPos,
22109                               /*NumBits=*/User->getValueSizeInBits(0));
22110         break;
22111       // TODO: support ISD::SRA
22112       // TODO: support ISD::SHL
22113       case ISD::SRL:
22114         // We should be shifting the Producer by a constant amount.
22115         if (auto *ShAmtC = dyn_cast<ConstantSDNode>(User->getOperand(1));
22116             User->getOperand(0).getNode() == E.Producer && ShAmtC) {
22117           // Logical right-shift means that we start extraction later,
22118           // but stop it at the same position we did previously.
22119           unsigned ShAmt = ShAmtC->getZExtValue();
22120           Worklist.emplace_back(User, E.BitPos + ShAmt, E.NumBits - ShAmt);
22121           break;
22122         }
22123         [[fallthrough]];
22124       default:
22125         // We can not model this user of the Producer.
22126         // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
22127         ProducerIsLeaf = true;
22128         // Profitability check: all users that we can not model
22129         //                      must be ISD::BUILD_VECTOR's.
22130         if (User->getOpcode() != ISD::BUILD_VECTOR)
22131           return false;
22132         break;
22133       }
22134     }
22135     if (ProducerIsLeaf)
22136       Leafs.emplace_back(std::move(E));
22137   }
22138 
22139   unsigned NewVecEltBitWidth = Leafs.front().NumBits;
22140 
22141   // If we are still at the same element granularity, give up,
22142   if (NewVecEltBitWidth == VecEltBitWidth)
22143     return false;
22144 
22145   // The vector width must be a multiple of the new element width.
22146   if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
22147     return false;
22148 
22149   // All leafs must agree on the new element width.
22150   // All leafs must not expect any "padding" bits ontop of that width.
22151   // All leafs must start extraction from multiple of that width.
22152   if (!all_of(Leafs, [NewVecEltBitWidth](const Entry &E) {
22153         return (unsigned)E.NumBits == NewVecEltBitWidth &&
22154                E.Producer->getValueSizeInBits(0) == NewVecEltBitWidth &&
22155                E.BitPos % NewVecEltBitWidth == 0;
22156       }))
22157     return false;
22158 
22159   EVT NewScalarVT = EVT::getIntegerVT(*DAG.getContext(), NewVecEltBitWidth);
22160   EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewScalarVT,
22161                                   VecVT.getSizeInBits() / NewVecEltBitWidth);
22162 
22163   if (LegalTypes &&
22164       !(TLI.isTypeLegal(NewScalarVT) && TLI.isTypeLegal(NewVecVT)))
22165     return false;
22166 
22167   if (LegalOperations &&
22168       !(TLI.isOperationLegalOrCustom(ISD::BITCAST, NewVecVT) &&
22169         TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, NewVecVT)))
22170     return false;
22171 
22172   SDValue NewVecOp = DAG.getBitcast(NewVecVT, VecOp);
22173   for (const Entry &E : Leafs) {
22174     SDLoc DL(E.Producer);
22175     unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
22176     assert(NewIndex < NewVecVT.getVectorNumElements() &&
22177            "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
22178     SDValue V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, NewScalarVT, NewVecOp,
22179                             DAG.getVectorIdxConstant(NewIndex, DL));
22180     CombineTo(E.Producer, V);
22181   }
22182 
22183   return true;
22184 }
22185 
visitEXTRACT_VECTOR_ELT(SDNode * N)22186 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
22187   SDValue VecOp = N->getOperand(0);
22188   SDValue Index = N->getOperand(1);
22189   EVT ScalarVT = N->getValueType(0);
22190   EVT VecVT = VecOp.getValueType();
22191   if (VecOp.isUndef())
22192     return DAG.getUNDEF(ScalarVT);
22193 
22194   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
22195   //
22196   // This only really matters if the index is non-constant since other combines
22197   // on the constant elements already work.
22198   SDLoc DL(N);
22199   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
22200       Index == VecOp.getOperand(2)) {
22201     SDValue Elt = VecOp.getOperand(1);
22202     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
22203   }
22204 
22205   // (vextract (scalar_to_vector val, 0) -> val
22206   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22207     // Only 0'th element of SCALAR_TO_VECTOR is defined.
22208     if (DAG.isKnownNeverZero(Index))
22209       return DAG.getUNDEF(ScalarVT);
22210 
22211     // Check if the result type doesn't match the inserted element type.
22212     // The inserted element and extracted element may have mismatched bitwidth.
22213     // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
22214     SDValue InOp = VecOp.getOperand(0);
22215     if (InOp.getValueType() != ScalarVT) {
22216       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22217       if (InOp.getValueType().bitsGT(ScalarVT))
22218         return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp);
22219       return DAG.getNode(ISD::ANY_EXTEND, DL, ScalarVT, InOp);
22220     }
22221     return InOp;
22222   }
22223 
22224   // extract_vector_elt of out-of-bounds element -> UNDEF
22225   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22226   if (IndexC && VecVT.isFixedLengthVector() &&
22227       IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
22228     return DAG.getUNDEF(ScalarVT);
22229 
22230   // extract_vector_elt(freeze(x)), idx -> freeze(extract_vector_elt(x)), idx
22231   if (VecOp.hasOneUse() && VecOp.getOpcode() == ISD::FREEZE) {
22232     return DAG.getFreeze(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
22233                                      VecOp.getOperand(0), Index));
22234   }
22235 
22236   // extract_vector_elt (build_vector x, y), 1 -> y
22237   if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
22238        VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
22239       TLI.isTypeLegal(VecVT)) {
22240     assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
22241             VecVT.isFixedLengthVector()) &&
22242            "BUILD_VECTOR used for scalable vectors");
22243     unsigned IndexVal =
22244         VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
22245     SDValue Elt = VecOp.getOperand(IndexVal);
22246     EVT InEltVT = Elt.getValueType();
22247 
22248     if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
22249         isNullConstant(Elt)) {
22250       // Sometimes build_vector's scalar input types do not match result type.
22251       if (ScalarVT == InEltVT)
22252         return Elt;
22253 
22254       // TODO: It may be useful to truncate if free if the build_vector
22255       // implicitly converts.
22256     }
22257   }
22258 
22259   if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
22260     return BO;
22261 
22262   if (VecVT.isScalableVector())
22263     return SDValue();
22264 
22265   // All the code from this point onwards assumes fixed width vectors, but it's
22266   // possible that some of the combinations could be made to work for scalable
22267   // vectors too.
22268   unsigned NumElts = VecVT.getVectorNumElements();
22269   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22270 
22271   // See if the extracted element is constant, in which case fold it if its
22272   // a legal fp immediate.
22273   if (IndexC && ScalarVT.isFloatingPoint()) {
22274     APInt EltMask = APInt::getOneBitSet(NumElts, IndexC->getZExtValue());
22275     KnownBits KnownElt = DAG.computeKnownBits(VecOp, EltMask);
22276     if (KnownElt.isConstant()) {
22277       APFloat CstFP =
22278           APFloat(DAG.EVTToAPFloatSemantics(ScalarVT), KnownElt.getConstant());
22279       if (TLI.isFPImmLegal(CstFP, ScalarVT))
22280         return DAG.getConstantFP(CstFP, DL, ScalarVT);
22281     }
22282   }
22283 
22284   // TODO: These transforms should not require the 'hasOneUse' restriction, but
22285   // there are regressions on multiple targets without it. We can end up with a
22286   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
22287   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
22288       VecOp.hasOneUse()) {
22289     // The vector index of the LSBs of the source depend on the endian-ness.
22290     bool IsLE = DAG.getDataLayout().isLittleEndian();
22291     unsigned ExtractIndex = IndexC->getZExtValue();
22292     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
22293     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
22294     SDValue BCSrc = VecOp.getOperand(0);
22295     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
22296       return DAG.getAnyExtOrTrunc(BCSrc, DL, ScalarVT);
22297 
22298     if (LegalTypes && BCSrc.getValueType().isInteger() &&
22299         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22300       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
22301       // trunc i64 X to i32
22302       SDValue X = BCSrc.getOperand(0);
22303       assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
22304              "Extract element and scalar to vector can't change element type "
22305              "from FP to integer.");
22306       unsigned XBitWidth = X.getValueSizeInBits();
22307       BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
22308 
22309       // An extract element return value type can be wider than its vector
22310       // operand element type. In that case, the high bits are undefined, so
22311       // it's possible that we may need to extend rather than truncate.
22312       if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
22313         assert(XBitWidth % VecEltBitWidth == 0 &&
22314                "Scalar bitwidth must be a multiple of vector element bitwidth");
22315         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
22316       }
22317     }
22318   }
22319 
22320   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
22321   // We only perform this optimization before the op legalization phase because
22322   // we may introduce new vector instructions which are not backed by TD
22323   // patterns. For example on AVX, extracting elements from a wide vector
22324   // without using extract_subvector. However, if we can find an underlying
22325   // scalar value, then we can always use that.
22326   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
22327     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
22328     // Find the new index to extract from.
22329     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
22330 
22331     // Extracting an undef index is undef.
22332     if (OrigElt == -1)
22333       return DAG.getUNDEF(ScalarVT);
22334 
22335     // Select the right vector half to extract from.
22336     SDValue SVInVec;
22337     if (OrigElt < (int)NumElts) {
22338       SVInVec = VecOp.getOperand(0);
22339     } else {
22340       SVInVec = VecOp.getOperand(1);
22341       OrigElt -= NumElts;
22342     }
22343 
22344     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
22345       SDValue InOp = SVInVec.getOperand(OrigElt);
22346       if (InOp.getValueType() != ScalarVT) {
22347         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22348         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
22349       }
22350 
22351       return InOp;
22352     }
22353 
22354     // FIXME: We should handle recursing on other vector shuffles and
22355     // scalar_to_vector here as well.
22356 
22357     if (!LegalOperations ||
22358         // FIXME: Should really be just isOperationLegalOrCustom.
22359         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
22360         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
22361       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
22362                          DAG.getVectorIdxConstant(OrigElt, DL));
22363     }
22364   }
22365 
22366   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
22367   // simplify it based on the (valid) extraction indices.
22368   if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
22369         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22370                Use->getOperand(0) == VecOp &&
22371                isa<ConstantSDNode>(Use->getOperand(1));
22372       })) {
22373     APInt DemandedElts = APInt::getZero(NumElts);
22374     for (SDNode *Use : VecOp->uses()) {
22375       auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
22376       if (CstElt->getAPIntValue().ult(NumElts))
22377         DemandedElts.setBit(CstElt->getZExtValue());
22378     }
22379     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
22380       // We simplified the vector operand of this extract element. If this
22381       // extract is not dead, visit it again so it is folded properly.
22382       if (N->getOpcode() != ISD::DELETED_NODE)
22383         AddToWorklist(N);
22384       return SDValue(N, 0);
22385     }
22386     APInt DemandedBits = APInt::getAllOnes(VecEltBitWidth);
22387     if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
22388       // We simplified the vector operand of this extract element. If this
22389       // extract is not dead, visit it again so it is folded properly.
22390       if (N->getOpcode() != ISD::DELETED_NODE)
22391         AddToWorklist(N);
22392       return SDValue(N, 0);
22393     }
22394   }
22395 
22396   if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
22397     return SDValue(N, 0);
22398 
22399   // Everything under here is trying to match an extract of a loaded value.
22400   // If the result of load has to be truncated, then it's not necessarily
22401   // profitable.
22402   bool BCNumEltsChanged = false;
22403   EVT ExtVT = VecVT.getVectorElementType();
22404   EVT LVT = ExtVT;
22405   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
22406     return SDValue();
22407 
22408   if (VecOp.getOpcode() == ISD::BITCAST) {
22409     // Don't duplicate a load with other uses.
22410     if (!VecOp.hasOneUse())
22411       return SDValue();
22412 
22413     EVT BCVT = VecOp.getOperand(0).getValueType();
22414     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
22415       return SDValue();
22416     if (NumElts != BCVT.getVectorNumElements())
22417       BCNumEltsChanged = true;
22418     VecOp = VecOp.getOperand(0);
22419     ExtVT = BCVT.getVectorElementType();
22420   }
22421 
22422   // extract (vector load $addr), i --> load $addr + i * size
22423   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
22424       ISD::isNormalLoad(VecOp.getNode()) &&
22425       !Index->hasPredecessor(VecOp.getNode())) {
22426     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
22427     if (VecLoad && VecLoad->isSimple())
22428       return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
22429   }
22430 
22431   // Perform only after legalization to ensure build_vector / vector_shuffle
22432   // optimizations have already been done.
22433   if (!LegalOperations || !IndexC)
22434     return SDValue();
22435 
22436   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
22437   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
22438   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
22439   int Elt = IndexC->getZExtValue();
22440   LoadSDNode *LN0 = nullptr;
22441   if (ISD::isNormalLoad(VecOp.getNode())) {
22442     LN0 = cast<LoadSDNode>(VecOp);
22443   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
22444              VecOp.getOperand(0).getValueType() == ExtVT &&
22445              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
22446     // Don't duplicate a load with other uses.
22447     if (!VecOp.hasOneUse())
22448       return SDValue();
22449 
22450     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
22451   }
22452   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
22453     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
22454     // =>
22455     // (load $addr+1*size)
22456 
22457     // Don't duplicate a load with other uses.
22458     if (!VecOp.hasOneUse())
22459       return SDValue();
22460 
22461     // If the bit convert changed the number of elements, it is unsafe
22462     // to examine the mask.
22463     if (BCNumEltsChanged)
22464       return SDValue();
22465 
22466     // Select the input vector, guarding against out of range extract vector.
22467     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
22468     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
22469 
22470     if (VecOp.getOpcode() == ISD::BITCAST) {
22471       // Don't duplicate a load with other uses.
22472       if (!VecOp.hasOneUse())
22473         return SDValue();
22474 
22475       VecOp = VecOp.getOperand(0);
22476     }
22477     if (ISD::isNormalLoad(VecOp.getNode())) {
22478       LN0 = cast<LoadSDNode>(VecOp);
22479       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
22480       Index = DAG.getConstant(Elt, DL, Index.getValueType());
22481     }
22482   } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
22483              VecVT.getVectorElementType() == ScalarVT &&
22484              (!LegalTypes ||
22485               TLI.isTypeLegal(
22486                   VecOp.getOperand(0).getValueType().getVectorElementType()))) {
22487     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
22488     //      -> extract_vector_elt a, 0
22489     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
22490     //      -> extract_vector_elt a, 1
22491     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
22492     //      -> extract_vector_elt b, 0
22493     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
22494     //      -> extract_vector_elt b, 1
22495     SDLoc SL(N);
22496     EVT ConcatVT = VecOp.getOperand(0).getValueType();
22497     unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
22498     SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, SL,
22499                                      Index.getValueType());
22500 
22501     SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
22502     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL,
22503                               ConcatVT.getVectorElementType(),
22504                               ConcatOp, NewIdx);
22505     return DAG.getNode(ISD::BITCAST, SL, ScalarVT, Elt);
22506   }
22507 
22508   // Make sure we found a non-volatile load and the extractelement is
22509   // the only use.
22510   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
22511     return SDValue();
22512 
22513   // If Idx was -1 above, Elt is going to be -1, so just return undef.
22514   if (Elt == -1)
22515     return DAG.getUNDEF(LVT);
22516 
22517   return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
22518 }
22519 
22520 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)22521 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
22522   // We perform this optimization post type-legalization because
22523   // the type-legalizer often scalarizes integer-promoted vectors.
22524   // Performing this optimization before may create bit-casts which
22525   // will be type-legalized to complex code sequences.
22526   // We perform this optimization only before the operation legalizer because we
22527   // may introduce illegal operations.
22528   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22529     return SDValue();
22530 
22531   unsigned NumInScalars = N->getNumOperands();
22532   SDLoc DL(N);
22533   EVT VT = N->getValueType(0);
22534 
22535   // Check to see if this is a BUILD_VECTOR of a bunch of values
22536   // which come from any_extend or zero_extend nodes. If so, we can create
22537   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
22538   // optimizations. We do not handle sign-extend because we can't fill the sign
22539   // using shuffles.
22540   EVT SourceType = MVT::Other;
22541   bool AllAnyExt = true;
22542 
22543   for (unsigned i = 0; i != NumInScalars; ++i) {
22544     SDValue In = N->getOperand(i);
22545     // Ignore undef inputs.
22546     if (In.isUndef()) continue;
22547 
22548     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
22549     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
22550 
22551     // Abort if the element is not an extension.
22552     if (!ZeroExt && !AnyExt) {
22553       SourceType = MVT::Other;
22554       break;
22555     }
22556 
22557     // The input is a ZeroExt or AnyExt. Check the original type.
22558     EVT InTy = In.getOperand(0).getValueType();
22559 
22560     // Check that all of the widened source types are the same.
22561     if (SourceType == MVT::Other)
22562       // First time.
22563       SourceType = InTy;
22564     else if (InTy != SourceType) {
22565       // Multiple income types. Abort.
22566       SourceType = MVT::Other;
22567       break;
22568     }
22569 
22570     // Check if all of the extends are ANY_EXTENDs.
22571     AllAnyExt &= AnyExt;
22572   }
22573 
22574   // In order to have valid types, all of the inputs must be extended from the
22575   // same source type and all of the inputs must be any or zero extend.
22576   // Scalar sizes must be a power of two.
22577   EVT OutScalarTy = VT.getScalarType();
22578   bool ValidTypes =
22579       SourceType != MVT::Other &&
22580       llvm::has_single_bit<uint32_t>(OutScalarTy.getSizeInBits()) &&
22581       llvm::has_single_bit<uint32_t>(SourceType.getSizeInBits());
22582 
22583   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
22584   // turn into a single shuffle instruction.
22585   if (!ValidTypes)
22586     return SDValue();
22587 
22588   // If we already have a splat buildvector, then don't fold it if it means
22589   // introducing zeros.
22590   if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
22591     return SDValue();
22592 
22593   bool isLE = DAG.getDataLayout().isLittleEndian();
22594   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
22595   assert(ElemRatio > 1 && "Invalid element size ratio");
22596   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
22597                                DAG.getConstant(0, DL, SourceType);
22598 
22599   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
22600   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
22601 
22602   // Populate the new build_vector
22603   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
22604     SDValue Cast = N->getOperand(i);
22605     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
22606             Cast.getOpcode() == ISD::ZERO_EXTEND ||
22607             Cast.isUndef()) && "Invalid cast opcode");
22608     SDValue In;
22609     if (Cast.isUndef())
22610       In = DAG.getUNDEF(SourceType);
22611     else
22612       In = Cast->getOperand(0);
22613     unsigned Index = isLE ? (i * ElemRatio) :
22614                             (i * ElemRatio + (ElemRatio - 1));
22615 
22616     assert(Index < Ops.size() && "Invalid index");
22617     Ops[Index] = In;
22618   }
22619 
22620   // The type of the new BUILD_VECTOR node.
22621   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
22622   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
22623          "Invalid vector size");
22624   // Check if the new vector type is legal.
22625   if (!isTypeLegal(VecVT) ||
22626       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
22627        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
22628     return SDValue();
22629 
22630   // Make the new BUILD_VECTOR.
22631   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
22632 
22633   // The new BUILD_VECTOR node has the potential to be further optimized.
22634   AddToWorklist(BV.getNode());
22635   // Bitcast to the desired type.
22636   return DAG.getBitcast(VT, BV);
22637 }
22638 
22639 // Simplify (build_vec (trunc $1)
22640 //                     (trunc (srl $1 half-width))
22641 //                     (trunc (srl $1 (2 * half-width))))
22642 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)22643 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
22644   assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
22645 
22646   EVT VT = N->getValueType(0);
22647 
22648   // Don't run this before LegalizeTypes if VT is legal.
22649   // Targets may have other preferences.
22650   if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
22651     return SDValue();
22652 
22653   // Only for little endian
22654   if (!DAG.getDataLayout().isLittleEndian())
22655     return SDValue();
22656 
22657   SDLoc DL(N);
22658   EVT OutScalarTy = VT.getScalarType();
22659   uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
22660 
22661   // Only for power of two types to be sure that bitcast works well
22662   if (!isPowerOf2_64(ScalarTypeBitsize))
22663     return SDValue();
22664 
22665   unsigned NumInScalars = N->getNumOperands();
22666 
22667   // Look through bitcasts
22668   auto PeekThroughBitcast = [](SDValue Op) {
22669     if (Op.getOpcode() == ISD::BITCAST)
22670       return Op.getOperand(0);
22671     return Op;
22672   };
22673 
22674   // The source value where all the parts are extracted.
22675   SDValue Src;
22676   for (unsigned i = 0; i != NumInScalars; ++i) {
22677     SDValue In = PeekThroughBitcast(N->getOperand(i));
22678     // Ignore undef inputs.
22679     if (In.isUndef()) continue;
22680 
22681     if (In.getOpcode() != ISD::TRUNCATE)
22682       return SDValue();
22683 
22684     In = PeekThroughBitcast(In.getOperand(0));
22685 
22686     if (In.getOpcode() != ISD::SRL) {
22687       // For now only build_vec without shuffling, handle shifts here in the
22688       // future.
22689       if (i != 0)
22690         return SDValue();
22691 
22692       Src = In;
22693     } else {
22694       // In is SRL
22695       SDValue part = PeekThroughBitcast(In.getOperand(0));
22696 
22697       if (!Src) {
22698         Src = part;
22699       } else if (Src != part) {
22700         // Vector parts do not stem from the same variable
22701         return SDValue();
22702       }
22703 
22704       SDValue ShiftAmtVal = In.getOperand(1);
22705       if (!isa<ConstantSDNode>(ShiftAmtVal))
22706         return SDValue();
22707 
22708       uint64_t ShiftAmt = In.getConstantOperandVal(1);
22709 
22710       // The extracted value is not extracted at the right position
22711       if (ShiftAmt != i * ScalarTypeBitsize)
22712         return SDValue();
22713     }
22714   }
22715 
22716   // Only cast if the size is the same
22717   if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
22718     return SDValue();
22719 
22720   return DAG.getBitcast(VT, Src);
22721 }
22722 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)22723 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
22724                                            ArrayRef<int> VectorMask,
22725                                            SDValue VecIn1, SDValue VecIn2,
22726                                            unsigned LeftIdx, bool DidSplitVec) {
22727   SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
22728 
22729   EVT VT = N->getValueType(0);
22730   EVT InVT1 = VecIn1.getValueType();
22731   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
22732 
22733   unsigned NumElems = VT.getVectorNumElements();
22734   unsigned ShuffleNumElems = NumElems;
22735 
22736   // If we artificially split a vector in two already, then the offsets in the
22737   // operands will all be based off of VecIn1, even those in VecIn2.
22738   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
22739 
22740   uint64_t VTSize = VT.getFixedSizeInBits();
22741   uint64_t InVT1Size = InVT1.getFixedSizeInBits();
22742   uint64_t InVT2Size = InVT2.getFixedSizeInBits();
22743 
22744   assert(InVT2Size <= InVT1Size &&
22745          "Inputs must be sorted to be in non-increasing vector size order.");
22746 
22747   // We can't generate a shuffle node with mismatched input and output types.
22748   // Try to make the types match the type of the output.
22749   if (InVT1 != VT || InVT2 != VT) {
22750     if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
22751       // If the output vector length is a multiple of both input lengths,
22752       // we can concatenate them and pad the rest with undefs.
22753       unsigned NumConcats = VTSize / InVT1Size;
22754       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
22755       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
22756       ConcatOps[0] = VecIn1;
22757       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
22758       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
22759       VecIn2 = SDValue();
22760     } else if (InVT1Size == VTSize * 2) {
22761       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
22762         return SDValue();
22763 
22764       if (!VecIn2.getNode()) {
22765         // If we only have one input vector, and it's twice the size of the
22766         // output, split it in two.
22767         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
22768                              DAG.getVectorIdxConstant(NumElems, DL));
22769         VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
22770         // Since we now have shorter input vectors, adjust the offset of the
22771         // second vector's start.
22772         Vec2Offset = NumElems;
22773       } else {
22774         assert(InVT2Size <= InVT1Size &&
22775                "Second input is not going to be larger than the first one.");
22776 
22777         // VecIn1 is wider than the output, and we have another, possibly
22778         // smaller input. Pad the smaller input with undefs, shuffle at the
22779         // input vector width, and extract the output.
22780         // The shuffle type is different than VT, so check legality again.
22781         if (LegalOperations &&
22782             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
22783           return SDValue();
22784 
22785         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
22786         // lower it back into a BUILD_VECTOR. So if the inserted type is
22787         // illegal, don't even try.
22788         if (InVT1 != InVT2) {
22789           if (!TLI.isTypeLegal(InVT2))
22790             return SDValue();
22791           VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
22792                                DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
22793         }
22794         ShuffleNumElems = NumElems * 2;
22795       }
22796     } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
22797       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
22798       ConcatOps[0] = VecIn2;
22799       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
22800     } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
22801       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems) ||
22802           !TLI.isTypeLegal(InVT1) || !TLI.isTypeLegal(InVT2))
22803         return SDValue();
22804       // If dest vector has less than two elements, then use shuffle and extract
22805       // from larger regs will cost even more.
22806       if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
22807         return SDValue();
22808       assert(InVT2Size <= InVT1Size &&
22809              "Second input is not going to be larger than the first one.");
22810 
22811       // VecIn1 is wider than the output, and we have another, possibly
22812       // smaller input. Pad the smaller input with undefs, shuffle at the
22813       // input vector width, and extract the output.
22814       // The shuffle type is different than VT, so check legality again.
22815       if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
22816         return SDValue();
22817 
22818       if (InVT1 != InVT2) {
22819         VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
22820                              DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
22821       }
22822       ShuffleNumElems = InVT1Size / VTSize * NumElems;
22823     } else {
22824       // TODO: Support cases where the length mismatch isn't exactly by a
22825       // factor of 2.
22826       // TODO: Move this check upwards, so that if we have bad type
22827       // mismatches, we don't create any DAG nodes.
22828       return SDValue();
22829     }
22830   }
22831 
22832   // Initialize mask to undef.
22833   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
22834 
22835   // Only need to run up to the number of elements actually used, not the
22836   // total number of elements in the shuffle - if we are shuffling a wider
22837   // vector, the high lanes should be set to undef.
22838   for (unsigned i = 0; i != NumElems; ++i) {
22839     if (VectorMask[i] <= 0)
22840       continue;
22841 
22842     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
22843     if (VectorMask[i] == (int)LeftIdx) {
22844       Mask[i] = ExtIndex;
22845     } else if (VectorMask[i] == (int)LeftIdx + 1) {
22846       Mask[i] = Vec2Offset + ExtIndex;
22847     }
22848   }
22849 
22850   // The type the input vectors may have changed above.
22851   InVT1 = VecIn1.getValueType();
22852 
22853   // If we already have a VecIn2, it should have the same type as VecIn1.
22854   // If we don't, get an undef/zero vector of the appropriate type.
22855   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
22856   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
22857 
22858   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
22859   if (ShuffleNumElems > NumElems)
22860     Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
22861 
22862   return Shuffle;
22863 }
22864 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)22865 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
22866   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
22867 
22868   // First, determine where the build vector is not undef.
22869   // TODO: We could extend this to handle zero elements as well as undefs.
22870   int NumBVOps = BV->getNumOperands();
22871   int ZextElt = -1;
22872   for (int i = 0; i != NumBVOps; ++i) {
22873     SDValue Op = BV->getOperand(i);
22874     if (Op.isUndef())
22875       continue;
22876     if (ZextElt == -1)
22877       ZextElt = i;
22878     else
22879       return SDValue();
22880   }
22881   // Bail out if there's no non-undef element.
22882   if (ZextElt == -1)
22883     return SDValue();
22884 
22885   // The build vector contains some number of undef elements and exactly
22886   // one other element. That other element must be a zero-extended scalar
22887   // extracted from a vector at a constant index to turn this into a shuffle.
22888   // Also, require that the build vector does not implicitly truncate/extend
22889   // its elements.
22890   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
22891   EVT VT = BV->getValueType(0);
22892   SDValue Zext = BV->getOperand(ZextElt);
22893   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
22894       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
22895       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
22896       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
22897     return SDValue();
22898 
22899   // The zero-extend must be a multiple of the source size, and we must be
22900   // building a vector of the same size as the source of the extract element.
22901   SDValue Extract = Zext.getOperand(0);
22902   unsigned DestSize = Zext.getValueSizeInBits();
22903   unsigned SrcSize = Extract.getValueSizeInBits();
22904   if (DestSize % SrcSize != 0 ||
22905       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
22906     return SDValue();
22907 
22908   // Create a shuffle mask that will combine the extracted element with zeros
22909   // and undefs.
22910   int ZextRatio = DestSize / SrcSize;
22911   int NumMaskElts = NumBVOps * ZextRatio;
22912   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
22913   for (int i = 0; i != NumMaskElts; ++i) {
22914     if (i / ZextRatio == ZextElt) {
22915       // The low bits of the (potentially translated) extracted element map to
22916       // the source vector. The high bits map to zero. We will use a zero vector
22917       // as the 2nd source operand of the shuffle, so use the 1st element of
22918       // that vector (mask value is number-of-elements) for the high bits.
22919       int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
22920       ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(1)
22921                                            : NumMaskElts;
22922     }
22923 
22924     // Undef elements of the build vector remain undef because we initialize
22925     // the shuffle mask with -1.
22926   }
22927 
22928   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
22929   // bitcast (shuffle V, ZeroVec, VectorMask)
22930   SDLoc DL(BV);
22931   EVT VecVT = Extract.getOperand(0).getValueType();
22932   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
22933   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22934   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
22935                                              ZeroVec, ShufMask, DAG);
22936   if (!Shuf)
22937     return SDValue();
22938   return DAG.getBitcast(VT, Shuf);
22939 }
22940 
22941 // FIXME: promote to STLExtras.
22942 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)22943 static auto getFirstIndexOf(R &&Range, const T &Val) {
22944   auto I = find(Range, Val);
22945   if (I == Range.end())
22946     return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
22947   return std::distance(Range.begin(), I);
22948 }
22949 
22950 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
22951 // operations. If the types of the vectors we're extracting from allow it,
22952 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)22953 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
22954   SDLoc DL(N);
22955   EVT VT = N->getValueType(0);
22956 
22957   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
22958   if (!isTypeLegal(VT))
22959     return SDValue();
22960 
22961   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
22962     return V;
22963 
22964   // May only combine to shuffle after legalize if shuffle is legal.
22965   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
22966     return SDValue();
22967 
22968   bool UsesZeroVector = false;
22969   unsigned NumElems = N->getNumOperands();
22970 
22971   // Record, for each element of the newly built vector, which input vector
22972   // that element comes from. -1 stands for undef, 0 for the zero vector,
22973   // and positive values for the input vectors.
22974   // VectorMask maps each element to its vector number, and VecIn maps vector
22975   // numbers to their initial SDValues.
22976 
22977   SmallVector<int, 8> VectorMask(NumElems, -1);
22978   SmallVector<SDValue, 8> VecIn;
22979   VecIn.push_back(SDValue());
22980 
22981   for (unsigned i = 0; i != NumElems; ++i) {
22982     SDValue Op = N->getOperand(i);
22983 
22984     if (Op.isUndef())
22985       continue;
22986 
22987     // See if we can use a blend with a zero vector.
22988     // TODO: Should we generalize this to a blend with an arbitrary constant
22989     // vector?
22990     if (isNullConstant(Op) || isNullFPConstant(Op)) {
22991       UsesZeroVector = true;
22992       VectorMask[i] = 0;
22993       continue;
22994     }
22995 
22996     // Not an undef or zero. If the input is something other than an
22997     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
22998     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
22999         !isa<ConstantSDNode>(Op.getOperand(1)))
23000       return SDValue();
23001     SDValue ExtractedFromVec = Op.getOperand(0);
23002 
23003     if (ExtractedFromVec.getValueType().isScalableVector())
23004       return SDValue();
23005 
23006     const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
23007     if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
23008       return SDValue();
23009 
23010     // All inputs must have the same element type as the output.
23011     if (VT.getVectorElementType() !=
23012         ExtractedFromVec.getValueType().getVectorElementType())
23013       return SDValue();
23014 
23015     // Have we seen this input vector before?
23016     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
23017     // a map back from SDValues to numbers isn't worth it.
23018     int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
23019     if (Idx == -1) { // A new source vector?
23020       Idx = VecIn.size();
23021       VecIn.push_back(ExtractedFromVec);
23022     }
23023 
23024     VectorMask[i] = Idx;
23025   }
23026 
23027   // If we didn't find at least one input vector, bail out.
23028   if (VecIn.size() < 2)
23029     return SDValue();
23030 
23031   // If all the Operands of BUILD_VECTOR extract from same
23032   // vector, then split the vector efficiently based on the maximum
23033   // vector access index and adjust the VectorMask and
23034   // VecIn accordingly.
23035   bool DidSplitVec = false;
23036   if (VecIn.size() == 2) {
23037     unsigned MaxIndex = 0;
23038     unsigned NearestPow2 = 0;
23039     SDValue Vec = VecIn.back();
23040     EVT InVT = Vec.getValueType();
23041     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
23042 
23043     for (unsigned i = 0; i < NumElems; i++) {
23044       if (VectorMask[i] <= 0)
23045         continue;
23046       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
23047       IndexVec[i] = Index;
23048       MaxIndex = std::max(MaxIndex, Index);
23049     }
23050 
23051     NearestPow2 = PowerOf2Ceil(MaxIndex);
23052     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
23053         NumElems * 2 < NearestPow2) {
23054       unsigned SplitSize = NearestPow2 / 2;
23055       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
23056                                      InVT.getVectorElementType(), SplitSize);
23057       if (TLI.isTypeLegal(SplitVT) &&
23058           SplitSize + SplitVT.getVectorNumElements() <=
23059               InVT.getVectorNumElements()) {
23060         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
23061                                      DAG.getVectorIdxConstant(SplitSize, DL));
23062         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
23063                                      DAG.getVectorIdxConstant(0, DL));
23064         VecIn.pop_back();
23065         VecIn.push_back(VecIn1);
23066         VecIn.push_back(VecIn2);
23067         DidSplitVec = true;
23068 
23069         for (unsigned i = 0; i < NumElems; i++) {
23070           if (VectorMask[i] <= 0)
23071             continue;
23072           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
23073         }
23074       }
23075     }
23076   }
23077 
23078   // Sort input vectors by decreasing vector element count,
23079   // while preserving the relative order of equally-sized vectors.
23080   // Note that we keep the first "implicit zero vector as-is.
23081   SmallVector<SDValue, 8> SortedVecIn(VecIn);
23082   llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
23083                     [](const SDValue &a, const SDValue &b) {
23084                       return a.getValueType().getVectorNumElements() >
23085                              b.getValueType().getVectorNumElements();
23086                     });
23087 
23088   // We now also need to rebuild the VectorMask, because it referenced element
23089   // order in VecIn, and we just sorted them.
23090   for (int &SourceVectorIndex : VectorMask) {
23091     if (SourceVectorIndex <= 0)
23092       continue;
23093     unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
23094     assert(Idx > 0 && Idx < SortedVecIn.size() &&
23095            VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
23096     SourceVectorIndex = Idx;
23097   }
23098 
23099   VecIn = std::move(SortedVecIn);
23100 
23101   // TODO: Should this fire if some of the input vectors has illegal type (like
23102   // it does now), or should we let legalization run its course first?
23103 
23104   // Shuffle phase:
23105   // Take pairs of vectors, and shuffle them so that the result has elements
23106   // from these vectors in the correct places.
23107   // For example, given:
23108   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
23109   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
23110   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
23111   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
23112   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
23113   // We will generate:
23114   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
23115   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
23116   SmallVector<SDValue, 4> Shuffles;
23117   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
23118     unsigned LeftIdx = 2 * In + 1;
23119     SDValue VecLeft = VecIn[LeftIdx];
23120     SDValue VecRight =
23121         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
23122 
23123     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
23124                                                 VecRight, LeftIdx, DidSplitVec))
23125       Shuffles.push_back(Shuffle);
23126     else
23127       return SDValue();
23128   }
23129 
23130   // If we need the zero vector as an "ingredient" in the blend tree, add it
23131   // to the list of shuffles.
23132   if (UsesZeroVector)
23133     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
23134                                       : DAG.getConstantFP(0.0, DL, VT));
23135 
23136   // If we only have one shuffle, we're done.
23137   if (Shuffles.size() == 1)
23138     return Shuffles[0];
23139 
23140   // Update the vector mask to point to the post-shuffle vectors.
23141   for (int &Vec : VectorMask)
23142     if (Vec == 0)
23143       Vec = Shuffles.size() - 1;
23144     else
23145       Vec = (Vec - 1) / 2;
23146 
23147   // More than one shuffle. Generate a binary tree of blends, e.g. if from
23148   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
23149   // generate:
23150   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
23151   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
23152   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
23153   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
23154   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
23155   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
23156   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
23157 
23158   // Make sure the initial size of the shuffle list is even.
23159   if (Shuffles.size() % 2)
23160     Shuffles.push_back(DAG.getUNDEF(VT));
23161 
23162   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
23163     if (CurSize % 2) {
23164       Shuffles[CurSize] = DAG.getUNDEF(VT);
23165       CurSize++;
23166     }
23167     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
23168       int Left = 2 * In;
23169       int Right = 2 * In + 1;
23170       SmallVector<int, 8> Mask(NumElems, -1);
23171       SDValue L = Shuffles[Left];
23172       ArrayRef<int> LMask;
23173       bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
23174                            L.use_empty() && L.getOperand(1).isUndef() &&
23175                            L.getOperand(0).getValueType() == L.getValueType();
23176       if (IsLeftShuffle) {
23177         LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask();
23178         L = L.getOperand(0);
23179       }
23180       SDValue R = Shuffles[Right];
23181       ArrayRef<int> RMask;
23182       bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
23183                             R.use_empty() && R.getOperand(1).isUndef() &&
23184                             R.getOperand(0).getValueType() == R.getValueType();
23185       if (IsRightShuffle) {
23186         RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask();
23187         R = R.getOperand(0);
23188       }
23189       for (unsigned I = 0; I != NumElems; ++I) {
23190         if (VectorMask[I] == Left) {
23191           Mask[I] = I;
23192           if (IsLeftShuffle)
23193             Mask[I] = LMask[I];
23194           VectorMask[I] = In;
23195         } else if (VectorMask[I] == Right) {
23196           Mask[I] = I + NumElems;
23197           if (IsRightShuffle)
23198             Mask[I] = RMask[I] + NumElems;
23199           VectorMask[I] = In;
23200         }
23201       }
23202 
23203       Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask);
23204     }
23205   }
23206   return Shuffles[0];
23207 }
23208 
23209 // Try to turn a build vector of zero extends of extract vector elts into a
23210 // a vector zero extend and possibly an extract subvector.
23211 // TODO: Support sign extend?
23212 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)23213 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
23214   if (LegalOperations)
23215     return SDValue();
23216 
23217   EVT VT = N->getValueType(0);
23218 
23219   bool FoundZeroExtend = false;
23220   SDValue Op0 = N->getOperand(0);
23221   auto checkElem = [&](SDValue Op) -> int64_t {
23222     unsigned Opc = Op.getOpcode();
23223     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
23224     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
23225         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23226         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
23227       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
23228         return C->getZExtValue();
23229     return -1;
23230   };
23231 
23232   // Make sure the first element matches
23233   // (zext (extract_vector_elt X, C))
23234   // Offset must be a constant multiple of the
23235   // known-minimum vector length of the result type.
23236   int64_t Offset = checkElem(Op0);
23237   if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
23238     return SDValue();
23239 
23240   unsigned NumElems = N->getNumOperands();
23241   SDValue In = Op0.getOperand(0).getOperand(0);
23242   EVT InSVT = In.getValueType().getScalarType();
23243   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
23244 
23245   // Don't create an illegal input type after type legalization.
23246   if (LegalTypes && !TLI.isTypeLegal(InVT))
23247     return SDValue();
23248 
23249   // Ensure all the elements come from the same vector and are adjacent.
23250   for (unsigned i = 1; i != NumElems; ++i) {
23251     if ((Offset + i) != checkElem(N->getOperand(i)))
23252       return SDValue();
23253   }
23254 
23255   SDLoc DL(N);
23256   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
23257                    Op0.getOperand(0).getOperand(1));
23258   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
23259                      VT, In);
23260 }
23261 
23262 // If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
23263 // and all other elements being constant zero's, granularize the BUILD_VECTOR's
23264 // element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
23265 // This patten can appear during legalization.
23266 //
23267 // NOTE: This can be generalized to allow more than a single
23268 //       non-constant-zero op, UNDEF's, and to be KnownBits-based,
convertBuildVecZextToBuildVecWithZeros(SDNode * N)23269 SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
23270   // Don't run this after legalization. Targets may have other preferences.
23271   if (Level >= AfterLegalizeDAG)
23272     return SDValue();
23273 
23274   // FIXME: support big-endian.
23275   if (DAG.getDataLayout().isBigEndian())
23276     return SDValue();
23277 
23278   EVT VT = N->getValueType(0);
23279   EVT OpVT = N->getOperand(0).getValueType();
23280   assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
23281 
23282   EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
23283 
23284   if (!TLI.isTypeLegal(OpIntVT) ||
23285       (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT)))
23286     return SDValue();
23287 
23288   unsigned EltBitwidth = VT.getScalarSizeInBits();
23289   // NOTE: the actual width of operands may be wider than that!
23290 
23291   // Analyze all operands of this BUILD_VECTOR. What is the largest number of
23292   // active bits they all have? We'll want to truncate them all to that width.
23293   unsigned ActiveBits = 0;
23294   APInt KnownZeroOps(VT.getVectorNumElements(), 0);
23295   for (auto I : enumerate(N->ops())) {
23296     SDValue Op = I.value();
23297     // FIXME: support UNDEF elements?
23298     if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) {
23299       unsigned OpActiveBits =
23300           Cst->getAPIntValue().trunc(EltBitwidth).getActiveBits();
23301       if (OpActiveBits == 0) {
23302         KnownZeroOps.setBit(I.index());
23303         continue;
23304       }
23305       // Profitability check: don't allow non-zero constant operands.
23306       return SDValue();
23307     }
23308     // Profitability check: there must only be a single non-zero operand,
23309     // and it must be the first operand of the BUILD_VECTOR.
23310     if (I.index() != 0)
23311       return SDValue();
23312     // The operand must be a zero-extension itself.
23313     // FIXME: this could be generalized to known leading zeros check.
23314     if (Op.getOpcode() != ISD::ZERO_EXTEND)
23315       return SDValue();
23316     unsigned CurrActiveBits =
23317         Op.getOperand(0).getValueSizeInBits().getFixedValue();
23318     assert(!ActiveBits && "Already encountered non-constant-zero operand?");
23319     ActiveBits = CurrActiveBits;
23320     // We want to at least halve the element size.
23321     if (2 * ActiveBits > EltBitwidth)
23322       return SDValue();
23323   }
23324 
23325   // This BUILD_VECTOR must have at least one non-constant-zero operand.
23326   if (ActiveBits == 0)
23327     return SDValue();
23328 
23329   // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
23330   // into how many chunks can we split our element width?
23331   EVT NewScalarIntVT, NewIntVT;
23332   std::optional<unsigned> Factor;
23333   // We can split the element into at least two chunks, but not into more
23334   // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
23335   // for which the element width is a multiple of it,
23336   // and the resulting types/operations on that chunk width are legal.
23337   assert(2 * ActiveBits <= EltBitwidth &&
23338          "We know that half or less bits of the element are active.");
23339   for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
23340     if (EltBitwidth % Scale != 0)
23341       continue;
23342     unsigned ChunkBitwidth = EltBitwidth / Scale;
23343     assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
23344     NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
23345     NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
23346                                 Scale * N->getNumOperands());
23347     if (!TLI.isTypeLegal(NewScalarIntVT) || !TLI.isTypeLegal(NewIntVT) ||
23348         (LegalOperations &&
23349          !(TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
23350            TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT))))
23351       continue;
23352     Factor = Scale;
23353     break;
23354   }
23355   if (!Factor)
23356     return SDValue();
23357 
23358   SDLoc DL(N);
23359   SDValue ZeroOp = DAG.getConstant(0, DL, NewScalarIntVT);
23360 
23361   // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
23362   SmallVector<SDValue, 16> NewOps;
23363   NewOps.reserve(NewIntVT.getVectorNumElements());
23364   for (auto I : enumerate(N->ops())) {
23365     SDValue Op = I.value();
23366     assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
23367     unsigned SrcOpIdx = I.index();
23368     if (KnownZeroOps[SrcOpIdx]) {
23369       NewOps.append(*Factor, ZeroOp);
23370       continue;
23371     }
23372     Op = DAG.getBitcast(OpIntVT, Op);
23373     Op = DAG.getNode(ISD::TRUNCATE, DL, NewScalarIntVT, Op);
23374     NewOps.emplace_back(Op);
23375     NewOps.append(*Factor - 1, ZeroOp);
23376   }
23377   assert(NewOps.size() == NewIntVT.getVectorNumElements());
23378   SDValue NewBV = DAG.getBuildVector(NewIntVT, DL, NewOps);
23379   NewBV = DAG.getBitcast(VT, NewBV);
23380   return NewBV;
23381 }
23382 
visitBUILD_VECTOR(SDNode * N)23383 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
23384   EVT VT = N->getValueType(0);
23385 
23386   // A vector built entirely of undefs is undef.
23387   if (ISD::allOperandsUndef(N))
23388     return DAG.getUNDEF(VT);
23389 
23390   // If this is a splat of a bitcast from another vector, change to a
23391   // concat_vector.
23392   // For example:
23393   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
23394   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
23395   //
23396   // If X is a build_vector itself, the concat can become a larger build_vector.
23397   // TODO: Maybe this is useful for non-splat too?
23398   if (!LegalOperations) {
23399     if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
23400       Splat = peekThroughBitcasts(Splat);
23401       EVT SrcVT = Splat.getValueType();
23402       if (SrcVT.isVector()) {
23403         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
23404         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
23405                                      SrcVT.getVectorElementType(), NumElts);
23406         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
23407           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
23408           SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
23409                                        NewVT, Ops);
23410           return DAG.getBitcast(VT, Concat);
23411         }
23412       }
23413     }
23414   }
23415 
23416   // Check if we can express BUILD VECTOR via subvector extract.
23417   if (!LegalTypes && (N->getNumOperands() > 1)) {
23418     SDValue Op0 = N->getOperand(0);
23419     auto checkElem = [&](SDValue Op) -> uint64_t {
23420       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
23421           (Op0.getOperand(0) == Op.getOperand(0)))
23422         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
23423           return CNode->getZExtValue();
23424       return -1;
23425     };
23426 
23427     int Offset = checkElem(Op0);
23428     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
23429       if (Offset + i != checkElem(N->getOperand(i))) {
23430         Offset = -1;
23431         break;
23432       }
23433     }
23434 
23435     if ((Offset == 0) &&
23436         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
23437       return Op0.getOperand(0);
23438     if ((Offset != -1) &&
23439         ((Offset % N->getValueType(0).getVectorNumElements()) ==
23440          0)) // IDX must be multiple of output size.
23441       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
23442                          Op0.getOperand(0), Op0.getOperand(1));
23443   }
23444 
23445   if (SDValue V = convertBuildVecZextToZext(N))
23446     return V;
23447 
23448   if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
23449     return V;
23450 
23451   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
23452     return V;
23453 
23454   if (SDValue V = reduceBuildVecTruncToBitCast(N))
23455     return V;
23456 
23457   if (SDValue V = reduceBuildVecToShuffle(N))
23458     return V;
23459 
23460   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
23461   // Do this late as some of the above may replace the splat.
23462   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
23463     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
23464       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
23465       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
23466     }
23467 
23468   return SDValue();
23469 }
23470 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)23471 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
23472   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23473   EVT OpVT = N->getOperand(0).getValueType();
23474 
23475   // If the operands are legal vectors, leave them alone.
23476   if (TLI.isTypeLegal(OpVT) || OpVT.isScalableVector())
23477     return SDValue();
23478 
23479   SDLoc DL(N);
23480   EVT VT = N->getValueType(0);
23481   SmallVector<SDValue, 8> Ops;
23482 
23483   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
23484   SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
23485 
23486   // Keep track of what we encounter.
23487   bool AnyInteger = false;
23488   bool AnyFP = false;
23489   for (const SDValue &Op : N->ops()) {
23490     if (ISD::BITCAST == Op.getOpcode() &&
23491         !Op.getOperand(0).getValueType().isVector())
23492       Ops.push_back(Op.getOperand(0));
23493     else if (ISD::UNDEF == Op.getOpcode())
23494       Ops.push_back(ScalarUndef);
23495     else
23496       return SDValue();
23497 
23498     // Note whether we encounter an integer or floating point scalar.
23499     // If it's neither, bail out, it could be something weird like x86mmx.
23500     EVT LastOpVT = Ops.back().getValueType();
23501     if (LastOpVT.isFloatingPoint())
23502       AnyFP = true;
23503     else if (LastOpVT.isInteger())
23504       AnyInteger = true;
23505     else
23506       return SDValue();
23507   }
23508 
23509   // If any of the operands is a floating point scalar bitcast to a vector,
23510   // use floating point types throughout, and bitcast everything.
23511   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
23512   if (AnyFP) {
23513     SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
23514     ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
23515     if (AnyInteger) {
23516       for (SDValue &Op : Ops) {
23517         if (Op.getValueType() == SVT)
23518           continue;
23519         if (Op.isUndef())
23520           Op = ScalarUndef;
23521         else
23522           Op = DAG.getBitcast(SVT, Op);
23523       }
23524     }
23525   }
23526 
23527   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
23528                                VT.getSizeInBits() / SVT.getSizeInBits());
23529   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
23530 }
23531 
23532 // Attempt to merge nested concat_vectors/undefs.
23533 // Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
23534 //  --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
combineConcatVectorOfConcatVectors(SDNode * N,SelectionDAG & DAG)23535 static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
23536                                                   SelectionDAG &DAG) {
23537   EVT VT = N->getValueType(0);
23538 
23539   // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
23540   EVT SubVT;
23541   SDValue FirstConcat;
23542   for (const SDValue &Op : N->ops()) {
23543     if (Op.isUndef())
23544       continue;
23545     if (Op.getOpcode() != ISD::CONCAT_VECTORS)
23546       return SDValue();
23547     if (!FirstConcat) {
23548       SubVT = Op.getOperand(0).getValueType();
23549       if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
23550         return SDValue();
23551       FirstConcat = Op;
23552       continue;
23553     }
23554     if (SubVT != Op.getOperand(0).getValueType())
23555       return SDValue();
23556   }
23557   assert(FirstConcat && "Concat of all-undefs found");
23558 
23559   SmallVector<SDValue> ConcatOps;
23560   for (const SDValue &Op : N->ops()) {
23561     if (Op.isUndef()) {
23562       ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
23563       continue;
23564     }
23565     ConcatOps.append(Op->op_begin(), Op->op_end());
23566   }
23567   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
23568 }
23569 
23570 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
23571 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
23572 // most two distinct vectors the same size as the result, attempt to turn this
23573 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)23574 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
23575   EVT VT = N->getValueType(0);
23576   EVT OpVT = N->getOperand(0).getValueType();
23577 
23578   // We currently can't generate an appropriate shuffle for a scalable vector.
23579   if (VT.isScalableVector())
23580     return SDValue();
23581 
23582   int NumElts = VT.getVectorNumElements();
23583   int NumOpElts = OpVT.getVectorNumElements();
23584 
23585   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
23586   SmallVector<int, 8> Mask;
23587 
23588   for (SDValue Op : N->ops()) {
23589     Op = peekThroughBitcasts(Op);
23590 
23591     // UNDEF nodes convert to UNDEF shuffle mask values.
23592     if (Op.isUndef()) {
23593       Mask.append((unsigned)NumOpElts, -1);
23594       continue;
23595     }
23596 
23597     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
23598       return SDValue();
23599 
23600     // What vector are we extracting the subvector from and at what index?
23601     SDValue ExtVec = Op.getOperand(0);
23602     int ExtIdx = Op.getConstantOperandVal(1);
23603 
23604     // We want the EVT of the original extraction to correctly scale the
23605     // extraction index.
23606     EVT ExtVT = ExtVec.getValueType();
23607     ExtVec = peekThroughBitcasts(ExtVec);
23608 
23609     // UNDEF nodes convert to UNDEF shuffle mask values.
23610     if (ExtVec.isUndef()) {
23611       Mask.append((unsigned)NumOpElts, -1);
23612       continue;
23613     }
23614 
23615     // Ensure that we are extracting a subvector from a vector the same
23616     // size as the result.
23617     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
23618       return SDValue();
23619 
23620     // Scale the subvector index to account for any bitcast.
23621     int NumExtElts = ExtVT.getVectorNumElements();
23622     if (0 == (NumExtElts % NumElts))
23623       ExtIdx /= (NumExtElts / NumElts);
23624     else if (0 == (NumElts % NumExtElts))
23625       ExtIdx *= (NumElts / NumExtElts);
23626     else
23627       return SDValue();
23628 
23629     // At most we can reference 2 inputs in the final shuffle.
23630     if (SV0.isUndef() || SV0 == ExtVec) {
23631       SV0 = ExtVec;
23632       for (int i = 0; i != NumOpElts; ++i)
23633         Mask.push_back(i + ExtIdx);
23634     } else if (SV1.isUndef() || SV1 == ExtVec) {
23635       SV1 = ExtVec;
23636       for (int i = 0; i != NumOpElts; ++i)
23637         Mask.push_back(i + ExtIdx + NumElts);
23638     } else {
23639       return SDValue();
23640     }
23641   }
23642 
23643   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23644   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
23645                                      DAG.getBitcast(VT, SV1), Mask, DAG);
23646 }
23647 
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)23648 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
23649   unsigned CastOpcode = N->getOperand(0).getOpcode();
23650   switch (CastOpcode) {
23651   case ISD::SINT_TO_FP:
23652   case ISD::UINT_TO_FP:
23653   case ISD::FP_TO_SINT:
23654   case ISD::FP_TO_UINT:
23655     // TODO: Allow more opcodes?
23656     //  case ISD::BITCAST:
23657     //  case ISD::TRUNCATE:
23658     //  case ISD::ZERO_EXTEND:
23659     //  case ISD::SIGN_EXTEND:
23660     //  case ISD::FP_EXTEND:
23661     break;
23662   default:
23663     return SDValue();
23664   }
23665 
23666   EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
23667   if (!SrcVT.isVector())
23668     return SDValue();
23669 
23670   // All operands of the concat must be the same kind of cast from the same
23671   // source type.
23672   SmallVector<SDValue, 4> SrcOps;
23673   for (SDValue Op : N->ops()) {
23674     if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
23675         Op.getOperand(0).getValueType() != SrcVT)
23676       return SDValue();
23677     SrcOps.push_back(Op.getOperand(0));
23678   }
23679 
23680   // The wider cast must be supported by the target. This is unusual because
23681   // the operation support type parameter depends on the opcode. In addition,
23682   // check the other type in the cast to make sure this is really legal.
23683   EVT VT = N->getValueType(0);
23684   EVT SrcEltVT = SrcVT.getVectorElementType();
23685   ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
23686   EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
23687   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23688   switch (CastOpcode) {
23689   case ISD::SINT_TO_FP:
23690   case ISD::UINT_TO_FP:
23691     if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
23692         !TLI.isTypeLegal(VT))
23693       return SDValue();
23694     break;
23695   case ISD::FP_TO_SINT:
23696   case ISD::FP_TO_UINT:
23697     if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
23698         !TLI.isTypeLegal(ConcatSrcVT))
23699       return SDValue();
23700     break;
23701   default:
23702     llvm_unreachable("Unexpected cast opcode");
23703   }
23704 
23705   // concat (cast X), (cast Y)... -> cast (concat X, Y...)
23706   SDLoc DL(N);
23707   SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
23708   return DAG.getNode(CastOpcode, DL, VT, NewConcat);
23709 }
23710 
23711 // See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
23712 // the operands is a SHUFFLE_VECTOR, and all other operands are also operands
23713 // to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
combineConcatVectorOfShuffleAndItsOperands(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)23714 static SDValue combineConcatVectorOfShuffleAndItsOperands(
23715     SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
23716     bool LegalOperations) {
23717   EVT VT = N->getValueType(0);
23718   EVT OpVT = N->getOperand(0).getValueType();
23719   if (VT.isScalableVector())
23720     return SDValue();
23721 
23722   // For now, only allow simple 2-operand concatenations.
23723   if (N->getNumOperands() != 2)
23724     return SDValue();
23725 
23726   // Don't create illegal types/shuffles when not allowed to.
23727   if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
23728       (LegalOperations &&
23729        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT)))
23730     return SDValue();
23731 
23732   // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
23733   // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
23734   // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
23735   // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
23736   // (4) and for now, the SHUFFLE_VECTOR must be unary.
23737   ShuffleVectorSDNode *SVN = nullptr;
23738   for (SDValue Op : N->ops()) {
23739     if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Op);
23740         CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) &&
23741         all_of(N->ops(), [CurSVN](SDValue Op) {
23742           // FIXME: can we allow UNDEF operands?
23743           return !Op.isUndef() &&
23744                  (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op));
23745         })) {
23746       SVN = CurSVN;
23747       break;
23748     }
23749   }
23750   if (!SVN)
23751     return SDValue();
23752 
23753   // We are going to pad the shuffle operands, so any indice, that was picking
23754   // from the second operand, must be adjusted.
23755   SmallVector<int, 16> AdjustedMask;
23756   AdjustedMask.reserve(SVN->getMask().size());
23757   assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
23758   append_range(AdjustedMask, SVN->getMask());
23759 
23760   // Identity masks for the operands of the (padded) shuffle.
23761   SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
23762   MutableArrayRef<int> FirstShufOpIdentityMask =
23763       MutableArrayRef<int>(IdentityMask)
23764           .take_front(OpVT.getVectorNumElements());
23765   MutableArrayRef<int> SecondShufOpIdentityMask =
23766       MutableArrayRef<int>(IdentityMask).take_back(OpVT.getVectorNumElements());
23767   std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0);
23768   std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(),
23769             VT.getVectorNumElements());
23770 
23771   // New combined shuffle mask.
23772   SmallVector<int, 32> Mask;
23773   Mask.reserve(VT.getVectorNumElements());
23774   for (SDValue Op : N->ops()) {
23775     assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
23776     if (Op.getNode() == SVN) {
23777       append_range(Mask, AdjustedMask);
23778       continue;
23779     }
23780     if (Op == SVN->getOperand(0)) {
23781       append_range(Mask, FirstShufOpIdentityMask);
23782       continue;
23783     }
23784     if (Op == SVN->getOperand(1)) {
23785       append_range(Mask, SecondShufOpIdentityMask);
23786       continue;
23787     }
23788     llvm_unreachable("Unexpected operand!");
23789   }
23790 
23791   // Don't create illegal shuffle masks.
23792   if (!TLI.isShuffleMaskLegal(Mask, VT))
23793     return SDValue();
23794 
23795   // Pad the shuffle operands with UNDEF.
23796   SDLoc dl(N);
23797   std::array<SDValue, 2> ShufOps;
23798   for (auto I : zip(SVN->ops(), ShufOps)) {
23799     SDValue ShufOp = std::get<0>(I);
23800     SDValue &NewShufOp = std::get<1>(I);
23801     if (ShufOp.isUndef())
23802       NewShufOp = DAG.getUNDEF(VT);
23803     else {
23804       SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
23805                                           DAG.getUNDEF(OpVT));
23806       ShufOpParts[0] = ShufOp;
23807       NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts);
23808     }
23809   }
23810   // Finally, create the new wide shuffle.
23811   return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
23812 }
23813 
visitCONCAT_VECTORS(SDNode * N)23814 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
23815   // If we only have one input vector, we don't need to do any concatenation.
23816   if (N->getNumOperands() == 1)
23817     return N->getOperand(0);
23818 
23819   // Check if all of the operands are undefs.
23820   EVT VT = N->getValueType(0);
23821   if (ISD::allOperandsUndef(N))
23822     return DAG.getUNDEF(VT);
23823 
23824   // Optimize concat_vectors where all but the first of the vectors are undef.
23825   if (all_of(drop_begin(N->ops()),
23826              [](const SDValue &Op) { return Op.isUndef(); })) {
23827     SDValue In = N->getOperand(0);
23828     assert(In.getValueType().isVector() && "Must concat vectors");
23829 
23830     // If the input is a concat_vectors, just make a larger concat by padding
23831     // with smaller undefs.
23832     //
23833     // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
23834     // here could cause an infinite loop. That legalizing happens when LegalDAG
23835     // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
23836     // scalable.
23837     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
23838         !(LegalDAG && In.getValueType().isScalableVector())) {
23839       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
23840       SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
23841       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
23842       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
23843     }
23844 
23845     SDValue Scalar = peekThroughOneUseBitcasts(In);
23846 
23847     // concat_vectors(scalar_to_vector(scalar), undef) ->
23848     //     scalar_to_vector(scalar)
23849     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23850          Scalar.hasOneUse()) {
23851       EVT SVT = Scalar.getValueType().getVectorElementType();
23852       if (SVT == Scalar.getOperand(0).getValueType())
23853         Scalar = Scalar.getOperand(0);
23854     }
23855 
23856     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
23857     if (!Scalar.getValueType().isVector()) {
23858       // If the bitcast type isn't legal, it might be a trunc of a legal type;
23859       // look through the trunc so we can still do the transform:
23860       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
23861       if (Scalar->getOpcode() == ISD::TRUNCATE &&
23862           !TLI.isTypeLegal(Scalar.getValueType()) &&
23863           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
23864         Scalar = Scalar->getOperand(0);
23865 
23866       EVT SclTy = Scalar.getValueType();
23867 
23868       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
23869         return SDValue();
23870 
23871       // Bail out if the vector size is not a multiple of the scalar size.
23872       if (VT.getSizeInBits() % SclTy.getSizeInBits())
23873         return SDValue();
23874 
23875       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
23876       if (VNTNumElms < 2)
23877         return SDValue();
23878 
23879       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
23880       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
23881         return SDValue();
23882 
23883       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
23884       return DAG.getBitcast(VT, Res);
23885     }
23886   }
23887 
23888   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
23889   // We have already tested above for an UNDEF only concatenation.
23890   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
23891   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
23892   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
23893     return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
23894   };
23895   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
23896     SmallVector<SDValue, 8> Opnds;
23897     EVT SVT = VT.getScalarType();
23898 
23899     EVT MinVT = SVT;
23900     if (!SVT.isFloatingPoint()) {
23901       // If BUILD_VECTOR are from built from integer, they may have different
23902       // operand types. Get the smallest type and truncate all operands to it.
23903       bool FoundMinVT = false;
23904       for (const SDValue &Op : N->ops())
23905         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
23906           EVT OpSVT = Op.getOperand(0).getValueType();
23907           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
23908           FoundMinVT = true;
23909         }
23910       assert(FoundMinVT && "Concat vector type mismatch");
23911     }
23912 
23913     for (const SDValue &Op : N->ops()) {
23914       EVT OpVT = Op.getValueType();
23915       unsigned NumElts = OpVT.getVectorNumElements();
23916 
23917       if (ISD::UNDEF == Op.getOpcode())
23918         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
23919 
23920       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
23921         if (SVT.isFloatingPoint()) {
23922           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
23923           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
23924         } else {
23925           for (unsigned i = 0; i != NumElts; ++i)
23926             Opnds.push_back(
23927                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
23928         }
23929       }
23930     }
23931 
23932     assert(VT.getVectorNumElements() == Opnds.size() &&
23933            "Concat vector type mismatch");
23934     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
23935   }
23936 
23937   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
23938   // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
23939   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
23940     return V;
23941 
23942   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
23943     // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
23944     if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
23945       return V;
23946 
23947     // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
23948     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
23949       return V;
23950   }
23951 
23952   if (SDValue V = combineConcatVectorOfCasts(N, DAG))
23953     return V;
23954 
23955   if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
23956           N, DAG, TLI, LegalTypes, LegalOperations))
23957     return V;
23958 
23959   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
23960   // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
23961   // operands and look for a CONCAT operations that place the incoming vectors
23962   // at the exact same location.
23963   //
23964   // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
23965   SDValue SingleSource = SDValue();
23966   unsigned PartNumElem =
23967       N->getOperand(0).getValueType().getVectorMinNumElements();
23968 
23969   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
23970     SDValue Op = N->getOperand(i);
23971 
23972     if (Op.isUndef())
23973       continue;
23974 
23975     // Check if this is the identity extract:
23976     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
23977       return SDValue();
23978 
23979     // Find the single incoming vector for the extract_subvector.
23980     if (SingleSource.getNode()) {
23981       if (Op.getOperand(0) != SingleSource)
23982         return SDValue();
23983     } else {
23984       SingleSource = Op.getOperand(0);
23985 
23986       // Check the source type is the same as the type of the result.
23987       // If not, this concat may extend the vector, so we can not
23988       // optimize it away.
23989       if (SingleSource.getValueType() != N->getValueType(0))
23990         return SDValue();
23991     }
23992 
23993     // Check that we are reading from the identity index.
23994     unsigned IdentityIndex = i * PartNumElem;
23995     if (Op.getConstantOperandAPInt(1) != IdentityIndex)
23996       return SDValue();
23997   }
23998 
23999   if (SingleSource.getNode())
24000     return SingleSource;
24001 
24002   return SDValue();
24003 }
24004 
24005 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
24006 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)24007 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
24008   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
24009       V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
24010     return V.getOperand(1);
24011   }
24012   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
24013   if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
24014       V.getOperand(0).getValueType() == SubVT &&
24015       (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
24016     uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
24017     return V.getOperand(SubIdx);
24018   }
24019   return SDValue();
24020 }
24021 
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)24022 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
24023                                               SelectionDAG &DAG,
24024                                               bool LegalOperations) {
24025   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24026   SDValue BinOp = Extract->getOperand(0);
24027   unsigned BinOpcode = BinOp.getOpcode();
24028   if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
24029     return SDValue();
24030 
24031   EVT VecVT = BinOp.getValueType();
24032   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
24033   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
24034     return SDValue();
24035 
24036   SDValue Index = Extract->getOperand(1);
24037   EVT SubVT = Extract->getValueType(0);
24038   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
24039     return SDValue();
24040 
24041   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
24042   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
24043 
24044   // TODO: We could handle the case where only 1 operand is being inserted by
24045   //       creating an extract of the other operand, but that requires checking
24046   //       number of uses and/or costs.
24047   if (!Sub0 || !Sub1)
24048     return SDValue();
24049 
24050   // We are inserting both operands of the wide binop only to extract back
24051   // to the narrow vector size. Eliminate all of the insert/extract:
24052   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
24053   return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
24054                      BinOp->getFlags());
24055 }
24056 
24057 /// If we are extracting a subvector produced by a wide binary operator try
24058 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)24059 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
24060                                           bool LegalOperations) {
24061   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
24062   // some of these bailouts with other transforms.
24063 
24064   if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
24065     return V;
24066 
24067   // The extract index must be a constant, so we can map it to a concat operand.
24068   auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
24069   if (!ExtractIndexC)
24070     return SDValue();
24071 
24072   // We are looking for an optionally bitcasted wide vector binary operator
24073   // feeding an extract subvector.
24074   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24075   SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
24076   unsigned BOpcode = BinOp.getOpcode();
24077   if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
24078     return SDValue();
24079 
24080   // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
24081   // reduced to the unary fneg when it is visited, and we probably want to deal
24082   // with fneg in a target-specific way.
24083   if (BOpcode == ISD::FSUB) {
24084     auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
24085     if (C && C->getValueAPF().isNegZero())
24086       return SDValue();
24087   }
24088 
24089   // The binop must be a vector type, so we can extract some fraction of it.
24090   EVT WideBVT = BinOp.getValueType();
24091   // The optimisations below currently assume we are dealing with fixed length
24092   // vectors. It is possible to add support for scalable vectors, but at the
24093   // moment we've done no analysis to prove whether they are profitable or not.
24094   if (!WideBVT.isFixedLengthVector())
24095     return SDValue();
24096 
24097   EVT VT = Extract->getValueType(0);
24098   unsigned ExtractIndex = ExtractIndexC->getZExtValue();
24099   assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
24100          "Extract index is not a multiple of the vector length.");
24101 
24102   // Bail out if this is not a proper multiple width extraction.
24103   unsigned WideWidth = WideBVT.getSizeInBits();
24104   unsigned NarrowWidth = VT.getSizeInBits();
24105   if (WideWidth % NarrowWidth != 0)
24106     return SDValue();
24107 
24108   // Bail out if we are extracting a fraction of a single operation. This can
24109   // occur because we potentially looked through a bitcast of the binop.
24110   unsigned NarrowingRatio = WideWidth / NarrowWidth;
24111   unsigned WideNumElts = WideBVT.getVectorNumElements();
24112   if (WideNumElts % NarrowingRatio != 0)
24113     return SDValue();
24114 
24115   // Bail out if the target does not support a narrower version of the binop.
24116   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
24117                                    WideNumElts / NarrowingRatio);
24118   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT,
24119                                              LegalOperations))
24120     return SDValue();
24121 
24122   // If extraction is cheap, we don't need to look at the binop operands
24123   // for concat ops. The narrow binop alone makes this transform profitable.
24124   // We can't just reuse the original extract index operand because we may have
24125   // bitcasted.
24126   unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
24127   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
24128   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
24129       BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
24130     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
24131     SDLoc DL(Extract);
24132     SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
24133     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
24134                             BinOp.getOperand(0), NewExtIndex);
24135     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
24136                             BinOp.getOperand(1), NewExtIndex);
24137     SDValue NarrowBinOp =
24138         DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags());
24139     return DAG.getBitcast(VT, NarrowBinOp);
24140   }
24141 
24142   // Only handle the case where we are doubling and then halving. A larger ratio
24143   // may require more than two narrow binops to replace the wide binop.
24144   if (NarrowingRatio != 2)
24145     return SDValue();
24146 
24147   // TODO: The motivating case for this transform is an x86 AVX1 target. That
24148   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
24149   // flavors, but no other 256-bit integer support. This could be extended to
24150   // handle any binop, but that may require fixing/adding other folds to avoid
24151   // codegen regressions.
24152   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
24153     return SDValue();
24154 
24155   // We need at least one concatenation operation of a binop operand to make
24156   // this transform worthwhile. The concat must double the input vector sizes.
24157   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
24158     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
24159       return V.getOperand(ConcatOpNum);
24160     return SDValue();
24161   };
24162   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
24163   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
24164 
24165   if (SubVecL || SubVecR) {
24166     // If a binop operand was not the result of a concat, we must extract a
24167     // half-sized operand for our new narrow binop:
24168     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
24169     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
24170     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
24171     SDLoc DL(Extract);
24172     SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
24173     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
24174                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
24175                                       BinOp.getOperand(0), IndexC);
24176 
24177     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
24178                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
24179                                       BinOp.getOperand(1), IndexC);
24180 
24181     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
24182     return DAG.getBitcast(VT, NarrowBinOp);
24183   }
24184 
24185   return SDValue();
24186 }
24187 
24188 /// If we are extracting a subvector from a wide vector load, convert to a
24189 /// narrow load to eliminate the extraction:
24190 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)24191 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
24192   // TODO: Add support for big-endian. The offset calculation must be adjusted.
24193   if (DAG.getDataLayout().isBigEndian())
24194     return SDValue();
24195 
24196   auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
24197   if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
24198     return SDValue();
24199 
24200   // Allow targets to opt-out.
24201   EVT VT = Extract->getValueType(0);
24202 
24203   // We can only create byte sized loads.
24204   if (!VT.isByteSized())
24205     return SDValue();
24206 
24207   unsigned Index = Extract->getConstantOperandVal(1);
24208   unsigned NumElts = VT.getVectorMinNumElements();
24209   // A fixed length vector being extracted from a scalable vector
24210   // may not be any *smaller* than the scalable one.
24211   if (Index == 0 && NumElts >= Ld->getValueType(0).getVectorMinNumElements())
24212     return SDValue();
24213 
24214   // The definition of EXTRACT_SUBVECTOR states that the index must be a
24215   // multiple of the minimum number of elements in the result type.
24216   assert(Index % NumElts == 0 && "The extract subvector index is not a "
24217                                  "multiple of the result's element count");
24218 
24219   // It's fine to use TypeSize here as we know the offset will not be negative.
24220   TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
24221 
24222   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24223   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
24224     return SDValue();
24225 
24226   // The narrow load will be offset from the base address of the old load if
24227   // we are extracting from something besides index 0 (little-endian).
24228   SDLoc DL(Extract);
24229 
24230   // TODO: Use "BaseIndexOffset" to make this more effective.
24231   SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
24232 
24233   uint64_t StoreSize = MemoryLocation::getSizeOrUnknown(VT.getStoreSize());
24234   MachineFunction &MF = DAG.getMachineFunction();
24235   MachineMemOperand *MMO;
24236   if (Offset.isScalable()) {
24237     MachinePointerInfo MPI =
24238         MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
24239     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, StoreSize);
24240   } else
24241     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedValue(),
24242                                   StoreSize);
24243 
24244   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
24245   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
24246   return NewLd;
24247 }
24248 
24249 /// Given  EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
24250 /// try to produce  VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
24251 ///                                EXTRACT_SUBVECTOR(Op?, ?),
24252 ///                                Mask'))
24253 /// iff it is legal and profitable to do so. Notably, the trimmed mask
24254 /// (containing only the elements that are extracted)
24255 /// must reference at most two subvectors.
foldExtractSubvectorFromShuffleVector(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)24256 static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
24257                                                      SelectionDAG &DAG,
24258                                                      const TargetLowering &TLI,
24259                                                      bool LegalOperations) {
24260   assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24261          "Must only be called on EXTRACT_SUBVECTOR's");
24262 
24263   SDValue N0 = N->getOperand(0);
24264 
24265   // Only deal with non-scalable vectors.
24266   EVT NarrowVT = N->getValueType(0);
24267   EVT WideVT = N0.getValueType();
24268   if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
24269     return SDValue();
24270 
24271   // The operand must be a shufflevector.
24272   auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
24273   if (!WideShuffleVector)
24274     return SDValue();
24275 
24276   // The old shuffleneeds to go away.
24277   if (!WideShuffleVector->hasOneUse())
24278     return SDValue();
24279 
24280   // And the narrow shufflevector that we'll form must be legal.
24281   if (LegalOperations &&
24282       !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
24283     return SDValue();
24284 
24285   uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
24286   int NumEltsExtracted = NarrowVT.getVectorNumElements();
24287   assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
24288          "Extract index is not a multiple of the output vector length.");
24289 
24290   int WideNumElts = WideVT.getVectorNumElements();
24291 
24292   SmallVector<int, 16> NewMask;
24293   NewMask.reserve(NumEltsExtracted);
24294   SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
24295       DemandedSubvectors;
24296 
24297   // Try to decode the wide mask into narrow mask from at most two subvectors.
24298   for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
24299                                                   NumEltsExtracted)) {
24300     assert((M >= -1) && (M < (2 * WideNumElts)) &&
24301            "Out-of-bounds shuffle mask?");
24302 
24303     if (M < 0) {
24304       // Does not depend on operands, does not require adjustment.
24305       NewMask.emplace_back(M);
24306       continue;
24307     }
24308 
24309     // From which operand of the shuffle does this shuffle mask element pick?
24310     int WideShufOpIdx = M / WideNumElts;
24311     // Which element of that operand is picked?
24312     int OpEltIdx = M % WideNumElts;
24313 
24314     assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
24315            "Shuffle mask vector decomposition failure.");
24316 
24317     // And which NumEltsExtracted-sized subvector of that operand is that?
24318     int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
24319     // And which element within that subvector of that operand is that?
24320     int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
24321 
24322     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
24323            "Shuffle mask subvector decomposition failure.");
24324 
24325     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
24326             WideShufOpIdx * WideNumElts) == M &&
24327            "Shuffle mask full decomposition failure.");
24328 
24329     SDValue Op = WideShuffleVector->getOperand(WideShufOpIdx);
24330 
24331     if (Op.isUndef()) {
24332       // Picking from an undef operand. Let's adjust mask instead.
24333       NewMask.emplace_back(-1);
24334       continue;
24335     }
24336 
24337     const std::pair<SDValue, int> DemandedSubvector =
24338         std::make_pair(Op, OpSubvecIdx);
24339 
24340     if (DemandedSubvectors.insert(DemandedSubvector)) {
24341       if (DemandedSubvectors.size() > 2)
24342         return SDValue(); // We can't handle more than two subvectors.
24343       // How many elements into the WideVT does this subvector start?
24344       int Index = NumEltsExtracted * OpSubvecIdx;
24345       // Bail out if the extraction isn't going to be cheap.
24346       if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index))
24347         return SDValue();
24348     }
24349 
24350     // Ok, but from which operand of the new shuffle will this element pick?
24351     int NewOpIdx =
24352         getFirstIndexOf(DemandedSubvectors.getArrayRef(), DemandedSubvector);
24353     assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
24354 
24355     int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
24356     NewMask.emplace_back(AdjM);
24357   }
24358   assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
24359   assert(DemandedSubvectors.size() <= 2 &&
24360          "Should have ended up demanding at most two subvectors.");
24361 
24362   // Did we discover that the shuffle does not actually depend on operands?
24363   if (DemandedSubvectors.empty())
24364     return DAG.getUNDEF(NarrowVT);
24365 
24366   // Profitability check: only deal with extractions from the first subvector
24367   // unless the mask becomes an identity mask.
24368   if (!ShuffleVectorInst::isIdentityMask(NewMask, NewMask.size()) ||
24369       any_of(NewMask, [](int M) { return M < 0; }))
24370     for (auto &DemandedSubvector : DemandedSubvectors)
24371       if (DemandedSubvector.second != 0)
24372         return SDValue();
24373 
24374   // We still perform the exact same EXTRACT_SUBVECTOR,  just on different
24375   // operand[s]/index[es], so there is no point in checking for it's legality.
24376 
24377   // Do not turn a legal shuffle into an illegal one.
24378   if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
24379       !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
24380     return SDValue();
24381 
24382   SDLoc DL(N);
24383 
24384   SmallVector<SDValue, 2> NewOps;
24385   for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
24386            &DemandedSubvector : DemandedSubvectors) {
24387     // How many elements into the WideVT does this subvector start?
24388     int Index = NumEltsExtracted * DemandedSubvector.second;
24389     SDValue IndexC = DAG.getVectorIdxConstant(Index, DL);
24390     NewOps.emplace_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowVT,
24391                                     DemandedSubvector.first, IndexC));
24392   }
24393   assert((NewOps.size() == 1 || NewOps.size() == 2) &&
24394          "Should end up with either one or two ops");
24395 
24396   // If we ended up with only one operand, pad with an undef.
24397   if (NewOps.size() == 1)
24398     NewOps.emplace_back(DAG.getUNDEF(NarrowVT));
24399 
24400   return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
24401 }
24402 
visitEXTRACT_SUBVECTOR(SDNode * N)24403 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
24404   EVT NVT = N->getValueType(0);
24405   SDValue V = N->getOperand(0);
24406   uint64_t ExtIdx = N->getConstantOperandVal(1);
24407 
24408   // Extract from UNDEF is UNDEF.
24409   if (V.isUndef())
24410     return DAG.getUNDEF(NVT);
24411 
24412   if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
24413     if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
24414       return NarrowLoad;
24415 
24416   // Combine an extract of an extract into a single extract_subvector.
24417   // ext (ext X, C), 0 --> ext X, C
24418   if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
24419     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
24420                                     V.getConstantOperandVal(1)) &&
24421         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
24422       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
24423                          V.getOperand(1));
24424     }
24425   }
24426 
24427   // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
24428   if (V.getOpcode() == ISD::SPLAT_VECTOR)
24429     if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
24430       if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
24431         return DAG.getSplatVector(NVT, SDLoc(N), V.getOperand(0));
24432 
24433   // Try to move vector bitcast after extract_subv by scaling extraction index:
24434   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
24435   if (V.getOpcode() == ISD::BITCAST &&
24436       V.getOperand(0).getValueType().isVector() &&
24437       (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
24438     SDValue SrcOp = V.getOperand(0);
24439     EVT SrcVT = SrcOp.getValueType();
24440     unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
24441     unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
24442     if ((SrcNumElts % DestNumElts) == 0) {
24443       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
24444       ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
24445       EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
24446                                       NewExtEC);
24447       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
24448         SDLoc DL(N);
24449         SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
24450         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
24451                                          V.getOperand(0), NewIndex);
24452         return DAG.getBitcast(NVT, NewExtract);
24453       }
24454     }
24455     if ((DestNumElts % SrcNumElts) == 0) {
24456       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
24457       if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
24458         ElementCount NewExtEC =
24459             NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
24460         EVT ScalarVT = SrcVT.getScalarType();
24461         if ((ExtIdx % DestSrcRatio) == 0) {
24462           SDLoc DL(N);
24463           unsigned IndexValScaled = ExtIdx / DestSrcRatio;
24464           EVT NewExtVT =
24465               EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
24466           if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
24467             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
24468             SDValue NewExtract =
24469                 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
24470                             V.getOperand(0), NewIndex);
24471             return DAG.getBitcast(NVT, NewExtract);
24472           }
24473           if (NewExtEC.isScalar() &&
24474               TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
24475             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
24476             SDValue NewExtract =
24477                 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
24478                             V.getOperand(0), NewIndex);
24479             return DAG.getBitcast(NVT, NewExtract);
24480           }
24481         }
24482       }
24483     }
24484   }
24485 
24486   if (V.getOpcode() == ISD::CONCAT_VECTORS) {
24487     unsigned ExtNumElts = NVT.getVectorMinNumElements();
24488     EVT ConcatSrcVT = V.getOperand(0).getValueType();
24489     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
24490            "Concat and extract subvector do not change element type");
24491     assert((ExtIdx % ExtNumElts) == 0 &&
24492            "Extract index is not a multiple of the input vector length.");
24493 
24494     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
24495     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
24496 
24497     // If the concatenated source types match this extract, it's a direct
24498     // simplification:
24499     // extract_subvec (concat V1, V2, ...), i --> Vi
24500     if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
24501       return V.getOperand(ConcatOpIdx);
24502 
24503     // If the concatenated source vectors are a multiple length of this extract,
24504     // then extract a fraction of one of those source vectors directly from a
24505     // concat operand. Example:
24506     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
24507     //   v2i8 extract_subvec v8i8 Y, 6
24508     if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
24509         ConcatSrcNumElts % ExtNumElts == 0) {
24510       SDLoc DL(N);
24511       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
24512       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
24513              "Trying to extract from >1 concat operand?");
24514       assert(NewExtIdx % ExtNumElts == 0 &&
24515              "Extract index is not a multiple of the input vector length.");
24516       SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
24517       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
24518                          V.getOperand(ConcatOpIdx), NewIndexC);
24519     }
24520   }
24521 
24522   if (SDValue V =
24523           foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
24524     return V;
24525 
24526   V = peekThroughBitcasts(V);
24527 
24528   // If the input is a build vector. Try to make a smaller build vector.
24529   if (V.getOpcode() == ISD::BUILD_VECTOR) {
24530     EVT InVT = V.getValueType();
24531     unsigned ExtractSize = NVT.getSizeInBits();
24532     unsigned EltSize = InVT.getScalarSizeInBits();
24533     // Only do this if we won't split any elements.
24534     if (ExtractSize % EltSize == 0) {
24535       unsigned NumElems = ExtractSize / EltSize;
24536       EVT EltVT = InVT.getVectorElementType();
24537       EVT ExtractVT =
24538           NumElems == 1 ? EltVT
24539                         : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
24540       if ((Level < AfterLegalizeDAG ||
24541            (NumElems == 1 ||
24542             TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
24543           (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
24544         unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
24545 
24546         if (NumElems == 1) {
24547           SDValue Src = V->getOperand(IdxVal);
24548           if (EltVT != Src.getValueType())
24549             Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), EltVT, Src);
24550           return DAG.getBitcast(NVT, Src);
24551         }
24552 
24553         // Extract the pieces from the original build_vector.
24554         SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N),
24555                                               V->ops().slice(IdxVal, NumElems));
24556         return DAG.getBitcast(NVT, BuildVec);
24557       }
24558     }
24559   }
24560 
24561   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
24562     // Handle only simple case where vector being inserted and vector
24563     // being extracted are of same size.
24564     EVT SmallVT = V.getOperand(1).getValueType();
24565     if (!NVT.bitsEq(SmallVT))
24566       return SDValue();
24567 
24568     // Combine:
24569     //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
24570     // Into:
24571     //    indices are equal or bit offsets are equal => V1
24572     //    otherwise => (extract_subvec V1, ExtIdx)
24573     uint64_t InsIdx = V.getConstantOperandVal(2);
24574     if (InsIdx * SmallVT.getScalarSizeInBits() ==
24575         ExtIdx * NVT.getScalarSizeInBits()) {
24576       if (LegalOperations && !TLI.isOperationLegal(ISD::BITCAST, NVT))
24577         return SDValue();
24578 
24579       return DAG.getBitcast(NVT, V.getOperand(1));
24580     }
24581     return DAG.getNode(
24582         ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
24583         DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
24584         N->getOperand(1));
24585   }
24586 
24587   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
24588     return NarrowBOp;
24589 
24590   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
24591     return SDValue(N, 0);
24592 
24593   return SDValue();
24594 }
24595 
24596 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
24597 /// followed by concatenation. Narrow vector ops may have better performance
24598 /// than wide ops, and this can unlock further narrowing of other vector ops.
24599 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)24600 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
24601                                          SelectionDAG &DAG) {
24602   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
24603   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
24604       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
24605       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
24606     return SDValue();
24607 
24608   // Split the wide shuffle mask into halves. Any mask element that is accessing
24609   // operand 1 is offset down to account for narrowing of the vectors.
24610   ArrayRef<int> Mask = Shuf->getMask();
24611   EVT VT = Shuf->getValueType(0);
24612   unsigned NumElts = VT.getVectorNumElements();
24613   unsigned HalfNumElts = NumElts / 2;
24614   SmallVector<int, 16> Mask0(HalfNumElts, -1);
24615   SmallVector<int, 16> Mask1(HalfNumElts, -1);
24616   for (unsigned i = 0; i != NumElts; ++i) {
24617     if (Mask[i] == -1)
24618       continue;
24619     // If we reference the upper (undef) subvector then the element is undef.
24620     if ((Mask[i] % NumElts) >= HalfNumElts)
24621       continue;
24622     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
24623     if (i < HalfNumElts)
24624       Mask0[i] = M;
24625     else
24626       Mask1[i - HalfNumElts] = M;
24627   }
24628 
24629   // Ask the target if this is a valid transform.
24630   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24631   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
24632                                 HalfNumElts);
24633   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
24634       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
24635     return SDValue();
24636 
24637   // shuffle (concat X, undef), (concat Y, undef), Mask -->
24638   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
24639   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
24640   SDLoc DL(Shuf);
24641   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
24642   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
24643   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
24644 }
24645 
24646 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
24647 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)24648 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
24649   EVT VT = N->getValueType(0);
24650   unsigned NumElts = VT.getVectorNumElements();
24651 
24652   SDValue N0 = N->getOperand(0);
24653   SDValue N1 = N->getOperand(1);
24654   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
24655   ArrayRef<int> Mask = SVN->getMask();
24656 
24657   SmallVector<SDValue, 4> Ops;
24658   EVT ConcatVT = N0.getOperand(0).getValueType();
24659   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
24660   unsigned NumConcats = NumElts / NumElemsPerConcat;
24661 
24662   auto IsUndefMaskElt = [](int i) { return i == -1; };
24663 
24664   // Special case: shuffle(concat(A,B)) can be more efficiently represented
24665   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
24666   // half vector elements.
24667   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
24668       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
24669                    IsUndefMaskElt)) {
24670     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
24671                               N0.getOperand(1),
24672                               Mask.slice(0, NumElemsPerConcat));
24673     N1 = DAG.getUNDEF(ConcatVT);
24674     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
24675   }
24676 
24677   // Look at every vector that's inserted. We're looking for exact
24678   // subvector-sized copies from a concatenated vector
24679   for (unsigned I = 0; I != NumConcats; ++I) {
24680     unsigned Begin = I * NumElemsPerConcat;
24681     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
24682 
24683     // Make sure we're dealing with a copy.
24684     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
24685       Ops.push_back(DAG.getUNDEF(ConcatVT));
24686       continue;
24687     }
24688 
24689     int OpIdx = -1;
24690     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
24691       if (IsUndefMaskElt(SubMask[i]))
24692         continue;
24693       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
24694         return SDValue();
24695       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
24696       if (0 <= OpIdx && EltOpIdx != OpIdx)
24697         return SDValue();
24698       OpIdx = EltOpIdx;
24699     }
24700     assert(0 <= OpIdx && "Unknown concat_vectors op");
24701 
24702     if (OpIdx < (int)N0.getNumOperands())
24703       Ops.push_back(N0.getOperand(OpIdx));
24704     else
24705       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
24706   }
24707 
24708   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
24709 }
24710 
24711 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
24712 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
24713 //
24714 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
24715 // a simplification in some sense, but it isn't appropriate in general: some
24716 // BUILD_VECTORs are substantially cheaper than others. The general case
24717 // of a BUILD_VECTOR requires inserting each element individually (or
24718 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
24719 // all constants is a single constant pool load.  A BUILD_VECTOR where each
24720 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
24721 // are undef lowers to a small number of element insertions.
24722 //
24723 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
24724 // We don't fold shuffles where one side is a non-zero constant, and we don't
24725 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
24726 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)24727 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
24728                                        SelectionDAG &DAG,
24729                                        const TargetLowering &TLI) {
24730   EVT VT = SVN->getValueType(0);
24731   unsigned NumElts = VT.getVectorNumElements();
24732   SDValue N0 = SVN->getOperand(0);
24733   SDValue N1 = SVN->getOperand(1);
24734 
24735   if (!N0->hasOneUse())
24736     return SDValue();
24737 
24738   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
24739   // discussed above.
24740   if (!N1.isUndef()) {
24741     if (!N1->hasOneUse())
24742       return SDValue();
24743 
24744     bool N0AnyConst = isAnyConstantBuildVector(N0);
24745     bool N1AnyConst = isAnyConstantBuildVector(N1);
24746     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
24747       return SDValue();
24748     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
24749       return SDValue();
24750   }
24751 
24752   // If both inputs are splats of the same value then we can safely merge this
24753   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
24754   bool IsSplat = false;
24755   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
24756   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
24757   if (BV0 && BV1)
24758     if (SDValue Splat0 = BV0->getSplatValue())
24759       IsSplat = (Splat0 == BV1->getSplatValue());
24760 
24761   SmallVector<SDValue, 8> Ops;
24762   SmallSet<SDValue, 16> DuplicateOps;
24763   for (int M : SVN->getMask()) {
24764     SDValue Op = DAG.getUNDEF(VT.getScalarType());
24765     if (M >= 0) {
24766       int Idx = M < (int)NumElts ? M : M - NumElts;
24767       SDValue &S = (M < (int)NumElts ? N0 : N1);
24768       if (S.getOpcode() == ISD::BUILD_VECTOR) {
24769         Op = S.getOperand(Idx);
24770       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
24771         SDValue Op0 = S.getOperand(0);
24772         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
24773       } else {
24774         // Operand can't be combined - bail out.
24775         return SDValue();
24776       }
24777     }
24778 
24779     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
24780     // generating a splat; semantically, this is fine, but it's likely to
24781     // generate low-quality code if the target can't reconstruct an appropriate
24782     // shuffle.
24783     if (!Op.isUndef() && !isIntOrFPConstant(Op))
24784       if (!IsSplat && !DuplicateOps.insert(Op).second)
24785         return SDValue();
24786 
24787     Ops.push_back(Op);
24788   }
24789 
24790   // BUILD_VECTOR requires all inputs to be of the same type, find the
24791   // maximum type and extend them all.
24792   EVT SVT = VT.getScalarType();
24793   if (SVT.isInteger())
24794     for (SDValue &Op : Ops)
24795       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
24796   if (SVT != VT.getScalarType())
24797     for (SDValue &Op : Ops)
24798       Op = Op.isUndef() ? DAG.getUNDEF(SVT)
24799                         : (TLI.isZExtFree(Op.getValueType(), SVT)
24800                                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
24801                                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT));
24802   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
24803 }
24804 
24805 // Match shuffles that can be converted to *_vector_extend_in_reg.
24806 // This is often generated during legalization.
24807 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
24808 // and returns the EVT to which the extension should be performed.
24809 // NOTE: this assumes that the src is the first operand of the shuffle.
canCombineShuffleToExtendVectorInreg(unsigned Opcode,EVT VT,std::function<bool (unsigned)> Match,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)24810 static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
24811     unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
24812     SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
24813     bool LegalOperations) {
24814   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24815 
24816   // TODO Add support for big-endian when we have a test case.
24817   if (!VT.isInteger() || IsBigEndian)
24818     return std::nullopt;
24819 
24820   unsigned NumElts = VT.getVectorNumElements();
24821   unsigned EltSizeInBits = VT.getScalarSizeInBits();
24822 
24823   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
24824   // power-of-2 extensions as they are the most likely.
24825   // FIXME: should try Scale == NumElts case too,
24826   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
24827     // The vector width must be a multiple of Scale.
24828     if (NumElts % Scale != 0)
24829       continue;
24830 
24831     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
24832     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
24833 
24834     if ((LegalTypes && !TLI.isTypeLegal(OutVT)) ||
24835         (LegalOperations && !TLI.isOperationLegalOrCustom(Opcode, OutVT)))
24836       continue;
24837 
24838     if (Match(Scale))
24839       return OutVT;
24840   }
24841 
24842   return std::nullopt;
24843 }
24844 
24845 // Match shuffles that can be converted to any_vector_extend_in_reg.
24846 // This is often generated during legalization.
24847 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)24848 static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
24849                                                     SelectionDAG &DAG,
24850                                                     const TargetLowering &TLI,
24851                                                     bool LegalOperations) {
24852   EVT VT = SVN->getValueType(0);
24853   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24854 
24855   // TODO Add support for big-endian when we have a test case.
24856   if (!VT.isInteger() || IsBigEndian)
24857     return SDValue();
24858 
24859   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
24860   auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
24861                       Mask = SVN->getMask()](unsigned Scale) {
24862     for (unsigned i = 0; i != NumElts; ++i) {
24863       if (Mask[i] < 0)
24864         continue;
24865       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
24866         continue;
24867       return false;
24868     }
24869     return true;
24870   };
24871 
24872   unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
24873   SDValue N0 = SVN->getOperand(0);
24874   // Never create an illegal type. Only create unsupported operations if we
24875   // are pre-legalization.
24876   std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
24877       Opcode, VT, isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
24878   if (!OutVT)
24879     return SDValue();
24880   return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT, N0));
24881 }
24882 
24883 // Match shuffles that can be converted to zero_extend_vector_inreg.
24884 // This is often generated during legalization.
24885 // e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)24886 static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
24887                                                      SelectionDAG &DAG,
24888                                                      const TargetLowering &TLI,
24889                                                      bool LegalOperations) {
24890   bool LegalTypes = true;
24891   EVT VT = SVN->getValueType(0);
24892   assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
24893   unsigned NumElts = VT.getVectorNumElements();
24894   unsigned EltSizeInBits = VT.getScalarSizeInBits();
24895 
24896   // TODO: add support for big-endian when we have a test case.
24897   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24898   if (!VT.isInteger() || IsBigEndian)
24899     return SDValue();
24900 
24901   SmallVector<int, 16> Mask(SVN->getMask().begin(), SVN->getMask().end());
24902   auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
24903     for (int &Indice : Mask) {
24904       if (Indice < 0)
24905         continue;
24906       int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
24907       int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
24908       Fn(Indice, OpIdx, OpEltIdx);
24909     }
24910   };
24911 
24912   // Which elements of which operand does this shuffle demand?
24913   std::array<APInt, 2> OpsDemandedElts;
24914   for (APInt &OpDemandedElts : OpsDemandedElts)
24915     OpDemandedElts = APInt::getZero(NumElts);
24916   ForEachDecomposedIndice(
24917       [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
24918         OpsDemandedElts[OpIdx].setBit(OpEltIdx);
24919       });
24920 
24921   // Element-wise(!), which of these demanded elements are know to be zero?
24922   std::array<APInt, 2> OpsKnownZeroElts;
24923   for (auto I : zip(SVN->ops(), OpsDemandedElts, OpsKnownZeroElts))
24924     std::get<2>(I) =
24925         DAG.computeVectorKnownZeroElements(std::get<0>(I), std::get<1>(I));
24926 
24927   // Manifest zeroable element knowledge in the shuffle mask.
24928   // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
24929   //       this is a local invention, but it won't leak into DAG.
24930   // FIXME: should we not manifest them, but just check when matching?
24931   bool HadZeroableElts = false;
24932   ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
24933                               int &Indice, int OpIdx, int OpEltIdx) {
24934     if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
24935       Indice = -2; // Zeroable element.
24936       HadZeroableElts = true;
24937     }
24938   });
24939 
24940   // Don't proceed unless we've refined at least one zeroable mask indice.
24941   // If we didn't, then we are still trying to match the same shuffle mask
24942   // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
24943   // and evidently failed. Proceeding will lead to endless combine loops.
24944   if (!HadZeroableElts)
24945     return SDValue();
24946 
24947   // The shuffle may be more fine-grained than we want. Widen elements first.
24948   // FIXME: should we do this before manifesting zeroable shuffle mask indices?
24949   SmallVector<int, 16> ScaledMask;
24950   getShuffleMaskWithWidestElts(Mask, ScaledMask);
24951   assert(Mask.size() >= ScaledMask.size() &&
24952          Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
24953   int Prescale = Mask.size() / ScaledMask.size();
24954 
24955   NumElts = ScaledMask.size();
24956   EltSizeInBits *= Prescale;
24957 
24958   EVT PrescaledVT = EVT::getVectorVT(
24959       *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits),
24960       NumElts);
24961 
24962   if (LegalTypes && !TLI.isTypeLegal(PrescaledVT) && TLI.isTypeLegal(VT))
24963     return SDValue();
24964 
24965   // For example,
24966   // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
24967   // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
24968   auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
24969     assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
24970            "Unexpected mask scaling factor.");
24971     ArrayRef<int> Mask = ScaledMask;
24972     for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
24973          SrcElt != NumSrcElts; ++SrcElt) {
24974       // Analyze the shuffle mask in Scale-sized chunks.
24975       ArrayRef<int> MaskChunk = Mask.take_front(Scale);
24976       assert(MaskChunk.size() == Scale && "Unexpected mask size.");
24977       Mask = Mask.drop_front(MaskChunk.size());
24978       // The first indice in this chunk must be SrcElt, but not zero!
24979       // FIXME: undef should be fine, but that results in more-defined result.
24980       if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
24981         return false;
24982       // The rest of the indices in this chunk must be zeros.
24983       // FIXME: undef should be fine, but that results in more-defined result.
24984       if (!all_of(MaskChunk.drop_front(1),
24985                   [](int Indice) { return Indice == -2; }))
24986         return false;
24987     }
24988     assert(Mask.empty() && "Did not process the whole mask?");
24989     return true;
24990   };
24991 
24992   unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
24993   for (bool Commuted : {false, true}) {
24994     SDValue Op = SVN->getOperand(!Commuted ? 0 : 1);
24995     if (Commuted)
24996       ShuffleVectorSDNode::commuteMask(ScaledMask);
24997     std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
24998         Opcode, PrescaledVT, isZeroExtend, DAG, TLI, LegalTypes,
24999         LegalOperations);
25000     if (OutVT)
25001       return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT,
25002                                             DAG.getBitcast(PrescaledVT, Op)));
25003   }
25004   return SDValue();
25005 }
25006 
25007 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
25008 // each source element of a large type into the lowest elements of a smaller
25009 // destination type. This is often generated during legalization.
25010 // If the source node itself was a '*_extend_vector_inreg' node then we should
25011 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)25012 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
25013                                         SelectionDAG &DAG) {
25014   EVT VT = SVN->getValueType(0);
25015   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25016 
25017   // TODO Add support for big-endian when we have a test case.
25018   if (!VT.isInteger() || IsBigEndian)
25019     return SDValue();
25020 
25021   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
25022 
25023   unsigned Opcode = N0.getOpcode();
25024   if (!ISD::isExtVecInRegOpcode(Opcode))
25025     return SDValue();
25026 
25027   SDValue N00 = N0.getOperand(0);
25028   ArrayRef<int> Mask = SVN->getMask();
25029   unsigned NumElts = VT.getVectorNumElements();
25030   unsigned EltSizeInBits = VT.getScalarSizeInBits();
25031   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
25032   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
25033 
25034   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
25035     return SDValue();
25036   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
25037 
25038   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
25039   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
25040   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
25041   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
25042     for (unsigned i = 0; i != NumElts; ++i) {
25043       if (Mask[i] < 0)
25044         continue;
25045       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
25046         continue;
25047       return false;
25048     }
25049     return true;
25050   };
25051 
25052   // At the moment we just handle the case where we've truncated back to the
25053   // same size as before the extension.
25054   // TODO: handle more extension/truncation cases as cases arise.
25055   if (EltSizeInBits != ExtSrcSizeInBits)
25056     return SDValue();
25057 
25058   // We can remove *extend_vector_inreg only if the truncation happens at
25059   // the same scale as the extension.
25060   if (isTruncate(ExtScale))
25061     return DAG.getBitcast(VT, N00);
25062 
25063   return SDValue();
25064 }
25065 
25066 // Combine shuffles of splat-shuffles of the form:
25067 // shuffle (shuffle V, undef, splat-mask), undef, M
25068 // If splat-mask contains undef elements, we need to be careful about
25069 // introducing undef's in the folded mask which are not the result of composing
25070 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)25071 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
25072                                         SelectionDAG &DAG) {
25073   EVT VT = Shuf->getValueType(0);
25074   unsigned NumElts = VT.getVectorNumElements();
25075 
25076   if (!Shuf->getOperand(1).isUndef())
25077     return SDValue();
25078 
25079   // See if this unary non-splat shuffle actually *is* a splat shuffle,
25080   // in disguise, with all demanded elements being identical.
25081   // FIXME: this can be done per-operand.
25082   if (!Shuf->isSplat()) {
25083     APInt DemandedElts(NumElts, 0);
25084     for (int Idx : Shuf->getMask()) {
25085       if (Idx < 0)
25086         continue; // Ignore sentinel indices.
25087       assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
25088       DemandedElts.setBit(Idx);
25089     }
25090     assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
25091     APInt UndefElts;
25092     if (DAG.isSplatValue(Shuf->getOperand(0), DemandedElts, UndefElts)) {
25093       // Even if all demanded elements are splat, some of them could be undef.
25094       // Which lowest demanded element is *not* known-undef?
25095       std::optional<unsigned> MinNonUndefIdx;
25096       for (int Idx : Shuf->getMask()) {
25097         if (Idx < 0 || UndefElts[Idx])
25098           continue; // Ignore sentinel indices, and undef elements.
25099         MinNonUndefIdx = std::min<unsigned>(Idx, MinNonUndefIdx.value_or(~0U));
25100       }
25101       if (!MinNonUndefIdx)
25102         return DAG.getUNDEF(VT); // All undef - result is undef.
25103       assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
25104       SmallVector<int, 8> SplatMask(Shuf->getMask().begin(),
25105                                     Shuf->getMask().end());
25106       for (int &Idx : SplatMask) {
25107         if (Idx < 0)
25108           continue; // Passthrough sentinel indices.
25109         // Otherwise, just pick the lowest demanded non-undef element.
25110         // Or sentinel undef, if we know we'd pick a known-undef element.
25111         Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
25112       }
25113       assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
25114       return DAG.getVectorShuffle(VT, SDLoc(Shuf), Shuf->getOperand(0),
25115                                   Shuf->getOperand(1), SplatMask);
25116     }
25117   }
25118 
25119   // If the inner operand is a known splat with no undefs, just return that directly.
25120   // TODO: Create DemandedElts mask from Shuf's mask.
25121   // TODO: Allow undef elements and merge with the shuffle code below.
25122   if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false))
25123     return Shuf->getOperand(0);
25124 
25125   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
25126   if (!Splat || !Splat->isSplat())
25127     return SDValue();
25128 
25129   ArrayRef<int> ShufMask = Shuf->getMask();
25130   ArrayRef<int> SplatMask = Splat->getMask();
25131   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
25132 
25133   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
25134   // every undef mask element in the splat-shuffle has a corresponding undef
25135   // element in the user-shuffle's mask or if the composition of mask elements
25136   // would result in undef.
25137   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
25138   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
25139   //   In this case it is not legal to simplify to the splat-shuffle because we
25140   //   may be exposing the users of the shuffle an undef element at index 1
25141   //   which was not there before the combine.
25142   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
25143   //   In this case the composition of masks yields SplatMask, so it's ok to
25144   //   simplify to the splat-shuffle.
25145   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
25146   //   In this case the composed mask includes all undef elements of SplatMask
25147   //   and in addition sets element zero to undef. It is safe to simplify to
25148   //   the splat-shuffle.
25149   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
25150                                        ArrayRef<int> SplatMask) {
25151     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
25152       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
25153           SplatMask[UserMask[i]] != -1)
25154         return false;
25155     return true;
25156   };
25157   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
25158     return Shuf->getOperand(0);
25159 
25160   // Create a new shuffle with a mask that is composed of the two shuffles'
25161   // masks.
25162   SmallVector<int, 32> NewMask;
25163   for (int Idx : ShufMask)
25164     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
25165 
25166   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
25167                               Splat->getOperand(0), Splat->getOperand(1),
25168                               NewMask);
25169 }
25170 
25171 // Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
25172 // the mask can be treated as a larger type.
combineShuffleOfBitcast(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)25173 static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
25174                                        SelectionDAG &DAG,
25175                                        const TargetLowering &TLI,
25176                                        bool LegalOperations) {
25177   SDValue Op0 = SVN->getOperand(0);
25178   SDValue Op1 = SVN->getOperand(1);
25179   EVT VT = SVN->getValueType(0);
25180   if (Op0.getOpcode() != ISD::BITCAST)
25181     return SDValue();
25182   EVT InVT = Op0.getOperand(0).getValueType();
25183   if (!InVT.isVector() ||
25184       (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
25185                           Op1.getOperand(0).getValueType() != InVT)))
25186     return SDValue();
25187   if (isAnyConstantBuildVector(Op0.getOperand(0)) &&
25188       (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0))))
25189     return SDValue();
25190 
25191   int VTLanes = VT.getVectorNumElements();
25192   int InLanes = InVT.getVectorNumElements();
25193   if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
25194       (LegalOperations &&
25195        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT)))
25196     return SDValue();
25197   int Factor = VTLanes / InLanes;
25198 
25199   // Check that each group of lanes in the mask are either undef or make a valid
25200   // mask for the wider lane type.
25201   ArrayRef<int> Mask = SVN->getMask();
25202   SmallVector<int> NewMask;
25203   if (!widenShuffleMaskElts(Factor, Mask, NewMask))
25204     return SDValue();
25205 
25206   if (!TLI.isShuffleMaskLegal(NewMask, InVT))
25207     return SDValue();
25208 
25209   // Create the new shuffle with the new mask and bitcast it back to the
25210   // original type.
25211   SDLoc DL(SVN);
25212   Op0 = Op0.getOperand(0);
25213   Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0);
25214   SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask);
25215   return DAG.getBitcast(VT, NewShuf);
25216 }
25217 
25218 /// Combine shuffle of shuffle of the form:
25219 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)25220 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
25221                                      SelectionDAG &DAG) {
25222   if (!OuterShuf->getOperand(1).isUndef())
25223     return SDValue();
25224   auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
25225   if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
25226     return SDValue();
25227 
25228   ArrayRef<int> OuterMask = OuterShuf->getMask();
25229   ArrayRef<int> InnerMask = InnerShuf->getMask();
25230   unsigned NumElts = OuterMask.size();
25231   assert(NumElts == InnerMask.size() && "Mask length mismatch");
25232   SmallVector<int, 32> CombinedMask(NumElts, -1);
25233   int SplatIndex = -1;
25234   for (unsigned i = 0; i != NumElts; ++i) {
25235     // Undef lanes remain undef.
25236     int OuterMaskElt = OuterMask[i];
25237     if (OuterMaskElt == -1)
25238       continue;
25239 
25240     // Peek through the shuffle masks to get the underlying source element.
25241     int InnerMaskElt = InnerMask[OuterMaskElt];
25242     if (InnerMaskElt == -1)
25243       continue;
25244 
25245     // Initialize the splatted element.
25246     if (SplatIndex == -1)
25247       SplatIndex = InnerMaskElt;
25248 
25249     // Non-matching index - this is not a splat.
25250     if (SplatIndex != InnerMaskElt)
25251       return SDValue();
25252 
25253     CombinedMask[i] = InnerMaskElt;
25254   }
25255   assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
25256           getSplatIndex(CombinedMask) != -1) &&
25257          "Expected a splat mask");
25258 
25259   // TODO: The transform may be a win even if the mask is not legal.
25260   EVT VT = OuterShuf->getValueType(0);
25261   assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
25262   if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
25263     return SDValue();
25264 
25265   return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
25266                               InnerShuf->getOperand(1), CombinedMask);
25267 }
25268 
25269 /// If the shuffle mask is taking exactly one element from the first vector
25270 /// operand and passing through all other elements from the second vector
25271 /// operand, return the index of the mask element that is choosing an element
25272 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)25273 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
25274   int MaskSize = Mask.size();
25275   int EltFromOp0 = -1;
25276   // TODO: This does not match if there are undef elements in the shuffle mask.
25277   // Should we ignore undefs in the shuffle mask instead? The trade-off is
25278   // removing an instruction (a shuffle), but losing the knowledge that some
25279   // vector lanes are not needed.
25280   for (int i = 0; i != MaskSize; ++i) {
25281     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
25282       // We're looking for a shuffle of exactly one element from operand 0.
25283       if (EltFromOp0 != -1)
25284         return -1;
25285       EltFromOp0 = i;
25286     } else if (Mask[i] != i + MaskSize) {
25287       // Nothing from operand 1 can change lanes.
25288       return -1;
25289     }
25290   }
25291   return EltFromOp0;
25292 }
25293 
25294 /// If a shuffle inserts exactly one element from a source vector operand into
25295 /// another vector operand and we can access the specified element as a scalar,
25296 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)25297 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
25298                                       SelectionDAG &DAG) {
25299   // First, check if we are taking one element of a vector and shuffling that
25300   // element into another vector.
25301   ArrayRef<int> Mask = Shuf->getMask();
25302   SmallVector<int, 16> CommutedMask(Mask);
25303   SDValue Op0 = Shuf->getOperand(0);
25304   SDValue Op1 = Shuf->getOperand(1);
25305   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
25306   if (ShufOp0Index == -1) {
25307     // Commute mask and check again.
25308     ShuffleVectorSDNode::commuteMask(CommutedMask);
25309     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
25310     if (ShufOp0Index == -1)
25311       return SDValue();
25312     // Commute operands to match the commuted shuffle mask.
25313     std::swap(Op0, Op1);
25314     Mask = CommutedMask;
25315   }
25316 
25317   // The shuffle inserts exactly one element from operand 0 into operand 1.
25318   // Now see if we can access that element as a scalar via a real insert element
25319   // instruction.
25320   // TODO: We can try harder to locate the element as a scalar. Examples: it
25321   // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
25322   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
25323          "Shuffle mask value must be from operand 0");
25324   if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
25325     return SDValue();
25326 
25327   auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
25328   if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
25329     return SDValue();
25330 
25331   // There's an existing insertelement with constant insertion index, so we
25332   // don't need to check the legality/profitability of a replacement operation
25333   // that differs at most in the constant value. The target should be able to
25334   // lower any of those in a similar way. If not, legalization will expand this
25335   // to a scalar-to-vector plus shuffle.
25336   //
25337   // Note that the shuffle may move the scalar from the position that the insert
25338   // element used. Therefore, our new insert element occurs at the shuffle's
25339   // mask index value, not the insert's index value.
25340   // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
25341   SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
25342   return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
25343                      Op1, Op0.getOperand(1), NewInsIndex);
25344 }
25345 
25346 /// If we have a unary shuffle of a shuffle, see if it can be folded away
25347 /// completely. This has the potential to lose undef knowledge because the first
25348 /// shuffle may not have an undef mask element where the second one does. So
25349 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)25350 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
25351   // shuf (shuf0 X, Y, Mask0), undef, Mask
25352   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
25353   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
25354     return SDValue();
25355 
25356   ArrayRef<int> Mask = Shuf->getMask();
25357   ArrayRef<int> Mask0 = Shuf0->getMask();
25358   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
25359     // Ignore undef elements.
25360     if (Mask[i] == -1)
25361       continue;
25362     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
25363 
25364     // Is the element of the shuffle operand chosen by this shuffle the same as
25365     // the element chosen by the shuffle operand itself?
25366     if (Mask0[Mask[i]] != Mask0[i])
25367       return SDValue();
25368   }
25369   // Every element of this shuffle is identical to the result of the previous
25370   // shuffle, so we can replace this value.
25371   return Shuf->getOperand(0);
25372 }
25373 
visitVECTOR_SHUFFLE(SDNode * N)25374 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
25375   EVT VT = N->getValueType(0);
25376   unsigned NumElts = VT.getVectorNumElements();
25377 
25378   SDValue N0 = N->getOperand(0);
25379   SDValue N1 = N->getOperand(1);
25380 
25381   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
25382 
25383   // Canonicalize shuffle undef, undef -> undef
25384   if (N0.isUndef() && N1.isUndef())
25385     return DAG.getUNDEF(VT);
25386 
25387   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
25388 
25389   // Canonicalize shuffle v, v -> v, undef
25390   if (N0 == N1)
25391     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
25392                                 createUnaryMask(SVN->getMask(), NumElts));
25393 
25394   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
25395   if (N0.isUndef())
25396     return DAG.getCommutedVectorShuffle(*SVN);
25397 
25398   // Remove references to rhs if it is undef
25399   if (N1.isUndef()) {
25400     bool Changed = false;
25401     SmallVector<int, 8> NewMask;
25402     for (unsigned i = 0; i != NumElts; ++i) {
25403       int Idx = SVN->getMaskElt(i);
25404       if (Idx >= (int)NumElts) {
25405         Idx = -1;
25406         Changed = true;
25407       }
25408       NewMask.push_back(Idx);
25409     }
25410     if (Changed)
25411       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
25412   }
25413 
25414   if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
25415     return InsElt;
25416 
25417   // A shuffle of a single vector that is a splatted value can always be folded.
25418   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
25419     return V;
25420 
25421   if (SDValue V = formSplatFromShuffles(SVN, DAG))
25422     return V;
25423 
25424   // If it is a splat, check if the argument vector is another splat or a
25425   // build_vector.
25426   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
25427     int SplatIndex = SVN->getSplatIndex();
25428     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
25429         TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) {
25430       // splat (vector_bo L, R), Index -->
25431       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
25432       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
25433       SDLoc DL(N);
25434       EVT EltVT = VT.getScalarType();
25435       SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
25436       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
25437       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
25438       SDValue NewBO =
25439           DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags());
25440       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
25441       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
25442       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
25443     }
25444 
25445     // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
25446     // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
25447     if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) &&
25448         N0.hasOneUse()) {
25449       if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
25450         return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0));
25451 
25452       if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
25453         if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2)))
25454           if (Idx->getAPIntValue() == SplatIndex)
25455             return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1));
25456 
25457       // Look through a bitcast if LE and splatting lane 0, through to a
25458       // scalar_to_vector or a build_vector.
25459       if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(0).hasOneUse() &&
25460           SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
25461           (N0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
25462            N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR)) {
25463         EVT N00VT = N0.getOperand(0).getValueType();
25464         if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
25465             VT.isInteger() && N00VT.isInteger()) {
25466           EVT InVT =
25467               TLI.getTypeToTransformTo(*DAG.getContext(), VT.getScalarType());
25468           SDValue Op = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0),
25469                                           SDLoc(N), InVT);
25470           return DAG.getSplatBuildVector(VT, SDLoc(N), Op);
25471         }
25472       }
25473     }
25474 
25475     // If this is a bit convert that changes the element type of the vector but
25476     // not the number of vector elements, look through it.  Be careful not to
25477     // look though conversions that change things like v4f32 to v2f64.
25478     SDNode *V = N0.getNode();
25479     if (V->getOpcode() == ISD::BITCAST) {
25480       SDValue ConvInput = V->getOperand(0);
25481       if (ConvInput.getValueType().isVector() &&
25482           ConvInput.getValueType().getVectorNumElements() == NumElts)
25483         V = ConvInput.getNode();
25484     }
25485 
25486     if (V->getOpcode() == ISD::BUILD_VECTOR) {
25487       assert(V->getNumOperands() == NumElts &&
25488              "BUILD_VECTOR has wrong number of operands");
25489       SDValue Base;
25490       bool AllSame = true;
25491       for (unsigned i = 0; i != NumElts; ++i) {
25492         if (!V->getOperand(i).isUndef()) {
25493           Base = V->getOperand(i);
25494           break;
25495         }
25496       }
25497       // Splat of <u, u, u, u>, return <u, u, u, u>
25498       if (!Base.getNode())
25499         return N0;
25500       for (unsigned i = 0; i != NumElts; ++i) {
25501         if (V->getOperand(i) != Base) {
25502           AllSame = false;
25503           break;
25504         }
25505       }
25506       // Splat of <x, x, x, x>, return <x, x, x, x>
25507       if (AllSame)
25508         return N0;
25509 
25510       // Canonicalize any other splat as a build_vector.
25511       SDValue Splatted = V->getOperand(SplatIndex);
25512       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
25513       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
25514 
25515       // We may have jumped through bitcasts, so the type of the
25516       // BUILD_VECTOR may not match the type of the shuffle.
25517       if (V->getValueType(0) != VT)
25518         NewBV = DAG.getBitcast(VT, NewBV);
25519       return NewBV;
25520     }
25521   }
25522 
25523   // Simplify source operands based on shuffle mask.
25524   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
25525     return SDValue(N, 0);
25526 
25527   // This is intentionally placed after demanded elements simplification because
25528   // it could eliminate knowledge of undef elements created by this shuffle.
25529   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
25530     return ShufOp;
25531 
25532   // Match shuffles that can be converted to any_vector_extend_in_reg.
25533   if (SDValue V =
25534           combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
25535     return V;
25536 
25537   // Combine "truncate_vector_in_reg" style shuffles.
25538   if (SDValue V = combineTruncationShuffle(SVN, DAG))
25539     return V;
25540 
25541   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
25542       Level < AfterLegalizeVectorOps &&
25543       (N1.isUndef() ||
25544       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
25545        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
25546     if (SDValue V = partitionShuffleOfConcats(N, DAG))
25547       return V;
25548   }
25549 
25550   // A shuffle of a concat of the same narrow vector can be reduced to use
25551   // only low-half elements of a concat with undef:
25552   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
25553   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
25554       N0.getNumOperands() == 2 &&
25555       N0.getOperand(0) == N0.getOperand(1)) {
25556     int HalfNumElts = (int)NumElts / 2;
25557     SmallVector<int, 8> NewMask;
25558     for (unsigned i = 0; i != NumElts; ++i) {
25559       int Idx = SVN->getMaskElt(i);
25560       if (Idx >= HalfNumElts) {
25561         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
25562         Idx -= HalfNumElts;
25563       }
25564       NewMask.push_back(Idx);
25565     }
25566     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
25567       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
25568       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
25569                                    N0.getOperand(0), UndefVec);
25570       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
25571     }
25572   }
25573 
25574   // See if we can replace a shuffle with an insert_subvector.
25575   // e.g. v2i32 into v8i32:
25576   // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
25577   // --> insert_subvector(lhs,rhs1,4).
25578   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
25579       TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
25580     auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
25581       // Ensure RHS subvectors are legal.
25582       assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
25583       EVT SubVT = RHS.getOperand(0).getValueType();
25584       int NumSubVecs = RHS.getNumOperands();
25585       int NumSubElts = SubVT.getVectorNumElements();
25586       assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
25587       if (!TLI.isTypeLegal(SubVT))
25588         return SDValue();
25589 
25590       // Don't bother if we have an unary shuffle (matches undef + LHS elts).
25591       if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
25592         return SDValue();
25593 
25594       // Search [NumSubElts] spans for RHS sequence.
25595       // TODO: Can we avoid nested loops to increase performance?
25596       SmallVector<int> InsertionMask(NumElts);
25597       for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
25598         for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
25599           // Reset mask to identity.
25600           std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
25601 
25602           // Add subvector insertion.
25603           std::iota(InsertionMask.begin() + SubIdx,
25604                     InsertionMask.begin() + SubIdx + NumSubElts,
25605                     NumElts + (SubVec * NumSubElts));
25606 
25607           // See if the shuffle mask matches the reference insertion mask.
25608           bool MatchingShuffle = true;
25609           for (int i = 0; i != (int)NumElts; ++i) {
25610             int ExpectIdx = InsertionMask[i];
25611             int ActualIdx = Mask[i];
25612             if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
25613               MatchingShuffle = false;
25614               break;
25615             }
25616           }
25617 
25618           if (MatchingShuffle)
25619             return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, LHS,
25620                                RHS.getOperand(SubVec),
25621                                DAG.getVectorIdxConstant(SubIdx, SDLoc(N)));
25622         }
25623       }
25624       return SDValue();
25625     };
25626     ArrayRef<int> Mask = SVN->getMask();
25627     if (N1.getOpcode() == ISD::CONCAT_VECTORS)
25628       if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
25629         return InsertN1;
25630     if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
25631       SmallVector<int> CommuteMask(Mask);
25632       ShuffleVectorSDNode::commuteMask(CommuteMask);
25633       if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
25634         return InsertN0;
25635     }
25636   }
25637 
25638   // If we're not performing a select/blend shuffle, see if we can convert the
25639   // shuffle into a AND node, with all the out-of-lane elements are known zero.
25640   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
25641     bool IsInLaneMask = true;
25642     ArrayRef<int> Mask = SVN->getMask();
25643     SmallVector<int, 16> ClearMask(NumElts, -1);
25644     APInt DemandedLHS = APInt::getZero(NumElts);
25645     APInt DemandedRHS = APInt::getZero(NumElts);
25646     for (int I = 0; I != (int)NumElts; ++I) {
25647       int M = Mask[I];
25648       if (M < 0)
25649         continue;
25650       ClearMask[I] = M == I ? I : (I + NumElts);
25651       IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
25652       if (M != I) {
25653         APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
25654         Demanded.setBit(M % NumElts);
25655       }
25656     }
25657     // TODO: Should we try to mask with N1 as well?
25658     if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
25659         (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
25660         (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
25661       SDLoc DL(N);
25662       EVT IntVT = VT.changeVectorElementTypeToInteger();
25663       EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
25664       // Transform the type to a legal type so that the buildvector constant
25665       // elements are not illegal. Make sure that the result is larger than the
25666       // original type, incase the value is split into two (eg i64->i32).
25667       if (!TLI.isTypeLegal(IntSVT) && LegalTypes)
25668         IntSVT = TLI.getTypeToTransformTo(*DAG.getContext(), IntSVT);
25669       if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
25670         SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
25671         SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
25672         SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
25673         for (int I = 0; I != (int)NumElts; ++I)
25674           if (0 <= Mask[I])
25675             AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
25676 
25677         // See if a clear mask is legal instead of going via
25678         // XformToShuffleWithZero which loses UNDEF mask elements.
25679         if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
25680           return DAG.getBitcast(
25681               VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
25682                                       DAG.getConstant(0, DL, IntVT), ClearMask));
25683 
25684         if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
25685           return DAG.getBitcast(
25686               VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
25687                               DAG.getBuildVector(IntVT, DL, AndMask)));
25688       }
25689     }
25690   }
25691 
25692   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
25693   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
25694   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
25695     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
25696       return Res;
25697 
25698   // If this shuffle only has a single input that is a bitcasted shuffle,
25699   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
25700   // back to their original types.
25701   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
25702       N1.isUndef() && Level < AfterLegalizeVectorOps &&
25703       TLI.isTypeLegal(VT)) {
25704 
25705     SDValue BC0 = peekThroughOneUseBitcasts(N0);
25706     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
25707       EVT SVT = VT.getScalarType();
25708       EVT InnerVT = BC0->getValueType(0);
25709       EVT InnerSVT = InnerVT.getScalarType();
25710 
25711       // Determine which shuffle works with the smaller scalar type.
25712       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
25713       EVT ScaleSVT = ScaleVT.getScalarType();
25714 
25715       if (TLI.isTypeLegal(ScaleVT) &&
25716           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
25717           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
25718         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
25719         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
25720 
25721         // Scale the shuffle masks to the smaller scalar type.
25722         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
25723         SmallVector<int, 8> InnerMask;
25724         SmallVector<int, 8> OuterMask;
25725         narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
25726         narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
25727 
25728         // Merge the shuffle masks.
25729         SmallVector<int, 8> NewMask;
25730         for (int M : OuterMask)
25731           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
25732 
25733         // Test for shuffle mask legality over both commutations.
25734         SDValue SV0 = BC0->getOperand(0);
25735         SDValue SV1 = BC0->getOperand(1);
25736         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
25737         if (!LegalMask) {
25738           std::swap(SV0, SV1);
25739           ShuffleVectorSDNode::commuteMask(NewMask);
25740           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
25741         }
25742 
25743         if (LegalMask) {
25744           SV0 = DAG.getBitcast(ScaleVT, SV0);
25745           SV1 = DAG.getBitcast(ScaleVT, SV1);
25746           return DAG.getBitcast(
25747               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
25748         }
25749       }
25750     }
25751   }
25752 
25753   // Match shuffles of bitcasts, so long as the mask can be treated as the
25754   // larger type.
25755   if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
25756     return V;
25757 
25758   // Compute the combined shuffle mask for a shuffle with SV0 as the first
25759   // operand, and SV1 as the second operand.
25760   // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
25761   //      Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
25762   auto MergeInnerShuffle =
25763       [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
25764                      ShuffleVectorSDNode *OtherSVN, SDValue N1,
25765                      const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
25766                      SmallVectorImpl<int> &Mask) -> bool {
25767     // Don't try to fold splats; they're likely to simplify somehow, or they
25768     // might be free.
25769     if (OtherSVN->isSplat())
25770       return false;
25771 
25772     SV0 = SV1 = SDValue();
25773     Mask.clear();
25774 
25775     for (unsigned i = 0; i != NumElts; ++i) {
25776       int Idx = SVN->getMaskElt(i);
25777       if (Idx < 0) {
25778         // Propagate Undef.
25779         Mask.push_back(Idx);
25780         continue;
25781       }
25782 
25783       if (Commute)
25784         Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
25785 
25786       SDValue CurrentVec;
25787       if (Idx < (int)NumElts) {
25788         // This shuffle index refers to the inner shuffle N0. Lookup the inner
25789         // shuffle mask to identify which vector is actually referenced.
25790         Idx = OtherSVN->getMaskElt(Idx);
25791         if (Idx < 0) {
25792           // Propagate Undef.
25793           Mask.push_back(Idx);
25794           continue;
25795         }
25796         CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
25797                                           : OtherSVN->getOperand(1);
25798       } else {
25799         // This shuffle index references an element within N1.
25800         CurrentVec = N1;
25801       }
25802 
25803       // Simple case where 'CurrentVec' is UNDEF.
25804       if (CurrentVec.isUndef()) {
25805         Mask.push_back(-1);
25806         continue;
25807       }
25808 
25809       // Canonicalize the shuffle index. We don't know yet if CurrentVec
25810       // will be the first or second operand of the combined shuffle.
25811       Idx = Idx % NumElts;
25812       if (!SV0.getNode() || SV0 == CurrentVec) {
25813         // Ok. CurrentVec is the left hand side.
25814         // Update the mask accordingly.
25815         SV0 = CurrentVec;
25816         Mask.push_back(Idx);
25817         continue;
25818       }
25819       if (!SV1.getNode() || SV1 == CurrentVec) {
25820         // Ok. CurrentVec is the right hand side.
25821         // Update the mask accordingly.
25822         SV1 = CurrentVec;
25823         Mask.push_back(Idx + NumElts);
25824         continue;
25825       }
25826 
25827       // Last chance - see if the vector is another shuffle and if it
25828       // uses one of the existing candidate shuffle ops.
25829       if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
25830         int InnerIdx = CurrentSVN->getMaskElt(Idx);
25831         if (InnerIdx < 0) {
25832           Mask.push_back(-1);
25833           continue;
25834         }
25835         SDValue InnerVec = (InnerIdx < (int)NumElts)
25836                                ? CurrentSVN->getOperand(0)
25837                                : CurrentSVN->getOperand(1);
25838         if (InnerVec.isUndef()) {
25839           Mask.push_back(-1);
25840           continue;
25841         }
25842         InnerIdx %= NumElts;
25843         if (InnerVec == SV0) {
25844           Mask.push_back(InnerIdx);
25845           continue;
25846         }
25847         if (InnerVec == SV1) {
25848           Mask.push_back(InnerIdx + NumElts);
25849           continue;
25850         }
25851       }
25852 
25853       // Bail out if we cannot convert the shuffle pair into a single shuffle.
25854       return false;
25855     }
25856 
25857     if (llvm::all_of(Mask, [](int M) { return M < 0; }))
25858       return true;
25859 
25860     // Avoid introducing shuffles with illegal mask.
25861     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
25862     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
25863     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
25864     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
25865     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
25866     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
25867     if (TLI.isShuffleMaskLegal(Mask, VT))
25868       return true;
25869 
25870     std::swap(SV0, SV1);
25871     ShuffleVectorSDNode::commuteMask(Mask);
25872     return TLI.isShuffleMaskLegal(Mask, VT);
25873   };
25874 
25875   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
25876     // Canonicalize shuffles according to rules:
25877     //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
25878     //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
25879     //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
25880     if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
25881         N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
25882       // The incoming shuffle must be of the same type as the result of the
25883       // current shuffle.
25884       assert(N1->getOperand(0).getValueType() == VT &&
25885              "Shuffle types don't match");
25886 
25887       SDValue SV0 = N1->getOperand(0);
25888       SDValue SV1 = N1->getOperand(1);
25889       bool HasSameOp0 = N0 == SV0;
25890       bool IsSV1Undef = SV1.isUndef();
25891       if (HasSameOp0 || IsSV1Undef || N0 == SV1)
25892         // Commute the operands of this shuffle so merging below will trigger.
25893         return DAG.getCommutedVectorShuffle(*SVN);
25894     }
25895 
25896     // Canonicalize splat shuffles to the RHS to improve merging below.
25897     //  shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
25898     if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
25899         N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
25900         cast<ShuffleVectorSDNode>(N0)->isSplat() &&
25901         !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
25902       return DAG.getCommutedVectorShuffle(*SVN);
25903     }
25904 
25905     // Try to fold according to rules:
25906     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
25907     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
25908     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
25909     // Don't try to fold shuffles with illegal type.
25910     // Only fold if this shuffle is the only user of the other shuffle.
25911     // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
25912     for (int i = 0; i != 2; ++i) {
25913       if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
25914           N->isOnlyUserOf(N->getOperand(i).getNode())) {
25915         // The incoming shuffle must be of the same type as the result of the
25916         // current shuffle.
25917         auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
25918         assert(OtherSV->getOperand(0).getValueType() == VT &&
25919                "Shuffle types don't match");
25920 
25921         SDValue SV0, SV1;
25922         SmallVector<int, 4> Mask;
25923         if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
25924                               SV0, SV1, Mask)) {
25925           // Check if all indices in Mask are Undef. In case, propagate Undef.
25926           if (llvm::all_of(Mask, [](int M) { return M < 0; }))
25927             return DAG.getUNDEF(VT);
25928 
25929           return DAG.getVectorShuffle(VT, SDLoc(N),
25930                                       SV0 ? SV0 : DAG.getUNDEF(VT),
25931                                       SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
25932         }
25933       }
25934     }
25935 
25936     // Merge shuffles through binops if we are able to merge it with at least
25937     // one other shuffles.
25938     // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
25939     // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
25940     unsigned SrcOpcode = N0.getOpcode();
25941     if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
25942         (N1.isUndef() ||
25943          (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
25944       // Get binop source ops, or just pass on the undef.
25945       SDValue Op00 = N0.getOperand(0);
25946       SDValue Op01 = N0.getOperand(1);
25947       SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
25948       SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
25949       // TODO: We might be able to relax the VT check but we don't currently
25950       // have any isBinOp() that has different result/ops VTs so play safe until
25951       // we have test coverage.
25952       if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
25953           Op01.getValueType() == VT && Op11.getValueType() == VT &&
25954           (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
25955            Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
25956            Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
25957            Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
25958         auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
25959                                         SmallVectorImpl<int> &Mask, bool LeftOp,
25960                                         bool Commute) {
25961           SDValue InnerN = Commute ? N1 : N0;
25962           SDValue Op0 = LeftOp ? Op00 : Op01;
25963           SDValue Op1 = LeftOp ? Op10 : Op11;
25964           if (Commute)
25965             std::swap(Op0, Op1);
25966           // Only accept the merged shuffle if we don't introduce undef elements,
25967           // or the inner shuffle already contained undef elements.
25968           auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
25969           return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
25970                  MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
25971                                    Mask) &&
25972                  (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
25973                   llvm::none_of(Mask, [](int M) { return M < 0; }));
25974         };
25975 
25976         // Ensure we don't increase the number of shuffles - we must merge a
25977         // shuffle from at least one of the LHS and RHS ops.
25978         bool MergedLeft = false;
25979         SDValue LeftSV0, LeftSV1;
25980         SmallVector<int, 4> LeftMask;
25981         if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
25982             CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
25983           MergedLeft = true;
25984         } else {
25985           LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
25986           LeftSV0 = Op00, LeftSV1 = Op10;
25987         }
25988 
25989         bool MergedRight = false;
25990         SDValue RightSV0, RightSV1;
25991         SmallVector<int, 4> RightMask;
25992         if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
25993             CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
25994           MergedRight = true;
25995         } else {
25996           RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
25997           RightSV0 = Op01, RightSV1 = Op11;
25998         }
25999 
26000         if (MergedLeft || MergedRight) {
26001           SDLoc DL(N);
26002           SDValue LHS = DAG.getVectorShuffle(
26003               VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
26004               LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
26005           SDValue RHS = DAG.getVectorShuffle(
26006               VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
26007               RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
26008           return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
26009         }
26010       }
26011     }
26012   }
26013 
26014   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
26015     return V;
26016 
26017   // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
26018   // Perform this really late, because it could eliminate knowledge
26019   // of undef elements created by this shuffle.
26020   if (Level < AfterLegalizeTypes)
26021     if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
26022                                                           LegalOperations))
26023       return V;
26024 
26025   return SDValue();
26026 }
26027 
visitSCALAR_TO_VECTOR(SDNode * N)26028 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
26029   EVT VT = N->getValueType(0);
26030   if (!VT.isFixedLengthVector())
26031     return SDValue();
26032 
26033   // Try to convert a scalar binop with an extracted vector element to a vector
26034   // binop. This is intended to reduce potentially expensive register moves.
26035   // TODO: Check if both operands are extracted.
26036   // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
26037   // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
26038   SDValue Scalar = N->getOperand(0);
26039   unsigned Opcode = Scalar.getOpcode();
26040   EVT VecEltVT = VT.getScalarType();
26041   if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
26042       TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
26043       Scalar.getOperand(0).getValueType() == VecEltVT &&
26044       Scalar.getOperand(1).getValueType() == VecEltVT &&
26045       Scalar->isOnlyUserOf(Scalar.getOperand(0).getNode()) &&
26046       Scalar->isOnlyUserOf(Scalar.getOperand(1).getNode()) &&
26047       DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
26048     // Match an extract element and get a shuffle mask equivalent.
26049     SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
26050 
26051     for (int i : {0, 1}) {
26052       // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
26053       // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
26054       SDValue EE = Scalar.getOperand(i);
26055       auto *C = dyn_cast<ConstantSDNode>(Scalar.getOperand(i ? 0 : 1));
26056       if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
26057           EE.getOperand(0).getValueType() == VT &&
26058           isa<ConstantSDNode>(EE.getOperand(1))) {
26059         // Mask = {ExtractIndex, undef, undef....}
26060         ShufMask[0] = EE.getConstantOperandVal(1);
26061         // Make sure the shuffle is legal if we are crossing lanes.
26062         if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
26063           SDLoc DL(N);
26064           SDValue V[] = {EE.getOperand(0),
26065                          DAG.getConstant(C->getAPIntValue(), DL, VT)};
26066           SDValue VecBO = DAG.getNode(Opcode, DL, VT, V[i], V[1 - i]);
26067           return DAG.getVectorShuffle(VT, DL, VecBO, DAG.getUNDEF(VT),
26068                                       ShufMask);
26069         }
26070       }
26071     }
26072   }
26073 
26074   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
26075   // with a VECTOR_SHUFFLE and possible truncate.
26076   if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
26077       !Scalar.getOperand(0).getValueType().isFixedLengthVector())
26078     return SDValue();
26079 
26080   // If we have an implicit truncate, truncate here if it is legal.
26081   if (VecEltVT != Scalar.getValueType() &&
26082       Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) {
26083     SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar);
26084     return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
26085   }
26086 
26087   auto *ExtIndexC = dyn_cast<ConstantSDNode>(Scalar.getOperand(1));
26088   if (!ExtIndexC)
26089     return SDValue();
26090 
26091   SDValue SrcVec = Scalar.getOperand(0);
26092   EVT SrcVT = SrcVec.getValueType();
26093   unsigned SrcNumElts = SrcVT.getVectorNumElements();
26094   unsigned VTNumElts = VT.getVectorNumElements();
26095   if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
26096     // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
26097     SmallVector<int, 8> Mask(SrcNumElts, -1);
26098     Mask[0] = ExtIndexC->getZExtValue();
26099     SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
26100         SrcVT, SDLoc(N), SrcVec, DAG.getUNDEF(SrcVT), Mask, DAG);
26101     if (!LegalShuffle)
26102       return SDValue();
26103 
26104     // If the initial vector is the same size, the shuffle is the result.
26105     if (VT == SrcVT)
26106       return LegalShuffle;
26107 
26108     // If not, shorten the shuffled vector.
26109     if (VTNumElts != SrcNumElts) {
26110       SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
26111       EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
26112                                    SrcVT.getVectorElementType(), VTNumElts);
26113       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, LegalShuffle,
26114                          ZeroIdx);
26115     }
26116   }
26117 
26118   return SDValue();
26119 }
26120 
visitINSERT_SUBVECTOR(SDNode * N)26121 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
26122   EVT VT = N->getValueType(0);
26123   SDValue N0 = N->getOperand(0);
26124   SDValue N1 = N->getOperand(1);
26125   SDValue N2 = N->getOperand(2);
26126   uint64_t InsIdx = N->getConstantOperandVal(2);
26127 
26128   // If inserting an UNDEF, just return the original vector.
26129   if (N1.isUndef())
26130     return N0;
26131 
26132   // If this is an insert of an extracted vector into an undef vector, we can
26133   // just use the input to the extract if the types match, and can simplify
26134   // in some cases even if they don't.
26135   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26136       N1.getOperand(1) == N2) {
26137     EVT SrcVT = N1.getOperand(0).getValueType();
26138     if (SrcVT == VT)
26139       return N1.getOperand(0);
26140     // TODO: To remove the zero check, need to adjust the offset to
26141     // a multiple of the new src type.
26142     if (isNullConstant(N2) &&
26143         VT.isScalableVector() == SrcVT.isScalableVector()) {
26144       if (VT.getVectorMinNumElements() >= SrcVT.getVectorMinNumElements())
26145         return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
26146                            VT, N0, N1.getOperand(0), N2);
26147       else
26148         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N),
26149                            VT, N1.getOperand(0), N2);
26150     }
26151   }
26152 
26153   // Simplify scalar inserts into an undef vector:
26154   // insert_subvector undef, (splat X), N2 -> splat X
26155   if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
26156     if (DAG.isConstantValueOfAnyType(N1.getOperand(0)) || N1.hasOneUse())
26157       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0));
26158 
26159   // If we are inserting a bitcast value into an undef, with the same
26160   // number of elements, just use the bitcast input of the extract.
26161   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
26162   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
26163   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
26164       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26165       N1.getOperand(0).getOperand(1) == N2 &&
26166       N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
26167           VT.getVectorElementCount() &&
26168       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
26169           VT.getSizeInBits()) {
26170     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
26171   }
26172 
26173   // If both N1 and N2 are bitcast values on which insert_subvector
26174   // would makes sense, pull the bitcast through.
26175   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
26176   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
26177   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
26178     SDValue CN0 = N0.getOperand(0);
26179     SDValue CN1 = N1.getOperand(0);
26180     EVT CN0VT = CN0.getValueType();
26181     EVT CN1VT = CN1.getValueType();
26182     if (CN0VT.isVector() && CN1VT.isVector() &&
26183         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
26184         CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
26185       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
26186                                       CN0.getValueType(), CN0, CN1, N2);
26187       return DAG.getBitcast(VT, NewINSERT);
26188     }
26189   }
26190 
26191   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
26192   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
26193   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
26194   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26195       N0.getOperand(1).getValueType() == N1.getValueType() &&
26196       N0.getOperand(2) == N2)
26197     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
26198                        N1, N2);
26199 
26200   // Eliminate an intermediate insert into an undef vector:
26201   // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
26202   // insert_subvector undef, X, 0
26203   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
26204       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)) &&
26205       isNullConstant(N2))
26206     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
26207                        N1.getOperand(1), N2);
26208 
26209   // Push subvector bitcasts to the output, adjusting the index as we go.
26210   // insert_subvector(bitcast(v), bitcast(s), c1)
26211   // -> bitcast(insert_subvector(v, s, c2))
26212   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
26213       N1.getOpcode() == ISD::BITCAST) {
26214     SDValue N0Src = peekThroughBitcasts(N0);
26215     SDValue N1Src = peekThroughBitcasts(N1);
26216     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
26217     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
26218     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
26219         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
26220       EVT NewVT;
26221       SDLoc DL(N);
26222       SDValue NewIdx;
26223       LLVMContext &Ctx = *DAG.getContext();
26224       ElementCount NumElts = VT.getVectorElementCount();
26225       unsigned EltSizeInBits = VT.getScalarSizeInBits();
26226       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
26227         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
26228         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
26229         NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
26230       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
26231         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
26232         if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
26233           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
26234                                    NumElts.divideCoefficientBy(Scale));
26235           NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
26236         }
26237       }
26238       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
26239         SDValue Res = DAG.getBitcast(NewVT, N0Src);
26240         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
26241         return DAG.getBitcast(VT, Res);
26242       }
26243     }
26244   }
26245 
26246   // Canonicalize insert_subvector dag nodes.
26247   // Example:
26248   // (insert_subvector (insert_subvector A, Idx0), Idx1)
26249   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
26250   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
26251       N1.getValueType() == N0.getOperand(1).getValueType()) {
26252     unsigned OtherIdx = N0.getConstantOperandVal(2);
26253     if (InsIdx < OtherIdx) {
26254       // Swap nodes.
26255       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
26256                                   N0.getOperand(0), N1, N2);
26257       AddToWorklist(NewOp.getNode());
26258       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
26259                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
26260     }
26261   }
26262 
26263   // If the input vector is a concatenation, and the insert replaces
26264   // one of the pieces, we can optimize into a single concat_vectors.
26265   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
26266       N0.getOperand(0).getValueType() == N1.getValueType() &&
26267       N0.getOperand(0).getValueType().isScalableVector() ==
26268           N1.getValueType().isScalableVector()) {
26269     unsigned Factor = N1.getValueType().getVectorMinNumElements();
26270     SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
26271     Ops[InsIdx / Factor] = N1;
26272     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
26273   }
26274 
26275   // Simplify source operands based on insertion.
26276   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
26277     return SDValue(N, 0);
26278 
26279   return SDValue();
26280 }
26281 
visitFP_TO_FP16(SDNode * N)26282 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
26283   SDValue N0 = N->getOperand(0);
26284 
26285   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
26286   if (N0->getOpcode() == ISD::FP16_TO_FP)
26287     return N0->getOperand(0);
26288 
26289   return SDValue();
26290 }
26291 
visitFP16_TO_FP(SDNode * N)26292 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
26293   auto Op = N->getOpcode();
26294   assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
26295          "opcode should be FP16_TO_FP or BF16_TO_FP.");
26296   SDValue N0 = N->getOperand(0);
26297 
26298   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
26299   // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26300   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
26301     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
26302     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
26303       return DAG.getNode(Op, SDLoc(N), N->getValueType(0), N0.getOperand(0));
26304     }
26305   }
26306 
26307   return SDValue();
26308 }
26309 
visitFP_TO_BF16(SDNode * N)26310 SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
26311   SDValue N0 = N->getOperand(0);
26312 
26313   // fold (fp_to_bf16 (bf16_to_fp op)) -> op
26314   if (N0->getOpcode() == ISD::BF16_TO_FP)
26315     return N0->getOperand(0);
26316 
26317   return SDValue();
26318 }
26319 
visitBF16_TO_FP(SDNode * N)26320 SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
26321   // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26322   return visitFP16_TO_FP(N);
26323 }
26324 
visitVECREDUCE(SDNode * N)26325 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
26326   SDValue N0 = N->getOperand(0);
26327   EVT VT = N0.getValueType();
26328   unsigned Opcode = N->getOpcode();
26329 
26330   // VECREDUCE over 1-element vector is just an extract.
26331   if (VT.getVectorElementCount().isScalar()) {
26332     SDLoc dl(N);
26333     SDValue Res =
26334         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
26335                     DAG.getVectorIdxConstant(0, dl));
26336     if (Res.getValueType() != N->getValueType(0))
26337       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
26338     return Res;
26339   }
26340 
26341   // On an boolean vector an and/or reduction is the same as a umin/umax
26342   // reduction. Convert them if the latter is legal while the former isn't.
26343   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
26344     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
26345         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
26346     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
26347         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
26348         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
26349       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
26350   }
26351 
26352   // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
26353   // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
26354   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26355       TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
26356     SDValue Vec = N0.getOperand(0);
26357     SDValue Subvec = N0.getOperand(1);
26358     if ((Opcode == ISD::VECREDUCE_OR &&
26359          (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) ||
26360         (Opcode == ISD::VECREDUCE_AND &&
26361          (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec))))
26362       return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec);
26363   }
26364 
26365   return SDValue();
26366 }
26367 
visitVP_FSUB(SDNode * N)26368 SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
26369   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
26370 
26371   // FSUB -> FMA combines:
26372   if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
26373     AddToWorklist(Fused.getNode());
26374     return Fused;
26375   }
26376   return SDValue();
26377 }
26378 
visitVPOp(SDNode * N)26379 SDValue DAGCombiner::visitVPOp(SDNode *N) {
26380 
26381   if (N->getOpcode() == ISD::VP_GATHER)
26382     if (SDValue SD = visitVPGATHER(N))
26383       return SD;
26384 
26385   if (N->getOpcode() == ISD::VP_SCATTER)
26386     if (SDValue SD = visitVPSCATTER(N))
26387       return SD;
26388 
26389   if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
26390     if (SDValue SD = visitVP_STRIDED_LOAD(N))
26391       return SD;
26392 
26393   if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
26394     if (SDValue SD = visitVP_STRIDED_STORE(N))
26395       return SD;
26396 
26397   // VP operations in which all vector elements are disabled - either by
26398   // determining that the mask is all false or that the EVL is 0 - can be
26399   // eliminated.
26400   bool AreAllEltsDisabled = false;
26401   if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(N->getOpcode()))
26402     AreAllEltsDisabled |= isNullConstant(N->getOperand(*EVLIdx));
26403   if (auto MaskIdx = ISD::getVPMaskIdx(N->getOpcode()))
26404     AreAllEltsDisabled |=
26405         ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
26406 
26407   // This is the only generic VP combine we support for now.
26408   if (!AreAllEltsDisabled) {
26409     switch (N->getOpcode()) {
26410     case ISD::VP_FADD:
26411       return visitVP_FADD(N);
26412     case ISD::VP_FSUB:
26413       return visitVP_FSUB(N);
26414     case ISD::VP_FMA:
26415       return visitFMA<VPMatchContext>(N);
26416     }
26417     return SDValue();
26418   }
26419 
26420   // Binary operations can be replaced by UNDEF.
26421   if (ISD::isVPBinaryOp(N->getOpcode()))
26422     return DAG.getUNDEF(N->getValueType(0));
26423 
26424   // VP Memory operations can be replaced by either the chain (stores) or the
26425   // chain + undef (loads).
26426   if (const auto *MemSD = dyn_cast<MemSDNode>(N)) {
26427     if (MemSD->writeMem())
26428       return MemSD->getChain();
26429     return CombineTo(N, DAG.getUNDEF(N->getValueType(0)), MemSD->getChain());
26430   }
26431 
26432   // Reduction operations return the start operand when no elements are active.
26433   if (ISD::isVPReduction(N->getOpcode()))
26434     return N->getOperand(0);
26435 
26436   return SDValue();
26437 }
26438 
visitGET_FPENV_MEM(SDNode * N)26439 SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
26440   SDValue Chain = N->getOperand(0);
26441   SDValue Ptr = N->getOperand(1);
26442   EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT();
26443 
26444   // Check if the memory, where FP state is written to, is used only in a single
26445   // load operation.
26446   LoadSDNode *LdNode = nullptr;
26447   for (auto *U : Ptr->uses()) {
26448     if (U == N)
26449       continue;
26450     if (auto *Ld = dyn_cast<LoadSDNode>(U)) {
26451       if (LdNode && LdNode != Ld)
26452         return SDValue();
26453       LdNode = Ld;
26454       continue;
26455     }
26456     return SDValue();
26457   }
26458   if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26459       !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26460       !LdNode->getChain().reachesChainWithoutSideEffects(SDValue(N, 0)))
26461     return SDValue();
26462 
26463   // Check if the loaded value is used only in a store operation.
26464   StoreSDNode *StNode = nullptr;
26465   for (auto I = LdNode->use_begin(), E = LdNode->use_end(); I != E; ++I) {
26466     SDUse &U = I.getUse();
26467     if (U.getResNo() == 0) {
26468       if (auto *St = dyn_cast<StoreSDNode>(U.getUser())) {
26469         if (StNode)
26470           return SDValue();
26471         StNode = St;
26472       } else {
26473         return SDValue();
26474       }
26475     }
26476   }
26477   if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26478       !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26479       !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1)))
26480     return SDValue();
26481 
26482   // Create new node GET_FPENV_MEM, which uses the store address to write FP
26483   // environment.
26484   SDValue Res = DAG.getGetFPEnv(Chain, SDLoc(N), StNode->getBasePtr(), MemVT,
26485                                 StNode->getMemOperand());
26486   CombineTo(StNode, Res, false);
26487   return Res;
26488 }
26489 
visitSET_FPENV_MEM(SDNode * N)26490 SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
26491   SDValue Chain = N->getOperand(0);
26492   SDValue Ptr = N->getOperand(1);
26493   EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT();
26494 
26495   // Check if the address of FP state is used also in a store operation only.
26496   StoreSDNode *StNode = nullptr;
26497   for (auto *U : Ptr->uses()) {
26498     if (U == N)
26499       continue;
26500     if (auto *St = dyn_cast<StoreSDNode>(U)) {
26501       if (StNode && StNode != St)
26502         return SDValue();
26503       StNode = St;
26504       continue;
26505     }
26506     return SDValue();
26507   }
26508   if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26509       !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26510       !Chain.reachesChainWithoutSideEffects(SDValue(StNode, 0)))
26511     return SDValue();
26512 
26513   // Check if the stored value is loaded from some location and the loaded
26514   // value is used only in the store operation.
26515   SDValue StValue = StNode->getValue();
26516   auto *LdNode = dyn_cast<LoadSDNode>(StValue);
26517   if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26518       !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26519       !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1)))
26520     return SDValue();
26521 
26522   // Create new node SET_FPENV_MEM, which uses the load address to read FP
26523   // environment.
26524   SDValue Res =
26525       DAG.getSetFPEnv(LdNode->getChain(), SDLoc(N), LdNode->getBasePtr(), MemVT,
26526                       LdNode->getMemOperand());
26527   return Res;
26528 }
26529 
26530 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
26531 /// with the destination vector and a zero vector.
26532 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
26533 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)26534 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
26535   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
26536 
26537   EVT VT = N->getValueType(0);
26538   SDValue LHS = N->getOperand(0);
26539   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
26540   SDLoc DL(N);
26541 
26542   // Make sure we're not running after operation legalization where it
26543   // may have custom lowered the vector shuffles.
26544   if (LegalOperations)
26545     return SDValue();
26546 
26547   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
26548     return SDValue();
26549 
26550   EVT RVT = RHS.getValueType();
26551   unsigned NumElts = RHS.getNumOperands();
26552 
26553   // Attempt to create a valid clear mask, splitting the mask into
26554   // sub elements and checking to see if each is
26555   // all zeros or all ones - suitable for shuffle masking.
26556   auto BuildClearMask = [&](int Split) {
26557     int NumSubElts = NumElts * Split;
26558     int NumSubBits = RVT.getScalarSizeInBits() / Split;
26559 
26560     SmallVector<int, 8> Indices;
26561     for (int i = 0; i != NumSubElts; ++i) {
26562       int EltIdx = i / Split;
26563       int SubIdx = i % Split;
26564       SDValue Elt = RHS.getOperand(EltIdx);
26565       // X & undef --> 0 (not undef). So this lane must be converted to choose
26566       // from the zero constant vector (same as if the element had all 0-bits).
26567       if (Elt.isUndef()) {
26568         Indices.push_back(i + NumSubElts);
26569         continue;
26570       }
26571 
26572       APInt Bits;
26573       if (auto *Cst = dyn_cast<ConstantSDNode>(Elt))
26574         Bits = Cst->getAPIntValue();
26575       else if (auto *CstFP = dyn_cast<ConstantFPSDNode>(Elt))
26576         Bits = CstFP->getValueAPF().bitcastToAPInt();
26577       else
26578         return SDValue();
26579 
26580       // Extract the sub element from the constant bit mask.
26581       if (DAG.getDataLayout().isBigEndian())
26582         Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
26583       else
26584         Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
26585 
26586       if (Bits.isAllOnes())
26587         Indices.push_back(i);
26588       else if (Bits == 0)
26589         Indices.push_back(i + NumSubElts);
26590       else
26591         return SDValue();
26592     }
26593 
26594     // Let's see if the target supports this vector_shuffle.
26595     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
26596     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
26597     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
26598       return SDValue();
26599 
26600     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
26601     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
26602                                                    DAG.getBitcast(ClearVT, LHS),
26603                                                    Zero, Indices));
26604   };
26605 
26606   // Determine maximum split level (byte level masking).
26607   int MaxSplit = 1;
26608   if (RVT.getScalarSizeInBits() % 8 == 0)
26609     MaxSplit = RVT.getScalarSizeInBits() / 8;
26610 
26611   for (int Split = 1; Split <= MaxSplit; ++Split)
26612     if (RVT.getScalarSizeInBits() % Split == 0)
26613       if (SDValue S = BuildClearMask(Split))
26614         return S;
26615 
26616   return SDValue();
26617 }
26618 
26619 /// If a vector binop is performed on splat values, it may be profitable to
26620 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)26621 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
26622                                       const SDLoc &DL) {
26623   SDValue N0 = N->getOperand(0);
26624   SDValue N1 = N->getOperand(1);
26625   unsigned Opcode = N->getOpcode();
26626   EVT VT = N->getValueType(0);
26627   EVT EltVT = VT.getVectorElementType();
26628   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26629 
26630   // TODO: Remove/replace the extract cost check? If the elements are available
26631   //       as scalars, then there may be no extract cost. Should we ask if
26632   //       inserting a scalar back into a vector is cheap instead?
26633   int Index0, Index1;
26634   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
26635   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
26636   // Extract element from splat_vector should be free.
26637   // TODO: use DAG.isSplatValue instead?
26638   bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
26639                            N1.getOpcode() == ISD::SPLAT_VECTOR;
26640   if (!Src0 || !Src1 || Index0 != Index1 ||
26641       Src0.getValueType().getVectorElementType() != EltVT ||
26642       Src1.getValueType().getVectorElementType() != EltVT ||
26643       !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
26644       !TLI.isOperationLegalOrCustom(Opcode, EltVT))
26645     return SDValue();
26646 
26647   SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
26648   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
26649   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
26650   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
26651 
26652   // If all lanes but 1 are undefined, no need to splat the scalar result.
26653   // TODO: Keep track of undefs and use that info in the general case.
26654   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
26655       count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
26656       count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
26657     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
26658     // build_vec ..undef, (bo X, Y), undef...
26659     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
26660     Ops[Index0] = ScalarBO;
26661     return DAG.getBuildVector(VT, DL, Ops);
26662   }
26663 
26664   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
26665   return DAG.getSplat(VT, DL, ScalarBO);
26666 }
26667 
26668 /// Visit a vector cast operation, like FP_EXTEND.
SimplifyVCastOp(SDNode * N,const SDLoc & DL)26669 SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
26670   EVT VT = N->getValueType(0);
26671   assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
26672   EVT EltVT = VT.getVectorElementType();
26673   unsigned Opcode = N->getOpcode();
26674 
26675   SDValue N0 = N->getOperand(0);
26676   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26677 
26678   // TODO: promote operation might be also good here?
26679   int Index0;
26680   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
26681   if (Src0 &&
26682       (N0.getOpcode() == ISD::SPLAT_VECTOR ||
26683        TLI.isExtractVecEltCheap(VT, Index0)) &&
26684       TLI.isOperationLegalOrCustom(Opcode, EltVT) &&
26685       TLI.preferScalarizeSplat(N)) {
26686     EVT SrcVT = N0.getValueType();
26687     EVT SrcEltVT = SrcVT.getVectorElementType();
26688     SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
26689     SDValue Elt =
26690         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC);
26691     SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, Elt, N->getFlags());
26692     if (VT.isScalableVector())
26693       return DAG.getSplatVector(VT, DL, ScalarBO);
26694     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
26695     return DAG.getBuildVector(VT, DL, Ops);
26696   }
26697 
26698   return SDValue();
26699 }
26700 
26701 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N,const SDLoc & DL)26702 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
26703   EVT VT = N->getValueType(0);
26704   assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
26705 
26706   SDValue LHS = N->getOperand(0);
26707   SDValue RHS = N->getOperand(1);
26708   unsigned Opcode = N->getOpcode();
26709   SDNodeFlags Flags = N->getFlags();
26710 
26711   // Move unary shuffles with identical masks after a vector binop:
26712   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
26713   //   --> shuffle (VBinOp A, B), Undef, Mask
26714   // This does not require type legality checks because we are creating the
26715   // same types of operations that are in the original sequence. We do have to
26716   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
26717   // though. This code is adapted from the identical transform in instcombine.
26718   if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
26719     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
26720     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
26721     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
26722         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
26723         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
26724       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
26725                                      RHS.getOperand(0), Flags);
26726       SDValue UndefV = LHS.getOperand(1);
26727       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
26728     }
26729 
26730     // Try to sink a splat shuffle after a binop with a uniform constant.
26731     // This is limited to cases where neither the shuffle nor the constant have
26732     // undefined elements because that could be poison-unsafe or inhibit
26733     // demanded elements analysis. It is further limited to not change a splat
26734     // of an inserted scalar because that may be optimized better by
26735     // load-folding or other target-specific behaviors.
26736     if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
26737         Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
26738         Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
26739       // binop (splat X), (splat C) --> splat (binop X, C)
26740       SDValue X = Shuf0->getOperand(0);
26741       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
26742       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
26743                                   Shuf0->getMask());
26744     }
26745     if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
26746         Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
26747         Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
26748       // binop (splat C), (splat X) --> splat (binop C, X)
26749       SDValue X = Shuf1->getOperand(0);
26750       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
26751       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
26752                                   Shuf1->getMask());
26753     }
26754   }
26755 
26756   // The following pattern is likely to emerge with vector reduction ops. Moving
26757   // the binary operation ahead of insertion may allow using a narrower vector
26758   // instruction that has better performance than the wide version of the op:
26759   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
26760   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
26761       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
26762       LHS.getOperand(2) == RHS.getOperand(2) &&
26763       (LHS.hasOneUse() || RHS.hasOneUse())) {
26764     SDValue X = LHS.getOperand(1);
26765     SDValue Y = RHS.getOperand(1);
26766     SDValue Z = LHS.getOperand(2);
26767     EVT NarrowVT = X.getValueType();
26768     if (NarrowVT == Y.getValueType() &&
26769         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
26770                                               LegalOperations)) {
26771       // (binop undef, undef) may not return undef, so compute that result.
26772       SDValue VecC =
26773           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
26774       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
26775       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
26776     }
26777   }
26778 
26779   // Make sure all but the first op are undef or constant.
26780   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
26781     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
26782            all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
26783              return Op.isUndef() ||
26784                     ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
26785            });
26786   };
26787 
26788   // The following pattern is likely to emerge with vector reduction ops. Moving
26789   // the binary operation ahead of the concat may allow using a narrower vector
26790   // instruction that has better performance than the wide version of the op:
26791   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
26792   //   concat (VBinOp X, Y), VecC
26793   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
26794       (LHS.hasOneUse() || RHS.hasOneUse())) {
26795     EVT NarrowVT = LHS.getOperand(0).getValueType();
26796     if (NarrowVT == RHS.getOperand(0).getValueType() &&
26797         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
26798       unsigned NumOperands = LHS.getNumOperands();
26799       SmallVector<SDValue, 4> ConcatOps;
26800       for (unsigned i = 0; i != NumOperands; ++i) {
26801         // This constant fold for operands 1 and up.
26802         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
26803                                         RHS.getOperand(i)));
26804       }
26805 
26806       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
26807     }
26808   }
26809 
26810   if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
26811     return V;
26812 
26813   return SDValue();
26814 }
26815 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)26816 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
26817                                     SDValue N2) {
26818   assert(N0.getOpcode() == ISD::SETCC &&
26819          "First argument must be a SetCC node!");
26820 
26821   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
26822                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
26823 
26824   // If we got a simplified select_cc node back from SimplifySelectCC, then
26825   // break it down into a new SETCC node, and a new SELECT node, and then return
26826   // the SELECT node, since we were called with a SELECT node.
26827   if (SCC.getNode()) {
26828     // Check to see if we got a select_cc back (to turn into setcc/select).
26829     // Otherwise, just return whatever node we got back, like fabs.
26830     if (SCC.getOpcode() == ISD::SELECT_CC) {
26831       const SDNodeFlags Flags = N0->getFlags();
26832       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
26833                                   N0.getValueType(),
26834                                   SCC.getOperand(0), SCC.getOperand(1),
26835                                   SCC.getOperand(4), Flags);
26836       AddToWorklist(SETCC.getNode());
26837       SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
26838                                          SCC.getOperand(2), SCC.getOperand(3));
26839       SelectNode->setFlags(Flags);
26840       return SelectNode;
26841     }
26842 
26843     return SCC;
26844   }
26845   return SDValue();
26846 }
26847 
26848 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
26849 /// being selected between, see if we can simplify the select.  Callers of this
26850 /// should assume that TheSelect is deleted if this returns true.  As such, they
26851 /// should return the appropriate thing (e.g. the node) back to the top-level of
26852 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)26853 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
26854                                     SDValue RHS) {
26855   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
26856   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
26857   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
26858     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
26859       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
26860       SDValue Sqrt = RHS;
26861       ISD::CondCode CC;
26862       SDValue CmpLHS;
26863       const ConstantFPSDNode *Zero = nullptr;
26864 
26865       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
26866         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
26867         CmpLHS = TheSelect->getOperand(0);
26868         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
26869       } else {
26870         // SELECT or VSELECT
26871         SDValue Cmp = TheSelect->getOperand(0);
26872         if (Cmp.getOpcode() == ISD::SETCC) {
26873           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
26874           CmpLHS = Cmp.getOperand(0);
26875           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
26876         }
26877       }
26878       if (Zero && Zero->isZero() &&
26879           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
26880           CC == ISD::SETULT || CC == ISD::SETLT)) {
26881         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
26882         CombineTo(TheSelect, Sqrt);
26883         return true;
26884       }
26885     }
26886   }
26887   // Cannot simplify select with vector condition
26888   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
26889 
26890   // If this is a select from two identical things, try to pull the operation
26891   // through the select.
26892   if (LHS.getOpcode() != RHS.getOpcode() ||
26893       !LHS.hasOneUse() || !RHS.hasOneUse())
26894     return false;
26895 
26896   // If this is a load and the token chain is identical, replace the select
26897   // of two loads with a load through a select of the address to load from.
26898   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
26899   // constants have been dropped into the constant pool.
26900   if (LHS.getOpcode() == ISD::LOAD) {
26901     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
26902     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
26903 
26904     // Token chains must be identical.
26905     if (LHS.getOperand(0) != RHS.getOperand(0) ||
26906         // Do not let this transformation reduce the number of volatile loads.
26907         // Be conservative for atomics for the moment
26908         // TODO: This does appear to be legal for unordered atomics (see D66309)
26909         !LLD->isSimple() || !RLD->isSimple() ||
26910         // FIXME: If either is a pre/post inc/dec load,
26911         // we'd need to split out the address adjustment.
26912         LLD->isIndexed() || RLD->isIndexed() ||
26913         // If this is an EXTLOAD, the VT's must match.
26914         LLD->getMemoryVT() != RLD->getMemoryVT() ||
26915         // If this is an EXTLOAD, the kind of extension must match.
26916         (LLD->getExtensionType() != RLD->getExtensionType() &&
26917          // The only exception is if one of the extensions is anyext.
26918          LLD->getExtensionType() != ISD::EXTLOAD &&
26919          RLD->getExtensionType() != ISD::EXTLOAD) ||
26920         // FIXME: this discards src value information.  This is
26921         // over-conservative. It would be beneficial to be able to remember
26922         // both potential memory locations.  Since we are discarding
26923         // src value info, don't do the transformation if the memory
26924         // locations are not in the default address space.
26925         LLD->getPointerInfo().getAddrSpace() != 0 ||
26926         RLD->getPointerInfo().getAddrSpace() != 0 ||
26927         // We can't produce a CMOV of a TargetFrameIndex since we won't
26928         // generate the address generation required.
26929         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
26930         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
26931         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
26932                                       LLD->getBasePtr().getValueType()))
26933       return false;
26934 
26935     // The loads must not depend on one another.
26936     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
26937       return false;
26938 
26939     // Check that the select condition doesn't reach either load.  If so,
26940     // folding this will induce a cycle into the DAG.  If not, this is safe to
26941     // xform, so create a select of the addresses.
26942 
26943     SmallPtrSet<const SDNode *, 32> Visited;
26944     SmallVector<const SDNode *, 16> Worklist;
26945 
26946     // Always fail if LLD and RLD are not independent. TheSelect is a
26947     // predecessor to all Nodes in question so we need not search past it.
26948 
26949     Visited.insert(TheSelect);
26950     Worklist.push_back(LLD);
26951     Worklist.push_back(RLD);
26952 
26953     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
26954         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
26955       return false;
26956 
26957     SDValue Addr;
26958     if (TheSelect->getOpcode() == ISD::SELECT) {
26959       // We cannot do this optimization if any pair of {RLD, LLD} is a
26960       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
26961       // Loads, we only need to check if CondNode is a successor to one of the
26962       // loads. We can further avoid this if there's no use of their chain
26963       // value.
26964       SDNode *CondNode = TheSelect->getOperand(0).getNode();
26965       Worklist.push_back(CondNode);
26966 
26967       if ((LLD->hasAnyUseOfValue(1) &&
26968            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
26969           (RLD->hasAnyUseOfValue(1) &&
26970            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
26971         return false;
26972 
26973       Addr = DAG.getSelect(SDLoc(TheSelect),
26974                            LLD->getBasePtr().getValueType(),
26975                            TheSelect->getOperand(0), LLD->getBasePtr(),
26976                            RLD->getBasePtr());
26977     } else {  // Otherwise SELECT_CC
26978       // We cannot do this optimization if any pair of {RLD, LLD} is a
26979       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
26980       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
26981       // one of the loads. We can further avoid this if there's no use of their
26982       // chain value.
26983 
26984       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
26985       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
26986       Worklist.push_back(CondLHS);
26987       Worklist.push_back(CondRHS);
26988 
26989       if ((LLD->hasAnyUseOfValue(1) &&
26990            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
26991           (RLD->hasAnyUseOfValue(1) &&
26992            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
26993         return false;
26994 
26995       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
26996                          LLD->getBasePtr().getValueType(),
26997                          TheSelect->getOperand(0),
26998                          TheSelect->getOperand(1),
26999                          LLD->getBasePtr(), RLD->getBasePtr(),
27000                          TheSelect->getOperand(4));
27001     }
27002 
27003     SDValue Load;
27004     // It is safe to replace the two loads if they have different alignments,
27005     // but the new load must be the minimum (most restrictive) alignment of the
27006     // inputs.
27007     Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
27008     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
27009     if (!RLD->isInvariant())
27010       MMOFlags &= ~MachineMemOperand::MOInvariant;
27011     if (!RLD->isDereferenceable())
27012       MMOFlags &= ~MachineMemOperand::MODereferenceable;
27013     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
27014       // FIXME: Discards pointer and AA info.
27015       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
27016                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
27017                          MMOFlags);
27018     } else {
27019       // FIXME: Discards pointer and AA info.
27020       Load = DAG.getExtLoad(
27021           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
27022                                                   : LLD->getExtensionType(),
27023           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
27024           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
27025     }
27026 
27027     // Users of the select now use the result of the load.
27028     CombineTo(TheSelect, Load);
27029 
27030     // Users of the old loads now use the new load's chain.  We know the
27031     // old-load value is dead now.
27032     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
27033     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
27034     return true;
27035   }
27036 
27037   return false;
27038 }
27039 
27040 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
27041 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)27042 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
27043                                             SDValue N1, SDValue N2, SDValue N3,
27044                                             ISD::CondCode CC) {
27045   // If this is a select where the false operand is zero and the compare is a
27046   // check of the sign bit, see if we can perform the "gzip trick":
27047   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
27048   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
27049   EVT XType = N0.getValueType();
27050   EVT AType = N2.getValueType();
27051   if (!isNullConstant(N3) || !XType.bitsGE(AType))
27052     return SDValue();
27053 
27054   // If the comparison is testing for a positive value, we have to invert
27055   // the sign bit mask, so only do that transform if the target has a bitwise
27056   // 'and not' instruction (the invert is free).
27057   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
27058     // (X > -1) ? A : 0
27059     // (X >  0) ? X : 0 <-- This is canonical signed max.
27060     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
27061       return SDValue();
27062   } else if (CC == ISD::SETLT) {
27063     // (X <  0) ? A : 0
27064     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
27065     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
27066       return SDValue();
27067   } else {
27068     return SDValue();
27069   }
27070 
27071   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
27072   // constant.
27073   EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
27074   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
27075   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
27076     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
27077     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
27078       SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
27079       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
27080       AddToWorklist(Shift.getNode());
27081 
27082       if (XType.bitsGT(AType)) {
27083         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
27084         AddToWorklist(Shift.getNode());
27085       }
27086 
27087       if (CC == ISD::SETGT)
27088         Shift = DAG.getNOT(DL, Shift, AType);
27089 
27090       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
27091     }
27092   }
27093 
27094   unsigned ShCt = XType.getSizeInBits() - 1;
27095   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
27096     return SDValue();
27097 
27098   SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
27099   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
27100   AddToWorklist(Shift.getNode());
27101 
27102   if (XType.bitsGT(AType)) {
27103     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
27104     AddToWorklist(Shift.getNode());
27105   }
27106 
27107   if (CC == ISD::SETGT)
27108     Shift = DAG.getNOT(DL, Shift, AType);
27109 
27110   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
27111 }
27112 
27113 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)27114 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
27115   SDValue N0 = N->getOperand(0);
27116   SDValue N1 = N->getOperand(1);
27117   SDValue N2 = N->getOperand(2);
27118   SDLoc DL(N);
27119 
27120   unsigned BinOpc = N1.getOpcode();
27121   if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc) ||
27122       (N1.getResNo() != N2.getResNo()))
27123     return SDValue();
27124 
27125   // The use checks are intentionally on SDNode because we may be dealing
27126   // with opcodes that produce more than one SDValue.
27127   // TODO: Do we really need to check N0 (the condition operand of the select)?
27128   //       But removing that clause could cause an infinite loop...
27129   if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
27130     return SDValue();
27131 
27132   // Binops may include opcodes that return multiple values, so all values
27133   // must be created/propagated from the newly created binops below.
27134   SDVTList OpVTs = N1->getVTList();
27135 
27136   // Fold select(cond, binop(x, y), binop(z, y))
27137   //  --> binop(select(cond, x, z), y)
27138   if (N1.getOperand(1) == N2.getOperand(1)) {
27139     SDValue N10 = N1.getOperand(0);
27140     SDValue N20 = N2.getOperand(0);
27141     SDValue NewSel = DAG.getSelect(DL, N10.getValueType(), N0, N10, N20);
27142     SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, NewSel, N1.getOperand(1));
27143     NewBinOp->setFlags(N1->getFlags());
27144     NewBinOp->intersectFlagsWith(N2->getFlags());
27145     return SDValue(NewBinOp.getNode(), N1.getResNo());
27146   }
27147 
27148   // Fold select(cond, binop(x, y), binop(x, z))
27149   //  --> binop(x, select(cond, y, z))
27150   if (N1.getOperand(0) == N2.getOperand(0)) {
27151     SDValue N11 = N1.getOperand(1);
27152     SDValue N21 = N2.getOperand(1);
27153     // Second op VT might be different (e.g. shift amount type)
27154     if (N11.getValueType() == N21.getValueType()) {
27155       SDValue NewSel = DAG.getSelect(DL, N11.getValueType(), N0, N11, N21);
27156       SDValue NewBinOp =
27157           DAG.getNode(BinOpc, DL, OpVTs, N1.getOperand(0), NewSel);
27158       NewBinOp->setFlags(N1->getFlags());
27159       NewBinOp->intersectFlagsWith(N2->getFlags());
27160       return SDValue(NewBinOp.getNode(), N1.getResNo());
27161     }
27162   }
27163 
27164   // TODO: Handle isCommutativeBinOp patterns as well?
27165   return SDValue();
27166 }
27167 
27168 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)27169 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
27170   SDValue N0 = N->getOperand(0);
27171   EVT VT = N->getValueType(0);
27172   bool IsFabs = N->getOpcode() == ISD::FABS;
27173   bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
27174 
27175   if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
27176     return SDValue();
27177 
27178   SDValue Int = N0.getOperand(0);
27179   EVT IntVT = Int.getValueType();
27180 
27181   // The operand to cast should be integer.
27182   if (!IntVT.isInteger() || IntVT.isVector())
27183     return SDValue();
27184 
27185   // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
27186   // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
27187   APInt SignMask;
27188   if (N0.getValueType().isVector()) {
27189     // For vector, create a sign mask (0x80...) or its inverse (for fabs,
27190     // 0x7f...) per element and splat it.
27191     SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
27192     if (IsFabs)
27193       SignMask = ~SignMask;
27194     SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
27195   } else {
27196     // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
27197     SignMask = APInt::getSignMask(IntVT.getSizeInBits());
27198     if (IsFabs)
27199       SignMask = ~SignMask;
27200   }
27201   SDLoc DL(N0);
27202   Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
27203                     DAG.getConstant(SignMask, DL, IntVT));
27204   AddToWorklist(Int.getNode());
27205   return DAG.getBitcast(VT, Int);
27206 }
27207 
27208 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
27209 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
27210 /// in it. This may be a win when the constant is not otherwise available
27211 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)27212 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
27213     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
27214     ISD::CondCode CC) {
27215   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
27216     return SDValue();
27217 
27218   // If we are before legalize types, we want the other legalization to happen
27219   // first (for example, to avoid messing with soft float).
27220   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
27221   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
27222   EVT VT = N2.getValueType();
27223   if (!TV || !FV || !TLI.isTypeLegal(VT))
27224     return SDValue();
27225 
27226   // If a constant can be materialized without loads, this does not make sense.
27227   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
27228       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
27229       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
27230     return SDValue();
27231 
27232   // If both constants have multiple uses, then we won't need to do an extra
27233   // load. The values are likely around in registers for other users.
27234   if (!TV->hasOneUse() && !FV->hasOneUse())
27235     return SDValue();
27236 
27237   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
27238                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
27239   Type *FPTy = Elts[0]->getType();
27240   const DataLayout &TD = DAG.getDataLayout();
27241 
27242   // Create a ConstantArray of the two constants.
27243   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
27244   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
27245                                       TD.getPrefTypeAlign(FPTy));
27246   Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
27247 
27248   // Get offsets to the 0 and 1 elements of the array, so we can select between
27249   // them.
27250   SDValue Zero = DAG.getIntPtrConstant(0, DL);
27251   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
27252   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
27253   SDValue Cond =
27254       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
27255   AddToWorklist(Cond.getNode());
27256   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
27257   AddToWorklist(CstOffset.getNode());
27258   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
27259   AddToWorklist(CPIdx.getNode());
27260   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
27261                      MachinePointerInfo::getConstantPool(
27262                          DAG.getMachineFunction()), Alignment);
27263 }
27264 
27265 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
27266 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)27267 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
27268                                       SDValue N2, SDValue N3, ISD::CondCode CC,
27269                                       bool NotExtCompare) {
27270   // (x ? y : y) -> y.
27271   if (N2 == N3) return N2;
27272 
27273   EVT CmpOpVT = N0.getValueType();
27274   EVT CmpResVT = getSetCCResultType(CmpOpVT);
27275   EVT VT = N2.getValueType();
27276   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
27277   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
27278   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
27279 
27280   // Determine if the condition we're dealing with is constant.
27281   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
27282     AddToWorklist(SCC.getNode());
27283     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
27284       // fold select_cc true, x, y -> x
27285       // fold select_cc false, x, y -> y
27286       return !(SCCC->isZero()) ? N2 : N3;
27287     }
27288   }
27289 
27290   if (SDValue V =
27291           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
27292     return V;
27293 
27294   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
27295     return V;
27296 
27297   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
27298   // where y is has a single bit set.
27299   // A plaintext description would be, we can turn the SELECT_CC into an AND
27300   // when the condition can be materialized as an all-ones register.  Any
27301   // single bit-test can be materialized as an all-ones register with
27302   // shift-left and shift-right-arith.
27303   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
27304       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
27305     SDValue AndLHS = N0->getOperand(0);
27306     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
27307     if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
27308       // Shift the tested bit over the sign bit.
27309       const APInt &AndMask = ConstAndRHS->getAPIntValue();
27310       if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
27311         unsigned ShCt = AndMask.getBitWidth() - 1;
27312         SDValue ShlAmt =
27313             DAG.getConstant(AndMask.countl_zero(), SDLoc(AndLHS),
27314                             getShiftAmountTy(AndLHS.getValueType()));
27315         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
27316 
27317         // Now arithmetic right shift it all the way over, so the result is
27318         // either all-ones, or zero.
27319         SDValue ShrAmt =
27320           DAG.getConstant(ShCt, SDLoc(Shl),
27321                           getShiftAmountTy(Shl.getValueType()));
27322         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
27323 
27324         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
27325       }
27326     }
27327   }
27328 
27329   // fold select C, 16, 0 -> shl C, 4
27330   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
27331   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
27332 
27333   if ((Fold || Swap) &&
27334       TLI.getBooleanContents(CmpOpVT) ==
27335           TargetLowering::ZeroOrOneBooleanContent &&
27336       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
27337 
27338     if (Swap) {
27339       CC = ISD::getSetCCInverse(CC, CmpOpVT);
27340       std::swap(N2C, N3C);
27341     }
27342 
27343     // If the caller doesn't want us to simplify this into a zext of a compare,
27344     // don't do it.
27345     if (NotExtCompare && N2C->isOne())
27346       return SDValue();
27347 
27348     SDValue Temp, SCC;
27349     // zext (setcc n0, n1)
27350     if (LegalTypes) {
27351       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
27352       Temp = DAG.getZExtOrTrunc(SCC, SDLoc(N2), VT);
27353     } else {
27354       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
27355       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
27356     }
27357 
27358     AddToWorklist(SCC.getNode());
27359     AddToWorklist(Temp.getNode());
27360 
27361     if (N2C->isOne())
27362       return Temp;
27363 
27364     unsigned ShCt = N2C->getAPIntValue().logBase2();
27365     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
27366       return SDValue();
27367 
27368     // shl setcc result by log2 n2c
27369     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
27370                        DAG.getConstant(ShCt, SDLoc(Temp),
27371                                        getShiftAmountTy(Temp.getValueType())));
27372   }
27373 
27374   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
27375   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
27376   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
27377   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
27378   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
27379   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
27380   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
27381   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
27382   if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
27383     SDValue ValueOnZero = N2;
27384     SDValue Count = N3;
27385     // If the condition is NE instead of E, swap the operands.
27386     if (CC == ISD::SETNE)
27387       std::swap(ValueOnZero, Count);
27388     // Check if the value on zero is a constant equal to the bits in the type.
27389     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
27390       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
27391         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
27392         // legal, combine to just cttz.
27393         if ((Count.getOpcode() == ISD::CTTZ ||
27394              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
27395             N0 == Count.getOperand(0) &&
27396             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
27397           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
27398         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
27399         // legal, combine to just ctlz.
27400         if ((Count.getOpcode() == ISD::CTLZ ||
27401              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
27402             N0 == Count.getOperand(0) &&
27403             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
27404           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
27405       }
27406     }
27407   }
27408 
27409   // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
27410   // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
27411   if (!NotExtCompare && N1C && N2C && N3C &&
27412       N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
27413       ((N1C->isAllOnes() && CC == ISD::SETGT) ||
27414        (N1C->isZero() && CC == ISD::SETLT)) &&
27415       !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) {
27416     SDValue ASR = DAG.getNode(
27417         ISD::SRA, DL, CmpOpVT, N0,
27418         DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT));
27419     return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT),
27420                        DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
27421   }
27422 
27423   if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27424     return S;
27425   if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27426     return S;
27427 
27428   return SDValue();
27429 }
27430 
27431 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)27432 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
27433                                    ISD::CondCode Cond, const SDLoc &DL,
27434                                    bool foldBooleans) {
27435   TargetLowering::DAGCombinerInfo
27436     DagCombineInfo(DAG, Level, false, this);
27437   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
27438 }
27439 
27440 /// Given an ISD::SDIV node expressing a divide by constant, return
27441 /// a DAG expression to select that will generate the same value by multiplying
27442 /// by a magic number.
27443 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)27444 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
27445   // when optimising for minimum size, we don't want to expand a div to a mul
27446   // and a shift.
27447   if (DAG.getMachineFunction().getFunction().hasMinSize())
27448     return SDValue();
27449 
27450   SmallVector<SDNode *, 8> Built;
27451   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
27452     for (SDNode *N : Built)
27453       AddToWorklist(N);
27454     return S;
27455   }
27456 
27457   return SDValue();
27458 }
27459 
27460 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
27461 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)27462 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
27463   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
27464   if (!C)
27465     return SDValue();
27466 
27467   // Avoid division by zero.
27468   if (C->isZero())
27469     return SDValue();
27470 
27471   SmallVector<SDNode *, 8> Built;
27472   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
27473     for (SDNode *N : Built)
27474       AddToWorklist(N);
27475     return S;
27476   }
27477 
27478   return SDValue();
27479 }
27480 
27481 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
27482 /// expression that will generate the same value by multiplying by a magic
27483 /// number.
27484 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)27485 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
27486   // when optimising for minimum size, we don't want to expand a div to a mul
27487   // and a shift.
27488   if (DAG.getMachineFunction().getFunction().hasMinSize())
27489     return SDValue();
27490 
27491   SmallVector<SDNode *, 8> Built;
27492   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
27493     for (SDNode *N : Built)
27494       AddToWorklist(N);
27495     return S;
27496   }
27497 
27498   return SDValue();
27499 }
27500 
27501 /// Given an ISD::SREM node expressing a remainder by constant power of 2,
27502 /// return a DAG expression that will generate the same value.
BuildSREMPow2(SDNode * N)27503 SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
27504   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
27505   if (!C)
27506     return SDValue();
27507 
27508   // Avoid division by zero.
27509   if (C->isZero())
27510     return SDValue();
27511 
27512   SmallVector<SDNode *, 8> Built;
27513   if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) {
27514     for (SDNode *N : Built)
27515       AddToWorklist(N);
27516     return S;
27517   }
27518 
27519   return SDValue();
27520 }
27521 
27522 // This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
27523 //
27524 // Returns the node that represents `Log2(Op)`. This may create a new node. If
27525 // we are unable to compute `Log2(Op)` its return `SDValue()`.
27526 //
27527 // All nodes will be created at `DL` and the output will be of type `VT`.
27528 //
27529 // This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
27530 // `AssumeNonZero` if this function should simply assume (not require proving
27531 // `Op` is non-zero).
takeInexpensiveLog2(SelectionDAG & DAG,const SDLoc & DL,EVT VT,SDValue Op,unsigned Depth,bool AssumeNonZero)27532 static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
27533                                    SDValue Op, unsigned Depth,
27534                                    bool AssumeNonZero) {
27535   assert(VT.isInteger() && "Only integer types are supported!");
27536 
27537   auto PeekThroughCastsAndTrunc = [](SDValue V) {
27538     while (true) {
27539       switch (V.getOpcode()) {
27540       case ISD::TRUNCATE:
27541       case ISD::ZERO_EXTEND:
27542         V = V.getOperand(0);
27543         break;
27544       default:
27545         return V;
27546       }
27547     }
27548   };
27549 
27550   if (VT.isScalableVector())
27551     return SDValue();
27552 
27553   Op = PeekThroughCastsAndTrunc(Op);
27554 
27555   // Helper for determining whether a value is a power-2 constant scalar or a
27556   // vector of such elements.
27557   SmallVector<APInt> Pow2Constants;
27558   auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
27559     if (C->isZero() || C->isOpaque())
27560       return false;
27561     // TODO: We may also be able to support negative powers of 2 here.
27562     if (C->getAPIntValue().isPowerOf2()) {
27563       Pow2Constants.emplace_back(C->getAPIntValue());
27564       return true;
27565     }
27566     return false;
27567   };
27568 
27569   if (ISD::matchUnaryPredicate(Op, IsPowerOfTwo)) {
27570     if (!VT.isVector())
27571       return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
27572     // We need to create a build vector
27573     SmallVector<SDValue> Log2Ops;
27574     for (const APInt &Pow2 : Pow2Constants)
27575       Log2Ops.emplace_back(
27576           DAG.getConstant(Pow2.logBase2(), DL, VT.getScalarType()));
27577     return DAG.getBuildVector(VT, DL, Log2Ops);
27578   }
27579 
27580   if (Depth >= DAG.MaxRecursionDepth)
27581     return SDValue();
27582 
27583   auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
27584     ToCast = PeekThroughCastsAndTrunc(ToCast);
27585     EVT CurVT = ToCast.getValueType();
27586     if (NewVT == CurVT)
27587       return ToCast;
27588 
27589     if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
27590       return DAG.getBitcast(NewVT, ToCast);
27591 
27592     return DAG.getZExtOrTrunc(ToCast, DL, NewVT);
27593   };
27594 
27595   // log2(X << Y) -> log2(X) + Y
27596   if (Op.getOpcode() == ISD::SHL) {
27597     // 1 << Y and X nuw/nsw << Y are all non-zero.
27598     if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
27599         Op->getFlags().hasNoSignedWrap() || isOneConstant(Op.getOperand(0)))
27600       if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0),
27601                                              Depth + 1, AssumeNonZero))
27602         return DAG.getNode(ISD::ADD, DL, VT, LogX,
27603                            CastToVT(VT, Op.getOperand(1)));
27604   }
27605 
27606   // c ? X : Y -> c ? Log2(X) : Log2(Y)
27607   if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
27608       Op.hasOneUse()) {
27609     if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
27610                                            Depth + 1, AssumeNonZero))
27611       if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
27612                                              Depth + 1, AssumeNonZero))
27613         return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
27614   }
27615 
27616   // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
27617   // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
27618   if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
27619       Op.hasOneUse()) {
27620     // Use AssumeNonZero as false here. Otherwise we can hit case where
27621     // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
27622     if (SDValue LogX =
27623             takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0), Depth + 1,
27624                                 /*AssumeNonZero*/ false))
27625       if (SDValue LogY =
27626               takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1), Depth + 1,
27627                                   /*AssumeNonZero*/ false))
27628         return DAG.getNode(Op.getOpcode(), DL, VT, LogX, LogY);
27629   }
27630 
27631   return SDValue();
27632 }
27633 
27634 /// Determines the LogBase2 value for a non-null input value using the
27635 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL,bool KnownNonZero,bool InexpensiveOnly,std::optional<EVT> OutVT)27636 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
27637                                    bool KnownNonZero, bool InexpensiveOnly,
27638                                    std::optional<EVT> OutVT) {
27639   EVT VT = OutVT ? *OutVT : V.getValueType();
27640   SDValue InexpensiveLogBase2 =
27641       takeInexpensiveLog2(DAG, DL, VT, V, /*Depth*/ 0, KnownNonZero);
27642   if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(V))
27643     return InexpensiveLogBase2;
27644 
27645   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
27646   SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
27647   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
27648   return LogBase2;
27649 }
27650 
27651 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27652 /// For the reciprocal, we need to find the zero of the function:
27653 ///   F(X) = 1/X - A [which has a zero at X = 1/A]
27654 ///     =>
27655 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
27656 ///     does not require additional intermediate precision]
27657 /// For the last iteration, put numerator N into it to gain more precision:
27658 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)27659 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
27660                                       SDNodeFlags Flags) {
27661   if (LegalDAG)
27662     return SDValue();
27663 
27664   // TODO: Handle extended types?
27665   EVT VT = Op.getValueType();
27666   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
27667       VT.getScalarType() != MVT::f64)
27668     return SDValue();
27669 
27670   // If estimates are explicitly disabled for this function, we're done.
27671   MachineFunction &MF = DAG.getMachineFunction();
27672   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
27673   if (Enabled == TLI.ReciprocalEstimate::Disabled)
27674     return SDValue();
27675 
27676   // Estimates may be explicitly enabled for this type with a custom number of
27677   // refinement steps.
27678   int Iterations = TLI.getDivRefinementSteps(VT, MF);
27679   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
27680     AddToWorklist(Est.getNode());
27681 
27682     SDLoc DL(Op);
27683     if (Iterations) {
27684       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
27685 
27686       // Newton iterations: Est = Est + Est (N - Arg * Est)
27687       // If this is the last iteration, also multiply by the numerator.
27688       for (int i = 0; i < Iterations; ++i) {
27689         SDValue MulEst = Est;
27690 
27691         if (i == Iterations - 1) {
27692           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
27693           AddToWorklist(MulEst.getNode());
27694         }
27695 
27696         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
27697         AddToWorklist(NewEst.getNode());
27698 
27699         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
27700                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
27701         AddToWorklist(NewEst.getNode());
27702 
27703         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
27704         AddToWorklist(NewEst.getNode());
27705 
27706         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
27707         AddToWorklist(Est.getNode());
27708       }
27709     } else {
27710       // If no iterations are available, multiply with N.
27711       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
27712       AddToWorklist(Est.getNode());
27713     }
27714 
27715     return Est;
27716   }
27717 
27718   return SDValue();
27719 }
27720 
27721 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27722 /// For the reciprocal sqrt, we need to find the zero of the function:
27723 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
27724 ///     =>
27725 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
27726 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)27727 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
27728                                          unsigned Iterations,
27729                                          SDNodeFlags Flags, bool Reciprocal) {
27730   EVT VT = Arg.getValueType();
27731   SDLoc DL(Arg);
27732   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
27733 
27734   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
27735   // this entire sequence requires only one FP constant.
27736   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
27737   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
27738 
27739   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
27740   for (unsigned i = 0; i < Iterations; ++i) {
27741     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
27742     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
27743     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
27744     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
27745   }
27746 
27747   // If non-reciprocal square root is requested, multiply the result by Arg.
27748   if (!Reciprocal)
27749     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
27750 
27751   return Est;
27752 }
27753 
27754 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27755 /// For the reciprocal sqrt, we need to find the zero of the function:
27756 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
27757 ///     =>
27758 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)27759 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
27760                                          unsigned Iterations,
27761                                          SDNodeFlags Flags, bool Reciprocal) {
27762   EVT VT = Arg.getValueType();
27763   SDLoc DL(Arg);
27764   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
27765   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
27766 
27767   // This routine must enter the loop below to work correctly
27768   // when (Reciprocal == false).
27769   assert(Iterations > 0);
27770 
27771   // Newton iterations for reciprocal square root:
27772   // E = (E * -0.5) * ((A * E) * E + -3.0)
27773   for (unsigned i = 0; i < Iterations; ++i) {
27774     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
27775     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
27776     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
27777 
27778     // When calculating a square root at the last iteration build:
27779     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
27780     // (notice a common subexpression)
27781     SDValue LHS;
27782     if (Reciprocal || (i + 1) < Iterations) {
27783       // RSQRT: LHS = (E * -0.5)
27784       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
27785     } else {
27786       // SQRT: LHS = (A * E) * -0.5
27787       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
27788     }
27789 
27790     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
27791   }
27792 
27793   return Est;
27794 }
27795 
27796 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
27797 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
27798 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)27799 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
27800                                            bool Reciprocal) {
27801   if (LegalDAG)
27802     return SDValue();
27803 
27804   // TODO: Handle extended types?
27805   EVT VT = Op.getValueType();
27806   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
27807       VT.getScalarType() != MVT::f64)
27808     return SDValue();
27809 
27810   // If estimates are explicitly disabled for this function, we're done.
27811   MachineFunction &MF = DAG.getMachineFunction();
27812   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
27813   if (Enabled == TLI.ReciprocalEstimate::Disabled)
27814     return SDValue();
27815 
27816   // Estimates may be explicitly enabled for this type with a custom number of
27817   // refinement steps.
27818   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
27819 
27820   bool UseOneConstNR = false;
27821   if (SDValue Est =
27822       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
27823                           Reciprocal)) {
27824     AddToWorklist(Est.getNode());
27825 
27826     if (Iterations > 0)
27827       Est = UseOneConstNR
27828             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
27829             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
27830     if (!Reciprocal) {
27831       SDLoc DL(Op);
27832       // Try the target specific test first.
27833       SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
27834 
27835       // The estimate is now completely wrong if the input was exactly 0.0 or
27836       // possibly a denormal. Force the answer to 0.0 or value provided by
27837       // target for those cases.
27838       Est = DAG.getNode(
27839           Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
27840           Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
27841     }
27842     return Est;
27843   }
27844 
27845   return SDValue();
27846 }
27847 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)27848 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
27849   return buildSqrtEstimateImpl(Op, Flags, true);
27850 }
27851 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)27852 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
27853   return buildSqrtEstimateImpl(Op, Flags, false);
27854 }
27855 
27856 /// Return true if there is any possibility that the two addresses overlap.
mayAlias(SDNode * Op0,SDNode * Op1) const27857 bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
27858 
27859   struct MemUseCharacteristics {
27860     bool IsVolatile;
27861     bool IsAtomic;
27862     SDValue BasePtr;
27863     int64_t Offset;
27864     std::optional<int64_t> NumBytes;
27865     MachineMemOperand *MMO;
27866   };
27867 
27868   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
27869     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
27870       int64_t Offset = 0;
27871       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
27872         Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
27873                      ? C->getSExtValue()
27874                      : (LSN->getAddressingMode() == ISD::PRE_DEC)
27875                            ? -1 * C->getSExtValue()
27876                            : 0;
27877       uint64_t Size =
27878           MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize());
27879       return {LSN->isVolatile(),
27880               LSN->isAtomic(),
27881               LSN->getBasePtr(),
27882               Offset /*base offset*/,
27883               std::optional<int64_t>(Size),
27884               LSN->getMemOperand()};
27885     }
27886     if (const auto *LN = cast<LifetimeSDNode>(N))
27887       return {false /*isVolatile*/,
27888               /*isAtomic*/ false,
27889               LN->getOperand(1),
27890               (LN->hasOffset()) ? LN->getOffset() : 0,
27891               (LN->hasOffset()) ? std::optional<int64_t>(LN->getSize())
27892                                 : std::optional<int64_t>(),
27893               (MachineMemOperand *)nullptr};
27894     // Default.
27895     return {false /*isvolatile*/,
27896             /*isAtomic*/ false,          SDValue(),
27897             (int64_t)0 /*offset*/,       std::optional<int64_t>() /*size*/,
27898             (MachineMemOperand *)nullptr};
27899   };
27900 
27901   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
27902                         MUC1 = getCharacteristics(Op1);
27903 
27904   // If they are to the same address, then they must be aliases.
27905   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
27906       MUC0.Offset == MUC1.Offset)
27907     return true;
27908 
27909   // If they are both volatile then they cannot be reordered.
27910   if (MUC0.IsVolatile && MUC1.IsVolatile)
27911     return true;
27912 
27913   // Be conservative about atomics for the moment
27914   // TODO: This is way overconservative for unordered atomics (see D66309)
27915   if (MUC0.IsAtomic && MUC1.IsAtomic)
27916     return true;
27917 
27918   if (MUC0.MMO && MUC1.MMO) {
27919     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
27920         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
27921       return false;
27922   }
27923 
27924   // Try to prove that there is aliasing, or that there is no aliasing. Either
27925   // way, we can return now. If nothing can be proved, proceed with more tests.
27926   bool IsAlias;
27927   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
27928                                        DAG, IsAlias))
27929     return IsAlias;
27930 
27931   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
27932   // either are not known.
27933   if (!MUC0.MMO || !MUC1.MMO)
27934     return true;
27935 
27936   // If one operation reads from invariant memory, and the other may store, they
27937   // cannot alias. These should really be checking the equivalent of mayWrite,
27938   // but it only matters for memory nodes other than load /store.
27939   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
27940       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
27941     return false;
27942 
27943   // If we know required SrcValue1 and SrcValue2 have relatively large
27944   // alignment compared to the size and offset of the access, we may be able
27945   // to prove they do not alias. This check is conservative for now to catch
27946   // cases created by splitting vector types, it only works when the offsets are
27947   // multiples of the size of the data.
27948   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
27949   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
27950   Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
27951   Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
27952   auto &Size0 = MUC0.NumBytes;
27953   auto &Size1 = MUC1.NumBytes;
27954   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
27955       Size0.has_value() && Size1.has_value() && *Size0 == *Size1 &&
27956       OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 &&
27957       SrcValOffset1 % *Size1 == 0) {
27958     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
27959     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
27960 
27961     // There is no overlap between these relatively aligned accesses of
27962     // similar size. Return no alias.
27963     if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0)
27964       return false;
27965   }
27966 
27967   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
27968                    ? CombinerGlobalAA
27969                    : DAG.getSubtarget().useAA();
27970 #ifndef NDEBUG
27971   if (CombinerAAOnlyFunc.getNumOccurrences() &&
27972       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
27973     UseAA = false;
27974 #endif
27975 
27976   if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && Size0 &&
27977       Size1) {
27978     // Use alias analysis information.
27979     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
27980     int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset;
27981     int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset;
27982     if (AA->isNoAlias(
27983             MemoryLocation(MUC0.MMO->getValue(), Overlap0,
27984                            UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
27985             MemoryLocation(MUC1.MMO->getValue(), Overlap1,
27986                            UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
27987       return false;
27988   }
27989 
27990   // Otherwise we have to assume they alias.
27991   return true;
27992 }
27993 
27994 /// Walk up chain skipping non-aliasing memory nodes,
27995 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)27996 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
27997                                    SmallVectorImpl<SDValue> &Aliases) {
27998   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
27999   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
28000 
28001   // Get alias information for node.
28002   // TODO: relax aliasing for unordered atomics (see D66309)
28003   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
28004 
28005   // Starting off.
28006   Chains.push_back(OriginalChain);
28007   unsigned Depth = 0;
28008 
28009   // Attempt to improve chain by a single step
28010   auto ImproveChain = [&](SDValue &C) -> bool {
28011     switch (C.getOpcode()) {
28012     case ISD::EntryToken:
28013       // No need to mark EntryToken.
28014       C = SDValue();
28015       return true;
28016     case ISD::LOAD:
28017     case ISD::STORE: {
28018       // Get alias information for C.
28019       // TODO: Relax aliasing for unordered atomics (see D66309)
28020       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
28021                       cast<LSBaseSDNode>(C.getNode())->isSimple();
28022       if ((IsLoad && IsOpLoad) || !mayAlias(N, C.getNode())) {
28023         // Look further up the chain.
28024         C = C.getOperand(0);
28025         return true;
28026       }
28027       // Alias, so stop here.
28028       return false;
28029     }
28030 
28031     case ISD::CopyFromReg:
28032       // Always forward past CopyFromReg.
28033       C = C.getOperand(0);
28034       return true;
28035 
28036     case ISD::LIFETIME_START:
28037     case ISD::LIFETIME_END: {
28038       // We can forward past any lifetime start/end that can be proven not to
28039       // alias the memory access.
28040       if (!mayAlias(N, C.getNode())) {
28041         // Look further up the chain.
28042         C = C.getOperand(0);
28043         return true;
28044       }
28045       return false;
28046     }
28047     default:
28048       return false;
28049     }
28050   };
28051 
28052   // Look at each chain and determine if it is an alias.  If so, add it to the
28053   // aliases list.  If not, then continue up the chain looking for the next
28054   // candidate.
28055   while (!Chains.empty()) {
28056     SDValue Chain = Chains.pop_back_val();
28057 
28058     // Don't bother if we've seen Chain before.
28059     if (!Visited.insert(Chain.getNode()).second)
28060       continue;
28061 
28062     // For TokenFactor nodes, look at each operand and only continue up the
28063     // chain until we reach the depth limit.
28064     //
28065     // FIXME: The depth check could be made to return the last non-aliasing
28066     // chain we found before we hit a tokenfactor rather than the original
28067     // chain.
28068     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
28069       Aliases.clear();
28070       Aliases.push_back(OriginalChain);
28071       return;
28072     }
28073 
28074     if (Chain.getOpcode() == ISD::TokenFactor) {
28075       // We have to check each of the operands of the token factor for "small"
28076       // token factors, so we queue them up.  Adding the operands to the queue
28077       // (stack) in reverse order maintains the original order and increases the
28078       // likelihood that getNode will find a matching token factor (CSE.)
28079       if (Chain.getNumOperands() > 16) {
28080         Aliases.push_back(Chain);
28081         continue;
28082       }
28083       for (unsigned n = Chain.getNumOperands(); n;)
28084         Chains.push_back(Chain.getOperand(--n));
28085       ++Depth;
28086       continue;
28087     }
28088     // Everything else
28089     if (ImproveChain(Chain)) {
28090       // Updated Chain Found, Consider new chain if one exists.
28091       if (Chain.getNode())
28092         Chains.push_back(Chain);
28093       ++Depth;
28094       continue;
28095     }
28096     // No Improved Chain Possible, treat as Alias.
28097     Aliases.push_back(Chain);
28098   }
28099 }
28100 
28101 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
28102 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)28103 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
28104   if (OptLevel == CodeGenOptLevel::None)
28105     return OldChain;
28106 
28107   // Ops for replacing token factor.
28108   SmallVector<SDValue, 8> Aliases;
28109 
28110   // Accumulate all the aliases to this node.
28111   GatherAllAliases(N, OldChain, Aliases);
28112 
28113   // If no operands then chain to entry token.
28114   if (Aliases.empty())
28115     return DAG.getEntryNode();
28116 
28117   // If a single operand then chain to it.  We don't need to revisit it.
28118   if (Aliases.size() == 1)
28119     return Aliases[0];
28120 
28121   // Construct a custom tailored token factor.
28122   return DAG.getTokenFactor(SDLoc(N), Aliases);
28123 }
28124 
28125 // This function tries to collect a bunch of potentially interesting
28126 // nodes to improve the chains of, all at once. This might seem
28127 // redundant, as this function gets called when visiting every store
28128 // node, so why not let the work be done on each store as it's visited?
28129 //
28130 // I believe this is mainly important because mergeConsecutiveStores
28131 // is unable to deal with merging stores of different sizes, so unless
28132 // we improve the chains of all the potential candidates up-front
28133 // before running mergeConsecutiveStores, it might only see some of
28134 // the nodes that will eventually be candidates, and then not be able
28135 // to go from a partially-merged state to the desired final
28136 // fully-merged state.
28137 
parallelizeChainedStores(StoreSDNode * St)28138 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
28139   SmallVector<StoreSDNode *, 8> ChainedStores;
28140   StoreSDNode *STChain = St;
28141   // Intervals records which offsets from BaseIndex have been covered. In
28142   // the common case, every store writes to the immediately previous address
28143   // space and thus merged with the previous interval at insertion time.
28144 
28145   using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
28146                                  IntervalMapHalfOpenInfo<int64_t>>;
28147   IMap::Allocator A;
28148   IMap Intervals(A);
28149 
28150   // This holds the base pointer, index, and the offset in bytes from the base
28151   // pointer.
28152   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
28153 
28154   // We must have a base and an offset.
28155   if (!BasePtr.getBase().getNode())
28156     return false;
28157 
28158   // Do not handle stores to undef base pointers.
28159   if (BasePtr.getBase().isUndef())
28160     return false;
28161 
28162   // Do not handle stores to opaque types
28163   if (St->getMemoryVT().isZeroSized())
28164     return false;
28165 
28166   // BaseIndexOffset assumes that offsets are fixed-size, which
28167   // is not valid for scalable vectors where the offsets are
28168   // scaled by `vscale`, so bail out early.
28169   if (St->getMemoryVT().isScalableVT())
28170     return false;
28171 
28172   // Add ST's interval.
28173   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8,
28174                    std::monostate{});
28175 
28176   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
28177     if (Chain->getMemoryVT().isScalableVector())
28178       return false;
28179 
28180     // If the chain has more than one use, then we can't reorder the mem ops.
28181     if (!SDValue(Chain, 0)->hasOneUse())
28182       break;
28183     // TODO: Relax for unordered atomics (see D66309)
28184     if (!Chain->isSimple() || Chain->isIndexed())
28185       break;
28186 
28187     // Find the base pointer and offset for this memory node.
28188     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
28189     // Check that the base pointer is the same as the original one.
28190     int64_t Offset;
28191     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
28192       break;
28193     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
28194     // Make sure we don't overlap with other intervals by checking the ones to
28195     // the left or right before inserting.
28196     auto I = Intervals.find(Offset);
28197     // If there's a next interval, we should end before it.
28198     if (I != Intervals.end() && I.start() < (Offset + Length))
28199       break;
28200     // If there's a previous interval, we should start after it.
28201     if (I != Intervals.begin() && (--I).stop() <= Offset)
28202       break;
28203     Intervals.insert(Offset, Offset + Length, std::monostate{});
28204 
28205     ChainedStores.push_back(Chain);
28206     STChain = Chain;
28207   }
28208 
28209   // If we didn't find a chained store, exit.
28210   if (ChainedStores.empty())
28211     return false;
28212 
28213   // Improve all chained stores (St and ChainedStores members) starting from
28214   // where the store chain ended and return single TokenFactor.
28215   SDValue NewChain = STChain->getChain();
28216   SmallVector<SDValue, 8> TFOps;
28217   for (unsigned I = ChainedStores.size(); I;) {
28218     StoreSDNode *S = ChainedStores[--I];
28219     SDValue BetterChain = FindBetterChain(S, NewChain);
28220     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
28221         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
28222     TFOps.push_back(SDValue(S, 0));
28223     ChainedStores[I] = S;
28224   }
28225 
28226   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
28227   SDValue BetterChain = FindBetterChain(St, NewChain);
28228   SDValue NewST;
28229   if (St->isTruncatingStore())
28230     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
28231                               St->getBasePtr(), St->getMemoryVT(),
28232                               St->getMemOperand());
28233   else
28234     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
28235                          St->getBasePtr(), St->getMemOperand());
28236 
28237   TFOps.push_back(NewST);
28238 
28239   // If we improved every element of TFOps, then we've lost the dependence on
28240   // NewChain to successors of St and we need to add it back to TFOps. Do so at
28241   // the beginning to keep relative order consistent with FindBetterChains.
28242   auto hasImprovedChain = [&](SDValue ST) -> bool {
28243     return ST->getOperand(0) != NewChain;
28244   };
28245   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
28246   if (AddNewChain)
28247     TFOps.insert(TFOps.begin(), NewChain);
28248 
28249   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
28250   CombineTo(St, TF);
28251 
28252   // Add TF and its operands to the worklist.
28253   AddToWorklist(TF.getNode());
28254   for (const SDValue &Op : TF->ops())
28255     AddToWorklist(Op.getNode());
28256   AddToWorklist(STChain);
28257   return true;
28258 }
28259 
findBetterNeighborChains(StoreSDNode * St)28260 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
28261   if (OptLevel == CodeGenOptLevel::None)
28262     return false;
28263 
28264   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
28265 
28266   // We must have a base and an offset.
28267   if (!BasePtr.getBase().getNode())
28268     return false;
28269 
28270   // Do not handle stores to undef base pointers.
28271   if (BasePtr.getBase().isUndef())
28272     return false;
28273 
28274   // Directly improve a chain of disjoint stores starting at St.
28275   if (parallelizeChainedStores(St))
28276     return true;
28277 
28278   // Improve St's Chain..
28279   SDValue BetterChain = FindBetterChain(St, St->getChain());
28280   if (St->getChain() != BetterChain) {
28281     replaceStoreChain(St, BetterChain);
28282     return true;
28283   }
28284   return false;
28285 }
28286 
28287 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOptLevel OptLevel)28288 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
28289                            CodeGenOptLevel OptLevel) {
28290   /// This is the main entry point to this class.
28291   DAGCombiner(*this, AA, OptLevel).Run(Level);
28292 }
28293