1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/Statistic.h"
32 #include "llvm/Analysis/AliasAnalysis.h"
33 #include "llvm/Analysis/MemoryLocation.h"
34 #include "llvm/Analysis/TargetLibraryInfo.h"
35 #include "llvm/Analysis/VectorUtils.h"
36 #include "llvm/CodeGen/DAGCombine.h"
37 #include "llvm/CodeGen/ISDOpcodes.h"
38 #include "llvm/CodeGen/MachineFrameInfo.h"
39 #include "llvm/CodeGen/MachineFunction.h"
40 #include "llvm/CodeGen/MachineMemOperand.h"
41 #include "llvm/CodeGen/RuntimeLibcalls.h"
42 #include "llvm/CodeGen/SelectionDAG.h"
43 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
44 #include "llvm/CodeGen/SelectionDAGNodes.h"
45 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
46 #include "llvm/CodeGen/TargetLowering.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/TargetSubtargetInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/Constant.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/LLVMContext.h"
56 #include "llvm/IR/Metadata.h"
57 #include "llvm/Support/Casting.h"
58 #include "llvm/Support/CodeGen.h"
59 #include "llvm/Support/CommandLine.h"
60 #include "llvm/Support/Compiler.h"
61 #include "llvm/Support/Debug.h"
62 #include "llvm/Support/ErrorHandling.h"
63 #include "llvm/Support/KnownBits.h"
64 #include "llvm/Support/MachineValueType.h"
65 #include "llvm/Support/MathExtras.h"
66 #include "llvm/Support/raw_ostream.h"
67 #include "llvm/Target/TargetMachine.h"
68 #include "llvm/Target/TargetOptions.h"
69 #include <algorithm>
70 #include <cassert>
71 #include <cstdint>
72 #include <functional>
73 #include <iterator>
74 #include <string>
75 #include <tuple>
76 #include <utility>
77 
78 using namespace llvm;
79 
80 #define DEBUG_TYPE "dagcombine"
81 
82 STATISTIC(NodesCombined   , "Number of dag nodes combined");
83 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
84 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
85 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
86 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
87 STATISTIC(SlicedLoads, "Number of load sliced");
88 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
89 
90 static cl::opt<bool>
91 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
92                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
93 
94 static cl::opt<bool>
95 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
96         cl::desc("Enable DAG combiner's use of TBAA"));
97 
98 #ifndef NDEBUG
99 static cl::opt<std::string>
100 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
101                    cl::desc("Only use DAG-combiner alias analysis in this"
102                             " function"));
103 #endif
104 
105 /// Hidden option to stress test load slicing, i.e., when this option
106 /// is enabled, load slicing bypasses most of its profitability guards.
107 static cl::opt<bool>
108 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
109                   cl::desc("Bypass the profitability model of load slicing"),
110                   cl::init(false));
111 
112 static cl::opt<bool>
113   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
114                     cl::desc("DAG combiner may split indexing from loads"));
115 
116 static cl::opt<bool>
117     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
118                        cl::desc("DAG combiner enable merging multiple stores "
119                                 "into a wider store"));
120 
121 static cl::opt<unsigned> TokenFactorInlineLimit(
122     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
123     cl::desc("Limit the number of operands to inline for Token Factors"));
124 
125 static cl::opt<unsigned> StoreMergeDependenceLimit(
126     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
127     cl::desc("Limit the number of times for the same StoreNode and RootNode "
128              "to bail out in store merging dependence check"));
129 
130 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
131     "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
132     cl::desc("DAG cominber enable reducing the width of load/op/store "
133              "sequence"));
134 
135 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
136     "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
137     cl::desc("DAG cominber enable load/<replace bytes>/store with "
138              "a narrower store"));
139 
140 namespace {
141 
142   class DAGCombiner {
143     SelectionDAG &DAG;
144     const TargetLowering &TLI;
145     const SelectionDAGTargetInfo *STI;
146     CombineLevel Level;
147     CodeGenOpt::Level OptLevel;
148     bool LegalDAG = false;
149     bool LegalOperations = false;
150     bool LegalTypes = false;
151     bool ForCodeSize;
152     bool DisableGenericCombines;
153 
154     /// Worklist of all of the nodes that need to be simplified.
155     ///
156     /// This must behave as a stack -- new nodes to process are pushed onto the
157     /// back and when processing we pop off of the back.
158     ///
159     /// The worklist will not contain duplicates but may contain null entries
160     /// due to nodes being deleted from the underlying DAG.
161     SmallVector<SDNode *, 64> Worklist;
162 
163     /// Mapping from an SDNode to its position on the worklist.
164     ///
165     /// This is used to find and remove nodes from the worklist (by nulling
166     /// them) when they are deleted from the underlying DAG. It relies on
167     /// stable indices of nodes within the worklist.
168     DenseMap<SDNode *, unsigned> WorklistMap;
169     /// This records all nodes attempted to add to the worklist since we
170     /// considered a new worklist entry. As we keep do not add duplicate nodes
171     /// in the worklist, this is different from the tail of the worklist.
172     SmallSetVector<SDNode *, 32> PruningList;
173 
174     /// Set of nodes which have been combined (at least once).
175     ///
176     /// This is used to allow us to reliably add any operands of a DAG node
177     /// which have not yet been combined to the worklist.
178     SmallPtrSet<SDNode *, 32> CombinedNodes;
179 
180     /// Map from candidate StoreNode to the pair of RootNode and count.
181     /// The count is used to track how many times we have seen the StoreNode
182     /// with the same RootNode bail out in dependence check. If we have seen
183     /// the bail out for the same pair many times over a limit, we won't
184     /// consider the StoreNode with the same RootNode as store merging
185     /// candidate again.
186     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
187 
188     // AA - Used for DAG load/store alias analysis.
189     AliasAnalysis *AA;
190 
191     /// When an instruction is simplified, add all users of the instruction to
192     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)193     void AddUsersToWorklist(SDNode *N) {
194       for (SDNode *Node : N->uses())
195         AddToWorklist(Node);
196     }
197 
198     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)199     void AddToWorklistWithUsers(SDNode *N) {
200       AddUsersToWorklist(N);
201       AddToWorklist(N);
202     }
203 
204     // Prune potentially dangling nodes. This is called after
205     // any visit to a node, but should also be called during a visit after any
206     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()207     void clearAddedDanglingWorklistEntries() {
208       // Check any nodes added to the worklist to see if they are prunable.
209       while (!PruningList.empty()) {
210         auto *N = PruningList.pop_back_val();
211         if (N->use_empty())
212           recursivelyDeleteUnusedNodes(N);
213       }
214     }
215 
getNextWorklistEntry()216     SDNode *getNextWorklistEntry() {
217       // Before we do any work, remove nodes that are not in use.
218       clearAddedDanglingWorklistEntries();
219       SDNode *N = nullptr;
220       // The Worklist holds the SDNodes in order, but it may contain null
221       // entries.
222       while (!N && !Worklist.empty()) {
223         N = Worklist.pop_back_val();
224       }
225 
226       if (N) {
227         bool GoodWorklistEntry = WorklistMap.erase(N);
228         (void)GoodWorklistEntry;
229         assert(GoodWorklistEntry &&
230                "Found a worklist entry without a corresponding map entry!");
231       }
232       return N;
233     }
234 
235     /// Call the node-specific routine that folds each particular type of node.
236     SDValue visit(SDNode *N);
237 
238   public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOpt::Level OL)239     DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
240         : DAG(D), TLI(D.getTargetLoweringInfo()),
241           STI(D.getSubtarget().getSelectionDAGInfo()),
242           Level(BeforeLegalizeTypes), OptLevel(OL), AA(AA) {
243       ForCodeSize = DAG.shouldOptForSize();
244       DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
245 
246       MaximumLegalStoreInBits = 0;
247       // We use the minimum store size here, since that's all we can guarantee
248       // for the scalable vector types.
249       for (MVT VT : MVT::all_valuetypes())
250         if (EVT(VT).isSimple() && VT != MVT::Other &&
251             TLI.isTypeLegal(EVT(VT)) &&
252             VT.getSizeInBits().getKnownMinSize() >= MaximumLegalStoreInBits)
253           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinSize();
254     }
255 
ConsiderForPruning(SDNode * N)256     void ConsiderForPruning(SDNode *N) {
257       // Mark this for potential pruning.
258       PruningList.insert(N);
259     }
260 
261     /// Add to the worklist making sure its instance is at the back (next to be
262     /// processed.)
AddToWorklist(SDNode * N)263     void AddToWorklist(SDNode *N) {
264       assert(N->getOpcode() != ISD::DELETED_NODE &&
265              "Deleted Node added to Worklist");
266 
267       // Skip handle nodes as they can't usefully be combined and confuse the
268       // zero-use deletion strategy.
269       if (N->getOpcode() == ISD::HANDLENODE)
270         return;
271 
272       ConsiderForPruning(N);
273 
274       if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
275         Worklist.push_back(N);
276     }
277 
278     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)279     void removeFromWorklist(SDNode *N) {
280       CombinedNodes.erase(N);
281       PruningList.remove(N);
282       StoreRootCountMap.erase(N);
283 
284       auto It = WorklistMap.find(N);
285       if (It == WorklistMap.end())
286         return; // Not in the worklist.
287 
288       // Null out the entry rather than erasing it to avoid a linear operation.
289       Worklist[It->second] = nullptr;
290       WorklistMap.erase(It);
291     }
292 
293     void deleteAndRecombine(SDNode *N);
294     bool recursivelyDeleteUnusedNodes(SDNode *N);
295 
296     /// Replaces all uses of the results of one DAG node with new values.
297     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
298                       bool AddTo = true);
299 
300     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)301     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
302       return CombineTo(N, &Res, 1, AddTo);
303     }
304 
305     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)306     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
307                       bool AddTo = true) {
308       SDValue To[] = { Res0, Res1 };
309       return CombineTo(N, To, 2, AddTo);
310     }
311 
312     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
313 
314   private:
315     unsigned MaximumLegalStoreInBits;
316 
317     /// Check the specified integer node value to see if it can be simplified or
318     /// if things it uses can be simplified by bit propagation.
319     /// If so, return true.
SimplifyDemandedBits(SDValue Op)320     bool SimplifyDemandedBits(SDValue Op) {
321       unsigned BitWidth = Op.getScalarValueSizeInBits();
322       APInt DemandedBits = APInt::getAllOnesValue(BitWidth);
323       return SimplifyDemandedBits(Op, DemandedBits);
324     }
325 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)326     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
327       TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
328       KnownBits Known;
329       if (!TLI.SimplifyDemandedBits(Op, DemandedBits, Known, TLO, 0, false))
330         return false;
331 
332       // Revisit the node.
333       AddToWorklist(Op.getNode());
334 
335       CommitTargetLoweringOpt(TLO);
336       return true;
337     }
338 
339     /// Check the specified vector node value to see if it can be simplified or
340     /// if things it uses can be simplified as it only uses some of the
341     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)342     bool SimplifyDemandedVectorElts(SDValue Op) {
343       // TODO: For now just pretend it cannot be simplified.
344       if (Op.getValueType().isScalableVector())
345         return false;
346 
347       unsigned NumElts = Op.getValueType().getVectorNumElements();
348       APInt DemandedElts = APInt::getAllOnesValue(NumElts);
349       return SimplifyDemandedVectorElts(Op, DemandedElts);
350     }
351 
352     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
353                               const APInt &DemandedElts,
354                               bool AssumeSingleUse = false);
355     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
356                                     bool AssumeSingleUse = false);
357 
358     bool CombineToPreIndexedLoadStore(SDNode *N);
359     bool CombineToPostIndexedLoadStore(SDNode *N);
360     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
361     bool SliceUpLoad(SDNode *N);
362 
363     // Scalars have size 0 to distinguish from singleton vectors.
364     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
365     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
366     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
367 
368     /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
369     ///   load.
370     ///
371     /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
372     /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
373     /// \param EltNo index of the vector element to load.
374     /// \param OriginalLoad load that EVE came from to be replaced.
375     /// \returns EVE on success SDValue() on failure.
376     SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
377                                          SDValue EltNo,
378                                          LoadSDNode *OriginalLoad);
379     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
380     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
381     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
382     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
383     SDValue PromoteIntBinOp(SDValue Op);
384     SDValue PromoteIntShiftOp(SDValue Op);
385     SDValue PromoteExtend(SDValue Op);
386     bool PromoteLoad(SDValue Op);
387 
388     /// Call the node-specific routine that knows how to fold each
389     /// particular type of node. If that doesn't do anything, try the
390     /// target-specific DAG combines.
391     SDValue combine(SDNode *N);
392 
393     // Visitation implementation - Implement dag node combining for different
394     // node types.  The semantics are as follows:
395     // Return Value:
396     //   SDValue.getNode() == 0 - No change was made
397     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
398     //   otherwise              - N should be replaced by the returned Operand.
399     //
400     SDValue visitTokenFactor(SDNode *N);
401     SDValue visitMERGE_VALUES(SDNode *N);
402     SDValue visitADD(SDNode *N);
403     SDValue visitADDLike(SDNode *N);
404     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
405     SDValue visitSUB(SDNode *N);
406     SDValue visitADDSAT(SDNode *N);
407     SDValue visitSUBSAT(SDNode *N);
408     SDValue visitADDC(SDNode *N);
409     SDValue visitADDO(SDNode *N);
410     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
411     SDValue visitSUBC(SDNode *N);
412     SDValue visitSUBO(SDNode *N);
413     SDValue visitADDE(SDNode *N);
414     SDValue visitADDCARRY(SDNode *N);
415     SDValue visitSADDO_CARRY(SDNode *N);
416     SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
417     SDValue visitSUBE(SDNode *N);
418     SDValue visitSUBCARRY(SDNode *N);
419     SDValue visitSSUBO_CARRY(SDNode *N);
420     SDValue visitMUL(SDNode *N);
421     SDValue visitMULFIX(SDNode *N);
422     SDValue useDivRem(SDNode *N);
423     SDValue visitSDIV(SDNode *N);
424     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
425     SDValue visitUDIV(SDNode *N);
426     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
427     SDValue visitREM(SDNode *N);
428     SDValue visitMULHU(SDNode *N);
429     SDValue visitMULHS(SDNode *N);
430     SDValue visitSMUL_LOHI(SDNode *N);
431     SDValue visitUMUL_LOHI(SDNode *N);
432     SDValue visitMULO(SDNode *N);
433     SDValue visitIMINMAX(SDNode *N);
434     SDValue visitAND(SDNode *N);
435     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
436     SDValue visitOR(SDNode *N);
437     SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
438     SDValue visitXOR(SDNode *N);
439     SDValue SimplifyVBinOp(SDNode *N);
440     SDValue visitSHL(SDNode *N);
441     SDValue visitSRA(SDNode *N);
442     SDValue visitSRL(SDNode *N);
443     SDValue visitFunnelShift(SDNode *N);
444     SDValue visitRotate(SDNode *N);
445     SDValue visitABS(SDNode *N);
446     SDValue visitBSWAP(SDNode *N);
447     SDValue visitBITREVERSE(SDNode *N);
448     SDValue visitCTLZ(SDNode *N);
449     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
450     SDValue visitCTTZ(SDNode *N);
451     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
452     SDValue visitCTPOP(SDNode *N);
453     SDValue visitSELECT(SDNode *N);
454     SDValue visitVSELECT(SDNode *N);
455     SDValue visitSELECT_CC(SDNode *N);
456     SDValue visitSETCC(SDNode *N);
457     SDValue visitSETCCCARRY(SDNode *N);
458     SDValue visitSIGN_EXTEND(SDNode *N);
459     SDValue visitZERO_EXTEND(SDNode *N);
460     SDValue visitANY_EXTEND(SDNode *N);
461     SDValue visitAssertExt(SDNode *N);
462     SDValue visitAssertAlign(SDNode *N);
463     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
464     SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N);
465     SDValue visitZERO_EXTEND_VECTOR_INREG(SDNode *N);
466     SDValue visitTRUNCATE(SDNode *N);
467     SDValue visitBITCAST(SDNode *N);
468     SDValue visitFREEZE(SDNode *N);
469     SDValue visitBUILD_PAIR(SDNode *N);
470     SDValue visitFADD(SDNode *N);
471     SDValue visitSTRICT_FADD(SDNode *N);
472     SDValue visitFSUB(SDNode *N);
473     SDValue visitFMUL(SDNode *N);
474     SDValue visitFMA(SDNode *N);
475     SDValue visitFDIV(SDNode *N);
476     SDValue visitFREM(SDNode *N);
477     SDValue visitFSQRT(SDNode *N);
478     SDValue visitFCOPYSIGN(SDNode *N);
479     SDValue visitFPOW(SDNode *N);
480     SDValue visitSINT_TO_FP(SDNode *N);
481     SDValue visitUINT_TO_FP(SDNode *N);
482     SDValue visitFP_TO_SINT(SDNode *N);
483     SDValue visitFP_TO_UINT(SDNode *N);
484     SDValue visitFP_ROUND(SDNode *N);
485     SDValue visitFP_EXTEND(SDNode *N);
486     SDValue visitFNEG(SDNode *N);
487     SDValue visitFABS(SDNode *N);
488     SDValue visitFCEIL(SDNode *N);
489     SDValue visitFTRUNC(SDNode *N);
490     SDValue visitFFLOOR(SDNode *N);
491     SDValue visitFMINNUM(SDNode *N);
492     SDValue visitFMAXNUM(SDNode *N);
493     SDValue visitFMINIMUM(SDNode *N);
494     SDValue visitFMAXIMUM(SDNode *N);
495     SDValue visitBRCOND(SDNode *N);
496     SDValue visitBR_CC(SDNode *N);
497     SDValue visitLOAD(SDNode *N);
498 
499     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
500     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
501 
502     SDValue visitSTORE(SDNode *N);
503     SDValue visitLIFETIME_END(SDNode *N);
504     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
505     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
506     SDValue visitBUILD_VECTOR(SDNode *N);
507     SDValue visitCONCAT_VECTORS(SDNode *N);
508     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
509     SDValue visitVECTOR_SHUFFLE(SDNode *N);
510     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
511     SDValue visitINSERT_SUBVECTOR(SDNode *N);
512     SDValue visitMLOAD(SDNode *N);
513     SDValue visitMSTORE(SDNode *N);
514     SDValue visitMGATHER(SDNode *N);
515     SDValue visitMSCATTER(SDNode *N);
516     SDValue visitFP_TO_FP16(SDNode *N);
517     SDValue visitFP16_TO_FP(SDNode *N);
518     SDValue visitVECREDUCE(SDNode *N);
519 
520     SDValue visitFADDForFMACombine(SDNode *N);
521     SDValue visitFSUBForFMACombine(SDNode *N);
522     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
523 
524     SDValue XformToShuffleWithZero(SDNode *N);
525     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
526                                                     const SDLoc &DL, SDValue N0,
527                                                     SDValue N1);
528     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
529                                       SDValue N1);
530     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
531                            SDValue N1, SDNodeFlags Flags);
532 
533     SDValue visitShiftByConstant(SDNode *N);
534 
535     SDValue foldSelectOfConstants(SDNode *N);
536     SDValue foldVSelectOfConstants(SDNode *N);
537     SDValue foldBinOpIntoSelect(SDNode *BO);
538     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
539     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
540     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
541     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
542                              SDValue N2, SDValue N3, ISD::CondCode CC,
543                              bool NotExtCompare = false);
544     SDValue convertSelectOfFPConstantsToLoadOffset(
545         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
546         ISD::CondCode CC);
547     SDValue foldSignChangeInBitcast(SDNode *N);
548     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
549                                    SDValue N2, SDValue N3, ISD::CondCode CC);
550     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
551                               const SDLoc &DL);
552     SDValue unfoldMaskedMerge(SDNode *N);
553     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
554     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
555                           const SDLoc &DL, bool foldBooleans);
556     SDValue rebuildSetCC(SDValue N);
557 
558     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
559                            SDValue &CC, bool MatchStrict = false) const;
560     bool isOneUseSetCC(SDValue N) const;
561 
562     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
563                                          unsigned HiOp);
564     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
565     SDValue CombineExtLoad(SDNode *N);
566     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
567     SDValue combineRepeatedFPDivisors(SDNode *N);
568     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
569     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
570     SDValue BuildSDIV(SDNode *N);
571     SDValue BuildSDIVPow2(SDNode *N);
572     SDValue BuildUDIV(SDNode *N);
573     SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
574     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
575     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
576     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
577     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
578     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
579                                 SDNodeFlags Flags, bool Reciprocal);
580     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
581                                 SDNodeFlags Flags, bool Reciprocal);
582     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
583                                bool DemandHighBits = true);
584     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
585     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
586                               SDValue InnerPos, SDValue InnerNeg,
587                               unsigned PosOpcode, unsigned NegOpcode,
588                               const SDLoc &DL);
589     SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
590                               SDValue InnerPos, SDValue InnerNeg,
591                               unsigned PosOpcode, unsigned NegOpcode,
592                               const SDLoc &DL);
593     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
594     SDValue MatchLoadCombine(SDNode *N);
595     SDValue mergeTruncStores(StoreSDNode *N);
596     SDValue ReduceLoadWidth(SDNode *N);
597     SDValue ReduceLoadOpStoreWidth(SDNode *N);
598     SDValue splitMergedValStore(StoreSDNode *ST);
599     SDValue TransformFPLoadStorePair(SDNode *N);
600     SDValue convertBuildVecZextToZext(SDNode *N);
601     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
602     SDValue reduceBuildVecTruncToBitCast(SDNode *N);
603     SDValue reduceBuildVecToShuffle(SDNode *N);
604     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
605                                   ArrayRef<int> VectorMask, SDValue VecIn1,
606                                   SDValue VecIn2, unsigned LeftIdx,
607                                   bool DidSplitVec);
608     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
609 
610     /// Walk up chain skipping non-aliasing memory nodes,
611     /// looking for aliasing nodes and adding them to the Aliases vector.
612     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
613                           SmallVectorImpl<SDValue> &Aliases);
614 
615     /// Return true if there is any possibility that the two addresses overlap.
616     bool isAlias(SDNode *Op0, SDNode *Op1) const;
617 
618     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
619     /// chain (aliasing node.)
620     SDValue FindBetterChain(SDNode *N, SDValue Chain);
621 
622     /// Try to replace a store and any possibly adjacent stores on
623     /// consecutive chains with better chains. Return true only if St is
624     /// replaced.
625     ///
626     /// Notice that other chains may still be replaced even if the function
627     /// returns false.
628     bool findBetterNeighborChains(StoreSDNode *St);
629 
630     // Helper for findBetterNeighborChains. Walk up store chain add additional
631     // chained stores that do not overlap and can be parallelized.
632     bool parallelizeChainedStores(StoreSDNode *St);
633 
634     /// Holds a pointer to an LSBaseSDNode as well as information on where it
635     /// is located in a sequence of memory operations connected by a chain.
636     struct MemOpLink {
637       // Ptr to the mem node.
638       LSBaseSDNode *MemNode;
639 
640       // Offset from the base ptr.
641       int64_t OffsetFromBase;
642 
MemOpLink__anonf026c69c0111::DAGCombiner::MemOpLink643       MemOpLink(LSBaseSDNode *N, int64_t Offset)
644           : MemNode(N), OffsetFromBase(Offset) {}
645     };
646 
647     // Classify the origin of a stored value.
648     enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)649     StoreSource getStoreSource(SDValue StoreVal) {
650       switch (StoreVal.getOpcode()) {
651       case ISD::Constant:
652       case ISD::ConstantFP:
653         return StoreSource::Constant;
654       case ISD::EXTRACT_VECTOR_ELT:
655       case ISD::EXTRACT_SUBVECTOR:
656         return StoreSource::Extract;
657       case ISD::LOAD:
658         return StoreSource::Load;
659       default:
660         return StoreSource::Unknown;
661       }
662     }
663 
664     /// This is a helper function for visitMUL to check the profitability
665     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
666     /// MulNode is the original multiply, AddNode is (add x, c1),
667     /// and ConstNode is c2.
668     bool isMulAddWithConstProfitable(SDNode *MulNode,
669                                      SDValue &AddNode,
670                                      SDValue &ConstNode);
671 
672     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
673     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
674     /// the type of the loaded value to be extended.
675     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
676                           EVT LoadResultTy, EVT &ExtVT);
677 
678     /// Helper function to calculate whether the given Load/Store can have its
679     /// width reduced to ExtVT.
680     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
681                            EVT &MemVT, unsigned ShAmt = 0);
682 
683     /// Used by BackwardsPropagateMask to find suitable loads.
684     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
685                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
686                            ConstantSDNode *Mask, SDNode *&NodeToMask);
687     /// Attempt to propagate a given AND node back to load leaves so that they
688     /// can be combined into narrow loads.
689     bool BackwardsPropagateMask(SDNode *N);
690 
691     /// Helper function for mergeConsecutiveStores which merges the component
692     /// store chains.
693     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
694                                 unsigned NumStores);
695 
696     /// This is a helper function for mergeConsecutiveStores. When the source
697     /// elements of the consecutive stores are all constants or all extracted
698     /// vector elements, try to merge them into one larger store introducing
699     /// bitcasts if necessary.  \return True if a merged store was created.
700     bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
701                                          EVT MemVT, unsigned NumStores,
702                                          bool IsConstantSrc, bool UseVector,
703                                          bool UseTrunc);
704 
705     /// This is a helper function for mergeConsecutiveStores. Stores that
706     /// potentially may be merged with St are placed in StoreNodes. RootNode is
707     /// a chain predecessor to all store candidates.
708     void getStoreMergeCandidates(StoreSDNode *St,
709                                  SmallVectorImpl<MemOpLink> &StoreNodes,
710                                  SDNode *&Root);
711 
712     /// Helper function for mergeConsecutiveStores. Checks if candidate stores
713     /// have indirect dependency through their operands. RootNode is the
714     /// predecessor to all stores calculated by getStoreMergeCandidates and is
715     /// used to prune the dependency check. \return True if safe to merge.
716     bool checkMergeStoreCandidatesForDependencies(
717         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
718         SDNode *RootNode);
719 
720     /// This is a helper function for mergeConsecutiveStores. Given a list of
721     /// store candidates, find the first N that are consecutive in memory.
722     /// Returns 0 if there are not at least 2 consecutive stores to try merging.
723     unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
724                                   int64_t ElementSizeBytes) const;
725 
726     /// This is a helper function for mergeConsecutiveStores. It is used for
727     /// store chains that are composed entirely of constant values.
728     bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
729                                   unsigned NumConsecutiveStores,
730                                   EVT MemVT, SDNode *Root, bool AllowVectors);
731 
732     /// This is a helper function for mergeConsecutiveStores. It is used for
733     /// store chains that are composed entirely of extracted vector elements.
734     /// When extracting multiple vector elements, try to store them in one
735     /// vector store rather than a sequence of scalar stores.
736     bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
737                                  unsigned NumConsecutiveStores, EVT MemVT,
738                                  SDNode *Root);
739 
740     /// This is a helper function for mergeConsecutiveStores. It is used for
741     /// store chains that are composed entirely of loaded values.
742     bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
743                               unsigned NumConsecutiveStores, EVT MemVT,
744                               SDNode *Root, bool AllowVectors,
745                               bool IsNonTemporalStore, bool IsNonTemporalLoad);
746 
747     /// Merge consecutive store operations into a wide store.
748     /// This optimization uses wide integers or vectors when possible.
749     /// \return true if stores were merged.
750     bool mergeConsecutiveStores(StoreSDNode *St);
751 
752     /// Try to transform a truncation where C is a constant:
753     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
754     ///
755     /// \p N needs to be a truncation and its first operand an AND. Other
756     /// requirements are checked by the function (e.g. that trunc is
757     /// single-use) and if missed an empty SDValue is returned.
758     SDValue distributeTruncateThroughAnd(SDNode *N);
759 
760     /// Helper function to determine whether the target supports operation
761     /// given by \p Opcode for type \p VT, that is, whether the operation
762     /// is legal or custom before legalizing operations, and whether is
763     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)764     bool hasOperation(unsigned Opcode, EVT VT) {
765       return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
766     }
767 
768   public:
769     /// Runs the dag combiner on all nodes in the work list
770     void Run(CombineLevel AtLevel);
771 
getDAG() const772     SelectionDAG &getDAG() const { return DAG; }
773 
774     /// Returns a type large enough to hold any valid shift amount - before type
775     /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)776     EVT getShiftAmountTy(EVT LHSTy) {
777       assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
778       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
779     }
780 
781     /// This method returns true if we are running before type legalization or
782     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)783     bool isTypeLegal(const EVT &VT) {
784       if (!LegalTypes) return true;
785       return TLI.isTypeLegal(VT);
786     }
787 
788     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const789     EVT getSetCCResultType(EVT VT) const {
790       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
791     }
792 
793     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
794                          SDValue OrigLoad, SDValue ExtLoad,
795                          ISD::NodeType ExtType);
796   };
797 
798 /// This class is a DAGUpdateListener that removes any deleted
799 /// nodes from the worklist.
800 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
801   DAGCombiner &DC;
802 
803 public:
WorklistRemover(DAGCombiner & dc)804   explicit WorklistRemover(DAGCombiner &dc)
805     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
806 
NodeDeleted(SDNode * N,SDNode * E)807   void NodeDeleted(SDNode *N, SDNode *E) override {
808     DC.removeFromWorklist(N);
809   }
810 };
811 
812 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
813   DAGCombiner &DC;
814 
815 public:
WorklistInserter(DAGCombiner & dc)816   explicit WorklistInserter(DAGCombiner &dc)
817       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
818 
819   // FIXME: Ideally we could add N to the worklist, but this causes exponential
820   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)821   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
822 };
823 
824 } // end anonymous namespace
825 
826 //===----------------------------------------------------------------------===//
827 //  TargetLowering::DAGCombinerInfo implementation
828 //===----------------------------------------------------------------------===//
829 
AddToWorklist(SDNode * N)830 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
831   ((DAGCombiner*)DC)->AddToWorklist(N);
832 }
833 
834 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)835 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
836   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
837 }
838 
839 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)840 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
841   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
842 }
843 
844 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)845 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
846   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
847 }
848 
849 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)850 recursivelyDeleteUnusedNodes(SDNode *N) {
851   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
852 }
853 
854 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)855 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
856   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
857 }
858 
859 //===----------------------------------------------------------------------===//
860 // Helper Functions
861 //===----------------------------------------------------------------------===//
862 
deleteAndRecombine(SDNode * N)863 void DAGCombiner::deleteAndRecombine(SDNode *N) {
864   removeFromWorklist(N);
865 
866   // If the operands of this node are only used by the node, they will now be
867   // dead. Make sure to re-visit them and recursively delete dead nodes.
868   for (const SDValue &Op : N->ops())
869     // For an operand generating multiple values, one of the values may
870     // become dead allowing further simplification (e.g. split index
871     // arithmetic from an indexed load).
872     if (Op->hasOneUse() || Op->getNumValues() > 1)
873       AddToWorklist(Op.getNode());
874 
875   DAG.DeleteNode(N);
876 }
877 
878 // APInts must be the same size for most operations, this helper
879 // function zero extends the shorter of the pair so that they match.
880 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)881 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
882   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
883   LHS = LHS.zextOrSelf(Bits);
884   RHS = RHS.zextOrSelf(Bits);
885 }
886 
887 // Return true if this node is a setcc, or is a select_cc
888 // that selects between the target values used for true and false, making it
889 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
890 // the appropriate nodes based on the type of node we are checking. This
891 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const892 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
893                                     SDValue &CC, bool MatchStrict) const {
894   if (N.getOpcode() == ISD::SETCC) {
895     LHS = N.getOperand(0);
896     RHS = N.getOperand(1);
897     CC  = N.getOperand(2);
898     return true;
899   }
900 
901   if (MatchStrict &&
902       (N.getOpcode() == ISD::STRICT_FSETCC ||
903        N.getOpcode() == ISD::STRICT_FSETCCS)) {
904     LHS = N.getOperand(1);
905     RHS = N.getOperand(2);
906     CC  = N.getOperand(3);
907     return true;
908   }
909 
910   if (N.getOpcode() != ISD::SELECT_CC ||
911       !TLI.isConstTrueVal(N.getOperand(2).getNode()) ||
912       !TLI.isConstFalseVal(N.getOperand(3).getNode()))
913     return false;
914 
915   if (TLI.getBooleanContents(N.getValueType()) ==
916       TargetLowering::UndefinedBooleanContent)
917     return false;
918 
919   LHS = N.getOperand(0);
920   RHS = N.getOperand(1);
921   CC  = N.getOperand(4);
922   return true;
923 }
924 
925 /// Return true if this is a SetCC-equivalent operation with only one use.
926 /// If this is true, it allows the users to invert the operation for free when
927 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const928 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
929   SDValue N0, N1, N2;
930   if (isSetCCEquivalent(N, N0, N1, N2) && N.getNode()->hasOneUse())
931     return true;
932   return false;
933 }
934 
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)935 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
936   if (!ScalarTy.isSimple())
937     return false;
938 
939   uint64_t MaskForTy = 0ULL;
940   switch (ScalarTy.getSimpleVT().SimpleTy) {
941   case MVT::i8:
942     MaskForTy = 0xFFULL;
943     break;
944   case MVT::i16:
945     MaskForTy = 0xFFFFULL;
946     break;
947   case MVT::i32:
948     MaskForTy = 0xFFFFFFFFULL;
949     break;
950   default:
951     return false;
952     break;
953   }
954 
955   APInt Val;
956   if (ISD::isConstantSplatVector(N, Val))
957     return Val.getLimitedValue() == MaskForTy;
958 
959   return false;
960 }
961 
962 // Determines if it is a constant integer or a splat/build vector of constant
963 // integers (and undefs).
964 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)965 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
966   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
967     return !(Const->isOpaque() && NoOpaques);
968   if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
969     return false;
970   unsigned BitWidth = N.getScalarValueSizeInBits();
971   for (const SDValue &Op : N->op_values()) {
972     if (Op.isUndef())
973       continue;
974     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
975     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
976         (Const->isOpaque() && NoOpaques))
977       return false;
978   }
979   return true;
980 }
981 
982 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
983 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)984 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
985   if (V.getOpcode() != ISD::BUILD_VECTOR)
986     return false;
987   return isConstantOrConstantVector(V, NoOpaques) ||
988          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
989 }
990 
991 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)992 static bool canSplitIdx(LoadSDNode *LD) {
993   return MaySplitLoadIndex &&
994          (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
995           !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
996 }
997 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)998 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
999                                                              const SDLoc &DL,
1000                                                              SDValue N0,
1001                                                              SDValue N1) {
1002   // Currently this only tries to ensure we don't undo the GEP splits done by
1003   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1004   // we check if the following transformation would be problematic:
1005   // (load/store (add, (add, x, offset1), offset2)) ->
1006   // (load/store (add, x, offset1+offset2)).
1007 
1008   if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1009     return false;
1010 
1011   if (N0.hasOneUse())
1012     return false;
1013 
1014   auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
1015   auto *C2 = dyn_cast<ConstantSDNode>(N1);
1016   if (!C1 || !C2)
1017     return false;
1018 
1019   const APInt &C1APIntVal = C1->getAPIntValue();
1020   const APInt &C2APIntVal = C2->getAPIntValue();
1021   if (C1APIntVal.getBitWidth() > 64 || C2APIntVal.getBitWidth() > 64)
1022     return false;
1023 
1024   const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1025   if (CombinedValueIntVal.getBitWidth() > 64)
1026     return false;
1027   const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1028 
1029   for (SDNode *Node : N0->uses()) {
1030     auto LoadStore = dyn_cast<MemSDNode>(Node);
1031     if (LoadStore) {
1032       // Is x[offset2] already not a legal addressing mode? If so then
1033       // reassociating the constants breaks nothing (we test offset2 because
1034       // that's the one we hope to fold into the load or store).
1035       TargetLoweringBase::AddrMode AM;
1036       AM.HasBaseReg = true;
1037       AM.BaseOffs = C2APIntVal.getSExtValue();
1038       EVT VT = LoadStore->getMemoryVT();
1039       unsigned AS = LoadStore->getAddressSpace();
1040       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1041       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1042         continue;
1043 
1044       // Would x[offset1+offset2] still be a legal addressing mode?
1045       AM.BaseOffs = CombinedValue;
1046       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1047         return true;
1048     }
1049   }
1050 
1051   return false;
1052 }
1053 
1054 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
1055 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)1056 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1057                                                SDValue N0, SDValue N1) {
1058   EVT VT = N0.getValueType();
1059 
1060   if (N0.getOpcode() != Opc)
1061     return SDValue();
1062 
1063   if (DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) {
1064     if (DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
1065       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1066       if (SDValue OpNode =
1067               DAG.FoldConstantArithmetic(Opc, DL, VT, {N0.getOperand(1), N1}))
1068         return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode);
1069       return SDValue();
1070     }
1071     if (N0.hasOneUse()) {
1072       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1073       //              iff (op x, c1) has one use
1074       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N0.getOperand(0), N1);
1075       if (!OpNode.getNode())
1076         return SDValue();
1077       return DAG.getNode(Opc, DL, VT, OpNode, N0.getOperand(1));
1078     }
1079   }
1080   return SDValue();
1081 }
1082 
1083 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1084 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1085                                     SDValue N1, SDNodeFlags Flags) {
1086   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1087 
1088   // Floating-point reassociation is not allowed without loose FP math.
1089   if (N0.getValueType().isFloatingPoint() ||
1090       N1.getValueType().isFloatingPoint())
1091     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1092       return SDValue();
1093 
1094   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1))
1095     return Combined;
1096   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0))
1097     return Combined;
1098   return SDValue();
1099 }
1100 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1101 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1102                                bool AddTo) {
1103   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1104   ++NodesCombined;
1105   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1106              To[0].getNode()->dump(&DAG);
1107              dbgs() << " and " << NumTo - 1 << " other values\n");
1108   for (unsigned i = 0, e = NumTo; i != e; ++i)
1109     assert((!To[i].getNode() ||
1110             N->getValueType(i) == To[i].getValueType()) &&
1111            "Cannot combine value to value of different type!");
1112 
1113   WorklistRemover DeadNodes(*this);
1114   DAG.ReplaceAllUsesWith(N, To);
1115   if (AddTo) {
1116     // Push the new nodes and any users onto the worklist
1117     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1118       if (To[i].getNode()) {
1119         AddToWorklist(To[i].getNode());
1120         AddUsersToWorklist(To[i].getNode());
1121       }
1122     }
1123   }
1124 
1125   // Finally, if the node is now dead, remove it from the graph.  The node
1126   // may not be dead if the replacement process recursively simplified to
1127   // something else needing this node.
1128   if (N->use_empty())
1129     deleteAndRecombine(N);
1130   return SDValue(N, 0);
1131 }
1132 
1133 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1134 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1135   // Replace the old value with the new one.
1136   ++NodesCombined;
1137   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG);
1138              dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG);
1139              dbgs() << '\n');
1140 
1141   // Replace all uses.  If any nodes become isomorphic to other nodes and
1142   // are deleted, make sure to remove them from our worklist.
1143   WorklistRemover DeadNodes(*this);
1144   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1145 
1146   // Push the new node and any (possibly new) users onto the worklist.
1147   AddToWorklistWithUsers(TLO.New.getNode());
1148 
1149   // Finally, if the node is now dead, remove it from the graph.  The node
1150   // may not be dead if the replacement process recursively simplified to
1151   // something else needing this node.
1152   if (TLO.Old.getNode()->use_empty())
1153     deleteAndRecombine(TLO.Old.getNode());
1154 }
1155 
1156 /// Check the specified integer node value to see if it can be simplified or if
1157 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1158 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1159                                        const APInt &DemandedElts,
1160                                        bool AssumeSingleUse) {
1161   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1162   KnownBits Known;
1163   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1164                                 AssumeSingleUse))
1165     return false;
1166 
1167   // Revisit the node.
1168   AddToWorklist(Op.getNode());
1169 
1170   CommitTargetLoweringOpt(TLO);
1171   return true;
1172 }
1173 
1174 /// Check the specified vector node value to see if it can be simplified or
1175 /// if things it uses can be simplified as it only uses some of the elements.
1176 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1177 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1178                                              const APInt &DemandedElts,
1179                                              bool AssumeSingleUse) {
1180   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1181   APInt KnownUndef, KnownZero;
1182   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1183                                       TLO, 0, AssumeSingleUse))
1184     return false;
1185 
1186   // Revisit the node.
1187   AddToWorklist(Op.getNode());
1188 
1189   CommitTargetLoweringOpt(TLO);
1190   return true;
1191 }
1192 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1193 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1194   SDLoc DL(Load);
1195   EVT VT = Load->getValueType(0);
1196   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1197 
1198   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1199              Trunc.getNode()->dump(&DAG); dbgs() << '\n');
1200   WorklistRemover DeadNodes(*this);
1201   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1202   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1203   deleteAndRecombine(Load);
1204   AddToWorklist(Trunc.getNode());
1205 }
1206 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1207 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1208   Replace = false;
1209   SDLoc DL(Op);
1210   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1211     LoadSDNode *LD = cast<LoadSDNode>(Op);
1212     EVT MemVT = LD->getMemoryVT();
1213     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1214                                                       : LD->getExtensionType();
1215     Replace = true;
1216     return DAG.getExtLoad(ExtType, DL, PVT,
1217                           LD->getChain(), LD->getBasePtr(),
1218                           MemVT, LD->getMemOperand());
1219   }
1220 
1221   unsigned Opc = Op.getOpcode();
1222   switch (Opc) {
1223   default: break;
1224   case ISD::AssertSext:
1225     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1226       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1227     break;
1228   case ISD::AssertZext:
1229     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1230       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1231     break;
1232   case ISD::Constant: {
1233     unsigned ExtOpc =
1234       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1235     return DAG.getNode(ExtOpc, DL, PVT, Op);
1236   }
1237   }
1238 
1239   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1240     return SDValue();
1241   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1242 }
1243 
SExtPromoteOperand(SDValue Op,EVT PVT)1244 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1245   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1246     return SDValue();
1247   EVT OldVT = Op.getValueType();
1248   SDLoc DL(Op);
1249   bool Replace = false;
1250   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1251   if (!NewOp.getNode())
1252     return SDValue();
1253   AddToWorklist(NewOp.getNode());
1254 
1255   if (Replace)
1256     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1257   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1258                      DAG.getValueType(OldVT));
1259 }
1260 
ZExtPromoteOperand(SDValue Op,EVT PVT)1261 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1262   EVT OldVT = Op.getValueType();
1263   SDLoc DL(Op);
1264   bool Replace = false;
1265   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1266   if (!NewOp.getNode())
1267     return SDValue();
1268   AddToWorklist(NewOp.getNode());
1269 
1270   if (Replace)
1271     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1272   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1273 }
1274 
1275 /// Promote the specified integer binary operation if the target indicates it is
1276 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1277 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1278 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1279   if (!LegalOperations)
1280     return SDValue();
1281 
1282   EVT VT = Op.getValueType();
1283   if (VT.isVector() || !VT.isInteger())
1284     return SDValue();
1285 
1286   // If operation type is 'undesirable', e.g. i16 on x86, consider
1287   // promoting it.
1288   unsigned Opc = Op.getOpcode();
1289   if (TLI.isTypeDesirableForOp(Opc, VT))
1290     return SDValue();
1291 
1292   EVT PVT = VT;
1293   // Consult target whether it is a good idea to promote this operation and
1294   // what's the right type to promote it to.
1295   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1296     assert(PVT != VT && "Don't know what type to promote to!");
1297 
1298     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1299 
1300     bool Replace0 = false;
1301     SDValue N0 = Op.getOperand(0);
1302     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1303 
1304     bool Replace1 = false;
1305     SDValue N1 = Op.getOperand(1);
1306     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1307     SDLoc DL(Op);
1308 
1309     SDValue RV =
1310         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1311 
1312     // We are always replacing N0/N1's use in N and only need additional
1313     // replacements if there are additional uses.
1314     // Note: We are checking uses of the *nodes* (SDNode) rather than values
1315     //       (SDValue) here because the node may reference multiple values
1316     //       (for example, the chain value of a load node).
1317     Replace0 &= !N0->hasOneUse();
1318     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1319 
1320     // Combine Op here so it is preserved past replacements.
1321     CombineTo(Op.getNode(), RV);
1322 
1323     // If operands have a use ordering, make sure we deal with
1324     // predecessor first.
1325     if (Replace0 && Replace1 && N0.getNode()->isPredecessorOf(N1.getNode())) {
1326       std::swap(N0, N1);
1327       std::swap(NN0, NN1);
1328     }
1329 
1330     if (Replace0) {
1331       AddToWorklist(NN0.getNode());
1332       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1333     }
1334     if (Replace1) {
1335       AddToWorklist(NN1.getNode());
1336       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1337     }
1338     return Op;
1339   }
1340   return SDValue();
1341 }
1342 
1343 /// Promote the specified integer shift operation if the target indicates it is
1344 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1345 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1346 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1347   if (!LegalOperations)
1348     return SDValue();
1349 
1350   EVT VT = Op.getValueType();
1351   if (VT.isVector() || !VT.isInteger())
1352     return SDValue();
1353 
1354   // If operation type is 'undesirable', e.g. i16 on x86, consider
1355   // promoting it.
1356   unsigned Opc = Op.getOpcode();
1357   if (TLI.isTypeDesirableForOp(Opc, VT))
1358     return SDValue();
1359 
1360   EVT PVT = VT;
1361   // Consult target whether it is a good idea to promote this operation and
1362   // what's the right type to promote it to.
1363   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1364     assert(PVT != VT && "Don't know what type to promote to!");
1365 
1366     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1367 
1368     bool Replace = false;
1369     SDValue N0 = Op.getOperand(0);
1370     SDValue N1 = Op.getOperand(1);
1371     if (Opc == ISD::SRA)
1372       N0 = SExtPromoteOperand(N0, PVT);
1373     else if (Opc == ISD::SRL)
1374       N0 = ZExtPromoteOperand(N0, PVT);
1375     else
1376       N0 = PromoteOperand(N0, PVT, Replace);
1377 
1378     if (!N0.getNode())
1379       return SDValue();
1380 
1381     SDLoc DL(Op);
1382     SDValue RV =
1383         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1384 
1385     if (Replace)
1386       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1387 
1388     // Deal with Op being deleted.
1389     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1390       return RV;
1391   }
1392   return SDValue();
1393 }
1394 
PromoteExtend(SDValue Op)1395 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1396   if (!LegalOperations)
1397     return SDValue();
1398 
1399   EVT VT = Op.getValueType();
1400   if (VT.isVector() || !VT.isInteger())
1401     return SDValue();
1402 
1403   // If operation type is 'undesirable', e.g. i16 on x86, consider
1404   // promoting it.
1405   unsigned Opc = Op.getOpcode();
1406   if (TLI.isTypeDesirableForOp(Opc, VT))
1407     return SDValue();
1408 
1409   EVT PVT = VT;
1410   // Consult target whether it is a good idea to promote this operation and
1411   // what's the right type to promote it to.
1412   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1413     assert(PVT != VT && "Don't know what type to promote to!");
1414     // fold (aext (aext x)) -> (aext x)
1415     // fold (aext (zext x)) -> (zext x)
1416     // fold (aext (sext x)) -> (sext x)
1417     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1418     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1419   }
1420   return SDValue();
1421 }
1422 
PromoteLoad(SDValue Op)1423 bool DAGCombiner::PromoteLoad(SDValue Op) {
1424   if (!LegalOperations)
1425     return false;
1426 
1427   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1428     return false;
1429 
1430   EVT VT = Op.getValueType();
1431   if (VT.isVector() || !VT.isInteger())
1432     return false;
1433 
1434   // If operation type is 'undesirable', e.g. i16 on x86, consider
1435   // promoting it.
1436   unsigned Opc = Op.getOpcode();
1437   if (TLI.isTypeDesirableForOp(Opc, VT))
1438     return false;
1439 
1440   EVT PVT = VT;
1441   // Consult target whether it is a good idea to promote this operation and
1442   // what's the right type to promote it to.
1443   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1444     assert(PVT != VT && "Don't know what type to promote to!");
1445 
1446     SDLoc DL(Op);
1447     SDNode *N = Op.getNode();
1448     LoadSDNode *LD = cast<LoadSDNode>(N);
1449     EVT MemVT = LD->getMemoryVT();
1450     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1451                                                       : LD->getExtensionType();
1452     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1453                                    LD->getChain(), LD->getBasePtr(),
1454                                    MemVT, LD->getMemOperand());
1455     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1456 
1457     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1458                Result.getNode()->dump(&DAG); dbgs() << '\n');
1459     WorklistRemover DeadNodes(*this);
1460     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1461     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1462     deleteAndRecombine(N);
1463     AddToWorklist(Result.getNode());
1464     return true;
1465   }
1466   return false;
1467 }
1468 
1469 /// Recursively delete a node which has no uses and any operands for
1470 /// which it is the only use.
1471 ///
1472 /// Note that this both deletes the nodes and removes them from the worklist.
1473 /// It also adds any nodes who have had a user deleted to the worklist as they
1474 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1475 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1476   if (!N->use_empty())
1477     return false;
1478 
1479   SmallSetVector<SDNode *, 16> Nodes;
1480   Nodes.insert(N);
1481   do {
1482     N = Nodes.pop_back_val();
1483     if (!N)
1484       continue;
1485 
1486     if (N->use_empty()) {
1487       for (const SDValue &ChildN : N->op_values())
1488         Nodes.insert(ChildN.getNode());
1489 
1490       removeFromWorklist(N);
1491       DAG.DeleteNode(N);
1492     } else {
1493       AddToWorklist(N);
1494     }
1495   } while (!Nodes.empty());
1496   return true;
1497 }
1498 
1499 //===----------------------------------------------------------------------===//
1500 //  Main DAG Combiner implementation
1501 //===----------------------------------------------------------------------===//
1502 
Run(CombineLevel AtLevel)1503 void DAGCombiner::Run(CombineLevel AtLevel) {
1504   // set the instance variables, so that the various visit routines may use it.
1505   Level = AtLevel;
1506   LegalDAG = Level >= AfterLegalizeDAG;
1507   LegalOperations = Level >= AfterLegalizeVectorOps;
1508   LegalTypes = Level >= AfterLegalizeTypes;
1509 
1510   WorklistInserter AddNodes(*this);
1511 
1512   // Add all the dag nodes to the worklist.
1513   for (SDNode &Node : DAG.allnodes())
1514     AddToWorklist(&Node);
1515 
1516   // Create a dummy node (which is not added to allnodes), that adds a reference
1517   // to the root node, preventing it from being deleted, and tracking any
1518   // changes of the root.
1519   HandleSDNode Dummy(DAG.getRoot());
1520 
1521   // While we have a valid worklist entry node, try to combine it.
1522   while (SDNode *N = getNextWorklistEntry()) {
1523     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1524     // N is deleted from the DAG, since they too may now be dead or may have a
1525     // reduced number of uses, allowing other xforms.
1526     if (recursivelyDeleteUnusedNodes(N))
1527       continue;
1528 
1529     WorklistRemover DeadNodes(*this);
1530 
1531     // If this combine is running after legalizing the DAG, re-legalize any
1532     // nodes pulled off the worklist.
1533     if (LegalDAG) {
1534       SmallSetVector<SDNode *, 16> UpdatedNodes;
1535       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1536 
1537       for (SDNode *LN : UpdatedNodes)
1538         AddToWorklistWithUsers(LN);
1539 
1540       if (!NIsValid)
1541         continue;
1542     }
1543 
1544     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1545 
1546     // Add any operands of the new node which have not yet been combined to the
1547     // worklist as well. Because the worklist uniques things already, this
1548     // won't repeatedly process the same operand.
1549     CombinedNodes.insert(N);
1550     for (const SDValue &ChildN : N->op_values())
1551       if (!CombinedNodes.count(ChildN.getNode()))
1552         AddToWorklist(ChildN.getNode());
1553 
1554     SDValue RV = combine(N);
1555 
1556     if (!RV.getNode())
1557       continue;
1558 
1559     ++NodesCombined;
1560 
1561     // If we get back the same node we passed in, rather than a new node or
1562     // zero, we know that the node must have defined multiple values and
1563     // CombineTo was used.  Since CombineTo takes care of the worklist
1564     // mechanics for us, we have no work to do in this case.
1565     if (RV.getNode() == N)
1566       continue;
1567 
1568     assert(N->getOpcode() != ISD::DELETED_NODE &&
1569            RV.getOpcode() != ISD::DELETED_NODE &&
1570            "Node was deleted but visit returned new node!");
1571 
1572     LLVM_DEBUG(dbgs() << " ... into: "; RV.getNode()->dump(&DAG));
1573 
1574     if (N->getNumValues() == RV.getNode()->getNumValues())
1575       DAG.ReplaceAllUsesWith(N, RV.getNode());
1576     else {
1577       assert(N->getValueType(0) == RV.getValueType() &&
1578              N->getNumValues() == 1 && "Type mismatch");
1579       DAG.ReplaceAllUsesWith(N, &RV);
1580     }
1581 
1582     // Push the new node and any users onto the worklist.  Omit this if the
1583     // new node is the EntryToken (e.g. if a store managed to get optimized
1584     // out), because re-visiting the EntryToken and its users will not uncover
1585     // any additional opportunities, but there may be a large number of such
1586     // users, potentially causing compile time explosion.
1587     if (RV.getOpcode() != ISD::EntryToken) {
1588       AddToWorklist(RV.getNode());
1589       AddUsersToWorklist(RV.getNode());
1590     }
1591 
1592     // Finally, if the node is now dead, remove it from the graph.  The node
1593     // may not be dead if the replacement process recursively simplified to
1594     // something else needing this node. This will also take care of adding any
1595     // operands which have lost a user to the worklist.
1596     recursivelyDeleteUnusedNodes(N);
1597   }
1598 
1599   // If the root changed (e.g. it was a dead load, update the root).
1600   DAG.setRoot(Dummy.getValue());
1601   DAG.RemoveDeadNodes();
1602 }
1603 
visit(SDNode * N)1604 SDValue DAGCombiner::visit(SDNode *N) {
1605   switch (N->getOpcode()) {
1606   default: break;
1607   case ISD::TokenFactor:        return visitTokenFactor(N);
1608   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1609   case ISD::ADD:                return visitADD(N);
1610   case ISD::SUB:                return visitSUB(N);
1611   case ISD::SADDSAT:
1612   case ISD::UADDSAT:            return visitADDSAT(N);
1613   case ISD::SSUBSAT:
1614   case ISD::USUBSAT:            return visitSUBSAT(N);
1615   case ISD::ADDC:               return visitADDC(N);
1616   case ISD::SADDO:
1617   case ISD::UADDO:              return visitADDO(N);
1618   case ISD::SUBC:               return visitSUBC(N);
1619   case ISD::SSUBO:
1620   case ISD::USUBO:              return visitSUBO(N);
1621   case ISD::ADDE:               return visitADDE(N);
1622   case ISD::ADDCARRY:           return visitADDCARRY(N);
1623   case ISD::SADDO_CARRY:        return visitSADDO_CARRY(N);
1624   case ISD::SUBE:               return visitSUBE(N);
1625   case ISD::SUBCARRY:           return visitSUBCARRY(N);
1626   case ISD::SSUBO_CARRY:        return visitSSUBO_CARRY(N);
1627   case ISD::SMULFIX:
1628   case ISD::SMULFIXSAT:
1629   case ISD::UMULFIX:
1630   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1631   case ISD::MUL:                return visitMUL(N);
1632   case ISD::SDIV:               return visitSDIV(N);
1633   case ISD::UDIV:               return visitUDIV(N);
1634   case ISD::SREM:
1635   case ISD::UREM:               return visitREM(N);
1636   case ISD::MULHU:              return visitMULHU(N);
1637   case ISD::MULHS:              return visitMULHS(N);
1638   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1639   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1640   case ISD::SMULO:
1641   case ISD::UMULO:              return visitMULO(N);
1642   case ISD::SMIN:
1643   case ISD::SMAX:
1644   case ISD::UMIN:
1645   case ISD::UMAX:               return visitIMINMAX(N);
1646   case ISD::AND:                return visitAND(N);
1647   case ISD::OR:                 return visitOR(N);
1648   case ISD::XOR:                return visitXOR(N);
1649   case ISD::SHL:                return visitSHL(N);
1650   case ISD::SRA:                return visitSRA(N);
1651   case ISD::SRL:                return visitSRL(N);
1652   case ISD::ROTR:
1653   case ISD::ROTL:               return visitRotate(N);
1654   case ISD::FSHL:
1655   case ISD::FSHR:               return visitFunnelShift(N);
1656   case ISD::ABS:                return visitABS(N);
1657   case ISD::BSWAP:              return visitBSWAP(N);
1658   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1659   case ISD::CTLZ:               return visitCTLZ(N);
1660   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1661   case ISD::CTTZ:               return visitCTTZ(N);
1662   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1663   case ISD::CTPOP:              return visitCTPOP(N);
1664   case ISD::SELECT:             return visitSELECT(N);
1665   case ISD::VSELECT:            return visitVSELECT(N);
1666   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1667   case ISD::SETCC:              return visitSETCC(N);
1668   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1669   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1670   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1671   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1672   case ISD::AssertSext:
1673   case ISD::AssertZext:         return visitAssertExt(N);
1674   case ISD::AssertAlign:        return visitAssertAlign(N);
1675   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
1676   case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N);
1677   case ISD::ZERO_EXTEND_VECTOR_INREG: return visitZERO_EXTEND_VECTOR_INREG(N);
1678   case ISD::TRUNCATE:           return visitTRUNCATE(N);
1679   case ISD::BITCAST:            return visitBITCAST(N);
1680   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
1681   case ISD::FADD:               return visitFADD(N);
1682   case ISD::STRICT_FADD:        return visitSTRICT_FADD(N);
1683   case ISD::FSUB:               return visitFSUB(N);
1684   case ISD::FMUL:               return visitFMUL(N);
1685   case ISD::FMA:                return visitFMA(N);
1686   case ISD::FDIV:               return visitFDIV(N);
1687   case ISD::FREM:               return visitFREM(N);
1688   case ISD::FSQRT:              return visitFSQRT(N);
1689   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
1690   case ISD::FPOW:               return visitFPOW(N);
1691   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
1692   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
1693   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
1694   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
1695   case ISD::FP_ROUND:           return visitFP_ROUND(N);
1696   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
1697   case ISD::FNEG:               return visitFNEG(N);
1698   case ISD::FABS:               return visitFABS(N);
1699   case ISD::FFLOOR:             return visitFFLOOR(N);
1700   case ISD::FMINNUM:            return visitFMINNUM(N);
1701   case ISD::FMAXNUM:            return visitFMAXNUM(N);
1702   case ISD::FMINIMUM:           return visitFMINIMUM(N);
1703   case ISD::FMAXIMUM:           return visitFMAXIMUM(N);
1704   case ISD::FCEIL:              return visitFCEIL(N);
1705   case ISD::FTRUNC:             return visitFTRUNC(N);
1706   case ISD::BRCOND:             return visitBRCOND(N);
1707   case ISD::BR_CC:              return visitBR_CC(N);
1708   case ISD::LOAD:               return visitLOAD(N);
1709   case ISD::STORE:              return visitSTORE(N);
1710   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
1711   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1712   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
1713   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
1714   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
1715   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
1716   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
1717   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
1718   case ISD::MGATHER:            return visitMGATHER(N);
1719   case ISD::MLOAD:              return visitMLOAD(N);
1720   case ISD::MSCATTER:           return visitMSCATTER(N);
1721   case ISD::MSTORE:             return visitMSTORE(N);
1722   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
1723   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
1724   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
1725   case ISD::FREEZE:             return visitFREEZE(N);
1726   case ISD::VECREDUCE_FADD:
1727   case ISD::VECREDUCE_FMUL:
1728   case ISD::VECREDUCE_ADD:
1729   case ISD::VECREDUCE_MUL:
1730   case ISD::VECREDUCE_AND:
1731   case ISD::VECREDUCE_OR:
1732   case ISD::VECREDUCE_XOR:
1733   case ISD::VECREDUCE_SMAX:
1734   case ISD::VECREDUCE_SMIN:
1735   case ISD::VECREDUCE_UMAX:
1736   case ISD::VECREDUCE_UMIN:
1737   case ISD::VECREDUCE_FMAX:
1738   case ISD::VECREDUCE_FMIN:     return visitVECREDUCE(N);
1739   }
1740   return SDValue();
1741 }
1742 
combine(SDNode * N)1743 SDValue DAGCombiner::combine(SDNode *N) {
1744   SDValue RV;
1745   if (!DisableGenericCombines)
1746     RV = visit(N);
1747 
1748   // If nothing happened, try a target-specific DAG combine.
1749   if (!RV.getNode()) {
1750     assert(N->getOpcode() != ISD::DELETED_NODE &&
1751            "Node was deleted but visit returned NULL!");
1752 
1753     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1754         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
1755 
1756       // Expose the DAG combiner to the target combiner impls.
1757       TargetLowering::DAGCombinerInfo
1758         DagCombineInfo(DAG, Level, false, this);
1759 
1760       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
1761     }
1762   }
1763 
1764   // If nothing happened still, try promoting the operation.
1765   if (!RV.getNode()) {
1766     switch (N->getOpcode()) {
1767     default: break;
1768     case ISD::ADD:
1769     case ISD::SUB:
1770     case ISD::MUL:
1771     case ISD::AND:
1772     case ISD::OR:
1773     case ISD::XOR:
1774       RV = PromoteIntBinOp(SDValue(N, 0));
1775       break;
1776     case ISD::SHL:
1777     case ISD::SRA:
1778     case ISD::SRL:
1779       RV = PromoteIntShiftOp(SDValue(N, 0));
1780       break;
1781     case ISD::SIGN_EXTEND:
1782     case ISD::ZERO_EXTEND:
1783     case ISD::ANY_EXTEND:
1784       RV = PromoteExtend(SDValue(N, 0));
1785       break;
1786     case ISD::LOAD:
1787       if (PromoteLoad(SDValue(N, 0)))
1788         RV = SDValue(N, 0);
1789       break;
1790     }
1791   }
1792 
1793   // If N is a commutative binary node, try to eliminate it if the commuted
1794   // version is already present in the DAG.
1795   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) &&
1796       N->getNumValues() == 1) {
1797     SDValue N0 = N->getOperand(0);
1798     SDValue N1 = N->getOperand(1);
1799 
1800     // Constant operands are canonicalized to RHS.
1801     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
1802       SDValue Ops[] = {N1, N0};
1803       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
1804                                             N->getFlags());
1805       if (CSENode)
1806         return SDValue(CSENode, 0);
1807     }
1808   }
1809 
1810   return RV;
1811 }
1812 
1813 /// Given a node, return its input chain if it has one, otherwise return a null
1814 /// sd operand.
getInputChainForNode(SDNode * N)1815 static SDValue getInputChainForNode(SDNode *N) {
1816   if (unsigned NumOps = N->getNumOperands()) {
1817     if (N->getOperand(0).getValueType() == MVT::Other)
1818       return N->getOperand(0);
1819     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
1820       return N->getOperand(NumOps-1);
1821     for (unsigned i = 1; i < NumOps-1; ++i)
1822       if (N->getOperand(i).getValueType() == MVT::Other)
1823         return N->getOperand(i);
1824   }
1825   return SDValue();
1826 }
1827 
visitTokenFactor(SDNode * N)1828 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
1829   // If N has two operands, where one has an input chain equal to the other,
1830   // the 'other' chain is redundant.
1831   if (N->getNumOperands() == 2) {
1832     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
1833       return N->getOperand(0);
1834     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
1835       return N->getOperand(1);
1836   }
1837 
1838   // Don't simplify token factors if optnone.
1839   if (OptLevel == CodeGenOpt::None)
1840     return SDValue();
1841 
1842   // Don't simplify the token factor if the node itself has too many operands.
1843   if (N->getNumOperands() > TokenFactorInlineLimit)
1844     return SDValue();
1845 
1846   // If the sole user is a token factor, we should make sure we have a
1847   // chance to merge them together. This prevents TF chains from inhibiting
1848   // optimizations.
1849   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
1850     AddToWorklist(*(N->use_begin()));
1851 
1852   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
1853   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
1854   SmallPtrSet<SDNode*, 16> SeenOps;
1855   bool Changed = false;             // If we should replace this token factor.
1856 
1857   // Start out with this token factor.
1858   TFs.push_back(N);
1859 
1860   // Iterate through token factors.  The TFs grows when new token factors are
1861   // encountered.
1862   for (unsigned i = 0; i < TFs.size(); ++i) {
1863     // Limit number of nodes to inline, to avoid quadratic compile times.
1864     // We have to add the outstanding Token Factors to Ops, otherwise we might
1865     // drop Ops from the resulting Token Factors.
1866     if (Ops.size() > TokenFactorInlineLimit) {
1867       for (unsigned j = i; j < TFs.size(); j++)
1868         Ops.emplace_back(TFs[j], 0);
1869       // Drop unprocessed Token Factors from TFs, so we do not add them to the
1870       // combiner worklist later.
1871       TFs.resize(i);
1872       break;
1873     }
1874 
1875     SDNode *TF = TFs[i];
1876     // Check each of the operands.
1877     for (const SDValue &Op : TF->op_values()) {
1878       switch (Op.getOpcode()) {
1879       case ISD::EntryToken:
1880         // Entry tokens don't need to be added to the list. They are
1881         // redundant.
1882         Changed = true;
1883         break;
1884 
1885       case ISD::TokenFactor:
1886         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
1887           // Queue up for processing.
1888           TFs.push_back(Op.getNode());
1889           Changed = true;
1890           break;
1891         }
1892         LLVM_FALLTHROUGH;
1893 
1894       default:
1895         // Only add if it isn't already in the list.
1896         if (SeenOps.insert(Op.getNode()).second)
1897           Ops.push_back(Op);
1898         else
1899           Changed = true;
1900         break;
1901       }
1902     }
1903   }
1904 
1905   // Re-visit inlined Token Factors, to clean them up in case they have been
1906   // removed. Skip the first Token Factor, as this is the current node.
1907   for (unsigned i = 1, e = TFs.size(); i < e; i++)
1908     AddToWorklist(TFs[i]);
1909 
1910   // Remove Nodes that are chained to another node in the list. Do so
1911   // by walking up chains breath-first stopping when we've seen
1912   // another operand. In general we must climb to the EntryNode, but we can exit
1913   // early if we find all remaining work is associated with just one operand as
1914   // no further pruning is possible.
1915 
1916   // List of nodes to search through and original Ops from which they originate.
1917   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
1918   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
1919   SmallPtrSet<SDNode *, 16> SeenChains;
1920   bool DidPruneOps = false;
1921 
1922   unsigned NumLeftToConsider = 0;
1923   for (const SDValue &Op : Ops) {
1924     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
1925     OpWorkCount.push_back(1);
1926   }
1927 
1928   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
1929     // If this is an Op, we can remove the op from the list. Remark any
1930     // search associated with it as from the current OpNumber.
1931     if (SeenOps.contains(Op)) {
1932       Changed = true;
1933       DidPruneOps = true;
1934       unsigned OrigOpNumber = 0;
1935       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
1936         OrigOpNumber++;
1937       assert((OrigOpNumber != Ops.size()) &&
1938              "expected to find TokenFactor Operand");
1939       // Re-mark worklist from OrigOpNumber to OpNumber
1940       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
1941         if (Worklist[i].second == OrigOpNumber) {
1942           Worklist[i].second = OpNumber;
1943         }
1944       }
1945       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
1946       OpWorkCount[OrigOpNumber] = 0;
1947       NumLeftToConsider--;
1948     }
1949     // Add if it's a new chain
1950     if (SeenChains.insert(Op).second) {
1951       OpWorkCount[OpNumber]++;
1952       Worklist.push_back(std::make_pair(Op, OpNumber));
1953     }
1954   };
1955 
1956   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
1957     // We need at least be consider at least 2 Ops to prune.
1958     if (NumLeftToConsider <= 1)
1959       break;
1960     auto CurNode = Worklist[i].first;
1961     auto CurOpNumber = Worklist[i].second;
1962     assert((OpWorkCount[CurOpNumber] > 0) &&
1963            "Node should not appear in worklist");
1964     switch (CurNode->getOpcode()) {
1965     case ISD::EntryToken:
1966       // Hitting EntryToken is the only way for the search to terminate without
1967       // hitting
1968       // another operand's search. Prevent us from marking this operand
1969       // considered.
1970       NumLeftToConsider++;
1971       break;
1972     case ISD::TokenFactor:
1973       for (const SDValue &Op : CurNode->op_values())
1974         AddToWorklist(i, Op.getNode(), CurOpNumber);
1975       break;
1976     case ISD::LIFETIME_START:
1977     case ISD::LIFETIME_END:
1978     case ISD::CopyFromReg:
1979     case ISD::CopyToReg:
1980       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
1981       break;
1982     default:
1983       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
1984         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
1985       break;
1986     }
1987     OpWorkCount[CurOpNumber]--;
1988     if (OpWorkCount[CurOpNumber] == 0)
1989       NumLeftToConsider--;
1990   }
1991 
1992   // If we've changed things around then replace token factor.
1993   if (Changed) {
1994     SDValue Result;
1995     if (Ops.empty()) {
1996       // The entry token is the only possible outcome.
1997       Result = DAG.getEntryNode();
1998     } else {
1999       if (DidPruneOps) {
2000         SmallVector<SDValue, 8> PrunedOps;
2001         //
2002         for (const SDValue &Op : Ops) {
2003           if (SeenChains.count(Op.getNode()) == 0)
2004             PrunedOps.push_back(Op);
2005         }
2006         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2007       } else {
2008         Result = DAG.getTokenFactor(SDLoc(N), Ops);
2009       }
2010     }
2011     return Result;
2012   }
2013   return SDValue();
2014 }
2015 
2016 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2017 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2018   WorklistRemover DeadNodes(*this);
2019   // Replacing results may cause a different MERGE_VALUES to suddenly
2020   // be CSE'd with N, and carry its uses with it. Iterate until no
2021   // uses remain, to ensure that the node can be safely deleted.
2022   // First add the users of this node to the work list so that they
2023   // can be tried again once they have new operands.
2024   AddUsersToWorklist(N);
2025   do {
2026     // Do as a single replacement to avoid rewalking use lists.
2027     SmallVector<SDValue, 8> Ops;
2028     for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2029       Ops.push_back(N->getOperand(i));
2030     DAG.ReplaceAllUsesWith(N, Ops.data());
2031   } while (!N->use_empty());
2032   deleteAndRecombine(N);
2033   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
2034 }
2035 
2036 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2037 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2038 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2039   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2040   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2041 }
2042 
2043 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2044 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2045 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2046                                     const TargetLowering &TLI) {
2047   EVT VT;
2048   unsigned AS;
2049 
2050   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2051     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2052       return false;
2053     VT = LD->getMemoryVT();
2054     AS = LD->getAddressSpace();
2055   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2056     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2057       return false;
2058     VT = ST->getMemoryVT();
2059     AS = ST->getAddressSpace();
2060   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2061     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2062       return false;
2063     VT = LD->getMemoryVT();
2064     AS = LD->getAddressSpace();
2065   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2066     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2067       return false;
2068     VT = ST->getMemoryVT();
2069     AS = ST->getAddressSpace();
2070   } else
2071     return false;
2072 
2073   TargetLowering::AddrMode AM;
2074   if (N->getOpcode() == ISD::ADD) {
2075     AM.HasBaseReg = true;
2076     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2077     if (Offset)
2078       // [reg +/- imm]
2079       AM.BaseOffs = Offset->getSExtValue();
2080     else
2081       // [reg +/- reg]
2082       AM.Scale = 1;
2083   } else if (N->getOpcode() == ISD::SUB) {
2084     AM.HasBaseReg = true;
2085     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2086     if (Offset)
2087       // [reg +/- imm]
2088       AM.BaseOffs = -Offset->getSExtValue();
2089     else
2090       // [reg +/- reg]
2091       AM.Scale = 1;
2092   } else
2093     return false;
2094 
2095   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2096                                    VT.getTypeForEVT(*DAG.getContext()), AS);
2097 }
2098 
foldBinOpIntoSelect(SDNode * BO)2099 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2100   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2101          "Unexpected binary operator");
2102 
2103   // Don't do this unless the old select is going away. We want to eliminate the
2104   // binary operator, not replace a binop with a select.
2105   // TODO: Handle ISD::SELECT_CC.
2106   unsigned SelOpNo = 0;
2107   SDValue Sel = BO->getOperand(0);
2108   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2109     SelOpNo = 1;
2110     Sel = BO->getOperand(1);
2111   }
2112 
2113   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2114     return SDValue();
2115 
2116   SDValue CT = Sel.getOperand(1);
2117   if (!isConstantOrConstantVector(CT, true) &&
2118       !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2119     return SDValue();
2120 
2121   SDValue CF = Sel.getOperand(2);
2122   if (!isConstantOrConstantVector(CF, true) &&
2123       !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2124     return SDValue();
2125 
2126   // Bail out if any constants are opaque because we can't constant fold those.
2127   // The exception is "and" and "or" with either 0 or -1 in which case we can
2128   // propagate non constant operands into select. I.e.:
2129   // and (select Cond, 0, -1), X --> select Cond, 0, X
2130   // or X, (select Cond, -1, 0) --> select Cond, -1, X
2131   auto BinOpcode = BO->getOpcode();
2132   bool CanFoldNonConst =
2133       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2134       (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) &&
2135       (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF));
2136 
2137   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2138   if (!CanFoldNonConst &&
2139       !isConstantOrConstantVector(CBO, true) &&
2140       !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2141     return SDValue();
2142 
2143   EVT VT = BO->getValueType(0);
2144 
2145   // We have a select-of-constants followed by a binary operator with a
2146   // constant. Eliminate the binop by pulling the constant math into the select.
2147   // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
2148   SDLoc DL(Sel);
2149   SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
2150                           : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
2151   if (!CanFoldNonConst && !NewCT.isUndef() &&
2152       !isConstantOrConstantVector(NewCT, true) &&
2153       !DAG.isConstantFPBuildVectorOrConstantFP(NewCT))
2154     return SDValue();
2155 
2156   SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
2157                           : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
2158   if (!CanFoldNonConst && !NewCF.isUndef() &&
2159       !isConstantOrConstantVector(NewCF, true) &&
2160       !DAG.isConstantFPBuildVectorOrConstantFP(NewCF))
2161     return SDValue();
2162 
2163   SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2164   SelectOp->setFlags(BO->getFlags());
2165   return SelectOp;
2166 }
2167 
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2168 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2169   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2170          "Expecting add or sub");
2171 
2172   // Match a constant operand and a zext operand for the math instruction:
2173   // add Z, C
2174   // sub C, Z
2175   bool IsAdd = N->getOpcode() == ISD::ADD;
2176   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2177   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2178   auto *CN = dyn_cast<ConstantSDNode>(C);
2179   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2180     return SDValue();
2181 
2182   // Match the zext operand as a setcc of a boolean.
2183   if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2184       Z.getOperand(0).getValueType() != MVT::i1)
2185     return SDValue();
2186 
2187   // Match the compare as: setcc (X & 1), 0, eq.
2188   SDValue SetCC = Z.getOperand(0);
2189   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2190   if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2191       SetCC.getOperand(0).getOpcode() != ISD::AND ||
2192       !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2193     return SDValue();
2194 
2195   // We are adding/subtracting a constant and an inverted low bit. Turn that
2196   // into a subtract/add of the low bit with incremented/decremented constant:
2197   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2198   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2199   EVT VT = C.getValueType();
2200   SDLoc DL(N);
2201   SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2202   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2203                        DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2204   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2205 }
2206 
2207 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2208 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2209 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2210   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2211          "Expecting add or sub");
2212 
2213   // We need a constant operand for the add/sub, and the other operand is a
2214   // logical shift right: add (srl), C or sub C, (srl).
2215   bool IsAdd = N->getOpcode() == ISD::ADD;
2216   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2217   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2218   if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2219       ShiftOp.getOpcode() != ISD::SRL)
2220     return SDValue();
2221 
2222   // The shift must be of a 'not' value.
2223   SDValue Not = ShiftOp.getOperand(0);
2224   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2225     return SDValue();
2226 
2227   // The shift must be moving the sign bit to the least-significant-bit.
2228   EVT VT = ShiftOp.getValueType();
2229   SDValue ShAmt = ShiftOp.getOperand(1);
2230   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2231   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2232     return SDValue();
2233 
2234   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2235   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2236   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2237   SDLoc DL(N);
2238   auto ShOpcode = IsAdd ? ISD::SRA : ISD::SRL;
2239   SDValue NewShift = DAG.getNode(ShOpcode, DL, VT, Not.getOperand(0), ShAmt);
2240   if (SDValue NewC =
2241           DAG.FoldConstantArithmetic(IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2242                                      {ConstantOp, DAG.getConstant(1, DL, VT)}))
2243     return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2244   return SDValue();
2245 }
2246 
2247 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2248 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2249 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2250 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2251   SDValue N0 = N->getOperand(0);
2252   SDValue N1 = N->getOperand(1);
2253   EVT VT = N0.getValueType();
2254   SDLoc DL(N);
2255 
2256   // fold vector ops
2257   if (VT.isVector()) {
2258     if (SDValue FoldedVOp = SimplifyVBinOp(N))
2259       return FoldedVOp;
2260 
2261     // fold (add x, 0) -> x, vector edition
2262     if (ISD::isBuildVectorAllZeros(N1.getNode()))
2263       return N0;
2264     if (ISD::isBuildVectorAllZeros(N0.getNode()))
2265       return N1;
2266   }
2267 
2268   // fold (add x, undef) -> undef
2269   if (N0.isUndef())
2270     return N0;
2271 
2272   if (N1.isUndef())
2273     return N1;
2274 
2275   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
2276     // canonicalize constant to RHS
2277     if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
2278       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2279     // fold (add c1, c2) -> c1+c2
2280     return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1});
2281   }
2282 
2283   // fold (add x, 0) -> x
2284   if (isNullConstant(N1))
2285     return N0;
2286 
2287   if (isConstantOrConstantVector(N1, /* NoOpaque */ true)) {
2288     // fold ((A-c1)+c2) -> (A+(c2-c1))
2289     if (N0.getOpcode() == ISD::SUB &&
2290         isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2291       SDValue Sub =
2292           DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N0.getOperand(1)});
2293       assert(Sub && "Constant folding failed");
2294       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2295     }
2296 
2297     // fold ((c1-A)+c2) -> (c1+c2)-A
2298     if (N0.getOpcode() == ISD::SUB &&
2299         isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) {
2300       SDValue Add =
2301           DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N0.getOperand(0)});
2302       assert(Add && "Constant folding failed");
2303       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2304     }
2305 
2306     // add (sext i1 X), 1 -> zext (not i1 X)
2307     // We don't transform this pattern:
2308     //   add (zext i1 X), -1 -> sext (not i1 X)
2309     // because most (?) targets generate better code for the zext form.
2310     if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2311         isOneOrOneSplat(N1)) {
2312       SDValue X = N0.getOperand(0);
2313       if ((!LegalOperations ||
2314            (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2315             TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2316           X.getScalarValueSizeInBits() == 1) {
2317         SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2318         return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2319       }
2320     }
2321 
2322     // Fold (add (or x, c0), c1) -> (add x, (c0 + c1)) if (or x, c0) is
2323     // equivalent to (add x, c0).
2324     if (N0.getOpcode() == ISD::OR &&
2325         isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true) &&
2326         DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) {
2327       if (SDValue Add0 = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT,
2328                                                     {N1, N0.getOperand(1)}))
2329         return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0);
2330     }
2331   }
2332 
2333   if (SDValue NewSel = foldBinOpIntoSelect(N))
2334     return NewSel;
2335 
2336   // reassociate add
2337   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N0, N1)) {
2338     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2339       return RADD;
2340   }
2341   // fold ((0-A) + B) -> B-A
2342   if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2343     return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2344 
2345   // fold (A + (0-B)) -> A-B
2346   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2347     return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2348 
2349   // fold (A+(B-A)) -> B
2350   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2351     return N1.getOperand(0);
2352 
2353   // fold ((B-A)+A) -> B
2354   if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2355     return N0.getOperand(0);
2356 
2357   // fold ((A-B)+(C-A)) -> (C-B)
2358   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2359       N0.getOperand(0) == N1.getOperand(1))
2360     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2361                        N0.getOperand(1));
2362 
2363   // fold ((A-B)+(B-C)) -> (A-C)
2364   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2365       N0.getOperand(1) == N1.getOperand(0))
2366     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2367                        N1.getOperand(1));
2368 
2369   // fold (A+(B-(A+C))) to (B-C)
2370   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2371       N0 == N1.getOperand(1).getOperand(0))
2372     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2373                        N1.getOperand(1).getOperand(1));
2374 
2375   // fold (A+(B-(C+A))) to (B-C)
2376   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2377       N0 == N1.getOperand(1).getOperand(1))
2378     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2379                        N1.getOperand(1).getOperand(0));
2380 
2381   // fold (A+((B-A)+or-C)) to (B+or-C)
2382   if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2383       N1.getOperand(0).getOpcode() == ISD::SUB &&
2384       N0 == N1.getOperand(0).getOperand(1))
2385     return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2386                        N1.getOperand(1));
2387 
2388   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2389   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB) {
2390     SDValue N00 = N0.getOperand(0);
2391     SDValue N01 = N0.getOperand(1);
2392     SDValue N10 = N1.getOperand(0);
2393     SDValue N11 = N1.getOperand(1);
2394 
2395     if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2396       return DAG.getNode(ISD::SUB, DL, VT,
2397                          DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2398                          DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2399   }
2400 
2401   // fold (add (umax X, C), -C) --> (usubsat X, C)
2402   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2403     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2404       return (!Max && !Op) ||
2405              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2406     };
2407     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2408                                   /*AllowUndefs*/ true))
2409       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2410                          N0.getOperand(1));
2411   }
2412 
2413   if (SimplifyDemandedBits(SDValue(N, 0)))
2414     return SDValue(N, 0);
2415 
2416   if (isOneOrOneSplat(N1)) {
2417     // fold (add (xor a, -1), 1) -> (sub 0, a)
2418     if (isBitwiseNot(N0))
2419       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2420                          N0.getOperand(0));
2421 
2422     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2423     if (N0.getOpcode() == ISD::ADD ||
2424         N0.getOpcode() == ISD::UADDO ||
2425         N0.getOpcode() == ISD::SADDO) {
2426       SDValue A, Xor;
2427 
2428       if (isBitwiseNot(N0.getOperand(0))) {
2429         A = N0.getOperand(1);
2430         Xor = N0.getOperand(0);
2431       } else if (isBitwiseNot(N0.getOperand(1))) {
2432         A = N0.getOperand(0);
2433         Xor = N0.getOperand(1);
2434       }
2435 
2436       if (Xor)
2437         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2438     }
2439 
2440     // Look for:
2441     //   add (add x, y), 1
2442     // And if the target does not like this form then turn into:
2443     //   sub y, (xor x, -1)
2444     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() &&
2445         N0.getOpcode() == ISD::ADD) {
2446       SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2447                                 DAG.getAllOnesConstant(DL, VT));
2448       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2449     }
2450   }
2451 
2452   // (x - y) + -1  ->  add (xor y, -1), x
2453   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2454       isAllOnesOrAllOnesSplat(N1)) {
2455     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2456     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2457   }
2458 
2459   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2460     return Combined;
2461 
2462   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2463     return Combined;
2464 
2465   return SDValue();
2466 }
2467 
visitADD(SDNode * N)2468 SDValue DAGCombiner::visitADD(SDNode *N) {
2469   SDValue N0 = N->getOperand(0);
2470   SDValue N1 = N->getOperand(1);
2471   EVT VT = N0.getValueType();
2472   SDLoc DL(N);
2473 
2474   if (SDValue Combined = visitADDLike(N))
2475     return Combined;
2476 
2477   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2478     return V;
2479 
2480   if (SDValue V = foldAddSubOfSignBit(N, DAG))
2481     return V;
2482 
2483   // fold (a+b) -> (a|b) iff a and b share no bits.
2484   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2485       DAG.haveNoCommonBitsSet(N0, N1))
2486     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2487 
2488   // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2489   if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2490     const APInt &C0 = N0->getConstantOperandAPInt(0);
2491     const APInt &C1 = N1->getConstantOperandAPInt(0);
2492     return DAG.getVScale(DL, VT, C0 + C1);
2493   }
2494 
2495   // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2496   if ((N0.getOpcode() == ISD::ADD) &&
2497       (N0.getOperand(1).getOpcode() == ISD::VSCALE) &&
2498       (N1.getOpcode() == ISD::VSCALE)) {
2499     const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2500     const APInt &VS1 = N1->getConstantOperandAPInt(0);
2501     SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
2502     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
2503   }
2504 
2505   return SDValue();
2506 }
2507 
visitADDSAT(SDNode * N)2508 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2509   unsigned Opcode = N->getOpcode();
2510   SDValue N0 = N->getOperand(0);
2511   SDValue N1 = N->getOperand(1);
2512   EVT VT = N0.getValueType();
2513   SDLoc DL(N);
2514 
2515   // fold vector ops
2516   if (VT.isVector()) {
2517     // TODO SimplifyVBinOp
2518 
2519     // fold (add_sat x, 0) -> x, vector edition
2520     if (ISD::isBuildVectorAllZeros(N1.getNode()))
2521       return N0;
2522     if (ISD::isBuildVectorAllZeros(N0.getNode()))
2523       return N1;
2524   }
2525 
2526   // fold (add_sat x, undef) -> -1
2527   if (N0.isUndef() || N1.isUndef())
2528     return DAG.getAllOnesConstant(DL, VT);
2529 
2530   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
2531     // canonicalize constant to RHS
2532     if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
2533       return DAG.getNode(Opcode, DL, VT, N1, N0);
2534     // fold (add_sat c1, c2) -> c3
2535     return DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1});
2536   }
2537 
2538   // fold (add_sat x, 0) -> x
2539   if (isNullConstant(N1))
2540     return N0;
2541 
2542   // If it cannot overflow, transform into an add.
2543   if (Opcode == ISD::UADDSAT)
2544     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2545       return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
2546 
2547   return SDValue();
2548 }
2549 
getAsCarry(const TargetLowering & TLI,SDValue V)2550 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
2551   bool Masked = false;
2552 
2553   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2554   while (true) {
2555     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2556       V = V.getOperand(0);
2557       continue;
2558     }
2559 
2560     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
2561       Masked = true;
2562       V = V.getOperand(0);
2563       continue;
2564     }
2565 
2566     break;
2567   }
2568 
2569   // If this is not a carry, return.
2570   if (V.getResNo() != 1)
2571     return SDValue();
2572 
2573   if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
2574       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
2575     return SDValue();
2576 
2577   EVT VT = V.getNode()->getValueType(0);
2578   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
2579     return SDValue();
2580 
2581   // If the result is masked, then no matter what kind of bool it is we can
2582   // return. If it isn't, then we need to make sure the bool type is either 0 or
2583   // 1 and not other values.
2584   if (Masked ||
2585       TLI.getBooleanContents(V.getValueType()) ==
2586           TargetLoweringBase::ZeroOrOneBooleanContent)
2587     return V;
2588 
2589   return SDValue();
2590 }
2591 
2592 /// Given the operands of an add/sub operation, see if the 2nd operand is a
2593 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
2594 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)2595 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
2596                                  SelectionDAG &DAG, const SDLoc &DL) {
2597   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
2598     return SDValue();
2599 
2600   EVT VT = N0.getValueType();
2601   if (DAG.ComputeNumSignBits(N1.getOperand(0)) != VT.getScalarSizeInBits())
2602     return SDValue();
2603 
2604   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
2605   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
2606   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N1.getOperand(0));
2607 }
2608 
2609 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)2610 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
2611                                           SDNode *LocReference) {
2612   EVT VT = N0.getValueType();
2613   SDLoc DL(LocReference);
2614 
2615   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
2616   if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
2617       isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
2618     return DAG.getNode(ISD::SUB, DL, VT, N0,
2619                        DAG.getNode(ISD::SHL, DL, VT,
2620                                    N1.getOperand(0).getOperand(1),
2621                                    N1.getOperand(1)));
2622 
2623   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
2624     return V;
2625 
2626   // Look for:
2627   //   add (add x, 1), y
2628   // And if the target does not like this form then turn into:
2629   //   sub y, (xor x, -1)
2630   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() &&
2631       N0.getOpcode() == ISD::ADD && isOneOrOneSplat(N0.getOperand(1))) {
2632     SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2633                               DAG.getAllOnesConstant(DL, VT));
2634     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
2635   }
2636 
2637   // Hoist one-use subtraction by non-opaque constant:
2638   //   (x - C) + y  ->  (x + y) - C
2639   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
2640   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2641       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
2642     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
2643     return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2644   }
2645   // Hoist one-use subtraction from non-opaque constant:
2646   //   (C - x) + y  ->  (y - x) + C
2647   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2648       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
2649     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2650     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
2651   }
2652 
2653   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
2654   // rather than 'add 0/-1' (the zext should get folded).
2655   // add (sext i1 Y), X --> sub X, (zext i1 Y)
2656   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
2657       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
2658       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
2659     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
2660     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
2661   }
2662 
2663   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
2664   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2665     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
2666     if (TN->getVT() == MVT::i1) {
2667       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
2668                                  DAG.getConstant(1, DL, VT));
2669       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
2670     }
2671   }
2672 
2673   // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2674   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
2675       N1.getResNo() == 0)
2676     return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
2677                        N0, N1.getOperand(0), N1.getOperand(2));
2678 
2679   // (add X, Carry) -> (addcarry X, 0, Carry)
2680   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2681     if (SDValue Carry = getAsCarry(TLI, N1))
2682       return DAG.getNode(ISD::ADDCARRY, DL,
2683                          DAG.getVTList(VT, Carry.getValueType()), N0,
2684                          DAG.getConstant(0, DL, VT), Carry);
2685 
2686   return SDValue();
2687 }
2688 
visitADDC(SDNode * N)2689 SDValue DAGCombiner::visitADDC(SDNode *N) {
2690   SDValue N0 = N->getOperand(0);
2691   SDValue N1 = N->getOperand(1);
2692   EVT VT = N0.getValueType();
2693   SDLoc DL(N);
2694 
2695   // If the flag result is dead, turn this into an ADD.
2696   if (!N->hasAnyUseOfValue(1))
2697     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2698                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2699 
2700   // canonicalize constant to RHS.
2701   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2702   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2703   if (N0C && !N1C)
2704     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
2705 
2706   // fold (addc x, 0) -> x + no carry out
2707   if (isNullConstant(N1))
2708     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
2709                                         DL, MVT::Glue));
2710 
2711   // If it cannot overflow, transform into an add.
2712   if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2713     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2714                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2715 
2716   return SDValue();
2717 }
2718 
2719 /**
2720  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
2721  * then the flip also occurs if computing the inverse is the same cost.
2722  * This function returns an empty SDValue in case it cannot flip the boolean
2723  * without increasing the cost of the computation. If you want to flip a boolean
2724  * no matter what, use DAG.getLogicalNOT.
2725  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)2726 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
2727                                   const TargetLowering &TLI,
2728                                   bool Force) {
2729   if (Force && isa<ConstantSDNode>(V))
2730     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2731 
2732   if (V.getOpcode() != ISD::XOR)
2733     return SDValue();
2734 
2735   ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
2736   if (!Const)
2737     return SDValue();
2738 
2739   EVT VT = V.getValueType();
2740 
2741   bool IsFlip = false;
2742   switch(TLI.getBooleanContents(VT)) {
2743     case TargetLowering::ZeroOrOneBooleanContent:
2744       IsFlip = Const->isOne();
2745       break;
2746     case TargetLowering::ZeroOrNegativeOneBooleanContent:
2747       IsFlip = Const->isAllOnesValue();
2748       break;
2749     case TargetLowering::UndefinedBooleanContent:
2750       IsFlip = (Const->getAPIntValue() & 0x01) == 1;
2751       break;
2752   }
2753 
2754   if (IsFlip)
2755     return V.getOperand(0);
2756   if (Force)
2757     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2758   return SDValue();
2759 }
2760 
visitADDO(SDNode * N)2761 SDValue DAGCombiner::visitADDO(SDNode *N) {
2762   SDValue N0 = N->getOperand(0);
2763   SDValue N1 = N->getOperand(1);
2764   EVT VT = N0.getValueType();
2765   bool IsSigned = (ISD::SADDO == N->getOpcode());
2766 
2767   EVT CarryVT = N->getValueType(1);
2768   SDLoc DL(N);
2769 
2770   // If the flag result is dead, turn this into an ADD.
2771   if (!N->hasAnyUseOfValue(1))
2772     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2773                      DAG.getUNDEF(CarryVT));
2774 
2775   // canonicalize constant to RHS.
2776   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2777       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2778     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
2779 
2780   // fold (addo x, 0) -> x + no carry out
2781   if (isNullOrNullSplat(N1))
2782     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
2783 
2784   if (!IsSigned) {
2785     // If it cannot overflow, transform into an add.
2786     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2787       return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2788                        DAG.getConstant(0, DL, CarryVT));
2789 
2790     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
2791     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
2792       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
2793                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
2794       return CombineTo(
2795           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
2796     }
2797 
2798     if (SDValue Combined = visitUADDOLike(N0, N1, N))
2799       return Combined;
2800 
2801     if (SDValue Combined = visitUADDOLike(N1, N0, N))
2802       return Combined;
2803   }
2804 
2805   return SDValue();
2806 }
2807 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)2808 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
2809   EVT VT = N0.getValueType();
2810   if (VT.isVector())
2811     return SDValue();
2812 
2813   // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2814   // If Y + 1 cannot overflow.
2815   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
2816     SDValue Y = N1.getOperand(0);
2817     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
2818     if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
2819       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
2820                          N1.getOperand(2));
2821   }
2822 
2823   // (uaddo X, Carry) -> (addcarry X, 0, Carry)
2824   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2825     if (SDValue Carry = getAsCarry(TLI, N1))
2826       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
2827                          DAG.getConstant(0, SDLoc(N), VT), Carry);
2828 
2829   return SDValue();
2830 }
2831 
visitADDE(SDNode * N)2832 SDValue DAGCombiner::visitADDE(SDNode *N) {
2833   SDValue N0 = N->getOperand(0);
2834   SDValue N1 = N->getOperand(1);
2835   SDValue CarryIn = N->getOperand(2);
2836 
2837   // canonicalize constant to RHS
2838   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2839   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2840   if (N0C && !N1C)
2841     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
2842                        N1, N0, CarryIn);
2843 
2844   // fold (adde x, y, false) -> (addc x, y)
2845   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
2846     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
2847 
2848   return SDValue();
2849 }
2850 
visitADDCARRY(SDNode * N)2851 SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
2852   SDValue N0 = N->getOperand(0);
2853   SDValue N1 = N->getOperand(1);
2854   SDValue CarryIn = N->getOperand(2);
2855   SDLoc DL(N);
2856 
2857   // canonicalize constant to RHS
2858   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2859   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2860   if (N0C && !N1C)
2861     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
2862 
2863   // fold (addcarry x, y, false) -> (uaddo x, y)
2864   if (isNullConstant(CarryIn)) {
2865     if (!LegalOperations ||
2866         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
2867       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
2868   }
2869 
2870   // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
2871   if (isNullConstant(N0) && isNullConstant(N1)) {
2872     EVT VT = N0.getValueType();
2873     EVT CarryVT = CarryIn.getValueType();
2874     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
2875     AddToWorklist(CarryExt.getNode());
2876     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
2877                                     DAG.getConstant(1, DL, VT)),
2878                      DAG.getConstant(0, DL, CarryVT));
2879   }
2880 
2881   if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
2882     return Combined;
2883 
2884   if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
2885     return Combined;
2886 
2887   return SDValue();
2888 }
2889 
visitSADDO_CARRY(SDNode * N)2890 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
2891   SDValue N0 = N->getOperand(0);
2892   SDValue N1 = N->getOperand(1);
2893   SDValue CarryIn = N->getOperand(2);
2894   SDLoc DL(N);
2895 
2896   // canonicalize constant to RHS
2897   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2898   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2899   if (N0C && !N1C)
2900     return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
2901 
2902   // fold (saddo_carry x, y, false) -> (saddo x, y)
2903   if (isNullConstant(CarryIn)) {
2904     if (!LegalOperations ||
2905         TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
2906       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
2907   }
2908 
2909   return SDValue();
2910 }
2911 
2912 /**
2913  * If we are facing some sort of diamond carry propapagtion pattern try to
2914  * break it up to generate something like:
2915  *   (addcarry X, 0, (addcarry A, B, Z):Carry)
2916  *
2917  * The end result is usually an increase in operation required, but because the
2918  * carry is now linearized, other tranforms can kick in and optimize the DAG.
2919  *
2920  * Patterns typically look something like
2921  *            (uaddo A, B)
2922  *             /       \
2923  *          Carry      Sum
2924  *            |          \
2925  *            | (addcarry *, 0, Z)
2926  *            |       /
2927  *             \   Carry
2928  *              |   /
2929  * (addcarry X, *, *)
2930  *
2931  * But numerous variation exist. Our goal is to identify A, B, X and Z and
2932  * produce a combine with a single path for carry propagation.
2933  */
combineADDCARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)2934 static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
2935                                       SDValue X, SDValue Carry0, SDValue Carry1,
2936                                       SDNode *N) {
2937   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
2938     return SDValue();
2939   if (Carry1.getOpcode() != ISD::UADDO)
2940     return SDValue();
2941 
2942   SDValue Z;
2943 
2944   /**
2945    * First look for a suitable Z. It will present itself in the form of
2946    * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
2947    */
2948   if (Carry0.getOpcode() == ISD::ADDCARRY &&
2949       isNullConstant(Carry0.getOperand(1))) {
2950     Z = Carry0.getOperand(2);
2951   } else if (Carry0.getOpcode() == ISD::UADDO &&
2952              isOneConstant(Carry0.getOperand(1))) {
2953     EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
2954     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
2955   } else {
2956     // We couldn't find a suitable Z.
2957     return SDValue();
2958   }
2959 
2960 
2961   auto cancelDiamond = [&](SDValue A,SDValue B) {
2962     SDLoc DL(N);
2963     SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z);
2964     Combiner.AddToWorklist(NewY.getNode());
2965     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X,
2966                        DAG.getConstant(0, DL, X.getValueType()),
2967                        NewY.getValue(1));
2968   };
2969 
2970   /**
2971    *      (uaddo A, B)
2972    *           |
2973    *          Sum
2974    *           |
2975    * (addcarry *, 0, Z)
2976    */
2977   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
2978     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
2979   }
2980 
2981   /**
2982    * (addcarry A, 0, Z)
2983    *         |
2984    *        Sum
2985    *         |
2986    *  (uaddo *, B)
2987    */
2988   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
2989     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
2990   }
2991 
2992   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
2993     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
2994   }
2995 
2996   return SDValue();
2997 }
2998 
2999 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3000 // match patterns like:
3001 //
3002 //          (uaddo A, B)            CarryIn
3003 //            |  \                     |
3004 //            |   \                    |
3005 //    PartialSum   PartialCarryOutX   /
3006 //            |        |             /
3007 //            |    ____|____________/
3008 //            |   /    |
3009 //     (uaddo *, *)    \________
3010 //       |  \                   \
3011 //       |   \                   |
3012 //       |    PartialCarryOutY   |
3013 //       |        \              |
3014 //       |         \            /
3015 //   AddCarrySum    |    ______/
3016 //                  |   /
3017 //   CarryOut = (or *, *)
3018 //
3019 // And generate ADDCARRY (or SUBCARRY) with two result values:
3020 //
3021 //    {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn)
3022 //
3023 // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with
3024 // a single path for carry/borrow out propagation:
combineCarryDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,const TargetLowering & TLI,SDValue Carry0,SDValue Carry1,SDNode * N)3025 static SDValue combineCarryDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
3026                                    const TargetLowering &TLI, SDValue Carry0,
3027                                    SDValue Carry1, SDNode *N) {
3028   if (Carry0.getResNo() != 1 || Carry1.getResNo() != 1)
3029     return SDValue();
3030   unsigned Opcode = Carry0.getOpcode();
3031   if (Opcode != Carry1.getOpcode())
3032     return SDValue();
3033   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3034     return SDValue();
3035 
3036   // Canonicalize the add/sub of A and B as Carry0 and the add/sub of the
3037   // carry/borrow in as Carry1. (The top and middle uaddo nodes respectively in
3038   // the above ASCII art.)
3039   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3040       Carry1.getOperand(1) != Carry0.getValue(0))
3041     std::swap(Carry0, Carry1);
3042   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3043       Carry1.getOperand(1) != Carry0.getValue(0))
3044     return SDValue();
3045 
3046   // The carry in value must be on the righthand side for subtraction.
3047   unsigned CarryInOperandNum =
3048       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3049   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3050     return SDValue();
3051   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3052 
3053   unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY;
3054   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3055     return SDValue();
3056 
3057   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3058   // TODO: make getAsCarry() aware of how partial carries are merged.
3059   if (CarryIn.getOpcode() != ISD::ZERO_EXTEND)
3060     return SDValue();
3061   CarryIn = CarryIn.getOperand(0);
3062   if (CarryIn.getValueType() != MVT::i1)
3063     return SDValue();
3064 
3065   SDLoc DL(N);
3066   SDValue Merged =
3067       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3068                   Carry0.getOperand(1), CarryIn);
3069 
3070   // Please note that because we have proven that the result of the UADDO/USUBO
3071   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3072   // therefore prove that if the first UADDO/USUBO overflows, the second
3073   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3074   // maximum value.
3075   //
3076   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3077   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3078   //
3079   // This is important because it means that OR and XOR can be used to merge
3080   // carry flags; and that AND can return a constant zero.
3081   //
3082   // TODO: match other operations that can merge flags (ADD, etc)
3083   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3084   if (N->getOpcode() == ISD::AND)
3085     return DAG.getConstant(0, DL, MVT::i1);
3086   return Merged.getValue(1);
3087 }
3088 
visitADDCARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3089 SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
3090                                        SDNode *N) {
3091   // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry.
3092   if (isBitwiseNot(N0))
3093     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3094       SDLoc DL(N);
3095       SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1,
3096                                 N0.getOperand(0), NotC);
3097       return CombineTo(
3098           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3099     }
3100 
3101   // Iff the flag result is dead:
3102   // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
3103   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3104   // or the dependency between the instructions.
3105   if ((N0.getOpcode() == ISD::ADD ||
3106        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3107         N0.getValue(1) != CarryIn)) &&
3108       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3109     return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
3110                        N0.getOperand(0), N0.getOperand(1), CarryIn);
3111 
3112   /**
3113    * When one of the addcarry argument is itself a carry, we may be facing
3114    * a diamond carry propagation. In which case we try to transform the DAG
3115    * to ensure linear carry propagation if that is possible.
3116    */
3117   if (auto Y = getAsCarry(TLI, N1)) {
3118     // Because both are carries, Y and Z can be swapped.
3119     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3120       return R;
3121     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3122       return R;
3123   }
3124 
3125   return SDValue();
3126 }
3127 
3128 // Since it may not be valid to emit a fold to zero for vector initializers
3129 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)3130 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3131                              SelectionDAG &DAG, bool LegalOperations) {
3132   if (!VT.isVector())
3133     return DAG.getConstant(0, DL, VT);
3134   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
3135     return DAG.getConstant(0, DL, VT);
3136   return SDValue();
3137 }
3138 
visitSUB(SDNode * N)3139 SDValue DAGCombiner::visitSUB(SDNode *N) {
3140   SDValue N0 = N->getOperand(0);
3141   SDValue N1 = N->getOperand(1);
3142   EVT VT = N0.getValueType();
3143   SDLoc DL(N);
3144 
3145   // fold vector ops
3146   if (VT.isVector()) {
3147     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3148       return FoldedVOp;
3149 
3150     // fold (sub x, 0) -> x, vector edition
3151     if (ISD::isBuildVectorAllZeros(N1.getNode()))
3152       return N0;
3153   }
3154 
3155   // fold (sub x, x) -> 0
3156   // FIXME: Refactor this and xor and other similar operations together.
3157   if (N0 == N1)
3158     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3159 
3160   // fold (sub c1, c2) -> c3
3161   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
3162     return C;
3163 
3164   if (SDValue NewSel = foldBinOpIntoSelect(N))
3165     return NewSel;
3166 
3167   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3168 
3169   // fold (sub x, c) -> (add x, -c)
3170   if (N1C) {
3171     return DAG.getNode(ISD::ADD, DL, VT, N0,
3172                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3173   }
3174 
3175   if (isNullOrNullSplat(N0)) {
3176     unsigned BitWidth = VT.getScalarSizeInBits();
3177     // Right-shifting everything out but the sign bit followed by negation is
3178     // the same as flipping arithmetic/logical shift type without the negation:
3179     // -(X >>u 31) -> (X >>s 31)
3180     // -(X >>s 31) -> (X >>u 31)
3181     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3182       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3183       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3184         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3185         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3186           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3187       }
3188     }
3189 
3190     // 0 - X --> 0 if the sub is NUW.
3191     if (N->getFlags().hasNoUnsignedWrap())
3192       return N0;
3193 
3194     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3195       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3196       // N1 must be 0 because negating the minimum signed value is undefined.
3197       if (N->getFlags().hasNoSignedWrap())
3198         return N0;
3199 
3200       // 0 - X --> X if X is 0 or the minimum signed value.
3201       return N1;
3202     }
3203 
3204     // Convert 0 - abs(x).
3205     SDValue Result;
3206     if (N1->getOpcode() == ISD::ABS &&
3207         !TLI.isOperationLegalOrCustom(ISD::ABS, VT) &&
3208         TLI.expandABS(N1.getNode(), Result, DAG, true))
3209       return Result;
3210   }
3211 
3212   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3213   if (isAllOnesOrAllOnesSplat(N0))
3214     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3215 
3216   // fold (A - (0-B)) -> A+B
3217   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3218     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3219 
3220   // fold A-(A-B) -> B
3221   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3222     return N1.getOperand(1);
3223 
3224   // fold (A+B)-A -> B
3225   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3226     return N0.getOperand(1);
3227 
3228   // fold (A+B)-B -> A
3229   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3230     return N0.getOperand(0);
3231 
3232   // fold (A+C1)-C2 -> A+(C1-C2)
3233   if (N0.getOpcode() == ISD::ADD &&
3234       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3235       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3236     SDValue NewC =
3237         DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0.getOperand(1), N1});
3238     assert(NewC && "Constant folding failed");
3239     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3240   }
3241 
3242   // fold C2-(A+C1) -> (C2-C1)-A
3243   if (N1.getOpcode() == ISD::ADD) {
3244     SDValue N11 = N1.getOperand(1);
3245     if (isConstantOrConstantVector(N0, /* NoOpaques */ true) &&
3246         isConstantOrConstantVector(N11, /* NoOpaques */ true)) {
3247       SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11});
3248       assert(NewC && "Constant folding failed");
3249       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3250     }
3251   }
3252 
3253   // fold (A-C1)-C2 -> A-(C1+C2)
3254   if (N0.getOpcode() == ISD::SUB &&
3255       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3256       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3257     SDValue NewC =
3258         DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0.getOperand(1), N1});
3259     assert(NewC && "Constant folding failed");
3260     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3261   }
3262 
3263   // fold (c1-A)-c2 -> (c1-c2)-A
3264   if (N0.getOpcode() == ISD::SUB &&
3265       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3266       isConstantOrConstantVector(N0.getOperand(0), /* NoOpaques */ true)) {
3267     SDValue NewC =
3268         DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0.getOperand(0), N1});
3269     assert(NewC && "Constant folding failed");
3270     return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3271   }
3272 
3273   // fold ((A+(B+or-C))-B) -> A+or-C
3274   if (N0.getOpcode() == ISD::ADD &&
3275       (N0.getOperand(1).getOpcode() == ISD::SUB ||
3276        N0.getOperand(1).getOpcode() == ISD::ADD) &&
3277       N0.getOperand(1).getOperand(0) == N1)
3278     return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3279                        N0.getOperand(1).getOperand(1));
3280 
3281   // fold ((A+(C+B))-B) -> A+C
3282   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3283       N0.getOperand(1).getOperand(1) == N1)
3284     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3285                        N0.getOperand(1).getOperand(0));
3286 
3287   // fold ((A-(B-C))-C) -> A-B
3288   if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3289       N0.getOperand(1).getOperand(1) == N1)
3290     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3291                        N0.getOperand(1).getOperand(0));
3292 
3293   // fold (A-(B-C)) -> A+(C-B)
3294   if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3295     return DAG.getNode(ISD::ADD, DL, VT, N0,
3296                        DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3297                                    N1.getOperand(0)));
3298 
3299   // A - (A & B)  ->  A & (~B)
3300   if (N1.getOpcode() == ISD::AND) {
3301     SDValue A = N1.getOperand(0);
3302     SDValue B = N1.getOperand(1);
3303     if (A != N0)
3304       std::swap(A, B);
3305     if (A == N0 &&
3306         (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3307       SDValue InvB =
3308           DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3309       return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3310     }
3311   }
3312 
3313   // fold (X - (-Y * Z)) -> (X + (Y * Z))
3314   if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3315     if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3316         isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3317       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3318                                 N1.getOperand(0).getOperand(1),
3319                                 N1.getOperand(1));
3320       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3321     }
3322     if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3323         isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3324       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3325                                 N1.getOperand(0),
3326                                 N1.getOperand(1).getOperand(1));
3327       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3328     }
3329   }
3330 
3331   // If either operand of a sub is undef, the result is undef
3332   if (N0.isUndef())
3333     return N0;
3334   if (N1.isUndef())
3335     return N1;
3336 
3337   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3338     return V;
3339 
3340   if (SDValue V = foldAddSubOfSignBit(N, DAG))
3341     return V;
3342 
3343   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3344     return V;
3345 
3346   // (x - y) - 1  ->  add (xor y, -1), x
3347   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && isOneOrOneSplat(N1)) {
3348     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
3349                               DAG.getAllOnesConstant(DL, VT));
3350     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
3351   }
3352 
3353   // Look for:
3354   //   sub y, (xor x, -1)
3355   // And if the target does not like this form then turn into:
3356   //   add (add x, y), 1
3357   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3358     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3359     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3360   }
3361 
3362   // Hoist one-use addition by non-opaque constant:
3363   //   (x + C) - y  ->  (x - y) + C
3364   if (N0.hasOneUse() && N0.getOpcode() == ISD::ADD &&
3365       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3366     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3367     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3368   }
3369   // y - (x + C)  ->  (y - x) - C
3370   if (N1.hasOneUse() && N1.getOpcode() == ISD::ADD &&
3371       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
3372     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
3373     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
3374   }
3375   // (x - C) - y  ->  (x - y) - C
3376   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3377   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3378       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3379     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3380     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
3381   }
3382   // (C - x) - y  ->  C - (x + y)
3383   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3384       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3385     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
3386     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
3387   }
3388 
3389   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3390   // rather than 'sub 0/1' (the sext should get folded).
3391   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3392   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3393       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
3394       TLI.getBooleanContents(VT) ==
3395           TargetLowering::ZeroOrNegativeOneBooleanContent) {
3396     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
3397     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
3398   }
3399 
3400   // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3401   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
3402     if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3403       SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
3404       SDValue S0 = N1.getOperand(0);
3405       if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
3406         if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
3407           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
3408             return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
3409     }
3410   }
3411 
3412   // If the relocation model supports it, consider symbol offsets.
3413   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
3414     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3415       // fold (sub Sym, c) -> Sym-c
3416       if (N1C && GA->getOpcode() == ISD::GlobalAddress)
3417         return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
3418                                     GA->getOffset() -
3419                                         (uint64_t)N1C->getSExtValue());
3420       // fold (sub Sym+c1, Sym+c2) -> c1-c2
3421       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
3422         if (GA->getGlobal() == GB->getGlobal())
3423           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
3424                                  DL, VT);
3425     }
3426 
3427   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3428   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3429     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3430     if (TN->getVT() == MVT::i1) {
3431       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3432                                  DAG.getConstant(1, DL, VT));
3433       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
3434     }
3435   }
3436 
3437   // canonicalize (sub X, (vscale * C)) to (add X,  (vscale * -C))
3438   if (N1.getOpcode() == ISD::VSCALE) {
3439     const APInt &IntVal = N1.getConstantOperandAPInt(0);
3440     return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
3441   }
3442 
3443   // Prefer an add for more folding potential and possibly better codegen:
3444   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3445   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3446     SDValue ShAmt = N1.getOperand(1);
3447     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
3448     if (ShAmtC &&
3449         ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
3450       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
3451       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
3452     }
3453   }
3454 
3455   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) {
3456     // (sub Carry, X)  ->  (addcarry (sub 0, X), 0, Carry)
3457     if (SDValue Carry = getAsCarry(TLI, N0)) {
3458       SDValue X = N1;
3459       SDValue Zero = DAG.getConstant(0, DL, VT);
3460       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
3461       return DAG.getNode(ISD::ADDCARRY, DL,
3462                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
3463                          Carry);
3464     }
3465   }
3466 
3467   return SDValue();
3468 }
3469 
visitSUBSAT(SDNode * N)3470 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
3471   SDValue N0 = N->getOperand(0);
3472   SDValue N1 = N->getOperand(1);
3473   EVT VT = N0.getValueType();
3474   SDLoc DL(N);
3475 
3476   // fold vector ops
3477   if (VT.isVector()) {
3478     // TODO SimplifyVBinOp
3479 
3480     // fold (sub_sat x, 0) -> x, vector edition
3481     if (ISD::isBuildVectorAllZeros(N1.getNode()))
3482       return N0;
3483   }
3484 
3485   // fold (sub_sat x, undef) -> 0
3486   if (N0.isUndef() || N1.isUndef())
3487     return DAG.getConstant(0, DL, VT);
3488 
3489   // fold (sub_sat x, x) -> 0
3490   if (N0 == N1)
3491     return DAG.getConstant(0, DL, VT);
3492 
3493   // fold (sub_sat c1, c2) -> c3
3494   if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
3495     return C;
3496 
3497   // fold (sub_sat x, 0) -> x
3498   if (isNullConstant(N1))
3499     return N0;
3500 
3501   return SDValue();
3502 }
3503 
visitSUBC(SDNode * N)3504 SDValue DAGCombiner::visitSUBC(SDNode *N) {
3505   SDValue N0 = N->getOperand(0);
3506   SDValue N1 = N->getOperand(1);
3507   EVT VT = N0.getValueType();
3508   SDLoc DL(N);
3509 
3510   // If the flag result is dead, turn this into an SUB.
3511   if (!N->hasAnyUseOfValue(1))
3512     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3513                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3514 
3515   // fold (subc x, x) -> 0 + no borrow
3516   if (N0 == N1)
3517     return CombineTo(N, DAG.getConstant(0, DL, VT),
3518                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3519 
3520   // fold (subc x, 0) -> x + no borrow
3521   if (isNullConstant(N1))
3522     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3523 
3524   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3525   if (isAllOnesConstant(N0))
3526     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3527                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3528 
3529   return SDValue();
3530 }
3531 
visitSUBO(SDNode * N)3532 SDValue DAGCombiner::visitSUBO(SDNode *N) {
3533   SDValue N0 = N->getOperand(0);
3534   SDValue N1 = N->getOperand(1);
3535   EVT VT = N0.getValueType();
3536   bool IsSigned = (ISD::SSUBO == N->getOpcode());
3537 
3538   EVT CarryVT = N->getValueType(1);
3539   SDLoc DL(N);
3540 
3541   // If the flag result is dead, turn this into an SUB.
3542   if (!N->hasAnyUseOfValue(1))
3543     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3544                      DAG.getUNDEF(CarryVT));
3545 
3546   // fold (subo x, x) -> 0 + no borrow
3547   if (N0 == N1)
3548     return CombineTo(N, DAG.getConstant(0, DL, VT),
3549                      DAG.getConstant(0, DL, CarryVT));
3550 
3551   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3552 
3553   // fold (subox, c) -> (addo x, -c)
3554   if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) {
3555     return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
3556                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3557   }
3558 
3559   // fold (subo x, 0) -> x + no borrow
3560   if (isNullOrNullSplat(N1))
3561     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3562 
3563   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3564   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
3565     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3566                      DAG.getConstant(0, DL, CarryVT));
3567 
3568   return SDValue();
3569 }
3570 
visitSUBE(SDNode * N)3571 SDValue DAGCombiner::visitSUBE(SDNode *N) {
3572   SDValue N0 = N->getOperand(0);
3573   SDValue N1 = N->getOperand(1);
3574   SDValue CarryIn = N->getOperand(2);
3575 
3576   // fold (sube x, y, false) -> (subc x, y)
3577   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3578     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
3579 
3580   return SDValue();
3581 }
3582 
visitSUBCARRY(SDNode * N)3583 SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
3584   SDValue N0 = N->getOperand(0);
3585   SDValue N1 = N->getOperand(1);
3586   SDValue CarryIn = N->getOperand(2);
3587 
3588   // fold (subcarry x, y, false) -> (usubo x, y)
3589   if (isNullConstant(CarryIn)) {
3590     if (!LegalOperations ||
3591         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
3592       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
3593   }
3594 
3595   return SDValue();
3596 }
3597 
visitSSUBO_CARRY(SDNode * N)3598 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
3599   SDValue N0 = N->getOperand(0);
3600   SDValue N1 = N->getOperand(1);
3601   SDValue CarryIn = N->getOperand(2);
3602 
3603   // fold (ssubo_carry x, y, false) -> (ssubo x, y)
3604   if (isNullConstant(CarryIn)) {
3605     if (!LegalOperations ||
3606         TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
3607       return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
3608   }
3609 
3610   return SDValue();
3611 }
3612 
3613 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
3614 // UMULFIXSAT here.
visitMULFIX(SDNode * N)3615 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
3616   SDValue N0 = N->getOperand(0);
3617   SDValue N1 = N->getOperand(1);
3618   SDValue Scale = N->getOperand(2);
3619   EVT VT = N0.getValueType();
3620 
3621   // fold (mulfix x, undef, scale) -> 0
3622   if (N0.isUndef() || N1.isUndef())
3623     return DAG.getConstant(0, SDLoc(N), VT);
3624 
3625   // Canonicalize constant to RHS (vector doesn't have to splat)
3626   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3627      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3628     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
3629 
3630   // fold (mulfix x, 0, scale) -> 0
3631   if (isNullConstant(N1))
3632     return DAG.getConstant(0, SDLoc(N), VT);
3633 
3634   return SDValue();
3635 }
3636 
visitMUL(SDNode * N)3637 SDValue DAGCombiner::visitMUL(SDNode *N) {
3638   SDValue N0 = N->getOperand(0);
3639   SDValue N1 = N->getOperand(1);
3640   EVT VT = N0.getValueType();
3641 
3642   // fold (mul x, undef) -> 0
3643   if (N0.isUndef() || N1.isUndef())
3644     return DAG.getConstant(0, SDLoc(N), VT);
3645 
3646   bool N1IsConst = false;
3647   bool N1IsOpaqueConst = false;
3648   APInt ConstValue1;
3649 
3650   // fold vector ops
3651   if (VT.isVector()) {
3652     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3653       return FoldedVOp;
3654 
3655     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
3656     assert((!N1IsConst ||
3657             ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
3658            "Splat APInt should be element width");
3659   } else {
3660     N1IsConst = isa<ConstantSDNode>(N1);
3661     if (N1IsConst) {
3662       ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue();
3663       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
3664     }
3665   }
3666 
3667   // fold (mul c1, c2) -> c1*c2
3668   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT, {N0, N1}))
3669     return C;
3670 
3671   // canonicalize constant to RHS (vector doesn't have to splat)
3672   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3673      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3674     return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
3675 
3676   // fold (mul x, 0) -> 0
3677   if (N1IsConst && ConstValue1.isNullValue())
3678     return N1;
3679 
3680   // fold (mul x, 1) -> x
3681   if (N1IsConst && ConstValue1.isOneValue())
3682     return N0;
3683 
3684   if (SDValue NewSel = foldBinOpIntoSelect(N))
3685     return NewSel;
3686 
3687   // fold (mul x, -1) -> 0-x
3688   if (N1IsConst && ConstValue1.isAllOnesValue()) {
3689     SDLoc DL(N);
3690     return DAG.getNode(ISD::SUB, DL, VT,
3691                        DAG.getConstant(0, DL, VT), N0);
3692   }
3693 
3694   // fold (mul x, (1 << c)) -> x << c
3695   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
3696       DAG.isKnownToBeAPowerOfTwo(N1) &&
3697       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
3698     SDLoc DL(N);
3699     SDValue LogBase2 = BuildLogBase2(N1, DL);
3700     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
3701     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
3702     return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
3703   }
3704 
3705   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
3706   if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) {
3707     unsigned Log2Val = (-ConstValue1).logBase2();
3708     SDLoc DL(N);
3709     // FIXME: If the input is something that is easily negated (e.g. a
3710     // single-use add), we should put the negate there.
3711     return DAG.getNode(ISD::SUB, DL, VT,
3712                        DAG.getConstant(0, DL, VT),
3713                        DAG.getNode(ISD::SHL, DL, VT, N0,
3714                             DAG.getConstant(Log2Val, DL,
3715                                       getShiftAmountTy(N0.getValueType()))));
3716   }
3717 
3718   // Try to transform:
3719   // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
3720   // mul x, (2^N + 1) --> add (shl x, N), x
3721   // mul x, (2^N - 1) --> sub (shl x, N), x
3722   // Examples: x * 33 --> (x << 5) + x
3723   //           x * 15 --> (x << 4) - x
3724   //           x * -33 --> -((x << 5) + x)
3725   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
3726   // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
3727   // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
3728   // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
3729   // Examples: x * 0x8800 --> (x << 15) + (x << 11)
3730   //           x * 0xf800 --> (x << 16) - (x << 11)
3731   //           x * -0x8800 --> -((x << 15) + (x << 11))
3732   //           x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
3733   if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
3734     // TODO: We could handle more general decomposition of any constant by
3735     //       having the target set a limit on number of ops and making a
3736     //       callback to determine that sequence (similar to sqrt expansion).
3737     unsigned MathOp = ISD::DELETED_NODE;
3738     APInt MulC = ConstValue1.abs();
3739     // The constant `2` should be treated as (2^0 + 1).
3740     unsigned TZeros = MulC == 2 ? 0 : MulC.countTrailingZeros();
3741     MulC.lshrInPlace(TZeros);
3742     if ((MulC - 1).isPowerOf2())
3743       MathOp = ISD::ADD;
3744     else if ((MulC + 1).isPowerOf2())
3745       MathOp = ISD::SUB;
3746 
3747     if (MathOp != ISD::DELETED_NODE) {
3748       unsigned ShAmt =
3749           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
3750       ShAmt += TZeros;
3751       assert(ShAmt < VT.getScalarSizeInBits() &&
3752              "multiply-by-constant generated out of bounds shift");
3753       SDLoc DL(N);
3754       SDValue Shl =
3755           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
3756       SDValue R =
3757           TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
3758                                DAG.getNode(ISD::SHL, DL, VT, N0,
3759                                            DAG.getConstant(TZeros, DL, VT)))
3760                  : DAG.getNode(MathOp, DL, VT, Shl, N0);
3761       if (ConstValue1.isNegative())
3762         R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R);
3763       return R;
3764     }
3765   }
3766 
3767   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
3768   if (N0.getOpcode() == ISD::SHL &&
3769       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3770       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3771     SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1));
3772     if (isConstantOrConstantVector(C3))
3773       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3);
3774   }
3775 
3776   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
3777   // use.
3778   {
3779     SDValue Sh(nullptr, 0), Y(nullptr, 0);
3780 
3781     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
3782     if (N0.getOpcode() == ISD::SHL &&
3783         isConstantOrConstantVector(N0.getOperand(1)) &&
3784         N0.getNode()->hasOneUse()) {
3785       Sh = N0; Y = N1;
3786     } else if (N1.getOpcode() == ISD::SHL &&
3787                isConstantOrConstantVector(N1.getOperand(1)) &&
3788                N1.getNode()->hasOneUse()) {
3789       Sh = N1; Y = N0;
3790     }
3791 
3792     if (Sh.getNode()) {
3793       SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, Sh.getOperand(0), Y);
3794       return DAG.getNode(ISD::SHL, SDLoc(N), VT, Mul, Sh.getOperand(1));
3795     }
3796   }
3797 
3798   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
3799   if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
3800       N0.getOpcode() == ISD::ADD &&
3801       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
3802       isMulAddWithConstProfitable(N, N0, N1))
3803       return DAG.getNode(ISD::ADD, SDLoc(N), VT,
3804                          DAG.getNode(ISD::MUL, SDLoc(N0), VT,
3805                                      N0.getOperand(0), N1),
3806                          DAG.getNode(ISD::MUL, SDLoc(N1), VT,
3807                                      N0.getOperand(1), N1));
3808 
3809   // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
3810   if (N0.getOpcode() == ISD::VSCALE)
3811     if (ConstantSDNode *NC1 = isConstOrConstSplat(N1)) {
3812       const APInt &C0 = N0.getConstantOperandAPInt(0);
3813       const APInt &C1 = NC1->getAPIntValue();
3814       return DAG.getVScale(SDLoc(N), VT, C0 * C1);
3815     }
3816 
3817   // Fold ((mul x, 0/undef) -> 0,
3818   //       (mul x, 1) -> x) -> x)
3819   // -> and(x, mask)
3820   // We can replace vectors with '0' and '1' factors with a clearing mask.
3821   if (VT.isFixedLengthVector()) {
3822     unsigned NumElts = VT.getVectorNumElements();
3823     SmallBitVector ClearMask;
3824     ClearMask.reserve(NumElts);
3825     auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
3826       if (!V || V->isNullValue()) {
3827         ClearMask.push_back(true);
3828         return true;
3829       }
3830       ClearMask.push_back(false);
3831       return V->isOne();
3832     };
3833     if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
3834         ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
3835       assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
3836       SDLoc DL(N);
3837       EVT LegalSVT = N1.getOperand(0).getValueType();
3838       SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
3839       SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
3840       SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
3841       for (unsigned I = 0; I != NumElts; ++I)
3842         if (ClearMask[I])
3843           Mask[I] = Zero;
3844       return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
3845     }
3846   }
3847 
3848   // reassociate mul
3849   if (SDValue RMUL = reassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags()))
3850     return RMUL;
3851 
3852   return SDValue();
3853 }
3854 
3855 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)3856 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
3857                                      const TargetLowering &TLI) {
3858   RTLIB::Libcall LC;
3859   EVT NodeType = Node->getValueType(0);
3860   if (!NodeType.isSimple())
3861     return false;
3862   switch (NodeType.getSimpleVT().SimpleTy) {
3863   default: return false; // No libcall for vector types.
3864   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
3865   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
3866   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
3867   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
3868   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
3869   }
3870 
3871   return TLI.getLibcallName(LC) != nullptr;
3872 }
3873 
3874 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)3875 SDValue DAGCombiner::useDivRem(SDNode *Node) {
3876   if (Node->use_empty())
3877     return SDValue(); // This is a dead node, leave it alone.
3878 
3879   unsigned Opcode = Node->getOpcode();
3880   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
3881   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
3882 
3883   // DivMod lib calls can still work on non-legal types if using lib-calls.
3884   EVT VT = Node->getValueType(0);
3885   if (VT.isVector() || !VT.isInteger())
3886     return SDValue();
3887 
3888   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
3889     return SDValue();
3890 
3891   // If DIVREM is going to get expanded into a libcall,
3892   // but there is no libcall available, then don't combine.
3893   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
3894       !isDivRemLibcallAvailable(Node, isSigned, TLI))
3895     return SDValue();
3896 
3897   // If div is legal, it's better to do the normal expansion
3898   unsigned OtherOpcode = 0;
3899   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
3900     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
3901     if (TLI.isOperationLegalOrCustom(Opcode, VT))
3902       return SDValue();
3903   } else {
3904     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
3905     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
3906       return SDValue();
3907   }
3908 
3909   SDValue Op0 = Node->getOperand(0);
3910   SDValue Op1 = Node->getOperand(1);
3911   SDValue combined;
3912   for (SDNode::use_iterator UI = Op0.getNode()->use_begin(),
3913          UE = Op0.getNode()->use_end(); UI != UE; ++UI) {
3914     SDNode *User = *UI;
3915     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
3916         User->use_empty())
3917       continue;
3918     // Convert the other matching node(s), too;
3919     // otherwise, the DIVREM may get target-legalized into something
3920     // target-specific that we won't be able to recognize.
3921     unsigned UserOpc = User->getOpcode();
3922     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
3923         User->getOperand(0) == Op0 &&
3924         User->getOperand(1) == Op1) {
3925       if (!combined) {
3926         if (UserOpc == OtherOpcode) {
3927           SDVTList VTs = DAG.getVTList(VT, VT);
3928           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
3929         } else if (UserOpc == DivRemOpc) {
3930           combined = SDValue(User, 0);
3931         } else {
3932           assert(UserOpc == Opcode);
3933           continue;
3934         }
3935       }
3936       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
3937         CombineTo(User, combined);
3938       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
3939         CombineTo(User, combined.getValue(1));
3940     }
3941   }
3942   return combined;
3943 }
3944 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)3945 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
3946   SDValue N0 = N->getOperand(0);
3947   SDValue N1 = N->getOperand(1);
3948   EVT VT = N->getValueType(0);
3949   SDLoc DL(N);
3950 
3951   unsigned Opc = N->getOpcode();
3952   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
3953   ConstantSDNode *N1C = isConstOrConstSplat(N1);
3954 
3955   // X / undef -> undef
3956   // X % undef -> undef
3957   // X / 0 -> undef
3958   // X % 0 -> undef
3959   // NOTE: This includes vectors where any divisor element is zero/undef.
3960   if (DAG.isUndef(Opc, {N0, N1}))
3961     return DAG.getUNDEF(VT);
3962 
3963   // undef / X -> 0
3964   // undef % X -> 0
3965   if (N0.isUndef())
3966     return DAG.getConstant(0, DL, VT);
3967 
3968   // 0 / X -> 0
3969   // 0 % X -> 0
3970   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3971   if (N0C && N0C->isNullValue())
3972     return N0;
3973 
3974   // X / X -> 1
3975   // X % X -> 0
3976   if (N0 == N1)
3977     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
3978 
3979   // X / 1 -> X
3980   // X % 1 -> 0
3981   // If this is a boolean op (single-bit element type), we can't have
3982   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
3983   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
3984   // it's a 1.
3985   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
3986     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
3987 
3988   return SDValue();
3989 }
3990 
visitSDIV(SDNode * N)3991 SDValue DAGCombiner::visitSDIV(SDNode *N) {
3992   SDValue N0 = N->getOperand(0);
3993   SDValue N1 = N->getOperand(1);
3994   EVT VT = N->getValueType(0);
3995   EVT CCVT = getSetCCResultType(VT);
3996 
3997   // fold vector ops
3998   if (VT.isVector())
3999     if (SDValue FoldedVOp = SimplifyVBinOp(N))
4000       return FoldedVOp;
4001 
4002   SDLoc DL(N);
4003 
4004   // fold (sdiv c1, c2) -> c1/c2
4005   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4006   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
4007     return C;
4008 
4009   // fold (sdiv X, -1) -> 0-X
4010   if (N1C && N1C->isAllOnesValue())
4011     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
4012 
4013   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4014   if (N1C && N1C->getAPIntValue().isMinSignedValue())
4015     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4016                          DAG.getConstant(1, DL, VT),
4017                          DAG.getConstant(0, DL, VT));
4018 
4019   if (SDValue V = simplifyDivRem(N, DAG))
4020     return V;
4021 
4022   if (SDValue NewSel = foldBinOpIntoSelect(N))
4023     return NewSel;
4024 
4025   // If we know the sign bits of both operands are zero, strength reduce to a
4026   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
4027   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4028     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
4029 
4030   if (SDValue V = visitSDIVLike(N0, N1, N)) {
4031     // If the corresponding remainder node exists, update its users with
4032     // (Dividend - (Quotient * Divisor).
4033     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
4034                                               { N0, N1 })) {
4035       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4036       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4037       AddToWorklist(Mul.getNode());
4038       AddToWorklist(Sub.getNode());
4039       CombineTo(RemNode, Sub);
4040     }
4041     return V;
4042   }
4043 
4044   // sdiv, srem -> sdivrem
4045   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4046   // true.  Otherwise, we break the simplification logic in visitREM().
4047   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4048   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4049     if (SDValue DivRem = useDivRem(N))
4050         return DivRem;
4051 
4052   return SDValue();
4053 }
4054 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)4055 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4056   SDLoc DL(N);
4057   EVT VT = N->getValueType(0);
4058   EVT CCVT = getSetCCResultType(VT);
4059   unsigned BitWidth = VT.getScalarSizeInBits();
4060 
4061   // Helper for determining whether a value is a power-2 constant scalar or a
4062   // vector of such elements.
4063   auto IsPowerOfTwo = [](ConstantSDNode *C) {
4064     if (C->isNullValue() || C->isOpaque())
4065       return false;
4066     if (C->getAPIntValue().isPowerOf2())
4067       return true;
4068     if ((-C->getAPIntValue()).isPowerOf2())
4069       return true;
4070     return false;
4071   };
4072 
4073   // fold (sdiv X, pow2) -> simple ops after legalize
4074   // FIXME: We check for the exact bit here because the generic lowering gives
4075   // better results in that case. The target-specific lowering should learn how
4076   // to handle exact sdivs efficiently.
4077   if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) {
4078     // Target-specific implementation of sdiv x, pow2.
4079     if (SDValue Res = BuildSDIVPow2(N))
4080       return Res;
4081 
4082     // Create constants that are functions of the shift amount value.
4083     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
4084     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
4085     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
4086     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
4087     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
4088     if (!isConstantOrConstantVector(Inexact))
4089       return SDValue();
4090 
4091     // Splat the sign bit into the register
4092     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
4093                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
4094     AddToWorklist(Sign.getNode());
4095 
4096     // Add (N0 < 0) ? abs2 - 1 : 0;
4097     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
4098     AddToWorklist(Srl.getNode());
4099     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
4100     AddToWorklist(Add.getNode());
4101     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
4102     AddToWorklist(Sra.getNode());
4103 
4104     // Special case: (sdiv X, 1) -> X
4105     // Special Case: (sdiv X, -1) -> 0-X
4106     SDValue One = DAG.getConstant(1, DL, VT);
4107     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4108     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
4109     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
4110     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
4111     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
4112 
4113     // If dividing by a positive value, we're done. Otherwise, the result must
4114     // be negated.
4115     SDValue Zero = DAG.getConstant(0, DL, VT);
4116     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
4117 
4118     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4119     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
4120     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
4121     return Res;
4122   }
4123 
4124   // If integer divide is expensive and we satisfy the requirements, emit an
4125   // alternate sequence.  Targets may check function attributes for size/speed
4126   // trade-offs.
4127   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4128   if (isConstantOrConstantVector(N1) &&
4129       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4130     if (SDValue Op = BuildSDIV(N))
4131       return Op;
4132 
4133   return SDValue();
4134 }
4135 
visitUDIV(SDNode * N)4136 SDValue DAGCombiner::visitUDIV(SDNode *N) {
4137   SDValue N0 = N->getOperand(0);
4138   SDValue N1 = N->getOperand(1);
4139   EVT VT = N->getValueType(0);
4140   EVT CCVT = getSetCCResultType(VT);
4141 
4142   // fold vector ops
4143   if (VT.isVector())
4144     if (SDValue FoldedVOp = SimplifyVBinOp(N))
4145       return FoldedVOp;
4146 
4147   SDLoc DL(N);
4148 
4149   // fold (udiv c1, c2) -> c1/c2
4150   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4151   if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
4152     return C;
4153 
4154   // fold (udiv X, -1) -> select(X == -1, 1, 0)
4155   if (N1C && N1C->getAPIntValue().isAllOnesValue())
4156     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4157                          DAG.getConstant(1, DL, VT),
4158                          DAG.getConstant(0, DL, VT));
4159 
4160   if (SDValue V = simplifyDivRem(N, DAG))
4161     return V;
4162 
4163   if (SDValue NewSel = foldBinOpIntoSelect(N))
4164     return NewSel;
4165 
4166   if (SDValue V = visitUDIVLike(N0, N1, N)) {
4167     // If the corresponding remainder node exists, update its users with
4168     // (Dividend - (Quotient * Divisor).
4169     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
4170                                               { N0, N1 })) {
4171       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4172       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4173       AddToWorklist(Mul.getNode());
4174       AddToWorklist(Sub.getNode());
4175       CombineTo(RemNode, Sub);
4176     }
4177     return V;
4178   }
4179 
4180   // sdiv, srem -> sdivrem
4181   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4182   // true.  Otherwise, we break the simplification logic in visitREM().
4183   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4184   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4185     if (SDValue DivRem = useDivRem(N))
4186         return DivRem;
4187 
4188   return SDValue();
4189 }
4190 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)4191 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4192   SDLoc DL(N);
4193   EVT VT = N->getValueType(0);
4194 
4195   // fold (udiv x, (1 << c)) -> x >>u c
4196   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4197       DAG.isKnownToBeAPowerOfTwo(N1)) {
4198     SDValue LogBase2 = BuildLogBase2(N1, DL);
4199     AddToWorklist(LogBase2.getNode());
4200 
4201     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4202     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4203     AddToWorklist(Trunc.getNode());
4204     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4205   }
4206 
4207   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4208   if (N1.getOpcode() == ISD::SHL) {
4209     SDValue N10 = N1.getOperand(0);
4210     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
4211         DAG.isKnownToBeAPowerOfTwo(N10)) {
4212       SDValue LogBase2 = BuildLogBase2(N10, DL);
4213       AddToWorklist(LogBase2.getNode());
4214 
4215       EVT ADDVT = N1.getOperand(1).getValueType();
4216       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
4217       AddToWorklist(Trunc.getNode());
4218       SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
4219       AddToWorklist(Add.getNode());
4220       return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
4221     }
4222   }
4223 
4224   // fold (udiv x, c) -> alternate
4225   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4226   if (isConstantOrConstantVector(N1) &&
4227       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4228     if (SDValue Op = BuildUDIV(N))
4229       return Op;
4230 
4231   return SDValue();
4232 }
4233 
4234 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)4235 SDValue DAGCombiner::visitREM(SDNode *N) {
4236   unsigned Opcode = N->getOpcode();
4237   SDValue N0 = N->getOperand(0);
4238   SDValue N1 = N->getOperand(1);
4239   EVT VT = N->getValueType(0);
4240   EVT CCVT = getSetCCResultType(VT);
4241 
4242   bool isSigned = (Opcode == ISD::SREM);
4243   SDLoc DL(N);
4244 
4245   // fold (rem c1, c2) -> c1%c2
4246   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4247   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4248     return C;
4249 
4250   // fold (urem X, -1) -> select(X == -1, 0, x)
4251   if (!isSigned && N1C && N1C->getAPIntValue().isAllOnesValue())
4252     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4253                          DAG.getConstant(0, DL, VT), N0);
4254 
4255   if (SDValue V = simplifyDivRem(N, DAG))
4256     return V;
4257 
4258   if (SDValue NewSel = foldBinOpIntoSelect(N))
4259     return NewSel;
4260 
4261   if (isSigned) {
4262     // If we know the sign bits of both operands are zero, strength reduce to a
4263     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4264     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4265       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4266   } else {
4267     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4268       // fold (urem x, pow2) -> (and x, pow2-1)
4269       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4270       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4271       AddToWorklist(Add.getNode());
4272       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4273     }
4274     if (N1.getOpcode() == ISD::SHL &&
4275         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
4276       // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4277       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4278       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4279       AddToWorklist(Add.getNode());
4280       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4281     }
4282   }
4283 
4284   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4285 
4286   // If X/C can be simplified by the division-by-constant logic, lower
4287   // X%C to the equivalent of X-X/C*C.
4288   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4289   // speculative DIV must not cause a DIVREM conversion.  We guard against this
4290   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
4291   // combine will not return a DIVREM.  Regardless, checking cheapness here
4292   // makes sense since the simplification results in fatter code.
4293   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
4294     SDValue OptimizedDiv =
4295         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4296     if (OptimizedDiv.getNode()) {
4297       // If the equivalent Div node also exists, update its users.
4298       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4299       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
4300                                                 { N0, N1 }))
4301         CombineTo(DivNode, OptimizedDiv);
4302       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
4303       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4304       AddToWorklist(OptimizedDiv.getNode());
4305       AddToWorklist(Mul.getNode());
4306       return Sub;
4307     }
4308   }
4309 
4310   // sdiv, srem -> sdivrem
4311   if (SDValue DivRem = useDivRem(N))
4312     return DivRem.getValue(1);
4313 
4314   return SDValue();
4315 }
4316 
visitMULHS(SDNode * N)4317 SDValue DAGCombiner::visitMULHS(SDNode *N) {
4318   SDValue N0 = N->getOperand(0);
4319   SDValue N1 = N->getOperand(1);
4320   EVT VT = N->getValueType(0);
4321   SDLoc DL(N);
4322 
4323   if (VT.isVector()) {
4324     // fold (mulhs x, 0) -> 0
4325     // do not return N0/N1, because undef node may exist.
4326     if (ISD::isBuildVectorAllZeros(N0.getNode()) ||
4327         ISD::isBuildVectorAllZeros(N1.getNode()))
4328       return DAG.getConstant(0, DL, VT);
4329   }
4330 
4331   // fold (mulhs x, 0) -> 0
4332   if (isNullConstant(N1))
4333     return N1;
4334   // fold (mulhs x, 1) -> (sra x, size(x)-1)
4335   if (isOneConstant(N1))
4336     return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
4337                        DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
4338                                        getShiftAmountTy(N0.getValueType())));
4339 
4340   // fold (mulhs x, undef) -> 0
4341   if (N0.isUndef() || N1.isUndef())
4342     return DAG.getConstant(0, DL, VT);
4343 
4344   // If the type twice as wide is legal, transform the mulhs to a wider multiply
4345   // plus a shift.
4346   if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
4347       !VT.isVector()) {
4348     MVT Simple = VT.getSimpleVT();
4349     unsigned SimpleSize = Simple.getSizeInBits();
4350     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4351     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4352       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4353       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4354       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4355       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4356             DAG.getConstant(SimpleSize, DL,
4357                             getShiftAmountTy(N1.getValueType())));
4358       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4359     }
4360   }
4361 
4362   return SDValue();
4363 }
4364 
visitMULHU(SDNode * N)4365 SDValue DAGCombiner::visitMULHU(SDNode *N) {
4366   SDValue N0 = N->getOperand(0);
4367   SDValue N1 = N->getOperand(1);
4368   EVT VT = N->getValueType(0);
4369   SDLoc DL(N);
4370 
4371   if (VT.isVector()) {
4372     // fold (mulhu x, 0) -> 0
4373     // do not return N0/N1, because undef node may exist.
4374     if (ISD::isBuildVectorAllZeros(N0.getNode()) ||
4375         ISD::isBuildVectorAllZeros(N1.getNode()))
4376       return DAG.getConstant(0, DL, VT);
4377   }
4378 
4379   // fold (mulhu x, 0) -> 0
4380   if (isNullConstant(N1))
4381     return N1;
4382   // fold (mulhu x, 1) -> 0
4383   if (isOneConstant(N1))
4384     return DAG.getConstant(0, DL, N0.getValueType());
4385   // fold (mulhu x, undef) -> 0
4386   if (N0.isUndef() || N1.isUndef())
4387     return DAG.getConstant(0, DL, VT);
4388 
4389   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
4390   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4391       DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
4392     unsigned NumEltBits = VT.getScalarSizeInBits();
4393     SDValue LogBase2 = BuildLogBase2(N1, DL);
4394     SDValue SRLAmt = DAG.getNode(
4395         ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
4396     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4397     SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
4398     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4399   }
4400 
4401   // If the type twice as wide is legal, transform the mulhu to a wider multiply
4402   // plus a shift.
4403   if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
4404       !VT.isVector()) {
4405     MVT Simple = VT.getSimpleVT();
4406     unsigned SimpleSize = Simple.getSizeInBits();
4407     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4408     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4409       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4410       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4411       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4412       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4413             DAG.getConstant(SimpleSize, DL,
4414                             getShiftAmountTy(N1.getValueType())));
4415       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4416     }
4417   }
4418 
4419   return SDValue();
4420 }
4421 
4422 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
4423 /// give the opcodes for the two computations that are being performed. Return
4424 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)4425 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
4426                                                 unsigned HiOp) {
4427   // If the high half is not needed, just compute the low half.
4428   bool HiExists = N->hasAnyUseOfValue(1);
4429   if (!HiExists && (!LegalOperations ||
4430                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
4431     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4432     return CombineTo(N, Res, Res);
4433   }
4434 
4435   // If the low half is not needed, just compute the high half.
4436   bool LoExists = N->hasAnyUseOfValue(0);
4437   if (!LoExists && (!LegalOperations ||
4438                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
4439     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4440     return CombineTo(N, Res, Res);
4441   }
4442 
4443   // If both halves are used, return as it is.
4444   if (LoExists && HiExists)
4445     return SDValue();
4446 
4447   // If the two computed results can be simplified separately, separate them.
4448   if (LoExists) {
4449     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4450     AddToWorklist(Lo.getNode());
4451     SDValue LoOpt = combine(Lo.getNode());
4452     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
4453         (!LegalOperations ||
4454          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
4455       return CombineTo(N, LoOpt, LoOpt);
4456   }
4457 
4458   if (HiExists) {
4459     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4460     AddToWorklist(Hi.getNode());
4461     SDValue HiOpt = combine(Hi.getNode());
4462     if (HiOpt.getNode() && HiOpt != Hi &&
4463         (!LegalOperations ||
4464          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
4465       return CombineTo(N, HiOpt, HiOpt);
4466   }
4467 
4468   return SDValue();
4469 }
4470 
visitSMUL_LOHI(SDNode * N)4471 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
4472   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
4473     return Res;
4474 
4475   EVT VT = N->getValueType(0);
4476   SDLoc DL(N);
4477 
4478   // If the type is twice as wide is legal, transform the mulhu to a wider
4479   // multiply plus a shift.
4480   if (VT.isSimple() && !VT.isVector()) {
4481     MVT Simple = VT.getSimpleVT();
4482     unsigned SimpleSize = Simple.getSizeInBits();
4483     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4484     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4485       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N->getOperand(0));
4486       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N->getOperand(1));
4487       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4488       // Compute the high part as N1.
4489       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4490             DAG.getConstant(SimpleSize, DL,
4491                             getShiftAmountTy(Lo.getValueType())));
4492       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4493       // Compute the low part as N0.
4494       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4495       return CombineTo(N, Lo, Hi);
4496     }
4497   }
4498 
4499   return SDValue();
4500 }
4501 
visitUMUL_LOHI(SDNode * N)4502 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
4503   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
4504     return Res;
4505 
4506   EVT VT = N->getValueType(0);
4507   SDLoc DL(N);
4508 
4509   // (umul_lohi N0, 0) -> (0, 0)
4510   if (isNullConstant(N->getOperand(1))) {
4511     SDValue Zero = DAG.getConstant(0, DL, VT);
4512     return CombineTo(N, Zero, Zero);
4513   }
4514 
4515   // (umul_lohi N0, 1) -> (N0, 0)
4516   if (isOneConstant(N->getOperand(1))) {
4517     SDValue Zero = DAG.getConstant(0, DL, VT);
4518     return CombineTo(N, N->getOperand(0), Zero);
4519   }
4520 
4521   // If the type is twice as wide is legal, transform the mulhu to a wider
4522   // multiply plus a shift.
4523   if (VT.isSimple() && !VT.isVector()) {
4524     MVT Simple = VT.getSimpleVT();
4525     unsigned SimpleSize = Simple.getSizeInBits();
4526     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4527     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4528       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N->getOperand(0));
4529       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N->getOperand(1));
4530       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4531       // Compute the high part as N1.
4532       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4533             DAG.getConstant(SimpleSize, DL,
4534                             getShiftAmountTy(Lo.getValueType())));
4535       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4536       // Compute the low part as N0.
4537       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4538       return CombineTo(N, Lo, Hi);
4539     }
4540   }
4541 
4542   return SDValue();
4543 }
4544 
visitMULO(SDNode * N)4545 SDValue DAGCombiner::visitMULO(SDNode *N) {
4546   SDValue N0 = N->getOperand(0);
4547   SDValue N1 = N->getOperand(1);
4548   EVT VT = N0.getValueType();
4549   bool IsSigned = (ISD::SMULO == N->getOpcode());
4550 
4551   EVT CarryVT = N->getValueType(1);
4552   SDLoc DL(N);
4553 
4554   // canonicalize constant to RHS.
4555   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4556       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4557     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
4558 
4559   // fold (mulo x, 0) -> 0 + no carry out
4560   if (isNullOrNullSplat(N1))
4561     return CombineTo(N, DAG.getConstant(0, DL, VT),
4562                      DAG.getConstant(0, DL, CarryVT));
4563 
4564   // (mulo x, 2) -> (addo x, x)
4565   if (ConstantSDNode *C2 = isConstOrConstSplat(N1))
4566     if (C2->getAPIntValue() == 2)
4567       return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
4568                          N->getVTList(), N0, N0);
4569 
4570   return SDValue();
4571 }
4572 
visitIMINMAX(SDNode * N)4573 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
4574   SDValue N0 = N->getOperand(0);
4575   SDValue N1 = N->getOperand(1);
4576   EVT VT = N0.getValueType();
4577   unsigned Opcode = N->getOpcode();
4578 
4579   // fold vector ops
4580   if (VT.isVector())
4581     if (SDValue FoldedVOp = SimplifyVBinOp(N))
4582       return FoldedVOp;
4583 
4584   // fold operation with constant operands.
4585   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, SDLoc(N), VT, {N0, N1}))
4586     return C;
4587 
4588   // canonicalize constant to RHS
4589   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4590       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4591     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
4592 
4593   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
4594   // Only do this if the current op isn't legal and the flipped is.
4595   if (!TLI.isOperationLegal(Opcode, VT) &&
4596       (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
4597       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
4598     unsigned AltOpcode;
4599     switch (Opcode) {
4600     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
4601     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
4602     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
4603     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
4604     default: llvm_unreachable("Unknown MINMAX opcode");
4605     }
4606     if (TLI.isOperationLegal(AltOpcode, VT))
4607       return DAG.getNode(AltOpcode, SDLoc(N), VT, N0, N1);
4608   }
4609 
4610   // Simplify the operands using demanded-bits information.
4611   if (SimplifyDemandedBits(SDValue(N, 0)))
4612     return SDValue(N, 0);
4613 
4614   return SDValue();
4615 }
4616 
4617 /// If this is a bitwise logic instruction and both operands have the same
4618 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)4619 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
4620   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
4621   EVT VT = N0.getValueType();
4622   unsigned LogicOpcode = N->getOpcode();
4623   unsigned HandOpcode = N0.getOpcode();
4624   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
4625           LogicOpcode == ISD::XOR) && "Expected logic opcode");
4626   assert(HandOpcode == N1.getOpcode() && "Bad input!");
4627 
4628   // Bail early if none of these transforms apply.
4629   if (N0.getNumOperands() == 0)
4630     return SDValue();
4631 
4632   // FIXME: We should check number of uses of the operands to not increase
4633   //        the instruction count for all transforms.
4634 
4635   // Handle size-changing casts.
4636   SDValue X = N0.getOperand(0);
4637   SDValue Y = N1.getOperand(0);
4638   EVT XVT = X.getValueType();
4639   SDLoc DL(N);
4640   if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
4641       HandOpcode == ISD::SIGN_EXTEND) {
4642     // If both operands have other uses, this transform would create extra
4643     // instructions without eliminating anything.
4644     if (!N0.hasOneUse() && !N1.hasOneUse())
4645       return SDValue();
4646     // We need matching integer source types.
4647     if (XVT != Y.getValueType())
4648       return SDValue();
4649     // Don't create an illegal op during or after legalization. Don't ever
4650     // create an unsupported vector op.
4651     if ((VT.isVector() || LegalOperations) &&
4652         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
4653       return SDValue();
4654     // Avoid infinite looping with PromoteIntBinOp.
4655     // TODO: Should we apply desirable/legal constraints to all opcodes?
4656     if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
4657         !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
4658       return SDValue();
4659     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
4660     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4661     return DAG.getNode(HandOpcode, DL, VT, Logic);
4662   }
4663 
4664   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
4665   if (HandOpcode == ISD::TRUNCATE) {
4666     // If both operands have other uses, this transform would create extra
4667     // instructions without eliminating anything.
4668     if (!N0.hasOneUse() && !N1.hasOneUse())
4669       return SDValue();
4670     // We need matching source types.
4671     if (XVT != Y.getValueType())
4672       return SDValue();
4673     // Don't create an illegal op during or after legalization.
4674     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
4675       return SDValue();
4676     // Be extra careful sinking truncate. If it's free, there's no benefit in
4677     // widening a binop. Also, don't create a logic op on an illegal type.
4678     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
4679       return SDValue();
4680     if (!TLI.isTypeLegal(XVT))
4681       return SDValue();
4682     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4683     return DAG.getNode(HandOpcode, DL, VT, Logic);
4684   }
4685 
4686   // For binops SHL/SRL/SRA/AND:
4687   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
4688   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
4689        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
4690       N0.getOperand(1) == N1.getOperand(1)) {
4691     // If either operand has other uses, this transform is not an improvement.
4692     if (!N0.hasOneUse() || !N1.hasOneUse())
4693       return SDValue();
4694     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4695     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
4696   }
4697 
4698   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
4699   if (HandOpcode == ISD::BSWAP) {
4700     // If either operand has other uses, this transform is not an improvement.
4701     if (!N0.hasOneUse() || !N1.hasOneUse())
4702       return SDValue();
4703     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4704     return DAG.getNode(HandOpcode, DL, VT, Logic);
4705   }
4706 
4707   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
4708   // Only perform this optimization up until type legalization, before
4709   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
4710   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
4711   // we don't want to undo this promotion.
4712   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
4713   // on scalars.
4714   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
4715        Level <= AfterLegalizeTypes) {
4716     // Input types must be integer and the same.
4717     if (XVT.isInteger() && XVT == Y.getValueType() &&
4718         !(VT.isVector() && TLI.isTypeLegal(VT) &&
4719           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
4720       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4721       return DAG.getNode(HandOpcode, DL, VT, Logic);
4722     }
4723   }
4724 
4725   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
4726   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
4727   // If both shuffles use the same mask, and both shuffle within a single
4728   // vector, then it is worthwhile to move the swizzle after the operation.
4729   // The type-legalizer generates this pattern when loading illegal
4730   // vector types from memory. In many cases this allows additional shuffle
4731   // optimizations.
4732   // There are other cases where moving the shuffle after the xor/and/or
4733   // is profitable even if shuffles don't perform a swizzle.
4734   // If both shuffles use the same mask, and both shuffles have the same first
4735   // or second operand, then it might still be profitable to move the shuffle
4736   // after the xor/and/or operation.
4737   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
4738     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
4739     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
4740     assert(X.getValueType() == Y.getValueType() &&
4741            "Inputs to shuffles are not the same type");
4742 
4743     // Check that both shuffles use the same mask. The masks are known to be of
4744     // the same length because the result vector type is the same.
4745     // Check also that shuffles have only one use to avoid introducing extra
4746     // instructions.
4747     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
4748         !SVN0->getMask().equals(SVN1->getMask()))
4749       return SDValue();
4750 
4751     // Don't try to fold this node if it requires introducing a
4752     // build vector of all zeros that might be illegal at this stage.
4753     SDValue ShOp = N0.getOperand(1);
4754     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
4755       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4756 
4757     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
4758     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
4759       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
4760                                   N0.getOperand(0), N1.getOperand(0));
4761       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
4762     }
4763 
4764     // Don't try to fold this node if it requires introducing a
4765     // build vector of all zeros that might be illegal at this stage.
4766     ShOp = N0.getOperand(0);
4767     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
4768       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4769 
4770     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
4771     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
4772       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
4773                                   N1.getOperand(1));
4774       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
4775     }
4776   }
4777 
4778   return SDValue();
4779 }
4780 
4781 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)4782 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
4783                                        const SDLoc &DL) {
4784   SDValue LL, LR, RL, RR, N0CC, N1CC;
4785   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
4786       !isSetCCEquivalent(N1, RL, RR, N1CC))
4787     return SDValue();
4788 
4789   assert(N0.getValueType() == N1.getValueType() &&
4790          "Unexpected operand types for bitwise logic op");
4791   assert(LL.getValueType() == LR.getValueType() &&
4792          RL.getValueType() == RR.getValueType() &&
4793          "Unexpected operand types for setcc");
4794 
4795   // If we're here post-legalization or the logic op type is not i1, the logic
4796   // op type must match a setcc result type. Also, all folds require new
4797   // operations on the left and right operands, so those types must match.
4798   EVT VT = N0.getValueType();
4799   EVT OpVT = LL.getValueType();
4800   if (LegalOperations || VT.getScalarType() != MVT::i1)
4801     if (VT != getSetCCResultType(OpVT))
4802       return SDValue();
4803   if (OpVT != RL.getValueType())
4804     return SDValue();
4805 
4806   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
4807   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
4808   bool IsInteger = OpVT.isInteger();
4809   if (LR == RR && CC0 == CC1 && IsInteger) {
4810     bool IsZero = isNullOrNullSplat(LR);
4811     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
4812 
4813     // All bits clear?
4814     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
4815     // All sign bits clear?
4816     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
4817     // Any bits set?
4818     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
4819     // Any sign bits set?
4820     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
4821 
4822     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
4823     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
4824     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
4825     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
4826     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
4827       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
4828       AddToWorklist(Or.getNode());
4829       return DAG.getSetCC(DL, VT, Or, LR, CC1);
4830     }
4831 
4832     // All bits set?
4833     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
4834     // All sign bits set?
4835     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
4836     // Any bits clear?
4837     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
4838     // Any sign bits clear?
4839     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
4840 
4841     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
4842     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
4843     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
4844     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
4845     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
4846       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
4847       AddToWorklist(And.getNode());
4848       return DAG.getSetCC(DL, VT, And, LR, CC1);
4849     }
4850   }
4851 
4852   // TODO: What is the 'or' equivalent of this fold?
4853   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
4854   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
4855       IsInteger && CC0 == ISD::SETNE &&
4856       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
4857        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
4858     SDValue One = DAG.getConstant(1, DL, OpVT);
4859     SDValue Two = DAG.getConstant(2, DL, OpVT);
4860     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
4861     AddToWorklist(Add.getNode());
4862     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
4863   }
4864 
4865   // Try more general transforms if the predicates match and the only user of
4866   // the compares is the 'and' or 'or'.
4867   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
4868       N0.hasOneUse() && N1.hasOneUse()) {
4869     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
4870     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
4871     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
4872       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
4873       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
4874       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
4875       SDValue Zero = DAG.getConstant(0, DL, OpVT);
4876       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
4877     }
4878 
4879     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
4880     // TODO - support non-uniform vector amounts.
4881     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
4882       // Match a shared variable operand and 2 non-opaque constant operands.
4883       ConstantSDNode *C0 = isConstOrConstSplat(LR);
4884       ConstantSDNode *C1 = isConstOrConstSplat(RR);
4885       if (LL == RL && C0 && C1 && !C0->isOpaque() && !C1->isOpaque()) {
4886         // Canonicalize larger constant as C0.
4887         if (C1->getAPIntValue().ugt(C0->getAPIntValue()))
4888           std::swap(C0, C1);
4889 
4890         // The difference of the constants must be a single bit.
4891         const APInt &C0Val = C0->getAPIntValue();
4892         const APInt &C1Val = C1->getAPIntValue();
4893         if ((C0Val - C1Val).isPowerOf2()) {
4894           // and/or (setcc X, C0, ne), (setcc X, C1, ne/eq) -->
4895           // setcc ((add X, -C1), ~(C0 - C1)), 0, ne/eq
4896           SDValue OffsetC = DAG.getConstant(-C1Val, DL, OpVT);
4897           SDValue Add = DAG.getNode(ISD::ADD, DL, OpVT, LL, OffsetC);
4898           SDValue MaskC = DAG.getConstant(~(C0Val - C1Val), DL, OpVT);
4899           SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Add, MaskC);
4900           SDValue Zero = DAG.getConstant(0, DL, OpVT);
4901           return DAG.getSetCC(DL, VT, And, Zero, CC0);
4902         }
4903       }
4904     }
4905   }
4906 
4907   // Canonicalize equivalent operands to LL == RL.
4908   if (LL == RR && LR == RL) {
4909     CC1 = ISD::getSetCCSwappedOperands(CC1);
4910     std::swap(RL, RR);
4911   }
4912 
4913   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
4914   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
4915   if (LL == RL && LR == RR) {
4916     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
4917                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
4918     if (NewCC != ISD::SETCC_INVALID &&
4919         (!LegalOperations ||
4920          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
4921           TLI.isOperationLegal(ISD::SETCC, OpVT))))
4922       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
4923   }
4924 
4925   return SDValue();
4926 }
4927 
4928 /// This contains all DAGCombine rules which reduce two values combined by
4929 /// an And operation to a single value. This makes them reusable in the context
4930 /// of visitSELECT(). Rules involving constants are not included as
4931 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)4932 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
4933   EVT VT = N1.getValueType();
4934   SDLoc DL(N);
4935 
4936   // fold (and x, undef) -> 0
4937   if (N0.isUndef() || N1.isUndef())
4938     return DAG.getConstant(0, DL, VT);
4939 
4940   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
4941     return V;
4942 
4943   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
4944       VT.getSizeInBits() <= 64) {
4945     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
4946       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
4947         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
4948         // immediate for an add, but it is legal if its top c2 bits are set,
4949         // transform the ADD so the immediate doesn't need to be materialized
4950         // in a register.
4951         APInt ADDC = ADDI->getAPIntValue();
4952         APInt SRLC = SRLI->getAPIntValue();
4953         if (ADDC.getMinSignedBits() <= 64 &&
4954             SRLC.ult(VT.getSizeInBits()) &&
4955             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
4956           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
4957                                              SRLC.getZExtValue());
4958           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
4959             ADDC |= Mask;
4960             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
4961               SDLoc DL0(N0);
4962               SDValue NewAdd =
4963                 DAG.getNode(ISD::ADD, DL0, VT,
4964                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
4965               CombineTo(N0.getNode(), NewAdd);
4966               // Return N so it doesn't get rechecked!
4967               return SDValue(N, 0);
4968             }
4969           }
4970         }
4971       }
4972     }
4973   }
4974 
4975   // Reduce bit extract of low half of an integer to the narrower type.
4976   // (and (srl i64:x, K), KMask) ->
4977   //   (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
4978   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
4979     if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
4980       if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
4981         unsigned Size = VT.getSizeInBits();
4982         const APInt &AndMask = CAnd->getAPIntValue();
4983         unsigned ShiftBits = CShift->getZExtValue();
4984 
4985         // Bail out, this node will probably disappear anyway.
4986         if (ShiftBits == 0)
4987           return SDValue();
4988 
4989         unsigned MaskBits = AndMask.countTrailingOnes();
4990         EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
4991 
4992         if (AndMask.isMask() &&
4993             // Required bits must not span the two halves of the integer and
4994             // must fit in the half size type.
4995             (ShiftBits + MaskBits <= Size / 2) &&
4996             TLI.isNarrowingProfitable(VT, HalfVT) &&
4997             TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
4998             TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
4999             TLI.isTruncateFree(VT, HalfVT) &&
5000             TLI.isZExtFree(HalfVT, VT)) {
5001           // The isNarrowingProfitable is to avoid regressions on PPC and
5002           // AArch64 which match a few 64-bit bit insert / bit extract patterns
5003           // on downstream users of this. Those patterns could probably be
5004           // extended to handle extensions mixed in.
5005 
5006           SDValue SL(N0);
5007           assert(MaskBits <= Size);
5008 
5009           // Extracting the highest bit of the low half.
5010           EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
5011           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
5012                                       N0.getOperand(0));
5013 
5014           SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
5015           SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
5016           SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
5017           SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
5018           return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
5019         }
5020       }
5021     }
5022   }
5023 
5024   return SDValue();
5025 }
5026 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)5027 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
5028                                    EVT LoadResultTy, EVT &ExtVT) {
5029   if (!AndC->getAPIntValue().isMask())
5030     return false;
5031 
5032   unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
5033 
5034   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5035   EVT LoadedVT = LoadN->getMemoryVT();
5036 
5037   if (ExtVT == LoadedVT &&
5038       (!LegalOperations ||
5039        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
5040     // ZEXTLOAD will match without needing to change the size of the value being
5041     // loaded.
5042     return true;
5043   }
5044 
5045   // Do not change the width of a volatile or atomic loads.
5046   if (!LoadN->isSimple())
5047     return false;
5048 
5049   // Do not generate loads of non-round integer types since these can
5050   // be expensive (and would be wrong if the type is not byte sized).
5051   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
5052     return false;
5053 
5054   if (LegalOperations &&
5055       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
5056     return false;
5057 
5058   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
5059     return false;
5060 
5061   return true;
5062 }
5063 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)5064 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
5065                                     ISD::LoadExtType ExtType, EVT &MemVT,
5066                                     unsigned ShAmt) {
5067   if (!LDST)
5068     return false;
5069   // Only allow byte offsets.
5070   if (ShAmt % 8)
5071     return false;
5072 
5073   // Do not generate loads of non-round integer types since these can
5074   // be expensive (and would be wrong if the type is not byte sized).
5075   if (!MemVT.isRound())
5076     return false;
5077 
5078   // Don't change the width of a volatile or atomic loads.
5079   if (!LDST->isSimple())
5080     return false;
5081 
5082   EVT LdStMemVT = LDST->getMemoryVT();
5083 
5084   // Bail out when changing the scalable property, since we can't be sure that
5085   // we're actually narrowing here.
5086   if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
5087     return false;
5088 
5089   // Verify that we are actually reducing a load width here.
5090   if (LdStMemVT.bitsLT(MemVT))
5091     return false;
5092 
5093   // Ensure that this isn't going to produce an unsupported memory access.
5094   if (ShAmt) {
5095     assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
5096     const unsigned ByteShAmt = ShAmt / 8;
5097     const Align LDSTAlign = LDST->getAlign();
5098     const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
5099     if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
5100                                 LDST->getAddressSpace(), NarrowAlign,
5101                                 LDST->getMemOperand()->getFlags()))
5102       return false;
5103   }
5104 
5105   // It's not possible to generate a constant of extended or untyped type.
5106   EVT PtrType = LDST->getBasePtr().getValueType();
5107   if (PtrType == MVT::Untyped || PtrType.isExtended())
5108     return false;
5109 
5110   if (isa<LoadSDNode>(LDST)) {
5111     LoadSDNode *Load = cast<LoadSDNode>(LDST);
5112     // Don't transform one with multiple uses, this would require adding a new
5113     // load.
5114     if (!SDValue(Load, 0).hasOneUse())
5115       return false;
5116 
5117     if (LegalOperations &&
5118         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
5119       return false;
5120 
5121     // For the transform to be legal, the load must produce only two values
5122     // (the value loaded and the chain).  Don't transform a pre-increment
5123     // load, for example, which produces an extra value.  Otherwise the
5124     // transformation is not equivalent, and the downstream logic to replace
5125     // uses gets things wrong.
5126     if (Load->getNumValues() > 2)
5127       return false;
5128 
5129     // If the load that we're shrinking is an extload and we're not just
5130     // discarding the extension we can't simply shrink the load. Bail.
5131     // TODO: It would be possible to merge the extensions in some cases.
5132     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
5133         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5134       return false;
5135 
5136     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
5137       return false;
5138   } else {
5139     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
5140     StoreSDNode *Store = cast<StoreSDNode>(LDST);
5141     // Can't write outside the original store
5142     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5143       return false;
5144 
5145     if (LegalOperations &&
5146         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
5147       return false;
5148   }
5149   return true;
5150 }
5151 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)5152 bool DAGCombiner::SearchForAndLoads(SDNode *N,
5153                                     SmallVectorImpl<LoadSDNode*> &Loads,
5154                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
5155                                     ConstantSDNode *Mask,
5156                                     SDNode *&NodeToMask) {
5157   // Recursively search for the operands, looking for loads which can be
5158   // narrowed.
5159   for (SDValue Op : N->op_values()) {
5160     if (Op.getValueType().isVector())
5161       return false;
5162 
5163     // Some constants may need fixing up later if they are too large.
5164     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
5165       if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
5166           (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
5167         NodesWithConsts.insert(N);
5168       continue;
5169     }
5170 
5171     if (!Op.hasOneUse())
5172       return false;
5173 
5174     switch(Op.getOpcode()) {
5175     case ISD::LOAD: {
5176       auto *Load = cast<LoadSDNode>(Op);
5177       EVT ExtVT;
5178       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
5179           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
5180 
5181         // ZEXTLOAD is already small enough.
5182         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
5183             ExtVT.bitsGE(Load->getMemoryVT()))
5184           continue;
5185 
5186         // Use LE to convert equal sized loads to zext.
5187         if (ExtVT.bitsLE(Load->getMemoryVT()))
5188           Loads.push_back(Load);
5189 
5190         continue;
5191       }
5192       return false;
5193     }
5194     case ISD::ZERO_EXTEND:
5195     case ISD::AssertZext: {
5196       unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
5197       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5198       EVT VT = Op.getOpcode() == ISD::AssertZext ?
5199         cast<VTSDNode>(Op.getOperand(1))->getVT() :
5200         Op.getOperand(0).getValueType();
5201 
5202       // We can accept extending nodes if the mask is wider or an equal
5203       // width to the original type.
5204       if (ExtVT.bitsGE(VT))
5205         continue;
5206       break;
5207     }
5208     case ISD::OR:
5209     case ISD::XOR:
5210     case ISD::AND:
5211       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
5212                              NodeToMask))
5213         return false;
5214       continue;
5215     }
5216 
5217     // Allow one node which will masked along with any loads found.
5218     if (NodeToMask)
5219       return false;
5220 
5221     // Also ensure that the node to be masked only produces one data result.
5222     NodeToMask = Op.getNode();
5223     if (NodeToMask->getNumValues() > 1) {
5224       bool HasValue = false;
5225       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
5226         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
5227         if (VT != MVT::Glue && VT != MVT::Other) {
5228           if (HasValue) {
5229             NodeToMask = nullptr;
5230             return false;
5231           }
5232           HasValue = true;
5233         }
5234       }
5235       assert(HasValue && "Node to be masked has no data result?");
5236     }
5237   }
5238   return true;
5239 }
5240 
BackwardsPropagateMask(SDNode * N)5241 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
5242   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
5243   if (!Mask)
5244     return false;
5245 
5246   if (!Mask->getAPIntValue().isMask())
5247     return false;
5248 
5249   // No need to do anything if the and directly uses a load.
5250   if (isa<LoadSDNode>(N->getOperand(0)))
5251     return false;
5252 
5253   SmallVector<LoadSDNode*, 8> Loads;
5254   SmallPtrSet<SDNode*, 2> NodesWithConsts;
5255   SDNode *FixupNode = nullptr;
5256   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
5257     if (Loads.size() == 0)
5258       return false;
5259 
5260     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
5261     SDValue MaskOp = N->getOperand(1);
5262 
5263     // If it exists, fixup the single node we allow in the tree that needs
5264     // masking.
5265     if (FixupNode) {
5266       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
5267       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
5268                                 FixupNode->getValueType(0),
5269                                 SDValue(FixupNode, 0), MaskOp);
5270       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
5271       if (And.getOpcode() == ISD ::AND)
5272         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
5273     }
5274 
5275     // Narrow any constants that need it.
5276     for (auto *LogicN : NodesWithConsts) {
5277       SDValue Op0 = LogicN->getOperand(0);
5278       SDValue Op1 = LogicN->getOperand(1);
5279 
5280       if (isa<ConstantSDNode>(Op0))
5281           std::swap(Op0, Op1);
5282 
5283       SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
5284                                 Op1, MaskOp);
5285 
5286       DAG.UpdateNodeOperands(LogicN, Op0, And);
5287     }
5288 
5289     // Create narrow loads.
5290     for (auto *Load : Loads) {
5291       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
5292       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
5293                                 SDValue(Load, 0), MaskOp);
5294       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
5295       if (And.getOpcode() == ISD ::AND)
5296         And = SDValue(
5297             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
5298       SDValue NewLoad = ReduceLoadWidth(And.getNode());
5299       assert(NewLoad &&
5300              "Shouldn't be masking the load if it can't be narrowed");
5301       CombineTo(Load, NewLoad, NewLoad.getValue(1));
5302     }
5303     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
5304     return true;
5305   }
5306   return false;
5307 }
5308 
5309 // Unfold
5310 //    x &  (-1 'logical shift' y)
5311 // To
5312 //    (x 'opposite logical shift' y) 'logical shift' y
5313 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)5314 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
5315   assert(N->getOpcode() == ISD::AND);
5316 
5317   SDValue N0 = N->getOperand(0);
5318   SDValue N1 = N->getOperand(1);
5319 
5320   // Do we actually prefer shifts over mask?
5321   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
5322     return SDValue();
5323 
5324   // Try to match  (-1 '[outer] logical shift' y)
5325   unsigned OuterShift;
5326   unsigned InnerShift; // The opposite direction to the OuterShift.
5327   SDValue Y;           // Shift amount.
5328   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
5329     if (!M.hasOneUse())
5330       return false;
5331     OuterShift = M->getOpcode();
5332     if (OuterShift == ISD::SHL)
5333       InnerShift = ISD::SRL;
5334     else if (OuterShift == ISD::SRL)
5335       InnerShift = ISD::SHL;
5336     else
5337       return false;
5338     if (!isAllOnesConstant(M->getOperand(0)))
5339       return false;
5340     Y = M->getOperand(1);
5341     return true;
5342   };
5343 
5344   SDValue X;
5345   if (matchMask(N1))
5346     X = N0;
5347   else if (matchMask(N0))
5348     X = N1;
5349   else
5350     return SDValue();
5351 
5352   SDLoc DL(N);
5353   EVT VT = N->getValueType(0);
5354 
5355   //     tmp = x   'opposite logical shift' y
5356   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
5357   //     ret = tmp 'logical shift' y
5358   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
5359 
5360   return T1;
5361 }
5362 
5363 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
5364 /// For a target with a bit test, this is expected to become test + set and save
5365 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)5366 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
5367   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
5368 
5369   // This is probably not worthwhile without a supported type.
5370   EVT VT = And->getValueType(0);
5371   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5372   if (!TLI.isTypeLegal(VT))
5373     return SDValue();
5374 
5375   // Look through an optional extension and find a 'not'.
5376   // TODO: Should we favor test+set even without the 'not' op?
5377   SDValue Not = And->getOperand(0), And1 = And->getOperand(1);
5378   if (Not.getOpcode() == ISD::ANY_EXTEND)
5379     Not = Not.getOperand(0);
5380   if (!isBitwiseNot(Not) || !Not.hasOneUse() || !isOneConstant(And1))
5381     return SDValue();
5382 
5383   // Look though an optional truncation. The source operand may not be the same
5384   // type as the original 'and', but that is ok because we are masking off
5385   // everything but the low bit.
5386   SDValue Srl = Not.getOperand(0);
5387   if (Srl.getOpcode() == ISD::TRUNCATE)
5388     Srl = Srl.getOperand(0);
5389 
5390   // Match a shift-right by constant.
5391   if (Srl.getOpcode() != ISD::SRL || !Srl.hasOneUse() ||
5392       !isa<ConstantSDNode>(Srl.getOperand(1)))
5393     return SDValue();
5394 
5395   // We might have looked through casts that make this transform invalid.
5396   // TODO: If the source type is wider than the result type, do the mask and
5397   //       compare in the source type.
5398   const APInt &ShiftAmt = Srl.getConstantOperandAPInt(1);
5399   unsigned VTBitWidth = VT.getSizeInBits();
5400   if (ShiftAmt.uge(VTBitWidth))
5401     return SDValue();
5402 
5403   // Turn this into a bit-test pattern using mask op + setcc:
5404   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
5405   SDLoc DL(And);
5406   SDValue X = DAG.getZExtOrTrunc(Srl.getOperand(0), DL, VT);
5407   EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
5408   SDValue Mask = DAG.getConstant(
5409       APInt::getOneBitSet(VTBitWidth, ShiftAmt.getZExtValue()), DL, VT);
5410   SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
5411   SDValue Zero = DAG.getConstant(0, DL, VT);
5412   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
5413   return DAG.getZExtOrTrunc(Setcc, DL, VT);
5414 }
5415 
visitAND(SDNode * N)5416 SDValue DAGCombiner::visitAND(SDNode *N) {
5417   SDValue N0 = N->getOperand(0);
5418   SDValue N1 = N->getOperand(1);
5419   EVT VT = N1.getValueType();
5420 
5421   // x & x --> x
5422   if (N0 == N1)
5423     return N0;
5424 
5425   // fold vector ops
5426   if (VT.isVector()) {
5427     if (SDValue FoldedVOp = SimplifyVBinOp(N))
5428       return FoldedVOp;
5429 
5430     // fold (and x, 0) -> 0, vector edition
5431     if (ISD::isBuildVectorAllZeros(N0.getNode()))
5432       // do not return N0, because undef node may exist in N0
5433       return DAG.getConstant(APInt::getNullValue(N0.getScalarValueSizeInBits()),
5434                              SDLoc(N), N0.getValueType());
5435     if (ISD::isBuildVectorAllZeros(N1.getNode()))
5436       // do not return N1, because undef node may exist in N1
5437       return DAG.getConstant(APInt::getNullValue(N1.getScalarValueSizeInBits()),
5438                              SDLoc(N), N1.getValueType());
5439 
5440     // fold (and x, -1) -> x, vector edition
5441     if (ISD::isBuildVectorAllOnes(N0.getNode()))
5442       return N1;
5443     if (ISD::isBuildVectorAllOnes(N1.getNode()))
5444       return N0;
5445 
5446     // fold (and (masked_load) (build_vec (x, ...))) to zext_masked_load
5447     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
5448     auto *BVec = dyn_cast<BuildVectorSDNode>(N1);
5449     if (MLoad && BVec && MLoad->getExtensionType() == ISD::EXTLOAD &&
5450         N0.hasOneUse() && N1.hasOneUse()) {
5451       EVT LoadVT = MLoad->getMemoryVT();
5452       EVT ExtVT = VT;
5453       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
5454         // For this AND to be a zero extension of the masked load the elements
5455         // of the BuildVec must mask the bottom bits of the extended element
5456         // type
5457         if (ConstantSDNode *Splat = BVec->getConstantSplatNode()) {
5458           uint64_t ElementSize =
5459               LoadVT.getVectorElementType().getScalarSizeInBits();
5460           if (Splat->getAPIntValue().isMask(ElementSize)) {
5461             return DAG.getMaskedLoad(
5462                 ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
5463                 MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
5464                 LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
5465                 ISD::ZEXTLOAD, MLoad->isExpandingLoad());
5466           }
5467         }
5468       }
5469     }
5470   }
5471 
5472   // fold (and c1, c2) -> c1&c2
5473   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5474   if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
5475     return C;
5476 
5477   // canonicalize constant to RHS
5478   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5479       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5480     return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
5481 
5482   // fold (and x, -1) -> x
5483   if (isAllOnesConstant(N1))
5484     return N0;
5485 
5486   // if (and x, c) is known to be zero, return 0
5487   unsigned BitWidth = VT.getScalarSizeInBits();
5488   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
5489                                    APInt::getAllOnesValue(BitWidth)))
5490     return DAG.getConstant(0, SDLoc(N), VT);
5491 
5492   if (SDValue NewSel = foldBinOpIntoSelect(N))
5493     return NewSel;
5494 
5495   // reassociate and
5496   if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
5497     return RAND;
5498 
5499   // Try to convert a constant mask AND into a shuffle clear mask.
5500   if (VT.isVector())
5501     if (SDValue Shuffle = XformToShuffleWithZero(N))
5502       return Shuffle;
5503 
5504   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
5505     return Combined;
5506 
5507   // fold (and (or x, C), D) -> D if (C & D) == D
5508   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
5509     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
5510   };
5511   if (N0.getOpcode() == ISD::OR &&
5512       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
5513     return N1;
5514   // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
5515   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
5516     SDValue N0Op0 = N0.getOperand(0);
5517     APInt Mask = ~N1C->getAPIntValue();
5518     Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
5519     if (DAG.MaskedValueIsZero(N0Op0, Mask)) {
5520       SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
5521                                  N0.getValueType(), N0Op0);
5522 
5523       // Replace uses of the AND with uses of the Zero extend node.
5524       CombineTo(N, Zext);
5525 
5526       // We actually want to replace all uses of the any_extend with the
5527       // zero_extend, to avoid duplicating things.  This will later cause this
5528       // AND to be folded.
5529       CombineTo(N0.getNode(), Zext);
5530       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
5531     }
5532   }
5533 
5534   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
5535   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
5536   // already be zero by virtue of the width of the base type of the load.
5537   //
5538   // the 'X' node here can either be nothing or an extract_vector_elt to catch
5539   // more cases.
5540   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
5541        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
5542        N0.getOperand(0).getOpcode() == ISD::LOAD &&
5543        N0.getOperand(0).getResNo() == 0) ||
5544       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
5545     LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
5546                                          N0 : N0.getOperand(0) );
5547 
5548     // Get the constant (if applicable) the zero'th operand is being ANDed with.
5549     // This can be a pure constant or a vector splat, in which case we treat the
5550     // vector as a scalar and use the splat value.
5551     APInt Constant = APInt::getNullValue(1);
5552     if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
5553       Constant = C->getAPIntValue();
5554     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
5555       APInt SplatValue, SplatUndef;
5556       unsigned SplatBitSize;
5557       bool HasAnyUndefs;
5558       bool IsSplat = Vector->isConstantSplat(SplatValue, SplatUndef,
5559                                              SplatBitSize, HasAnyUndefs);
5560       if (IsSplat) {
5561         // Undef bits can contribute to a possible optimisation if set, so
5562         // set them.
5563         SplatValue |= SplatUndef;
5564 
5565         // The splat value may be something like "0x00FFFFFF", which means 0 for
5566         // the first vector value and FF for the rest, repeating. We need a mask
5567         // that will apply equally to all members of the vector, so AND all the
5568         // lanes of the constant together.
5569         unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
5570 
5571         // If the splat value has been compressed to a bitlength lower
5572         // than the size of the vector lane, we need to re-expand it to
5573         // the lane size.
5574         if (EltBitWidth > SplatBitSize)
5575           for (SplatValue = SplatValue.zextOrTrunc(EltBitWidth);
5576                SplatBitSize < EltBitWidth; SplatBitSize = SplatBitSize * 2)
5577             SplatValue |= SplatValue.shl(SplatBitSize);
5578 
5579         // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
5580         // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
5581         if ((SplatBitSize % EltBitWidth) == 0) {
5582           Constant = APInt::getAllOnesValue(EltBitWidth);
5583           for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
5584             Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
5585         }
5586       }
5587     }
5588 
5589     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
5590     // actually legal and isn't going to get expanded, else this is a false
5591     // optimisation.
5592     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
5593                                                     Load->getValueType(0),
5594                                                     Load->getMemoryVT());
5595 
5596     // Resize the constant to the same size as the original memory access before
5597     // extension. If it is still the AllOnesValue then this AND is completely
5598     // unneeded.
5599     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
5600 
5601     bool B;
5602     switch (Load->getExtensionType()) {
5603     default: B = false; break;
5604     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
5605     case ISD::ZEXTLOAD:
5606     case ISD::NON_EXTLOAD: B = true; break;
5607     }
5608 
5609     if (B && Constant.isAllOnesValue()) {
5610       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
5611       // preserve semantics once we get rid of the AND.
5612       SDValue NewLoad(Load, 0);
5613 
5614       // Fold the AND away. NewLoad may get replaced immediately.
5615       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
5616 
5617       if (Load->getExtensionType() == ISD::EXTLOAD) {
5618         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
5619                               Load->getValueType(0), SDLoc(Load),
5620                               Load->getChain(), Load->getBasePtr(),
5621                               Load->getOffset(), Load->getMemoryVT(),
5622                               Load->getMemOperand());
5623         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
5624         if (Load->getNumValues() == 3) {
5625           // PRE/POST_INC loads have 3 values.
5626           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
5627                            NewLoad.getValue(2) };
5628           CombineTo(Load, To, 3, true);
5629         } else {
5630           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
5631         }
5632       }
5633 
5634       return SDValue(N, 0); // Return N so it doesn't get rechecked!
5635     }
5636   }
5637 
5638   // fold (and (masked_gather x)) -> (zext_masked_gather x)
5639   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
5640     EVT MemVT = GN0->getMemoryVT();
5641     EVT ScalarVT = MemVT.getScalarType();
5642 
5643     if (SDValue(GN0, 0).hasOneUse() &&
5644         isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
5645         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
5646       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
5647                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
5648 
5649       SDValue ZExtLoad = DAG.getMaskedGather(
5650           DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops,
5651           GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD);
5652 
5653       CombineTo(N, ZExtLoad);
5654       AddToWorklist(ZExtLoad.getNode());
5655       // Avoid recheck of N.
5656       return SDValue(N, 0);
5657     }
5658   }
5659 
5660   // fold (and (load x), 255) -> (zextload x, i8)
5661   // fold (and (extload x, i16), 255) -> (zextload x, i8)
5662   // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8)
5663   if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
5664                                 (N0.getOpcode() == ISD::ANY_EXTEND &&
5665                                  N0.getOperand(0).getOpcode() == ISD::LOAD))) {
5666     if (SDValue Res = ReduceLoadWidth(N)) {
5667       LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
5668         ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
5669       AddToWorklist(N);
5670       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res);
5671       return SDValue(N, 0);
5672     }
5673   }
5674 
5675   if (LegalTypes) {
5676     // Attempt to propagate the AND back up to the leaves which, if they're
5677     // loads, can be combined to narrow loads and the AND node can be removed.
5678     // Perform after legalization so that extend nodes will already be
5679     // combined into the loads.
5680     if (BackwardsPropagateMask(N))
5681       return SDValue(N, 0);
5682   }
5683 
5684   if (SDValue Combined = visitANDLike(N0, N1, N))
5685     return Combined;
5686 
5687   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
5688   if (N0.getOpcode() == N1.getOpcode())
5689     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
5690       return V;
5691 
5692   // Masking the negated extension of a boolean is just the zero-extended
5693   // boolean:
5694   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
5695   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
5696   //
5697   // Note: the SimplifyDemandedBits fold below can make an information-losing
5698   // transform, and then we have no way to find this better fold.
5699   if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
5700     if (isNullOrNullSplat(N0.getOperand(0))) {
5701       SDValue SubRHS = N0.getOperand(1);
5702       if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
5703           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
5704         return SubRHS;
5705       if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
5706           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
5707         return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
5708     }
5709   }
5710 
5711   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
5712   // fold (and (sra)) -> (and (srl)) when possible.
5713   if (SimplifyDemandedBits(SDValue(N, 0)))
5714     return SDValue(N, 0);
5715 
5716   // fold (zext_inreg (extload x)) -> (zextload x)
5717   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
5718   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
5719       (ISD::isEXTLoad(N0.getNode()) ||
5720        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
5721     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
5722     EVT MemVT = LN0->getMemoryVT();
5723     // If we zero all the possible extended bits, then we can turn this into
5724     // a zextload if we are running before legalize or the operation is legal.
5725     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
5726     unsigned MemBitSize = MemVT.getScalarSizeInBits();
5727     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
5728     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
5729         ((!LegalOperations && LN0->isSimple()) ||
5730          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
5731       SDValue ExtLoad =
5732           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
5733                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
5734       AddToWorklist(N);
5735       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
5736       return SDValue(N, 0); // Return N so it doesn't get rechecked!
5737     }
5738   }
5739 
5740   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
5741   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
5742     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
5743                                            N0.getOperand(1), false))
5744       return BSwap;
5745   }
5746 
5747   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
5748     return Shifts;
5749 
5750   if (TLI.hasBitTest(N0, N1))
5751     if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
5752       return V;
5753 
5754   // Recognize the following pattern:
5755   //
5756   // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
5757   //
5758   // where bitmask is a mask that clears the upper bits of AndVT. The
5759   // number of bits in bitmask must be a power of two.
5760   auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
5761     if (LHS->getOpcode() != ISD::SIGN_EXTEND)
5762       return false;
5763 
5764     auto *C = dyn_cast<ConstantSDNode>(RHS);
5765     if (!C)
5766       return false;
5767 
5768     if (!C->getAPIntValue().isMask(
5769             LHS.getOperand(0).getValueType().getFixedSizeInBits()))
5770       return false;
5771 
5772     return true;
5773   };
5774 
5775   // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
5776   if (IsAndZeroExtMask(N0, N1))
5777     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0.getOperand(0));
5778 
5779   return SDValue();
5780 }
5781 
5782 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)5783 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
5784                                         bool DemandHighBits) {
5785   if (!LegalOperations)
5786     return SDValue();
5787 
5788   EVT VT = N->getValueType(0);
5789   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
5790     return SDValue();
5791   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
5792     return SDValue();
5793 
5794   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
5795   bool LookPassAnd0 = false;
5796   bool LookPassAnd1 = false;
5797   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
5798       std::swap(N0, N1);
5799   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
5800       std::swap(N0, N1);
5801   if (N0.getOpcode() == ISD::AND) {
5802     if (!N0.getNode()->hasOneUse())
5803       return SDValue();
5804     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5805     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
5806     // This is needed for X86.
5807     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
5808                   N01C->getZExtValue() != 0xFFFF))
5809       return SDValue();
5810     N0 = N0.getOperand(0);
5811     LookPassAnd0 = true;
5812   }
5813 
5814   if (N1.getOpcode() == ISD::AND) {
5815     if (!N1.getNode()->hasOneUse())
5816       return SDValue();
5817     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
5818     if (!N11C || N11C->getZExtValue() != 0xFF)
5819       return SDValue();
5820     N1 = N1.getOperand(0);
5821     LookPassAnd1 = true;
5822   }
5823 
5824   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
5825     std::swap(N0, N1);
5826   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
5827     return SDValue();
5828   if (!N0.getNode()->hasOneUse() || !N1.getNode()->hasOneUse())
5829     return SDValue();
5830 
5831   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5832   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
5833   if (!N01C || !N11C)
5834     return SDValue();
5835   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
5836     return SDValue();
5837 
5838   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
5839   SDValue N00 = N0->getOperand(0);
5840   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
5841     if (!N00.getNode()->hasOneUse())
5842       return SDValue();
5843     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
5844     if (!N001C || N001C->getZExtValue() != 0xFF)
5845       return SDValue();
5846     N00 = N00.getOperand(0);
5847     LookPassAnd0 = true;
5848   }
5849 
5850   SDValue N10 = N1->getOperand(0);
5851   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
5852     if (!N10.getNode()->hasOneUse())
5853       return SDValue();
5854     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
5855     // Also allow 0xFFFF since the bits will be shifted out. This is needed
5856     // for X86.
5857     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
5858                    N101C->getZExtValue() != 0xFFFF))
5859       return SDValue();
5860     N10 = N10.getOperand(0);
5861     LookPassAnd1 = true;
5862   }
5863 
5864   if (N00 != N10)
5865     return SDValue();
5866 
5867   // Make sure everything beyond the low halfword gets set to zero since the SRL
5868   // 16 will clear the top bits.
5869   unsigned OpSizeInBits = VT.getSizeInBits();
5870   if (DemandHighBits && OpSizeInBits > 16) {
5871     // If the left-shift isn't masked out then the only way this is a bswap is
5872     // if all bits beyond the low 8 are 0. In that case the entire pattern
5873     // reduces to a left shift anyway: leave it for other parts of the combiner.
5874     if (!LookPassAnd0)
5875       return SDValue();
5876 
5877     // However, if the right shift isn't masked out then it might be because
5878     // it's not needed. See if we can spot that too.
5879     if (!LookPassAnd1 &&
5880         !DAG.MaskedValueIsZero(
5881             N10, APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - 16)))
5882       return SDValue();
5883   }
5884 
5885   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
5886   if (OpSizeInBits > 16) {
5887     SDLoc DL(N);
5888     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
5889                       DAG.getConstant(OpSizeInBits - 16, DL,
5890                                       getShiftAmountTy(VT)));
5891   }
5892   return Res;
5893 }
5894 
5895 /// Return true if the specified node is an element that makes up a 32-bit
5896 /// packed halfword byteswap.
5897 /// ((x & 0x000000ff) << 8) |
5898 /// ((x & 0x0000ff00) >> 8) |
5899 /// ((x & 0x00ff0000) << 8) |
5900 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)5901 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
5902   if (!N.getNode()->hasOneUse())
5903     return false;
5904 
5905   unsigned Opc = N.getOpcode();
5906   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
5907     return false;
5908 
5909   SDValue N0 = N.getOperand(0);
5910   unsigned Opc0 = N0.getOpcode();
5911   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
5912     return false;
5913 
5914   ConstantSDNode *N1C = nullptr;
5915   // SHL or SRL: look upstream for AND mask operand
5916   if (Opc == ISD::AND)
5917     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5918   else if (Opc0 == ISD::AND)
5919     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5920   if (!N1C)
5921     return false;
5922 
5923   unsigned MaskByteOffset;
5924   switch (N1C->getZExtValue()) {
5925   default:
5926     return false;
5927   case 0xFF:       MaskByteOffset = 0; break;
5928   case 0xFF00:     MaskByteOffset = 1; break;
5929   case 0xFFFF:
5930     // In case demanded bits didn't clear the bits that will be shifted out.
5931     // This is needed for X86.
5932     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
5933       MaskByteOffset = 1;
5934       break;
5935     }
5936     return false;
5937   case 0xFF0000:   MaskByteOffset = 2; break;
5938   case 0xFF000000: MaskByteOffset = 3; break;
5939   }
5940 
5941   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
5942   if (Opc == ISD::AND) {
5943     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
5944       // (x >> 8) & 0xff
5945       // (x >> 8) & 0xff0000
5946       if (Opc0 != ISD::SRL)
5947         return false;
5948       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5949       if (!C || C->getZExtValue() != 8)
5950         return false;
5951     } else {
5952       // (x << 8) & 0xff00
5953       // (x << 8) & 0xff000000
5954       if (Opc0 != ISD::SHL)
5955         return false;
5956       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5957       if (!C || C->getZExtValue() != 8)
5958         return false;
5959     }
5960   } else if (Opc == ISD::SHL) {
5961     // (x & 0xff) << 8
5962     // (x & 0xff0000) << 8
5963     if (MaskByteOffset != 0 && MaskByteOffset != 2)
5964       return false;
5965     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5966     if (!C || C->getZExtValue() != 8)
5967       return false;
5968   } else { // Opc == ISD::SRL
5969     // (x & 0xff00) >> 8
5970     // (x & 0xff000000) >> 8
5971     if (MaskByteOffset != 1 && MaskByteOffset != 3)
5972       return false;
5973     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5974     if (!C || C->getZExtValue() != 8)
5975       return false;
5976   }
5977 
5978   if (Parts[MaskByteOffset])
5979     return false;
5980 
5981   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
5982   return true;
5983 }
5984 
5985 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)5986 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
5987   if (N.getOpcode() == ISD::OR)
5988     return isBSwapHWordElement(N.getOperand(0), Parts) &&
5989            isBSwapHWordElement(N.getOperand(1), Parts);
5990 
5991   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
5992     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
5993     if (!C || C->getAPIntValue() != 16)
5994       return false;
5995     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
5996     return true;
5997   }
5998 
5999   return false;
6000 }
6001 
6002 // Match this pattern:
6003 //   (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
6004 // And rewrite this to:
6005 //   (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT,EVT ShiftAmountTy)6006 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
6007                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
6008                                        SDValue N1, EVT VT, EVT ShiftAmountTy) {
6009   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
6010          "MatchBSwapHWordOrAndAnd: expecting i32");
6011   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6012     return SDValue();
6013   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
6014     return SDValue();
6015   // TODO: this is too restrictive; lifting this restriction requires more tests
6016   if (!N0->hasOneUse() || !N1->hasOneUse())
6017     return SDValue();
6018   ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
6019   ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
6020   if (!Mask0 || !Mask1)
6021     return SDValue();
6022   if (Mask0->getAPIntValue() != 0xff00ff00 ||
6023       Mask1->getAPIntValue() != 0x00ff00ff)
6024     return SDValue();
6025   SDValue Shift0 = N0.getOperand(0);
6026   SDValue Shift1 = N1.getOperand(0);
6027   if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
6028     return SDValue();
6029   ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
6030   ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
6031   if (!ShiftAmt0 || !ShiftAmt1)
6032     return SDValue();
6033   if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
6034     return SDValue();
6035   if (Shift0.getOperand(0) != Shift1.getOperand(0))
6036     return SDValue();
6037 
6038   SDLoc DL(N);
6039   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
6040   SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
6041   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6042 }
6043 
6044 /// Match a 32-bit packed halfword bswap. That is
6045 /// ((x & 0x000000ff) << 8) |
6046 /// ((x & 0x0000ff00) >> 8) |
6047 /// ((x & 0x00ff0000) << 8) |
6048 /// ((x & 0xff000000) >> 8)
6049 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)6050 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
6051   if (!LegalOperations)
6052     return SDValue();
6053 
6054   EVT VT = N->getValueType(0);
6055   if (VT != MVT::i32)
6056     return SDValue();
6057   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6058     return SDValue();
6059 
6060   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
6061                                               getShiftAmountTy(VT)))
6062   return BSwap;
6063 
6064   // Try again with commuted operands.
6065   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
6066                                               getShiftAmountTy(VT)))
6067   return BSwap;
6068 
6069 
6070   // Look for either
6071   // (or (bswaphpair), (bswaphpair))
6072   // (or (or (bswaphpair), (and)), (and))
6073   // (or (or (and), (bswaphpair)), (and))
6074   SDNode *Parts[4] = {};
6075 
6076   if (isBSwapHWordPair(N0, Parts)) {
6077     // (or (or (and), (and)), (or (and), (and)))
6078     if (!isBSwapHWordPair(N1, Parts))
6079       return SDValue();
6080   } else if (N0.getOpcode() == ISD::OR) {
6081     // (or (or (or (and), (and)), (and)), (and))
6082     if (!isBSwapHWordElement(N1, Parts))
6083       return SDValue();
6084     SDValue N00 = N0.getOperand(0);
6085     SDValue N01 = N0.getOperand(1);
6086     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
6087         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
6088       return SDValue();
6089   } else
6090     return SDValue();
6091 
6092   // Make sure the parts are all coming from the same node.
6093   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
6094     return SDValue();
6095 
6096   SDLoc DL(N);
6097   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
6098                               SDValue(Parts[0], 0));
6099 
6100   // Result of the bswap should be rotated by 16. If it's not legal, then
6101   // do  (x << 16) | (x >> 16).
6102   SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
6103   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
6104     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
6105   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6106     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6107   return DAG.getNode(ISD::OR, DL, VT,
6108                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
6109                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
6110 }
6111 
6112 /// This contains all DAGCombine rules which reduce two values combined by
6113 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)6114 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
6115   EVT VT = N1.getValueType();
6116   SDLoc DL(N);
6117 
6118   // fold (or x, undef) -> -1
6119   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
6120     return DAG.getAllOnesConstant(DL, VT);
6121 
6122   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
6123     return V;
6124 
6125   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
6126   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
6127       // Don't increase # computations.
6128       (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
6129     // We can only do this xform if we know that bits from X that are set in C2
6130     // but not in C1 are already zero.  Likewise for Y.
6131     if (const ConstantSDNode *N0O1C =
6132         getAsNonOpaqueConstant(N0.getOperand(1))) {
6133       if (const ConstantSDNode *N1O1C =
6134           getAsNonOpaqueConstant(N1.getOperand(1))) {
6135         // We can only do this xform if we know that bits from X that are set in
6136         // C2 but not in C1 are already zero.  Likewise for Y.
6137         const APInt &LHSMask = N0O1C->getAPIntValue();
6138         const APInt &RHSMask = N1O1C->getAPIntValue();
6139 
6140         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
6141             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
6142           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
6143                                   N0.getOperand(0), N1.getOperand(0));
6144           return DAG.getNode(ISD::AND, DL, VT, X,
6145                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
6146         }
6147       }
6148     }
6149   }
6150 
6151   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
6152   if (N0.getOpcode() == ISD::AND &&
6153       N1.getOpcode() == ISD::AND &&
6154       N0.getOperand(0) == N1.getOperand(0) &&
6155       // Don't increase # computations.
6156       (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
6157     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
6158                             N0.getOperand(1), N1.getOperand(1));
6159     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
6160   }
6161 
6162   return SDValue();
6163 }
6164 
6165 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)6166 static SDValue visitORCommutative(
6167     SelectionDAG &DAG, SDValue N0, SDValue N1, SDNode *N) {
6168   EVT VT = N0.getValueType();
6169   if (N0.getOpcode() == ISD::AND) {
6170     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
6171     if (isBitwiseNot(N0.getOperand(1)) && N0.getOperand(1).getOperand(0) == N1)
6172       return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(0), N1);
6173 
6174     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
6175     if (isBitwiseNot(N0.getOperand(0)) && N0.getOperand(0).getOperand(0) == N1)
6176       return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(1), N1);
6177   }
6178 
6179   return SDValue();
6180 }
6181 
visitOR(SDNode * N)6182 SDValue DAGCombiner::visitOR(SDNode *N) {
6183   SDValue N0 = N->getOperand(0);
6184   SDValue N1 = N->getOperand(1);
6185   EVT VT = N1.getValueType();
6186 
6187   // x | x --> x
6188   if (N0 == N1)
6189     return N0;
6190 
6191   // fold vector ops
6192   if (VT.isVector()) {
6193     if (SDValue FoldedVOp = SimplifyVBinOp(N))
6194       return FoldedVOp;
6195 
6196     // fold (or x, 0) -> x, vector edition
6197     if (ISD::isBuildVectorAllZeros(N0.getNode()))
6198       return N1;
6199     if (ISD::isBuildVectorAllZeros(N1.getNode()))
6200       return N0;
6201 
6202     // fold (or x, -1) -> -1, vector edition
6203     if (ISD::isBuildVectorAllOnes(N0.getNode()))
6204       // do not return N0, because undef node may exist in N0
6205       return DAG.getAllOnesConstant(SDLoc(N), N0.getValueType());
6206     if (ISD::isBuildVectorAllOnes(N1.getNode()))
6207       // do not return N1, because undef node may exist in N1
6208       return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
6209 
6210     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
6211     // Do this only if the resulting shuffle is legal.
6212     if (isa<ShuffleVectorSDNode>(N0) &&
6213         isa<ShuffleVectorSDNode>(N1) &&
6214         // Avoid folding a node with illegal type.
6215         TLI.isTypeLegal(VT)) {
6216       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
6217       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
6218       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
6219       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
6220       // Ensure both shuffles have a zero input.
6221       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
6222         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
6223         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
6224         const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0);
6225         const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1);
6226         bool CanFold = true;
6227         int NumElts = VT.getVectorNumElements();
6228         SmallVector<int, 4> Mask(NumElts);
6229 
6230         for (int i = 0; i != NumElts; ++i) {
6231           int M0 = SV0->getMaskElt(i);
6232           int M1 = SV1->getMaskElt(i);
6233 
6234           // Determine if either index is pointing to a zero vector.
6235           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
6236           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
6237 
6238           // If one element is zero and the otherside is undef, keep undef.
6239           // This also handles the case that both are undef.
6240           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0)) {
6241             Mask[i] = -1;
6242             continue;
6243           }
6244 
6245           // Make sure only one of the elements is zero.
6246           if (M0Zero == M1Zero) {
6247             CanFold = false;
6248             break;
6249           }
6250 
6251           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
6252 
6253           // We have a zero and non-zero element. If the non-zero came from
6254           // SV0 make the index a LHS index. If it came from SV1, make it
6255           // a RHS index. We need to mod by NumElts because we don't care
6256           // which operand it came from in the original shuffles.
6257           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
6258         }
6259 
6260         if (CanFold) {
6261           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
6262           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
6263 
6264           SDValue LegalShuffle =
6265               TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
6266                                           Mask, DAG);
6267           if (LegalShuffle)
6268             return LegalShuffle;
6269         }
6270       }
6271     }
6272   }
6273 
6274   // fold (or c1, c2) -> c1|c2
6275   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
6276   if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
6277     return C;
6278 
6279   // canonicalize constant to RHS
6280   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6281      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6282     return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
6283 
6284   // fold (or x, 0) -> x
6285   if (isNullConstant(N1))
6286     return N0;
6287 
6288   // fold (or x, -1) -> -1
6289   if (isAllOnesConstant(N1))
6290     return N1;
6291 
6292   if (SDValue NewSel = foldBinOpIntoSelect(N))
6293     return NewSel;
6294 
6295   // fold (or x, c) -> c iff (x & ~c) == 0
6296   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
6297     return N1;
6298 
6299   if (SDValue Combined = visitORLike(N0, N1, N))
6300     return Combined;
6301 
6302   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
6303     return Combined;
6304 
6305   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
6306   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
6307     return BSwap;
6308   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
6309     return BSwap;
6310 
6311   // reassociate or
6312   if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
6313     return ROR;
6314 
6315   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
6316   // iff (c1 & c2) != 0 or c1/c2 are undef.
6317   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
6318     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
6319   };
6320   if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
6321       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
6322     if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
6323                                                  {N1, N0.getOperand(1)})) {
6324       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
6325       AddToWorklist(IOR.getNode());
6326       return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
6327     }
6328   }
6329 
6330   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
6331     return Combined;
6332   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
6333     return Combined;
6334 
6335   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
6336   if (N0.getOpcode() == N1.getOpcode())
6337     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
6338       return V;
6339 
6340   // See if this is some rotate idiom.
6341   if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
6342     return Rot;
6343 
6344   if (SDValue Load = MatchLoadCombine(N))
6345     return Load;
6346 
6347   // Simplify the operands using demanded-bits information.
6348   if (SimplifyDemandedBits(SDValue(N, 0)))
6349     return SDValue(N, 0);
6350 
6351   // If OR can be rewritten into ADD, try combines based on ADD.
6352   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
6353       DAG.haveNoCommonBitsSet(N0, N1))
6354     if (SDValue Combined = visitADDLike(N))
6355       return Combined;
6356 
6357   return SDValue();
6358 }
6359 
stripConstantMask(SelectionDAG & DAG,SDValue Op,SDValue & Mask)6360 static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) {
6361   if (Op.getOpcode() == ISD::AND &&
6362       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
6363     Mask = Op.getOperand(1);
6364     return Op.getOperand(0);
6365   }
6366   return Op;
6367 }
6368 
6369 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)6370 static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift,
6371                             SDValue &Mask) {
6372   Op = stripConstantMask(DAG, Op, Mask);
6373   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
6374     Shift = Op;
6375     return true;
6376   }
6377   return false;
6378 }
6379 
6380 /// Helper function for visitOR to extract the needed side of a rotate idiom
6381 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
6382 /// InstCombine merged some outside op with one of the shifts from
6383 /// the rotate pattern.
6384 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
6385 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
6386 /// patterns:
6387 ///
6388 ///   (or (add v v) (shrl v bitwidth-1)):
6389 ///     expands (add v v) -> (shl v 1)
6390 ///
6391 ///   (or (mul v c0) (shrl (mul v c1) c2)):
6392 ///     expands (mul v c0) -> (shl (mul v c1) c3)
6393 ///
6394 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
6395 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
6396 ///
6397 ///   (or (shl v c0) (shrl (shl v c1) c2)):
6398 ///     expands (shl v c0) -> (shl (shl v c1) c3)
6399 ///
6400 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
6401 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
6402 ///
6403 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)6404 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
6405                                      SDValue ExtractFrom, SDValue &Mask,
6406                                      const SDLoc &DL) {
6407   assert(OppShift && ExtractFrom && "Empty SDValue");
6408   assert(
6409       (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) &&
6410       "Existing shift must be valid as a rotate half");
6411 
6412   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
6413 
6414   // Value and Type of the shift.
6415   SDValue OppShiftLHS = OppShift.getOperand(0);
6416   EVT ShiftedVT = OppShiftLHS.getValueType();
6417 
6418   // Amount of the existing shift.
6419   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
6420 
6421   // (add v v) -> (shl v 1)
6422   // TODO: Should this be a general DAG canonicalization?
6423   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
6424       ExtractFrom.getOpcode() == ISD::ADD &&
6425       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
6426       ExtractFrom.getOperand(0) == OppShiftLHS &&
6427       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
6428     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
6429                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
6430 
6431   // Preconditions:
6432   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
6433   //
6434   // Find opcode of the needed shift to be extracted from (op0 v c0).
6435   unsigned Opcode = ISD::DELETED_NODE;
6436   bool IsMulOrDiv = false;
6437   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
6438   // opcode or its arithmetic (mul or udiv) variant.
6439   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
6440     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
6441     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
6442       return false;
6443     Opcode = NeededShift;
6444     return true;
6445   };
6446   // op0 must be either the needed shift opcode or the mul/udiv equivalent
6447   // that the needed shift can be extracted from.
6448   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
6449       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
6450     return SDValue();
6451 
6452   // op0 must be the same opcode on both sides, have the same LHS argument,
6453   // and produce the same value type.
6454   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
6455       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
6456       ShiftedVT != ExtractFrom.getValueType())
6457     return SDValue();
6458 
6459   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
6460   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
6461   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
6462   ConstantSDNode *ExtractFromCst =
6463       isConstOrConstSplat(ExtractFrom.getOperand(1));
6464   // TODO: We should be able to handle non-uniform constant vectors for these values
6465   // Check that we have constant values.
6466   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
6467       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
6468       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
6469     return SDValue();
6470 
6471   // Compute the shift amount we need to extract to complete the rotate.
6472   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
6473   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
6474     return SDValue();
6475   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
6476   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
6477   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
6478   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
6479   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
6480 
6481   // Now try extract the needed shift from the ExtractFrom op and see if the
6482   // result matches up with the existing shift's LHS op.
6483   if (IsMulOrDiv) {
6484     // Op to extract from is a mul or udiv by a constant.
6485     // Check:
6486     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
6487     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
6488     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
6489                                                  NeededShiftAmt.getZExtValue());
6490     APInt ResultAmt;
6491     APInt Rem;
6492     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
6493     if (Rem != 0 || ResultAmt != OppLHSAmt)
6494       return SDValue();
6495   } else {
6496     // Op to extract from is a shift by a constant.
6497     // Check:
6498     //      c2 - (bitwidth(op0 v c0) - c1) == c0
6499     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
6500                                           ExtractFromAmt.getBitWidth()))
6501       return SDValue();
6502   }
6503 
6504   // Return the expanded shift op that should allow a rotate to be formed.
6505   EVT ShiftVT = OppShift.getOperand(1).getValueType();
6506   EVT ResVT = ExtractFrom.getValueType();
6507   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
6508   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
6509 }
6510 
6511 // Return true if we can prove that, whenever Neg and Pos are both in the
6512 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
6513 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
6514 //
6515 //     (or (shift1 X, Neg), (shift2 X, Pos))
6516 //
6517 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
6518 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
6519 // to consider shift amounts with defined behavior.
6520 //
6521 // The IsRotate flag should be set when the LHS of both shifts is the same.
6522 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate)6523 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
6524                            SelectionDAG &DAG, bool IsRotate) {
6525   // If EltSize is a power of 2 then:
6526   //
6527   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
6528   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
6529   //
6530   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
6531   // for the stronger condition:
6532   //
6533   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
6534   //
6535   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
6536   // we can just replace Neg with Neg' for the rest of the function.
6537   //
6538   // In other cases we check for the even stronger condition:
6539   //
6540   //     Neg == EltSize - Pos                                    [B]
6541   //
6542   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
6543   // behavior if Pos == 0 (and consequently Neg == EltSize).
6544   //
6545   // We could actually use [A] whenever EltSize is a power of 2, but the
6546   // only extra cases that it would match are those uninteresting ones
6547   // where Neg and Pos are never in range at the same time.  E.g. for
6548   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
6549   // as well as (sub 32, Pos), but:
6550   //
6551   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
6552   //
6553   // always invokes undefined behavior for 32-bit X.
6554   //
6555   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
6556   //
6557   // NOTE: We can only do this when matching an AND and not a general
6558   // funnel shift.
6559   unsigned MaskLoBits = 0;
6560   if (IsRotate && Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) {
6561     if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) {
6562       KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0));
6563       unsigned Bits = Log2_64(EltSize);
6564       if (NegC->getAPIntValue().getActiveBits() <= Bits &&
6565           ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) {
6566         Neg = Neg.getOperand(0);
6567         MaskLoBits = Bits;
6568       }
6569     }
6570   }
6571 
6572   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
6573   if (Neg.getOpcode() != ISD::SUB)
6574     return false;
6575   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
6576   if (!NegC)
6577     return false;
6578   SDValue NegOp1 = Neg.getOperand(1);
6579 
6580   // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with
6581   // Pos'.  The truncation is redundant for the purpose of the equality.
6582   if (MaskLoBits && Pos.getOpcode() == ISD::AND) {
6583     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) {
6584       KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0));
6585       if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits &&
6586           ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >=
6587            MaskLoBits))
6588         Pos = Pos.getOperand(0);
6589     }
6590   }
6591 
6592   // The condition we need is now:
6593   //
6594   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
6595   //
6596   // If NegOp1 == Pos then we need:
6597   //
6598   //              EltSize & Mask == NegC & Mask
6599   //
6600   // (because "x & Mask" is a truncation and distributes through subtraction).
6601   //
6602   // We also need to account for a potential truncation of NegOp1 if the amount
6603   // has already been legalized to a shift amount type.
6604   APInt Width;
6605   if ((Pos == NegOp1) ||
6606       (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
6607     Width = NegC->getAPIntValue();
6608 
6609   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
6610   // Then the condition we want to prove becomes:
6611   //
6612   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
6613   //
6614   // which, again because "x & Mask" is a truncation, becomes:
6615   //
6616   //                NegC & Mask == (EltSize - PosC) & Mask
6617   //             EltSize & Mask == (NegC + PosC) & Mask
6618   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
6619     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
6620       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
6621     else
6622       return false;
6623   } else
6624     return false;
6625 
6626   // Now we just need to check that EltSize & Mask == Width & Mask.
6627   if (MaskLoBits)
6628     // EltSize & Mask is 0 since Mask is EltSize - 1.
6629     return Width.getLoBits(MaskLoBits) == 0;
6630   return Width == EltSize;
6631 }
6632 
6633 // A subroutine of MatchRotate used once we have found an OR of two opposite
6634 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
6635 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
6636 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
6637 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)6638 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
6639                                        SDValue Neg, SDValue InnerPos,
6640                                        SDValue InnerNeg, unsigned PosOpcode,
6641                                        unsigned NegOpcode, const SDLoc &DL) {
6642   // fold (or (shl x, (*ext y)),
6643   //          (srl x, (*ext (sub 32, y)))) ->
6644   //   (rotl x, y) or (rotr x, (sub 32, y))
6645   //
6646   // fold (or (shl x, (*ext (sub 32, y))),
6647   //          (srl x, (*ext y))) ->
6648   //   (rotr x, y) or (rotl x, (sub 32, y))
6649   EVT VT = Shifted.getValueType();
6650   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
6651                      /*IsRotate*/ true)) {
6652     bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
6653     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
6654                        HasPos ? Pos : Neg);
6655   }
6656 
6657   return SDValue();
6658 }
6659 
6660 // A subroutine of MatchRotate used once we have found an OR of two opposite
6661 // shifts of N0 + N1.  If Neg == <operand size> - Pos then the OR reduces
6662 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
6663 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
6664 // Neg with outer conversions stripped away.
6665 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)6666 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
6667                                        SDValue Neg, SDValue InnerPos,
6668                                        SDValue InnerNeg, unsigned PosOpcode,
6669                                        unsigned NegOpcode, const SDLoc &DL) {
6670   EVT VT = N0.getValueType();
6671   unsigned EltBits = VT.getScalarSizeInBits();
6672 
6673   // fold (or (shl x0, (*ext y)),
6674   //          (srl x1, (*ext (sub 32, y)))) ->
6675   //   (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
6676   //
6677   // fold (or (shl x0, (*ext (sub 32, y))),
6678   //          (srl x1, (*ext y))) ->
6679   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
6680   if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
6681     bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
6682     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
6683                        HasPos ? Pos : Neg);
6684   }
6685 
6686   // Matching the shift+xor cases, we can't easily use the xor'd shift amount
6687   // so for now just use the PosOpcode case if its legal.
6688   // TODO: When can we use the NegOpcode case?
6689   if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
6690     auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
6691       if (Op.getOpcode() != BinOpc)
6692         return false;
6693       ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1));
6694       return Cst && (Cst->getAPIntValue() == Imm);
6695     };
6696 
6697     // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
6698     //   -> (fshl x0, x1, y)
6699     if (IsBinOpImm(N1, ISD::SRL, 1) &&
6700         IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
6701         InnerPos == InnerNeg.getOperand(0) &&
6702         TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
6703       return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos);
6704     }
6705 
6706     // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
6707     //   -> (fshr x0, x1, y)
6708     if (IsBinOpImm(N0, ISD::SHL, 1) &&
6709         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
6710         InnerNeg == InnerPos.getOperand(0) &&
6711         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
6712       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
6713     }
6714 
6715     // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
6716     //   -> (fshr x0, x1, y)
6717     // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
6718     if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) &&
6719         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
6720         InnerNeg == InnerPos.getOperand(0) &&
6721         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
6722       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
6723     }
6724   }
6725 
6726   return SDValue();
6727 }
6728 
6729 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
6730 // idioms for rotate, and if the target supports rotation instructions, generate
6731 // a rot[lr]. This also matches funnel shift patterns, similar to rotation but
6732 // with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)6733 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
6734   // Must be a legal type.  Expanded 'n promoted things won't work with rotates.
6735   EVT VT = LHS.getValueType();
6736   if (!TLI.isTypeLegal(VT))
6737     return SDValue();
6738 
6739   // The target must have at least one rotate/funnel flavor.
6740   bool HasROTL = hasOperation(ISD::ROTL, VT);
6741   bool HasROTR = hasOperation(ISD::ROTR, VT);
6742   bool HasFSHL = hasOperation(ISD::FSHL, VT);
6743   bool HasFSHR = hasOperation(ISD::FSHR, VT);
6744   if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
6745     return SDValue();
6746 
6747   // Check for truncated rotate.
6748   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
6749       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
6750     assert(LHS.getValueType() == RHS.getValueType());
6751     if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
6752       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
6753     }
6754   }
6755 
6756   // Match "(X shl/srl V1) & V2" where V2 may not be present.
6757   SDValue LHSShift;   // The shift.
6758   SDValue LHSMask;    // AND value if any.
6759   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
6760 
6761   SDValue RHSShift;   // The shift.
6762   SDValue RHSMask;    // AND value if any.
6763   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
6764 
6765   // If neither side matched a rotate half, bail
6766   if (!LHSShift && !RHSShift)
6767     return SDValue();
6768 
6769   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
6770   // side of the rotate, so try to handle that here. In all cases we need to
6771   // pass the matched shift from the opposite side to compute the opcode and
6772   // needed shift amount to extract.  We still want to do this if both sides
6773   // matched a rotate half because one half may be a potential overshift that
6774   // can be broken down (ie if InstCombine merged two shl or srl ops into a
6775   // single one).
6776 
6777   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
6778   if (LHSShift)
6779     if (SDValue NewRHSShift =
6780             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
6781       RHSShift = NewRHSShift;
6782   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
6783   if (RHSShift)
6784     if (SDValue NewLHSShift =
6785             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
6786       LHSShift = NewLHSShift;
6787 
6788   // If a side is still missing, nothing else we can do.
6789   if (!RHSShift || !LHSShift)
6790     return SDValue();
6791 
6792   // At this point we've matched or extracted a shift op on each side.
6793 
6794   if (LHSShift.getOpcode() == RHSShift.getOpcode())
6795     return SDValue(); // Shifts must disagree.
6796 
6797   bool IsRotate = LHSShift.getOperand(0) == RHSShift.getOperand(0);
6798   if (!IsRotate && !(HasFSHL || HasFSHR))
6799     return SDValue(); // Requires funnel shift support.
6800 
6801   // Canonicalize shl to left side in a shl/srl pair.
6802   if (RHSShift.getOpcode() == ISD::SHL) {
6803     std::swap(LHS, RHS);
6804     std::swap(LHSShift, RHSShift);
6805     std::swap(LHSMask, RHSMask);
6806   }
6807 
6808   unsigned EltSizeInBits = VT.getScalarSizeInBits();
6809   SDValue LHSShiftArg = LHSShift.getOperand(0);
6810   SDValue LHSShiftAmt = LHSShift.getOperand(1);
6811   SDValue RHSShiftArg = RHSShift.getOperand(0);
6812   SDValue RHSShiftAmt = RHSShift.getOperand(1);
6813 
6814   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
6815   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
6816   // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
6817   // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
6818   // iff C1+C2 == EltSizeInBits
6819   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
6820                                         ConstantSDNode *RHS) {
6821     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
6822   };
6823   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
6824     SDValue Res;
6825     if (IsRotate && (HasROTL || HasROTR))
6826       Res = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
6827                         HasROTL ? LHSShiftAmt : RHSShiftAmt);
6828     else
6829       Res = DAG.getNode(HasFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
6830                         RHSShiftArg, HasFSHL ? LHSShiftAmt : RHSShiftAmt);
6831 
6832     // If there is an AND of either shifted operand, apply it to the result.
6833     if (LHSMask.getNode() || RHSMask.getNode()) {
6834       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
6835       SDValue Mask = AllOnes;
6836 
6837       if (LHSMask.getNode()) {
6838         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
6839         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
6840                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
6841       }
6842       if (RHSMask.getNode()) {
6843         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
6844         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
6845                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
6846       }
6847 
6848       Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
6849     }
6850 
6851     return Res;
6852   }
6853 
6854   // If there is a mask here, and we have a variable shift, we can't be sure
6855   // that we're masking out the right stuff.
6856   if (LHSMask.getNode() || RHSMask.getNode())
6857     return SDValue();
6858 
6859   // If the shift amount is sign/zext/any-extended just peel it off.
6860   SDValue LExtOp0 = LHSShiftAmt;
6861   SDValue RExtOp0 = RHSShiftAmt;
6862   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
6863        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
6864        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
6865        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
6866       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
6867        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
6868        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
6869        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
6870     LExtOp0 = LHSShiftAmt.getOperand(0);
6871     RExtOp0 = RHSShiftAmt.getOperand(0);
6872   }
6873 
6874   if (IsRotate && (HasROTL || HasROTR)) {
6875     SDValue TryL =
6876         MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
6877                           RExtOp0, ISD::ROTL, ISD::ROTR, DL);
6878     if (TryL)
6879       return TryL;
6880 
6881     SDValue TryR =
6882         MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
6883                           LExtOp0, ISD::ROTR, ISD::ROTL, DL);
6884     if (TryR)
6885       return TryR;
6886   }
6887 
6888   SDValue TryL =
6889       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
6890                         LExtOp0, RExtOp0, ISD::FSHL, ISD::FSHR, DL);
6891   if (TryL)
6892     return TryL;
6893 
6894   SDValue TryR =
6895       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
6896                         RExtOp0, LExtOp0, ISD::FSHR, ISD::FSHL, DL);
6897   if (TryR)
6898     return TryR;
6899 
6900   return SDValue();
6901 }
6902 
6903 namespace {
6904 
6905 /// Represents known origin of an individual byte in load combine pattern. The
6906 /// value of the byte is either constant zero or comes from memory.
6907 struct ByteProvider {
6908   // For constant zero providers Load is set to nullptr. For memory providers
6909   // Load represents the node which loads the byte from memory.
6910   // ByteOffset is the offset of the byte in the value produced by the load.
6911   LoadSDNode *Load = nullptr;
6912   unsigned ByteOffset = 0;
6913 
6914   ByteProvider() = default;
6915 
getMemory__anonf026c69c0e11::ByteProvider6916   static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
6917     return ByteProvider(Load, ByteOffset);
6918   }
6919 
getConstantZero__anonf026c69c0e11::ByteProvider6920   static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
6921 
isConstantZero__anonf026c69c0e11::ByteProvider6922   bool isConstantZero() const { return !Load; }
isMemory__anonf026c69c0e11::ByteProvider6923   bool isMemory() const { return Load; }
6924 
operator ==__anonf026c69c0e11::ByteProvider6925   bool operator==(const ByteProvider &Other) const {
6926     return Other.Load == Load && Other.ByteOffset == ByteOffset;
6927   }
6928 
6929 private:
ByteProvider__anonf026c69c0e11::ByteProvider6930   ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
6931       : Load(Load), ByteOffset(ByteOffset) {}
6932 };
6933 
6934 } // end anonymous namespace
6935 
6936 /// Recursively traverses the expression calculating the origin of the requested
6937 /// byte of the given value. Returns None if the provider can't be calculated.
6938 ///
6939 /// For all the values except the root of the expression verifies that the value
6940 /// has exactly one use and if it's not true return None. This way if the origin
6941 /// of the byte is returned it's guaranteed that the values which contribute to
6942 /// the byte are not used outside of this expression.
6943 ///
6944 /// Because the parts of the expression are not allowed to have more than one
6945 /// use this function iterates over trees, not DAGs. So it never visits the same
6946 /// node more than once.
6947 static const Optional<ByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,bool Root=false)6948 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
6949                       bool Root = false) {
6950   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
6951   if (Depth == 10)
6952     return None;
6953 
6954   if (!Root && !Op.hasOneUse())
6955     return None;
6956 
6957   assert(Op.getValueType().isScalarInteger() && "can't handle other types");
6958   unsigned BitWidth = Op.getValueSizeInBits();
6959   if (BitWidth % 8 != 0)
6960     return None;
6961   unsigned ByteWidth = BitWidth / 8;
6962   assert(Index < ByteWidth && "invalid index requested");
6963   (void) ByteWidth;
6964 
6965   switch (Op.getOpcode()) {
6966   case ISD::OR: {
6967     auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
6968     if (!LHS)
6969       return None;
6970     auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
6971     if (!RHS)
6972       return None;
6973 
6974     if (LHS->isConstantZero())
6975       return RHS;
6976     if (RHS->isConstantZero())
6977       return LHS;
6978     return None;
6979   }
6980   case ISD::SHL: {
6981     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
6982     if (!ShiftOp)
6983       return None;
6984 
6985     uint64_t BitShift = ShiftOp->getZExtValue();
6986     if (BitShift % 8 != 0)
6987       return None;
6988     uint64_t ByteShift = BitShift / 8;
6989 
6990     return Index < ByteShift
6991                ? ByteProvider::getConstantZero()
6992                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
6993                                        Depth + 1);
6994   }
6995   case ISD::ANY_EXTEND:
6996   case ISD::SIGN_EXTEND:
6997   case ISD::ZERO_EXTEND: {
6998     SDValue NarrowOp = Op->getOperand(0);
6999     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
7000     if (NarrowBitWidth % 8 != 0)
7001       return None;
7002     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
7003 
7004     if (Index >= NarrowByteWidth)
7005       return Op.getOpcode() == ISD::ZERO_EXTEND
7006                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
7007                  : None;
7008     return calculateByteProvider(NarrowOp, Index, Depth + 1);
7009   }
7010   case ISD::BSWAP:
7011     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
7012                                  Depth + 1);
7013   case ISD::LOAD: {
7014     auto L = cast<LoadSDNode>(Op.getNode());
7015     if (!L->isSimple() || L->isIndexed())
7016       return None;
7017 
7018     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
7019     if (NarrowBitWidth % 8 != 0)
7020       return None;
7021     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
7022 
7023     if (Index >= NarrowByteWidth)
7024       return L->getExtensionType() == ISD::ZEXTLOAD
7025                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
7026                  : None;
7027     return ByteProvider::getMemory(L, Index);
7028   }
7029   }
7030 
7031   return None;
7032 }
7033 
littleEndianByteAt(unsigned BW,unsigned i)7034 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
7035   return i;
7036 }
7037 
bigEndianByteAt(unsigned BW,unsigned i)7038 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
7039   return BW - i - 1;
7040 }
7041 
7042 // Check if the bytes offsets we are looking at match with either big or
7043 // little endian value loaded. Return true for big endian, false for little
7044 // endian, and None if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)7045 static Optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
7046                                   int64_t FirstOffset) {
7047   // The endian can be decided only when it is 2 bytes at least.
7048   unsigned Width = ByteOffsets.size();
7049   if (Width < 2)
7050     return None;
7051 
7052   bool BigEndian = true, LittleEndian = true;
7053   for (unsigned i = 0; i < Width; i++) {
7054     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
7055     LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
7056     BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
7057     if (!BigEndian && !LittleEndian)
7058       return None;
7059   }
7060 
7061   assert((BigEndian != LittleEndian) && "It should be either big endian or"
7062                                         "little endian");
7063   return BigEndian;
7064 }
7065 
stripTruncAndExt(SDValue Value)7066 static SDValue stripTruncAndExt(SDValue Value) {
7067   switch (Value.getOpcode()) {
7068   case ISD::TRUNCATE:
7069   case ISD::ZERO_EXTEND:
7070   case ISD::SIGN_EXTEND:
7071   case ISD::ANY_EXTEND:
7072     return stripTruncAndExt(Value.getOperand(0));
7073   }
7074   return Value;
7075 }
7076 
7077 /// Match a pattern where a wide type scalar value is stored by several narrow
7078 /// stores. Fold it into a single store or a BSWAP and a store if the targets
7079 /// supports it.
7080 ///
7081 /// Assuming little endian target:
7082 ///  i8 *p = ...
7083 ///  i32 val = ...
7084 ///  p[0] = (val >> 0) & 0xFF;
7085 ///  p[1] = (val >> 8) & 0xFF;
7086 ///  p[2] = (val >> 16) & 0xFF;
7087 ///  p[3] = (val >> 24) & 0xFF;
7088 /// =>
7089 ///  *((i32)p) = val;
7090 ///
7091 ///  i8 *p = ...
7092 ///  i32 val = ...
7093 ///  p[0] = (val >> 24) & 0xFF;
7094 ///  p[1] = (val >> 16) & 0xFF;
7095 ///  p[2] = (val >> 8) & 0xFF;
7096 ///  p[3] = (val >> 0) & 0xFF;
7097 /// =>
7098 ///  *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)7099 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
7100   // The matching looks for "store (trunc x)" patterns that appear early but are
7101   // likely to be replaced by truncating store nodes during combining.
7102   // TODO: If there is evidence that running this later would help, this
7103   //       limitation could be removed. Legality checks may need to be added
7104   //       for the created store and optional bswap/rotate.
7105   if (LegalOperations)
7106     return SDValue();
7107 
7108   // We only handle merging simple stores of 1-4 bytes.
7109   // TODO: Allow unordered atomics when wider type is legal (see D66309)
7110   EVT MemVT = N->getMemoryVT();
7111   if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
7112       !N->isSimple() || N->isIndexed())
7113     return SDValue();
7114 
7115   // Collect all of the stores in the chain.
7116   SDValue Chain = N->getChain();
7117   SmallVector<StoreSDNode *, 8> Stores = {N};
7118   while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
7119     // All stores must be the same size to ensure that we are writing all of the
7120     // bytes in the wide value.
7121     // TODO: We could allow multiple sizes by tracking each stored byte.
7122     if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
7123         Store->isIndexed())
7124       return SDValue();
7125     Stores.push_back(Store);
7126     Chain = Store->getChain();
7127   }
7128   // There is no reason to continue if we do not have at least a pair of stores.
7129   if (Stores.size() < 2)
7130     return SDValue();
7131 
7132   // Handle simple types only.
7133   LLVMContext &Context = *DAG.getContext();
7134   unsigned NumStores = Stores.size();
7135   unsigned NarrowNumBits = N->getMemoryVT().getScalarSizeInBits();
7136   unsigned WideNumBits = NumStores * NarrowNumBits;
7137   EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
7138   if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
7139     return SDValue();
7140 
7141   // Check if all bytes of the source value that we are looking at are stored
7142   // to the same base address. Collect offsets from Base address into OffsetMap.
7143   SDValue SourceValue;
7144   SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
7145   int64_t FirstOffset = INT64_MAX;
7146   StoreSDNode *FirstStore = nullptr;
7147   Optional<BaseIndexOffset> Base;
7148   for (auto Store : Stores) {
7149     // All the stores store different parts of the CombinedValue. A truncate is
7150     // required to get the partial value.
7151     SDValue Trunc = Store->getValue();
7152     if (Trunc.getOpcode() != ISD::TRUNCATE)
7153       return SDValue();
7154     // Other than the first/last part, a shift operation is required to get the
7155     // offset.
7156     int64_t Offset = 0;
7157     SDValue WideVal = Trunc.getOperand(0);
7158     if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
7159         isa<ConstantSDNode>(WideVal.getOperand(1))) {
7160       // The shift amount must be a constant multiple of the narrow type.
7161       // It is translated to the offset address in the wide source value "y".
7162       //
7163       // x = srl y, ShiftAmtC
7164       // i8 z = trunc x
7165       // store z, ...
7166       uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
7167       if (ShiftAmtC % NarrowNumBits != 0)
7168         return SDValue();
7169 
7170       Offset = ShiftAmtC / NarrowNumBits;
7171       WideVal = WideVal.getOperand(0);
7172     }
7173 
7174     // Stores must share the same source value with different offsets.
7175     // Truncate and extends should be stripped to get the single source value.
7176     if (!SourceValue)
7177       SourceValue = WideVal;
7178     else if (stripTruncAndExt(SourceValue) != stripTruncAndExt(WideVal))
7179       return SDValue();
7180     else if (SourceValue.getValueType() != WideVT) {
7181       if (WideVal.getValueType() == WideVT ||
7182           WideVal.getScalarValueSizeInBits() >
7183               SourceValue.getScalarValueSizeInBits())
7184         SourceValue = WideVal;
7185       // Give up if the source value type is smaller than the store size.
7186       if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
7187         return SDValue();
7188     }
7189 
7190     // Stores must share the same base address.
7191     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
7192     int64_t ByteOffsetFromBase = 0;
7193     if (!Base)
7194       Base = Ptr;
7195     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
7196       return SDValue();
7197 
7198     // Remember the first store.
7199     if (ByteOffsetFromBase < FirstOffset) {
7200       FirstStore = Store;
7201       FirstOffset = ByteOffsetFromBase;
7202     }
7203     // Map the offset in the store and the offset in the combined value, and
7204     // early return if it has been set before.
7205     if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
7206       return SDValue();
7207     OffsetMap[Offset] = ByteOffsetFromBase;
7208   }
7209 
7210   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
7211   assert(FirstStore && "First store must be set");
7212 
7213   // Check that a store of the wide type is both allowed and fast on the target
7214   const DataLayout &Layout = DAG.getDataLayout();
7215   bool Fast = false;
7216   bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
7217                                         *FirstStore->getMemOperand(), &Fast);
7218   if (!Allowed || !Fast)
7219     return SDValue();
7220 
7221   // Check if the pieces of the value are going to the expected places in memory
7222   // to merge the stores.
7223   auto checkOffsets = [&](bool MatchLittleEndian) {
7224     if (MatchLittleEndian) {
7225       for (unsigned i = 0; i != NumStores; ++i)
7226         if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
7227           return false;
7228     } else { // MatchBigEndian by reversing loop counter.
7229       for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
7230         if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
7231           return false;
7232     }
7233     return true;
7234   };
7235 
7236   // Check if the offsets line up for the native data layout of this target.
7237   bool NeedBswap = false;
7238   bool NeedRotate = false;
7239   if (!checkOffsets(Layout.isLittleEndian())) {
7240     // Special-case: check if byte offsets line up for the opposite endian.
7241     if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
7242       NeedBswap = true;
7243     else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
7244       NeedRotate = true;
7245     else
7246       return SDValue();
7247   }
7248 
7249   SDLoc DL(N);
7250   if (WideVT != SourceValue.getValueType()) {
7251     assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
7252            "Unexpected store value to merge");
7253     SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
7254   }
7255 
7256   // Before legalize we can introduce illegal bswaps/rotates which will be later
7257   // converted to an explicit bswap sequence. This way we end up with a single
7258   // store and byte shuffling instead of several stores and byte shuffling.
7259   if (NeedBswap) {
7260     SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
7261   } else if (NeedRotate) {
7262     assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
7263     SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
7264     SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
7265   }
7266 
7267   SDValue NewStore =
7268       DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
7269                    FirstStore->getPointerInfo(), FirstStore->getAlign());
7270 
7271   // Rely on other DAG combine rules to remove the other individual stores.
7272   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
7273   return NewStore;
7274 }
7275 
7276 /// Match a pattern where a wide type scalar value is loaded by several narrow
7277 /// loads and combined by shifts and ors. Fold it into a single load or a load
7278 /// and a BSWAP if the targets supports it.
7279 ///
7280 /// Assuming little endian target:
7281 ///  i8 *a = ...
7282 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
7283 /// =>
7284 ///  i32 val = *((i32)a)
7285 ///
7286 ///  i8 *a = ...
7287 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
7288 /// =>
7289 ///  i32 val = BSWAP(*((i32)a))
7290 ///
7291 /// TODO: This rule matches complex patterns with OR node roots and doesn't
7292 /// interact well with the worklist mechanism. When a part of the pattern is
7293 /// updated (e.g. one of the loads) its direct users are put into the worklist,
7294 /// but the root node of the pattern which triggers the load combine is not
7295 /// necessarily a direct user of the changed node. For example, once the address
7296 /// of t28 load is reassociated load combine won't be triggered:
7297 ///             t25: i32 = add t4, Constant:i32<2>
7298 ///           t26: i64 = sign_extend t25
7299 ///        t27: i64 = add t2, t26
7300 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
7301 ///     t29: i32 = zero_extend t28
7302 ///   t32: i32 = shl t29, Constant:i8<8>
7303 /// t33: i32 = or t23, t32
7304 /// As a possible fix visitLoad can check if the load can be a part of a load
7305 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)7306 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
7307   assert(N->getOpcode() == ISD::OR &&
7308          "Can only match load combining against OR nodes");
7309 
7310   // Handles simple types only
7311   EVT VT = N->getValueType(0);
7312   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
7313     return SDValue();
7314   unsigned ByteWidth = VT.getSizeInBits() / 8;
7315 
7316   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
7317   auto MemoryByteOffset = [&] (ByteProvider P) {
7318     assert(P.isMemory() && "Must be a memory byte provider");
7319     unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits();
7320     assert(LoadBitWidth % 8 == 0 &&
7321            "can only analyze providers for individual bytes not bit");
7322     unsigned LoadByteWidth = LoadBitWidth / 8;
7323     return IsBigEndianTarget
7324             ? bigEndianByteAt(LoadByteWidth, P.ByteOffset)
7325             : littleEndianByteAt(LoadByteWidth, P.ByteOffset);
7326   };
7327 
7328   Optional<BaseIndexOffset> Base;
7329   SDValue Chain;
7330 
7331   SmallPtrSet<LoadSDNode *, 8> Loads;
7332   Optional<ByteProvider> FirstByteProvider;
7333   int64_t FirstOffset = INT64_MAX;
7334 
7335   // Check if all the bytes of the OR we are looking at are loaded from the same
7336   // base address. Collect bytes offsets from Base address in ByteOffsets.
7337   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
7338   unsigned ZeroExtendedBytes = 0;
7339   for (int i = ByteWidth - 1; i >= 0; --i) {
7340     auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
7341     if (!P)
7342       return SDValue();
7343 
7344     if (P->isConstantZero()) {
7345       // It's OK for the N most significant bytes to be 0, we can just
7346       // zero-extend the load.
7347       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
7348         return SDValue();
7349       continue;
7350     }
7351     assert(P->isMemory() && "provenance should either be memory or zero");
7352 
7353     LoadSDNode *L = P->Load;
7354     assert(L->hasNUsesOfValue(1, 0) && L->isSimple() &&
7355            !L->isIndexed() &&
7356            "Must be enforced by calculateByteProvider");
7357     assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
7358 
7359     // All loads must share the same chain
7360     SDValue LChain = L->getChain();
7361     if (!Chain)
7362       Chain = LChain;
7363     else if (Chain != LChain)
7364       return SDValue();
7365 
7366     // Loads must share the same base address
7367     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
7368     int64_t ByteOffsetFromBase = 0;
7369     if (!Base)
7370       Base = Ptr;
7371     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
7372       return SDValue();
7373 
7374     // Calculate the offset of the current byte from the base address
7375     ByteOffsetFromBase += MemoryByteOffset(*P);
7376     ByteOffsets[i] = ByteOffsetFromBase;
7377 
7378     // Remember the first byte load
7379     if (ByteOffsetFromBase < FirstOffset) {
7380       FirstByteProvider = P;
7381       FirstOffset = ByteOffsetFromBase;
7382     }
7383 
7384     Loads.insert(L);
7385   }
7386   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
7387          "memory, so there must be at least one load which produces the value");
7388   assert(Base && "Base address of the accessed memory location must be set");
7389   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
7390 
7391   bool NeedsZext = ZeroExtendedBytes > 0;
7392 
7393   EVT MemVT =
7394       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
7395 
7396   if (!MemVT.isSimple())
7397     return SDValue();
7398 
7399   // Before legalize we can introduce too wide illegal loads which will be later
7400   // split into legal sized loads. This enables us to combine i64 load by i8
7401   // patterns to a couple of i32 loads on 32 bit targets.
7402   if (LegalOperations &&
7403       !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
7404                             MemVT))
7405     return SDValue();
7406 
7407   // Check if the bytes of the OR we are looking at match with either big or
7408   // little endian value load
7409   Optional<bool> IsBigEndian = isBigEndian(
7410       makeArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
7411   if (!IsBigEndian.hasValue())
7412     return SDValue();
7413 
7414   assert(FirstByteProvider && "must be set");
7415 
7416   // Ensure that the first byte is loaded from zero offset of the first load.
7417   // So the combined value can be loaded from the first load address.
7418   if (MemoryByteOffset(*FirstByteProvider) != 0)
7419     return SDValue();
7420   LoadSDNode *FirstLoad = FirstByteProvider->Load;
7421 
7422   // The node we are looking at matches with the pattern, check if we can
7423   // replace it with a single (possibly zero-extended) load and bswap + shift if
7424   // needed.
7425 
7426   // If the load needs byte swap check if the target supports it
7427   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
7428 
7429   // Before legalize we can introduce illegal bswaps which will be later
7430   // converted to an explicit bswap sequence. This way we end up with a single
7431   // load and byte shuffling instead of several loads and byte shuffling.
7432   // We do not introduce illegal bswaps when zero-extending as this tends to
7433   // introduce too many arithmetic instructions.
7434   if (NeedsBswap && (LegalOperations || NeedsZext) &&
7435       !TLI.isOperationLegal(ISD::BSWAP, VT))
7436     return SDValue();
7437 
7438   // If we need to bswap and zero extend, we have to insert a shift. Check that
7439   // it is legal.
7440   if (NeedsBswap && NeedsZext && LegalOperations &&
7441       !TLI.isOperationLegal(ISD::SHL, VT))
7442     return SDValue();
7443 
7444   // Check that a load of the wide type is both allowed and fast on the target
7445   bool Fast = false;
7446   bool Allowed =
7447       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
7448                              *FirstLoad->getMemOperand(), &Fast);
7449   if (!Allowed || !Fast)
7450     return SDValue();
7451 
7452   SDValue NewLoad =
7453       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
7454                      Chain, FirstLoad->getBasePtr(),
7455                      FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
7456 
7457   // Transfer chain users from old loads to the new load.
7458   for (LoadSDNode *L : Loads)
7459     DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
7460 
7461   if (!NeedsBswap)
7462     return NewLoad;
7463 
7464   SDValue ShiftedLoad =
7465       NeedsZext
7466           ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
7467                         DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
7468                                                    SDLoc(N), LegalOperations))
7469           : NewLoad;
7470   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
7471 }
7472 
7473 // If the target has andn, bsl, or a similar bit-select instruction,
7474 // we want to unfold masked merge, with canonical pattern of:
7475 //   |        A  |  |B|
7476 //   ((x ^ y) & m) ^ y
7477 //    |  D  |
7478 // Into:
7479 //   (x & m) | (y & ~m)
7480 // If y is a constant, and the 'andn' does not work with immediates,
7481 // we unfold into a different pattern:
7482 //   ~(~x & m) & (m | y)
7483 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
7484 //       the very least that breaks andnpd / andnps patterns, and because those
7485 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)7486 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
7487   assert(N->getOpcode() == ISD::XOR);
7488 
7489   // Don't touch 'not' (i.e. where y = -1).
7490   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
7491     return SDValue();
7492 
7493   EVT VT = N->getValueType(0);
7494 
7495   // There are 3 commutable operators in the pattern,
7496   // so we have to deal with 8 possible variants of the basic pattern.
7497   SDValue X, Y, M;
7498   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
7499     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
7500       return false;
7501     SDValue Xor = And.getOperand(XorIdx);
7502     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
7503       return false;
7504     SDValue Xor0 = Xor.getOperand(0);
7505     SDValue Xor1 = Xor.getOperand(1);
7506     // Don't touch 'not' (i.e. where y = -1).
7507     if (isAllOnesOrAllOnesSplat(Xor1))
7508       return false;
7509     if (Other == Xor0)
7510       std::swap(Xor0, Xor1);
7511     if (Other != Xor1)
7512       return false;
7513     X = Xor0;
7514     Y = Xor1;
7515     M = And.getOperand(XorIdx ? 0 : 1);
7516     return true;
7517   };
7518 
7519   SDValue N0 = N->getOperand(0);
7520   SDValue N1 = N->getOperand(1);
7521   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
7522       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
7523     return SDValue();
7524 
7525   // Don't do anything if the mask is constant. This should not be reachable.
7526   // InstCombine should have already unfolded this pattern, and DAGCombiner
7527   // probably shouldn't produce it, too.
7528   if (isa<ConstantSDNode>(M.getNode()))
7529     return SDValue();
7530 
7531   // We can transform if the target has AndNot
7532   if (!TLI.hasAndNot(M))
7533     return SDValue();
7534 
7535   SDLoc DL(N);
7536 
7537   // If Y is a constant, check that 'andn' works with immediates.
7538   if (!TLI.hasAndNot(Y)) {
7539     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
7540     // If not, we need to do a bit more work to make sure andn is still used.
7541     SDValue NotX = DAG.getNOT(DL, X, VT);
7542     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
7543     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
7544     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
7545     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
7546   }
7547 
7548   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
7549   SDValue NotM = DAG.getNOT(DL, M, VT);
7550   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
7551 
7552   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
7553 }
7554 
visitXOR(SDNode * N)7555 SDValue DAGCombiner::visitXOR(SDNode *N) {
7556   SDValue N0 = N->getOperand(0);
7557   SDValue N1 = N->getOperand(1);
7558   EVT VT = N0.getValueType();
7559 
7560   // fold vector ops
7561   if (VT.isVector()) {
7562     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7563       return FoldedVOp;
7564 
7565     // fold (xor x, 0) -> x, vector edition
7566     if (ISD::isBuildVectorAllZeros(N0.getNode()))
7567       return N1;
7568     if (ISD::isBuildVectorAllZeros(N1.getNode()))
7569       return N0;
7570   }
7571 
7572   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
7573   SDLoc DL(N);
7574   if (N0.isUndef() && N1.isUndef())
7575     return DAG.getConstant(0, DL, VT);
7576 
7577   // fold (xor x, undef) -> undef
7578   if (N0.isUndef())
7579     return N0;
7580   if (N1.isUndef())
7581     return N1;
7582 
7583   // fold (xor c1, c2) -> c1^c2
7584   if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
7585     return C;
7586 
7587   // canonicalize constant to RHS
7588   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7589      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7590     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
7591 
7592   // fold (xor x, 0) -> x
7593   if (isNullConstant(N1))
7594     return N0;
7595 
7596   if (SDValue NewSel = foldBinOpIntoSelect(N))
7597     return NewSel;
7598 
7599   // reassociate xor
7600   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
7601     return RXOR;
7602 
7603   // fold !(x cc y) -> (x !cc y)
7604   unsigned N0Opcode = N0.getOpcode();
7605   SDValue LHS, RHS, CC;
7606   if (TLI.isConstTrueVal(N1.getNode()) &&
7607       isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/true)) {
7608     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
7609                                                LHS.getValueType());
7610     if (!LegalOperations ||
7611         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
7612       switch (N0Opcode) {
7613       default:
7614         llvm_unreachable("Unhandled SetCC Equivalent!");
7615       case ISD::SETCC:
7616         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
7617       case ISD::SELECT_CC:
7618         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
7619                                N0.getOperand(3), NotCC);
7620       case ISD::STRICT_FSETCC:
7621       case ISD::STRICT_FSETCCS: {
7622         if (N0.hasOneUse()) {
7623           // FIXME Can we handle multiple uses? Could we token factor the chain
7624           // results from the new/old setcc?
7625           SDValue SetCC =
7626               DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
7627                            N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
7628           CombineTo(N, SetCC);
7629           DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
7630           recursivelyDeleteUnusedNodes(N0.getNode());
7631           return SDValue(N, 0); // Return N so it doesn't get rechecked!
7632         }
7633         break;
7634       }
7635       }
7636     }
7637   }
7638 
7639   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
7640   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
7641       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
7642     SDValue V = N0.getOperand(0);
7643     SDLoc DL0(N0);
7644     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
7645                     DAG.getConstant(1, DL0, V.getValueType()));
7646     AddToWorklist(V.getNode());
7647     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
7648   }
7649 
7650   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
7651   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
7652       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
7653     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
7654     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
7655       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
7656       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
7657       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
7658       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
7659       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
7660     }
7661   }
7662   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
7663   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
7664       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
7665     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
7666     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
7667       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
7668       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
7669       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
7670       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
7671       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
7672     }
7673   }
7674 
7675   // fold (not (neg x)) -> (add X, -1)
7676   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
7677   // Y is a constant or the subtract has a single use.
7678   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
7679       isNullConstant(N0.getOperand(0))) {
7680     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
7681                        DAG.getAllOnesConstant(DL, VT));
7682   }
7683 
7684   // fold (not (add X, -1)) -> (neg X)
7685   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
7686       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
7687     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
7688                        N0.getOperand(0));
7689   }
7690 
7691   // fold (xor (and x, y), y) -> (and (not x), y)
7692   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
7693     SDValue X = N0.getOperand(0);
7694     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
7695     AddToWorklist(NotX.getNode());
7696     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
7697   }
7698 
7699   if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) {
7700     ConstantSDNode *XorC = isConstOrConstSplat(N1);
7701     ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1));
7702     unsigned BitWidth = VT.getScalarSizeInBits();
7703     if (XorC && ShiftC) {
7704       // Don't crash on an oversized shift. We can not guarantee that a bogus
7705       // shift has been simplified to undef.
7706       uint64_t ShiftAmt = ShiftC->getLimitedValue();
7707       if (ShiftAmt < BitWidth) {
7708         APInt Ones = APInt::getAllOnesValue(BitWidth);
7709         Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt);
7710         if (XorC->getAPIntValue() == Ones) {
7711           // If the xor constant is a shifted -1, do a 'not' before the shift:
7712           // xor (X << ShiftC), XorC --> (not X) << ShiftC
7713           // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
7714           SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
7715           return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1));
7716         }
7717       }
7718     }
7719   }
7720 
7721   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
7722   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
7723     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
7724     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
7725     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
7726       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
7727       SDValue S0 = S.getOperand(0);
7728       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
7729         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
7730           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
7731             return DAG.getNode(ISD::ABS, DL, VT, S0);
7732     }
7733   }
7734 
7735   // fold (xor x, x) -> 0
7736   if (N0 == N1)
7737     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
7738 
7739   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
7740   // Here is a concrete example of this equivalence:
7741   // i16   x ==  14
7742   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
7743   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
7744   //
7745   // =>
7746   //
7747   // i16     ~1      == 0b1111111111111110
7748   // i16 rol(~1, 14) == 0b1011111111111111
7749   //
7750   // Some additional tips to help conceptualize this transform:
7751   // - Try to see the operation as placing a single zero in a value of all ones.
7752   // - There exists no value for x which would allow the result to contain zero.
7753   // - Values of x larger than the bitwidth are undefined and do not require a
7754   //   consistent result.
7755   // - Pushing the zero left requires shifting one bits in from the right.
7756   // A rotate left of ~1 is a nice way of achieving the desired result.
7757   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
7758       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
7759     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
7760                        N0.getOperand(1));
7761   }
7762 
7763   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
7764   if (N0Opcode == N1.getOpcode())
7765     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7766       return V;
7767 
7768   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
7769   if (SDValue MM = unfoldMaskedMerge(N))
7770     return MM;
7771 
7772   // Simplify the expression using non-local knowledge.
7773   if (SimplifyDemandedBits(SDValue(N, 0)))
7774     return SDValue(N, 0);
7775 
7776   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
7777     return Combined;
7778 
7779   return SDValue();
7780 }
7781 
7782 /// If we have a shift-by-constant of a bitwise logic op that itself has a
7783 /// shift-by-constant operand with identical opcode, we may be able to convert
7784 /// that into 2 independent shifts followed by the logic op. This is a
7785 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)7786 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
7787   // Match a one-use bitwise logic op.
7788   SDValue LogicOp = Shift->getOperand(0);
7789   if (!LogicOp.hasOneUse())
7790     return SDValue();
7791 
7792   unsigned LogicOpcode = LogicOp.getOpcode();
7793   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
7794       LogicOpcode != ISD::XOR)
7795     return SDValue();
7796 
7797   // Find a matching one-use shift by constant.
7798   unsigned ShiftOpcode = Shift->getOpcode();
7799   SDValue C1 = Shift->getOperand(1);
7800   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
7801   assert(C1Node && "Expected a shift with constant operand");
7802   const APInt &C1Val = C1Node->getAPIntValue();
7803   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
7804                              const APInt *&ShiftAmtVal) {
7805     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
7806       return false;
7807 
7808     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
7809     if (!ShiftCNode)
7810       return false;
7811 
7812     // Capture the shifted operand and shift amount value.
7813     ShiftOp = V.getOperand(0);
7814     ShiftAmtVal = &ShiftCNode->getAPIntValue();
7815 
7816     // Shift amount types do not have to match their operand type, so check that
7817     // the constants are the same width.
7818     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
7819       return false;
7820 
7821     // The fold is not valid if the sum of the shift values exceeds bitwidth.
7822     if ((*ShiftAmtVal + C1Val).uge(V.getScalarValueSizeInBits()))
7823       return false;
7824 
7825     return true;
7826   };
7827 
7828   // Logic ops are commutative, so check each operand for a match.
7829   SDValue X, Y;
7830   const APInt *C0Val;
7831   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
7832     Y = LogicOp.getOperand(1);
7833   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
7834     Y = LogicOp.getOperand(0);
7835   else
7836     return SDValue();
7837 
7838   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
7839   SDLoc DL(Shift);
7840   EVT VT = Shift->getValueType(0);
7841   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
7842   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
7843   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
7844   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
7845   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
7846 }
7847 
7848 /// Handle transforms common to the three shifts, when the shift amount is a
7849 /// constant.
7850 /// We are looking for: (shift being one of shl/sra/srl)
7851 ///   shift (binop X, C0), C1
7852 /// And want to transform into:
7853 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)7854 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
7855   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
7856 
7857   // Do not turn a 'not' into a regular xor.
7858   if (isBitwiseNot(N->getOperand(0)))
7859     return SDValue();
7860 
7861   // The inner binop must be one-use, since we want to replace it.
7862   SDValue LHS = N->getOperand(0);
7863   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
7864     return SDValue();
7865 
7866   // TODO: This is limited to early combining because it may reveal regressions
7867   //       otherwise. But since we just checked a target hook to see if this is
7868   //       desirable, that should have filtered out cases where this interferes
7869   //       with some other pattern matching.
7870   if (!LegalTypes)
7871     if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
7872       return R;
7873 
7874   // We want to pull some binops through shifts, so that we have (and (shift))
7875   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
7876   // thing happens with address calculations, so it's important to canonicalize
7877   // it.
7878   switch (LHS.getOpcode()) {
7879   default:
7880     return SDValue();
7881   case ISD::OR:
7882   case ISD::XOR:
7883   case ISD::AND:
7884     break;
7885   case ISD::ADD:
7886     if (N->getOpcode() != ISD::SHL)
7887       return SDValue(); // only shl(add) not sr[al](add).
7888     break;
7889   }
7890 
7891   // We require the RHS of the binop to be a constant and not opaque as well.
7892   ConstantSDNode *BinOpCst = getAsNonOpaqueConstant(LHS.getOperand(1));
7893   if (!BinOpCst)
7894     return SDValue();
7895 
7896   // FIXME: disable this unless the input to the binop is a shift by a constant
7897   // or is copy/select. Enable this in other cases when figure out it's exactly
7898   // profitable.
7899   SDValue BinOpLHSVal = LHS.getOperand(0);
7900   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
7901                             BinOpLHSVal.getOpcode() == ISD::SRA ||
7902                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
7903                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
7904   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
7905                         BinOpLHSVal.getOpcode() == ISD::SELECT;
7906 
7907   if (!IsShiftByConstant && !IsCopyOrSelect)
7908     return SDValue();
7909 
7910   if (IsCopyOrSelect && N->hasOneUse())
7911     return SDValue();
7912 
7913   // Fold the constants, shifting the binop RHS by the shift amount.
7914   SDLoc DL(N);
7915   EVT VT = N->getValueType(0);
7916   SDValue NewRHS = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(1),
7917                                N->getOperand(1));
7918   assert(isa<ConstantSDNode>(NewRHS) && "Folding was not successful!");
7919 
7920   SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
7921                                  N->getOperand(1));
7922   return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
7923 }
7924 
distributeTruncateThroughAnd(SDNode * N)7925 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
7926   assert(N->getOpcode() == ISD::TRUNCATE);
7927   assert(N->getOperand(0).getOpcode() == ISD::AND);
7928 
7929   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
7930   EVT TruncVT = N->getValueType(0);
7931   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
7932       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
7933     SDValue N01 = N->getOperand(0).getOperand(1);
7934     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
7935       SDLoc DL(N);
7936       SDValue N00 = N->getOperand(0).getOperand(0);
7937       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
7938       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
7939       AddToWorklist(Trunc00.getNode());
7940       AddToWorklist(Trunc01.getNode());
7941       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
7942     }
7943   }
7944 
7945   return SDValue();
7946 }
7947 
visitRotate(SDNode * N)7948 SDValue DAGCombiner::visitRotate(SDNode *N) {
7949   SDLoc dl(N);
7950   SDValue N0 = N->getOperand(0);
7951   SDValue N1 = N->getOperand(1);
7952   EVT VT = N->getValueType(0);
7953   unsigned Bitsize = VT.getScalarSizeInBits();
7954 
7955   // fold (rot x, 0) -> x
7956   if (isNullOrNullSplat(N1))
7957     return N0;
7958 
7959   // fold (rot x, c) -> x iff (c % BitSize) == 0
7960   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
7961     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
7962     if (DAG.MaskedValueIsZero(N1, ModuloMask))
7963       return N0;
7964   }
7965 
7966   // fold (rot x, c) -> (rot x, c % BitSize)
7967   bool OutOfRange = false;
7968   auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
7969     OutOfRange |= C->getAPIntValue().uge(Bitsize);
7970     return true;
7971   };
7972   if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
7973     EVT AmtVT = N1.getValueType();
7974     SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
7975     if (SDValue Amt =
7976             DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
7977       return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
7978   }
7979 
7980   // rot i16 X, 8 --> bswap X
7981   auto *RotAmtC = isConstOrConstSplat(N1);
7982   if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
7983       VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
7984     return DAG.getNode(ISD::BSWAP, dl, VT, N0);
7985 
7986   // Simplify the operands using demanded-bits information.
7987   if (SimplifyDemandedBits(SDValue(N, 0)))
7988     return SDValue(N, 0);
7989 
7990   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
7991   if (N1.getOpcode() == ISD::TRUNCATE &&
7992       N1.getOperand(0).getOpcode() == ISD::AND) {
7993     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
7994       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
7995   }
7996 
7997   unsigned NextOp = N0.getOpcode();
7998   // fold (rot* (rot* x, c2), c1) -> (rot* x, c1 +- c2 % bitsize)
7999   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
8000     SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
8001     SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
8002     if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
8003       EVT ShiftVT = C1->getValueType(0);
8004       bool SameSide = (N->getOpcode() == NextOp);
8005       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
8006       if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
8007               CombineOp, dl, ShiftVT, {N1, N0.getOperand(1)})) {
8008         SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
8009         SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
8010             ISD::SREM, dl, ShiftVT, {CombinedShift, BitsizeC});
8011         return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
8012                            CombinedShiftNorm);
8013       }
8014     }
8015   }
8016   return SDValue();
8017 }
8018 
visitSHL(SDNode * N)8019 SDValue DAGCombiner::visitSHL(SDNode *N) {
8020   SDValue N0 = N->getOperand(0);
8021   SDValue N1 = N->getOperand(1);
8022   if (SDValue V = DAG.simplifyShift(N0, N1))
8023     return V;
8024 
8025   EVT VT = N0.getValueType();
8026   EVT ShiftVT = N1.getValueType();
8027   unsigned OpSizeInBits = VT.getScalarSizeInBits();
8028 
8029   // fold vector ops
8030   if (VT.isVector()) {
8031     if (SDValue FoldedVOp = SimplifyVBinOp(N))
8032       return FoldedVOp;
8033 
8034     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
8035     // If setcc produces all-one true value then:
8036     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
8037     if (N1CV && N1CV->isConstant()) {
8038       if (N0.getOpcode() == ISD::AND) {
8039         SDValue N00 = N0->getOperand(0);
8040         SDValue N01 = N0->getOperand(1);
8041         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
8042 
8043         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
8044             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
8045                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
8046           if (SDValue C =
8047                   DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N01, N1}))
8048             return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
8049         }
8050       }
8051     }
8052   }
8053 
8054   ConstantSDNode *N1C = isConstOrConstSplat(N1);
8055 
8056   // fold (shl c1, c2) -> c1<<c2
8057   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
8058     return C;
8059 
8060   if (SDValue NewSel = foldBinOpIntoSelect(N))
8061     return NewSel;
8062 
8063   // if (shl x, c) is known to be zero, return 0
8064   if (DAG.MaskedValueIsZero(SDValue(N, 0),
8065                             APInt::getAllOnesValue(OpSizeInBits)))
8066     return DAG.getConstant(0, SDLoc(N), VT);
8067 
8068   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
8069   if (N1.getOpcode() == ISD::TRUNCATE &&
8070       N1.getOperand(0).getOpcode() == ISD::AND) {
8071     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8072       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
8073   }
8074 
8075   if (SimplifyDemandedBits(SDValue(N, 0)))
8076     return SDValue(N, 0);
8077 
8078   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
8079   if (N0.getOpcode() == ISD::SHL) {
8080     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
8081                                           ConstantSDNode *RHS) {
8082       APInt c1 = LHS->getAPIntValue();
8083       APInt c2 = RHS->getAPIntValue();
8084       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8085       return (c1 + c2).uge(OpSizeInBits);
8086     };
8087     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
8088       return DAG.getConstant(0, SDLoc(N), VT);
8089 
8090     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
8091                                        ConstantSDNode *RHS) {
8092       APInt c1 = LHS->getAPIntValue();
8093       APInt c2 = RHS->getAPIntValue();
8094       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8095       return (c1 + c2).ult(OpSizeInBits);
8096     };
8097     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
8098       SDLoc DL(N);
8099       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
8100       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
8101     }
8102   }
8103 
8104   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
8105   // For this to be valid, the second form must not preserve any of the bits
8106   // that are shifted out by the inner shift in the first form.  This means
8107   // the outer shift size must be >= the number of bits added by the ext.
8108   // As a corollary, we don't care what kind of ext it is.
8109   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
8110        N0.getOpcode() == ISD::ANY_EXTEND ||
8111        N0.getOpcode() == ISD::SIGN_EXTEND) &&
8112       N0.getOperand(0).getOpcode() == ISD::SHL) {
8113     SDValue N0Op0 = N0.getOperand(0);
8114     SDValue InnerShiftAmt = N0Op0.getOperand(1);
8115     EVT InnerVT = N0Op0.getValueType();
8116     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
8117 
8118     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
8119                                                          ConstantSDNode *RHS) {
8120       APInt c1 = LHS->getAPIntValue();
8121       APInt c2 = RHS->getAPIntValue();
8122       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8123       return c2.uge(OpSizeInBits - InnerBitwidth) &&
8124              (c1 + c2).uge(OpSizeInBits);
8125     };
8126     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
8127                                   /*AllowUndefs*/ false,
8128                                   /*AllowTypeMismatch*/ true))
8129       return DAG.getConstant(0, SDLoc(N), VT);
8130 
8131     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
8132                                                       ConstantSDNode *RHS) {
8133       APInt c1 = LHS->getAPIntValue();
8134       APInt c2 = RHS->getAPIntValue();
8135       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8136       return c2.uge(OpSizeInBits - InnerBitwidth) &&
8137              (c1 + c2).ult(OpSizeInBits);
8138     };
8139     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
8140                                   /*AllowUndefs*/ false,
8141                                   /*AllowTypeMismatch*/ true)) {
8142       SDLoc DL(N);
8143       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
8144       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
8145       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
8146       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
8147     }
8148   }
8149 
8150   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
8151   // Only fold this if the inner zext has no other uses to avoid increasing
8152   // the total number of instructions.
8153   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
8154       N0.getOperand(0).getOpcode() == ISD::SRL) {
8155     SDValue N0Op0 = N0.getOperand(0);
8156     SDValue InnerShiftAmt = N0Op0.getOperand(1);
8157 
8158     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
8159       APInt c1 = LHS->getAPIntValue();
8160       APInt c2 = RHS->getAPIntValue();
8161       zeroExtendToMatch(c1, c2);
8162       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
8163     };
8164     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
8165                                   /*AllowUndefs*/ false,
8166                                   /*AllowTypeMismatch*/ true)) {
8167       SDLoc DL(N);
8168       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
8169       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
8170       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
8171       AddToWorklist(NewSHL.getNode());
8172       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
8173     }
8174   }
8175 
8176   // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
8177   // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1  > C2
8178   // TODO - support non-uniform vector shift amounts.
8179   if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
8180       N0->getFlags().hasExact()) {
8181     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
8182       uint64_t C1 = N0C1->getZExtValue();
8183       uint64_t C2 = N1C->getZExtValue();
8184       SDLoc DL(N);
8185       if (C1 <= C2)
8186         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
8187                            DAG.getConstant(C2 - C1, DL, ShiftVT));
8188       return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0),
8189                          DAG.getConstant(C1 - C2, DL, ShiftVT));
8190     }
8191   }
8192 
8193   // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
8194   //                               (and (srl x, (sub c1, c2), MASK)
8195   // Only fold this if the inner shift has no other uses -- if it does, folding
8196   // this will increase the total number of instructions.
8197   // TODO - drop hasOneUse requirement if c1 == c2?
8198   // TODO - support non-uniform vector shift amounts.
8199   if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() &&
8200       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
8201     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
8202       if (N0C1->getAPIntValue().ult(OpSizeInBits)) {
8203         uint64_t c1 = N0C1->getZExtValue();
8204         uint64_t c2 = N1C->getZExtValue();
8205         APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1);
8206         SDValue Shift;
8207         if (c2 > c1) {
8208           Mask <<= c2 - c1;
8209           SDLoc DL(N);
8210           Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
8211                               DAG.getConstant(c2 - c1, DL, ShiftVT));
8212         } else {
8213           Mask.lshrInPlace(c1 - c2);
8214           SDLoc DL(N);
8215           Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
8216                               DAG.getConstant(c1 - c2, DL, ShiftVT));
8217         }
8218         SDLoc DL(N0);
8219         return DAG.getNode(ISD::AND, DL, VT, Shift,
8220                            DAG.getConstant(Mask, DL, VT));
8221       }
8222     }
8223   }
8224 
8225   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
8226   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
8227       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
8228     SDLoc DL(N);
8229     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
8230     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
8231     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
8232   }
8233 
8234   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
8235   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
8236   // Variant of version done on multiply, except mul by a power of 2 is turned
8237   // into a shift.
8238   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
8239       N0.getNode()->hasOneUse() &&
8240       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
8241       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
8242       TLI.isDesirableToCommuteWithShift(N, Level)) {
8243     SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
8244     SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
8245     AddToWorklist(Shl0.getNode());
8246     AddToWorklist(Shl1.getNode());
8247     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
8248   }
8249 
8250   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
8251   if (N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse() &&
8252       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
8253       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) {
8254     SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
8255     if (isConstantOrConstantVector(Shl))
8256       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
8257   }
8258 
8259   if (N1C && !N1C->isOpaque())
8260     if (SDValue NewSHL = visitShiftByConstant(N))
8261       return NewSHL;
8262 
8263   // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
8264   if (N0.getOpcode() == ISD::VSCALE)
8265     if (ConstantSDNode *NC1 = isConstOrConstSplat(N->getOperand(1))) {
8266       const APInt &C0 = N0.getConstantOperandAPInt(0);
8267       const APInt &C1 = NC1->getAPIntValue();
8268       return DAG.getVScale(SDLoc(N), VT, C0 << C1);
8269     }
8270 
8271   return SDValue();
8272 }
8273 
8274 // Transform a right shift of a multiply into a multiply-high.
8275 // Examples:
8276 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
8277 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)8278 static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
8279                                   const TargetLowering &TLI) {
8280   assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
8281          "SRL or SRA node is required here!");
8282 
8283   // Check the shift amount. Proceed with the transformation if the shift
8284   // amount is constant.
8285   ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
8286   if (!ShiftAmtSrc)
8287     return SDValue();
8288 
8289   SDLoc DL(N);
8290 
8291   // The operation feeding into the shift must be a multiply.
8292   SDValue ShiftOperand = N->getOperand(0);
8293   if (ShiftOperand.getOpcode() != ISD::MUL)
8294     return SDValue();
8295 
8296   // Both operands must be equivalent extend nodes.
8297   SDValue LeftOp = ShiftOperand.getOperand(0);
8298   SDValue RightOp = ShiftOperand.getOperand(1);
8299   bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
8300   bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
8301 
8302   if ((!(IsSignExt || IsZeroExt)) || LeftOp.getOpcode() != RightOp.getOpcode())
8303     return SDValue();
8304 
8305   EVT WideVT1 = LeftOp.getValueType();
8306   EVT WideVT2 = RightOp.getValueType();
8307   (void)WideVT2;
8308   // Proceed with the transformation if the wide types match.
8309   assert((WideVT1 == WideVT2) &&
8310          "Cannot have a multiply node with two different operand types.");
8311 
8312   EVT NarrowVT = LeftOp.getOperand(0).getValueType();
8313   // Check that the two extend nodes are the same type.
8314   if (NarrowVT !=  RightOp.getOperand(0).getValueType())
8315     return SDValue();
8316 
8317   // Proceed with the transformation if the wide type is twice as large
8318   // as the narrow type.
8319   unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
8320   if (WideVT1.getScalarSizeInBits() != 2 * NarrowVTSize)
8321     return SDValue();
8322 
8323   // Check the shift amount with the narrow type size.
8324   // Proceed with the transformation if the shift amount is the width
8325   // of the narrow type.
8326   unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
8327   if (ShiftAmt != NarrowVTSize)
8328     return SDValue();
8329 
8330   // If the operation feeding into the MUL is a sign extend (sext),
8331   // we use mulhs. Othewise, zero extends (zext) use mulhu.
8332   unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
8333 
8334   // Combine to mulh if mulh is legal/custom for the narrow type on the target.
8335   if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
8336     return SDValue();
8337 
8338   SDValue Result = DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0),
8339                                RightOp.getOperand(0));
8340   return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT1)
8341                                      : DAG.getZExtOrTrunc(Result, DL, WideVT1));
8342 }
8343 
visitSRA(SDNode * N)8344 SDValue DAGCombiner::visitSRA(SDNode *N) {
8345   SDValue N0 = N->getOperand(0);
8346   SDValue N1 = N->getOperand(1);
8347   if (SDValue V = DAG.simplifyShift(N0, N1))
8348     return V;
8349 
8350   EVT VT = N0.getValueType();
8351   unsigned OpSizeInBits = VT.getScalarSizeInBits();
8352 
8353   // Arithmetic shifting an all-sign-bit value is a no-op.
8354   // fold (sra 0, x) -> 0
8355   // fold (sra -1, x) -> -1
8356   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
8357     return N0;
8358 
8359   // fold vector ops
8360   if (VT.isVector())
8361     if (SDValue FoldedVOp = SimplifyVBinOp(N))
8362       return FoldedVOp;
8363 
8364   ConstantSDNode *N1C = isConstOrConstSplat(N1);
8365 
8366   // fold (sra c1, c2) -> (sra c1, c2)
8367   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
8368     return C;
8369 
8370   if (SDValue NewSel = foldBinOpIntoSelect(N))
8371     return NewSel;
8372 
8373   // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
8374   // sext_inreg.
8375   if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
8376     unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
8377     EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
8378     if (VT.isVector())
8379       ExtVT = EVT::getVectorVT(*DAG.getContext(),
8380                                ExtVT, VT.getVectorNumElements());
8381     if (!LegalOperations ||
8382         TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
8383         TargetLowering::Legal)
8384       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
8385                          N0.getOperand(0), DAG.getValueType(ExtVT));
8386   }
8387 
8388   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
8389   // clamp (add c1, c2) to max shift.
8390   if (N0.getOpcode() == ISD::SRA) {
8391     SDLoc DL(N);
8392     EVT ShiftVT = N1.getValueType();
8393     EVT ShiftSVT = ShiftVT.getScalarType();
8394     SmallVector<SDValue, 16> ShiftValues;
8395 
8396     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
8397       APInt c1 = LHS->getAPIntValue();
8398       APInt c2 = RHS->getAPIntValue();
8399       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8400       APInt Sum = c1 + c2;
8401       unsigned ShiftSum =
8402           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
8403       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
8404       return true;
8405     };
8406     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
8407       SDValue ShiftValue;
8408       if (VT.isVector())
8409         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
8410       else
8411         ShiftValue = ShiftValues[0];
8412       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
8413     }
8414   }
8415 
8416   // fold (sra (shl X, m), (sub result_size, n))
8417   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
8418   // result_size - n != m.
8419   // If truncate is free for the target sext(shl) is likely to result in better
8420   // code.
8421   if (N0.getOpcode() == ISD::SHL && N1C) {
8422     // Get the two constanst of the shifts, CN0 = m, CN = n.
8423     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
8424     if (N01C) {
8425       LLVMContext &Ctx = *DAG.getContext();
8426       // Determine what the truncate's result bitsize and type would be.
8427       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
8428 
8429       if (VT.isVector())
8430         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements());
8431 
8432       // Determine the residual right-shift amount.
8433       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
8434 
8435       // If the shift is not a no-op (in which case this should be just a sign
8436       // extend already), the truncated to type is legal, sign_extend is legal
8437       // on that type, and the truncate to that type is both legal and free,
8438       // perform the transform.
8439       if ((ShiftAmt > 0) &&
8440           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
8441           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
8442           TLI.isTruncateFree(VT, TruncVT)) {
8443         SDLoc DL(N);
8444         SDValue Amt = DAG.getConstant(ShiftAmt, DL,
8445             getShiftAmountTy(N0.getOperand(0).getValueType()));
8446         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
8447                                     N0.getOperand(0), Amt);
8448         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
8449                                     Shift);
8450         return DAG.getNode(ISD::SIGN_EXTEND, DL,
8451                            N->getValueType(0), Trunc);
8452       }
8453     }
8454   }
8455 
8456   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
8457   //   sra (add (shl X, N1C), AddC), N1C -->
8458   //   sext (add (trunc X to (width - N1C)), AddC')
8459   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() && N1C &&
8460       N0.getOperand(0).getOpcode() == ISD::SHL &&
8461       N0.getOperand(0).getOperand(1) == N1 && N0.getOperand(0).hasOneUse()) {
8462     if (ConstantSDNode *AddC = isConstOrConstSplat(N0.getOperand(1))) {
8463       SDValue Shl = N0.getOperand(0);
8464       // Determine what the truncate's type would be and ask the target if that
8465       // is a free operation.
8466       LLVMContext &Ctx = *DAG.getContext();
8467       unsigned ShiftAmt = N1C->getZExtValue();
8468       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
8469       if (VT.isVector())
8470         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements());
8471 
8472       // TODO: The simple type check probably belongs in the default hook
8473       //       implementation and/or target-specific overrides (because
8474       //       non-simple types likely require masking when legalized), but that
8475       //       restriction may conflict with other transforms.
8476       if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
8477           TLI.isTruncateFree(VT, TruncVT)) {
8478         SDLoc DL(N);
8479         SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
8480         SDValue ShiftC = DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).
8481                              trunc(TruncVT.getScalarSizeInBits()), DL, TruncVT);
8482         SDValue Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
8483         return DAG.getSExtOrTrunc(Add, DL, VT);
8484       }
8485     }
8486   }
8487 
8488   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
8489   if (N1.getOpcode() == ISD::TRUNCATE &&
8490       N1.getOperand(0).getOpcode() == ISD::AND) {
8491     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8492       return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
8493   }
8494 
8495   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
8496   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
8497   //      if c1 is equal to the number of bits the trunc removes
8498   // TODO - support non-uniform vector shift amounts.
8499   if (N0.getOpcode() == ISD::TRUNCATE &&
8500       (N0.getOperand(0).getOpcode() == ISD::SRL ||
8501        N0.getOperand(0).getOpcode() == ISD::SRA) &&
8502       N0.getOperand(0).hasOneUse() &&
8503       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
8504     SDValue N0Op0 = N0.getOperand(0);
8505     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
8506       EVT LargeVT = N0Op0.getValueType();
8507       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
8508       if (LargeShift->getAPIntValue() == TruncBits) {
8509         SDLoc DL(N);
8510         SDValue Amt = DAG.getConstant(N1C->getZExtValue() + TruncBits, DL,
8511                                       getShiftAmountTy(LargeVT));
8512         SDValue SRA =
8513             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
8514         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
8515       }
8516     }
8517   }
8518 
8519   // Simplify, based on bits shifted out of the LHS.
8520   if (SimplifyDemandedBits(SDValue(N, 0)))
8521     return SDValue(N, 0);
8522 
8523   // If the sign bit is known to be zero, switch this to a SRL.
8524   if (DAG.SignBitIsZero(N0))
8525     return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
8526 
8527   if (N1C && !N1C->isOpaque())
8528     if (SDValue NewSRA = visitShiftByConstant(N))
8529       return NewSRA;
8530 
8531   // Try to transform this shift into a multiply-high if
8532   // it matches the appropriate pattern detected in combineShiftToMULH.
8533   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
8534     return MULH;
8535 
8536   return SDValue();
8537 }
8538 
visitSRL(SDNode * N)8539 SDValue DAGCombiner::visitSRL(SDNode *N) {
8540   SDValue N0 = N->getOperand(0);
8541   SDValue N1 = N->getOperand(1);
8542   if (SDValue V = DAG.simplifyShift(N0, N1))
8543     return V;
8544 
8545   EVT VT = N0.getValueType();
8546   unsigned OpSizeInBits = VT.getScalarSizeInBits();
8547 
8548   // fold vector ops
8549   if (VT.isVector())
8550     if (SDValue FoldedVOp = SimplifyVBinOp(N))
8551       return FoldedVOp;
8552 
8553   ConstantSDNode *N1C = isConstOrConstSplat(N1);
8554 
8555   // fold (srl c1, c2) -> c1 >>u c2
8556   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
8557     return C;
8558 
8559   if (SDValue NewSel = foldBinOpIntoSelect(N))
8560     return NewSel;
8561 
8562   // if (srl x, c) is known to be zero, return 0
8563   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
8564                                    APInt::getAllOnesValue(OpSizeInBits)))
8565     return DAG.getConstant(0, SDLoc(N), VT);
8566 
8567   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
8568   if (N0.getOpcode() == ISD::SRL) {
8569     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
8570                                           ConstantSDNode *RHS) {
8571       APInt c1 = LHS->getAPIntValue();
8572       APInt c2 = RHS->getAPIntValue();
8573       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8574       return (c1 + c2).uge(OpSizeInBits);
8575     };
8576     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
8577       return DAG.getConstant(0, SDLoc(N), VT);
8578 
8579     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
8580                                        ConstantSDNode *RHS) {
8581       APInt c1 = LHS->getAPIntValue();
8582       APInt c2 = RHS->getAPIntValue();
8583       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8584       return (c1 + c2).ult(OpSizeInBits);
8585     };
8586     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
8587       SDLoc DL(N);
8588       EVT ShiftVT = N1.getValueType();
8589       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
8590       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
8591     }
8592   }
8593 
8594   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
8595       N0.getOperand(0).getOpcode() == ISD::SRL) {
8596     SDValue InnerShift = N0.getOperand(0);
8597     // TODO - support non-uniform vector shift amounts.
8598     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
8599       uint64_t c1 = N001C->getZExtValue();
8600       uint64_t c2 = N1C->getZExtValue();
8601       EVT InnerShiftVT = InnerShift.getValueType();
8602       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
8603       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
8604       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
8605       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
8606       if (c1 + OpSizeInBits == InnerShiftSize) {
8607         SDLoc DL(N);
8608         if (c1 + c2 >= InnerShiftSize)
8609           return DAG.getConstant(0, DL, VT);
8610         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
8611         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
8612                                        InnerShift.getOperand(0), NewShiftAmt);
8613         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
8614       }
8615       // In the more general case, we can clear the high bits after the shift:
8616       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
8617       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
8618           c1 + c2 < InnerShiftSize) {
8619         SDLoc DL(N);
8620         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
8621         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
8622                                        InnerShift.getOperand(0), NewShiftAmt);
8623         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
8624                                                             OpSizeInBits - c2),
8625                                        DL, InnerShiftVT);
8626         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
8627         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
8628       }
8629     }
8630   }
8631 
8632   // fold (srl (shl x, c), c) -> (and x, cst2)
8633   // TODO - (srl (shl x, c1), c2).
8634   if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 &&
8635       isConstantOrConstantVector(N1, /* NoOpaques */ true)) {
8636     SDLoc DL(N);
8637     SDValue Mask =
8638         DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1);
8639     AddToWorklist(Mask.getNode());
8640     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask);
8641   }
8642 
8643   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
8644   // TODO - support non-uniform vector shift amounts.
8645   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
8646     // Shifting in all undef bits?
8647     EVT SmallVT = N0.getOperand(0).getValueType();
8648     unsigned BitSize = SmallVT.getScalarSizeInBits();
8649     if (N1C->getAPIntValue().uge(BitSize))
8650       return DAG.getUNDEF(VT);
8651 
8652     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
8653       uint64_t ShiftAmt = N1C->getZExtValue();
8654       SDLoc DL0(N0);
8655       SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
8656                                        N0.getOperand(0),
8657                           DAG.getConstant(ShiftAmt, DL0,
8658                                           getShiftAmountTy(SmallVT)));
8659       AddToWorklist(SmallShift.getNode());
8660       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
8661       SDLoc DL(N);
8662       return DAG.getNode(ISD::AND, DL, VT,
8663                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
8664                          DAG.getConstant(Mask, DL, VT));
8665     }
8666   }
8667 
8668   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
8669   // bit, which is unmodified by sra.
8670   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
8671     if (N0.getOpcode() == ISD::SRA)
8672       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
8673   }
8674 
8675   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit).
8676   if (N1C && N0.getOpcode() == ISD::CTLZ &&
8677       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
8678     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
8679 
8680     // If any of the input bits are KnownOne, then the input couldn't be all
8681     // zeros, thus the result of the srl will always be zero.
8682     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
8683 
8684     // If all of the bits input the to ctlz node are known to be zero, then
8685     // the result of the ctlz is "32" and the result of the shift is one.
8686     APInt UnknownBits = ~Known.Zero;
8687     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
8688 
8689     // Otherwise, check to see if there is exactly one bit input to the ctlz.
8690     if (UnknownBits.isPowerOf2()) {
8691       // Okay, we know that only that the single bit specified by UnknownBits
8692       // could be set on input to the CTLZ node. If this bit is set, the SRL
8693       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
8694       // to an SRL/XOR pair, which is likely to simplify more.
8695       unsigned ShAmt = UnknownBits.countTrailingZeros();
8696       SDValue Op = N0.getOperand(0);
8697 
8698       if (ShAmt) {
8699         SDLoc DL(N0);
8700         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
8701                   DAG.getConstant(ShAmt, DL,
8702                                   getShiftAmountTy(Op.getValueType())));
8703         AddToWorklist(Op.getNode());
8704       }
8705 
8706       SDLoc DL(N);
8707       return DAG.getNode(ISD::XOR, DL, VT,
8708                          Op, DAG.getConstant(1, DL, VT));
8709     }
8710   }
8711 
8712   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
8713   if (N1.getOpcode() == ISD::TRUNCATE &&
8714       N1.getOperand(0).getOpcode() == ISD::AND) {
8715     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8716       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
8717   }
8718 
8719   // fold operands of srl based on knowledge that the low bits are not
8720   // demanded.
8721   if (SimplifyDemandedBits(SDValue(N, 0)))
8722     return SDValue(N, 0);
8723 
8724   if (N1C && !N1C->isOpaque())
8725     if (SDValue NewSRL = visitShiftByConstant(N))
8726       return NewSRL;
8727 
8728   // Attempt to convert a srl of a load into a narrower zero-extending load.
8729   if (SDValue NarrowLoad = ReduceLoadWidth(N))
8730     return NarrowLoad;
8731 
8732   // Here is a common situation. We want to optimize:
8733   //
8734   //   %a = ...
8735   //   %b = and i32 %a, 2
8736   //   %c = srl i32 %b, 1
8737   //   brcond i32 %c ...
8738   //
8739   // into
8740   //
8741   //   %a = ...
8742   //   %b = and %a, 2
8743   //   %c = setcc eq %b, 0
8744   //   brcond %c ...
8745   //
8746   // However when after the source operand of SRL is optimized into AND, the SRL
8747   // itself may not be optimized further. Look for it and add the BRCOND into
8748   // the worklist.
8749   if (N->hasOneUse()) {
8750     SDNode *Use = *N->use_begin();
8751     if (Use->getOpcode() == ISD::BRCOND)
8752       AddToWorklist(Use);
8753     else if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse()) {
8754       // Also look pass the truncate.
8755       Use = *Use->use_begin();
8756       if (Use->getOpcode() == ISD::BRCOND)
8757         AddToWorklist(Use);
8758     }
8759   }
8760 
8761   // Try to transform this shift into a multiply-high if
8762   // it matches the appropriate pattern detected in combineShiftToMULH.
8763   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
8764     return MULH;
8765 
8766   return SDValue();
8767 }
8768 
visitFunnelShift(SDNode * N)8769 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
8770   EVT VT = N->getValueType(0);
8771   SDValue N0 = N->getOperand(0);
8772   SDValue N1 = N->getOperand(1);
8773   SDValue N2 = N->getOperand(2);
8774   bool IsFSHL = N->getOpcode() == ISD::FSHL;
8775   unsigned BitWidth = VT.getScalarSizeInBits();
8776 
8777   // fold (fshl N0, N1, 0) -> N0
8778   // fold (fshr N0, N1, 0) -> N1
8779   if (isPowerOf2_32(BitWidth))
8780     if (DAG.MaskedValueIsZero(
8781             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
8782       return IsFSHL ? N0 : N1;
8783 
8784   auto IsUndefOrZero = [](SDValue V) {
8785     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
8786   };
8787 
8788   // TODO - support non-uniform vector shift amounts.
8789   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
8790     EVT ShAmtTy = N2.getValueType();
8791 
8792     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
8793     if (Cst->getAPIntValue().uge(BitWidth)) {
8794       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
8795       return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
8796                          DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
8797     }
8798 
8799     unsigned ShAmt = Cst->getZExtValue();
8800     if (ShAmt == 0)
8801       return IsFSHL ? N0 : N1;
8802 
8803     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
8804     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
8805     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
8806     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
8807     if (IsUndefOrZero(N0))
8808       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
8809                          DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
8810                                          SDLoc(N), ShAmtTy));
8811     if (IsUndefOrZero(N1))
8812       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
8813                          DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
8814                                          SDLoc(N), ShAmtTy));
8815 
8816     // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
8817     // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
8818     // TODO - bigendian support once we have test coverage.
8819     // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
8820     // TODO - permit LHS EXTLOAD if extensions are shifted out.
8821     if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
8822         !DAG.getDataLayout().isBigEndian()) {
8823       auto *LHS = dyn_cast<LoadSDNode>(N0);
8824       auto *RHS = dyn_cast<LoadSDNode>(N1);
8825       if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
8826           LHS->getAddressSpace() == RHS->getAddressSpace() &&
8827           (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) &&
8828           ISD::isNON_EXTLoad(LHS)) {
8829         if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
8830           SDLoc DL(RHS);
8831           uint64_t PtrOff =
8832               IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
8833           Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
8834           bool Fast = false;
8835           if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
8836                                      RHS->getAddressSpace(), NewAlign,
8837                                      RHS->getMemOperand()->getFlags(), &Fast) &&
8838               Fast) {
8839             SDValue NewPtr = DAG.getMemBasePlusOffset(
8840                 RHS->getBasePtr(), TypeSize::Fixed(PtrOff), DL);
8841             AddToWorklist(NewPtr.getNode());
8842             SDValue Load = DAG.getLoad(
8843                 VT, DL, RHS->getChain(), NewPtr,
8844                 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
8845                 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
8846             // Replace the old load's chain with the new load's chain.
8847             WorklistRemover DeadNodes(*this);
8848             DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1));
8849             return Load;
8850           }
8851         }
8852       }
8853     }
8854   }
8855 
8856   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
8857   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
8858   // iff We know the shift amount is in range.
8859   // TODO: when is it worth doing SUB(BW, N2) as well?
8860   if (isPowerOf2_32(BitWidth)) {
8861     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
8862     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
8863       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
8864     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
8865       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
8866   }
8867 
8868   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
8869   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
8870   // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
8871   // is legal as well we might be better off avoiding non-constant (BW - N2).
8872   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
8873   if (N0 == N1 && hasOperation(RotOpc, VT))
8874     return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
8875 
8876   // Simplify, based on bits shifted out of N0/N1.
8877   if (SimplifyDemandedBits(SDValue(N, 0)))
8878     return SDValue(N, 0);
8879 
8880   return SDValue();
8881 }
8882 
visitABS(SDNode * N)8883 SDValue DAGCombiner::visitABS(SDNode *N) {
8884   SDValue N0 = N->getOperand(0);
8885   EVT VT = N->getValueType(0);
8886 
8887   // fold (abs c1) -> c2
8888   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8889     return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
8890   // fold (abs (abs x)) -> (abs x)
8891   if (N0.getOpcode() == ISD::ABS)
8892     return N0;
8893   // fold (abs x) -> x iff not-negative
8894   if (DAG.SignBitIsZero(N0))
8895     return N0;
8896   return SDValue();
8897 }
8898 
visitBSWAP(SDNode * N)8899 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
8900   SDValue N0 = N->getOperand(0);
8901   EVT VT = N->getValueType(0);
8902 
8903   // fold (bswap c1) -> c2
8904   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8905     return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N0);
8906   // fold (bswap (bswap x)) -> x
8907   if (N0.getOpcode() == ISD::BSWAP)
8908     return N0->getOperand(0);
8909   return SDValue();
8910 }
8911 
visitBITREVERSE(SDNode * N)8912 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
8913   SDValue N0 = N->getOperand(0);
8914   EVT VT = N->getValueType(0);
8915 
8916   // fold (bitreverse c1) -> c2
8917   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8918     return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
8919   // fold (bitreverse (bitreverse x)) -> x
8920   if (N0.getOpcode() == ISD::BITREVERSE)
8921     return N0.getOperand(0);
8922   return SDValue();
8923 }
8924 
visitCTLZ(SDNode * N)8925 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
8926   SDValue N0 = N->getOperand(0);
8927   EVT VT = N->getValueType(0);
8928 
8929   // fold (ctlz c1) -> c2
8930   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8931     return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
8932 
8933   // If the value is known never to be zero, switch to the undef version.
8934   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
8935     if (DAG.isKnownNeverZero(N0))
8936       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8937   }
8938 
8939   return SDValue();
8940 }
8941 
visitCTLZ_ZERO_UNDEF(SDNode * N)8942 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
8943   SDValue N0 = N->getOperand(0);
8944   EVT VT = N->getValueType(0);
8945 
8946   // fold (ctlz_zero_undef c1) -> c2
8947   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8948     return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8949   return SDValue();
8950 }
8951 
visitCTTZ(SDNode * N)8952 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
8953   SDValue N0 = N->getOperand(0);
8954   EVT VT = N->getValueType(0);
8955 
8956   // fold (cttz c1) -> c2
8957   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8958     return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
8959 
8960   // If the value is known never to be zero, switch to the undef version.
8961   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
8962     if (DAG.isKnownNeverZero(N0))
8963       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8964   }
8965 
8966   return SDValue();
8967 }
8968 
visitCTTZ_ZERO_UNDEF(SDNode * N)8969 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
8970   SDValue N0 = N->getOperand(0);
8971   EVT VT = N->getValueType(0);
8972 
8973   // fold (cttz_zero_undef c1) -> c2
8974   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8975     return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8976   return SDValue();
8977 }
8978 
visitCTPOP(SDNode * N)8979 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
8980   SDValue N0 = N->getOperand(0);
8981   EVT VT = N->getValueType(0);
8982 
8983   // fold (ctpop c1) -> c2
8984   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8985     return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
8986   return SDValue();
8987 }
8988 
8989 // FIXME: This should be checking for no signed zeros on individual operands, as
8990 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)8991 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
8992                                          SDValue RHS,
8993                                          const TargetLowering &TLI) {
8994   const TargetOptions &Options = DAG.getTarget().Options;
8995   EVT VT = LHS.getValueType();
8996 
8997   return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
8998          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
8999          DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
9000 }
9001 
9002 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)9003 static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
9004                                    SDValue RHS, SDValue True, SDValue False,
9005                                    ISD::CondCode CC, const TargetLowering &TLI,
9006                                    SelectionDAG &DAG) {
9007   if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True))
9008     return SDValue();
9009 
9010   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
9011   switch (CC) {
9012   case ISD::SETOLT:
9013   case ISD::SETOLE:
9014   case ISD::SETLT:
9015   case ISD::SETLE:
9016   case ISD::SETULT:
9017   case ISD::SETULE: {
9018     // Since it's known never nan to get here already, either fminnum or
9019     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
9020     // expanded in terms of it.
9021     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
9022     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
9023       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
9024 
9025     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
9026     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
9027       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
9028     return SDValue();
9029   }
9030   case ISD::SETOGT:
9031   case ISD::SETOGE:
9032   case ISD::SETGT:
9033   case ISD::SETGE:
9034   case ISD::SETUGT:
9035   case ISD::SETUGE: {
9036     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
9037     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
9038       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
9039 
9040     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
9041     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
9042       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
9043     return SDValue();
9044   }
9045   default:
9046     return SDValue();
9047   }
9048 }
9049 
9050 /// If a (v)select has a condition value that is a sign-bit test, try to smear
9051 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)9052 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
9053   SDValue Cond = N->getOperand(0);
9054   SDValue C1 = N->getOperand(1);
9055   SDValue C2 = N->getOperand(2);
9056   assert(isConstantOrConstantVector(C1) && isConstantOrConstantVector(C2) &&
9057          "Expected select-of-constants");
9058 
9059   EVT VT = N->getValueType(0);
9060   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
9061       VT != Cond.getOperand(0).getValueType())
9062     return SDValue();
9063 
9064   // The inverted-condition + commuted-select variants of these patterns are
9065   // canonicalized to these forms in IR.
9066   SDValue X = Cond.getOperand(0);
9067   SDValue CondC = Cond.getOperand(1);
9068   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
9069   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
9070       isAllOnesOrAllOnesSplat(C2)) {
9071     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
9072     SDLoc DL(N);
9073     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
9074     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
9075     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
9076   }
9077   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
9078     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
9079     SDLoc DL(N);
9080     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
9081     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
9082     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
9083   }
9084   return SDValue();
9085 }
9086 
foldSelectOfConstants(SDNode * N)9087 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
9088   SDValue Cond = N->getOperand(0);
9089   SDValue N1 = N->getOperand(1);
9090   SDValue N2 = N->getOperand(2);
9091   EVT VT = N->getValueType(0);
9092   EVT CondVT = Cond.getValueType();
9093   SDLoc DL(N);
9094 
9095   if (!VT.isInteger())
9096     return SDValue();
9097 
9098   auto *C1 = dyn_cast<ConstantSDNode>(N1);
9099   auto *C2 = dyn_cast<ConstantSDNode>(N2);
9100   if (!C1 || !C2)
9101     return SDValue();
9102 
9103   // Only do this before legalization to avoid conflicting with target-specific
9104   // transforms in the other direction (create a select from a zext/sext). There
9105   // is also a target-independent combine here in DAGCombiner in the other
9106   // direction for (select Cond, -1, 0) when the condition is not i1.
9107   if (CondVT == MVT::i1 && !LegalOperations) {
9108     if (C1->isNullValue() && C2->isOne()) {
9109       // select Cond, 0, 1 --> zext (!Cond)
9110       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
9111       if (VT != MVT::i1)
9112         NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond);
9113       return NotCond;
9114     }
9115     if (C1->isNullValue() && C2->isAllOnesValue()) {
9116       // select Cond, 0, -1 --> sext (!Cond)
9117       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
9118       if (VT != MVT::i1)
9119         NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond);
9120       return NotCond;
9121     }
9122     if (C1->isOne() && C2->isNullValue()) {
9123       // select Cond, 1, 0 --> zext (Cond)
9124       if (VT != MVT::i1)
9125         Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
9126       return Cond;
9127     }
9128     if (C1->isAllOnesValue() && C2->isNullValue()) {
9129       // select Cond, -1, 0 --> sext (Cond)
9130       if (VT != MVT::i1)
9131         Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
9132       return Cond;
9133     }
9134 
9135     // Use a target hook because some targets may prefer to transform in the
9136     // other direction.
9137     if (TLI.convertSelectOfConstantsToMath(VT)) {
9138       // For any constants that differ by 1, we can transform the select into an
9139       // extend and add.
9140       const APInt &C1Val = C1->getAPIntValue();
9141       const APInt &C2Val = C2->getAPIntValue();
9142       if (C1Val - 1 == C2Val) {
9143         // select Cond, C1, C1-1 --> add (zext Cond), C1-1
9144         if (VT != MVT::i1)
9145           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
9146         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
9147       }
9148       if (C1Val + 1 == C2Val) {
9149         // select Cond, C1, C1+1 --> add (sext Cond), C1+1
9150         if (VT != MVT::i1)
9151           Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
9152         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
9153       }
9154 
9155       // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
9156       if (C1Val.isPowerOf2() && C2Val.isNullValue()) {
9157         if (VT != MVT::i1)
9158           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
9159         SDValue ShAmtC = DAG.getConstant(C1Val.exactLogBase2(), DL, VT);
9160         return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
9161       }
9162 
9163       if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
9164         return V;
9165     }
9166 
9167     return SDValue();
9168   }
9169 
9170   // fold (select Cond, 0, 1) -> (xor Cond, 1)
9171   // We can't do this reliably if integer based booleans have different contents
9172   // to floating point based booleans. This is because we can't tell whether we
9173   // have an integer-based boolean or a floating-point-based boolean unless we
9174   // can find the SETCC that produced it and inspect its operands. This is
9175   // fairly easy if C is the SETCC node, but it can potentially be
9176   // undiscoverable (or not reasonably discoverable). For example, it could be
9177   // in another basic block or it could require searching a complicated
9178   // expression.
9179   if (CondVT.isInteger() &&
9180       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
9181           TargetLowering::ZeroOrOneBooleanContent &&
9182       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
9183           TargetLowering::ZeroOrOneBooleanContent &&
9184       C1->isNullValue() && C2->isOne()) {
9185     SDValue NotCond =
9186         DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
9187     if (VT.bitsEq(CondVT))
9188       return NotCond;
9189     return DAG.getZExtOrTrunc(NotCond, DL, VT);
9190   }
9191 
9192   return SDValue();
9193 }
9194 
visitSELECT(SDNode * N)9195 SDValue DAGCombiner::visitSELECT(SDNode *N) {
9196   SDValue N0 = N->getOperand(0);
9197   SDValue N1 = N->getOperand(1);
9198   SDValue N2 = N->getOperand(2);
9199   EVT VT = N->getValueType(0);
9200   EVT VT0 = N0.getValueType();
9201   SDLoc DL(N);
9202   SDNodeFlags Flags = N->getFlags();
9203 
9204   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
9205     return V;
9206 
9207   // fold (select X, X, Y) -> (or X, Y)
9208   // fold (select X, 1, Y) -> (or C, Y)
9209   if (VT == VT0 && VT == MVT::i1 && (N0 == N1 || isOneConstant(N1)))
9210     return DAG.getNode(ISD::OR, DL, VT, N0, N2);
9211 
9212   if (SDValue V = foldSelectOfConstants(N))
9213     return V;
9214 
9215   // fold (select C, 0, X) -> (and (not C), X)
9216   if (VT == VT0 && VT == MVT::i1 && isNullConstant(N1)) {
9217     SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
9218     AddToWorklist(NOTNode.getNode());
9219     return DAG.getNode(ISD::AND, DL, VT, NOTNode, N2);
9220   }
9221   // fold (select C, X, 1) -> (or (not C), X)
9222   if (VT == VT0 && VT == MVT::i1 && isOneConstant(N2)) {
9223     SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
9224     AddToWorklist(NOTNode.getNode());
9225     return DAG.getNode(ISD::OR, DL, VT, NOTNode, N1);
9226   }
9227   // fold (select X, Y, X) -> (and X, Y)
9228   // fold (select X, Y, 0) -> (and X, Y)
9229   if (VT == VT0 && VT == MVT::i1 && (N0 == N2 || isNullConstant(N2)))
9230     return DAG.getNode(ISD::AND, DL, VT, N0, N1);
9231 
9232   // If we can fold this based on the true/false value, do so.
9233   if (SimplifySelectOps(N, N1, N2))
9234     return SDValue(N, 0); // Don't revisit N.
9235 
9236   if (VT0 == MVT::i1) {
9237     // The code in this block deals with the following 2 equivalences:
9238     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
9239     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
9240     // The target can specify its preferred form with the
9241     // shouldNormalizeToSelectSequence() callback. However we always transform
9242     // to the right anyway if we find the inner select exists in the DAG anyway
9243     // and we always transform to the left side if we know that we can further
9244     // optimize the combination of the conditions.
9245     bool normalizeToSequence =
9246         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
9247     // select (and Cond0, Cond1), X, Y
9248     //   -> select Cond0, (select Cond1, X, Y), Y
9249     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
9250       SDValue Cond0 = N0->getOperand(0);
9251       SDValue Cond1 = N0->getOperand(1);
9252       SDValue InnerSelect =
9253           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
9254       if (normalizeToSequence || !InnerSelect.use_empty())
9255         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
9256                            InnerSelect, N2, Flags);
9257       // Cleanup on failure.
9258       if (InnerSelect.use_empty())
9259         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
9260     }
9261     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
9262     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
9263       SDValue Cond0 = N0->getOperand(0);
9264       SDValue Cond1 = N0->getOperand(1);
9265       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
9266                                         Cond1, N1, N2, Flags);
9267       if (normalizeToSequence || !InnerSelect.use_empty())
9268         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
9269                            InnerSelect, Flags);
9270       // Cleanup on failure.
9271       if (InnerSelect.use_empty())
9272         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
9273     }
9274 
9275     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
9276     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
9277       SDValue N1_0 = N1->getOperand(0);
9278       SDValue N1_1 = N1->getOperand(1);
9279       SDValue N1_2 = N1->getOperand(2);
9280       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
9281         // Create the actual and node if we can generate good code for it.
9282         if (!normalizeToSequence) {
9283           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
9284           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
9285                              N2, Flags);
9286         }
9287         // Otherwise see if we can optimize the "and" to a better pattern.
9288         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
9289           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
9290                              N2, Flags);
9291         }
9292       }
9293     }
9294     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
9295     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
9296       SDValue N2_0 = N2->getOperand(0);
9297       SDValue N2_1 = N2->getOperand(1);
9298       SDValue N2_2 = N2->getOperand(2);
9299       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
9300         // Create the actual or node if we can generate good code for it.
9301         if (!normalizeToSequence) {
9302           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
9303           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
9304                              N2_2, Flags);
9305         }
9306         // Otherwise see if we can optimize to a better pattern.
9307         if (SDValue Combined = visitORLike(N0, N2_0, N))
9308           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
9309                              N2_2, Flags);
9310       }
9311     }
9312   }
9313 
9314   // select (not Cond), N1, N2 -> select Cond, N2, N1
9315   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
9316     SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
9317     SelectOp->setFlags(Flags);
9318     return SelectOp;
9319   }
9320 
9321   // Fold selects based on a setcc into other things, such as min/max/abs.
9322   if (N0.getOpcode() == ISD::SETCC) {
9323     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
9324     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
9325 
9326     // select (fcmp lt x, y), x, y -> fminnum x, y
9327     // select (fcmp gt x, y), x, y -> fmaxnum x, y
9328     //
9329     // This is OK if we don't care what happens if either operand is a NaN.
9330     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
9331       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2,
9332                                                 CC, TLI, DAG))
9333         return FMinMax;
9334 
9335     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
9336     // This is conservatively limited to pre-legal-operations to give targets
9337     // a chance to reverse the transform if they want to do that. Also, it is
9338     // unlikely that the pattern would be formed late, so it's probably not
9339     // worth going through the other checks.
9340     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
9341         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
9342         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
9343       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
9344       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
9345       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
9346         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
9347         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
9348         //
9349         // The IR equivalent of this transform would have this form:
9350         //   %a = add %x, C
9351         //   %c = icmp ugt %x, ~C
9352         //   %r = select %c, -1, %a
9353         //   =>
9354         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
9355         //   %u0 = extractvalue %u, 0
9356         //   %u1 = extractvalue %u, 1
9357         //   %r = select %u1, -1, %u0
9358         SDVTList VTs = DAG.getVTList(VT, VT0);
9359         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
9360         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
9361       }
9362     }
9363 
9364     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
9365         (!LegalOperations &&
9366          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
9367       // Any flags available in a select/setcc fold will be on the setcc as they
9368       // migrated from fcmp
9369       Flags = N0.getNode()->getFlags();
9370       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
9371                                        N2, N0.getOperand(2));
9372       SelectNode->setFlags(Flags);
9373       return SelectNode;
9374     }
9375 
9376     return SimplifySelect(DL, N0, N1, N2);
9377   }
9378 
9379   return SDValue();
9380 }
9381 
9382 // This function assumes all the vselect's arguments are CONCAT_VECTOR
9383 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)9384 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
9385   SDLoc DL(N);
9386   SDValue Cond = N->getOperand(0);
9387   SDValue LHS = N->getOperand(1);
9388   SDValue RHS = N->getOperand(2);
9389   EVT VT = N->getValueType(0);
9390   int NumElems = VT.getVectorNumElements();
9391   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
9392          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
9393          Cond.getOpcode() == ISD::BUILD_VECTOR);
9394 
9395   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
9396   // binary ones here.
9397   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
9398     return SDValue();
9399 
9400   // We're sure we have an even number of elements due to the
9401   // concat_vectors we have as arguments to vselect.
9402   // Skip BV elements until we find one that's not an UNDEF
9403   // After we find an UNDEF element, keep looping until we get to half the
9404   // length of the BV and see if all the non-undef nodes are the same.
9405   ConstantSDNode *BottomHalf = nullptr;
9406   for (int i = 0; i < NumElems / 2; ++i) {
9407     if (Cond->getOperand(i)->isUndef())
9408       continue;
9409 
9410     if (BottomHalf == nullptr)
9411       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
9412     else if (Cond->getOperand(i).getNode() != BottomHalf)
9413       return SDValue();
9414   }
9415 
9416   // Do the same for the second half of the BuildVector
9417   ConstantSDNode *TopHalf = nullptr;
9418   for (int i = NumElems / 2; i < NumElems; ++i) {
9419     if (Cond->getOperand(i)->isUndef())
9420       continue;
9421 
9422     if (TopHalf == nullptr)
9423       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
9424     else if (Cond->getOperand(i).getNode() != TopHalf)
9425       return SDValue();
9426   }
9427 
9428   assert(TopHalf && BottomHalf &&
9429          "One half of the selector was all UNDEFs and the other was all the "
9430          "same value. This should have been addressed before this function.");
9431   return DAG.getNode(
9432       ISD::CONCAT_VECTORS, DL, VT,
9433       BottomHalf->isNullValue() ? RHS->getOperand(0) : LHS->getOperand(0),
9434       TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1));
9435 }
9436 
refineUniformBase(SDValue & BasePtr,SDValue & Index,SelectionDAG & DAG)9437 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) {
9438   if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
9439     return false;
9440 
9441   // For now we check only the LHS of the add.
9442   SDValue LHS = Index.getOperand(0);
9443   SDValue SplatVal = DAG.getSplatValue(LHS);
9444   if (!SplatVal)
9445     return false;
9446 
9447   BasePtr = SplatVal;
9448   Index = Index.getOperand(1);
9449   return true;
9450 }
9451 
9452 // Fold sext/zext of index into index type.
refineIndexType(MaskedGatherScatterSDNode * MGS,SDValue & Index,bool Scaled,SelectionDAG & DAG)9453 bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index,
9454                      bool Scaled, SelectionDAG &DAG) {
9455   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9456 
9457   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
9458     SDValue Op = Index.getOperand(0);
9459     MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED);
9460     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
9461       Index = Op;
9462       return true;
9463     }
9464   }
9465 
9466   if (Index.getOpcode() == ISD::SIGN_EXTEND) {
9467     SDValue Op = Index.getOperand(0);
9468     MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED);
9469     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
9470       Index = Op;
9471       return true;
9472     }
9473   }
9474 
9475   return false;
9476 }
9477 
visitMSCATTER(SDNode * N)9478 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
9479   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
9480   SDValue Mask = MSC->getMask();
9481   SDValue Chain = MSC->getChain();
9482   SDValue Index = MSC->getIndex();
9483   SDValue Scale = MSC->getScale();
9484   SDValue StoreVal = MSC->getValue();
9485   SDValue BasePtr = MSC->getBasePtr();
9486   SDLoc DL(N);
9487 
9488   // Zap scatters with a zero mask.
9489   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
9490     return Chain;
9491 
9492   if (refineUniformBase(BasePtr, Index, DAG)) {
9493     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
9494     return DAG.getMaskedScatter(
9495         DAG.getVTList(MVT::Other), StoreVal.getValueType(), DL, Ops,
9496         MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
9497   }
9498 
9499   if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) {
9500     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
9501     return DAG.getMaskedScatter(
9502         DAG.getVTList(MVT::Other), StoreVal.getValueType(), DL, Ops,
9503         MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
9504   }
9505 
9506   return SDValue();
9507 }
9508 
visitMSTORE(SDNode * N)9509 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
9510   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
9511   SDValue Mask = MST->getMask();
9512   SDValue Chain = MST->getChain();
9513   SDLoc DL(N);
9514 
9515   // Zap masked stores with a zero mask.
9516   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
9517     return Chain;
9518 
9519   // If this is a masked load with an all ones mask, we can use a unmasked load.
9520   // FIXME: Can we do this for indexed, compressing, or truncating stores?
9521   if (ISD::isBuildVectorAllOnes(Mask.getNode()) &&
9522       MST->isUnindexed() && !MST->isCompressingStore() &&
9523       !MST->isTruncatingStore())
9524     return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
9525                         MST->getBasePtr(), MST->getMemOperand());
9526 
9527   // Try transforming N to an indexed store.
9528   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
9529     return SDValue(N, 0);
9530 
9531   return SDValue();
9532 }
9533 
visitMGATHER(SDNode * N)9534 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
9535   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
9536   SDValue Mask = MGT->getMask();
9537   SDValue Chain = MGT->getChain();
9538   SDValue Index = MGT->getIndex();
9539   SDValue Scale = MGT->getScale();
9540   SDValue PassThru = MGT->getPassThru();
9541   SDValue BasePtr = MGT->getBasePtr();
9542   SDLoc DL(N);
9543 
9544   // Zap gathers with a zero mask.
9545   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
9546     return CombineTo(N, PassThru, MGT->getChain());
9547 
9548   if (refineUniformBase(BasePtr, Index, DAG)) {
9549     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
9550     return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
9551                                PassThru.getValueType(), DL, Ops,
9552                                MGT->getMemOperand(), MGT->getIndexType(),
9553                                MGT->getExtensionType());
9554   }
9555 
9556   if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) {
9557     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
9558     return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
9559                                PassThru.getValueType(), DL, Ops,
9560                                MGT->getMemOperand(), MGT->getIndexType(),
9561                                MGT->getExtensionType());
9562   }
9563 
9564   return SDValue();
9565 }
9566 
visitMLOAD(SDNode * N)9567 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
9568   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
9569   SDValue Mask = MLD->getMask();
9570   SDLoc DL(N);
9571 
9572   // Zap masked loads with a zero mask.
9573   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
9574     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
9575 
9576   // If this is a masked load with an all ones mask, we can use a unmasked load.
9577   // FIXME: Can we do this for indexed, expanding, or extending loads?
9578   if (ISD::isBuildVectorAllOnes(Mask.getNode()) &&
9579       MLD->isUnindexed() && !MLD->isExpandingLoad() &&
9580       MLD->getExtensionType() == ISD::NON_EXTLOAD) {
9581     SDValue NewLd = DAG.getLoad(N->getValueType(0), SDLoc(N), MLD->getChain(),
9582                                 MLD->getBasePtr(), MLD->getMemOperand());
9583     return CombineTo(N, NewLd, NewLd.getValue(1));
9584   }
9585 
9586   // Try transforming N to an indexed load.
9587   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
9588     return SDValue(N, 0);
9589 
9590   return SDValue();
9591 }
9592 
9593 /// A vector select of 2 constant vectors can be simplified to math/logic to
9594 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)9595 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
9596   SDValue Cond = N->getOperand(0);
9597   SDValue N1 = N->getOperand(1);
9598   SDValue N2 = N->getOperand(2);
9599   EVT VT = N->getValueType(0);
9600   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
9601       !TLI.convertSelectOfConstantsToMath(VT) ||
9602       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
9603       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
9604     return SDValue();
9605 
9606   // Check if we can use the condition value to increment/decrement a single
9607   // constant value. This simplifies a select to an add and removes a constant
9608   // load/materialization from the general case.
9609   bool AllAddOne = true;
9610   bool AllSubOne = true;
9611   unsigned Elts = VT.getVectorNumElements();
9612   for (unsigned i = 0; i != Elts; ++i) {
9613     SDValue N1Elt = N1.getOperand(i);
9614     SDValue N2Elt = N2.getOperand(i);
9615     if (N1Elt.isUndef() || N2Elt.isUndef())
9616       continue;
9617     if (N1Elt.getValueType() != N2Elt.getValueType())
9618       continue;
9619 
9620     const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
9621     const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
9622     if (C1 != C2 + 1)
9623       AllAddOne = false;
9624     if (C1 != C2 - 1)
9625       AllSubOne = false;
9626   }
9627 
9628   // Further simplifications for the extra-special cases where the constants are
9629   // all 0 or all -1 should be implemented as folds of these patterns.
9630   SDLoc DL(N);
9631   if (AllAddOne || AllSubOne) {
9632     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
9633     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
9634     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
9635     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
9636     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
9637   }
9638 
9639   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
9640   APInt Pow2C;
9641   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
9642       isNullOrNullSplat(N2)) {
9643     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
9644     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
9645     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
9646   }
9647 
9648   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
9649     return V;
9650 
9651   // The general case for select-of-constants:
9652   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
9653   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
9654   // leave that to a machine-specific pass.
9655   return SDValue();
9656 }
9657 
visitVSELECT(SDNode * N)9658 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
9659   SDValue N0 = N->getOperand(0);
9660   SDValue N1 = N->getOperand(1);
9661   SDValue N2 = N->getOperand(2);
9662   EVT VT = N->getValueType(0);
9663   SDLoc DL(N);
9664 
9665   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
9666     return V;
9667 
9668   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
9669   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
9670     return DAG.getSelect(DL, VT, F, N2, N1);
9671 
9672   // Canonicalize integer abs.
9673   // vselect (setg[te] X,  0),  X, -X ->
9674   // vselect (setgt    X, -1),  X, -X ->
9675   // vselect (setl[te] X,  0), -X,  X ->
9676   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
9677   if (N0.getOpcode() == ISD::SETCC) {
9678     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
9679     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
9680     bool isAbs = false;
9681     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
9682 
9683     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
9684          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
9685         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
9686       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
9687     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
9688              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
9689       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
9690 
9691     if (isAbs) {
9692       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
9693         return DAG.getNode(ISD::ABS, DL, VT, LHS);
9694 
9695       SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
9696                                   DAG.getConstant(VT.getScalarSizeInBits() - 1,
9697                                                   DL, getShiftAmountTy(VT)));
9698       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
9699       AddToWorklist(Shift.getNode());
9700       AddToWorklist(Add.getNode());
9701       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
9702     }
9703 
9704     // vselect x, y (fcmp lt x, y) -> fminnum x, y
9705     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
9706     //
9707     // This is OK if we don't care about what happens if either operand is a
9708     // NaN.
9709     //
9710     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
9711       if (SDValue FMinMax =
9712               combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC, TLI, DAG))
9713         return FMinMax;
9714     }
9715 
9716     // If this select has a condition (setcc) with narrower operands than the
9717     // select, try to widen the compare to match the select width.
9718     // TODO: This should be extended to handle any constant.
9719     // TODO: This could be extended to handle non-loading patterns, but that
9720     //       requires thorough testing to avoid regressions.
9721     if (isNullOrNullSplat(RHS)) {
9722       EVT NarrowVT = LHS.getValueType();
9723       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
9724       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
9725       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
9726       unsigned WideWidth = WideVT.getScalarSizeInBits();
9727       bool IsSigned = isSignedIntSetCC(CC);
9728       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
9729       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
9730           SetCCWidth != 1 && SetCCWidth < WideWidth &&
9731           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
9732           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
9733         // Both compare operands can be widened for free. The LHS can use an
9734         // extended load, and the RHS is a constant:
9735         //   vselect (ext (setcc load(X), C)), N1, N2 -->
9736         //   vselect (setcc extload(X), C'), N1, N2
9737         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
9738         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
9739         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
9740         EVT WideSetCCVT = getSetCCResultType(WideVT);
9741         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
9742         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
9743       }
9744     }
9745 
9746     // Match VSELECTs into add with unsigned saturation.
9747     if (hasOperation(ISD::UADDSAT, VT)) {
9748       // Check if one of the arms of the VSELECT is vector with all bits set.
9749       // If it's on the left side invert the predicate to simplify logic below.
9750       SDValue Other;
9751       ISD::CondCode SatCC = CC;
9752       if (ISD::isBuildVectorAllOnes(N1.getNode())) {
9753         Other = N2;
9754         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
9755       } else if (ISD::isBuildVectorAllOnes(N2.getNode())) {
9756         Other = N1;
9757       }
9758 
9759       if (Other && Other.getOpcode() == ISD::ADD) {
9760         SDValue CondLHS = LHS, CondRHS = RHS;
9761         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
9762 
9763         // Canonicalize condition operands.
9764         if (SatCC == ISD::SETUGE) {
9765           std::swap(CondLHS, CondRHS);
9766           SatCC = ISD::SETULE;
9767         }
9768 
9769         // We can test against either of the addition operands.
9770         // x <= x+y ? x+y : ~0 --> uaddsat x, y
9771         // x+y >= x ? x+y : ~0 --> uaddsat x, y
9772         if (SatCC == ISD::SETULE && Other == CondRHS &&
9773             (OpLHS == CondLHS || OpRHS == CondLHS))
9774           return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
9775 
9776         if (isa<BuildVectorSDNode>(OpRHS) && isa<BuildVectorSDNode>(CondRHS) &&
9777             CondLHS == OpLHS) {
9778           // If the RHS is a constant we have to reverse the const
9779           // canonicalization.
9780           // x >= ~C ? x+C : ~0 --> uaddsat x, C
9781           auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
9782             return Cond->getAPIntValue() == ~Op->getAPIntValue();
9783           };
9784           if (SatCC == ISD::SETULE &&
9785               ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
9786             return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
9787         }
9788       }
9789     }
9790 
9791     // Match VSELECTs into sub with unsigned saturation.
9792     if (hasOperation(ISD::USUBSAT, VT)) {
9793       // Check if one of the arms of the VSELECT is a zero vector. If it's on
9794       // the left side invert the predicate to simplify logic below.
9795       SDValue Other;
9796       ISD::CondCode SatCC = CC;
9797       if (ISD::isBuildVectorAllZeros(N1.getNode())) {
9798         Other = N2;
9799         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
9800       } else if (ISD::isBuildVectorAllZeros(N2.getNode())) {
9801         Other = N1;
9802       }
9803 
9804       if (Other && Other.getNumOperands() == 2 && Other.getOperand(0) == LHS) {
9805         SDValue CondRHS = RHS;
9806         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
9807 
9808         // Look for a general sub with unsigned saturation first.
9809         // x >= y ? x-y : 0 --> usubsat x, y
9810         // x >  y ? x-y : 0 --> usubsat x, y
9811         if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
9812             Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
9813           return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
9814 
9815         if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS)) {
9816           if (isa<BuildVectorSDNode>(CondRHS)) {
9817             // If the RHS is a constant we have to reverse the const
9818             // canonicalization.
9819             // x > C-1 ? x+-C : 0 --> usubsat x, C
9820             auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
9821               return (!Op && !Cond) ||
9822                      (Op && Cond &&
9823                       Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
9824             };
9825             if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
9826                 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
9827                                           /*AllowUndefs*/ true)) {
9828               OpRHS = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
9829                                   OpRHS);
9830               return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
9831             }
9832 
9833             // Another special case: If C was a sign bit, the sub has been
9834             // canonicalized into a xor.
9835             // FIXME: Would it be better to use computeKnownBits to determine
9836             //        whether it's safe to decanonicalize the xor?
9837             // x s< 0 ? x^C : 0 --> usubsat x, C
9838             if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) {
9839               if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
9840                   ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
9841                   OpRHSConst->getAPIntValue().isSignMask()) {
9842                 // Note that we have to rebuild the RHS constant here to ensure
9843                 // we don't rely on particular values of undef lanes.
9844                 OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT);
9845                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
9846               }
9847             }
9848           }
9849         }
9850       }
9851     }
9852   }
9853 
9854   if (SimplifySelectOps(N, N1, N2))
9855     return SDValue(N, 0);  // Don't revisit N.
9856 
9857   // Fold (vselect (build_vector all_ones), N1, N2) -> N1
9858   if (ISD::isBuildVectorAllOnes(N0.getNode()))
9859     return N1;
9860   // Fold (vselect (build_vector all_zeros), N1, N2) -> N2
9861   if (ISD::isBuildVectorAllZeros(N0.getNode()))
9862     return N2;
9863 
9864   // The ConvertSelectToConcatVector function is assuming both the above
9865   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
9866   // and addressed.
9867   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
9868       N2.getOpcode() == ISD::CONCAT_VECTORS &&
9869       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
9870     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
9871       return CV;
9872   }
9873 
9874   if (SDValue V = foldVSelectOfConstants(N))
9875     return V;
9876 
9877   return SDValue();
9878 }
9879 
visitSELECT_CC(SDNode * N)9880 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
9881   SDValue N0 = N->getOperand(0);
9882   SDValue N1 = N->getOperand(1);
9883   SDValue N2 = N->getOperand(2);
9884   SDValue N3 = N->getOperand(3);
9885   SDValue N4 = N->getOperand(4);
9886   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
9887 
9888   // fold select_cc lhs, rhs, x, x, cc -> x
9889   if (N2 == N3)
9890     return N2;
9891 
9892   // Determine if the condition we're dealing with is constant
9893   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
9894                                   CC, SDLoc(N), false)) {
9895     AddToWorklist(SCC.getNode());
9896 
9897     if (ConstantSDNode *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode())) {
9898       if (!SCCC->isNullValue())
9899         return N2;    // cond always true -> true val
9900       else
9901         return N3;    // cond always false -> false val
9902     } else if (SCC->isUndef()) {
9903       // When the condition is UNDEF, just return the first operand. This is
9904       // coherent the DAG creation, no setcc node is created in this case
9905       return N2;
9906     } else if (SCC.getOpcode() == ISD::SETCC) {
9907       // Fold to a simpler select_cc
9908       SDValue SelectOp = DAG.getNode(
9909           ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
9910           SCC.getOperand(1), N2, N3, SCC.getOperand(2));
9911       SelectOp->setFlags(SCC->getFlags());
9912       return SelectOp;
9913     }
9914   }
9915 
9916   // If we can fold this based on the true/false value, do so.
9917   if (SimplifySelectOps(N, N2, N3))
9918     return SDValue(N, 0);  // Don't revisit N.
9919 
9920   // fold select_cc into other things, such as min/max/abs
9921   return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
9922 }
9923 
visitSETCC(SDNode * N)9924 SDValue DAGCombiner::visitSETCC(SDNode *N) {
9925   // setcc is very commonly used as an argument to brcond. This pattern
9926   // also lend itself to numerous combines and, as a result, it is desired
9927   // we keep the argument to a brcond as a setcc as much as possible.
9928   bool PreferSetCC =
9929       N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
9930 
9931   SDValue Combined = SimplifySetCC(
9932       N->getValueType(0), N->getOperand(0), N->getOperand(1),
9933       cast<CondCodeSDNode>(N->getOperand(2))->get(), SDLoc(N), !PreferSetCC);
9934 
9935   if (!Combined)
9936     return SDValue();
9937 
9938   // If we prefer to have a setcc, and we don't, we'll try our best to
9939   // recreate one using rebuildSetCC.
9940   if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
9941     SDValue NewSetCC = rebuildSetCC(Combined);
9942 
9943     // We don't have anything interesting to combine to.
9944     if (NewSetCC.getNode() == N)
9945       return SDValue();
9946 
9947     if (NewSetCC)
9948       return NewSetCC;
9949   }
9950 
9951   return Combined;
9952 }
9953 
visitSETCCCARRY(SDNode * N)9954 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
9955   SDValue LHS = N->getOperand(0);
9956   SDValue RHS = N->getOperand(1);
9957   SDValue Carry = N->getOperand(2);
9958   SDValue Cond = N->getOperand(3);
9959 
9960   // If Carry is false, fold to a regular SETCC.
9961   if (isNullConstant(Carry))
9962     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
9963 
9964   return SDValue();
9965 }
9966 
9967 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
9968 /// a build_vector of constants.
9969 /// This function is called by the DAGCombiner when visiting sext/zext/aext
9970 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
9971 /// Vector extends are not folded if operations are legal; this is to
9972 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)9973 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
9974                                          SelectionDAG &DAG, bool LegalTypes) {
9975   unsigned Opcode = N->getOpcode();
9976   SDValue N0 = N->getOperand(0);
9977   EVT VT = N->getValueType(0);
9978   SDLoc DL(N);
9979 
9980   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
9981          Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
9982          Opcode == ISD::ZERO_EXTEND_VECTOR_INREG)
9983          && "Expected EXTEND dag node in input!");
9984 
9985   // fold (sext c1) -> c1
9986   // fold (zext c1) -> c1
9987   // fold (aext c1) -> c1
9988   if (isa<ConstantSDNode>(N0))
9989     return DAG.getNode(Opcode, DL, VT, N0);
9990 
9991   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
9992   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
9993   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
9994   if (N0->getOpcode() == ISD::SELECT) {
9995     SDValue Op1 = N0->getOperand(1);
9996     SDValue Op2 = N0->getOperand(2);
9997     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
9998         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
9999       // For any_extend, choose sign extension of the constants to allow a
10000       // possible further transform to sign_extend_inreg.i.e.
10001       //
10002       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
10003       // t2: i64 = any_extend t1
10004       // -->
10005       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
10006       // -->
10007       // t4: i64 = sign_extend_inreg t3
10008       unsigned FoldOpc = Opcode;
10009       if (FoldOpc == ISD::ANY_EXTEND)
10010         FoldOpc = ISD::SIGN_EXTEND;
10011       return DAG.getSelect(DL, VT, N0->getOperand(0),
10012                            DAG.getNode(FoldOpc, DL, VT, Op1),
10013                            DAG.getNode(FoldOpc, DL, VT, Op2));
10014     }
10015   }
10016 
10017   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
10018   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
10019   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
10020   EVT SVT = VT.getScalarType();
10021   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
10022       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
10023     return SDValue();
10024 
10025   // We can fold this node into a build_vector.
10026   unsigned VTBits = SVT.getSizeInBits();
10027   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
10028   SmallVector<SDValue, 8> Elts;
10029   unsigned NumElts = VT.getVectorNumElements();
10030 
10031   // For zero-extensions, UNDEF elements still guarantee to have the upper
10032   // bits set to zero.
10033   bool IsZext =
10034       Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG;
10035 
10036   for (unsigned i = 0; i != NumElts; ++i) {
10037     SDValue Op = N0.getOperand(i);
10038     if (Op.isUndef()) {
10039       Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT));
10040       continue;
10041     }
10042 
10043     SDLoc DL(Op);
10044     // Get the constant value and if needed trunc it to the size of the type.
10045     // Nodes like build_vector might have constants wider than the scalar type.
10046     APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().zextOrTrunc(EVTBits);
10047     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
10048       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
10049     else
10050       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
10051   }
10052 
10053   return DAG.getBuildVector(VT, DL, Elts);
10054 }
10055 
10056 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
10057 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
10058 // transformation. Returns true if extension are possible and the above
10059 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)10060 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
10061                                     unsigned ExtOpc,
10062                                     SmallVectorImpl<SDNode *> &ExtendNodes,
10063                                     const TargetLowering &TLI) {
10064   bool HasCopyToRegUses = false;
10065   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
10066   for (SDNode::use_iterator UI = N0.getNode()->use_begin(),
10067                             UE = N0.getNode()->use_end();
10068        UI != UE; ++UI) {
10069     SDNode *User = *UI;
10070     if (User == N)
10071       continue;
10072     if (UI.getUse().getResNo() != N0.getResNo())
10073       continue;
10074     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
10075     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
10076       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
10077       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
10078         // Sign bits will be lost after a zext.
10079         return false;
10080       bool Add = false;
10081       for (unsigned i = 0; i != 2; ++i) {
10082         SDValue UseOp = User->getOperand(i);
10083         if (UseOp == N0)
10084           continue;
10085         if (!isa<ConstantSDNode>(UseOp))
10086           return false;
10087         Add = true;
10088       }
10089       if (Add)
10090         ExtendNodes.push_back(User);
10091       continue;
10092     }
10093     // If truncates aren't free and there are users we can't
10094     // extend, it isn't worthwhile.
10095     if (!isTruncFree)
10096       return false;
10097     // Remember if this value is live-out.
10098     if (User->getOpcode() == ISD::CopyToReg)
10099       HasCopyToRegUses = true;
10100   }
10101 
10102   if (HasCopyToRegUses) {
10103     bool BothLiveOut = false;
10104     for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
10105          UI != UE; ++UI) {
10106       SDUse &Use = UI.getUse();
10107       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
10108         BothLiveOut = true;
10109         break;
10110       }
10111     }
10112     if (BothLiveOut)
10113       // Both unextended and extended values are live out. There had better be
10114       // a good reason for the transformation.
10115       return ExtendNodes.size();
10116   }
10117   return true;
10118 }
10119 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)10120 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
10121                                   SDValue OrigLoad, SDValue ExtLoad,
10122                                   ISD::NodeType ExtType) {
10123   // Extend SetCC uses if necessary.
10124   SDLoc DL(ExtLoad);
10125   for (SDNode *SetCC : SetCCs) {
10126     SmallVector<SDValue, 4> Ops;
10127 
10128     for (unsigned j = 0; j != 2; ++j) {
10129       SDValue SOp = SetCC->getOperand(j);
10130       if (SOp == OrigLoad)
10131         Ops.push_back(ExtLoad);
10132       else
10133         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
10134     }
10135 
10136     Ops.push_back(SetCC->getOperand(2));
10137     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
10138   }
10139 }
10140 
10141 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)10142 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
10143   SDValue N0 = N->getOperand(0);
10144   EVT DstVT = N->getValueType(0);
10145   EVT SrcVT = N0.getValueType();
10146 
10147   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
10148           N->getOpcode() == ISD::ZERO_EXTEND) &&
10149          "Unexpected node type (not an extend)!");
10150 
10151   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
10152   // For example, on a target with legal v4i32, but illegal v8i32, turn:
10153   //   (v8i32 (sext (v8i16 (load x))))
10154   // into:
10155   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
10156   //                          (v4i32 (sextload (x + 16)))))
10157   // Where uses of the original load, i.e.:
10158   //   (v8i16 (load x))
10159   // are replaced with:
10160   //   (v8i16 (truncate
10161   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
10162   //                            (v4i32 (sextload (x + 16)))))))
10163   //
10164   // This combine is only applicable to illegal, but splittable, vectors.
10165   // All legal types, and illegal non-vector types, are handled elsewhere.
10166   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
10167   //
10168   if (N0->getOpcode() != ISD::LOAD)
10169     return SDValue();
10170 
10171   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10172 
10173   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
10174       !N0.hasOneUse() || !LN0->isSimple() ||
10175       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
10176       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
10177     return SDValue();
10178 
10179   SmallVector<SDNode *, 4> SetCCs;
10180   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
10181     return SDValue();
10182 
10183   ISD::LoadExtType ExtType =
10184       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
10185 
10186   // Try to split the vector types to get down to legal types.
10187   EVT SplitSrcVT = SrcVT;
10188   EVT SplitDstVT = DstVT;
10189   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
10190          SplitSrcVT.getVectorNumElements() > 1) {
10191     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
10192     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
10193   }
10194 
10195   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
10196     return SDValue();
10197 
10198   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
10199 
10200   SDLoc DL(N);
10201   const unsigned NumSplits =
10202       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
10203   const unsigned Stride = SplitSrcVT.getStoreSize();
10204   SmallVector<SDValue, 4> Loads;
10205   SmallVector<SDValue, 4> Chains;
10206 
10207   SDValue BasePtr = LN0->getBasePtr();
10208   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
10209     const unsigned Offset = Idx * Stride;
10210     const Align Align = commonAlignment(LN0->getAlign(), Offset);
10211 
10212     SDValue SplitLoad = DAG.getExtLoad(
10213         ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
10214         LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
10215         LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
10216 
10217     BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(Stride), DL);
10218 
10219     Loads.push_back(SplitLoad.getValue(0));
10220     Chains.push_back(SplitLoad.getValue(1));
10221   }
10222 
10223   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
10224   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
10225 
10226   // Simplify TF.
10227   AddToWorklist(NewChain.getNode());
10228 
10229   CombineTo(N, NewValue);
10230 
10231   // Replace uses of the original load (before extension)
10232   // with a truncate of the concatenated sextloaded vectors.
10233   SDValue Trunc =
10234       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
10235   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
10236   CombineTo(N0.getNode(), Trunc, NewChain);
10237   return SDValue(N, 0); // Return N so it doesn't get rechecked!
10238 }
10239 
10240 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
10241 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)10242 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
10243   assert(N->getOpcode() == ISD::ZERO_EXTEND);
10244   EVT VT = N->getValueType(0);
10245   EVT OrigVT = N->getOperand(0).getValueType();
10246   if (TLI.isZExtFree(OrigVT, VT))
10247     return SDValue();
10248 
10249   // and/or/xor
10250   SDValue N0 = N->getOperand(0);
10251   if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
10252         N0.getOpcode() == ISD::XOR) ||
10253       N0.getOperand(1).getOpcode() != ISD::Constant ||
10254       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
10255     return SDValue();
10256 
10257   // shl/shr
10258   SDValue N1 = N0->getOperand(0);
10259   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
10260       N1.getOperand(1).getOpcode() != ISD::Constant ||
10261       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
10262     return SDValue();
10263 
10264   // load
10265   if (!isa<LoadSDNode>(N1.getOperand(0)))
10266     return SDValue();
10267   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
10268   EVT MemVT = Load->getMemoryVT();
10269   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
10270       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
10271     return SDValue();
10272 
10273 
10274   // If the shift op is SHL, the logic op must be AND, otherwise the result
10275   // will be wrong.
10276   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
10277     return SDValue();
10278 
10279   if (!N0.hasOneUse() || !N1.hasOneUse())
10280     return SDValue();
10281 
10282   SmallVector<SDNode*, 4> SetCCs;
10283   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
10284                                ISD::ZERO_EXTEND, SetCCs, TLI))
10285     return SDValue();
10286 
10287   // Actually do the transformation.
10288   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
10289                                    Load->getChain(), Load->getBasePtr(),
10290                                    Load->getMemoryVT(), Load->getMemOperand());
10291 
10292   SDLoc DL1(N1);
10293   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
10294                               N1.getOperand(1));
10295 
10296   APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
10297   SDLoc DL0(N0);
10298   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
10299                             DAG.getConstant(Mask, DL0, VT));
10300 
10301   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
10302   CombineTo(N, And);
10303   if (SDValue(Load, 0).hasOneUse()) {
10304     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
10305   } else {
10306     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
10307                                 Load->getValueType(0), ExtLoad);
10308     CombineTo(Load, Trunc, ExtLoad.getValue(1));
10309   }
10310 
10311   // N0 is dead at this point.
10312   recursivelyDeleteUnusedNodes(N0.getNode());
10313 
10314   return SDValue(N,0); // Return N so it doesn't get rechecked!
10315 }
10316 
10317 /// If we're narrowing or widening the result of a vector select and the final
10318 /// size is the same size as a setcc (compare) feeding the select, then try to
10319 /// apply the cast operation to the select's operands because matching vector
10320 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)10321 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
10322   unsigned CastOpcode = Cast->getOpcode();
10323   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
10324           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
10325           CastOpcode == ISD::FP_ROUND) &&
10326          "Unexpected opcode for vector select narrowing/widening");
10327 
10328   // We only do this transform before legal ops because the pattern may be
10329   // obfuscated by target-specific operations after legalization. Do not create
10330   // an illegal select op, however, because that may be difficult to lower.
10331   EVT VT = Cast->getValueType(0);
10332   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
10333     return SDValue();
10334 
10335   SDValue VSel = Cast->getOperand(0);
10336   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
10337       VSel.getOperand(0).getOpcode() != ISD::SETCC)
10338     return SDValue();
10339 
10340   // Does the setcc have the same vector size as the casted select?
10341   SDValue SetCC = VSel.getOperand(0);
10342   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
10343   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
10344     return SDValue();
10345 
10346   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
10347   SDValue A = VSel.getOperand(1);
10348   SDValue B = VSel.getOperand(2);
10349   SDValue CastA, CastB;
10350   SDLoc DL(Cast);
10351   if (CastOpcode == ISD::FP_ROUND) {
10352     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
10353     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
10354     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
10355   } else {
10356     CastA = DAG.getNode(CastOpcode, DL, VT, A);
10357     CastB = DAG.getNode(CastOpcode, DL, VT, B);
10358   }
10359   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
10360 }
10361 
10362 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
10363 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)10364 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
10365                                      const TargetLowering &TLI, EVT VT,
10366                                      bool LegalOperations, SDNode *N,
10367                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
10368   SDNode *N0Node = N0.getNode();
10369   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
10370                                                    : ISD::isZEXTLoad(N0Node);
10371   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
10372       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
10373     return SDValue();
10374 
10375   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10376   EVT MemVT = LN0->getMemoryVT();
10377   if ((LegalOperations || !LN0->isSimple() ||
10378        VT.isVector()) &&
10379       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
10380     return SDValue();
10381 
10382   SDValue ExtLoad =
10383       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
10384                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
10385   Combiner.CombineTo(N, ExtLoad);
10386   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
10387   if (LN0->use_empty())
10388     Combiner.recursivelyDeleteUnusedNodes(LN0);
10389   return SDValue(N, 0); // Return N so it doesn't get rechecked!
10390 }
10391 
10392 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
10393 // Only generate vector extloads when 1) they're legal, and 2) they are
10394 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)10395 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
10396                                   const TargetLowering &TLI, EVT VT,
10397                                   bool LegalOperations, SDNode *N, SDValue N0,
10398                                   ISD::LoadExtType ExtLoadType,
10399                                   ISD::NodeType ExtOpc) {
10400   if (!ISD::isNON_EXTLoad(N0.getNode()) ||
10401       !ISD::isUNINDEXEDLoad(N0.getNode()) ||
10402       ((LegalOperations || VT.isVector() ||
10403         !cast<LoadSDNode>(N0)->isSimple()) &&
10404        !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
10405     return {};
10406 
10407   bool DoXform = true;
10408   SmallVector<SDNode *, 4> SetCCs;
10409   if (!N0.hasOneUse())
10410     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
10411   if (VT.isVector())
10412     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
10413   if (!DoXform)
10414     return {};
10415 
10416   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10417   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
10418                                    LN0->getBasePtr(), N0.getValueType(),
10419                                    LN0->getMemOperand());
10420   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
10421   // If the load value is used only by N, replace it via CombineTo N.
10422   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
10423   Combiner.CombineTo(N, ExtLoad);
10424   if (NoReplaceTrunc) {
10425     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
10426     Combiner.recursivelyDeleteUnusedNodes(LN0);
10427   } else {
10428     SDValue Trunc =
10429         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
10430     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
10431   }
10432   return SDValue(N, 0); // Return N so it doesn't get rechecked!
10433 }
10434 
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)10435 static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG,
10436                                         const TargetLowering &TLI, EVT VT,
10437                                         SDNode *N, SDValue N0,
10438                                         ISD::LoadExtType ExtLoadType,
10439                                         ISD::NodeType ExtOpc) {
10440   if (!N0.hasOneUse())
10441     return SDValue();
10442 
10443   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
10444   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
10445     return SDValue();
10446 
10447   if (!TLI.isLoadExtLegal(ExtLoadType, VT, Ld->getValueType(0)))
10448     return SDValue();
10449 
10450   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
10451     return SDValue();
10452 
10453   SDLoc dl(Ld);
10454   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
10455   SDValue NewLoad = DAG.getMaskedLoad(
10456       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
10457       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
10458       ExtLoadType, Ld->isExpandingLoad());
10459   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
10460   return NewLoad;
10461 }
10462 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)10463 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
10464                                        bool LegalOperations) {
10465   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
10466           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
10467 
10468   SDValue SetCC = N->getOperand(0);
10469   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
10470       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
10471     return SDValue();
10472 
10473   SDValue X = SetCC.getOperand(0);
10474   SDValue Ones = SetCC.getOperand(1);
10475   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
10476   EVT VT = N->getValueType(0);
10477   EVT XVT = X.getValueType();
10478   // setge X, C is canonicalized to setgt, so we do not need to match that
10479   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
10480   // not require the 'not' op.
10481   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
10482     // Invert and smear/shift the sign bit:
10483     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
10484     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
10485     SDLoc DL(N);
10486     unsigned ShCt = VT.getSizeInBits() - 1;
10487     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10488     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
10489       SDValue NotX = DAG.getNOT(DL, X, VT);
10490       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
10491       auto ShiftOpcode =
10492         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
10493       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
10494     }
10495   }
10496   return SDValue();
10497 }
10498 
visitSIGN_EXTEND(SDNode * N)10499 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
10500   SDValue N0 = N->getOperand(0);
10501   EVT VT = N->getValueType(0);
10502   SDLoc DL(N);
10503 
10504   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
10505     return Res;
10506 
10507   // fold (sext (sext x)) -> (sext x)
10508   // fold (sext (aext x)) -> (sext x)
10509   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
10510     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
10511 
10512   if (N0.getOpcode() == ISD::TRUNCATE) {
10513     // fold (sext (truncate (load x))) -> (sext (smaller load x))
10514     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
10515     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
10516       SDNode *oye = N0.getOperand(0).getNode();
10517       if (NarrowLoad.getNode() != N0.getNode()) {
10518         CombineTo(N0.getNode(), NarrowLoad);
10519         // CombineTo deleted the truncate, if needed, but not what's under it.
10520         AddToWorklist(oye);
10521       }
10522       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10523     }
10524 
10525     // See if the value being truncated is already sign extended.  If so, just
10526     // eliminate the trunc/sext pair.
10527     SDValue Op = N0.getOperand(0);
10528     unsigned OpBits   = Op.getScalarValueSizeInBits();
10529     unsigned MidBits  = N0.getScalarValueSizeInBits();
10530     unsigned DestBits = VT.getScalarSizeInBits();
10531     unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
10532 
10533     if (OpBits == DestBits) {
10534       // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
10535       // bits, it is already ready.
10536       if (NumSignBits > DestBits-MidBits)
10537         return Op;
10538     } else if (OpBits < DestBits) {
10539       // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
10540       // bits, just sext from i32.
10541       if (NumSignBits > OpBits-MidBits)
10542         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
10543     } else {
10544       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
10545       // bits, just truncate to i32.
10546       if (NumSignBits > OpBits-MidBits)
10547         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
10548     }
10549 
10550     // fold (sext (truncate x)) -> (sextinreg x).
10551     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
10552                                                  N0.getValueType())) {
10553       if (OpBits < DestBits)
10554         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
10555       else if (OpBits > DestBits)
10556         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
10557       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
10558                          DAG.getValueType(N0.getValueType()));
10559     }
10560   }
10561 
10562   // Try to simplify (sext (load x)).
10563   if (SDValue foldedExt =
10564           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
10565                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
10566     return foldedExt;
10567 
10568   if (SDValue foldedExt =
10569       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD,
10570                                ISD::SIGN_EXTEND))
10571     return foldedExt;
10572 
10573   // fold (sext (load x)) to multiple smaller sextloads.
10574   // Only on illegal but splittable vectors.
10575   if (SDValue ExtLoad = CombineExtLoad(N))
10576     return ExtLoad;
10577 
10578   // Try to simplify (sext (sextload x)).
10579   if (SDValue foldedExt = tryToFoldExtOfExtload(
10580           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
10581     return foldedExt;
10582 
10583   // fold (sext (and/or/xor (load x), cst)) ->
10584   //      (and/or/xor (sextload x), (sext cst))
10585   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
10586        N0.getOpcode() == ISD::XOR) &&
10587       isa<LoadSDNode>(N0.getOperand(0)) &&
10588       N0.getOperand(1).getOpcode() == ISD::Constant &&
10589       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
10590     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
10591     EVT MemVT = LN00->getMemoryVT();
10592     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
10593       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
10594       SmallVector<SDNode*, 4> SetCCs;
10595       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
10596                                              ISD::SIGN_EXTEND, SetCCs, TLI);
10597       if (DoXform) {
10598         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
10599                                          LN00->getChain(), LN00->getBasePtr(),
10600                                          LN00->getMemoryVT(),
10601                                          LN00->getMemOperand());
10602         APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
10603         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
10604                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
10605         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
10606         bool NoReplaceTruncAnd = !N0.hasOneUse();
10607         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
10608         CombineTo(N, And);
10609         // If N0 has multiple uses, change other uses as well.
10610         if (NoReplaceTruncAnd) {
10611           SDValue TruncAnd =
10612               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
10613           CombineTo(N0.getNode(), TruncAnd);
10614         }
10615         if (NoReplaceTrunc) {
10616           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
10617         } else {
10618           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
10619                                       LN00->getValueType(0), ExtLoad);
10620           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
10621         }
10622         return SDValue(N,0); // Return N so it doesn't get rechecked!
10623       }
10624     }
10625   }
10626 
10627   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
10628     return V;
10629 
10630   if (N0.getOpcode() == ISD::SETCC) {
10631     SDValue N00 = N0.getOperand(0);
10632     SDValue N01 = N0.getOperand(1);
10633     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10634     EVT N00VT = N00.getValueType();
10635 
10636     // sext(setcc) -> sext_in_reg(vsetcc) for vectors.
10637     // Only do this before legalize for now.
10638     if (VT.isVector() && !LegalOperations &&
10639         TLI.getBooleanContents(N00VT) ==
10640             TargetLowering::ZeroOrNegativeOneBooleanContent) {
10641       // On some architectures (such as SSE/NEON/etc) the SETCC result type is
10642       // of the same size as the compared operands. Only optimize sext(setcc())
10643       // if this is the case.
10644       EVT SVT = getSetCCResultType(N00VT);
10645 
10646       // If we already have the desired type, don't change it.
10647       if (SVT != N0.getValueType()) {
10648         // We know that the # elements of the results is the same as the
10649         // # elements of the compare (and the # elements of the compare result
10650         // for that matter).  Check to see that they are the same size.  If so,
10651         // we know that the element size of the sext'd result matches the
10652         // element size of the compare operands.
10653         if (VT.getSizeInBits() == SVT.getSizeInBits())
10654           return DAG.getSetCC(DL, VT, N00, N01, CC);
10655 
10656         // If the desired elements are smaller or larger than the source
10657         // elements, we can use a matching integer vector type and then
10658         // truncate/sign extend.
10659         EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
10660         if (SVT == MatchingVecType) {
10661           SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
10662           return DAG.getSExtOrTrunc(VsetCC, DL, VT);
10663         }
10664       }
10665     }
10666 
10667     // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
10668     // Here, T can be 1 or -1, depending on the type of the setcc and
10669     // getBooleanContents().
10670     unsigned SetCCWidth = N0.getScalarValueSizeInBits();
10671 
10672     // To determine the "true" side of the select, we need to know the high bit
10673     // of the value returned by the setcc if it evaluates to true.
10674     // If the type of the setcc is i1, then the true case of the select is just
10675     // sext(i1 1), that is, -1.
10676     // If the type of the setcc is larger (say, i8) then the value of the high
10677     // bit depends on getBooleanContents(), so ask TLI for a real "true" value
10678     // of the appropriate width.
10679     SDValue ExtTrueVal = (SetCCWidth == 1)
10680                              ? DAG.getAllOnesConstant(DL, VT)
10681                              : DAG.getBoolConstant(true, DL, VT, N00VT);
10682     SDValue Zero = DAG.getConstant(0, DL, VT);
10683     if (SDValue SCC =
10684             SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
10685       return SCC;
10686 
10687     if (!VT.isVector() && !TLI.convertSelectOfConstantsToMath(VT)) {
10688       EVT SetCCVT = getSetCCResultType(N00VT);
10689       // Don't do this transform for i1 because there's a select transform
10690       // that would reverse it.
10691       // TODO: We should not do this transform at all without a target hook
10692       // because a sext is likely cheaper than a select?
10693       if (SetCCVT.getScalarSizeInBits() != 1 &&
10694           (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
10695         SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
10696         return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
10697       }
10698     }
10699   }
10700 
10701   // fold (sext x) -> (zext x) if the sign bit is known zero.
10702   if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
10703       DAG.SignBitIsZero(N0))
10704     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
10705 
10706   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
10707     return NewVSel;
10708 
10709   // Eliminate this sign extend by doing a negation in the destination type:
10710   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
10711   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
10712       isNullOrNullSplat(N0.getOperand(0)) &&
10713       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
10714       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
10715     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
10716     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Zext);
10717   }
10718   // Eliminate this sign extend by doing a decrement in the destination type:
10719   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
10720   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
10721       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
10722       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
10723       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
10724     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
10725     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
10726   }
10727 
10728   // fold sext (not i1 X) -> add (zext i1 X), -1
10729   // TODO: This could be extended to handle bool vectors.
10730   if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
10731       (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
10732                             TLI.isOperationLegal(ISD::ADD, VT)))) {
10733     // If we can eliminate the 'not', the sext form should be better
10734     if (SDValue NewXor = visitXOR(N0.getNode())) {
10735       // Returning N0 is a form of in-visit replacement that may have
10736       // invalidated N0.
10737       if (NewXor.getNode() == N0.getNode()) {
10738         // Return SDValue here as the xor should have already been replaced in
10739         // this sext.
10740         return SDValue();
10741       } else {
10742         // Return a new sext with the new xor.
10743         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
10744       }
10745     }
10746 
10747     SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
10748     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
10749   }
10750 
10751   return SDValue();
10752 }
10753 
10754 // isTruncateOf - If N is a truncate of some other value, return true, record
10755 // the value being truncated in Op and which of Op's bits are zero/one in Known.
10756 // This function computes KnownBits to avoid a duplicated call to
10757 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)10758 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
10759                          KnownBits &Known) {
10760   if (N->getOpcode() == ISD::TRUNCATE) {
10761     Op = N->getOperand(0);
10762     Known = DAG.computeKnownBits(Op);
10763     return true;
10764   }
10765 
10766   if (N.getOpcode() != ISD::SETCC ||
10767       N.getValueType().getScalarType() != MVT::i1 ||
10768       cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
10769     return false;
10770 
10771   SDValue Op0 = N->getOperand(0);
10772   SDValue Op1 = N->getOperand(1);
10773   assert(Op0.getValueType() == Op1.getValueType());
10774 
10775   if (isNullOrNullSplat(Op0))
10776     Op = Op1;
10777   else if (isNullOrNullSplat(Op1))
10778     Op = Op0;
10779   else
10780     return false;
10781 
10782   Known = DAG.computeKnownBits(Op);
10783 
10784   return (Known.Zero | 1).isAllOnesValue();
10785 }
10786 
10787 /// Given an extending node with a pop-count operand, if the target does not
10788 /// support a pop-count in the narrow source type but does support it in the
10789 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)10790 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
10791   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
10792           Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
10793 
10794   SDValue CtPop = Extend->getOperand(0);
10795   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
10796     return SDValue();
10797 
10798   EVT VT = Extend->getValueType(0);
10799   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10800   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
10801       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
10802     return SDValue();
10803 
10804   // zext (ctpop X) --> ctpop (zext X)
10805   SDLoc DL(Extend);
10806   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
10807   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
10808 }
10809 
visitZERO_EXTEND(SDNode * N)10810 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
10811   SDValue N0 = N->getOperand(0);
10812   EVT VT = N->getValueType(0);
10813 
10814   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
10815     return Res;
10816 
10817   // fold (zext (zext x)) -> (zext x)
10818   // fold (zext (aext x)) -> (zext x)
10819   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
10820     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT,
10821                        N0.getOperand(0));
10822 
10823   // fold (zext (truncate x)) -> (zext x) or
10824   //      (zext (truncate x)) -> (truncate x)
10825   // This is valid when the truncated bits of x are already zero.
10826   SDValue Op;
10827   KnownBits Known;
10828   if (isTruncateOf(DAG, N0, Op, Known)) {
10829     APInt TruncatedBits =
10830       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
10831       APInt(Op.getScalarValueSizeInBits(), 0) :
10832       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
10833                         N0.getScalarValueSizeInBits(),
10834                         std::min(Op.getScalarValueSizeInBits(),
10835                                  VT.getScalarSizeInBits()));
10836     if (TruncatedBits.isSubsetOf(Known.Zero))
10837       return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
10838   }
10839 
10840   // fold (zext (truncate x)) -> (and x, mask)
10841   if (N0.getOpcode() == ISD::TRUNCATE) {
10842     // fold (zext (truncate (load x))) -> (zext (smaller load x))
10843     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
10844     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
10845       SDNode *oye = N0.getOperand(0).getNode();
10846       if (NarrowLoad.getNode() != N0.getNode()) {
10847         CombineTo(N0.getNode(), NarrowLoad);
10848         // CombineTo deleted the truncate, if needed, but not what's under it.
10849         AddToWorklist(oye);
10850       }
10851       return SDValue(N, 0); // Return N so it doesn't get rechecked!
10852     }
10853 
10854     EVT SrcVT = N0.getOperand(0).getValueType();
10855     EVT MinVT = N0.getValueType();
10856 
10857     // Try to mask before the extension to avoid having to generate a larger mask,
10858     // possibly over several sub-vectors.
10859     if (SrcVT.bitsLT(VT) && VT.isVector()) {
10860       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
10861                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
10862         SDValue Op = N0.getOperand(0);
10863         Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
10864         AddToWorklist(Op.getNode());
10865         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
10866         // Transfer the debug info; the new node is equivalent to N0.
10867         DAG.transferDbgValues(N0, ZExtOrTrunc);
10868         return ZExtOrTrunc;
10869       }
10870     }
10871 
10872     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
10873       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
10874       AddToWorklist(Op.getNode());
10875       SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
10876       // We may safely transfer the debug info describing the truncate node over
10877       // to the equivalent and operation.
10878       DAG.transferDbgValues(N0, And);
10879       return And;
10880     }
10881   }
10882 
10883   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
10884   // if either of the casts is not free.
10885   if (N0.getOpcode() == ISD::AND &&
10886       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
10887       N0.getOperand(1).getOpcode() == ISD::Constant &&
10888       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
10889                            N0.getValueType()) ||
10890        !TLI.isZExtFree(N0.getValueType(), VT))) {
10891     SDValue X = N0.getOperand(0).getOperand(0);
10892     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
10893     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
10894     SDLoc DL(N);
10895     return DAG.getNode(ISD::AND, DL, VT,
10896                        X, DAG.getConstant(Mask, DL, VT));
10897   }
10898 
10899   // Try to simplify (zext (load x)).
10900   if (SDValue foldedExt =
10901           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
10902                              ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
10903     return foldedExt;
10904 
10905   if (SDValue foldedExt =
10906       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD,
10907                                ISD::ZERO_EXTEND))
10908     return foldedExt;
10909 
10910   // fold (zext (load x)) to multiple smaller zextloads.
10911   // Only on illegal but splittable vectors.
10912   if (SDValue ExtLoad = CombineExtLoad(N))
10913     return ExtLoad;
10914 
10915   // fold (zext (and/or/xor (load x), cst)) ->
10916   //      (and/or/xor (zextload x), (zext cst))
10917   // Unless (and (load x) cst) will match as a zextload already and has
10918   // additional users.
10919   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
10920        N0.getOpcode() == ISD::XOR) &&
10921       isa<LoadSDNode>(N0.getOperand(0)) &&
10922       N0.getOperand(1).getOpcode() == ISD::Constant &&
10923       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
10924     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
10925     EVT MemVT = LN00->getMemoryVT();
10926     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
10927         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
10928       bool DoXform = true;
10929       SmallVector<SDNode*, 4> SetCCs;
10930       if (!N0.hasOneUse()) {
10931         if (N0.getOpcode() == ISD::AND) {
10932           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
10933           EVT LoadResultTy = AndC->getValueType(0);
10934           EVT ExtVT;
10935           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
10936             DoXform = false;
10937         }
10938       }
10939       if (DoXform)
10940         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
10941                                           ISD::ZERO_EXTEND, SetCCs, TLI);
10942       if (DoXform) {
10943         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
10944                                          LN00->getChain(), LN00->getBasePtr(),
10945                                          LN00->getMemoryVT(),
10946                                          LN00->getMemOperand());
10947         APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
10948         SDLoc DL(N);
10949         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
10950                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
10951         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
10952         bool NoReplaceTruncAnd = !N0.hasOneUse();
10953         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
10954         CombineTo(N, And);
10955         // If N0 has multiple uses, change other uses as well.
10956         if (NoReplaceTruncAnd) {
10957           SDValue TruncAnd =
10958               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
10959           CombineTo(N0.getNode(), TruncAnd);
10960         }
10961         if (NoReplaceTrunc) {
10962           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
10963         } else {
10964           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
10965                                       LN00->getValueType(0), ExtLoad);
10966           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
10967         }
10968         return SDValue(N,0); // Return N so it doesn't get rechecked!
10969       }
10970     }
10971   }
10972 
10973   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
10974   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
10975   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
10976     return ZExtLoad;
10977 
10978   // Try to simplify (zext (zextload x)).
10979   if (SDValue foldedExt = tryToFoldExtOfExtload(
10980           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
10981     return foldedExt;
10982 
10983   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
10984     return V;
10985 
10986   if (N0.getOpcode() == ISD::SETCC) {
10987     // Only do this before legalize for now.
10988     if (!LegalOperations && VT.isVector() &&
10989         N0.getValueType().getVectorElementType() == MVT::i1) {
10990       EVT N00VT = N0.getOperand(0).getValueType();
10991       if (getSetCCResultType(N00VT) == N0.getValueType())
10992         return SDValue();
10993 
10994       // We know that the # elements of the results is the same as the #
10995       // elements of the compare (and the # elements of the compare result for
10996       // that matter). Check to see that they are the same size. If so, we know
10997       // that the element size of the sext'd result matches the element size of
10998       // the compare operands.
10999       SDLoc DL(N);
11000       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
11001         // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
11002         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
11003                                      N0.getOperand(1), N0.getOperand(2));
11004         return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
11005       }
11006 
11007       // If the desired elements are smaller or larger than the source
11008       // elements we can use a matching integer vector type and then
11009       // truncate/any extend followed by zext_in_reg.
11010       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
11011       SDValue VsetCC =
11012           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
11013                       N0.getOperand(1), N0.getOperand(2));
11014       return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
11015                                     N0.getValueType());
11016     }
11017 
11018     // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
11019     SDLoc DL(N);
11020     EVT N0VT = N0.getValueType();
11021     EVT N00VT = N0.getOperand(0).getValueType();
11022     if (SDValue SCC = SimplifySelectCC(
11023             DL, N0.getOperand(0), N0.getOperand(1),
11024             DAG.getBoolConstant(true, DL, N0VT, N00VT),
11025             DAG.getBoolConstant(false, DL, N0VT, N00VT),
11026             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
11027       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
11028   }
11029 
11030   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
11031   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11032       isa<ConstantSDNode>(N0.getOperand(1)) &&
11033       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
11034       N0.hasOneUse()) {
11035     SDValue ShAmt = N0.getOperand(1);
11036     if (N0.getOpcode() == ISD::SHL) {
11037       SDValue InnerZExt = N0.getOperand(0);
11038       // If the original shl may be shifting out bits, do not perform this
11039       // transformation.
11040       unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
11041         InnerZExt.getOperand(0).getValueSizeInBits();
11042       if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits))
11043         return SDValue();
11044     }
11045 
11046     SDLoc DL(N);
11047 
11048     // Ensure that the shift amount is wide enough for the shifted value.
11049     if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
11050       ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
11051 
11052     return DAG.getNode(N0.getOpcode(), DL, VT,
11053                        DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
11054                        ShAmt);
11055   }
11056 
11057   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
11058     return NewVSel;
11059 
11060   if (SDValue NewCtPop = widenCtPop(N, DAG))
11061     return NewCtPop;
11062 
11063   return SDValue();
11064 }
11065 
visitANY_EXTEND(SDNode * N)11066 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
11067   SDValue N0 = N->getOperand(0);
11068   EVT VT = N->getValueType(0);
11069 
11070   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
11071     return Res;
11072 
11073   // fold (aext (aext x)) -> (aext x)
11074   // fold (aext (zext x)) -> (zext x)
11075   // fold (aext (sext x)) -> (sext x)
11076   if (N0.getOpcode() == ISD::ANY_EXTEND  ||
11077       N0.getOpcode() == ISD::ZERO_EXTEND ||
11078       N0.getOpcode() == ISD::SIGN_EXTEND)
11079     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
11080 
11081   // fold (aext (truncate (load x))) -> (aext (smaller load x))
11082   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
11083   if (N0.getOpcode() == ISD::TRUNCATE) {
11084     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
11085       SDNode *oye = N0.getOperand(0).getNode();
11086       if (NarrowLoad.getNode() != N0.getNode()) {
11087         CombineTo(N0.getNode(), NarrowLoad);
11088         // CombineTo deleted the truncate, if needed, but not what's under it.
11089         AddToWorklist(oye);
11090       }
11091       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
11092     }
11093   }
11094 
11095   // fold (aext (truncate x))
11096   if (N0.getOpcode() == ISD::TRUNCATE)
11097     return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
11098 
11099   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
11100   // if the trunc is not free.
11101   if (N0.getOpcode() == ISD::AND &&
11102       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
11103       N0.getOperand(1).getOpcode() == ISD::Constant &&
11104       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
11105                           N0.getValueType())) {
11106     SDLoc DL(N);
11107     SDValue X = N0.getOperand(0).getOperand(0);
11108     X = DAG.getAnyExtOrTrunc(X, DL, VT);
11109     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
11110     return DAG.getNode(ISD::AND, DL, VT,
11111                        X, DAG.getConstant(Mask, DL, VT));
11112   }
11113 
11114   // fold (aext (load x)) -> (aext (truncate (extload x)))
11115   // None of the supported targets knows how to perform load and any_ext
11116   // on vectors in one instruction, so attempt to fold to zext instead.
11117   if (VT.isVector()) {
11118     // Try to simplify (zext (load x)).
11119     if (SDValue foldedExt =
11120             tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
11121                                ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
11122       return foldedExt;
11123   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
11124              ISD::isUNINDEXEDLoad(N0.getNode()) &&
11125              TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
11126     bool DoXform = true;
11127     SmallVector<SDNode *, 4> SetCCs;
11128     if (!N0.hasOneUse())
11129       DoXform =
11130           ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
11131     if (DoXform) {
11132       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11133       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
11134                                        LN0->getChain(), LN0->getBasePtr(),
11135                                        N0.getValueType(), LN0->getMemOperand());
11136       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
11137       // If the load value is used only by N, replace it via CombineTo N.
11138       bool NoReplaceTrunc = N0.hasOneUse();
11139       CombineTo(N, ExtLoad);
11140       if (NoReplaceTrunc) {
11141         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
11142         recursivelyDeleteUnusedNodes(LN0);
11143       } else {
11144         SDValue Trunc =
11145             DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
11146         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
11147       }
11148       return SDValue(N, 0); // Return N so it doesn't get rechecked!
11149     }
11150   }
11151 
11152   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
11153   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
11154   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
11155   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
11156       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
11157     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11158     ISD::LoadExtType ExtType = LN0->getExtensionType();
11159     EVT MemVT = LN0->getMemoryVT();
11160     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
11161       SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
11162                                        VT, LN0->getChain(), LN0->getBasePtr(),
11163                                        MemVT, LN0->getMemOperand());
11164       CombineTo(N, ExtLoad);
11165       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
11166       recursivelyDeleteUnusedNodes(LN0);
11167       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
11168     }
11169   }
11170 
11171   if (N0.getOpcode() == ISD::SETCC) {
11172     // For vectors:
11173     // aext(setcc) -> vsetcc
11174     // aext(setcc) -> truncate(vsetcc)
11175     // aext(setcc) -> aext(vsetcc)
11176     // Only do this before legalize for now.
11177     if (VT.isVector() && !LegalOperations) {
11178       EVT N00VT = N0.getOperand(0).getValueType();
11179       if (getSetCCResultType(N00VT) == N0.getValueType())
11180         return SDValue();
11181 
11182       // We know that the # elements of the results is the same as the
11183       // # elements of the compare (and the # elements of the compare result
11184       // for that matter).  Check to see that they are the same size.  If so,
11185       // we know that the element size of the sext'd result matches the
11186       // element size of the compare operands.
11187       if (VT.getSizeInBits() == N00VT.getSizeInBits())
11188         return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
11189                              N0.getOperand(1),
11190                              cast<CondCodeSDNode>(N0.getOperand(2))->get());
11191 
11192       // If the desired elements are smaller or larger than the source
11193       // elements we can use a matching integer vector type and then
11194       // truncate/any extend
11195       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
11196       SDValue VsetCC =
11197         DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
11198                       N0.getOperand(1),
11199                       cast<CondCodeSDNode>(N0.getOperand(2))->get());
11200       return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
11201     }
11202 
11203     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
11204     SDLoc DL(N);
11205     if (SDValue SCC = SimplifySelectCC(
11206             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
11207             DAG.getConstant(0, DL, VT),
11208             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
11209       return SCC;
11210   }
11211 
11212   if (SDValue NewCtPop = widenCtPop(N, DAG))
11213     return NewCtPop;
11214 
11215   return SDValue();
11216 }
11217 
visitAssertExt(SDNode * N)11218 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
11219   unsigned Opcode = N->getOpcode();
11220   SDValue N0 = N->getOperand(0);
11221   SDValue N1 = N->getOperand(1);
11222   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
11223 
11224   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
11225   if (N0.getOpcode() == Opcode &&
11226       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
11227     return N0;
11228 
11229   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
11230       N0.getOperand(0).getOpcode() == Opcode) {
11231     // We have an assert, truncate, assert sandwich. Make one stronger assert
11232     // by asserting on the smallest asserted type to the larger source type.
11233     // This eliminates the later assert:
11234     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
11235     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
11236     SDValue BigA = N0.getOperand(0);
11237     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
11238     assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
11239            "Asserting zero/sign-extended bits to a type larger than the "
11240            "truncated destination does not provide information");
11241 
11242     SDLoc DL(N);
11243     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
11244     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
11245     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
11246                                     BigA.getOperand(0), MinAssertVTVal);
11247     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
11248   }
11249 
11250   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
11251   // than X. Just move the AssertZext in front of the truncate and drop the
11252   // AssertSExt.
11253   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
11254       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
11255       Opcode == ISD::AssertZext) {
11256     SDValue BigA = N0.getOperand(0);
11257     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
11258     assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
11259            "Asserting zero/sign-extended bits to a type larger than the "
11260            "truncated destination does not provide information");
11261 
11262     if (AssertVT.bitsLT(BigA_AssertVT)) {
11263       SDLoc DL(N);
11264       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
11265                                       BigA.getOperand(0), N1);
11266       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
11267     }
11268   }
11269 
11270   return SDValue();
11271 }
11272 
visitAssertAlign(SDNode * N)11273 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
11274   SDLoc DL(N);
11275 
11276   Align AL = cast<AssertAlignSDNode>(N)->getAlign();
11277   SDValue N0 = N->getOperand(0);
11278 
11279   // Fold (assertalign (assertalign x, AL0), AL1) ->
11280   // (assertalign x, max(AL0, AL1))
11281   if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
11282     return DAG.getAssertAlign(DL, N0.getOperand(0),
11283                               std::max(AL, AAN->getAlign()));
11284 
11285   // In rare cases, there are trivial arithmetic ops in source operands. Sink
11286   // this assert down to source operands so that those arithmetic ops could be
11287   // exposed to the DAG combining.
11288   switch (N0.getOpcode()) {
11289   default:
11290     break;
11291   case ISD::ADD:
11292   case ISD::SUB: {
11293     unsigned AlignShift = Log2(AL);
11294     SDValue LHS = N0.getOperand(0);
11295     SDValue RHS = N0.getOperand(1);
11296     unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
11297     unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
11298     if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
11299       if (LHSAlignShift < AlignShift)
11300         LHS = DAG.getAssertAlign(DL, LHS, AL);
11301       if (RHSAlignShift < AlignShift)
11302         RHS = DAG.getAssertAlign(DL, RHS, AL);
11303       return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
11304     }
11305     break;
11306   }
11307   }
11308 
11309   return SDValue();
11310 }
11311 
11312 /// If the result of a wider load is shifted to right of N  bits and then
11313 /// truncated to a narrower type and where N is a multiple of number of bits of
11314 /// the narrower type, transform it to a narrower load from address + N / num of
11315 /// bits of new type. Also narrow the load if the result is masked with an AND
11316 /// to effectively produce a smaller type. If the result is to be extended, also
11317 /// fold the extension to form a extending load.
ReduceLoadWidth(SDNode * N)11318 SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
11319   unsigned Opc = N->getOpcode();
11320 
11321   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
11322   SDValue N0 = N->getOperand(0);
11323   EVT VT = N->getValueType(0);
11324   EVT ExtVT = VT;
11325 
11326   // This transformation isn't valid for vector loads.
11327   if (VT.isVector())
11328     return SDValue();
11329 
11330   unsigned ShAmt = 0;
11331   bool HasShiftedOffset = false;
11332   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
11333   // extended to VT.
11334   if (Opc == ISD::SIGN_EXTEND_INREG) {
11335     ExtType = ISD::SEXTLOAD;
11336     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
11337   } else if (Opc == ISD::SRL) {
11338     // Another special-case: SRL is basically zero-extending a narrower value,
11339     // or it maybe shifting a higher subword, half or byte into the lowest
11340     // bits.
11341     ExtType = ISD::ZEXTLOAD;
11342     N0 = SDValue(N, 0);
11343 
11344     auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
11345     auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11346     if (!N01 || !LN0)
11347       return SDValue();
11348 
11349     uint64_t ShiftAmt = N01->getZExtValue();
11350     uint64_t MemoryWidth = LN0->getMemoryVT().getScalarSizeInBits();
11351     if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
11352       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
11353     else
11354       ExtVT = EVT::getIntegerVT(*DAG.getContext(),
11355                                 VT.getScalarSizeInBits() - ShiftAmt);
11356   } else if (Opc == ISD::AND) {
11357     // An AND with a constant mask is the same as a truncate + zero-extend.
11358     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
11359     if (!AndC)
11360       return SDValue();
11361 
11362     const APInt &Mask = AndC->getAPIntValue();
11363     unsigned ActiveBits = 0;
11364     if (Mask.isMask()) {
11365       ActiveBits = Mask.countTrailingOnes();
11366     } else if (Mask.isShiftedMask()) {
11367       ShAmt = Mask.countTrailingZeros();
11368       APInt ShiftedMask = Mask.lshr(ShAmt);
11369       ActiveBits = ShiftedMask.countTrailingOnes();
11370       HasShiftedOffset = true;
11371     } else
11372       return SDValue();
11373 
11374     ExtType = ISD::ZEXTLOAD;
11375     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
11376   }
11377 
11378   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
11379     SDValue SRL = N0;
11380     if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
11381       ShAmt = ConstShift->getZExtValue();
11382       unsigned EVTBits = ExtVT.getScalarSizeInBits();
11383       // Is the shift amount a multiple of size of VT?
11384       if ((ShAmt & (EVTBits-1)) == 0) {
11385         N0 = N0.getOperand(0);
11386         // Is the load width a multiple of size of VT?
11387         if ((N0.getScalarValueSizeInBits() & (EVTBits - 1)) != 0)
11388           return SDValue();
11389       }
11390 
11391       // At this point, we must have a load or else we can't do the transform.
11392       auto *LN0 = dyn_cast<LoadSDNode>(N0);
11393       if (!LN0) return SDValue();
11394 
11395       // Because a SRL must be assumed to *need* to zero-extend the high bits
11396       // (as opposed to anyext the high bits), we can't combine the zextload
11397       // lowering of SRL and an sextload.
11398       if (LN0->getExtensionType() == ISD::SEXTLOAD)
11399         return SDValue();
11400 
11401       // If the shift amount is larger than the input type then we're not
11402       // accessing any of the loaded bytes.  If the load was a zextload/extload
11403       // then the result of the shift+trunc is zero/undef (handled elsewhere).
11404       if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
11405         return SDValue();
11406 
11407       // If the SRL is only used by a masking AND, we may be able to adjust
11408       // the ExtVT to make the AND redundant.
11409       SDNode *Mask = *(SRL->use_begin());
11410       if (Mask->getOpcode() == ISD::AND &&
11411           isa<ConstantSDNode>(Mask->getOperand(1))) {
11412         const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
11413         if (ShiftMask.isMask()) {
11414           EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
11415                                            ShiftMask.countTrailingOnes());
11416           // If the mask is smaller, recompute the type.
11417           if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
11418               TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
11419             ExtVT = MaskedVT;
11420         }
11421       }
11422     }
11423   }
11424 
11425   // If the load is shifted left (and the result isn't shifted back right),
11426   // we can fold the truncate through the shift.
11427   unsigned ShLeftAmt = 0;
11428   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
11429       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
11430     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
11431       ShLeftAmt = N01->getZExtValue();
11432       N0 = N0.getOperand(0);
11433     }
11434   }
11435 
11436   // If we haven't found a load, we can't narrow it.
11437   if (!isa<LoadSDNode>(N0))
11438     return SDValue();
11439 
11440   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11441   // Reducing the width of a volatile load is illegal.  For atomics, we may be
11442   // able to reduce the width provided we never widen again. (see D66309)
11443   if (!LN0->isSimple() ||
11444       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
11445     return SDValue();
11446 
11447   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
11448     unsigned LVTStoreBits =
11449         LN0->getMemoryVT().getStoreSizeInBits().getFixedSize();
11450     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedSize();
11451     return LVTStoreBits - EVTStoreBits - ShAmt;
11452   };
11453 
11454   // For big endian targets, we need to adjust the offset to the pointer to
11455   // load the correct bytes.
11456   if (DAG.getDataLayout().isBigEndian())
11457     ShAmt = AdjustBigEndianShift(ShAmt);
11458 
11459   uint64_t PtrOff = ShAmt / 8;
11460   Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
11461   SDLoc DL(LN0);
11462   // The original load itself didn't wrap, so an offset within it doesn't.
11463   SDNodeFlags Flags;
11464   Flags.setNoUnsignedWrap(true);
11465   SDValue NewPtr = DAG.getMemBasePlusOffset(LN0->getBasePtr(),
11466                                             TypeSize::Fixed(PtrOff), DL, Flags);
11467   AddToWorklist(NewPtr.getNode());
11468 
11469   SDValue Load;
11470   if (ExtType == ISD::NON_EXTLOAD)
11471     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
11472                        LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
11473                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
11474   else
11475     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
11476                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
11477                           NewAlign, LN0->getMemOperand()->getFlags(),
11478                           LN0->getAAInfo());
11479 
11480   // Replace the old load's chain with the new load's chain.
11481   WorklistRemover DeadNodes(*this);
11482   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
11483 
11484   // Shift the result left, if we've swallowed a left shift.
11485   SDValue Result = Load;
11486   if (ShLeftAmt != 0) {
11487     EVT ShImmTy = getShiftAmountTy(Result.getValueType());
11488     if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
11489       ShImmTy = VT;
11490     // If the shift amount is as large as the result size (but, presumably,
11491     // no larger than the source) then the useful bits of the result are
11492     // zero; we can't simply return the shortened shift, because the result
11493     // of that operation is undefined.
11494     if (ShLeftAmt >= VT.getScalarSizeInBits())
11495       Result = DAG.getConstant(0, DL, VT);
11496     else
11497       Result = DAG.getNode(ISD::SHL, DL, VT,
11498                           Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
11499   }
11500 
11501   if (HasShiftedOffset) {
11502     // Recalculate the shift amount after it has been altered to calculate
11503     // the offset.
11504     if (DAG.getDataLayout().isBigEndian())
11505       ShAmt = AdjustBigEndianShift(ShAmt);
11506 
11507     // We're using a shifted mask, so the load now has an offset. This means
11508     // that data has been loaded into the lower bytes than it would have been
11509     // before, so we need to shl the loaded data into the correct position in the
11510     // register.
11511     SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
11512     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
11513     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
11514   }
11515 
11516   // Return the new loaded value.
11517   return Result;
11518 }
11519 
visitSIGN_EXTEND_INREG(SDNode * N)11520 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
11521   SDValue N0 = N->getOperand(0);
11522   SDValue N1 = N->getOperand(1);
11523   EVT VT = N->getValueType(0);
11524   EVT ExtVT = cast<VTSDNode>(N1)->getVT();
11525   unsigned VTBits = VT.getScalarSizeInBits();
11526   unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
11527 
11528   // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
11529   if (N0.isUndef())
11530     return DAG.getConstant(0, SDLoc(N), VT);
11531 
11532   // fold (sext_in_reg c1) -> c1
11533   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
11534     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
11535 
11536   // If the input is already sign extended, just drop the extension.
11537   if (DAG.ComputeNumSignBits(N0) >= (VTBits - ExtVTBits + 1))
11538     return N0;
11539 
11540   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
11541   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
11542       ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
11543     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0),
11544                        N1);
11545 
11546   // fold (sext_in_reg (sext x)) -> (sext x)
11547   // fold (sext_in_reg (aext x)) -> (sext x)
11548   // if x is small enough or if we know that x has more than 1 sign bit and the
11549   // sign_extend_inreg is extending from one of them.
11550   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
11551     SDValue N00 = N0.getOperand(0);
11552     unsigned N00Bits = N00.getScalarValueSizeInBits();
11553     if ((N00Bits <= ExtVTBits ||
11554          (N00Bits - DAG.ComputeNumSignBits(N00)) < ExtVTBits) &&
11555         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
11556       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
11557   }
11558 
11559   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
11560   if ((N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
11561        N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
11562        N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) &&
11563       N0.getOperand(0).getScalarValueSizeInBits() == ExtVTBits) {
11564     if (!LegalOperations ||
11565         TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT))
11566       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
11567                          N0.getOperand(0));
11568   }
11569 
11570   // fold (sext_in_reg (zext x)) -> (sext x)
11571   // iff we are extending the source sign bit.
11572   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
11573     SDValue N00 = N0.getOperand(0);
11574     if (N00.getScalarValueSizeInBits() == ExtVTBits &&
11575         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
11576       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
11577   }
11578 
11579   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
11580   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
11581     return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT);
11582 
11583   // fold operands of sext_in_reg based on knowledge that the top bits are not
11584   // demanded.
11585   if (SimplifyDemandedBits(SDValue(N, 0)))
11586     return SDValue(N, 0);
11587 
11588   // fold (sext_in_reg (load x)) -> (smaller sextload x)
11589   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
11590   if (SDValue NarrowLoad = ReduceLoadWidth(N))
11591     return NarrowLoad;
11592 
11593   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
11594   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
11595   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
11596   if (N0.getOpcode() == ISD::SRL) {
11597     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
11598       if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
11599         // We can turn this into an SRA iff the input to the SRL is already sign
11600         // extended enough.
11601         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
11602         if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
11603           return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
11604                              N0.getOperand(1));
11605       }
11606   }
11607 
11608   // fold (sext_inreg (extload x)) -> (sextload x)
11609   // If sextload is not supported by target, we can only do the combine when
11610   // load has one use. Doing otherwise can block folding the extload with other
11611   // extends that the target does support.
11612   if (ISD::isEXTLoad(N0.getNode()) &&
11613       ISD::isUNINDEXEDLoad(N0.getNode()) &&
11614       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
11615       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
11616         N0.hasOneUse()) ||
11617        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
11618     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11619     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
11620                                      LN0->getChain(),
11621                                      LN0->getBasePtr(), ExtVT,
11622                                      LN0->getMemOperand());
11623     CombineTo(N, ExtLoad);
11624     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
11625     AddToWorklist(ExtLoad.getNode());
11626     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
11627   }
11628   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
11629   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
11630       N0.hasOneUse() &&
11631       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
11632       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
11633        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
11634     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11635     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
11636                                      LN0->getChain(),
11637                                      LN0->getBasePtr(), ExtVT,
11638                                      LN0->getMemOperand());
11639     CombineTo(N, ExtLoad);
11640     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
11641     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
11642   }
11643 
11644   // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
11645   // ignore it if the masked load is already sign extended
11646   if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
11647     if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
11648         Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
11649         TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
11650       SDValue ExtMaskedLoad = DAG.getMaskedLoad(
11651           VT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
11652           Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
11653           Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
11654       CombineTo(N, ExtMaskedLoad);
11655       CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
11656       return SDValue(N, 0); // Return N so it doesn't get rechecked!
11657     }
11658   }
11659 
11660   // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
11661   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
11662     if (SDValue(GN0, 0).hasOneUse() &&
11663         ExtVT == GN0->getMemoryVT() &&
11664         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
11665       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
11666                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
11667 
11668       SDValue ExtLoad = DAG.getMaskedGather(
11669           DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
11670           GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
11671 
11672       CombineTo(N, ExtLoad);
11673       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
11674       AddToWorklist(ExtLoad.getNode());
11675       return SDValue(N, 0); // Return N so it doesn't get rechecked!
11676     }
11677   }
11678 
11679   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
11680   if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
11681     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
11682                                            N0.getOperand(1), false))
11683       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1);
11684   }
11685 
11686   return SDValue();
11687 }
11688 
visitSIGN_EXTEND_VECTOR_INREG(SDNode * N)11689 SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) {
11690   SDValue N0 = N->getOperand(0);
11691   EVT VT = N->getValueType(0);
11692 
11693   // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
11694   if (N0.isUndef())
11695     return DAG.getConstant(0, SDLoc(N), VT);
11696 
11697   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
11698     return Res;
11699 
11700   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
11701     return SDValue(N, 0);
11702 
11703   return SDValue();
11704 }
11705 
visitZERO_EXTEND_VECTOR_INREG(SDNode * N)11706 SDValue DAGCombiner::visitZERO_EXTEND_VECTOR_INREG(SDNode *N) {
11707   SDValue N0 = N->getOperand(0);
11708   EVT VT = N->getValueType(0);
11709 
11710   // zext_vector_inreg(undef) = 0 because the top bits will be zero.
11711   if (N0.isUndef())
11712     return DAG.getConstant(0, SDLoc(N), VT);
11713 
11714   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
11715     return Res;
11716 
11717   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
11718     return SDValue(N, 0);
11719 
11720   return SDValue();
11721 }
11722 
visitTRUNCATE(SDNode * N)11723 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
11724   SDValue N0 = N->getOperand(0);
11725   EVT VT = N->getValueType(0);
11726   EVT SrcVT = N0.getValueType();
11727   bool isLE = DAG.getDataLayout().isLittleEndian();
11728 
11729   // noop truncate
11730   if (SrcVT == VT)
11731     return N0;
11732 
11733   // fold (truncate (truncate x)) -> (truncate x)
11734   if (N0.getOpcode() == ISD::TRUNCATE)
11735     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
11736 
11737   // fold (truncate c1) -> c1
11738   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
11739     SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
11740     if (C.getNode() != N)
11741       return C;
11742   }
11743 
11744   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
11745   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
11746       N0.getOpcode() == ISD::SIGN_EXTEND ||
11747       N0.getOpcode() == ISD::ANY_EXTEND) {
11748     // if the source is smaller than the dest, we still need an extend.
11749     if (N0.getOperand(0).getValueType().bitsLT(VT))
11750       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
11751     // if the source is larger than the dest, than we just need the truncate.
11752     if (N0.getOperand(0).getValueType().bitsGT(VT))
11753       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
11754     // if the source and dest are the same type, we can drop both the extend
11755     // and the truncate.
11756     return N0.getOperand(0);
11757   }
11758 
11759   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
11760   if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
11761     return SDValue();
11762 
11763   // Fold extract-and-trunc into a narrow extract. For example:
11764   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
11765   //   i32 y = TRUNCATE(i64 x)
11766   //        -- becomes --
11767   //   v16i8 b = BITCAST (v2i64 val)
11768   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
11769   //
11770   // Note: We only run this optimization after type legalization (which often
11771   // creates this pattern) and before operation legalization after which
11772   // we need to be more careful about the vector instructions that we generate.
11773   if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
11774       LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
11775     EVT VecTy = N0.getOperand(0).getValueType();
11776     EVT ExTy = N0.getValueType();
11777     EVT TrTy = N->getValueType(0);
11778 
11779     auto EltCnt = VecTy.getVectorElementCount();
11780     unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
11781     auto NewEltCnt = EltCnt * SizeRatio;
11782 
11783     EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
11784     assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
11785 
11786     SDValue EltNo = N0->getOperand(1);
11787     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
11788       int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
11789       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
11790 
11791       SDLoc DL(N);
11792       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
11793                          DAG.getBitcast(NVT, N0.getOperand(0)),
11794                          DAG.getVectorIdxConstant(Index, DL));
11795     }
11796   }
11797 
11798   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
11799   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
11800     if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
11801         TLI.isTruncateFree(SrcVT, VT)) {
11802       SDLoc SL(N0);
11803       SDValue Cond = N0.getOperand(0);
11804       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
11805       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
11806       return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
11807     }
11808   }
11809 
11810   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
11811   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
11812       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
11813       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
11814     SDValue Amt = N0.getOperand(1);
11815     KnownBits Known = DAG.computeKnownBits(Amt);
11816     unsigned Size = VT.getScalarSizeInBits();
11817     if (Known.getBitWidth() - Known.countMinLeadingZeros() <= Log2_32(Size)) {
11818       SDLoc SL(N);
11819       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
11820 
11821       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
11822       if (AmtVT != Amt.getValueType()) {
11823         Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
11824         AddToWorklist(Amt.getNode());
11825       }
11826       return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
11827     }
11828   }
11829 
11830   // Attempt to pre-truncate BUILD_VECTOR sources.
11831   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
11832       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
11833       // Avoid creating illegal types if running after type legalizer.
11834       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
11835     SDLoc DL(N);
11836     EVT SVT = VT.getScalarType();
11837     SmallVector<SDValue, 8> TruncOps;
11838     for (const SDValue &Op : N0->op_values()) {
11839       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
11840       TruncOps.push_back(TruncOp);
11841     }
11842     return DAG.getBuildVector(VT, DL, TruncOps);
11843   }
11844 
11845   // Fold a series of buildvector, bitcast, and truncate if possible.
11846   // For example fold
11847   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
11848   //   (2xi32 (buildvector x, y)).
11849   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
11850       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
11851       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
11852       N0.getOperand(0).hasOneUse()) {
11853     SDValue BuildVect = N0.getOperand(0);
11854     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
11855     EVT TruncVecEltTy = VT.getVectorElementType();
11856 
11857     // Check that the element types match.
11858     if (BuildVectEltTy == TruncVecEltTy) {
11859       // Now we only need to compute the offset of the truncated elements.
11860       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
11861       unsigned TruncVecNumElts = VT.getVectorNumElements();
11862       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
11863 
11864       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
11865              "Invalid number of elements");
11866 
11867       SmallVector<SDValue, 8> Opnds;
11868       for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
11869         Opnds.push_back(BuildVect.getOperand(i));
11870 
11871       return DAG.getBuildVector(VT, SDLoc(N), Opnds);
11872     }
11873   }
11874 
11875   // See if we can simplify the input to this truncate through knowledge that
11876   // only the low bits are being used.
11877   // For example "trunc (or (shl x, 8), y)" // -> trunc y
11878   // Currently we only perform this optimization on scalars because vectors
11879   // may have different active low bits.
11880   if (!VT.isVector()) {
11881     APInt Mask =
11882         APInt::getLowBitsSet(N0.getValueSizeInBits(), VT.getSizeInBits());
11883     if (SDValue Shorter = DAG.GetDemandedBits(N0, Mask))
11884       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter);
11885   }
11886 
11887   // fold (truncate (load x)) -> (smaller load x)
11888   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
11889   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
11890     if (SDValue Reduced = ReduceLoadWidth(N))
11891       return Reduced;
11892 
11893     // Handle the case where the load remains an extending load even
11894     // after truncation.
11895     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
11896       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11897       if (LN0->isSimple() && LN0->getMemoryVT().bitsLT(VT)) {
11898         SDValue NewLoad = DAG.getExtLoad(LN0->getExtensionType(), SDLoc(LN0),
11899                                          VT, LN0->getChain(), LN0->getBasePtr(),
11900                                          LN0->getMemoryVT(),
11901                                          LN0->getMemOperand());
11902         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
11903         return NewLoad;
11904       }
11905     }
11906   }
11907 
11908   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
11909   // where ... are all 'undef'.
11910   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
11911     SmallVector<EVT, 8> VTs;
11912     SDValue V;
11913     unsigned Idx = 0;
11914     unsigned NumDefs = 0;
11915 
11916     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
11917       SDValue X = N0.getOperand(i);
11918       if (!X.isUndef()) {
11919         V = X;
11920         Idx = i;
11921         NumDefs++;
11922       }
11923       // Stop if more than one members are non-undef.
11924       if (NumDefs > 1)
11925         break;
11926 
11927       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
11928                                      VT.getVectorElementType(),
11929                                      X.getValueType().getVectorElementCount()));
11930     }
11931 
11932     if (NumDefs == 0)
11933       return DAG.getUNDEF(VT);
11934 
11935     if (NumDefs == 1) {
11936       assert(V.getNode() && "The single defined operand is empty!");
11937       SmallVector<SDValue, 8> Opnds;
11938       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
11939         if (i != Idx) {
11940           Opnds.push_back(DAG.getUNDEF(VTs[i]));
11941           continue;
11942         }
11943         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
11944         AddToWorklist(NV.getNode());
11945         Opnds.push_back(NV);
11946       }
11947       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
11948     }
11949   }
11950 
11951   // Fold truncate of a bitcast of a vector to an extract of the low vector
11952   // element.
11953   //
11954   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
11955   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
11956     SDValue VecSrc = N0.getOperand(0);
11957     EVT VecSrcVT = VecSrc.getValueType();
11958     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
11959         (!LegalOperations ||
11960          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
11961       SDLoc SL(N);
11962 
11963       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
11964       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
11965                          DAG.getVectorIdxConstant(Idx, SL));
11966     }
11967   }
11968 
11969   // Simplify the operands using demanded-bits information.
11970   if (SimplifyDemandedBits(SDValue(N, 0)))
11971     return SDValue(N, 0);
11972 
11973   // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
11974   // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
11975   // When the adde's carry is not used.
11976   if ((N0.getOpcode() == ISD::ADDE || N0.getOpcode() == ISD::ADDCARRY) &&
11977       N0.hasOneUse() && !N0.getNode()->hasAnyUseOfValue(1) &&
11978       // We only do for addcarry before legalize operation
11979       ((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
11980        TLI.isOperationLegal(N0.getOpcode(), VT))) {
11981     SDLoc SL(N);
11982     auto X = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
11983     auto Y = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
11984     auto VTs = DAG.getVTList(VT, N0->getValueType(1));
11985     return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2));
11986   }
11987 
11988   // fold (truncate (extract_subvector(ext x))) ->
11989   //      (extract_subvector x)
11990   // TODO: This can be generalized to cover cases where the truncate and extract
11991   // do not fully cancel each other out.
11992   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
11993     SDValue N00 = N0.getOperand(0);
11994     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
11995         N00.getOpcode() == ISD::ZERO_EXTEND ||
11996         N00.getOpcode() == ISD::ANY_EXTEND) {
11997       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
11998           VT.getVectorElementType())
11999         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
12000                            N00.getOperand(0), N0.getOperand(1));
12001     }
12002   }
12003 
12004   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12005     return NewVSel;
12006 
12007   // Narrow a suitable binary operation with a non-opaque constant operand by
12008   // moving it ahead of the truncate. This is limited to pre-legalization
12009   // because targets may prefer a wider type during later combines and invert
12010   // this transform.
12011   switch (N0.getOpcode()) {
12012   case ISD::ADD:
12013   case ISD::SUB:
12014   case ISD::MUL:
12015   case ISD::AND:
12016   case ISD::OR:
12017   case ISD::XOR:
12018     if (!LegalOperations && N0.hasOneUse() &&
12019         (isConstantOrConstantVector(N0.getOperand(0), true) ||
12020          isConstantOrConstantVector(N0.getOperand(1), true))) {
12021       // TODO: We already restricted this to pre-legalization, but for vectors
12022       // we are extra cautious to not create an unsupported operation.
12023       // Target-specific changes are likely needed to avoid regressions here.
12024       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
12025         SDLoc DL(N);
12026         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
12027         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
12028         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
12029       }
12030     }
12031   }
12032 
12033   return SDValue();
12034 }
12035 
getBuildPairElt(SDNode * N,unsigned i)12036 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
12037   SDValue Elt = N->getOperand(i);
12038   if (Elt.getOpcode() != ISD::MERGE_VALUES)
12039     return Elt.getNode();
12040   return Elt.getOperand(Elt.getResNo()).getNode();
12041 }
12042 
12043 /// build_pair (load, load) -> load
12044 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)12045 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
12046   assert(N->getOpcode() == ISD::BUILD_PAIR);
12047 
12048   LoadSDNode *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
12049   LoadSDNode *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
12050 
12051   // A BUILD_PAIR is always having the least significant part in elt 0 and the
12052   // most significant part in elt 1. So when combining into one large load, we
12053   // need to consider the endianness.
12054   if (DAG.getDataLayout().isBigEndian())
12055     std::swap(LD1, LD2);
12056 
12057   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !LD1->hasOneUse() ||
12058       LD1->getAddressSpace() != LD2->getAddressSpace())
12059     return SDValue();
12060   EVT LD1VT = LD1->getValueType(0);
12061   unsigned LD1Bytes = LD1VT.getStoreSize();
12062   if (ISD::isNON_EXTLoad(LD2) && LD2->hasOneUse() &&
12063       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1)) {
12064     Align Alignment = LD1->getAlign();
12065     Align NewAlign = DAG.getDataLayout().getABITypeAlign(
12066         VT.getTypeForEVT(*DAG.getContext()));
12067 
12068     if (NewAlign <= Alignment &&
12069         (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)))
12070       return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
12071                          LD1->getPointerInfo(), Alignment);
12072   }
12073 
12074   return SDValue();
12075 }
12076 
getPPCf128HiElementSelector(const SelectionDAG & DAG)12077 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
12078   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
12079   // and Lo parts; on big-endian machines it doesn't.
12080   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
12081 }
12082 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)12083 static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
12084                                     const TargetLowering &TLI) {
12085   // If this is not a bitcast to an FP type or if the target doesn't have
12086   // IEEE754-compliant FP logic, we're done.
12087   EVT VT = N->getValueType(0);
12088   if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
12089     return SDValue();
12090 
12091   // TODO: Handle cases where the integer constant is a different scalar
12092   // bitwidth to the FP.
12093   SDValue N0 = N->getOperand(0);
12094   EVT SourceVT = N0.getValueType();
12095   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
12096     return SDValue();
12097 
12098   unsigned FPOpcode;
12099   APInt SignMask;
12100   switch (N0.getOpcode()) {
12101   case ISD::AND:
12102     FPOpcode = ISD::FABS;
12103     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
12104     break;
12105   case ISD::XOR:
12106     FPOpcode = ISD::FNEG;
12107     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
12108     break;
12109   case ISD::OR:
12110     FPOpcode = ISD::FABS;
12111     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
12112     break;
12113   default:
12114     return SDValue();
12115   }
12116 
12117   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
12118   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
12119   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
12120   //   fneg (fabs X)
12121   SDValue LogicOp0 = N0.getOperand(0);
12122   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
12123   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
12124       LogicOp0.getOpcode() == ISD::BITCAST &&
12125       LogicOp0.getOperand(0).getValueType() == VT) {
12126     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
12127     NumFPLogicOpsConv++;
12128     if (N0.getOpcode() == ISD::OR)
12129       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
12130     return FPOp;
12131   }
12132 
12133   return SDValue();
12134 }
12135 
visitBITCAST(SDNode * N)12136 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
12137   SDValue N0 = N->getOperand(0);
12138   EVT VT = N->getValueType(0);
12139 
12140   if (N0.isUndef())
12141     return DAG.getUNDEF(VT);
12142 
12143   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
12144   // Only do this before legalize types, unless both types are integer and the
12145   // scalar type is legal. Only do this before legalize ops, since the target
12146   // maybe depending on the bitcast.
12147   // First check to see if this is all constant.
12148   // TODO: Support FP bitcasts after legalize types.
12149   if (VT.isVector() &&
12150       (!LegalTypes ||
12151        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
12152         TLI.isTypeLegal(VT.getVectorElementType()))) &&
12153       N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() &&
12154       cast<BuildVectorSDNode>(N0)->isConstant())
12155     return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
12156                                              VT.getVectorElementType());
12157 
12158   // If the input is a constant, let getNode fold it.
12159   if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0)) {
12160     // If we can't allow illegal operations, we need to check that this is just
12161     // a fp -> int or int -> conversion and that the resulting operation will
12162     // be legal.
12163     if (!LegalOperations ||
12164         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
12165          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
12166         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
12167          TLI.isOperationLegal(ISD::Constant, VT))) {
12168       SDValue C = DAG.getBitcast(VT, N0);
12169       if (C.getNode() != N)
12170         return C;
12171     }
12172   }
12173 
12174   // (conv (conv x, t1), t2) -> (conv x, t2)
12175   if (N0.getOpcode() == ISD::BITCAST)
12176     return DAG.getBitcast(VT, N0.getOperand(0));
12177 
12178   // fold (conv (load x)) -> (load (conv*)x)
12179   // If the resultant load doesn't need a higher alignment than the original!
12180   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
12181       // Do not remove the cast if the types differ in endian layout.
12182       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
12183           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
12184       // If the load is volatile, we only want to change the load type if the
12185       // resulting load is legal. Otherwise we might increase the number of
12186       // memory accesses. We don't care if the original type was legal or not
12187       // as we assume software couldn't rely on the number of accesses of an
12188       // illegal type.
12189       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
12190        TLI.isOperationLegal(ISD::LOAD, VT))) {
12191     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12192 
12193     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
12194                                     *LN0->getMemOperand())) {
12195       SDValue Load =
12196           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
12197                       LN0->getPointerInfo(), LN0->getAlign(),
12198                       LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
12199       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
12200       return Load;
12201     }
12202   }
12203 
12204   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
12205     return V;
12206 
12207   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
12208   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
12209   //
12210   // For ppc_fp128:
12211   // fold (bitcast (fneg x)) ->
12212   //     flipbit = signbit
12213   //     (xor (bitcast x) (build_pair flipbit, flipbit))
12214   //
12215   // fold (bitcast (fabs x)) ->
12216   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
12217   //     (xor (bitcast x) (build_pair flipbit, flipbit))
12218   // This often reduces constant pool loads.
12219   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
12220        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
12221       N0.getNode()->hasOneUse() && VT.isInteger() &&
12222       !VT.isVector() && !N0.getValueType().isVector()) {
12223     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
12224     AddToWorklist(NewConv.getNode());
12225 
12226     SDLoc DL(N);
12227     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
12228       assert(VT.getSizeInBits() == 128);
12229       SDValue SignBit = DAG.getConstant(
12230           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
12231       SDValue FlipBit;
12232       if (N0.getOpcode() == ISD::FNEG) {
12233         FlipBit = SignBit;
12234         AddToWorklist(FlipBit.getNode());
12235       } else {
12236         assert(N0.getOpcode() == ISD::FABS);
12237         SDValue Hi =
12238             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
12239                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
12240                                               SDLoc(NewConv)));
12241         AddToWorklist(Hi.getNode());
12242         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
12243         AddToWorklist(FlipBit.getNode());
12244       }
12245       SDValue FlipBits =
12246           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
12247       AddToWorklist(FlipBits.getNode());
12248       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
12249     }
12250     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
12251     if (N0.getOpcode() == ISD::FNEG)
12252       return DAG.getNode(ISD::XOR, DL, VT,
12253                          NewConv, DAG.getConstant(SignBit, DL, VT));
12254     assert(N0.getOpcode() == ISD::FABS);
12255     return DAG.getNode(ISD::AND, DL, VT,
12256                        NewConv, DAG.getConstant(~SignBit, DL, VT));
12257   }
12258 
12259   // fold (bitconvert (fcopysign cst, x)) ->
12260   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
12261   // Note that we don't handle (copysign x, cst) because this can always be
12262   // folded to an fneg or fabs.
12263   //
12264   // For ppc_fp128:
12265   // fold (bitcast (fcopysign cst, x)) ->
12266   //     flipbit = (and (extract_element
12267   //                     (xor (bitcast cst), (bitcast x)), 0),
12268   //                    signbit)
12269   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
12270   if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse() &&
12271       isa<ConstantFPSDNode>(N0.getOperand(0)) &&
12272       VT.isInteger() && !VT.isVector()) {
12273     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
12274     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
12275     if (isTypeLegal(IntXVT)) {
12276       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
12277       AddToWorklist(X.getNode());
12278 
12279       // If X has a different width than the result/lhs, sext it or truncate it.
12280       unsigned VTWidth = VT.getSizeInBits();
12281       if (OrigXWidth < VTWidth) {
12282         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
12283         AddToWorklist(X.getNode());
12284       } else if (OrigXWidth > VTWidth) {
12285         // To get the sign bit in the right place, we have to shift it right
12286         // before truncating.
12287         SDLoc DL(X);
12288         X = DAG.getNode(ISD::SRL, DL,
12289                         X.getValueType(), X,
12290                         DAG.getConstant(OrigXWidth-VTWidth, DL,
12291                                         X.getValueType()));
12292         AddToWorklist(X.getNode());
12293         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
12294         AddToWorklist(X.getNode());
12295       }
12296 
12297       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
12298         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
12299         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
12300         AddToWorklist(Cst.getNode());
12301         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
12302         AddToWorklist(X.getNode());
12303         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
12304         AddToWorklist(XorResult.getNode());
12305         SDValue XorResult64 = DAG.getNode(
12306             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
12307             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
12308                                   SDLoc(XorResult)));
12309         AddToWorklist(XorResult64.getNode());
12310         SDValue FlipBit =
12311             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
12312                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
12313         AddToWorklist(FlipBit.getNode());
12314         SDValue FlipBits =
12315             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
12316         AddToWorklist(FlipBits.getNode());
12317         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
12318       }
12319       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
12320       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
12321                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
12322       AddToWorklist(X.getNode());
12323 
12324       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
12325       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
12326                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
12327       AddToWorklist(Cst.getNode());
12328 
12329       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
12330     }
12331   }
12332 
12333   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
12334   if (N0.getOpcode() == ISD::BUILD_PAIR)
12335     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
12336       return CombineLD;
12337 
12338   // Remove double bitcasts from shuffles - this is often a legacy of
12339   // XformToShuffleWithZero being used to combine bitmaskings (of
12340   // float vectors bitcast to integer vectors) into shuffles.
12341   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
12342   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
12343       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
12344       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
12345       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
12346     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
12347 
12348     // If operands are a bitcast, peek through if it casts the original VT.
12349     // If operands are a constant, just bitcast back to original VT.
12350     auto PeekThroughBitcast = [&](SDValue Op) {
12351       if (Op.getOpcode() == ISD::BITCAST &&
12352           Op.getOperand(0).getValueType() == VT)
12353         return SDValue(Op.getOperand(0));
12354       if (Op.isUndef() || ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
12355           ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
12356         return DAG.getBitcast(VT, Op);
12357       return SDValue();
12358     };
12359 
12360     // FIXME: If either input vector is bitcast, try to convert the shuffle to
12361     // the result type of this bitcast. This would eliminate at least one
12362     // bitcast. See the transform in InstCombine.
12363     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
12364     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
12365     if (!(SV0 && SV1))
12366       return SDValue();
12367 
12368     int MaskScale =
12369         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
12370     SmallVector<int, 8> NewMask;
12371     for (int M : SVN->getMask())
12372       for (int i = 0; i != MaskScale; ++i)
12373         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
12374 
12375     SDValue LegalShuffle =
12376         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
12377     if (LegalShuffle)
12378       return LegalShuffle;
12379   }
12380 
12381   return SDValue();
12382 }
12383 
visitBUILD_PAIR(SDNode * N)12384 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
12385   EVT VT = N->getValueType(0);
12386   return CombineConsecutiveLoads(N, VT);
12387 }
12388 
visitFREEZE(SDNode * N)12389 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
12390   SDValue N0 = N->getOperand(0);
12391 
12392   // (freeze (freeze x)) -> (freeze x)
12393   if (N0.getOpcode() == ISD::FREEZE)
12394     return N0;
12395 
12396   // If the input is a constant, return it.
12397   if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0))
12398     return N0;
12399 
12400   return SDValue();
12401 }
12402 
12403 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
12404 /// operands. DstEltVT indicates the destination element value type.
12405 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)12406 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
12407   EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
12408 
12409   // If this is already the right type, we're done.
12410   if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
12411 
12412   unsigned SrcBitSize = SrcEltVT.getSizeInBits();
12413   unsigned DstBitSize = DstEltVT.getSizeInBits();
12414 
12415   // If this is a conversion of N elements of one type to N elements of another
12416   // type, convert each element.  This handles FP<->INT cases.
12417   if (SrcBitSize == DstBitSize) {
12418     SmallVector<SDValue, 8> Ops;
12419     for (SDValue Op : BV->op_values()) {
12420       // If the vector element type is not legal, the BUILD_VECTOR operands
12421       // are promoted and implicitly truncated.  Make that explicit here.
12422       if (Op.getValueType() != SrcEltVT)
12423         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
12424       Ops.push_back(DAG.getBitcast(DstEltVT, Op));
12425       AddToWorklist(Ops.back().getNode());
12426     }
12427     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
12428                               BV->getValueType(0).getVectorNumElements());
12429     return DAG.getBuildVector(VT, SDLoc(BV), Ops);
12430   }
12431 
12432   // Otherwise, we're growing or shrinking the elements.  To avoid having to
12433   // handle annoying details of growing/shrinking FP values, we convert them to
12434   // int first.
12435   if (SrcEltVT.isFloatingPoint()) {
12436     // Convert the input float vector to a int vector where the elements are the
12437     // same sizes.
12438     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
12439     BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
12440     SrcEltVT = IntVT;
12441   }
12442 
12443   // Now we know the input is an integer vector.  If the output is a FP type,
12444   // convert to integer first, then to FP of the right size.
12445   if (DstEltVT.isFloatingPoint()) {
12446     EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
12447     SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
12448 
12449     // Next, convert to FP elements of the same size.
12450     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
12451   }
12452 
12453   SDLoc DL(BV);
12454 
12455   // Okay, we know the src/dst types are both integers of differing types.
12456   // Handling growing first.
12457   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
12458   if (SrcBitSize < DstBitSize) {
12459     unsigned NumInputsPerOutput = DstBitSize/SrcBitSize;
12460 
12461     SmallVector<SDValue, 8> Ops;
12462     for (unsigned i = 0, e = BV->getNumOperands(); i != e;
12463          i += NumInputsPerOutput) {
12464       bool isLE = DAG.getDataLayout().isLittleEndian();
12465       APInt NewBits = APInt(DstBitSize, 0);
12466       bool EltIsUndef = true;
12467       for (unsigned j = 0; j != NumInputsPerOutput; ++j) {
12468         // Shift the previously computed bits over.
12469         NewBits <<= SrcBitSize;
12470         SDValue Op = BV->getOperand(i+ (isLE ? (NumInputsPerOutput-j-1) : j));
12471         if (Op.isUndef()) continue;
12472         EltIsUndef = false;
12473 
12474         NewBits |= cast<ConstantSDNode>(Op)->getAPIntValue().
12475                    zextOrTrunc(SrcBitSize).zext(DstBitSize);
12476       }
12477 
12478       if (EltIsUndef)
12479         Ops.push_back(DAG.getUNDEF(DstEltVT));
12480       else
12481         Ops.push_back(DAG.getConstant(NewBits, DL, DstEltVT));
12482     }
12483 
12484     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
12485     return DAG.getBuildVector(VT, DL, Ops);
12486   }
12487 
12488   // Finally, this must be the case where we are shrinking elements: each input
12489   // turns into multiple outputs.
12490   unsigned NumOutputsPerInput = SrcBitSize/DstBitSize;
12491   EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
12492                             NumOutputsPerInput*BV->getNumOperands());
12493   SmallVector<SDValue, 8> Ops;
12494 
12495   for (const SDValue &Op : BV->op_values()) {
12496     if (Op.isUndef()) {
12497       Ops.append(NumOutputsPerInput, DAG.getUNDEF(DstEltVT));
12498       continue;
12499     }
12500 
12501     APInt OpVal = cast<ConstantSDNode>(Op)->
12502                   getAPIntValue().zextOrTrunc(SrcBitSize);
12503 
12504     for (unsigned j = 0; j != NumOutputsPerInput; ++j) {
12505       APInt ThisVal = OpVal.trunc(DstBitSize);
12506       Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT));
12507       OpVal.lshrInPlace(DstBitSize);
12508     }
12509 
12510     // For big endian targets, swap the order of the pieces of each element.
12511     if (DAG.getDataLayout().isBigEndian())
12512       std::reverse(Ops.end()-NumOutputsPerInput, Ops.end());
12513   }
12514 
12515   return DAG.getBuildVector(VT, DL, Ops);
12516 }
12517 
isContractable(SDNode * N)12518 static bool isContractable(SDNode *N) {
12519   SDNodeFlags F = N->getFlags();
12520   return F.hasAllowContract() || F.hasAllowReassociation();
12521 }
12522 
12523 /// Try to perform FMA combining on a given FADD node.
visitFADDForFMACombine(SDNode * N)12524 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
12525   SDValue N0 = N->getOperand(0);
12526   SDValue N1 = N->getOperand(1);
12527   EVT VT = N->getValueType(0);
12528   SDLoc SL(N);
12529 
12530   const TargetOptions &Options = DAG.getTarget().Options;
12531 
12532   // Floating-point multiply-add with intermediate rounding.
12533   bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
12534 
12535   // Floating-point multiply-add without intermediate rounding.
12536   bool HasFMA =
12537       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
12538       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
12539 
12540   // No valid opcode, do not combine.
12541   if (!HasFMAD && !HasFMA)
12542     return SDValue();
12543 
12544   bool CanFuse = Options.UnsafeFPMath || isContractable(N);
12545   bool CanReassociate =
12546       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
12547   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
12548                               CanFuse || HasFMAD);
12549   // If the addition is not contractable, do not combine.
12550   if (!AllowFusionGlobally && !isContractable(N))
12551     return SDValue();
12552 
12553   if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
12554     return SDValue();
12555 
12556   // Always prefer FMAD to FMA for precision.
12557   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
12558   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
12559 
12560   // Is the node an FMUL and contractable either due to global flags or
12561   // SDNodeFlags.
12562   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
12563     if (N.getOpcode() != ISD::FMUL)
12564       return false;
12565     return AllowFusionGlobally || isContractable(N.getNode());
12566   };
12567   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
12568   // prefer to fold the multiply with fewer uses.
12569   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
12570     if (N0.getNode()->use_size() > N1.getNode()->use_size())
12571       std::swap(N0, N1);
12572   }
12573 
12574   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
12575   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
12576     return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
12577                        N0.getOperand(1), N1);
12578   }
12579 
12580   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
12581   // Note: Commutes FADD operands.
12582   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
12583     return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
12584                        N1.getOperand(1), N0);
12585   }
12586 
12587   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
12588   // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
12589   // This requires reassociation because it changes the order of operations.
12590   SDValue FMA, E;
12591   if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode &&
12592       N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
12593       N0.getOperand(2).hasOneUse()) {
12594     FMA = N0;
12595     E = N1;
12596   } else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode &&
12597              N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
12598              N1.getOperand(2).hasOneUse()) {
12599     FMA = N1;
12600     E = N0;
12601   }
12602   if (FMA && E) {
12603     SDValue A = FMA.getOperand(0);
12604     SDValue B = FMA.getOperand(1);
12605     SDValue C = FMA.getOperand(2).getOperand(0);
12606     SDValue D = FMA.getOperand(2).getOperand(1);
12607     SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
12608     return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE);
12609   }
12610 
12611   // Look through FP_EXTEND nodes to do more combining.
12612 
12613   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
12614   if (N0.getOpcode() == ISD::FP_EXTEND) {
12615     SDValue N00 = N0.getOperand(0);
12616     if (isContractableFMUL(N00) &&
12617         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12618                             N00.getValueType())) {
12619       return DAG.getNode(PreferredFusedOpcode, SL, VT,
12620                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
12621                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
12622                          N1);
12623     }
12624   }
12625 
12626   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
12627   // Note: Commutes FADD operands.
12628   if (N1.getOpcode() == ISD::FP_EXTEND) {
12629     SDValue N10 = N1.getOperand(0);
12630     if (isContractableFMUL(N10) &&
12631         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12632                             N10.getValueType())) {
12633       return DAG.getNode(PreferredFusedOpcode, SL, VT,
12634                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
12635                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)),
12636                          N0);
12637     }
12638   }
12639 
12640   // More folding opportunities when target permits.
12641   if (Aggressive) {
12642     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
12643     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
12644     auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
12645                                     SDValue Z) {
12646       return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
12647                          DAG.getNode(PreferredFusedOpcode, SL, VT,
12648                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
12649                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
12650                                      Z));
12651     };
12652     if (N0.getOpcode() == PreferredFusedOpcode) {
12653       SDValue N02 = N0.getOperand(2);
12654       if (N02.getOpcode() == ISD::FP_EXTEND) {
12655         SDValue N020 = N02.getOperand(0);
12656         if (isContractableFMUL(N020) &&
12657             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12658                                 N020.getValueType())) {
12659           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
12660                                       N020.getOperand(0), N020.getOperand(1),
12661                                       N1);
12662         }
12663       }
12664     }
12665 
12666     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
12667     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
12668     // FIXME: This turns two single-precision and one double-precision
12669     // operation into two double-precision operations, which might not be
12670     // interesting for all targets, especially GPUs.
12671     auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
12672                                     SDValue Z) {
12673       return DAG.getNode(
12674           PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
12675           DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
12676           DAG.getNode(PreferredFusedOpcode, SL, VT,
12677                       DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
12678                       DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
12679     };
12680     if (N0.getOpcode() == ISD::FP_EXTEND) {
12681       SDValue N00 = N0.getOperand(0);
12682       if (N00.getOpcode() == PreferredFusedOpcode) {
12683         SDValue N002 = N00.getOperand(2);
12684         if (isContractableFMUL(N002) &&
12685             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12686                                 N00.getValueType())) {
12687           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
12688                                       N002.getOperand(0), N002.getOperand(1),
12689                                       N1);
12690         }
12691       }
12692     }
12693 
12694     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
12695     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
12696     if (N1.getOpcode() == PreferredFusedOpcode) {
12697       SDValue N12 = N1.getOperand(2);
12698       if (N12.getOpcode() == ISD::FP_EXTEND) {
12699         SDValue N120 = N12.getOperand(0);
12700         if (isContractableFMUL(N120) &&
12701             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12702                                 N120.getValueType())) {
12703           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
12704                                       N120.getOperand(0), N120.getOperand(1),
12705                                       N0);
12706         }
12707       }
12708     }
12709 
12710     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
12711     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
12712     // FIXME: This turns two single-precision and one double-precision
12713     // operation into two double-precision operations, which might not be
12714     // interesting for all targets, especially GPUs.
12715     if (N1.getOpcode() == ISD::FP_EXTEND) {
12716       SDValue N10 = N1.getOperand(0);
12717       if (N10.getOpcode() == PreferredFusedOpcode) {
12718         SDValue N102 = N10.getOperand(2);
12719         if (isContractableFMUL(N102) &&
12720             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12721                                 N10.getValueType())) {
12722           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
12723                                       N102.getOperand(0), N102.getOperand(1),
12724                                       N0);
12725         }
12726       }
12727     }
12728   }
12729 
12730   return SDValue();
12731 }
12732 
12733 /// Try to perform FMA combining on a given FSUB node.
visitFSUBForFMACombine(SDNode * N)12734 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
12735   SDValue N0 = N->getOperand(0);
12736   SDValue N1 = N->getOperand(1);
12737   EVT VT = N->getValueType(0);
12738   SDLoc SL(N);
12739 
12740   const TargetOptions &Options = DAG.getTarget().Options;
12741   // Floating-point multiply-add with intermediate rounding.
12742   bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
12743 
12744   // Floating-point multiply-add without intermediate rounding.
12745   bool HasFMA =
12746       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
12747       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
12748 
12749   // No valid opcode, do not combine.
12750   if (!HasFMAD && !HasFMA)
12751     return SDValue();
12752 
12753   const SDNodeFlags Flags = N->getFlags();
12754   bool CanFuse = Options.UnsafeFPMath || isContractable(N);
12755   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
12756                               CanFuse || HasFMAD);
12757 
12758   // If the subtraction is not contractable, do not combine.
12759   if (!AllowFusionGlobally && !isContractable(N))
12760     return SDValue();
12761 
12762   if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
12763     return SDValue();
12764 
12765   // Always prefer FMAD to FMA for precision.
12766   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
12767   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
12768   bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
12769 
12770   // Is the node an FMUL and contractable either due to global flags or
12771   // SDNodeFlags.
12772   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
12773     if (N.getOpcode() != ISD::FMUL)
12774       return false;
12775     return AllowFusionGlobally || isContractable(N.getNode());
12776   };
12777 
12778   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
12779   auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
12780     if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
12781       return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
12782                          XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z));
12783     }
12784     return SDValue();
12785   };
12786 
12787   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
12788   // Note: Commutes FSUB operands.
12789   auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
12790     if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
12791       return DAG.getNode(PreferredFusedOpcode, SL, VT,
12792                          DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
12793                          YZ.getOperand(1), X);
12794     }
12795     return SDValue();
12796   };
12797 
12798   // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
12799   // prefer to fold the multiply with fewer uses.
12800   if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
12801       (N0.getNode()->use_size() > N1.getNode()->use_size())) {
12802     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
12803     if (SDValue V = tryToFoldXSubYZ(N0, N1))
12804       return V;
12805     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
12806     if (SDValue V = tryToFoldXYSubZ(N0, N1))
12807       return V;
12808   } else {
12809     // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
12810     if (SDValue V = tryToFoldXYSubZ(N0, N1))
12811       return V;
12812     // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
12813     if (SDValue V = tryToFoldXSubYZ(N0, N1))
12814       return V;
12815   }
12816 
12817   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
12818   if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
12819       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
12820     SDValue N00 = N0.getOperand(0).getOperand(0);
12821     SDValue N01 = N0.getOperand(0).getOperand(1);
12822     return DAG.getNode(PreferredFusedOpcode, SL, VT,
12823                        DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
12824                        DAG.getNode(ISD::FNEG, SL, VT, N1));
12825   }
12826 
12827   // Look through FP_EXTEND nodes to do more combining.
12828 
12829   // fold (fsub (fpext (fmul x, y)), z)
12830   //   -> (fma (fpext x), (fpext y), (fneg z))
12831   if (N0.getOpcode() == ISD::FP_EXTEND) {
12832     SDValue N00 = N0.getOperand(0);
12833     if (isContractableFMUL(N00) &&
12834         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12835                             N00.getValueType())) {
12836       return DAG.getNode(PreferredFusedOpcode, SL, VT,
12837                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
12838                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
12839                          DAG.getNode(ISD::FNEG, SL, VT, N1));
12840     }
12841   }
12842 
12843   // fold (fsub x, (fpext (fmul y, z)))
12844   //   -> (fma (fneg (fpext y)), (fpext z), x)
12845   // Note: Commutes FSUB operands.
12846   if (N1.getOpcode() == ISD::FP_EXTEND) {
12847     SDValue N10 = N1.getOperand(0);
12848     if (isContractableFMUL(N10) &&
12849         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12850                             N10.getValueType())) {
12851       return DAG.getNode(
12852           PreferredFusedOpcode, SL, VT,
12853           DAG.getNode(ISD::FNEG, SL, VT,
12854                       DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
12855           DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
12856     }
12857   }
12858 
12859   // fold (fsub (fpext (fneg (fmul, x, y))), z)
12860   //   -> (fneg (fma (fpext x), (fpext y), z))
12861   // Note: This could be removed with appropriate canonicalization of the
12862   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
12863   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
12864   // from implementing the canonicalization in visitFSUB.
12865   if (N0.getOpcode() == ISD::FP_EXTEND) {
12866     SDValue N00 = N0.getOperand(0);
12867     if (N00.getOpcode() == ISD::FNEG) {
12868       SDValue N000 = N00.getOperand(0);
12869       if (isContractableFMUL(N000) &&
12870           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12871                               N00.getValueType())) {
12872         return DAG.getNode(
12873             ISD::FNEG, SL, VT,
12874             DAG.getNode(PreferredFusedOpcode, SL, VT,
12875                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
12876                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
12877                         N1));
12878       }
12879     }
12880   }
12881 
12882   // fold (fsub (fneg (fpext (fmul, x, y))), z)
12883   //   -> (fneg (fma (fpext x)), (fpext y), z)
12884   // Note: This could be removed with appropriate canonicalization of the
12885   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
12886   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
12887   // from implementing the canonicalization in visitFSUB.
12888   if (N0.getOpcode() == ISD::FNEG) {
12889     SDValue N00 = N0.getOperand(0);
12890     if (N00.getOpcode() == ISD::FP_EXTEND) {
12891       SDValue N000 = N00.getOperand(0);
12892       if (isContractableFMUL(N000) &&
12893           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12894                               N000.getValueType())) {
12895         return DAG.getNode(
12896             ISD::FNEG, SL, VT,
12897             DAG.getNode(PreferredFusedOpcode, SL, VT,
12898                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
12899                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
12900                         N1));
12901       }
12902     }
12903   }
12904 
12905   // More folding opportunities when target permits.
12906   if (Aggressive) {
12907     // fold (fsub (fma x, y, (fmul u, v)), z)
12908     //   -> (fma x, y (fma u, v, (fneg z)))
12909     if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
12910         isContractableFMUL(N0.getOperand(2)) && N0->hasOneUse() &&
12911         N0.getOperand(2)->hasOneUse()) {
12912       return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
12913                          N0.getOperand(1),
12914                          DAG.getNode(PreferredFusedOpcode, SL, VT,
12915                                      N0.getOperand(2).getOperand(0),
12916                                      N0.getOperand(2).getOperand(1),
12917                                      DAG.getNode(ISD::FNEG, SL, VT, N1)));
12918     }
12919 
12920     // fold (fsub x, (fma y, z, (fmul u, v)))
12921     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
12922     if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
12923         isContractableFMUL(N1.getOperand(2)) &&
12924         N1->hasOneUse() && NoSignedZero) {
12925       SDValue N20 = N1.getOperand(2).getOperand(0);
12926       SDValue N21 = N1.getOperand(2).getOperand(1);
12927       return DAG.getNode(
12928           PreferredFusedOpcode, SL, VT,
12929           DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
12930           DAG.getNode(PreferredFusedOpcode, SL, VT,
12931                       DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
12932     }
12933 
12934 
12935     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
12936     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
12937     if (N0.getOpcode() == PreferredFusedOpcode &&
12938         N0->hasOneUse()) {
12939       SDValue N02 = N0.getOperand(2);
12940       if (N02.getOpcode() == ISD::FP_EXTEND) {
12941         SDValue N020 = N02.getOperand(0);
12942         if (isContractableFMUL(N020) &&
12943             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12944                                 N020.getValueType())) {
12945           return DAG.getNode(
12946               PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
12947               DAG.getNode(
12948                   PreferredFusedOpcode, SL, VT,
12949                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
12950                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
12951                   DAG.getNode(ISD::FNEG, SL, VT, N1)));
12952         }
12953       }
12954     }
12955 
12956     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
12957     //   -> (fma (fpext x), (fpext y),
12958     //           (fma (fpext u), (fpext v), (fneg z)))
12959     // FIXME: This turns two single-precision and one double-precision
12960     // operation into two double-precision operations, which might not be
12961     // interesting for all targets, especially GPUs.
12962     if (N0.getOpcode() == ISD::FP_EXTEND) {
12963       SDValue N00 = N0.getOperand(0);
12964       if (N00.getOpcode() == PreferredFusedOpcode) {
12965         SDValue N002 = N00.getOperand(2);
12966         if (isContractableFMUL(N002) &&
12967             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12968                                 N00.getValueType())) {
12969           return DAG.getNode(
12970               PreferredFusedOpcode, SL, VT,
12971               DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
12972               DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
12973               DAG.getNode(
12974                   PreferredFusedOpcode, SL, VT,
12975                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
12976                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
12977                   DAG.getNode(ISD::FNEG, SL, VT, N1)));
12978         }
12979       }
12980     }
12981 
12982     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
12983     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
12984     if (N1.getOpcode() == PreferredFusedOpcode &&
12985         N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
12986         N1->hasOneUse()) {
12987       SDValue N120 = N1.getOperand(2).getOperand(0);
12988       if (isContractableFMUL(N120) &&
12989           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12990                               N120.getValueType())) {
12991         SDValue N1200 = N120.getOperand(0);
12992         SDValue N1201 = N120.getOperand(1);
12993         return DAG.getNode(
12994             PreferredFusedOpcode, SL, VT,
12995             DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
12996             DAG.getNode(PreferredFusedOpcode, SL, VT,
12997                         DAG.getNode(ISD::FNEG, SL, VT,
12998                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
12999                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
13000       }
13001     }
13002 
13003     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
13004     //   -> (fma (fneg (fpext y)), (fpext z),
13005     //           (fma (fneg (fpext u)), (fpext v), x))
13006     // FIXME: This turns two single-precision and one double-precision
13007     // operation into two double-precision operations, which might not be
13008     // interesting for all targets, especially GPUs.
13009     if (N1.getOpcode() == ISD::FP_EXTEND &&
13010         N1.getOperand(0).getOpcode() == PreferredFusedOpcode) {
13011       SDValue CvtSrc = N1.getOperand(0);
13012       SDValue N100 = CvtSrc.getOperand(0);
13013       SDValue N101 = CvtSrc.getOperand(1);
13014       SDValue N102 = CvtSrc.getOperand(2);
13015       if (isContractableFMUL(N102) &&
13016           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13017                               CvtSrc.getValueType())) {
13018         SDValue N1020 = N102.getOperand(0);
13019         SDValue N1021 = N102.getOperand(1);
13020         return DAG.getNode(
13021             PreferredFusedOpcode, SL, VT,
13022             DAG.getNode(ISD::FNEG, SL, VT,
13023                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)),
13024             DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
13025             DAG.getNode(PreferredFusedOpcode, SL, VT,
13026                         DAG.getNode(ISD::FNEG, SL, VT,
13027                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
13028                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
13029       }
13030     }
13031   }
13032 
13033   return SDValue();
13034 }
13035 
13036 /// Try to perform FMA combining on a given FMUL node based on the distributive
13037 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
13038 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)13039 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
13040   SDValue N0 = N->getOperand(0);
13041   SDValue N1 = N->getOperand(1);
13042   EVT VT = N->getValueType(0);
13043   SDLoc SL(N);
13044 
13045   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
13046 
13047   const TargetOptions &Options = DAG.getTarget().Options;
13048 
13049   // The transforms below are incorrect when x == 0 and y == inf, because the
13050   // intermediate multiplication produces a nan.
13051   if (!Options.NoInfsFPMath)
13052     return SDValue();
13053 
13054   // Floating-point multiply-add without intermediate rounding.
13055   bool HasFMA =
13056       (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
13057       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
13058       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
13059 
13060   // Floating-point multiply-add with intermediate rounding. This can result
13061   // in a less precise result due to the changed rounding order.
13062   bool HasFMAD = Options.UnsafeFPMath &&
13063                  (LegalOperations && TLI.isFMADLegal(DAG, N));
13064 
13065   // No valid opcode, do not combine.
13066   if (!HasFMAD && !HasFMA)
13067     return SDValue();
13068 
13069   // Always prefer FMAD to FMA for precision.
13070   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
13071   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
13072 
13073   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
13074   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
13075   auto FuseFADD = [&](SDValue X, SDValue Y) {
13076     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
13077       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
13078         if (C->isExactlyValue(+1.0))
13079           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
13080                              Y);
13081         if (C->isExactlyValue(-1.0))
13082           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
13083                              DAG.getNode(ISD::FNEG, SL, VT, Y));
13084       }
13085     }
13086     return SDValue();
13087   };
13088 
13089   if (SDValue FMA = FuseFADD(N0, N1))
13090     return FMA;
13091   if (SDValue FMA = FuseFADD(N1, N0))
13092     return FMA;
13093 
13094   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
13095   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
13096   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
13097   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
13098   auto FuseFSUB = [&](SDValue X, SDValue Y) {
13099     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
13100       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
13101         if (C0->isExactlyValue(+1.0))
13102           return DAG.getNode(PreferredFusedOpcode, SL, VT,
13103                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
13104                              Y);
13105         if (C0->isExactlyValue(-1.0))
13106           return DAG.getNode(PreferredFusedOpcode, SL, VT,
13107                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
13108                              DAG.getNode(ISD::FNEG, SL, VT, Y));
13109       }
13110       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
13111         if (C1->isExactlyValue(+1.0))
13112           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
13113                              DAG.getNode(ISD::FNEG, SL, VT, Y));
13114         if (C1->isExactlyValue(-1.0))
13115           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
13116                              Y);
13117       }
13118     }
13119     return SDValue();
13120   };
13121 
13122   if (SDValue FMA = FuseFSUB(N0, N1))
13123     return FMA;
13124   if (SDValue FMA = FuseFSUB(N1, N0))
13125     return FMA;
13126 
13127   return SDValue();
13128 }
13129 
visitFADD(SDNode * N)13130 SDValue DAGCombiner::visitFADD(SDNode *N) {
13131   SDValue N0 = N->getOperand(0);
13132   SDValue N1 = N->getOperand(1);
13133   bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
13134   bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
13135   EVT VT = N->getValueType(0);
13136   SDLoc DL(N);
13137   const TargetOptions &Options = DAG.getTarget().Options;
13138   SDNodeFlags Flags = N->getFlags();
13139   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13140 
13141   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
13142     return R;
13143 
13144   // fold vector ops
13145   if (VT.isVector())
13146     if (SDValue FoldedVOp = SimplifyVBinOp(N))
13147       return FoldedVOp;
13148 
13149   // fold (fadd c1, c2) -> c1 + c2
13150   if (N0CFP && N1CFP)
13151     return DAG.getNode(ISD::FADD, DL, VT, N0, N1);
13152 
13153   // canonicalize constant to RHS
13154   if (N0CFP && !N1CFP)
13155     return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
13156 
13157   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
13158   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
13159   if (N1C && N1C->isZero())
13160     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
13161       return N0;
13162 
13163   if (SDValue NewSel = foldBinOpIntoSelect(N))
13164     return NewSel;
13165 
13166   // fold (fadd A, (fneg B)) -> (fsub A, B)
13167   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
13168     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
13169             N1, DAG, LegalOperations, ForCodeSize))
13170       return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
13171 
13172   // fold (fadd (fneg A), B) -> (fsub B, A)
13173   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
13174     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
13175             N0, DAG, LegalOperations, ForCodeSize))
13176       return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
13177 
13178   auto isFMulNegTwo = [](SDValue FMul) {
13179     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
13180       return false;
13181     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
13182     return C && C->isExactlyValue(-2.0);
13183   };
13184 
13185   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
13186   if (isFMulNegTwo(N0)) {
13187     SDValue B = N0.getOperand(0);
13188     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
13189     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
13190   }
13191   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
13192   if (isFMulNegTwo(N1)) {
13193     SDValue B = N1.getOperand(0);
13194     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
13195     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
13196   }
13197 
13198   // No FP constant should be created after legalization as Instruction
13199   // Selection pass has a hard time dealing with FP constants.
13200   bool AllowNewConst = (Level < AfterLegalizeDAG);
13201 
13202   // If nnan is enabled, fold lots of things.
13203   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
13204     // If allowed, fold (fadd (fneg x), x) -> 0.0
13205     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
13206       return DAG.getConstantFP(0.0, DL, VT);
13207 
13208     // If allowed, fold (fadd x, (fneg x)) -> 0.0
13209     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
13210       return DAG.getConstantFP(0.0, DL, VT);
13211   }
13212 
13213   // If 'unsafe math' or reassoc and nsz, fold lots of things.
13214   // TODO: break out portions of the transformations below for which Unsafe is
13215   //       considered and which do not require both nsz and reassoc
13216   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
13217        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
13218       AllowNewConst) {
13219     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
13220     if (N1CFP && N0.getOpcode() == ISD::FADD &&
13221         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
13222       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
13223       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
13224     }
13225 
13226     // We can fold chains of FADD's of the same value into multiplications.
13227     // This transform is not safe in general because we are reducing the number
13228     // of rounding steps.
13229     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
13230       if (N0.getOpcode() == ISD::FMUL) {
13231         bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
13232         bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
13233 
13234         // (fadd (fmul x, c), x) -> (fmul x, c+1)
13235         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
13236           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
13237                                        DAG.getConstantFP(1.0, DL, VT));
13238           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
13239         }
13240 
13241         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
13242         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
13243             N1.getOperand(0) == N1.getOperand(1) &&
13244             N0.getOperand(0) == N1.getOperand(0)) {
13245           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
13246                                        DAG.getConstantFP(2.0, DL, VT));
13247           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
13248         }
13249       }
13250 
13251       if (N1.getOpcode() == ISD::FMUL) {
13252         bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
13253         bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
13254 
13255         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
13256         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
13257           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
13258                                        DAG.getConstantFP(1.0, DL, VT));
13259           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
13260         }
13261 
13262         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
13263         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
13264             N0.getOperand(0) == N0.getOperand(1) &&
13265             N1.getOperand(0) == N0.getOperand(0)) {
13266           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
13267                                        DAG.getConstantFP(2.0, DL, VT));
13268           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
13269         }
13270       }
13271 
13272       if (N0.getOpcode() == ISD::FADD) {
13273         bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
13274         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
13275         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
13276             (N0.getOperand(0) == N1)) {
13277           return DAG.getNode(ISD::FMUL, DL, VT, N1,
13278                              DAG.getConstantFP(3.0, DL, VT));
13279         }
13280       }
13281 
13282       if (N1.getOpcode() == ISD::FADD) {
13283         bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
13284         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
13285         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
13286             N1.getOperand(0) == N0) {
13287           return DAG.getNode(ISD::FMUL, DL, VT, N0,
13288                              DAG.getConstantFP(3.0, DL, VT));
13289         }
13290       }
13291 
13292       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
13293       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
13294           N0.getOperand(0) == N0.getOperand(1) &&
13295           N1.getOperand(0) == N1.getOperand(1) &&
13296           N0.getOperand(0) == N1.getOperand(0)) {
13297         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
13298                            DAG.getConstantFP(4.0, DL, VT));
13299       }
13300     }
13301   } // enable-unsafe-fp-math
13302 
13303   // FADD -> FMA combines:
13304   if (SDValue Fused = visitFADDForFMACombine(N)) {
13305     AddToWorklist(Fused.getNode());
13306     return Fused;
13307   }
13308   return SDValue();
13309 }
13310 
visitSTRICT_FADD(SDNode * N)13311 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
13312   SDValue Chain = N->getOperand(0);
13313   SDValue N0 = N->getOperand(1);
13314   SDValue N1 = N->getOperand(2);
13315   EVT VT = N->getValueType(0);
13316   EVT ChainVT = N->getValueType(1);
13317   SDLoc DL(N);
13318   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13319 
13320   // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
13321   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
13322     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
13323             N1, DAG, LegalOperations, ForCodeSize)) {
13324       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
13325                          {Chain, N0, NegN1});
13326     }
13327 
13328   // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
13329   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
13330     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
13331             N0, DAG, LegalOperations, ForCodeSize)) {
13332       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
13333                          {Chain, N1, NegN0});
13334     }
13335   return SDValue();
13336 }
13337 
visitFSUB(SDNode * N)13338 SDValue DAGCombiner::visitFSUB(SDNode *N) {
13339   SDValue N0 = N->getOperand(0);
13340   SDValue N1 = N->getOperand(1);
13341   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
13342   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
13343   EVT VT = N->getValueType(0);
13344   SDLoc DL(N);
13345   const TargetOptions &Options = DAG.getTarget().Options;
13346   const SDNodeFlags Flags = N->getFlags();
13347   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13348 
13349   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
13350     return R;
13351 
13352   // fold vector ops
13353   if (VT.isVector())
13354     if (SDValue FoldedVOp = SimplifyVBinOp(N))
13355       return FoldedVOp;
13356 
13357   // fold (fsub c1, c2) -> c1-c2
13358   if (N0CFP && N1CFP)
13359     return DAG.getNode(ISD::FSUB, DL, VT, N0, N1);
13360 
13361   if (SDValue NewSel = foldBinOpIntoSelect(N))
13362     return NewSel;
13363 
13364   // (fsub A, 0) -> A
13365   if (N1CFP && N1CFP->isZero()) {
13366     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
13367         Flags.hasNoSignedZeros()) {
13368       return N0;
13369     }
13370   }
13371 
13372   if (N0 == N1) {
13373     // (fsub x, x) -> 0.0
13374     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
13375       return DAG.getConstantFP(0.0f, DL, VT);
13376   }
13377 
13378   // (fsub -0.0, N1) -> -N1
13379   if (N0CFP && N0CFP->isZero()) {
13380     if (N0CFP->isNegative() ||
13381         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
13382       // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
13383       // flushed to zero, unless all users treat denorms as zero (DAZ).
13384       // FIXME: This transform will change the sign of a NaN and the behavior
13385       // of a signaling NaN. It is only valid when a NoNaN flag is present.
13386       DenormalMode DenormMode = DAG.getDenormalMode(VT);
13387       if (DenormMode == DenormalMode::getIEEE()) {
13388         if (SDValue NegN1 =
13389                 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
13390           return NegN1;
13391         if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
13392           return DAG.getNode(ISD::FNEG, DL, VT, N1);
13393       }
13394     }
13395   }
13396 
13397   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
13398        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
13399       N1.getOpcode() == ISD::FADD) {
13400     // X - (X + Y) -> -Y
13401     if (N0 == N1->getOperand(0))
13402       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
13403     // X - (Y + X) -> -Y
13404     if (N0 == N1->getOperand(1))
13405       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
13406   }
13407 
13408   // fold (fsub A, (fneg B)) -> (fadd A, B)
13409   if (SDValue NegN1 =
13410           TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
13411     return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
13412 
13413   // FSUB -> FMA combines:
13414   if (SDValue Fused = visitFSUBForFMACombine(N)) {
13415     AddToWorklist(Fused.getNode());
13416     return Fused;
13417   }
13418 
13419   return SDValue();
13420 }
13421 
visitFMUL(SDNode * N)13422 SDValue DAGCombiner::visitFMUL(SDNode *N) {
13423   SDValue N0 = N->getOperand(0);
13424   SDValue N1 = N->getOperand(1);
13425   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
13426   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
13427   EVT VT = N->getValueType(0);
13428   SDLoc DL(N);
13429   const TargetOptions &Options = DAG.getTarget().Options;
13430   const SDNodeFlags Flags = N->getFlags();
13431   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13432 
13433   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
13434     return R;
13435 
13436   // fold vector ops
13437   if (VT.isVector()) {
13438     // This just handles C1 * C2 for vectors. Other vector folds are below.
13439     if (SDValue FoldedVOp = SimplifyVBinOp(N))
13440       return FoldedVOp;
13441   }
13442 
13443   // fold (fmul c1, c2) -> c1*c2
13444   if (N0CFP && N1CFP)
13445     return DAG.getNode(ISD::FMUL, DL, VT, N0, N1);
13446 
13447   // canonicalize constant to RHS
13448   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
13449      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
13450     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
13451 
13452   if (SDValue NewSel = foldBinOpIntoSelect(N))
13453     return NewSel;
13454 
13455   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
13456     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
13457     if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
13458         N0.getOpcode() == ISD::FMUL) {
13459       SDValue N00 = N0.getOperand(0);
13460       SDValue N01 = N0.getOperand(1);
13461       // Avoid an infinite loop by making sure that N00 is not a constant
13462       // (the inner multiply has not been constant folded yet).
13463       if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
13464           !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
13465         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
13466         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
13467       }
13468     }
13469 
13470     // Match a special-case: we convert X * 2.0 into fadd.
13471     // fmul (fadd X, X), C -> fmul X, 2.0 * C
13472     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
13473         N0.getOperand(0) == N0.getOperand(1)) {
13474       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
13475       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
13476       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
13477     }
13478   }
13479 
13480   // fold (fmul X, 2.0) -> (fadd X, X)
13481   if (N1CFP && N1CFP->isExactlyValue(+2.0))
13482     return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
13483 
13484   // fold (fmul X, -1.0) -> (fneg X)
13485   if (N1CFP && N1CFP->isExactlyValue(-1.0))
13486     if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
13487       return DAG.getNode(ISD::FNEG, DL, VT, N0);
13488 
13489   // -N0 * -N1 --> N0 * N1
13490   TargetLowering::NegatibleCost CostN0 =
13491       TargetLowering::NegatibleCost::Expensive;
13492   TargetLowering::NegatibleCost CostN1 =
13493       TargetLowering::NegatibleCost::Expensive;
13494   SDValue NegN0 =
13495       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
13496   SDValue NegN1 =
13497       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
13498   if (NegN0 && NegN1 &&
13499       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
13500        CostN1 == TargetLowering::NegatibleCost::Cheaper))
13501     return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
13502 
13503   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
13504   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
13505   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
13506       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
13507       TLI.isOperationLegal(ISD::FABS, VT)) {
13508     SDValue Select = N0, X = N1;
13509     if (Select.getOpcode() != ISD::SELECT)
13510       std::swap(Select, X);
13511 
13512     SDValue Cond = Select.getOperand(0);
13513     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
13514     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
13515 
13516     if (TrueOpnd && FalseOpnd &&
13517         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
13518         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
13519         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
13520       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
13521       switch (CC) {
13522       default: break;
13523       case ISD::SETOLT:
13524       case ISD::SETULT:
13525       case ISD::SETOLE:
13526       case ISD::SETULE:
13527       case ISD::SETLT:
13528       case ISD::SETLE:
13529         std::swap(TrueOpnd, FalseOpnd);
13530         LLVM_FALLTHROUGH;
13531       case ISD::SETOGT:
13532       case ISD::SETUGT:
13533       case ISD::SETOGE:
13534       case ISD::SETUGE:
13535       case ISD::SETGT:
13536       case ISD::SETGE:
13537         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
13538             TLI.isOperationLegal(ISD::FNEG, VT))
13539           return DAG.getNode(ISD::FNEG, DL, VT,
13540                    DAG.getNode(ISD::FABS, DL, VT, X));
13541         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
13542           return DAG.getNode(ISD::FABS, DL, VT, X);
13543 
13544         break;
13545       }
13546     }
13547   }
13548 
13549   // FMUL -> FMA combines:
13550   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
13551     AddToWorklist(Fused.getNode());
13552     return Fused;
13553   }
13554 
13555   return SDValue();
13556 }
13557 
visitFMA(SDNode * N)13558 SDValue DAGCombiner::visitFMA(SDNode *N) {
13559   SDValue N0 = N->getOperand(0);
13560   SDValue N1 = N->getOperand(1);
13561   SDValue N2 = N->getOperand(2);
13562   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
13563   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
13564   EVT VT = N->getValueType(0);
13565   SDLoc DL(N);
13566   const TargetOptions &Options = DAG.getTarget().Options;
13567   // FMA nodes have flags that propagate to the created nodes.
13568   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13569 
13570   bool UnsafeFPMath =
13571       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
13572 
13573   // Constant fold FMA.
13574   if (isa<ConstantFPSDNode>(N0) &&
13575       isa<ConstantFPSDNode>(N1) &&
13576       isa<ConstantFPSDNode>(N2)) {
13577     return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
13578   }
13579 
13580   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
13581   TargetLowering::NegatibleCost CostN0 =
13582       TargetLowering::NegatibleCost::Expensive;
13583   TargetLowering::NegatibleCost CostN1 =
13584       TargetLowering::NegatibleCost::Expensive;
13585   SDValue NegN0 =
13586       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
13587   SDValue NegN1 =
13588       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
13589   if (NegN0 && NegN1 &&
13590       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
13591        CostN1 == TargetLowering::NegatibleCost::Cheaper))
13592     return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
13593 
13594   if (UnsafeFPMath) {
13595     if (N0CFP && N0CFP->isZero())
13596       return N2;
13597     if (N1CFP && N1CFP->isZero())
13598       return N2;
13599   }
13600 
13601   if (N0CFP && N0CFP->isExactlyValue(1.0))
13602     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
13603   if (N1CFP && N1CFP->isExactlyValue(1.0))
13604     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
13605 
13606   // Canonicalize (fma c, x, y) -> (fma x, c, y)
13607   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
13608      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
13609     return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
13610 
13611   if (UnsafeFPMath) {
13612     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
13613     if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
13614         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
13615         DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
13616       return DAG.getNode(ISD::FMUL, DL, VT, N0,
13617                          DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
13618     }
13619 
13620     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
13621     if (N0.getOpcode() == ISD::FMUL &&
13622         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
13623         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
13624       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
13625                          DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)),
13626                          N2);
13627     }
13628   }
13629 
13630   // (fma x, -1, y) -> (fadd (fneg x), y)
13631   if (N1CFP) {
13632     if (N1CFP->isExactlyValue(1.0))
13633       return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
13634 
13635     if (N1CFP->isExactlyValue(-1.0) &&
13636         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
13637       SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
13638       AddToWorklist(RHSNeg.getNode());
13639       return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
13640     }
13641 
13642     // fma (fneg x), K, y -> fma x -K, y
13643     if (N0.getOpcode() == ISD::FNEG &&
13644         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
13645          (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
13646                                               ForCodeSize)))) {
13647       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
13648                          DAG.getNode(ISD::FNEG, DL, VT, N1), N2);
13649     }
13650   }
13651 
13652   if (UnsafeFPMath) {
13653     // (fma x, c, x) -> (fmul x, (c+1))
13654     if (N1CFP && N0 == N2) {
13655       return DAG.getNode(
13656           ISD::FMUL, DL, VT, N0,
13657           DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT)));
13658     }
13659 
13660     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
13661     if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
13662       return DAG.getNode(
13663           ISD::FMUL, DL, VT, N0,
13664           DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT)));
13665     }
13666   }
13667 
13668   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
13669   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
13670   if (!TLI.isFNegFree(VT))
13671     if (SDValue Neg = TLI.getCheaperNegatedExpression(
13672             SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
13673       return DAG.getNode(ISD::FNEG, DL, VT, Neg);
13674   return SDValue();
13675 }
13676 
13677 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
13678 // reciprocal.
13679 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
13680 // Notice that this is not always beneficial. One reason is different targets
13681 // may have different costs for FDIV and FMUL, so sometimes the cost of two
13682 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
13683 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)13684 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
13685   // TODO: Limit this transform based on optsize/minsize - it always creates at
13686   //       least 1 extra instruction. But the perf win may be substantial enough
13687   //       that only minsize should restrict this.
13688   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
13689   const SDNodeFlags Flags = N->getFlags();
13690   if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
13691     return SDValue();
13692 
13693   // Skip if current node is a reciprocal/fneg-reciprocal.
13694   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
13695   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
13696   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
13697     return SDValue();
13698 
13699   // Exit early if the target does not want this transform or if there can't
13700   // possibly be enough uses of the divisor to make the transform worthwhile.
13701   unsigned MinUses = TLI.combineRepeatedFPDivisors();
13702 
13703   // For splat vectors, scale the number of uses by the splat factor. If we can
13704   // convert the division into a scalar op, that will likely be much faster.
13705   unsigned NumElts = 1;
13706   EVT VT = N->getValueType(0);
13707   if (VT.isVector() && DAG.isSplatValue(N1))
13708     NumElts = VT.getVectorNumElements();
13709 
13710   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
13711     return SDValue();
13712 
13713   // Find all FDIV users of the same divisor.
13714   // Use a set because duplicates may be present in the user list.
13715   SetVector<SDNode *> Users;
13716   for (auto *U : N1->uses()) {
13717     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
13718       // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
13719       if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
13720           U->getOperand(0) == U->getOperand(1).getOperand(0) &&
13721           U->getFlags().hasAllowReassociation() &&
13722           U->getFlags().hasNoSignedZeros())
13723         continue;
13724 
13725       // This division is eligible for optimization only if global unsafe math
13726       // is enabled or if this division allows reciprocal formation.
13727       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
13728         Users.insert(U);
13729     }
13730   }
13731 
13732   // Now that we have the actual number of divisor uses, make sure it meets
13733   // the minimum threshold specified by the target.
13734   if ((Users.size() * NumElts) < MinUses)
13735     return SDValue();
13736 
13737   SDLoc DL(N);
13738   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
13739   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
13740 
13741   // Dividend / Divisor -> Dividend * Reciprocal
13742   for (auto *U : Users) {
13743     SDValue Dividend = U->getOperand(0);
13744     if (Dividend != FPOne) {
13745       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
13746                                     Reciprocal, Flags);
13747       CombineTo(U, NewNode);
13748     } else if (U != Reciprocal.getNode()) {
13749       // In the absence of fast-math-flags, this user node is always the
13750       // same node as Reciprocal, but with FMF they may be different nodes.
13751       CombineTo(U, Reciprocal);
13752     }
13753   }
13754   return SDValue(N, 0);  // N was replaced.
13755 }
13756 
visitFDIV(SDNode * N)13757 SDValue DAGCombiner::visitFDIV(SDNode *N) {
13758   SDValue N0 = N->getOperand(0);
13759   SDValue N1 = N->getOperand(1);
13760   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
13761   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
13762   EVT VT = N->getValueType(0);
13763   SDLoc DL(N);
13764   const TargetOptions &Options = DAG.getTarget().Options;
13765   SDNodeFlags Flags = N->getFlags();
13766   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13767 
13768   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
13769     return R;
13770 
13771   // fold vector ops
13772   if (VT.isVector())
13773     if (SDValue FoldedVOp = SimplifyVBinOp(N))
13774       return FoldedVOp;
13775 
13776   // fold (fdiv c1, c2) -> c1/c2
13777   if (N0CFP && N1CFP)
13778     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1);
13779 
13780   if (SDValue NewSel = foldBinOpIntoSelect(N))
13781     return NewSel;
13782 
13783   if (SDValue V = combineRepeatedFPDivisors(N))
13784     return V;
13785 
13786   if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
13787     // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
13788     if (N1CFP) {
13789       // Compute the reciprocal 1.0 / c2.
13790       const APFloat &N1APF = N1CFP->getValueAPF();
13791       APFloat Recip(N1APF.getSemantics(), 1); // 1.0
13792       APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
13793       // Only do the transform if the reciprocal is a legal fp immediate that
13794       // isn't too nasty (eg NaN, denormal, ...).
13795       if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
13796           (!LegalOperations ||
13797            // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
13798            // backend)... we should handle this gracefully after Legalize.
13799            // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
13800            TLI.isOperationLegal(ISD::ConstantFP, VT) ||
13801            TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
13802         return DAG.getNode(ISD::FMUL, DL, VT, N0,
13803                            DAG.getConstantFP(Recip, DL, VT));
13804     }
13805 
13806     // If this FDIV is part of a reciprocal square root, it may be folded
13807     // into a target-specific square root estimate instruction.
13808     if (N1.getOpcode() == ISD::FSQRT) {
13809       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
13810         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
13811     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
13812                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
13813       if (SDValue RV =
13814               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
13815         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
13816         AddToWorklist(RV.getNode());
13817         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
13818       }
13819     } else if (N1.getOpcode() == ISD::FP_ROUND &&
13820                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
13821       if (SDValue RV =
13822               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
13823         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
13824         AddToWorklist(RV.getNode());
13825         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
13826       }
13827     } else if (N1.getOpcode() == ISD::FMUL) {
13828       // Look through an FMUL. Even though this won't remove the FDIV directly,
13829       // it's still worthwhile to get rid of the FSQRT if possible.
13830       SDValue Sqrt, Y;
13831       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
13832         Sqrt = N1.getOperand(0);
13833         Y = N1.getOperand(1);
13834       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
13835         Sqrt = N1.getOperand(1);
13836         Y = N1.getOperand(0);
13837       }
13838       if (Sqrt.getNode()) {
13839         // If the other multiply operand is known positive, pull it into the
13840         // sqrt. That will eliminate the division if we convert to an estimate.
13841         if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
13842             N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
13843           SDValue A;
13844           if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
13845             A = Y.getOperand(0);
13846           else if (Y == Sqrt.getOperand(0))
13847             A = Y;
13848           if (A) {
13849             // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
13850             // X / (A * sqrt(A))       --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
13851             SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
13852             SDValue AAZ =
13853                 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
13854             if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
13855               return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
13856 
13857             // Estimate creation failed. Clean up speculatively created nodes.
13858             recursivelyDeleteUnusedNodes(AAZ.getNode());
13859           }
13860         }
13861 
13862         // We found a FSQRT, so try to make this fold:
13863         // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
13864         if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
13865           SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
13866           AddToWorklist(Div.getNode());
13867           return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
13868         }
13869       }
13870     }
13871 
13872     // Fold into a reciprocal estimate and multiply instead of a real divide.
13873     if (Options.NoInfsFPMath || Flags.hasNoInfs())
13874       if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
13875         return RV;
13876   }
13877 
13878   // Fold X/Sqrt(X) -> Sqrt(X)
13879   if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
13880       (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
13881     if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
13882       return N1;
13883 
13884   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
13885   TargetLowering::NegatibleCost CostN0 =
13886       TargetLowering::NegatibleCost::Expensive;
13887   TargetLowering::NegatibleCost CostN1 =
13888       TargetLowering::NegatibleCost::Expensive;
13889   SDValue NegN0 =
13890       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
13891   SDValue NegN1 =
13892       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
13893   if (NegN0 && NegN1 &&
13894       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
13895        CostN1 == TargetLowering::NegatibleCost::Cheaper))
13896     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
13897 
13898   return SDValue();
13899 }
13900 
visitFREM(SDNode * N)13901 SDValue DAGCombiner::visitFREM(SDNode *N) {
13902   SDValue N0 = N->getOperand(0);
13903   SDValue N1 = N->getOperand(1);
13904   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
13905   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
13906   EVT VT = N->getValueType(0);
13907   SDNodeFlags Flags = N->getFlags();
13908   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13909 
13910   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
13911     return R;
13912 
13913   // fold (frem c1, c2) -> fmod(c1,c2)
13914   if (N0CFP && N1CFP)
13915     return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1);
13916 
13917   if (SDValue NewSel = foldBinOpIntoSelect(N))
13918     return NewSel;
13919 
13920   return SDValue();
13921 }
13922 
visitFSQRT(SDNode * N)13923 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
13924   SDNodeFlags Flags = N->getFlags();
13925   const TargetOptions &Options = DAG.getTarget().Options;
13926 
13927   // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
13928   // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
13929   if (!Flags.hasApproximateFuncs() ||
13930       (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
13931     return SDValue();
13932 
13933   SDValue N0 = N->getOperand(0);
13934   if (TLI.isFsqrtCheap(N0, DAG))
13935     return SDValue();
13936 
13937   // FSQRT nodes have flags that propagate to the created nodes.
13938   // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
13939   //       transform the fdiv, we may produce a sub-optimal estimate sequence
13940   //       because the reciprocal calculation may not have to filter out a
13941   //       0.0 input.
13942   return buildSqrtEstimate(N0, Flags);
13943 }
13944 
13945 /// copysign(x, fp_extend(y)) -> copysign(x, y)
13946 /// copysign(x, fp_round(y)) -> copysign(x, y)
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)13947 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
13948   SDValue N1 = N->getOperand(1);
13949   if ((N1.getOpcode() == ISD::FP_EXTEND ||
13950        N1.getOpcode() == ISD::FP_ROUND)) {
13951     // Do not optimize out type conversion of f128 type yet.
13952     // For some targets like x86_64, configuration is changed to keep one f128
13953     // value in one SSE register, but instruction selection cannot handle
13954     // FCOPYSIGN on SSE registers yet.
13955     EVT N1VT = N1->getValueType(0);
13956     EVT N1Op0VT = N1->getOperand(0).getValueType();
13957     return (N1VT == N1Op0VT || N1Op0VT != MVT::f128);
13958   }
13959   return false;
13960 }
13961 
visitFCOPYSIGN(SDNode * N)13962 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
13963   SDValue N0 = N->getOperand(0);
13964   SDValue N1 = N->getOperand(1);
13965   bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
13966   bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
13967   EVT VT = N->getValueType(0);
13968 
13969   if (N0CFP && N1CFP) // Constant fold
13970     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1);
13971 
13972   if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
13973     const APFloat &V = N1C->getValueAPF();
13974     // copysign(x, c1) -> fabs(x)       iff ispos(c1)
13975     // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
13976     if (!V.isNegative()) {
13977       if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
13978         return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
13979     } else {
13980       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
13981         return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
13982                            DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
13983     }
13984   }
13985 
13986   // copysign(fabs(x), y) -> copysign(x, y)
13987   // copysign(fneg(x), y) -> copysign(x, y)
13988   // copysign(copysign(x,z), y) -> copysign(x, y)
13989   if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
13990       N0.getOpcode() == ISD::FCOPYSIGN)
13991     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
13992 
13993   // copysign(x, abs(y)) -> abs(x)
13994   if (N1.getOpcode() == ISD::FABS)
13995     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
13996 
13997   // copysign(x, copysign(y,z)) -> copysign(x, z)
13998   if (N1.getOpcode() == ISD::FCOPYSIGN)
13999     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
14000 
14001   // copysign(x, fp_extend(y)) -> copysign(x, y)
14002   // copysign(x, fp_round(y)) -> copysign(x, y)
14003   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
14004     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
14005 
14006   return SDValue();
14007 }
14008 
visitFPOW(SDNode * N)14009 SDValue DAGCombiner::visitFPOW(SDNode *N) {
14010   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
14011   if (!ExponentC)
14012     return SDValue();
14013   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14014 
14015   // Try to convert x ** (1/3) into cube root.
14016   // TODO: Handle the various flavors of long double.
14017   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
14018   //       Some range near 1/3 should be fine.
14019   EVT VT = N->getValueType(0);
14020   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
14021       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
14022     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
14023     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
14024     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
14025     // For regular numbers, rounding may cause the results to differ.
14026     // Therefore, we require { nsz ninf nnan afn } for this transform.
14027     // TODO: We could select out the special cases if we don't have nsz/ninf.
14028     SDNodeFlags Flags = N->getFlags();
14029     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
14030         !Flags.hasApproximateFuncs())
14031       return SDValue();
14032 
14033     // Do not create a cbrt() libcall if the target does not have it, and do not
14034     // turn a pow that has lowering support into a cbrt() libcall.
14035     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
14036         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
14037          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
14038       return SDValue();
14039 
14040     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
14041   }
14042 
14043   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
14044   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
14045   // TODO: This could be extended (using a target hook) to handle smaller
14046   // power-of-2 fractional exponents.
14047   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
14048   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
14049   if (ExponentIs025 || ExponentIs075) {
14050     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
14051     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
14052     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
14053     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
14054     // For regular numbers, rounding may cause the results to differ.
14055     // Therefore, we require { nsz ninf afn } for this transform.
14056     // TODO: We could select out the special cases if we don't have nsz/ninf.
14057     SDNodeFlags Flags = N->getFlags();
14058 
14059     // We only need no signed zeros for the 0.25 case.
14060     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
14061         !Flags.hasApproximateFuncs())
14062       return SDValue();
14063 
14064     // Don't double the number of libcalls. We are trying to inline fast code.
14065     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
14066       return SDValue();
14067 
14068     // Assume that libcalls are the smallest code.
14069     // TODO: This restriction should probably be lifted for vectors.
14070     if (ForCodeSize)
14071       return SDValue();
14072 
14073     // pow(X, 0.25) --> sqrt(sqrt(X))
14074     SDLoc DL(N);
14075     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
14076     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
14077     if (ExponentIs025)
14078       return SqrtSqrt;
14079     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
14080     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
14081   }
14082 
14083   return SDValue();
14084 }
14085 
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)14086 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
14087                                const TargetLowering &TLI) {
14088   // This optimization is guarded by a function attribute because it may produce
14089   // unexpected results. Ie, programs may be relying on the platform-specific
14090   // undefined behavior when the float-to-int conversion overflows.
14091   const Function &F = DAG.getMachineFunction().getFunction();
14092   Attribute StrictOverflow = F.getFnAttribute("strict-float-cast-overflow");
14093   if (StrictOverflow.getValueAsString().equals("false"))
14094     return SDValue();
14095 
14096   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
14097   // replacing casts with a libcall. We also must be allowed to ignore -0.0
14098   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
14099   // conversions would return +0.0.
14100   // FIXME: We should be able to use node-level FMF here.
14101   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
14102   EVT VT = N->getValueType(0);
14103   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
14104       !DAG.getTarget().Options.NoSignedZerosFPMath)
14105     return SDValue();
14106 
14107   // fptosi/fptoui round towards zero, so converting from FP to integer and
14108   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
14109   SDValue N0 = N->getOperand(0);
14110   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
14111       N0.getOperand(0).getValueType() == VT)
14112     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
14113 
14114   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
14115       N0.getOperand(0).getValueType() == VT)
14116     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
14117 
14118   return SDValue();
14119 }
14120 
visitSINT_TO_FP(SDNode * N)14121 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
14122   SDValue N0 = N->getOperand(0);
14123   EVT VT = N->getValueType(0);
14124   EVT OpVT = N0.getValueType();
14125 
14126   // [us]itofp(undef) = 0, because the result value is bounded.
14127   if (N0.isUndef())
14128     return DAG.getConstantFP(0.0, SDLoc(N), VT);
14129 
14130   // fold (sint_to_fp c1) -> c1fp
14131   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
14132       // ...but only if the target supports immediate floating-point values
14133       (!LegalOperations ||
14134        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
14135     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
14136 
14137   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
14138   // but UINT_TO_FP is legal on this target, try to convert.
14139   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
14140       hasOperation(ISD::UINT_TO_FP, OpVT)) {
14141     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
14142     if (DAG.SignBitIsZero(N0))
14143       return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
14144   }
14145 
14146   // The next optimizations are desirable only if SELECT_CC can be lowered.
14147   // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
14148   if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
14149       !VT.isVector() &&
14150       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
14151     SDLoc DL(N);
14152     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
14153                          DAG.getConstantFP(0.0, DL, VT));
14154   }
14155 
14156   // fold (sint_to_fp (zext (setcc x, y, cc))) ->
14157   //      (select (setcc x, y, cc), 1.0, 0.0)
14158   if (N0.getOpcode() == ISD::ZERO_EXTEND &&
14159       N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
14160       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
14161     SDLoc DL(N);
14162     return DAG.getSelect(DL, VT, N0.getOperand(0),
14163                          DAG.getConstantFP(1.0, DL, VT),
14164                          DAG.getConstantFP(0.0, DL, VT));
14165   }
14166 
14167   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
14168     return FTrunc;
14169 
14170   return SDValue();
14171 }
14172 
visitUINT_TO_FP(SDNode * N)14173 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
14174   SDValue N0 = N->getOperand(0);
14175   EVT VT = N->getValueType(0);
14176   EVT OpVT = N0.getValueType();
14177 
14178   // [us]itofp(undef) = 0, because the result value is bounded.
14179   if (N0.isUndef())
14180     return DAG.getConstantFP(0.0, SDLoc(N), VT);
14181 
14182   // fold (uint_to_fp c1) -> c1fp
14183   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
14184       // ...but only if the target supports immediate floating-point values
14185       (!LegalOperations ||
14186        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
14187     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
14188 
14189   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
14190   // but SINT_TO_FP is legal on this target, try to convert.
14191   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
14192       hasOperation(ISD::SINT_TO_FP, OpVT)) {
14193     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
14194     if (DAG.SignBitIsZero(N0))
14195       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
14196   }
14197 
14198   // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
14199   if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
14200       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
14201     SDLoc DL(N);
14202     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
14203                          DAG.getConstantFP(0.0, DL, VT));
14204   }
14205 
14206   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
14207     return FTrunc;
14208 
14209   return SDValue();
14210 }
14211 
14212 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)14213 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
14214   SDValue N0 = N->getOperand(0);
14215   EVT VT = N->getValueType(0);
14216 
14217   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
14218     return SDValue();
14219 
14220   SDValue Src = N0.getOperand(0);
14221   EVT SrcVT = Src.getValueType();
14222   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
14223   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
14224 
14225   // We can safely assume the conversion won't overflow the output range,
14226   // because (for example) (uint8_t)18293.f is undefined behavior.
14227 
14228   // Since we can assume the conversion won't overflow, our decision as to
14229   // whether the input will fit in the float should depend on the minimum
14230   // of the input range and output range.
14231 
14232   // This means this is also safe for a signed input and unsigned output, since
14233   // a negative input would lead to undefined behavior.
14234   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
14235   unsigned OutputSize = (int)VT.getScalarSizeInBits() - IsOutputSigned;
14236   unsigned ActualSize = std::min(InputSize, OutputSize);
14237   const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
14238 
14239   // We can only fold away the float conversion if the input range can be
14240   // represented exactly in the float range.
14241   if (APFloat::semanticsPrecision(sem) >= ActualSize) {
14242     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
14243       unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
14244                                                        : ISD::ZERO_EXTEND;
14245       return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
14246     }
14247     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
14248       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
14249     return DAG.getBitcast(VT, Src);
14250   }
14251   return SDValue();
14252 }
14253 
visitFP_TO_SINT(SDNode * N)14254 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
14255   SDValue N0 = N->getOperand(0);
14256   EVT VT = N->getValueType(0);
14257 
14258   // fold (fp_to_sint undef) -> undef
14259   if (N0.isUndef())
14260     return DAG.getUNDEF(VT);
14261 
14262   // fold (fp_to_sint c1fp) -> c1
14263   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14264     return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
14265 
14266   return FoldIntToFPToInt(N, DAG);
14267 }
14268 
visitFP_TO_UINT(SDNode * N)14269 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
14270   SDValue N0 = N->getOperand(0);
14271   EVT VT = N->getValueType(0);
14272 
14273   // fold (fp_to_uint undef) -> undef
14274   if (N0.isUndef())
14275     return DAG.getUNDEF(VT);
14276 
14277   // fold (fp_to_uint c1fp) -> c1
14278   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14279     return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
14280 
14281   return FoldIntToFPToInt(N, DAG);
14282 }
14283 
visitFP_ROUND(SDNode * N)14284 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
14285   SDValue N0 = N->getOperand(0);
14286   SDValue N1 = N->getOperand(1);
14287   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
14288   EVT VT = N->getValueType(0);
14289 
14290   // fold (fp_round c1fp) -> c1fp
14291   if (N0CFP)
14292     return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT, N0, N1);
14293 
14294   // fold (fp_round (fp_extend x)) -> x
14295   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
14296     return N0.getOperand(0);
14297 
14298   // fold (fp_round (fp_round x)) -> (fp_round x)
14299   if (N0.getOpcode() == ISD::FP_ROUND) {
14300     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
14301     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
14302 
14303     // Skip this folding if it results in an fp_round from f80 to f16.
14304     //
14305     // f80 to f16 always generates an expensive (and as yet, unimplemented)
14306     // libcall to __truncxfhf2 instead of selecting native f16 conversion
14307     // instructions from f32 or f64.  Moreover, the first (value-preserving)
14308     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
14309     // x86.
14310     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
14311       return SDValue();
14312 
14313     // If the first fp_round isn't a value preserving truncation, it might
14314     // introduce a tie in the second fp_round, that wouldn't occur in the
14315     // single-step fp_round we want to fold to.
14316     // In other words, double rounding isn't the same as rounding.
14317     // Also, this is a value preserving truncation iff both fp_round's are.
14318     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
14319       SDLoc DL(N);
14320       return DAG.getNode(ISD::FP_ROUND, DL, VT, N0.getOperand(0),
14321                          DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL));
14322     }
14323   }
14324 
14325   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
14326   if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse()) {
14327     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
14328                               N0.getOperand(0), N1);
14329     AddToWorklist(Tmp.getNode());
14330     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
14331                        Tmp, N0.getOperand(1));
14332   }
14333 
14334   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14335     return NewVSel;
14336 
14337   return SDValue();
14338 }
14339 
visitFP_EXTEND(SDNode * N)14340 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
14341   SDValue N0 = N->getOperand(0);
14342   EVT VT = N->getValueType(0);
14343 
14344   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
14345   if (N->hasOneUse() &&
14346       N->use_begin()->getOpcode() == ISD::FP_ROUND)
14347     return SDValue();
14348 
14349   // fold (fp_extend c1fp) -> c1fp
14350   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14351     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
14352 
14353   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
14354   if (N0.getOpcode() == ISD::FP16_TO_FP &&
14355       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
14356     return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
14357 
14358   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
14359   // value of X.
14360   if (N0.getOpcode() == ISD::FP_ROUND
14361       && N0.getConstantOperandVal(1) == 1) {
14362     SDValue In = N0.getOperand(0);
14363     if (In.getValueType() == VT) return In;
14364     if (VT.bitsLT(In.getValueType()))
14365       return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
14366                          In, N0.getOperand(1));
14367     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
14368   }
14369 
14370   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
14371   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
14372        TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
14373     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14374     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
14375                                      LN0->getChain(),
14376                                      LN0->getBasePtr(), N0.getValueType(),
14377                                      LN0->getMemOperand());
14378     CombineTo(N, ExtLoad);
14379     CombineTo(N0.getNode(),
14380               DAG.getNode(ISD::FP_ROUND, SDLoc(N0),
14381                           N0.getValueType(), ExtLoad,
14382                           DAG.getIntPtrConstant(1, SDLoc(N0))),
14383               ExtLoad.getValue(1));
14384     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14385   }
14386 
14387   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14388     return NewVSel;
14389 
14390   return SDValue();
14391 }
14392 
visitFCEIL(SDNode * N)14393 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
14394   SDValue N0 = N->getOperand(0);
14395   EVT VT = N->getValueType(0);
14396 
14397   // fold (fceil c1) -> fceil(c1)
14398   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14399     return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
14400 
14401   return SDValue();
14402 }
14403 
visitFTRUNC(SDNode * N)14404 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
14405   SDValue N0 = N->getOperand(0);
14406   EVT VT = N->getValueType(0);
14407 
14408   // fold (ftrunc c1) -> ftrunc(c1)
14409   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14410     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
14411 
14412   // fold ftrunc (known rounded int x) -> x
14413   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
14414   // likely to be generated to extract integer from a rounded floating value.
14415   switch (N0.getOpcode()) {
14416   default: break;
14417   case ISD::FRINT:
14418   case ISD::FTRUNC:
14419   case ISD::FNEARBYINT:
14420   case ISD::FFLOOR:
14421   case ISD::FCEIL:
14422     return N0;
14423   }
14424 
14425   return SDValue();
14426 }
14427 
visitFFLOOR(SDNode * N)14428 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
14429   SDValue N0 = N->getOperand(0);
14430   EVT VT = N->getValueType(0);
14431 
14432   // fold (ffloor c1) -> ffloor(c1)
14433   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14434     return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
14435 
14436   return SDValue();
14437 }
14438 
visitFNEG(SDNode * N)14439 SDValue DAGCombiner::visitFNEG(SDNode *N) {
14440   SDValue N0 = N->getOperand(0);
14441   EVT VT = N->getValueType(0);
14442   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14443 
14444   // Constant fold FNEG.
14445   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14446     return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
14447 
14448   if (SDValue NegN0 =
14449           TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
14450     return NegN0;
14451 
14452   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
14453   // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
14454   // know it was called from a context with a nsz flag if the input fsub does
14455   // not.
14456   if (N0.getOpcode() == ISD::FSUB &&
14457       (DAG.getTarget().Options.NoSignedZerosFPMath ||
14458        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
14459     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
14460                        N0.getOperand(0));
14461   }
14462 
14463   if (SDValue Cast = foldSignChangeInBitcast(N))
14464     return Cast;
14465 
14466   return SDValue();
14467 }
14468 
visitFMinMax(SelectionDAG & DAG,SDNode * N,APFloat (* Op)(const APFloat &,const APFloat &))14469 static SDValue visitFMinMax(SelectionDAG &DAG, SDNode *N,
14470                             APFloat (*Op)(const APFloat &, const APFloat &)) {
14471   SDValue N0 = N->getOperand(0);
14472   SDValue N1 = N->getOperand(1);
14473   EVT VT = N->getValueType(0);
14474   const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0);
14475   const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1);
14476   const SDNodeFlags Flags = N->getFlags();
14477   unsigned Opc = N->getOpcode();
14478   bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
14479   bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
14480   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14481 
14482   if (N0CFP && N1CFP) {
14483     const APFloat &C0 = N0CFP->getValueAPF();
14484     const APFloat &C1 = N1CFP->getValueAPF();
14485     return DAG.getConstantFP(Op(C0, C1), SDLoc(N), VT);
14486   }
14487 
14488   // Canonicalize to constant on RHS.
14489   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
14490       !DAG.isConstantFPBuildVectorOrConstantFP(N1))
14491     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
14492 
14493   if (N1CFP) {
14494     const APFloat &AF = N1CFP->getValueAPF();
14495 
14496     // minnum(X, nan) -> X
14497     // maxnum(X, nan) -> X
14498     // minimum(X, nan) -> nan
14499     // maximum(X, nan) -> nan
14500     if (AF.isNaN())
14501       return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
14502 
14503     // In the following folds, inf can be replaced with the largest finite
14504     // float, if the ninf flag is set.
14505     if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
14506       // minnum(X, -inf) -> -inf
14507       // maxnum(X, +inf) -> +inf
14508       // minimum(X, -inf) -> -inf if nnan
14509       // maximum(X, +inf) -> +inf if nnan
14510       if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
14511         return N->getOperand(1);
14512 
14513       // minnum(X, +inf) -> X if nnan
14514       // maxnum(X, -inf) -> X if nnan
14515       // minimum(X, +inf) -> X
14516       // maximum(X, -inf) -> X
14517       if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
14518         return N->getOperand(0);
14519     }
14520   }
14521 
14522   return SDValue();
14523 }
14524 
visitFMINNUM(SDNode * N)14525 SDValue DAGCombiner::visitFMINNUM(SDNode *N) {
14526   return visitFMinMax(DAG, N, minnum);
14527 }
14528 
visitFMAXNUM(SDNode * N)14529 SDValue DAGCombiner::visitFMAXNUM(SDNode *N) {
14530   return visitFMinMax(DAG, N, maxnum);
14531 }
14532 
visitFMINIMUM(SDNode * N)14533 SDValue DAGCombiner::visitFMINIMUM(SDNode *N) {
14534   return visitFMinMax(DAG, N, minimum);
14535 }
14536 
visitFMAXIMUM(SDNode * N)14537 SDValue DAGCombiner::visitFMAXIMUM(SDNode *N) {
14538   return visitFMinMax(DAG, N, maximum);
14539 }
14540 
visitFABS(SDNode * N)14541 SDValue DAGCombiner::visitFABS(SDNode *N) {
14542   SDValue N0 = N->getOperand(0);
14543   EVT VT = N->getValueType(0);
14544 
14545   // fold (fabs c1) -> fabs(c1)
14546   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14547     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
14548 
14549   // fold (fabs (fabs x)) -> (fabs x)
14550   if (N0.getOpcode() == ISD::FABS)
14551     return N->getOperand(0);
14552 
14553   // fold (fabs (fneg x)) -> (fabs x)
14554   // fold (fabs (fcopysign x, y)) -> (fabs x)
14555   if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
14556     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
14557 
14558   if (SDValue Cast = foldSignChangeInBitcast(N))
14559     return Cast;
14560 
14561   return SDValue();
14562 }
14563 
visitBRCOND(SDNode * N)14564 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
14565   SDValue Chain = N->getOperand(0);
14566   SDValue N1 = N->getOperand(1);
14567   SDValue N2 = N->getOperand(2);
14568 
14569   // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
14570   // nondeterministic jumps).
14571   if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
14572     return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
14573                        N1->getOperand(0), N2);
14574   }
14575 
14576   // If N is a constant we could fold this into a fallthrough or unconditional
14577   // branch. However that doesn't happen very often in normal code, because
14578   // Instcombine/SimplifyCFG should have handled the available opportunities.
14579   // If we did this folding here, it would be necessary to update the
14580   // MachineBasicBlock CFG, which is awkward.
14581 
14582   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
14583   // on the target.
14584   if (N1.getOpcode() == ISD::SETCC &&
14585       TLI.isOperationLegalOrCustom(ISD::BR_CC,
14586                                    N1.getOperand(0).getValueType())) {
14587     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
14588                        Chain, N1.getOperand(2),
14589                        N1.getOperand(0), N1.getOperand(1), N2);
14590   }
14591 
14592   if (N1.hasOneUse()) {
14593     // rebuildSetCC calls visitXor which may change the Chain when there is a
14594     // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
14595     HandleSDNode ChainHandle(Chain);
14596     if (SDValue NewN1 = rebuildSetCC(N1))
14597       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
14598                          ChainHandle.getValue(), NewN1, N2);
14599   }
14600 
14601   return SDValue();
14602 }
14603 
rebuildSetCC(SDValue N)14604 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
14605   if (N.getOpcode() == ISD::SRL ||
14606       (N.getOpcode() == ISD::TRUNCATE &&
14607        (N.getOperand(0).hasOneUse() &&
14608         N.getOperand(0).getOpcode() == ISD::SRL))) {
14609     // Look pass the truncate.
14610     if (N.getOpcode() == ISD::TRUNCATE)
14611       N = N.getOperand(0);
14612 
14613     // Match this pattern so that we can generate simpler code:
14614     //
14615     //   %a = ...
14616     //   %b = and i32 %a, 2
14617     //   %c = srl i32 %b, 1
14618     //   brcond i32 %c ...
14619     //
14620     // into
14621     //
14622     //   %a = ...
14623     //   %b = and i32 %a, 2
14624     //   %c = setcc eq %b, 0
14625     //   brcond %c ...
14626     //
14627     // This applies only when the AND constant value has one bit set and the
14628     // SRL constant is equal to the log2 of the AND constant. The back-end is
14629     // smart enough to convert the result into a TEST/JMP sequence.
14630     SDValue Op0 = N.getOperand(0);
14631     SDValue Op1 = N.getOperand(1);
14632 
14633     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
14634       SDValue AndOp1 = Op0.getOperand(1);
14635 
14636       if (AndOp1.getOpcode() == ISD::Constant) {
14637         const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
14638 
14639         if (AndConst.isPowerOf2() &&
14640             cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
14641           SDLoc DL(N);
14642           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
14643                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
14644                               ISD::SETNE);
14645         }
14646       }
14647     }
14648   }
14649 
14650   // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
14651   // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
14652   if (N.getOpcode() == ISD::XOR) {
14653     // Because we may call this on a speculatively constructed
14654     // SimplifiedSetCC Node, we need to simplify this node first.
14655     // Ideally this should be folded into SimplifySetCC and not
14656     // here. For now, grab a handle to N so we don't lose it from
14657     // replacements interal to the visit.
14658     HandleSDNode XORHandle(N);
14659     while (N.getOpcode() == ISD::XOR) {
14660       SDValue Tmp = visitXOR(N.getNode());
14661       // No simplification done.
14662       if (!Tmp.getNode())
14663         break;
14664       // Returning N is form in-visit replacement that may invalidated
14665       // N. Grab value from Handle.
14666       if (Tmp.getNode() == N.getNode())
14667         N = XORHandle.getValue();
14668       else // Node simplified. Try simplifying again.
14669         N = Tmp;
14670     }
14671 
14672     if (N.getOpcode() != ISD::XOR)
14673       return N;
14674 
14675     SDValue Op0 = N->getOperand(0);
14676     SDValue Op1 = N->getOperand(1);
14677 
14678     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
14679       bool Equal = false;
14680       // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
14681       if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
14682           Op0.getValueType() == MVT::i1) {
14683         N = Op0;
14684         Op0 = N->getOperand(0);
14685         Op1 = N->getOperand(1);
14686         Equal = true;
14687       }
14688 
14689       EVT SetCCVT = N.getValueType();
14690       if (LegalTypes)
14691         SetCCVT = getSetCCResultType(SetCCVT);
14692       // Replace the uses of XOR with SETCC
14693       return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
14694                           Equal ? ISD::SETEQ : ISD::SETNE);
14695     }
14696   }
14697 
14698   return SDValue();
14699 }
14700 
14701 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
14702 //
visitBR_CC(SDNode * N)14703 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
14704   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
14705   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
14706 
14707   // If N is a constant we could fold this into a fallthrough or unconditional
14708   // branch. However that doesn't happen very often in normal code, because
14709   // Instcombine/SimplifyCFG should have handled the available opportunities.
14710   // If we did this folding here, it would be necessary to update the
14711   // MachineBasicBlock CFG, which is awkward.
14712 
14713   // Use SimplifySetCC to simplify SETCC's.
14714   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
14715                                CondLHS, CondRHS, CC->get(), SDLoc(N),
14716                                false);
14717   if (Simp.getNode()) AddToWorklist(Simp.getNode());
14718 
14719   // fold to a simpler setcc
14720   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
14721     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
14722                        N->getOperand(0), Simp.getOperand(2),
14723                        Simp.getOperand(0), Simp.getOperand(1),
14724                        N->getOperand(4));
14725 
14726   return SDValue();
14727 }
14728 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)14729 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
14730                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
14731                                      const TargetLowering &TLI) {
14732   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
14733     if (LD->isIndexed())
14734       return false;
14735     EVT VT = LD->getMemoryVT();
14736     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
14737       return false;
14738     Ptr = LD->getBasePtr();
14739   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
14740     if (ST->isIndexed())
14741       return false;
14742     EVT VT = ST->getMemoryVT();
14743     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
14744       return false;
14745     Ptr = ST->getBasePtr();
14746     IsLoad = false;
14747   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
14748     if (LD->isIndexed())
14749       return false;
14750     EVT VT = LD->getMemoryVT();
14751     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
14752         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
14753       return false;
14754     Ptr = LD->getBasePtr();
14755     IsMasked = true;
14756   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
14757     if (ST->isIndexed())
14758       return false;
14759     EVT VT = ST->getMemoryVT();
14760     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
14761         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
14762       return false;
14763     Ptr = ST->getBasePtr();
14764     IsLoad = false;
14765     IsMasked = true;
14766   } else {
14767     return false;
14768   }
14769   return true;
14770 }
14771 
14772 /// Try turning a load/store into a pre-indexed load/store when the base
14773 /// pointer is an add or subtract and it has other uses besides the load/store.
14774 /// After the transformation, the new indexed load/store has effectively folded
14775 /// the add/subtract in and all of its other uses are redirected to the
14776 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)14777 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
14778   if (Level < AfterLegalizeDAG)
14779     return false;
14780 
14781   bool IsLoad = true;
14782   bool IsMasked = false;
14783   SDValue Ptr;
14784   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
14785                                 Ptr, TLI))
14786     return false;
14787 
14788   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
14789   // out.  There is no reason to make this a preinc/predec.
14790   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
14791       Ptr.getNode()->hasOneUse())
14792     return false;
14793 
14794   // Ask the target to do addressing mode selection.
14795   SDValue BasePtr;
14796   SDValue Offset;
14797   ISD::MemIndexedMode AM = ISD::UNINDEXED;
14798   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
14799     return false;
14800 
14801   // Backends without true r+i pre-indexed forms may need to pass a
14802   // constant base with a variable offset so that constant coercion
14803   // will work with the patterns in canonical form.
14804   bool Swapped = false;
14805   if (isa<ConstantSDNode>(BasePtr)) {
14806     std::swap(BasePtr, Offset);
14807     Swapped = true;
14808   }
14809 
14810   // Don't create a indexed load / store with zero offset.
14811   if (isNullConstant(Offset))
14812     return false;
14813 
14814   // Try turning it into a pre-indexed load / store except when:
14815   // 1) The new base ptr is a frame index.
14816   // 2) If N is a store and the new base ptr is either the same as or is a
14817   //    predecessor of the value being stored.
14818   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
14819   //    that would create a cycle.
14820   // 4) All uses are load / store ops that use it as old base ptr.
14821 
14822   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
14823   // (plus the implicit offset) to a register to preinc anyway.
14824   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
14825     return false;
14826 
14827   // Check #2.
14828   if (!IsLoad) {
14829     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
14830                            : cast<StoreSDNode>(N)->getValue();
14831 
14832     // Would require a copy.
14833     if (Val == BasePtr)
14834       return false;
14835 
14836     // Would create a cycle.
14837     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
14838       return false;
14839   }
14840 
14841   // Caches for hasPredecessorHelper.
14842   SmallPtrSet<const SDNode *, 32> Visited;
14843   SmallVector<const SDNode *, 16> Worklist;
14844   Worklist.push_back(N);
14845 
14846   // If the offset is a constant, there may be other adds of constants that
14847   // can be folded with this one. We should do this to avoid having to keep
14848   // a copy of the original base pointer.
14849   SmallVector<SDNode *, 16> OtherUses;
14850   if (isa<ConstantSDNode>(Offset))
14851     for (SDNode::use_iterator UI = BasePtr.getNode()->use_begin(),
14852                               UE = BasePtr.getNode()->use_end();
14853          UI != UE; ++UI) {
14854       SDUse &Use = UI.getUse();
14855       // Skip the use that is Ptr and uses of other results from BasePtr's
14856       // node (important for nodes that return multiple results).
14857       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
14858         continue;
14859 
14860       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
14861         continue;
14862 
14863       if (Use.getUser()->getOpcode() != ISD::ADD &&
14864           Use.getUser()->getOpcode() != ISD::SUB) {
14865         OtherUses.clear();
14866         break;
14867       }
14868 
14869       SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
14870       if (!isa<ConstantSDNode>(Op1)) {
14871         OtherUses.clear();
14872         break;
14873       }
14874 
14875       // FIXME: In some cases, we can be smarter about this.
14876       if (Op1.getValueType() != Offset.getValueType()) {
14877         OtherUses.clear();
14878         break;
14879       }
14880 
14881       OtherUses.push_back(Use.getUser());
14882     }
14883 
14884   if (Swapped)
14885     std::swap(BasePtr, Offset);
14886 
14887   // Now check for #3 and #4.
14888   bool RealUse = false;
14889 
14890   for (SDNode *Use : Ptr.getNode()->uses()) {
14891     if (Use == N)
14892       continue;
14893     if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
14894       return false;
14895 
14896     // If Ptr may be folded in addressing mode of other use, then it's
14897     // not profitable to do this transformation.
14898     if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
14899       RealUse = true;
14900   }
14901 
14902   if (!RealUse)
14903     return false;
14904 
14905   SDValue Result;
14906   if (!IsMasked) {
14907     if (IsLoad)
14908       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
14909     else
14910       Result =
14911           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
14912   } else {
14913     if (IsLoad)
14914       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
14915                                         Offset, AM);
14916     else
14917       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
14918                                          Offset, AM);
14919   }
14920   ++PreIndexedNodes;
14921   ++NodesCombined;
14922   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
14923              Result.getNode()->dump(&DAG); dbgs() << '\n');
14924   WorklistRemover DeadNodes(*this);
14925   if (IsLoad) {
14926     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
14927     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
14928   } else {
14929     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
14930   }
14931 
14932   // Finally, since the node is now dead, remove it from the graph.
14933   deleteAndRecombine(N);
14934 
14935   if (Swapped)
14936     std::swap(BasePtr, Offset);
14937 
14938   // Replace other uses of BasePtr that can be updated to use Ptr
14939   for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
14940     unsigned OffsetIdx = 1;
14941     if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
14942       OffsetIdx = 0;
14943     assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
14944            BasePtr.getNode() && "Expected BasePtr operand");
14945 
14946     // We need to replace ptr0 in the following expression:
14947     //   x0 * offset0 + y0 * ptr0 = t0
14948     // knowing that
14949     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
14950     //
14951     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
14952     // indexed load/store and the expression that needs to be re-written.
14953     //
14954     // Therefore, we have:
14955     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
14956 
14957     auto *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
14958     const APInt &Offset0 = CN->getAPIntValue();
14959     const APInt &Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
14960     int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
14961     int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
14962     int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
14963     int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
14964 
14965     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
14966 
14967     APInt CNV = Offset0;
14968     if (X0 < 0) CNV = -CNV;
14969     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
14970     else CNV = CNV - Offset1;
14971 
14972     SDLoc DL(OtherUses[i]);
14973 
14974     // We can now generate the new expression.
14975     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
14976     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
14977 
14978     SDValue NewUse = DAG.getNode(Opcode,
14979                                  DL,
14980                                  OtherUses[i]->getValueType(0), NewOp1, NewOp2);
14981     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
14982     deleteAndRecombine(OtherUses[i]);
14983   }
14984 
14985   // Replace the uses of Ptr with uses of the updated base value.
14986   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
14987   deleteAndRecombine(Ptr.getNode());
14988   AddToWorklist(Result.getNode());
14989 
14990   return true;
14991 }
14992 
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)14993 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
14994                                    SDValue &BasePtr, SDValue &Offset,
14995                                    ISD::MemIndexedMode &AM,
14996                                    SelectionDAG &DAG,
14997                                    const TargetLowering &TLI) {
14998   if (PtrUse == N ||
14999       (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
15000     return false;
15001 
15002   if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
15003     return false;
15004 
15005   // Don't create a indexed load / store with zero offset.
15006   if (isNullConstant(Offset))
15007     return false;
15008 
15009   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
15010     return false;
15011 
15012   SmallPtrSet<const SDNode *, 32> Visited;
15013   for (SDNode *Use : BasePtr.getNode()->uses()) {
15014     if (Use == Ptr.getNode())
15015       continue;
15016 
15017     // No if there's a later user which could perform the index instead.
15018     if (isa<MemSDNode>(Use)) {
15019       bool IsLoad = true;
15020       bool IsMasked = false;
15021       SDValue OtherPtr;
15022       if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad,
15023                                    IsMasked, OtherPtr, TLI)) {
15024         SmallVector<const SDNode *, 2> Worklist;
15025         Worklist.push_back(Use);
15026         if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
15027           return false;
15028       }
15029     }
15030 
15031     // If all the uses are load / store addresses, then don't do the
15032     // transformation.
15033     if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
15034       for (SDNode *UseUse : Use->uses())
15035         if (canFoldInAddressingMode(Use, UseUse, DAG, TLI))
15036           return false;
15037     }
15038   }
15039   return true;
15040 }
15041 
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)15042 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
15043                                          bool &IsMasked, SDValue &Ptr,
15044                                          SDValue &BasePtr, SDValue &Offset,
15045                                          ISD::MemIndexedMode &AM,
15046                                          SelectionDAG &DAG,
15047                                          const TargetLowering &TLI) {
15048   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
15049                                 IsMasked, Ptr, TLI) ||
15050       Ptr.getNode()->hasOneUse())
15051     return nullptr;
15052 
15053   // Try turning it into a post-indexed load / store except when
15054   // 1) All uses are load / store ops that use it as base ptr (and
15055   //    it may be folded as addressing mmode).
15056   // 2) Op must be independent of N, i.e. Op is neither a predecessor
15057   //    nor a successor of N. Otherwise, if Op is folded that would
15058   //    create a cycle.
15059   for (SDNode *Op : Ptr->uses()) {
15060     // Check for #1.
15061     if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
15062       continue;
15063 
15064     // Check for #2.
15065     SmallPtrSet<const SDNode *, 32> Visited;
15066     SmallVector<const SDNode *, 8> Worklist;
15067     // Ptr is predecessor to both N and Op.
15068     Visited.insert(Ptr.getNode());
15069     Worklist.push_back(N);
15070     Worklist.push_back(Op);
15071     if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
15072         !SDNode::hasPredecessorHelper(Op, Visited, Worklist))
15073       return Op;
15074   }
15075   return nullptr;
15076 }
15077 
15078 /// Try to combine a load/store with a add/sub of the base pointer node into a
15079 /// post-indexed load/store. The transformation folded the add/subtract into the
15080 /// new indexed load/store effectively and all of its uses are redirected to the
15081 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)15082 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
15083   if (Level < AfterLegalizeDAG)
15084     return false;
15085 
15086   bool IsLoad = true;
15087   bool IsMasked = false;
15088   SDValue Ptr;
15089   SDValue BasePtr;
15090   SDValue Offset;
15091   ISD::MemIndexedMode AM = ISD::UNINDEXED;
15092   SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
15093                                          Offset, AM, DAG, TLI);
15094   if (!Op)
15095     return false;
15096 
15097   SDValue Result;
15098   if (!IsMasked)
15099     Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
15100                                          Offset, AM)
15101                     : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
15102                                           BasePtr, Offset, AM);
15103   else
15104     Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
15105                                                BasePtr, Offset, AM)
15106                     : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
15107                                                 BasePtr, Offset, AM);
15108   ++PostIndexedNodes;
15109   ++NodesCombined;
15110   LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG);
15111              dbgs() << "\nWith: "; Result.getNode()->dump(&DAG);
15112              dbgs() << '\n');
15113   WorklistRemover DeadNodes(*this);
15114   if (IsLoad) {
15115     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
15116     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
15117   } else {
15118     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
15119   }
15120 
15121   // Finally, since the node is now dead, remove it from the graph.
15122   deleteAndRecombine(N);
15123 
15124   // Replace the uses of Use with uses of the updated base value.
15125   DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
15126                                 Result.getValue(IsLoad ? 1 : 0));
15127   deleteAndRecombine(Op);
15128   return true;
15129 }
15130 
15131 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)15132 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
15133   ISD::MemIndexedMode AM = LD->getAddressingMode();
15134   assert(AM != ISD::UNINDEXED);
15135   SDValue BP = LD->getOperand(1);
15136   SDValue Inc = LD->getOperand(2);
15137 
15138   // Some backends use TargetConstants for load offsets, but don't expect
15139   // TargetConstants in general ADD nodes. We can convert these constants into
15140   // regular Constants (if the constant is not opaque).
15141   assert((Inc.getOpcode() != ISD::TargetConstant ||
15142           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
15143          "Cannot split out indexing using opaque target constants");
15144   if (Inc.getOpcode() == ISD::TargetConstant) {
15145     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
15146     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
15147                           ConstInc->getValueType(0));
15148   }
15149 
15150   unsigned Opc =
15151       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
15152   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
15153 }
15154 
numVectorEltsOrZero(EVT T)15155 static inline ElementCount numVectorEltsOrZero(EVT T) {
15156   return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
15157 }
15158 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)15159 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
15160   Val = ST->getValue();
15161   EVT STType = Val.getValueType();
15162   EVT STMemType = ST->getMemoryVT();
15163   if (STType == STMemType)
15164     return true;
15165   if (isTypeLegal(STMemType))
15166     return false; // fail.
15167   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
15168       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
15169     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
15170     return true;
15171   }
15172   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
15173       STType.isInteger() && STMemType.isInteger()) {
15174     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
15175     return true;
15176   }
15177   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
15178     Val = DAG.getBitcast(STMemType, Val);
15179     return true;
15180   }
15181   return false; // fail.
15182 }
15183 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)15184 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
15185   EVT LDMemType = LD->getMemoryVT();
15186   EVT LDType = LD->getValueType(0);
15187   assert(Val.getValueType() == LDMemType &&
15188          "Attempting to extend value of non-matching type");
15189   if (LDType == LDMemType)
15190     return true;
15191   if (LDMemType.isInteger() && LDType.isInteger()) {
15192     switch (LD->getExtensionType()) {
15193     case ISD::NON_EXTLOAD:
15194       Val = DAG.getBitcast(LDType, Val);
15195       return true;
15196     case ISD::EXTLOAD:
15197       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
15198       return true;
15199     case ISD::SEXTLOAD:
15200       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
15201       return true;
15202     case ISD::ZEXTLOAD:
15203       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
15204       return true;
15205     }
15206   }
15207   return false;
15208 }
15209 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)15210 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
15211   if (OptLevel == CodeGenOpt::None || !LD->isSimple())
15212     return SDValue();
15213   SDValue Chain = LD->getOperand(0);
15214   StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
15215   // TODO: Relax this restriction for unordered atomics (see D66309)
15216   if (!ST || !ST->isSimple())
15217     return SDValue();
15218 
15219   EVT LDType = LD->getValueType(0);
15220   EVT LDMemType = LD->getMemoryVT();
15221   EVT STMemType = ST->getMemoryVT();
15222   EVT STType = ST->getValue().getValueType();
15223 
15224   // There are two cases to consider here:
15225   //  1. The store is fixed width and the load is scalable. In this case we
15226   //     don't know at compile time if the store completely envelops the load
15227   //     so we abandon the optimisation.
15228   //  2. The store is scalable and the load is fixed width. We could
15229   //     potentially support a limited number of cases here, but there has been
15230   //     no cost-benefit analysis to prove it's worth it.
15231   bool LdStScalable = LDMemType.isScalableVector();
15232   if (LdStScalable != STMemType.isScalableVector())
15233     return SDValue();
15234 
15235   // If we are dealing with scalable vectors on a big endian platform the
15236   // calculation of offsets below becomes trickier, since we do not know at
15237   // compile time the absolute size of the vector. Until we've done more
15238   // analysis on big-endian platforms it seems better to bail out for now.
15239   if (LdStScalable && DAG.getDataLayout().isBigEndian())
15240     return SDValue();
15241 
15242   BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
15243   BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
15244   int64_t Offset;
15245   if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
15246     return SDValue();
15247 
15248   // Normalize for Endianness. After this Offset=0 will denote that the least
15249   // significant bit in the loaded value maps to the least significant bit in
15250   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
15251   // n:th least significant byte of the stored value.
15252   if (DAG.getDataLayout().isBigEndian())
15253     Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedSize() -
15254               (int64_t)LDMemType.getStoreSizeInBits().getFixedSize()) /
15255                  8 -
15256              Offset;
15257 
15258   // Check that the stored value cover all bits that are loaded.
15259   bool STCoversLD;
15260 
15261   TypeSize LdMemSize = LDMemType.getSizeInBits();
15262   TypeSize StMemSize = STMemType.getSizeInBits();
15263   if (LdStScalable)
15264     STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
15265   else
15266     STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedSize() <=
15267                                    StMemSize.getFixedSize());
15268 
15269   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
15270     if (LD->isIndexed()) {
15271       // Cannot handle opaque target constants and we must respect the user's
15272       // request not to split indexes from loads.
15273       if (!canSplitIdx(LD))
15274         return SDValue();
15275       SDValue Idx = SplitIndexingFromLoad(LD);
15276       SDValue Ops[] = {Val, Idx, Chain};
15277       return CombineTo(LD, Ops, 3);
15278     }
15279     return CombineTo(LD, Val, Chain);
15280   };
15281 
15282   if (!STCoversLD)
15283     return SDValue();
15284 
15285   // Memory as copy space (potentially masked).
15286   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
15287     // Simple case: Direct non-truncating forwarding
15288     if (LDType.getSizeInBits() == LdMemSize)
15289       return ReplaceLd(LD, ST->getValue(), Chain);
15290     // Can we model the truncate and extension with an and mask?
15291     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
15292         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
15293       // Mask to size of LDMemType
15294       auto Mask =
15295           DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
15296                                                StMemSize.getFixedSize()),
15297                           SDLoc(ST), STType);
15298       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
15299       return ReplaceLd(LD, Val, Chain);
15300     }
15301   }
15302 
15303   // TODO: Deal with nonzero offset.
15304   if (LD->getBasePtr().isUndef() || Offset != 0)
15305     return SDValue();
15306   // Model necessary truncations / extenstions.
15307   SDValue Val;
15308   // Truncate Value To Stored Memory Size.
15309   do {
15310     if (!getTruncatedStoreValue(ST, Val))
15311       continue;
15312     if (!isTypeLegal(LDMemType))
15313       continue;
15314     if (STMemType != LDMemType) {
15315       // TODO: Support vectors? This requires extract_subvector/bitcast.
15316       if (!STMemType.isVector() && !LDMemType.isVector() &&
15317           STMemType.isInteger() && LDMemType.isInteger())
15318         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
15319       else
15320         continue;
15321     }
15322     if (!extendLoadedValueToExtension(LD, Val))
15323       continue;
15324     return ReplaceLd(LD, Val, Chain);
15325   } while (false);
15326 
15327   // On failure, cleanup dead nodes we may have created.
15328   if (Val->use_empty())
15329     deleteAndRecombine(Val.getNode());
15330   return SDValue();
15331 }
15332 
visitLOAD(SDNode * N)15333 SDValue DAGCombiner::visitLOAD(SDNode *N) {
15334   LoadSDNode *LD  = cast<LoadSDNode>(N);
15335   SDValue Chain = LD->getChain();
15336   SDValue Ptr   = LD->getBasePtr();
15337 
15338   // If load is not volatile and there are no uses of the loaded value (and
15339   // the updated indexed value in case of indexed loads), change uses of the
15340   // chain value into uses of the chain input (i.e. delete the dead load).
15341   // TODO: Allow this for unordered atomics (see D66309)
15342   if (LD->isSimple()) {
15343     if (N->getValueType(1) == MVT::Other) {
15344       // Unindexed loads.
15345       if (!N->hasAnyUseOfValue(0)) {
15346         // It's not safe to use the two value CombineTo variant here. e.g.
15347         // v1, chain2 = load chain1, loc
15348         // v2, chain3 = load chain2, loc
15349         // v3         = add v2, c
15350         // Now we replace use of chain2 with chain1.  This makes the second load
15351         // isomorphic to the one we are deleting, and thus makes this load live.
15352         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
15353                    dbgs() << "\nWith chain: "; Chain.getNode()->dump(&DAG);
15354                    dbgs() << "\n");
15355         WorklistRemover DeadNodes(*this);
15356         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
15357         AddUsersToWorklist(Chain.getNode());
15358         if (N->use_empty())
15359           deleteAndRecombine(N);
15360 
15361         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
15362       }
15363     } else {
15364       // Indexed loads.
15365       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
15366 
15367       // If this load has an opaque TargetConstant offset, then we cannot split
15368       // the indexing into an add/sub directly (that TargetConstant may not be
15369       // valid for a different type of node, and we cannot convert an opaque
15370       // target constant into a regular constant).
15371       bool CanSplitIdx = canSplitIdx(LD);
15372 
15373       if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
15374         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
15375         SDValue Index;
15376         if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
15377           Index = SplitIndexingFromLoad(LD);
15378           // Try to fold the base pointer arithmetic into subsequent loads and
15379           // stores.
15380           AddUsersToWorklist(N);
15381         } else
15382           Index = DAG.getUNDEF(N->getValueType(1));
15383         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
15384                    dbgs() << "\nWith: "; Undef.getNode()->dump(&DAG);
15385                    dbgs() << " and 2 other values\n");
15386         WorklistRemover DeadNodes(*this);
15387         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
15388         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
15389         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
15390         deleteAndRecombine(N);
15391         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
15392       }
15393     }
15394   }
15395 
15396   // If this load is directly stored, replace the load value with the stored
15397   // value.
15398   if (auto V = ForwardStoreValueToDirectLoad(LD))
15399     return V;
15400 
15401   // Try to infer better alignment information than the load already has.
15402   if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) {
15403     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
15404       if (*Alignment > LD->getAlign() &&
15405           isAligned(*Alignment, LD->getSrcValueOffset())) {
15406         SDValue NewLoad = DAG.getExtLoad(
15407             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
15408             LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
15409             LD->getMemOperand()->getFlags(), LD->getAAInfo());
15410         // NewLoad will always be N as we are only refining the alignment
15411         assert(NewLoad.getNode() == N);
15412         (void)NewLoad;
15413       }
15414     }
15415   }
15416 
15417   if (LD->isUnindexed()) {
15418     // Walk up chain skipping non-aliasing memory nodes.
15419     SDValue BetterChain = FindBetterChain(LD, Chain);
15420 
15421     // If there is a better chain.
15422     if (Chain != BetterChain) {
15423       SDValue ReplLoad;
15424 
15425       // Replace the chain to void dependency.
15426       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
15427         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
15428                                BetterChain, Ptr, LD->getMemOperand());
15429       } else {
15430         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
15431                                   LD->getValueType(0),
15432                                   BetterChain, Ptr, LD->getMemoryVT(),
15433                                   LD->getMemOperand());
15434       }
15435 
15436       // Create token factor to keep old chain connected.
15437       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
15438                                   MVT::Other, Chain, ReplLoad.getValue(1));
15439 
15440       // Replace uses with load result and token factor
15441       return CombineTo(N, ReplLoad.getValue(0), Token);
15442     }
15443   }
15444 
15445   // Try transforming N to an indexed load.
15446   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
15447     return SDValue(N, 0);
15448 
15449   // Try to slice up N to more direct loads if the slices are mapped to
15450   // different register banks or pairing can take place.
15451   if (SliceUpLoad(N))
15452     return SDValue(N, 0);
15453 
15454   return SDValue();
15455 }
15456 
15457 namespace {
15458 
15459 /// Helper structure used to slice a load in smaller loads.
15460 /// Basically a slice is obtained from the following sequence:
15461 /// Origin = load Ty1, Base
15462 /// Shift = srl Ty1 Origin, CstTy Amount
15463 /// Inst = trunc Shift to Ty2
15464 ///
15465 /// Then, it will be rewritten into:
15466 /// Slice = load SliceTy, Base + SliceOffset
15467 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
15468 ///
15469 /// SliceTy is deduced from the number of bits that are actually used to
15470 /// build Inst.
15471 struct LoadedSlice {
15472   /// Helper structure used to compute the cost of a slice.
15473   struct Cost {
15474     /// Are we optimizing for code size.
15475     bool ForCodeSize = false;
15476 
15477     /// Various cost.
15478     unsigned Loads = 0;
15479     unsigned Truncates = 0;
15480     unsigned CrossRegisterBanksCopies = 0;
15481     unsigned ZExts = 0;
15482     unsigned Shift = 0;
15483 
Cost__anonf026c69c2b11::LoadedSlice::Cost15484     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
15485 
15486     /// Get the cost of one isolated slice.
Cost__anonf026c69c2b11::LoadedSlice::Cost15487     Cost(const LoadedSlice &LS, bool ForCodeSize)
15488         : ForCodeSize(ForCodeSize), Loads(1) {
15489       EVT TruncType = LS.Inst->getValueType(0);
15490       EVT LoadedType = LS.getLoadedType();
15491       if (TruncType != LoadedType &&
15492           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
15493         ZExts = 1;
15494     }
15495 
15496     /// Account for slicing gain in the current cost.
15497     /// Slicing provide a few gains like removing a shift or a
15498     /// truncate. This method allows to grow the cost of the original
15499     /// load with the gain from this slice.
addSliceGain__anonf026c69c2b11::LoadedSlice::Cost15500     void addSliceGain(const LoadedSlice &LS) {
15501       // Each slice saves a truncate.
15502       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
15503       if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(),
15504                               LS.Inst->getValueType(0)))
15505         ++Truncates;
15506       // If there is a shift amount, this slice gets rid of it.
15507       if (LS.Shift)
15508         ++Shift;
15509       // If this slice can merge a cross register bank copy, account for it.
15510       if (LS.canMergeExpensiveCrossRegisterBankCopy())
15511         ++CrossRegisterBanksCopies;
15512     }
15513 
operator +=__anonf026c69c2b11::LoadedSlice::Cost15514     Cost &operator+=(const Cost &RHS) {
15515       Loads += RHS.Loads;
15516       Truncates += RHS.Truncates;
15517       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
15518       ZExts += RHS.ZExts;
15519       Shift += RHS.Shift;
15520       return *this;
15521     }
15522 
operator ==__anonf026c69c2b11::LoadedSlice::Cost15523     bool operator==(const Cost &RHS) const {
15524       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
15525              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
15526              ZExts == RHS.ZExts && Shift == RHS.Shift;
15527     }
15528 
operator !=__anonf026c69c2b11::LoadedSlice::Cost15529     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
15530 
operator <__anonf026c69c2b11::LoadedSlice::Cost15531     bool operator<(const Cost &RHS) const {
15532       // Assume cross register banks copies are as expensive as loads.
15533       // FIXME: Do we want some more target hooks?
15534       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
15535       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
15536       // Unless we are optimizing for code size, consider the
15537       // expensive operation first.
15538       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
15539         return ExpensiveOpsLHS < ExpensiveOpsRHS;
15540       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
15541              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
15542     }
15543 
operator >__anonf026c69c2b11::LoadedSlice::Cost15544     bool operator>(const Cost &RHS) const { return RHS < *this; }
15545 
operator <=__anonf026c69c2b11::LoadedSlice::Cost15546     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
15547 
operator >=__anonf026c69c2b11::LoadedSlice::Cost15548     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
15549   };
15550 
15551   // The last instruction that represent the slice. This should be a
15552   // truncate instruction.
15553   SDNode *Inst;
15554 
15555   // The original load instruction.
15556   LoadSDNode *Origin;
15557 
15558   // The right shift amount in bits from the original load.
15559   unsigned Shift;
15560 
15561   // The DAG from which Origin came from.
15562   // This is used to get some contextual information about legal types, etc.
15563   SelectionDAG *DAG;
15564 
LoadedSlice__anonf026c69c2b11::LoadedSlice15565   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
15566               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
15567       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
15568 
15569   /// Get the bits used in a chunk of bits \p BitWidth large.
15570   /// \return Result is \p BitWidth and has used bits set to 1 and
15571   ///         not used bits set to 0.
getUsedBits__anonf026c69c2b11::LoadedSlice15572   APInt getUsedBits() const {
15573     // Reproduce the trunc(lshr) sequence:
15574     // - Start from the truncated value.
15575     // - Zero extend to the desired bit width.
15576     // - Shift left.
15577     assert(Origin && "No original load to compare against.");
15578     unsigned BitWidth = Origin->getValueSizeInBits(0);
15579     assert(Inst && "This slice is not bound to an instruction");
15580     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
15581            "Extracted slice is bigger than the whole type!");
15582     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
15583     UsedBits.setAllBits();
15584     UsedBits = UsedBits.zext(BitWidth);
15585     UsedBits <<= Shift;
15586     return UsedBits;
15587   }
15588 
15589   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anonf026c69c2b11::LoadedSlice15590   unsigned getLoadedSize() const {
15591     unsigned SliceSize = getUsedBits().countPopulation();
15592     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
15593     return SliceSize / 8;
15594   }
15595 
15596   /// Get the type that will be loaded for this slice.
15597   /// Note: This may not be the final type for the slice.
getLoadedType__anonf026c69c2b11::LoadedSlice15598   EVT getLoadedType() const {
15599     assert(DAG && "Missing context");
15600     LLVMContext &Ctxt = *DAG->getContext();
15601     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
15602   }
15603 
15604   /// Get the alignment of the load used for this slice.
getAlign__anonf026c69c2b11::LoadedSlice15605   Align getAlign() const {
15606     Align Alignment = Origin->getAlign();
15607     uint64_t Offset = getOffsetFromBase();
15608     if (Offset != 0)
15609       Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
15610     return Alignment;
15611   }
15612 
15613   /// Check if this slice can be rewritten with legal operations.
isLegal__anonf026c69c2b11::LoadedSlice15614   bool isLegal() const {
15615     // An invalid slice is not legal.
15616     if (!Origin || !Inst || !DAG)
15617       return false;
15618 
15619     // Offsets are for indexed load only, we do not handle that.
15620     if (!Origin->getOffset().isUndef())
15621       return false;
15622 
15623     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
15624 
15625     // Check that the type is legal.
15626     EVT SliceType = getLoadedType();
15627     if (!TLI.isTypeLegal(SliceType))
15628       return false;
15629 
15630     // Check that the load is legal for this type.
15631     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
15632       return false;
15633 
15634     // Check that the offset can be computed.
15635     // 1. Check its type.
15636     EVT PtrType = Origin->getBasePtr().getValueType();
15637     if (PtrType == MVT::Untyped || PtrType.isExtended())
15638       return false;
15639 
15640     // 2. Check that it fits in the immediate.
15641     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
15642       return false;
15643 
15644     // 3. Check that the computation is legal.
15645     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
15646       return false;
15647 
15648     // Check that the zext is legal if it needs one.
15649     EVT TruncateType = Inst->getValueType(0);
15650     if (TruncateType != SliceType &&
15651         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
15652       return false;
15653 
15654     return true;
15655   }
15656 
15657   /// Get the offset in bytes of this slice in the original chunk of
15658   /// bits.
15659   /// \pre DAG != nullptr.
getOffsetFromBase__anonf026c69c2b11::LoadedSlice15660   uint64_t getOffsetFromBase() const {
15661     assert(DAG && "Missing context.");
15662     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
15663     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
15664     uint64_t Offset = Shift / 8;
15665     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
15666     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
15667            "The size of the original loaded type is not a multiple of a"
15668            " byte.");
15669     // If Offset is bigger than TySizeInBytes, it means we are loading all
15670     // zeros. This should have been optimized before in the process.
15671     assert(TySizeInBytes > Offset &&
15672            "Invalid shift amount for given loaded size");
15673     if (IsBigEndian)
15674       Offset = TySizeInBytes - Offset - getLoadedSize();
15675     return Offset;
15676   }
15677 
15678   /// Generate the sequence of instructions to load the slice
15679   /// represented by this object and redirect the uses of this slice to
15680   /// this new sequence of instructions.
15681   /// \pre this->Inst && this->Origin are valid Instructions and this
15682   /// object passed the legal check: LoadedSlice::isLegal returned true.
15683   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anonf026c69c2b11::LoadedSlice15684   SDValue loadSlice() const {
15685     assert(Inst && Origin && "Unable to replace a non-existing slice.");
15686     const SDValue &OldBaseAddr = Origin->getBasePtr();
15687     SDValue BaseAddr = OldBaseAddr;
15688     // Get the offset in that chunk of bytes w.r.t. the endianness.
15689     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
15690     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
15691     if (Offset) {
15692       // BaseAddr = BaseAddr + Offset.
15693       EVT ArithType = BaseAddr.getValueType();
15694       SDLoc DL(Origin);
15695       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
15696                               DAG->getConstant(Offset, DL, ArithType));
15697     }
15698 
15699     // Create the type of the loaded slice according to its size.
15700     EVT SliceType = getLoadedType();
15701 
15702     // Create the load for the slice.
15703     SDValue LastInst =
15704         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
15705                      Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
15706                      Origin->getMemOperand()->getFlags());
15707     // If the final type is not the same as the loaded type, this means that
15708     // we have to pad with zero. Create a zero extend for that.
15709     EVT FinalType = Inst->getValueType(0);
15710     if (SliceType != FinalType)
15711       LastInst =
15712           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
15713     return LastInst;
15714   }
15715 
15716   /// Check if this slice can be merged with an expensive cross register
15717   /// bank copy. E.g.,
15718   /// i = load i32
15719   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anonf026c69c2b11::LoadedSlice15720   bool canMergeExpensiveCrossRegisterBankCopy() const {
15721     if (!Inst || !Inst->hasOneUse())
15722       return false;
15723     SDNode *Use = *Inst->use_begin();
15724     if (Use->getOpcode() != ISD::BITCAST)
15725       return false;
15726     assert(DAG && "Missing context");
15727     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
15728     EVT ResVT = Use->getValueType(0);
15729     const TargetRegisterClass *ResRC =
15730         TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
15731     const TargetRegisterClass *ArgRC =
15732         TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
15733                            Use->getOperand(0)->isDivergent());
15734     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
15735       return false;
15736 
15737     // At this point, we know that we perform a cross-register-bank copy.
15738     // Check if it is expensive.
15739     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
15740     // Assume bitcasts are cheap, unless both register classes do not
15741     // explicitly share a common sub class.
15742     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
15743       return false;
15744 
15745     // Check if it will be merged with the load.
15746     // 1. Check the alignment constraint.
15747     Align RequiredAlignment = DAG->getDataLayout().getABITypeAlign(
15748         ResVT.getTypeForEVT(*DAG->getContext()));
15749 
15750     if (RequiredAlignment > getAlign())
15751       return false;
15752 
15753     // 2. Check that the load is a legal operation for that type.
15754     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
15755       return false;
15756 
15757     // 3. Check that we do not have a zext in the way.
15758     if (Inst->getValueType(0) != getLoadedType())
15759       return false;
15760 
15761     return true;
15762   }
15763 };
15764 
15765 } // end anonymous namespace
15766 
15767 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
15768 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)15769 static bool areUsedBitsDense(const APInt &UsedBits) {
15770   // If all the bits are one, this is dense!
15771   if (UsedBits.isAllOnesValue())
15772     return true;
15773 
15774   // Get rid of the unused bits on the right.
15775   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros());
15776   // Get rid of the unused bits on the left.
15777   if (NarrowedUsedBits.countLeadingZeros())
15778     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
15779   // Check that the chunk of bits is completely used.
15780   return NarrowedUsedBits.isAllOnesValue();
15781 }
15782 
15783 /// Check whether or not \p First and \p Second are next to each other
15784 /// in memory. This means that there is no hole between the bits loaded
15785 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)15786 static bool areSlicesNextToEachOther(const LoadedSlice &First,
15787                                      const LoadedSlice &Second) {
15788   assert(First.Origin == Second.Origin && First.Origin &&
15789          "Unable to match different memory origins.");
15790   APInt UsedBits = First.getUsedBits();
15791   assert((UsedBits & Second.getUsedBits()) == 0 &&
15792          "Slices are not supposed to overlap.");
15793   UsedBits |= Second.getUsedBits();
15794   return areUsedBitsDense(UsedBits);
15795 }
15796 
15797 /// Adjust the \p GlobalLSCost according to the target
15798 /// paring capabilities and the layout of the slices.
15799 /// \pre \p GlobalLSCost should account for at least as many loads as
15800 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)15801 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
15802                                  LoadedSlice::Cost &GlobalLSCost) {
15803   unsigned NumberOfSlices = LoadedSlices.size();
15804   // If there is less than 2 elements, no pairing is possible.
15805   if (NumberOfSlices < 2)
15806     return;
15807 
15808   // Sort the slices so that elements that are likely to be next to each
15809   // other in memory are next to each other in the list.
15810   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
15811     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
15812     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
15813   });
15814   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
15815   // First (resp. Second) is the first (resp. Second) potentially candidate
15816   // to be placed in a paired load.
15817   const LoadedSlice *First = nullptr;
15818   const LoadedSlice *Second = nullptr;
15819   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
15820                 // Set the beginning of the pair.
15821                                                            First = Second) {
15822     Second = &LoadedSlices[CurrSlice];
15823 
15824     // If First is NULL, it means we start a new pair.
15825     // Get to the next slice.
15826     if (!First)
15827       continue;
15828 
15829     EVT LoadedType = First->getLoadedType();
15830 
15831     // If the types of the slices are different, we cannot pair them.
15832     if (LoadedType != Second->getLoadedType())
15833       continue;
15834 
15835     // Check if the target supplies paired loads for this type.
15836     Align RequiredAlignment;
15837     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
15838       // move to the next pair, this type is hopeless.
15839       Second = nullptr;
15840       continue;
15841     }
15842     // Check if we meet the alignment requirement.
15843     if (First->getAlign() < RequiredAlignment)
15844       continue;
15845 
15846     // Check that both loads are next to each other in memory.
15847     if (!areSlicesNextToEachOther(*First, *Second))
15848       continue;
15849 
15850     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
15851     --GlobalLSCost.Loads;
15852     // Move to the next pair.
15853     Second = nullptr;
15854   }
15855 }
15856 
15857 /// Check the profitability of all involved LoadedSlice.
15858 /// Currently, it is considered profitable if there is exactly two
15859 /// involved slices (1) which are (2) next to each other in memory, and
15860 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
15861 ///
15862 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
15863 /// the elements themselves.
15864 ///
15865 /// FIXME: When the cost model will be mature enough, we can relax
15866 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)15867 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
15868                                 const APInt &UsedBits, bool ForCodeSize) {
15869   unsigned NumberOfSlices = LoadedSlices.size();
15870   if (StressLoadSlicing)
15871     return NumberOfSlices > 1;
15872 
15873   // Check (1).
15874   if (NumberOfSlices != 2)
15875     return false;
15876 
15877   // Check (2).
15878   if (!areUsedBitsDense(UsedBits))
15879     return false;
15880 
15881   // Check (3).
15882   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
15883   // The original code has one big load.
15884   OrigCost.Loads = 1;
15885   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
15886     const LoadedSlice &LS = LoadedSlices[CurrSlice];
15887     // Accumulate the cost of all the slices.
15888     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
15889     GlobalSlicingCost += SliceCost;
15890 
15891     // Account as cost in the original configuration the gain obtained
15892     // with the current slices.
15893     OrigCost.addSliceGain(LS);
15894   }
15895 
15896   // If the target supports paired load, adjust the cost accordingly.
15897   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
15898   return OrigCost > GlobalSlicingCost;
15899 }
15900 
15901 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
15902 /// operations, split it in the various pieces being extracted.
15903 ///
15904 /// This sort of thing is introduced by SROA.
15905 /// This slicing takes care not to insert overlapping loads.
15906 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)15907 bool DAGCombiner::SliceUpLoad(SDNode *N) {
15908   if (Level < AfterLegalizeDAG)
15909     return false;
15910 
15911   LoadSDNode *LD = cast<LoadSDNode>(N);
15912   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
15913       !LD->getValueType(0).isInteger())
15914     return false;
15915 
15916   // The algorithm to split up a load of a scalable vector into individual
15917   // elements currently requires knowing the length of the loaded type,
15918   // so will need adjusting to work on scalable vectors.
15919   if (LD->getValueType(0).isScalableVector())
15920     return false;
15921 
15922   // Keep track of already used bits to detect overlapping values.
15923   // In that case, we will just abort the transformation.
15924   APInt UsedBits(LD->getValueSizeInBits(0), 0);
15925 
15926   SmallVector<LoadedSlice, 4> LoadedSlices;
15927 
15928   // Check if this load is used as several smaller chunks of bits.
15929   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
15930   // of computation for each trunc.
15931   for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
15932        UI != UIEnd; ++UI) {
15933     // Skip the uses of the chain.
15934     if (UI.getUse().getResNo() != 0)
15935       continue;
15936 
15937     SDNode *User = *UI;
15938     unsigned Shift = 0;
15939 
15940     // Check if this is a trunc(lshr).
15941     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
15942         isa<ConstantSDNode>(User->getOperand(1))) {
15943       Shift = User->getConstantOperandVal(1);
15944       User = *User->use_begin();
15945     }
15946 
15947     // At this point, User is a Truncate, iff we encountered, trunc or
15948     // trunc(lshr).
15949     if (User->getOpcode() != ISD::TRUNCATE)
15950       return false;
15951 
15952     // The width of the type must be a power of 2 and greater than 8-bits.
15953     // Otherwise the load cannot be represented in LLVM IR.
15954     // Moreover, if we shifted with a non-8-bits multiple, the slice
15955     // will be across several bytes. We do not support that.
15956     unsigned Width = User->getValueSizeInBits(0);
15957     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
15958       return false;
15959 
15960     // Build the slice for this chain of computations.
15961     LoadedSlice LS(User, LD, Shift, &DAG);
15962     APInt CurrentUsedBits = LS.getUsedBits();
15963 
15964     // Check if this slice overlaps with another.
15965     if ((CurrentUsedBits & UsedBits) != 0)
15966       return false;
15967     // Update the bits used globally.
15968     UsedBits |= CurrentUsedBits;
15969 
15970     // Check if the new slice would be legal.
15971     if (!LS.isLegal())
15972       return false;
15973 
15974     // Record the slice.
15975     LoadedSlices.push_back(LS);
15976   }
15977 
15978   // Abort slicing if it does not seem to be profitable.
15979   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
15980     return false;
15981 
15982   ++SlicedLoads;
15983 
15984   // Rewrite each chain to use an independent load.
15985   // By construction, each chain can be represented by a unique load.
15986 
15987   // Prepare the argument for the new token factor for all the slices.
15988   SmallVector<SDValue, 8> ArgChains;
15989   for (SmallVectorImpl<LoadedSlice>::const_iterator
15990            LSIt = LoadedSlices.begin(),
15991            LSItEnd = LoadedSlices.end();
15992        LSIt != LSItEnd; ++LSIt) {
15993     SDValue SliceInst = LSIt->loadSlice();
15994     CombineTo(LSIt->Inst, SliceInst, true);
15995     if (SliceInst.getOpcode() != ISD::LOAD)
15996       SliceInst = SliceInst.getOperand(0);
15997     assert(SliceInst->getOpcode() == ISD::LOAD &&
15998            "It takes more than a zext to get to the loaded slice!!");
15999     ArgChains.push_back(SliceInst.getValue(1));
16000   }
16001 
16002   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
16003                               ArgChains);
16004   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
16005   AddToWorklist(Chain.getNode());
16006   return true;
16007 }
16008 
16009 /// Check to see if V is (and load (ptr), imm), where the load is having
16010 /// specific bytes cleared out.  If so, return the byte size being masked out
16011 /// and the shift amount.
16012 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)16013 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
16014   std::pair<unsigned, unsigned> Result(0, 0);
16015 
16016   // Check for the structure we're looking for.
16017   if (V->getOpcode() != ISD::AND ||
16018       !isa<ConstantSDNode>(V->getOperand(1)) ||
16019       !ISD::isNormalLoad(V->getOperand(0).getNode()))
16020     return Result;
16021 
16022   // Check the chain and pointer.
16023   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
16024   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
16025 
16026   // This only handles simple types.
16027   if (V.getValueType() != MVT::i16 &&
16028       V.getValueType() != MVT::i32 &&
16029       V.getValueType() != MVT::i64)
16030     return Result;
16031 
16032   // Check the constant mask.  Invert it so that the bits being masked out are
16033   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
16034   // follow the sign bit for uniformity.
16035   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
16036   unsigned NotMaskLZ = countLeadingZeros(NotMask);
16037   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
16038   unsigned NotMaskTZ = countTrailingZeros(NotMask);
16039   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
16040   if (NotMaskLZ == 64) return Result;  // All zero mask.
16041 
16042   // See if we have a continuous run of bits.  If so, we have 0*1+0*
16043   if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
16044     return Result;
16045 
16046   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
16047   if (V.getValueType() != MVT::i64 && NotMaskLZ)
16048     NotMaskLZ -= 64-V.getValueSizeInBits();
16049 
16050   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
16051   switch (MaskedBytes) {
16052   case 1:
16053   case 2:
16054   case 4: break;
16055   default: return Result; // All one mask, or 5-byte mask.
16056   }
16057 
16058   // Verify that the first bit starts at a multiple of mask so that the access
16059   // is aligned the same as the access width.
16060   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
16061 
16062   // For narrowing to be valid, it must be the case that the load the
16063   // immediately preceding memory operation before the store.
16064   if (LD == Chain.getNode())
16065     ; // ok.
16066   else if (Chain->getOpcode() == ISD::TokenFactor &&
16067            SDValue(LD, 1).hasOneUse()) {
16068     // LD has only 1 chain use so they are no indirect dependencies.
16069     if (!LD->isOperandOf(Chain.getNode()))
16070       return Result;
16071   } else
16072     return Result; // Fail.
16073 
16074   Result.first = MaskedBytes;
16075   Result.second = NotMaskTZ/8;
16076   return Result;
16077 }
16078 
16079 /// Check to see if IVal is something that provides a value as specified by
16080 /// MaskInfo. If so, replace the specified store with a narrower store of
16081 /// truncated IVal.
16082 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)16083 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
16084                                 SDValue IVal, StoreSDNode *St,
16085                                 DAGCombiner *DC) {
16086   unsigned NumBytes = MaskInfo.first;
16087   unsigned ByteShift = MaskInfo.second;
16088   SelectionDAG &DAG = DC->getDAG();
16089 
16090   // Check to see if IVal is all zeros in the part being masked in by the 'or'
16091   // that uses this.  If not, this is not a replacement.
16092   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
16093                                   ByteShift*8, (ByteShift+NumBytes)*8);
16094   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
16095 
16096   // Check that it is legal on the target to do this.  It is legal if the new
16097   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
16098   // legalization (and the target doesn't explicitly think this is a bad idea).
16099   MVT VT = MVT::getIntegerVT(NumBytes * 8);
16100   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
16101   if (!DC->isTypeLegal(VT))
16102     return SDValue();
16103   if (St->getMemOperand() &&
16104       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
16105                               *St->getMemOperand()))
16106     return SDValue();
16107 
16108   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
16109   // shifted by ByteShift and truncated down to NumBytes.
16110   if (ByteShift) {
16111     SDLoc DL(IVal);
16112     IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
16113                        DAG.getConstant(ByteShift*8, DL,
16114                                     DC->getShiftAmountTy(IVal.getValueType())));
16115   }
16116 
16117   // Figure out the offset for the store and the alignment of the access.
16118   unsigned StOffset;
16119   if (DAG.getDataLayout().isLittleEndian())
16120     StOffset = ByteShift;
16121   else
16122     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
16123 
16124   SDValue Ptr = St->getBasePtr();
16125   if (StOffset) {
16126     SDLoc DL(IVal);
16127     Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(StOffset), DL);
16128   }
16129 
16130   // Truncate down to the new size.
16131   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
16132 
16133   ++OpsNarrowed;
16134   return DAG
16135       .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
16136                 St->getPointerInfo().getWithOffset(StOffset),
16137                 St->getOriginalAlign());
16138 }
16139 
16140 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
16141 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
16142 /// narrowing the load and store if it would end up being a win for performance
16143 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)16144 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
16145   StoreSDNode *ST  = cast<StoreSDNode>(N);
16146   if (!ST->isSimple())
16147     return SDValue();
16148 
16149   SDValue Chain = ST->getChain();
16150   SDValue Value = ST->getValue();
16151   SDValue Ptr   = ST->getBasePtr();
16152   EVT VT = Value.getValueType();
16153 
16154   if (ST->isTruncatingStore() || VT.isVector() || !Value.hasOneUse())
16155     return SDValue();
16156 
16157   unsigned Opc = Value.getOpcode();
16158 
16159   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
16160   // is a byte mask indicating a consecutive number of bytes, check to see if
16161   // Y is known to provide just those bytes.  If so, we try to replace the
16162   // load + replace + store sequence with a single (narrower) store, which makes
16163   // the load dead.
16164   if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
16165     std::pair<unsigned, unsigned> MaskedLoad;
16166     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
16167     if (MaskedLoad.first)
16168       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
16169                                                   Value.getOperand(1), ST,this))
16170         return NewST;
16171 
16172     // Or is commutative, so try swapping X and Y.
16173     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
16174     if (MaskedLoad.first)
16175       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
16176                                                   Value.getOperand(0), ST,this))
16177         return NewST;
16178   }
16179 
16180   if (!EnableReduceLoadOpStoreWidth)
16181     return SDValue();
16182 
16183   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
16184       Value.getOperand(1).getOpcode() != ISD::Constant)
16185     return SDValue();
16186 
16187   SDValue N0 = Value.getOperand(0);
16188   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
16189       Chain == SDValue(N0.getNode(), 1)) {
16190     LoadSDNode *LD = cast<LoadSDNode>(N0);
16191     if (LD->getBasePtr() != Ptr ||
16192         LD->getPointerInfo().getAddrSpace() !=
16193         ST->getPointerInfo().getAddrSpace())
16194       return SDValue();
16195 
16196     // Find the type to narrow it the load / op / store to.
16197     SDValue N1 = Value.getOperand(1);
16198     unsigned BitWidth = N1.getValueSizeInBits();
16199     APInt Imm = cast<ConstantSDNode>(N1)->getAPIntValue();
16200     if (Opc == ISD::AND)
16201       Imm ^= APInt::getAllOnesValue(BitWidth);
16202     if (Imm == 0 || Imm.isAllOnesValue())
16203       return SDValue();
16204     unsigned ShAmt = Imm.countTrailingZeros();
16205     unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1;
16206     unsigned NewBW = NextPowerOf2(MSB - ShAmt);
16207     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
16208     // The narrowing should be profitable, the load/store operation should be
16209     // legal (or custom) and the store size should be equal to the NewVT width.
16210     while (NewBW < BitWidth &&
16211            (NewVT.getStoreSizeInBits() != NewBW ||
16212             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
16213             !TLI.isNarrowingProfitable(VT, NewVT))) {
16214       NewBW = NextPowerOf2(NewBW);
16215       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
16216     }
16217     if (NewBW >= BitWidth)
16218       return SDValue();
16219 
16220     // If the lsb changed does not start at the type bitwidth boundary,
16221     // start at the previous one.
16222     if (ShAmt % NewBW)
16223       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
16224     APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
16225                                    std::min(BitWidth, ShAmt + NewBW));
16226     if ((Imm & Mask) == Imm) {
16227       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
16228       if (Opc == ISD::AND)
16229         NewImm ^= APInt::getAllOnesValue(NewBW);
16230       uint64_t PtrOff = ShAmt / 8;
16231       // For big endian targets, we need to adjust the offset to the pointer to
16232       // load the correct bytes.
16233       if (DAG.getDataLayout().isBigEndian())
16234         PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
16235 
16236       Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
16237       Type *NewVTTy = NewVT.getTypeForEVT(*DAG.getContext());
16238       if (NewAlign < DAG.getDataLayout().getABITypeAlign(NewVTTy))
16239         return SDValue();
16240 
16241       SDValue NewPtr =
16242           DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOff), SDLoc(LD));
16243       SDValue NewLD =
16244           DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
16245                       LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
16246                       LD->getMemOperand()->getFlags(), LD->getAAInfo());
16247       SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
16248                                    DAG.getConstant(NewImm, SDLoc(Value),
16249                                                    NewVT));
16250       SDValue NewST =
16251           DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
16252                        ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
16253 
16254       AddToWorklist(NewPtr.getNode());
16255       AddToWorklist(NewLD.getNode());
16256       AddToWorklist(NewVal.getNode());
16257       WorklistRemover DeadNodes(*this);
16258       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
16259       ++OpsNarrowed;
16260       return NewST;
16261     }
16262   }
16263 
16264   return SDValue();
16265 }
16266 
16267 /// For a given floating point load / store pair, if the load value isn't used
16268 /// by any other operations, then consider transforming the pair to integer
16269 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)16270 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
16271   StoreSDNode *ST  = cast<StoreSDNode>(N);
16272   SDValue Value = ST->getValue();
16273   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
16274       Value.hasOneUse()) {
16275     LoadSDNode *LD = cast<LoadSDNode>(Value);
16276     EVT VT = LD->getMemoryVT();
16277     if (!VT.isFloatingPoint() ||
16278         VT != ST->getMemoryVT() ||
16279         LD->isNonTemporal() ||
16280         ST->isNonTemporal() ||
16281         LD->getPointerInfo().getAddrSpace() != 0 ||
16282         ST->getPointerInfo().getAddrSpace() != 0)
16283       return SDValue();
16284 
16285     TypeSize VTSize = VT.getSizeInBits();
16286 
16287     // We don't know the size of scalable types at compile time so we cannot
16288     // create an integer of the equivalent size.
16289     if (VTSize.isScalable())
16290       return SDValue();
16291 
16292     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedSize());
16293     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
16294         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
16295         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
16296         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT))
16297       return SDValue();
16298 
16299     Align LDAlign = LD->getAlign();
16300     Align STAlign = ST->getAlign();
16301     Type *IntVTTy = IntVT.getTypeForEVT(*DAG.getContext());
16302     Align ABIAlign = DAG.getDataLayout().getABITypeAlign(IntVTTy);
16303     if (LDAlign < ABIAlign || STAlign < ABIAlign)
16304       return SDValue();
16305 
16306     SDValue NewLD =
16307         DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
16308                     LD->getPointerInfo(), LDAlign);
16309 
16310     SDValue NewST =
16311         DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
16312                      ST->getPointerInfo(), STAlign);
16313 
16314     AddToWorklist(NewLD.getNode());
16315     AddToWorklist(NewST.getNode());
16316     WorklistRemover DeadNodes(*this);
16317     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
16318     ++LdStFP2Int;
16319     return NewST;
16320   }
16321 
16322   return SDValue();
16323 }
16324 
16325 // This is a helper function for visitMUL to check the profitability
16326 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
16327 // MulNode is the original multiply, AddNode is (add x, c1),
16328 // and ConstNode is c2.
16329 //
16330 // If the (add x, c1) has multiple uses, we could increase
16331 // the number of adds if we make this transformation.
16332 // It would only be worth doing this if we can remove a
16333 // multiply in the process. Check for that here.
16334 // To illustrate:
16335 //     (A + c1) * c3
16336 //     (A + c2) * c3
16337 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue & AddNode,SDValue & ConstNode)16338 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
16339                                               SDValue &AddNode,
16340                                               SDValue &ConstNode) {
16341   APInt Val;
16342 
16343   // If the add only has one use, this would be OK to do.
16344   if (AddNode.getNode()->hasOneUse())
16345     return true;
16346 
16347   // Walk all the users of the constant with which we're multiplying.
16348   for (SDNode *Use : ConstNode->uses()) {
16349     if (Use == MulNode) // This use is the one we're on right now. Skip it.
16350       continue;
16351 
16352     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
16353       SDNode *OtherOp;
16354       SDNode *MulVar = AddNode.getOperand(0).getNode();
16355 
16356       // OtherOp is what we're multiplying against the constant.
16357       if (Use->getOperand(0) == ConstNode)
16358         OtherOp = Use->getOperand(1).getNode();
16359       else
16360         OtherOp = Use->getOperand(0).getNode();
16361 
16362       // Check to see if multiply is with the same operand of our "add".
16363       //
16364       //     ConstNode  = CONST
16365       //     Use = ConstNode * A  <-- visiting Use. OtherOp is A.
16366       //     ...
16367       //     AddNode  = (A + c1)  <-- MulVar is A.
16368       //         = AddNode * ConstNode   <-- current visiting instruction.
16369       //
16370       // If we make this transformation, we will have a common
16371       // multiply (ConstNode * A) that we can save.
16372       if (OtherOp == MulVar)
16373         return true;
16374 
16375       // Now check to see if a future expansion will give us a common
16376       // multiply.
16377       //
16378       //     ConstNode  = CONST
16379       //     AddNode    = (A + c1)
16380       //     ...   = AddNode * ConstNode <-- current visiting instruction.
16381       //     ...
16382       //     OtherOp = (A + c2)
16383       //     Use     = OtherOp * ConstNode <-- visiting Use.
16384       //
16385       // If we make this transformation, we will have a common
16386       // multiply (CONST * A) after we also do the same transformation
16387       // to the "t2" instruction.
16388       if (OtherOp->getOpcode() == ISD::ADD &&
16389           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
16390           OtherOp->getOperand(0).getNode() == MulVar)
16391         return true;
16392     }
16393   }
16394 
16395   // Didn't find a case where this would be profitable.
16396   return false;
16397 }
16398 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)16399 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
16400                                          unsigned NumStores) {
16401   SmallVector<SDValue, 8> Chains;
16402   SmallPtrSet<const SDNode *, 8> Visited;
16403   SDLoc StoreDL(StoreNodes[0].MemNode);
16404 
16405   for (unsigned i = 0; i < NumStores; ++i) {
16406     Visited.insert(StoreNodes[i].MemNode);
16407   }
16408 
16409   // don't include nodes that are children or repeated nodes.
16410   for (unsigned i = 0; i < NumStores; ++i) {
16411     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
16412       Chains.push_back(StoreNodes[i].MemNode->getChain());
16413   }
16414 
16415   assert(Chains.size() > 0 && "Chain should have generated a chain");
16416   return DAG.getTokenFactor(StoreDL, Chains);
16417 }
16418 
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)16419 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
16420     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
16421     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
16422   // Make sure we have something to merge.
16423   if (NumStores < 2)
16424     return false;
16425 
16426   // The latest Node in the DAG.
16427   SDLoc DL(StoreNodes[0].MemNode);
16428 
16429   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
16430   unsigned SizeInBits = NumStores * ElementSizeBits;
16431   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
16432 
16433   EVT StoreTy;
16434   if (UseVector) {
16435     unsigned Elts = NumStores * NumMemElts;
16436     // Get the type for the merged vector store.
16437     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
16438   } else
16439     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
16440 
16441   SDValue StoredVal;
16442   if (UseVector) {
16443     if (IsConstantSrc) {
16444       SmallVector<SDValue, 8> BuildVector;
16445       for (unsigned I = 0; I != NumStores; ++I) {
16446         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
16447         SDValue Val = St->getValue();
16448         // If constant is of the wrong type, convert it now.
16449         if (MemVT != Val.getValueType()) {
16450           Val = peekThroughBitcasts(Val);
16451           // Deal with constants of wrong size.
16452           if (ElementSizeBits != Val.getValueSizeInBits()) {
16453             EVT IntMemVT =
16454                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
16455             if (isa<ConstantFPSDNode>(Val)) {
16456               // Not clear how to truncate FP values.
16457               return false;
16458             } else if (auto *C = dyn_cast<ConstantSDNode>(Val))
16459               Val = DAG.getConstant(C->getAPIntValue()
16460                                         .zextOrTrunc(Val.getValueSizeInBits())
16461                                         .zextOrTrunc(ElementSizeBits),
16462                                     SDLoc(C), IntMemVT);
16463           }
16464           // Make sure correctly size type is the correct type.
16465           Val = DAG.getBitcast(MemVT, Val);
16466         }
16467         BuildVector.push_back(Val);
16468       }
16469       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
16470                                                : ISD::BUILD_VECTOR,
16471                               DL, StoreTy, BuildVector);
16472     } else {
16473       SmallVector<SDValue, 8> Ops;
16474       for (unsigned i = 0; i < NumStores; ++i) {
16475         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
16476         SDValue Val = peekThroughBitcasts(St->getValue());
16477         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
16478         // type MemVT. If the underlying value is not the correct
16479         // type, but it is an extraction of an appropriate vector we
16480         // can recast Val to be of the correct type. This may require
16481         // converting between EXTRACT_VECTOR_ELT and
16482         // EXTRACT_SUBVECTOR.
16483         if ((MemVT != Val.getValueType()) &&
16484             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
16485              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
16486           EVT MemVTScalarTy = MemVT.getScalarType();
16487           // We may need to add a bitcast here to get types to line up.
16488           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
16489             Val = DAG.getBitcast(MemVT, Val);
16490           } else {
16491             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
16492                                             : ISD::EXTRACT_VECTOR_ELT;
16493             SDValue Vec = Val.getOperand(0);
16494             SDValue Idx = Val.getOperand(1);
16495             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
16496           }
16497         }
16498         Ops.push_back(Val);
16499       }
16500 
16501       // Build the extracted vector elements back into a vector.
16502       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
16503                                                : ISD::BUILD_VECTOR,
16504                               DL, StoreTy, Ops);
16505     }
16506   } else {
16507     // We should always use a vector store when merging extracted vector
16508     // elements, so this path implies a store of constants.
16509     assert(IsConstantSrc && "Merged vector elements should use vector store");
16510 
16511     APInt StoreInt(SizeInBits, 0);
16512 
16513     // Construct a single integer constant which is made of the smaller
16514     // constant inputs.
16515     bool IsLE = DAG.getDataLayout().isLittleEndian();
16516     for (unsigned i = 0; i < NumStores; ++i) {
16517       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
16518       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
16519 
16520       SDValue Val = St->getValue();
16521       Val = peekThroughBitcasts(Val);
16522       StoreInt <<= ElementSizeBits;
16523       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
16524         StoreInt |= C->getAPIntValue()
16525                         .zextOrTrunc(ElementSizeBits)
16526                         .zextOrTrunc(SizeInBits);
16527       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
16528         StoreInt |= C->getValueAPF()
16529                         .bitcastToAPInt()
16530                         .zextOrTrunc(ElementSizeBits)
16531                         .zextOrTrunc(SizeInBits);
16532         // If fp truncation is necessary give up for now.
16533         if (MemVT.getSizeInBits() != ElementSizeBits)
16534           return false;
16535       } else {
16536         llvm_unreachable("Invalid constant element type");
16537       }
16538     }
16539 
16540     // Create the new Load and Store operations.
16541     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
16542   }
16543 
16544   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
16545   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
16546 
16547   // make sure we use trunc store if it's necessary to be legal.
16548   SDValue NewStore;
16549   if (!UseTrunc) {
16550     NewStore =
16551         DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
16552                      FirstInChain->getPointerInfo(), FirstInChain->getAlign());
16553   } else { // Must be realized as a trunc store
16554     EVT LegalizedStoredValTy =
16555         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
16556     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
16557     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
16558     SDValue ExtendedStoreVal =
16559         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
16560                         LegalizedStoredValTy);
16561     NewStore = DAG.getTruncStore(
16562         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
16563         FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
16564         FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags());
16565   }
16566 
16567   // Replace all merged stores with the new store.
16568   for (unsigned i = 0; i < NumStores; ++i)
16569     CombineTo(StoreNodes[i].MemNode, NewStore);
16570 
16571   AddToWorklist(NewChain.getNode());
16572   return true;
16573 }
16574 
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)16575 void DAGCombiner::getStoreMergeCandidates(
16576     StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
16577     SDNode *&RootNode) {
16578   // This holds the base pointer, index, and the offset in bytes from the base
16579   // pointer. We must have a base and an offset. Do not handle stores to undef
16580   // base pointers.
16581   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
16582   if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
16583     return;
16584 
16585   SDValue Val = peekThroughBitcasts(St->getValue());
16586   StoreSource StoreSrc = getStoreSource(Val);
16587   assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
16588 
16589   // Match on loadbaseptr if relevant.
16590   EVT MemVT = St->getMemoryVT();
16591   BaseIndexOffset LBasePtr;
16592   EVT LoadVT;
16593   if (StoreSrc == StoreSource::Load) {
16594     auto *Ld = cast<LoadSDNode>(Val);
16595     LBasePtr = BaseIndexOffset::match(Ld, DAG);
16596     LoadVT = Ld->getMemoryVT();
16597     // Load and store should be the same type.
16598     if (MemVT != LoadVT)
16599       return;
16600     // Loads must only have one use.
16601     if (!Ld->hasNUsesOfValue(1, 0))
16602       return;
16603     // The memory operands must not be volatile/indexed/atomic.
16604     // TODO: May be able to relax for unordered atomics (see D66309)
16605     if (!Ld->isSimple() || Ld->isIndexed())
16606       return;
16607   }
16608   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
16609                             int64_t &Offset) -> bool {
16610     // The memory operands must not be volatile/indexed/atomic.
16611     // TODO: May be able to relax for unordered atomics (see D66309)
16612     if (!Other->isSimple() || Other->isIndexed())
16613       return false;
16614     // Don't mix temporal stores with non-temporal stores.
16615     if (St->isNonTemporal() != Other->isNonTemporal())
16616       return false;
16617     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
16618     // Allow merging constants of different types as integers.
16619     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
16620                                            : Other->getMemoryVT() != MemVT;
16621     switch (StoreSrc) {
16622     case StoreSource::Load: {
16623       if (NoTypeMatch)
16624         return false;
16625       // The Load's Base Ptr must also match.
16626       auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
16627       if (!OtherLd)
16628         return false;
16629       BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
16630       if (LoadVT != OtherLd->getMemoryVT())
16631         return false;
16632       // Loads must only have one use.
16633       if (!OtherLd->hasNUsesOfValue(1, 0))
16634         return false;
16635       // The memory operands must not be volatile/indexed/atomic.
16636       // TODO: May be able to relax for unordered atomics (see D66309)
16637       if (!OtherLd->isSimple() || OtherLd->isIndexed())
16638         return false;
16639       // Don't mix temporal loads with non-temporal loads.
16640       if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
16641         return false;
16642       if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
16643         return false;
16644       break;
16645     }
16646     case StoreSource::Constant:
16647       if (NoTypeMatch)
16648         return false;
16649       if (!(isa<ConstantSDNode>(OtherBC) || isa<ConstantFPSDNode>(OtherBC)))
16650         return false;
16651       break;
16652     case StoreSource::Extract:
16653       // Do not merge truncated stores here.
16654       if (Other->isTruncatingStore())
16655         return false;
16656       if (!MemVT.bitsEq(OtherBC.getValueType()))
16657         return false;
16658       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
16659           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
16660         return false;
16661       break;
16662     default:
16663       llvm_unreachable("Unhandled store source for merging");
16664     }
16665     Ptr = BaseIndexOffset::match(Other, DAG);
16666     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
16667   };
16668 
16669   // Check if the pair of StoreNode and the RootNode already bail out many
16670   // times which is over the limit in dependence check.
16671   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
16672                                         SDNode *RootNode) -> bool {
16673     auto RootCount = StoreRootCountMap.find(StoreNode);
16674     return RootCount != StoreRootCountMap.end() &&
16675            RootCount->second.first == RootNode &&
16676            RootCount->second.second > StoreMergeDependenceLimit;
16677   };
16678 
16679   auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
16680     // This must be a chain use.
16681     if (UseIter.getOperandNo() != 0)
16682       return;
16683     if (auto *OtherStore = dyn_cast<StoreSDNode>(*UseIter)) {
16684       BaseIndexOffset Ptr;
16685       int64_t PtrDiff;
16686       if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
16687           !OverLimitInDependenceCheck(OtherStore, RootNode))
16688         StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
16689     }
16690   };
16691 
16692   // We looking for a root node which is an ancestor to all mergable
16693   // stores. We search up through a load, to our root and then down
16694   // through all children. For instance we will find Store{1,2,3} if
16695   // St is Store1, Store2. or Store3 where the root is not a load
16696   // which always true for nonvolatile ops. TODO: Expand
16697   // the search to find all valid candidates through multiple layers of loads.
16698   //
16699   // Root
16700   // |-------|-------|
16701   // Load    Load    Store3
16702   // |       |
16703   // Store1   Store2
16704   //
16705   // FIXME: We should be able to climb and
16706   // descend TokenFactors to find candidates as well.
16707 
16708   RootNode = St->getChain().getNode();
16709 
16710   unsigned NumNodesExplored = 0;
16711   const unsigned MaxSearchNodes = 1024;
16712   if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
16713     RootNode = Ldn->getChain().getNode();
16714     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
16715          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
16716       if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) { // walk down chain
16717         for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
16718           TryToAddCandidate(I2);
16719       }
16720     }
16721   } else {
16722     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
16723          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
16724       TryToAddCandidate(I);
16725   }
16726 }
16727 
16728 // We need to check that merging these stores does not cause a loop in
16729 // the DAG. Any store candidate may depend on another candidate
16730 // indirectly through its operand (we already consider dependencies
16731 // through the chain). Check in parallel by searching up from
16732 // non-chain operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)16733 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
16734     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
16735     SDNode *RootNode) {
16736   // FIXME: We should be able to truncate a full search of
16737   // predecessors by doing a BFS and keeping tabs the originating
16738   // stores from which worklist nodes come from in a similar way to
16739   // TokenFactor simplfication.
16740 
16741   SmallPtrSet<const SDNode *, 32> Visited;
16742   SmallVector<const SDNode *, 8> Worklist;
16743 
16744   // RootNode is a predecessor to all candidates so we need not search
16745   // past it. Add RootNode (peeking through TokenFactors). Do not count
16746   // these towards size check.
16747 
16748   Worklist.push_back(RootNode);
16749   while (!Worklist.empty()) {
16750     auto N = Worklist.pop_back_val();
16751     if (!Visited.insert(N).second)
16752       continue; // Already present in Visited.
16753     if (N->getOpcode() == ISD::TokenFactor) {
16754       for (SDValue Op : N->ops())
16755         Worklist.push_back(Op.getNode());
16756     }
16757   }
16758 
16759   // Don't count pruning nodes towards max.
16760   unsigned int Max = 1024 + Visited.size();
16761   // Search Ops of store candidates.
16762   for (unsigned i = 0; i < NumStores; ++i) {
16763     SDNode *N = StoreNodes[i].MemNode;
16764     // Of the 4 Store Operands:
16765     //   * Chain (Op 0) -> We have already considered these
16766     //                    in candidate selection and can be
16767     //                    safely ignored
16768     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
16769     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
16770     //                       but aren't necessarily fromt the same base node, so
16771     //                       cycles possible (e.g. via indexed store).
16772     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
16773     //               non-indexed stores). Not constant on all targets (e.g. ARM)
16774     //               and so can participate in a cycle.
16775     for (unsigned j = 1; j < N->getNumOperands(); ++j)
16776       Worklist.push_back(N->getOperand(j).getNode());
16777   }
16778   // Search through DAG. We can stop early if we find a store node.
16779   for (unsigned i = 0; i < NumStores; ++i)
16780     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
16781                                      Max)) {
16782       // If the searching bail out, record the StoreNode and RootNode in the
16783       // StoreRootCountMap. If we have seen the pair many times over a limit,
16784       // we won't add the StoreNode into StoreNodes set again.
16785       if (Visited.size() >= Max) {
16786         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
16787         if (RootCount.first == RootNode)
16788           RootCount.second++;
16789         else
16790           RootCount = {RootNode, 1};
16791       }
16792       return false;
16793     }
16794   return true;
16795 }
16796 
16797 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const16798 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
16799                                   int64_t ElementSizeBytes) const {
16800   while (true) {
16801     // Find a store past the width of the first store.
16802     size_t StartIdx = 0;
16803     while ((StartIdx + 1 < StoreNodes.size()) &&
16804            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
16805               StoreNodes[StartIdx + 1].OffsetFromBase)
16806       ++StartIdx;
16807 
16808     // Bail if we don't have enough candidates to merge.
16809     if (StartIdx + 1 >= StoreNodes.size())
16810       return 0;
16811 
16812     // Trim stores that overlapped with the first store.
16813     if (StartIdx)
16814       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
16815 
16816     // Scan the memory operations on the chain and find the first
16817     // non-consecutive store memory address.
16818     unsigned NumConsecutiveStores = 1;
16819     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
16820     // Check that the addresses are consecutive starting from the second
16821     // element in the list of stores.
16822     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
16823       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
16824       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
16825         break;
16826       NumConsecutiveStores = i + 1;
16827     }
16828     if (NumConsecutiveStores > 1)
16829       return NumConsecutiveStores;
16830 
16831     // There are no consecutive stores at the start of the list.
16832     // Remove the first store and try again.
16833     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
16834   }
16835 }
16836 
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)16837 bool DAGCombiner::tryStoreMergeOfConstants(
16838     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
16839     EVT MemVT, SDNode *RootNode, bool AllowVectors) {
16840   LLVMContext &Context = *DAG.getContext();
16841   const DataLayout &DL = DAG.getDataLayout();
16842   int64_t ElementSizeBytes = MemVT.getStoreSize();
16843   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
16844   bool MadeChange = false;
16845 
16846   // Store the constants into memory as one consecutive store.
16847   while (NumConsecutiveStores >= 2) {
16848     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
16849     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
16850     unsigned FirstStoreAlign = FirstInChain->getAlignment();
16851     unsigned LastLegalType = 1;
16852     unsigned LastLegalVectorType = 1;
16853     bool LastIntegerTrunc = false;
16854     bool NonZero = false;
16855     unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
16856     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
16857       StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
16858       SDValue StoredVal = ST->getValue();
16859       bool IsElementZero = false;
16860       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
16861         IsElementZero = C->isNullValue();
16862       else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
16863         IsElementZero = C->getConstantFPValue()->isNullValue();
16864       if (IsElementZero) {
16865         if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
16866           FirstZeroAfterNonZero = i;
16867       }
16868       NonZero |= !IsElementZero;
16869 
16870       // Find a legal type for the constant store.
16871       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
16872       EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
16873       bool IsFast = false;
16874 
16875       // Break early when size is too large to be legal.
16876       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
16877         break;
16878 
16879       if (TLI.isTypeLegal(StoreTy) &&
16880           TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
16881           TLI.allowsMemoryAccess(Context, DL, StoreTy,
16882                                  *FirstInChain->getMemOperand(), &IsFast) &&
16883           IsFast) {
16884         LastIntegerTrunc = false;
16885         LastLegalType = i + 1;
16886         // Or check whether a truncstore is legal.
16887       } else if (TLI.getTypeAction(Context, StoreTy) ==
16888                  TargetLowering::TypePromoteInteger) {
16889         EVT LegalizedStoredValTy =
16890             TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
16891         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
16892             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
16893             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16894                                    *FirstInChain->getMemOperand(), &IsFast) &&
16895             IsFast) {
16896           LastIntegerTrunc = true;
16897           LastLegalType = i + 1;
16898         }
16899       }
16900 
16901       // We only use vectors if the constant is known to be zero or the
16902       // target allows it and the function is not marked with the
16903       // noimplicitfloat attribute.
16904       if ((!NonZero ||
16905            TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
16906           AllowVectors) {
16907         // Find a legal type for the vector store.
16908         unsigned Elts = (i + 1) * NumMemElts;
16909         EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
16910         if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
16911             TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
16912             TLI.allowsMemoryAccess(Context, DL, Ty,
16913                                    *FirstInChain->getMemOperand(), &IsFast) &&
16914             IsFast)
16915           LastLegalVectorType = i + 1;
16916       }
16917     }
16918 
16919     bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
16920     unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
16921 
16922     // Check if we found a legal integer type that creates a meaningful
16923     // merge.
16924     if (NumElem < 2) {
16925       // We know that candidate stores are in order and of correct
16926       // shape. While there is no mergeable sequence from the
16927       // beginning one may start later in the sequence. The only
16928       // reason a merge of size N could have failed where another of
16929       // the same size would not have, is if the alignment has
16930       // improved or we've dropped a non-zero value. Drop as many
16931       // candidates as we can here.
16932       unsigned NumSkip = 1;
16933       while ((NumSkip < NumConsecutiveStores) &&
16934              (NumSkip < FirstZeroAfterNonZero) &&
16935              (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
16936         NumSkip++;
16937 
16938       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
16939       NumConsecutiveStores -= NumSkip;
16940       continue;
16941     }
16942 
16943     // Check that we can merge these candidates without causing a cycle.
16944     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
16945                                                   RootNode)) {
16946       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
16947       NumConsecutiveStores -= NumElem;
16948       continue;
16949     }
16950 
16951     MadeChange |= mergeStoresOfConstantsOrVecElts(
16952         StoreNodes, MemVT, NumElem, true, UseVector, LastIntegerTrunc);
16953 
16954     // Remove merged stores for next iteration.
16955     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
16956     NumConsecutiveStores -= NumElem;
16957   }
16958   return MadeChange;
16959 }
16960 
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)16961 bool DAGCombiner::tryStoreMergeOfExtracts(
16962     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
16963     EVT MemVT, SDNode *RootNode) {
16964   LLVMContext &Context = *DAG.getContext();
16965   const DataLayout &DL = DAG.getDataLayout();
16966   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
16967   bool MadeChange = false;
16968 
16969   // Loop on Consecutive Stores on success.
16970   while (NumConsecutiveStores >= 2) {
16971     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
16972     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
16973     unsigned FirstStoreAlign = FirstInChain->getAlignment();
16974     unsigned NumStoresToMerge = 1;
16975     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
16976       // Find a legal type for the vector store.
16977       unsigned Elts = (i + 1) * NumMemElts;
16978       EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
16979       bool IsFast = false;
16980 
16981       // Break early when size is too large to be legal.
16982       if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
16983         break;
16984 
16985       if (TLI.isTypeLegal(Ty) && TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
16986           TLI.allowsMemoryAccess(Context, DL, Ty,
16987                                  *FirstInChain->getMemOperand(), &IsFast) &&
16988           IsFast)
16989         NumStoresToMerge = i + 1;
16990     }
16991 
16992     // Check if we found a legal integer type creating a meaningful
16993     // merge.
16994     if (NumStoresToMerge < 2) {
16995       // We know that candidate stores are in order and of correct
16996       // shape. While there is no mergeable sequence from the
16997       // beginning one may start later in the sequence. The only
16998       // reason a merge of size N could have failed where another of
16999       // the same size would not have, is if the alignment has
17000       // improved. Drop as many candidates as we can here.
17001       unsigned NumSkip = 1;
17002       while ((NumSkip < NumConsecutiveStores) &&
17003              (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
17004         NumSkip++;
17005 
17006       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
17007       NumConsecutiveStores -= NumSkip;
17008       continue;
17009     }
17010 
17011     // Check that we can merge these candidates without causing a cycle.
17012     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
17013                                                   RootNode)) {
17014       StoreNodes.erase(StoreNodes.begin(),
17015                        StoreNodes.begin() + NumStoresToMerge);
17016       NumConsecutiveStores -= NumStoresToMerge;
17017       continue;
17018     }
17019 
17020     MadeChange |= mergeStoresOfConstantsOrVecElts(
17021         StoreNodes, MemVT, NumStoresToMerge, false, true, false);
17022 
17023     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
17024     NumConsecutiveStores -= NumStoresToMerge;
17025   }
17026   return MadeChange;
17027 }
17028 
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)17029 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
17030                                        unsigned NumConsecutiveStores, EVT MemVT,
17031                                        SDNode *RootNode, bool AllowVectors,
17032                                        bool IsNonTemporalStore,
17033                                        bool IsNonTemporalLoad) {
17034   LLVMContext &Context = *DAG.getContext();
17035   const DataLayout &DL = DAG.getDataLayout();
17036   int64_t ElementSizeBytes = MemVT.getStoreSize();
17037   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
17038   bool MadeChange = false;
17039 
17040   int64_t StartAddress = StoreNodes[0].OffsetFromBase;
17041 
17042   // Look for load nodes which are used by the stored values.
17043   SmallVector<MemOpLink, 8> LoadNodes;
17044 
17045   // Find acceptable loads. Loads need to have the same chain (token factor),
17046   // must not be zext, volatile, indexed, and they must be consecutive.
17047   BaseIndexOffset LdBasePtr;
17048 
17049   for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
17050     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
17051     SDValue Val = peekThroughBitcasts(St->getValue());
17052     LoadSDNode *Ld = cast<LoadSDNode>(Val);
17053 
17054     BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
17055     // If this is not the first ptr that we check.
17056     int64_t LdOffset = 0;
17057     if (LdBasePtr.getBase().getNode()) {
17058       // The base ptr must be the same.
17059       if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
17060         break;
17061     } else {
17062       // Check that all other base pointers are the same as this one.
17063       LdBasePtr = LdPtr;
17064     }
17065 
17066     // We found a potential memory operand to merge.
17067     LoadNodes.push_back(MemOpLink(Ld, LdOffset));
17068   }
17069 
17070   while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
17071     Align RequiredAlignment;
17072     bool NeedRotate = false;
17073     if (LoadNodes.size() == 2) {
17074       // If we have load/store pair instructions and we only have two values,
17075       // don't bother merging.
17076       if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
17077           StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
17078         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
17079         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
17080         break;
17081       }
17082       // If the loads are reversed, see if we can rotate the halves into place.
17083       int64_t Offset0 = LoadNodes[0].OffsetFromBase;
17084       int64_t Offset1 = LoadNodes[1].OffsetFromBase;
17085       EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
17086       if (Offset0 - Offset1 == ElementSizeBytes &&
17087           (hasOperation(ISD::ROTL, PairVT) ||
17088            hasOperation(ISD::ROTR, PairVT))) {
17089         std::swap(LoadNodes[0], LoadNodes[1]);
17090         NeedRotate = true;
17091       }
17092     }
17093     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
17094     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
17095     Align FirstStoreAlign = FirstInChain->getAlign();
17096     LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
17097 
17098     // Scan the memory operations on the chain and find the first
17099     // non-consecutive load memory address. These variables hold the index in
17100     // the store node array.
17101 
17102     unsigned LastConsecutiveLoad = 1;
17103 
17104     // This variable refers to the size and not index in the array.
17105     unsigned LastLegalVectorType = 1;
17106     unsigned LastLegalIntegerType = 1;
17107     bool isDereferenceable = true;
17108     bool DoIntegerTruncate = false;
17109     StartAddress = LoadNodes[0].OffsetFromBase;
17110     SDValue LoadChain = FirstLoad->getChain();
17111     for (unsigned i = 1; i < LoadNodes.size(); ++i) {
17112       // All loads must share the same chain.
17113       if (LoadNodes[i].MemNode->getChain() != LoadChain)
17114         break;
17115 
17116       int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
17117       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
17118         break;
17119       LastConsecutiveLoad = i;
17120 
17121       if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
17122         isDereferenceable = false;
17123 
17124       // Find a legal type for the vector store.
17125       unsigned Elts = (i + 1) * NumMemElts;
17126       EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
17127 
17128       // Break early when size is too large to be legal.
17129       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
17130         break;
17131 
17132       bool IsFastSt = false;
17133       bool IsFastLd = false;
17134       if (TLI.isTypeLegal(StoreTy) &&
17135           TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
17136           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17137                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
17138           IsFastSt &&
17139           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17140                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
17141           IsFastLd) {
17142         LastLegalVectorType = i + 1;
17143       }
17144 
17145       // Find a legal type for the integer store.
17146       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
17147       StoreTy = EVT::getIntegerVT(Context, SizeInBits);
17148       if (TLI.isTypeLegal(StoreTy) &&
17149           TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
17150           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17151                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
17152           IsFastSt &&
17153           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17154                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
17155           IsFastLd) {
17156         LastLegalIntegerType = i + 1;
17157         DoIntegerTruncate = false;
17158         // Or check whether a truncstore and extload is legal.
17159       } else if (TLI.getTypeAction(Context, StoreTy) ==
17160                  TargetLowering::TypePromoteInteger) {
17161         EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
17162         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
17163             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
17164             TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
17165             TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
17166             TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
17167             TLI.allowsMemoryAccess(Context, DL, StoreTy,
17168                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
17169             IsFastSt &&
17170             TLI.allowsMemoryAccess(Context, DL, StoreTy,
17171                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
17172             IsFastLd) {
17173           LastLegalIntegerType = i + 1;
17174           DoIntegerTruncate = true;
17175         }
17176       }
17177     }
17178 
17179     // Only use vector types if the vector type is larger than the integer
17180     // type. If they are the same, use integers.
17181     bool UseVectorTy =
17182         LastLegalVectorType > LastLegalIntegerType && AllowVectors;
17183     unsigned LastLegalType =
17184         std::max(LastLegalVectorType, LastLegalIntegerType);
17185 
17186     // We add +1 here because the LastXXX variables refer to location while
17187     // the NumElem refers to array/index size.
17188     unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
17189     NumElem = std::min(LastLegalType, NumElem);
17190     Align FirstLoadAlign = FirstLoad->getAlign();
17191 
17192     if (NumElem < 2) {
17193       // We know that candidate stores are in order and of correct
17194       // shape. While there is no mergeable sequence from the
17195       // beginning one may start later in the sequence. The only
17196       // reason a merge of size N could have failed where another of
17197       // the same size would not have is if the alignment or either
17198       // the load or store has improved. Drop as many candidates as we
17199       // can here.
17200       unsigned NumSkip = 1;
17201       while ((NumSkip < LoadNodes.size()) &&
17202              (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
17203              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
17204         NumSkip++;
17205       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
17206       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
17207       NumConsecutiveStores -= NumSkip;
17208       continue;
17209     }
17210 
17211     // Check that we can merge these candidates without causing a cycle.
17212     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
17213                                                   RootNode)) {
17214       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
17215       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
17216       NumConsecutiveStores -= NumElem;
17217       continue;
17218     }
17219 
17220     // Find if it is better to use vectors or integers to load and store
17221     // to memory.
17222     EVT JointMemOpVT;
17223     if (UseVectorTy) {
17224       // Find a legal type for the vector store.
17225       unsigned Elts = NumElem * NumMemElts;
17226       JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
17227     } else {
17228       unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
17229       JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
17230     }
17231 
17232     SDLoc LoadDL(LoadNodes[0].MemNode);
17233     SDLoc StoreDL(StoreNodes[0].MemNode);
17234 
17235     // The merged loads are required to have the same incoming chain, so
17236     // using the first's chain is acceptable.
17237 
17238     SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
17239     AddToWorklist(NewStoreChain.getNode());
17240 
17241     MachineMemOperand::Flags LdMMOFlags =
17242         isDereferenceable ? MachineMemOperand::MODereferenceable
17243                           : MachineMemOperand::MONone;
17244     if (IsNonTemporalLoad)
17245       LdMMOFlags |= MachineMemOperand::MONonTemporal;
17246 
17247     MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
17248                                               ? MachineMemOperand::MONonTemporal
17249                                               : MachineMemOperand::MONone;
17250 
17251     SDValue NewLoad, NewStore;
17252     if (UseVectorTy || !DoIntegerTruncate) {
17253       NewLoad = DAG.getLoad(
17254           JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
17255           FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
17256       SDValue StoreOp = NewLoad;
17257       if (NeedRotate) {
17258         unsigned LoadWidth = ElementSizeBytes * 8 * 2;
17259         assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
17260                "Unexpected type for rotate-able load pair");
17261         SDValue RotAmt =
17262             DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
17263         // Target can convert to the identical ROTR if it does not have ROTL.
17264         StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
17265       }
17266       NewStore = DAG.getStore(
17267           NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
17268           FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
17269     } else { // This must be the truncstore/extload case
17270       EVT ExtendedTy =
17271           TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
17272       NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
17273                                FirstLoad->getChain(), FirstLoad->getBasePtr(),
17274                                FirstLoad->getPointerInfo(), JointMemOpVT,
17275                                FirstLoadAlign, LdMMOFlags);
17276       NewStore = DAG.getTruncStore(
17277           NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
17278           FirstInChain->getPointerInfo(), JointMemOpVT,
17279           FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags());
17280     }
17281 
17282     // Transfer chain users from old loads to the new load.
17283     for (unsigned i = 0; i < NumElem; ++i) {
17284       LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
17285       DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
17286                                     SDValue(NewLoad.getNode(), 1));
17287     }
17288 
17289     // Replace all stores with the new store. Recursively remove corresponding
17290     // values if they are no longer used.
17291     for (unsigned i = 0; i < NumElem; ++i) {
17292       SDValue Val = StoreNodes[i].MemNode->getOperand(1);
17293       CombineTo(StoreNodes[i].MemNode, NewStore);
17294       if (Val.getNode()->use_empty())
17295         recursivelyDeleteUnusedNodes(Val.getNode());
17296     }
17297 
17298     MadeChange = true;
17299     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
17300     LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
17301     NumConsecutiveStores -= NumElem;
17302   }
17303   return MadeChange;
17304 }
17305 
mergeConsecutiveStores(StoreSDNode * St)17306 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
17307   if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
17308     return false;
17309 
17310   // TODO: Extend this function to merge stores of scalable vectors.
17311   // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
17312   // store since we know <vscale x 16 x i8> is exactly twice as large as
17313   // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
17314   EVT MemVT = St->getMemoryVT();
17315   if (MemVT.isScalableVector())
17316     return false;
17317   if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
17318     return false;
17319 
17320   // This function cannot currently deal with non-byte-sized memory sizes.
17321   int64_t ElementSizeBytes = MemVT.getStoreSize();
17322   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
17323     return false;
17324 
17325   // Do not bother looking at stored values that are not constants, loads, or
17326   // extracted vector elements.
17327   SDValue StoredVal = peekThroughBitcasts(St->getValue());
17328   const StoreSource StoreSrc = getStoreSource(StoredVal);
17329   if (StoreSrc == StoreSource::Unknown)
17330     return false;
17331 
17332   SmallVector<MemOpLink, 8> StoreNodes;
17333   SDNode *RootNode;
17334   // Find potential store merge candidates by searching through chain sub-DAG
17335   getStoreMergeCandidates(St, StoreNodes, RootNode);
17336 
17337   // Check if there is anything to merge.
17338   if (StoreNodes.size() < 2)
17339     return false;
17340 
17341   // Sort the memory operands according to their distance from the
17342   // base pointer.
17343   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
17344     return LHS.OffsetFromBase < RHS.OffsetFromBase;
17345   });
17346 
17347   bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
17348       Attribute::NoImplicitFloat);
17349   bool IsNonTemporalStore = St->isNonTemporal();
17350   bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
17351                            cast<LoadSDNode>(StoredVal)->isNonTemporal();
17352 
17353   // Store Merge attempts to merge the lowest stores. This generally
17354   // works out as if successful, as the remaining stores are checked
17355   // after the first collection of stores is merged. However, in the
17356   // case that a non-mergeable store is found first, e.g., {p[-2],
17357   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
17358   // mergeable cases. To prevent this, we prune such stores from the
17359   // front of StoreNodes here.
17360   bool MadeChange = false;
17361   while (StoreNodes.size() > 1) {
17362     unsigned NumConsecutiveStores =
17363         getConsecutiveStores(StoreNodes, ElementSizeBytes);
17364     // There are no more stores in the list to examine.
17365     if (NumConsecutiveStores == 0)
17366       return MadeChange;
17367 
17368     // We have at least 2 consecutive stores. Try to merge them.
17369     assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
17370     switch (StoreSrc) {
17371     case StoreSource::Constant:
17372       MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
17373                                              MemVT, RootNode, AllowVectors);
17374       break;
17375 
17376     case StoreSource::Extract:
17377       MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
17378                                             MemVT, RootNode);
17379       break;
17380 
17381     case StoreSource::Load:
17382       MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
17383                                          MemVT, RootNode, AllowVectors,
17384                                          IsNonTemporalStore, IsNonTemporalLoad);
17385       break;
17386 
17387     default:
17388       llvm_unreachable("Unhandled store source type");
17389     }
17390   }
17391   return MadeChange;
17392 }
17393 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)17394 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
17395   SDLoc SL(ST);
17396   SDValue ReplStore;
17397 
17398   // Replace the chain to avoid dependency.
17399   if (ST->isTruncatingStore()) {
17400     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
17401                                   ST->getBasePtr(), ST->getMemoryVT(),
17402                                   ST->getMemOperand());
17403   } else {
17404     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
17405                              ST->getMemOperand());
17406   }
17407 
17408   // Create token to keep both nodes around.
17409   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
17410                               MVT::Other, ST->getChain(), ReplStore);
17411 
17412   // Make sure the new and old chains are cleaned up.
17413   AddToWorklist(Token.getNode());
17414 
17415   // Don't add users to work list.
17416   return CombineTo(ST, Token, false);
17417 }
17418 
replaceStoreOfFPConstant(StoreSDNode * ST)17419 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
17420   SDValue Value = ST->getValue();
17421   if (Value.getOpcode() == ISD::TargetConstantFP)
17422     return SDValue();
17423 
17424   if (!ISD::isNormalStore(ST))
17425     return SDValue();
17426 
17427   SDLoc DL(ST);
17428 
17429   SDValue Chain = ST->getChain();
17430   SDValue Ptr = ST->getBasePtr();
17431 
17432   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
17433 
17434   // NOTE: If the original store is volatile, this transform must not increase
17435   // the number of stores.  For example, on x86-32 an f64 can be stored in one
17436   // processor operation but an i64 (which is not legal) requires two.  So the
17437   // transform should not be done in this case.
17438 
17439   SDValue Tmp;
17440   switch (CFP->getSimpleValueType(0).SimpleTy) {
17441   default:
17442     llvm_unreachable("Unknown FP type");
17443   case MVT::f16:    // We don't do this for these yet.
17444   case MVT::f80:
17445   case MVT::f128:
17446   case MVT::ppcf128:
17447     return SDValue();
17448   case MVT::f32:
17449     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
17450         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
17451       ;
17452       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
17453                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
17454                             MVT::i32);
17455       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
17456     }
17457 
17458     return SDValue();
17459   case MVT::f64:
17460     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
17461          ST->isSimple()) ||
17462         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
17463       ;
17464       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
17465                             getZExtValue(), SDLoc(CFP), MVT::i64);
17466       return DAG.getStore(Chain, DL, Tmp,
17467                           Ptr, ST->getMemOperand());
17468     }
17469 
17470     if (ST->isSimple() &&
17471         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
17472       // Many FP stores are not made apparent until after legalize, e.g. for
17473       // argument passing.  Since this is so common, custom legalize the
17474       // 64-bit integer store into two 32-bit stores.
17475       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
17476       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
17477       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
17478       if (DAG.getDataLayout().isBigEndian())
17479         std::swap(Lo, Hi);
17480 
17481       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
17482       AAMDNodes AAInfo = ST->getAAInfo();
17483 
17484       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
17485                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
17486       Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(4), DL);
17487       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
17488                                  ST->getPointerInfo().getWithOffset(4),
17489                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
17490       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
17491                          St0, St1);
17492     }
17493 
17494     return SDValue();
17495   }
17496 }
17497 
visitSTORE(SDNode * N)17498 SDValue DAGCombiner::visitSTORE(SDNode *N) {
17499   StoreSDNode *ST  = cast<StoreSDNode>(N);
17500   SDValue Chain = ST->getChain();
17501   SDValue Value = ST->getValue();
17502   SDValue Ptr   = ST->getBasePtr();
17503 
17504   // If this is a store of a bit convert, store the input value if the
17505   // resultant store does not need a higher alignment than the original.
17506   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
17507       ST->isUnindexed()) {
17508     EVT SVT = Value.getOperand(0).getValueType();
17509     // If the store is volatile, we only want to change the store type if the
17510     // resulting store is legal. Otherwise we might increase the number of
17511     // memory accesses. We don't care if the original type was legal or not
17512     // as we assume software couldn't rely on the number of accesses of an
17513     // illegal type.
17514     // TODO: May be able to relax for unordered atomics (see D66309)
17515     if (((!LegalOperations && ST->isSimple()) ||
17516          TLI.isOperationLegal(ISD::STORE, SVT)) &&
17517         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
17518                                      DAG, *ST->getMemOperand())) {
17519       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
17520                           ST->getMemOperand());
17521     }
17522   }
17523 
17524   // Turn 'store undef, Ptr' -> nothing.
17525   if (Value.isUndef() && ST->isUnindexed())
17526     return Chain;
17527 
17528   // Try to infer better alignment information than the store already has.
17529   if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) {
17530     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
17531       if (*Alignment > ST->getAlign() &&
17532           isAligned(*Alignment, ST->getSrcValueOffset())) {
17533         SDValue NewStore =
17534             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
17535                               ST->getMemoryVT(), *Alignment,
17536                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
17537         // NewStore will always be N as we are only refining the alignment
17538         assert(NewStore.getNode() == N);
17539         (void)NewStore;
17540       }
17541     }
17542   }
17543 
17544   // Try transforming a pair floating point load / store ops to integer
17545   // load / store ops.
17546   if (SDValue NewST = TransformFPLoadStorePair(N))
17547     return NewST;
17548 
17549   // Try transforming several stores into STORE (BSWAP).
17550   if (SDValue Store = mergeTruncStores(ST))
17551     return Store;
17552 
17553   if (ST->isUnindexed()) {
17554     // Walk up chain skipping non-aliasing memory nodes, on this store and any
17555     // adjacent stores.
17556     if (findBetterNeighborChains(ST)) {
17557       // replaceStoreChain uses CombineTo, which handled all of the worklist
17558       // manipulation. Return the original node to not do anything else.
17559       return SDValue(ST, 0);
17560     }
17561     Chain = ST->getChain();
17562   }
17563 
17564   // FIXME: is there such a thing as a truncating indexed store?
17565   if (ST->isTruncatingStore() && ST->isUnindexed() &&
17566       Value.getValueType().isInteger() &&
17567       (!isa<ConstantSDNode>(Value) ||
17568        !cast<ConstantSDNode>(Value)->isOpaque())) {
17569     APInt TruncDemandedBits =
17570         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
17571                              ST->getMemoryVT().getScalarSizeInBits());
17572 
17573     // See if we can simplify the input to this truncstore with knowledge that
17574     // only the low bits are being used.  For example:
17575     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
17576     AddToWorklist(Value.getNode());
17577     if (SDValue Shorter = DAG.GetDemandedBits(Value, TruncDemandedBits))
17578       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
17579                                ST->getMemOperand());
17580 
17581     // Otherwise, see if we can simplify the operation with
17582     // SimplifyDemandedBits, which only works if the value has a single use.
17583     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
17584       // Re-visit the store if anything changed and the store hasn't been merged
17585       // with another node (N is deleted) SimplifyDemandedBits will add Value's
17586       // node back to the worklist if necessary, but we also need to re-visit
17587       // the Store node itself.
17588       if (N->getOpcode() != ISD::DELETED_NODE)
17589         AddToWorklist(N);
17590       return SDValue(N, 0);
17591     }
17592   }
17593 
17594   // If this is a load followed by a store to the same location, then the store
17595   // is dead/noop.
17596   // TODO: Can relax for unordered atomics (see D66309)
17597   if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) {
17598     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
17599         ST->isUnindexed() && ST->isSimple() &&
17600         // There can't be any side effects between the load and store, such as
17601         // a call or store.
17602         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
17603       // The store is dead, remove it.
17604       return Chain;
17605     }
17606   }
17607 
17608   // TODO: Can relax for unordered atomics (see D66309)
17609   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
17610     if (ST->isUnindexed() && ST->isSimple() &&
17611         ST1->isUnindexed() && ST1->isSimple()) {
17612       if (ST1->getBasePtr() == Ptr && ST1->getValue() == Value &&
17613           ST->getMemoryVT() == ST1->getMemoryVT()) {
17614         // If this is a store followed by a store with the same value to the
17615         // same location, then the store is dead/noop.
17616         return Chain;
17617       }
17618 
17619       if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
17620           !ST1->getBasePtr().isUndef() &&
17621           // BaseIndexOffset and the code below requires knowing the size
17622           // of a vector, so bail out if MemoryVT is scalable.
17623           !ST->getMemoryVT().isScalableVector() &&
17624           !ST1->getMemoryVT().isScalableVector()) {
17625         const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
17626         const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
17627         unsigned STBitSize = ST->getMemoryVT().getFixedSizeInBits();
17628         unsigned ChainBitSize = ST1->getMemoryVT().getFixedSizeInBits();
17629         // If this is a store who's preceding store to a subset of the current
17630         // location and no one other node is chained to that store we can
17631         // effectively drop the store. Do not remove stores to undef as they may
17632         // be used as data sinks.
17633         if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) {
17634           CombineTo(ST1, ST1->getChain());
17635           return SDValue();
17636         }
17637       }
17638     }
17639   }
17640 
17641   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
17642   // truncating store.  We can do this even if this is already a truncstore.
17643   if ((Value.getOpcode() == ISD::FP_ROUND || Value.getOpcode() == ISD::TRUNCATE)
17644       && Value.getNode()->hasOneUse() && ST->isUnindexed() &&
17645       TLI.isTruncStoreLegal(Value.getOperand(0).getValueType(),
17646                             ST->getMemoryVT())) {
17647     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
17648                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
17649   }
17650 
17651   // Always perform this optimization before types are legal. If the target
17652   // prefers, also try this after legalization to catch stores that were created
17653   // by intrinsics or other nodes.
17654   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
17655     while (true) {
17656       // There can be multiple store sequences on the same chain.
17657       // Keep trying to merge store sequences until we are unable to do so
17658       // or until we merge the last store on the chain.
17659       bool Changed = mergeConsecutiveStores(ST);
17660       if (!Changed) break;
17661       // Return N as merge only uses CombineTo and no worklist clean
17662       // up is necessary.
17663       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
17664         return SDValue(N, 0);
17665     }
17666   }
17667 
17668   // Try transforming N to an indexed store.
17669   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
17670     return SDValue(N, 0);
17671 
17672   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
17673   //
17674   // Make sure to do this only after attempting to merge stores in order to
17675   //  avoid changing the types of some subset of stores due to visit order,
17676   //  preventing their merging.
17677   if (isa<ConstantFPSDNode>(ST->getValue())) {
17678     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
17679       return NewSt;
17680   }
17681 
17682   if (SDValue NewSt = splitMergedValStore(ST))
17683     return NewSt;
17684 
17685   return ReduceLoadOpStoreWidth(N);
17686 }
17687 
visitLIFETIME_END(SDNode * N)17688 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
17689   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
17690   if (!LifetimeEnd->hasOffset())
17691     return SDValue();
17692 
17693   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
17694                                         LifetimeEnd->getOffset(), false);
17695 
17696   // We walk up the chains to find stores.
17697   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
17698   while (!Chains.empty()) {
17699     SDValue Chain = Chains.pop_back_val();
17700     if (!Chain.hasOneUse())
17701       continue;
17702     switch (Chain.getOpcode()) {
17703     case ISD::TokenFactor:
17704       for (unsigned Nops = Chain.getNumOperands(); Nops;)
17705         Chains.push_back(Chain.getOperand(--Nops));
17706       break;
17707     case ISD::LIFETIME_START:
17708     case ISD::LIFETIME_END:
17709       // We can forward past any lifetime start/end that can be proven not to
17710       // alias the node.
17711       if (!isAlias(Chain.getNode(), N))
17712         Chains.push_back(Chain.getOperand(0));
17713       break;
17714     case ISD::STORE: {
17715       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
17716       // TODO: Can relax for unordered atomics (see D66309)
17717       if (!ST->isSimple() || ST->isIndexed())
17718         continue;
17719       const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
17720       // The bounds of a scalable store are not known until runtime, so this
17721       // store cannot be elided.
17722       if (StoreSize.isScalable())
17723         continue;
17724       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
17725       // If we store purely within object bounds just before its lifetime ends,
17726       // we can remove the store.
17727       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
17728                                    StoreSize.getFixedSize() * 8)) {
17729         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
17730                    dbgs() << "\nwithin LIFETIME_END of : ";
17731                    LifetimeEndBase.dump(); dbgs() << "\n");
17732         CombineTo(ST, ST->getChain());
17733         return SDValue(N, 0);
17734       }
17735     }
17736     }
17737   }
17738   return SDValue();
17739 }
17740 
17741 /// For the instruction sequence of store below, F and I values
17742 /// are bundled together as an i64 value before being stored into memory.
17743 /// Sometimes it is more efficent to generate separate stores for F and I,
17744 /// which can remove the bitwise instructions or sink them to colder places.
17745 ///
17746 ///   (store (or (zext (bitcast F to i32) to i64),
17747 ///              (shl (zext I to i64), 32)), addr)  -->
17748 ///   (store F, addr) and (store I, addr+4)
17749 ///
17750 /// Similarly, splitting for other merged store can also be beneficial, like:
17751 /// For pair of {i32, i32}, i64 store --> two i32 stores.
17752 /// For pair of {i32, i16}, i64 store --> two i32 stores.
17753 /// For pair of {i16, i16}, i32 store --> two i16 stores.
17754 /// For pair of {i16, i8},  i32 store --> two i16 stores.
17755 /// For pair of {i8, i8},   i16 store --> two i8 stores.
17756 ///
17757 /// We allow each target to determine specifically which kind of splitting is
17758 /// supported.
17759 ///
17760 /// The store patterns are commonly seen from the simple code snippet below
17761 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
17762 ///   void goo(const std::pair<int, float> &);
17763 ///   hoo() {
17764 ///     ...
17765 ///     goo(std::make_pair(tmp, ftmp));
17766 ///     ...
17767 ///   }
17768 ///
splitMergedValStore(StoreSDNode * ST)17769 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
17770   if (OptLevel == CodeGenOpt::None)
17771     return SDValue();
17772 
17773   // Can't change the number of memory accesses for a volatile store or break
17774   // atomicity for an atomic one.
17775   if (!ST->isSimple())
17776     return SDValue();
17777 
17778   SDValue Val = ST->getValue();
17779   SDLoc DL(ST);
17780 
17781   // Match OR operand.
17782   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
17783     return SDValue();
17784 
17785   // Match SHL operand and get Lower and Higher parts of Val.
17786   SDValue Op1 = Val.getOperand(0);
17787   SDValue Op2 = Val.getOperand(1);
17788   SDValue Lo, Hi;
17789   if (Op1.getOpcode() != ISD::SHL) {
17790     std::swap(Op1, Op2);
17791     if (Op1.getOpcode() != ISD::SHL)
17792       return SDValue();
17793   }
17794   Lo = Op2;
17795   Hi = Op1.getOperand(0);
17796   if (!Op1.hasOneUse())
17797     return SDValue();
17798 
17799   // Match shift amount to HalfValBitSize.
17800   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
17801   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
17802   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
17803     return SDValue();
17804 
17805   // Lo and Hi are zero-extended from int with size less equal than 32
17806   // to i64.
17807   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
17808       !Lo.getOperand(0).getValueType().isScalarInteger() ||
17809       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
17810       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
17811       !Hi.getOperand(0).getValueType().isScalarInteger() ||
17812       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
17813     return SDValue();
17814 
17815   // Use the EVT of low and high parts before bitcast as the input
17816   // of target query.
17817   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
17818                   ? Lo.getOperand(0).getValueType()
17819                   : Lo.getValueType();
17820   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
17821                    ? Hi.getOperand(0).getValueType()
17822                    : Hi.getValueType();
17823   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
17824     return SDValue();
17825 
17826   // Start to split store.
17827   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
17828   AAMDNodes AAInfo = ST->getAAInfo();
17829 
17830   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
17831   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
17832   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
17833   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
17834 
17835   SDValue Chain = ST->getChain();
17836   SDValue Ptr = ST->getBasePtr();
17837   // Lower value store.
17838   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
17839                              ST->getOriginalAlign(), MMOFlags, AAInfo);
17840   Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(HalfValBitSize / 8), DL);
17841   // Higher value store.
17842   SDValue St1 = DAG.getStore(
17843       St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
17844       ST->getOriginalAlign(), MMOFlags, AAInfo);
17845   return St1;
17846 }
17847 
17848 /// Convert a disguised subvector insertion into a shuffle:
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)17849 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
17850   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
17851          "Expected extract_vector_elt");
17852   SDValue InsertVal = N->getOperand(1);
17853   SDValue Vec = N->getOperand(0);
17854 
17855   // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
17856   // InsIndex)
17857   //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
17858   //   CONCAT_VECTORS.
17859   if (Vec.getOpcode() == ISD::VECTOR_SHUFFLE && Vec.hasOneUse() &&
17860       InsertVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
17861       isa<ConstantSDNode>(InsertVal.getOperand(1))) {
17862     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Vec.getNode());
17863     ArrayRef<int> Mask = SVN->getMask();
17864 
17865     SDValue X = Vec.getOperand(0);
17866     SDValue Y = Vec.getOperand(1);
17867 
17868     // Vec's operand 0 is using indices from 0 to N-1 and
17869     // operand 1 from N to 2N - 1, where N is the number of
17870     // elements in the vectors.
17871     SDValue InsertVal0 = InsertVal.getOperand(0);
17872     int ElementOffset = -1;
17873 
17874     // We explore the inputs of the shuffle in order to see if we find the
17875     // source of the extract_vector_elt. If so, we can use it to modify the
17876     // shuffle rather than perform an insert_vector_elt.
17877     SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
17878     ArgWorkList.emplace_back(Mask.size(), Y);
17879     ArgWorkList.emplace_back(0, X);
17880 
17881     while (!ArgWorkList.empty()) {
17882       int ArgOffset;
17883       SDValue ArgVal;
17884       std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
17885 
17886       if (ArgVal == InsertVal0) {
17887         ElementOffset = ArgOffset;
17888         break;
17889       }
17890 
17891       // Peek through concat_vector.
17892       if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
17893         int CurrentArgOffset =
17894             ArgOffset + ArgVal.getValueType().getVectorNumElements();
17895         int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
17896         for (SDValue Op : reverse(ArgVal->ops())) {
17897           CurrentArgOffset -= Step;
17898           ArgWorkList.emplace_back(CurrentArgOffset, Op);
17899         }
17900 
17901         // Make sure we went through all the elements and did not screw up index
17902         // computation.
17903         assert(CurrentArgOffset == ArgOffset);
17904       }
17905     }
17906 
17907     if (ElementOffset != -1) {
17908       SmallVector<int, 16> NewMask(Mask.begin(), Mask.end());
17909 
17910       auto *ExtrIndex = cast<ConstantSDNode>(InsertVal.getOperand(1));
17911       NewMask[InsIndex] = ElementOffset + ExtrIndex->getZExtValue();
17912       assert(NewMask[InsIndex] <
17913                  (int)(2 * Vec.getValueType().getVectorNumElements()) &&
17914              NewMask[InsIndex] >= 0 && "NewMask[InsIndex] is out of bound");
17915 
17916       SDValue LegalShuffle =
17917               TLI.buildLegalVectorShuffle(Vec.getValueType(), SDLoc(N), X,
17918                                           Y, NewMask, DAG);
17919       if (LegalShuffle)
17920         return LegalShuffle;
17921     }
17922   }
17923 
17924   // insert_vector_elt V, (bitcast X from vector type), IdxC -->
17925   // bitcast(shuffle (bitcast V), (extended X), Mask)
17926   // Note: We do not use an insert_subvector node because that requires a
17927   // legal subvector type.
17928   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
17929       !InsertVal.getOperand(0).getValueType().isVector())
17930     return SDValue();
17931 
17932   SDValue SubVec = InsertVal.getOperand(0);
17933   SDValue DestVec = N->getOperand(0);
17934   EVT SubVecVT = SubVec.getValueType();
17935   EVT VT = DestVec.getValueType();
17936   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
17937   // If the source only has a single vector element, the cost of creating adding
17938   // it to a vector is likely to exceed the cost of a insert_vector_elt.
17939   if (NumSrcElts == 1)
17940     return SDValue();
17941   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
17942   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
17943 
17944   // Step 1: Create a shuffle mask that implements this insert operation. The
17945   // vector that we are inserting into will be operand 0 of the shuffle, so
17946   // those elements are just 'i'. The inserted subvector is in the first
17947   // positions of operand 1 of the shuffle. Example:
17948   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
17949   SmallVector<int, 16> Mask(NumMaskVals);
17950   for (unsigned i = 0; i != NumMaskVals; ++i) {
17951     if (i / NumSrcElts == InsIndex)
17952       Mask[i] = (i % NumSrcElts) + NumMaskVals;
17953     else
17954       Mask[i] = i;
17955   }
17956 
17957   // Bail out if the target can not handle the shuffle we want to create.
17958   EVT SubVecEltVT = SubVecVT.getVectorElementType();
17959   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
17960   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
17961     return SDValue();
17962 
17963   // Step 2: Create a wide vector from the inserted source vector by appending
17964   // undefined elements. This is the same size as our destination vector.
17965   SDLoc DL(N);
17966   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
17967   ConcatOps[0] = SubVec;
17968   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
17969 
17970   // Step 3: Shuffle in the padded subvector.
17971   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
17972   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
17973   AddToWorklist(PaddedSubV.getNode());
17974   AddToWorklist(DestVecBC.getNode());
17975   AddToWorklist(Shuf.getNode());
17976   return DAG.getBitcast(VT, Shuf);
17977 }
17978 
visitINSERT_VECTOR_ELT(SDNode * N)17979 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
17980   SDValue InVec = N->getOperand(0);
17981   SDValue InVal = N->getOperand(1);
17982   SDValue EltNo = N->getOperand(2);
17983   SDLoc DL(N);
17984 
17985   EVT VT = InVec.getValueType();
17986   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
17987 
17988   // Insert into out-of-bounds element is undefined.
17989   if (IndexC && VT.isFixedLengthVector() &&
17990       IndexC->getZExtValue() >= VT.getVectorNumElements())
17991     return DAG.getUNDEF(VT);
17992 
17993   // Remove redundant insertions:
17994   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
17995   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
17996       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
17997     return InVec;
17998 
17999   if (!IndexC) {
18000     // If this is variable insert to undef vector, it might be better to splat:
18001     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
18002     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
18003       if (VT.isScalableVector())
18004         return DAG.getSplatVector(VT, DL, InVal);
18005       else {
18006         SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), InVal);
18007         return DAG.getBuildVector(VT, DL, Ops);
18008       }
18009     }
18010     return SDValue();
18011   }
18012 
18013   if (VT.isScalableVector())
18014     return SDValue();
18015 
18016   unsigned NumElts = VT.getVectorNumElements();
18017 
18018   // We must know which element is being inserted for folds below here.
18019   unsigned Elt = IndexC->getZExtValue();
18020   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
18021     return Shuf;
18022 
18023   // Canonicalize insert_vector_elt dag nodes.
18024   // Example:
18025   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
18026   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
18027   //
18028   // Do this only if the child insert_vector node has one use; also
18029   // do this only if indices are both constants and Idx1 < Idx0.
18030   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
18031       && isa<ConstantSDNode>(InVec.getOperand(2))) {
18032     unsigned OtherElt = InVec.getConstantOperandVal(2);
18033     if (Elt < OtherElt) {
18034       // Swap nodes.
18035       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
18036                                   InVec.getOperand(0), InVal, EltNo);
18037       AddToWorklist(NewOp.getNode());
18038       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
18039                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
18040     }
18041   }
18042 
18043   // If we can't generate a legal BUILD_VECTOR, exit
18044   if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
18045     return SDValue();
18046 
18047   // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially
18048   // be converted to a BUILD_VECTOR).  Fill in the Ops vector with the
18049   // vector elements.
18050   SmallVector<SDValue, 8> Ops;
18051   // Do not combine these two vectors if the output vector will not replace
18052   // the input vector.
18053   if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) {
18054     Ops.append(InVec.getNode()->op_begin(),
18055                InVec.getNode()->op_end());
18056   } else if (InVec.isUndef()) {
18057     Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType()));
18058   } else {
18059     return SDValue();
18060   }
18061   assert(Ops.size() == NumElts && "Unexpected vector size");
18062 
18063   // Insert the element
18064   if (Elt < Ops.size()) {
18065     // All the operands of BUILD_VECTOR must have the same type;
18066     // we enforce that here.
18067     EVT OpVT = Ops[0].getValueType();
18068     Ops[Elt] = OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal;
18069   }
18070 
18071   // Return the new vector
18072   return DAG.getBuildVector(VT, DL, Ops);
18073 }
18074 
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)18075 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
18076                                                   SDValue EltNo,
18077                                                   LoadSDNode *OriginalLoad) {
18078   assert(OriginalLoad->isSimple());
18079 
18080   EVT ResultVT = EVE->getValueType(0);
18081   EVT VecEltVT = InVecVT.getVectorElementType();
18082 
18083   // If the vector element type is not a multiple of a byte then we are unable
18084   // to correctly compute an address to load only the extracted element as a
18085   // scalar.
18086   if (!VecEltVT.isByteSized())
18087     return SDValue();
18088 
18089   Align Alignment = OriginalLoad->getAlign();
18090   Align NewAlign = DAG.getDataLayout().getABITypeAlign(
18091       VecEltVT.getTypeForEVT(*DAG.getContext()));
18092 
18093   if (NewAlign > Alignment ||
18094       !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT))
18095     return SDValue();
18096 
18097   ISD::LoadExtType ExtTy = ResultVT.bitsGT(VecEltVT) ?
18098     ISD::NON_EXTLOAD : ISD::EXTLOAD;
18099   if (!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
18100     return SDValue();
18101 
18102   Alignment = NewAlign;
18103 
18104   SDValue NewPtr = OriginalLoad->getBasePtr();
18105   SDValue Offset;
18106   EVT PtrType = NewPtr.getValueType();
18107   MachinePointerInfo MPI;
18108   SDLoc DL(EVE);
18109   if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
18110     int Elt = ConstEltNo->getZExtValue();
18111     unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
18112     Offset = DAG.getConstant(PtrOff, DL, PtrType);
18113     MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
18114   } else {
18115     Offset = DAG.getZExtOrTrunc(EltNo, DL, PtrType);
18116     Offset = DAG.getNode(
18117         ISD::MUL, DL, PtrType, Offset,
18118         DAG.getConstant(VecEltVT.getStoreSize(), DL, PtrType));
18119     // Discard the pointer info except the address space because the memory
18120     // operand can't represent this new access since the offset is variable.
18121     MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
18122   }
18123   NewPtr = DAG.getMemBasePlusOffset(NewPtr, Offset, DL);
18124 
18125   // The replacement we need to do here is a little tricky: we need to
18126   // replace an extractelement of a load with a load.
18127   // Use ReplaceAllUsesOfValuesWith to do the replacement.
18128   // Note that this replacement assumes that the extractvalue is the only
18129   // use of the load; that's okay because we don't want to perform this
18130   // transformation in other cases anyway.
18131   SDValue Load;
18132   SDValue Chain;
18133   if (ResultVT.bitsGT(VecEltVT)) {
18134     // If the result type of vextract is wider than the load, then issue an
18135     // extending load instead.
18136     ISD::LoadExtType ExtType = TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT,
18137                                                   VecEltVT)
18138                                    ? ISD::ZEXTLOAD
18139                                    : ISD::EXTLOAD;
18140     Load = DAG.getExtLoad(ExtType, SDLoc(EVE), ResultVT,
18141                           OriginalLoad->getChain(), NewPtr, MPI, VecEltVT,
18142                           Alignment, OriginalLoad->getMemOperand()->getFlags(),
18143                           OriginalLoad->getAAInfo());
18144     Chain = Load.getValue(1);
18145   } else {
18146     Load = DAG.getLoad(
18147         VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr, MPI, Alignment,
18148         OriginalLoad->getMemOperand()->getFlags(), OriginalLoad->getAAInfo());
18149     Chain = Load.getValue(1);
18150     if (ResultVT.bitsLT(VecEltVT))
18151       Load = DAG.getNode(ISD::TRUNCATE, SDLoc(EVE), ResultVT, Load);
18152     else
18153       Load = DAG.getBitcast(ResultVT, Load);
18154   }
18155   WorklistRemover DeadNodes(*this);
18156   SDValue From[] = { SDValue(EVE, 0), SDValue(OriginalLoad, 1) };
18157   SDValue To[] = { Load, Chain };
18158   DAG.ReplaceAllUsesOfValuesWith(From, To, 2);
18159   // Make sure to revisit this node to clean it up; it will usually be dead.
18160   AddToWorklist(EVE);
18161   // Since we're explicitly calling ReplaceAllUses, add the new node to the
18162   // worklist explicitly as well.
18163   AddToWorklistWithUsers(Load.getNode());
18164   ++OpsNarrowed;
18165   return SDValue(EVE, 0);
18166 }
18167 
18168 /// Transform a vector binary operation into a scalar binary operation by moving
18169 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)18170 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
18171                                        bool LegalOperations) {
18172   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18173   SDValue Vec = ExtElt->getOperand(0);
18174   SDValue Index = ExtElt->getOperand(1);
18175   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
18176   if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
18177       Vec.getNode()->getNumValues() != 1)
18178     return SDValue();
18179 
18180   // Targets may want to avoid this to prevent an expensive register transfer.
18181   if (!TLI.shouldScalarizeBinop(Vec))
18182     return SDValue();
18183 
18184   // Extracting an element of a vector constant is constant-folded, so this
18185   // transform is just replacing a vector op with a scalar op while moving the
18186   // extract.
18187   SDValue Op0 = Vec.getOperand(0);
18188   SDValue Op1 = Vec.getOperand(1);
18189   if (isAnyConstantBuildVector(Op0, true) ||
18190       isAnyConstantBuildVector(Op1, true)) {
18191     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
18192     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
18193     SDLoc DL(ExtElt);
18194     EVT VT = ExtElt->getValueType(0);
18195     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
18196     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
18197     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
18198   }
18199 
18200   return SDValue();
18201 }
18202 
visitEXTRACT_VECTOR_ELT(SDNode * N)18203 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
18204   SDValue VecOp = N->getOperand(0);
18205   SDValue Index = N->getOperand(1);
18206   EVT ScalarVT = N->getValueType(0);
18207   EVT VecVT = VecOp.getValueType();
18208   if (VecOp.isUndef())
18209     return DAG.getUNDEF(ScalarVT);
18210 
18211   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
18212   //
18213   // This only really matters if the index is non-constant since other combines
18214   // on the constant elements already work.
18215   SDLoc DL(N);
18216   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
18217       Index == VecOp.getOperand(2)) {
18218     SDValue Elt = VecOp.getOperand(1);
18219     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
18220   }
18221 
18222   // (vextract (scalar_to_vector val, 0) -> val
18223   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
18224     // Only 0'th element of SCALAR_TO_VECTOR is defined.
18225     if (DAG.isKnownNeverZero(Index))
18226       return DAG.getUNDEF(ScalarVT);
18227 
18228     // Check if the result type doesn't match the inserted element type. A
18229     // SCALAR_TO_VECTOR may truncate the inserted element and the
18230     // EXTRACT_VECTOR_ELT may widen the extracted vector.
18231     SDValue InOp = VecOp.getOperand(0);
18232     if (InOp.getValueType() != ScalarVT) {
18233       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
18234       return DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
18235     }
18236     return InOp;
18237   }
18238 
18239   // extract_vector_elt of out-of-bounds element -> UNDEF
18240   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
18241   if (IndexC && VecVT.isFixedLengthVector() &&
18242       IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
18243     return DAG.getUNDEF(ScalarVT);
18244 
18245   // extract_vector_elt (build_vector x, y), 1 -> y
18246   if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
18247        VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
18248       TLI.isTypeLegal(VecVT) &&
18249       (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
18250     assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
18251             VecVT.isFixedLengthVector()) &&
18252            "BUILD_VECTOR used for scalable vectors");
18253     unsigned IndexVal =
18254         VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
18255     SDValue Elt = VecOp.getOperand(IndexVal);
18256     EVT InEltVT = Elt.getValueType();
18257 
18258     // Sometimes build_vector's scalar input types do not match result type.
18259     if (ScalarVT == InEltVT)
18260       return Elt;
18261 
18262     // TODO: It may be useful to truncate if free if the build_vector implicitly
18263     // converts.
18264   }
18265 
18266   if (VecVT.isScalableVector())
18267     return SDValue();
18268 
18269   // All the code from this point onwards assumes fixed width vectors, but it's
18270   // possible that some of the combinations could be made to work for scalable
18271   // vectors too.
18272   unsigned NumElts = VecVT.getVectorNumElements();
18273   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
18274 
18275   // TODO: These transforms should not require the 'hasOneUse' restriction, but
18276   // there are regressions on multiple targets without it. We can end up with a
18277   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
18278   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
18279       VecOp.hasOneUse()) {
18280     // The vector index of the LSBs of the source depend on the endian-ness.
18281     bool IsLE = DAG.getDataLayout().isLittleEndian();
18282     unsigned ExtractIndex = IndexC->getZExtValue();
18283     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
18284     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
18285     SDValue BCSrc = VecOp.getOperand(0);
18286     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
18287       return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc);
18288 
18289     if (LegalTypes && BCSrc.getValueType().isInteger() &&
18290         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
18291       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
18292       // trunc i64 X to i32
18293       SDValue X = BCSrc.getOperand(0);
18294       assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
18295              "Extract element and scalar to vector can't change element type "
18296              "from FP to integer.");
18297       unsigned XBitWidth = X.getValueSizeInBits();
18298       BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
18299 
18300       // An extract element return value type can be wider than its vector
18301       // operand element type. In that case, the high bits are undefined, so
18302       // it's possible that we may need to extend rather than truncate.
18303       if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
18304         assert(XBitWidth % VecEltBitWidth == 0 &&
18305                "Scalar bitwidth must be a multiple of vector element bitwidth");
18306         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
18307       }
18308     }
18309   }
18310 
18311   if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
18312     return BO;
18313 
18314   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
18315   // We only perform this optimization before the op legalization phase because
18316   // we may introduce new vector instructions which are not backed by TD
18317   // patterns. For example on AVX, extracting elements from a wide vector
18318   // without using extract_subvector. However, if we can find an underlying
18319   // scalar value, then we can always use that.
18320   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
18321     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
18322     // Find the new index to extract from.
18323     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
18324 
18325     // Extracting an undef index is undef.
18326     if (OrigElt == -1)
18327       return DAG.getUNDEF(ScalarVT);
18328 
18329     // Select the right vector half to extract from.
18330     SDValue SVInVec;
18331     if (OrigElt < (int)NumElts) {
18332       SVInVec = VecOp.getOperand(0);
18333     } else {
18334       SVInVec = VecOp.getOperand(1);
18335       OrigElt -= NumElts;
18336     }
18337 
18338     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
18339       SDValue InOp = SVInVec.getOperand(OrigElt);
18340       if (InOp.getValueType() != ScalarVT) {
18341         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
18342         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
18343       }
18344 
18345       return InOp;
18346     }
18347 
18348     // FIXME: We should handle recursing on other vector shuffles and
18349     // scalar_to_vector here as well.
18350 
18351     if (!LegalOperations ||
18352         // FIXME: Should really be just isOperationLegalOrCustom.
18353         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
18354         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
18355       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
18356                          DAG.getVectorIdxConstant(OrigElt, DL));
18357     }
18358   }
18359 
18360   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
18361   // simplify it based on the (valid) extraction indices.
18362   if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
18363         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
18364                Use->getOperand(0) == VecOp &&
18365                isa<ConstantSDNode>(Use->getOperand(1));
18366       })) {
18367     APInt DemandedElts = APInt::getNullValue(NumElts);
18368     for (SDNode *Use : VecOp->uses()) {
18369       auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
18370       if (CstElt->getAPIntValue().ult(NumElts))
18371         DemandedElts.setBit(CstElt->getZExtValue());
18372     }
18373     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
18374       // We simplified the vector operand of this extract element. If this
18375       // extract is not dead, visit it again so it is folded properly.
18376       if (N->getOpcode() != ISD::DELETED_NODE)
18377         AddToWorklist(N);
18378       return SDValue(N, 0);
18379     }
18380     APInt DemandedBits = APInt::getAllOnesValue(VecEltBitWidth);
18381     if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
18382       // We simplified the vector operand of this extract element. If this
18383       // extract is not dead, visit it again so it is folded properly.
18384       if (N->getOpcode() != ISD::DELETED_NODE)
18385         AddToWorklist(N);
18386       return SDValue(N, 0);
18387     }
18388   }
18389 
18390   // Everything under here is trying to match an extract of a loaded value.
18391   // If the result of load has to be truncated, then it's not necessarily
18392   // profitable.
18393   bool BCNumEltsChanged = false;
18394   EVT ExtVT = VecVT.getVectorElementType();
18395   EVT LVT = ExtVT;
18396   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
18397     return SDValue();
18398 
18399   if (VecOp.getOpcode() == ISD::BITCAST) {
18400     // Don't duplicate a load with other uses.
18401     if (!VecOp.hasOneUse())
18402       return SDValue();
18403 
18404     EVT BCVT = VecOp.getOperand(0).getValueType();
18405     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
18406       return SDValue();
18407     if (NumElts != BCVT.getVectorNumElements())
18408       BCNumEltsChanged = true;
18409     VecOp = VecOp.getOperand(0);
18410     ExtVT = BCVT.getVectorElementType();
18411   }
18412 
18413   // extract (vector load $addr), i --> load $addr + i * size
18414   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
18415       ISD::isNormalLoad(VecOp.getNode()) &&
18416       !Index->hasPredecessor(VecOp.getNode())) {
18417     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
18418     if (VecLoad && VecLoad->isSimple())
18419       return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
18420   }
18421 
18422   // Perform only after legalization to ensure build_vector / vector_shuffle
18423   // optimizations have already been done.
18424   if (!LegalOperations || !IndexC)
18425     return SDValue();
18426 
18427   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
18428   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
18429   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
18430   int Elt = IndexC->getZExtValue();
18431   LoadSDNode *LN0 = nullptr;
18432   if (ISD::isNormalLoad(VecOp.getNode())) {
18433     LN0 = cast<LoadSDNode>(VecOp);
18434   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
18435              VecOp.getOperand(0).getValueType() == ExtVT &&
18436              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
18437     // Don't duplicate a load with other uses.
18438     if (!VecOp.hasOneUse())
18439       return SDValue();
18440 
18441     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
18442   }
18443   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
18444     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
18445     // =>
18446     // (load $addr+1*size)
18447 
18448     // Don't duplicate a load with other uses.
18449     if (!VecOp.hasOneUse())
18450       return SDValue();
18451 
18452     // If the bit convert changed the number of elements, it is unsafe
18453     // to examine the mask.
18454     if (BCNumEltsChanged)
18455       return SDValue();
18456 
18457     // Select the input vector, guarding against out of range extract vector.
18458     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
18459     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
18460 
18461     if (VecOp.getOpcode() == ISD::BITCAST) {
18462       // Don't duplicate a load with other uses.
18463       if (!VecOp.hasOneUse())
18464         return SDValue();
18465 
18466       VecOp = VecOp.getOperand(0);
18467     }
18468     if (ISD::isNormalLoad(VecOp.getNode())) {
18469       LN0 = cast<LoadSDNode>(VecOp);
18470       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
18471       Index = DAG.getConstant(Elt, DL, Index.getValueType());
18472     }
18473   } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
18474              VecVT.getVectorElementType() == ScalarVT &&
18475              (!LegalTypes ||
18476               TLI.isTypeLegal(
18477                   VecOp.getOperand(0).getValueType().getVectorElementType()))) {
18478     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
18479     //      -> extract_vector_elt a, 0
18480     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
18481     //      -> extract_vector_elt a, 1
18482     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
18483     //      -> extract_vector_elt b, 0
18484     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
18485     //      -> extract_vector_elt b, 1
18486     SDLoc SL(N);
18487     EVT ConcatVT = VecOp.getOperand(0).getValueType();
18488     unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
18489     SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, SL,
18490                                      Index.getValueType());
18491 
18492     SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
18493     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL,
18494                               ConcatVT.getVectorElementType(),
18495                               ConcatOp, NewIdx);
18496     return DAG.getNode(ISD::BITCAST, SL, ScalarVT, Elt);
18497   }
18498 
18499   // Make sure we found a non-volatile load and the extractelement is
18500   // the only use.
18501   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
18502     return SDValue();
18503 
18504   // If Idx was -1 above, Elt is going to be -1, so just return undef.
18505   if (Elt == -1)
18506     return DAG.getUNDEF(LVT);
18507 
18508   return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
18509 }
18510 
18511 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)18512 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
18513   // We perform this optimization post type-legalization because
18514   // the type-legalizer often scalarizes integer-promoted vectors.
18515   // Performing this optimization before may create bit-casts which
18516   // will be type-legalized to complex code sequences.
18517   // We perform this optimization only before the operation legalizer because we
18518   // may introduce illegal operations.
18519   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
18520     return SDValue();
18521 
18522   unsigned NumInScalars = N->getNumOperands();
18523   SDLoc DL(N);
18524   EVT VT = N->getValueType(0);
18525 
18526   // Check to see if this is a BUILD_VECTOR of a bunch of values
18527   // which come from any_extend or zero_extend nodes. If so, we can create
18528   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
18529   // optimizations. We do not handle sign-extend because we can't fill the sign
18530   // using shuffles.
18531   EVT SourceType = MVT::Other;
18532   bool AllAnyExt = true;
18533 
18534   for (unsigned i = 0; i != NumInScalars; ++i) {
18535     SDValue In = N->getOperand(i);
18536     // Ignore undef inputs.
18537     if (In.isUndef()) continue;
18538 
18539     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
18540     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
18541 
18542     // Abort if the element is not an extension.
18543     if (!ZeroExt && !AnyExt) {
18544       SourceType = MVT::Other;
18545       break;
18546     }
18547 
18548     // The input is a ZeroExt or AnyExt. Check the original type.
18549     EVT InTy = In.getOperand(0).getValueType();
18550 
18551     // Check that all of the widened source types are the same.
18552     if (SourceType == MVT::Other)
18553       // First time.
18554       SourceType = InTy;
18555     else if (InTy != SourceType) {
18556       // Multiple income types. Abort.
18557       SourceType = MVT::Other;
18558       break;
18559     }
18560 
18561     // Check if all of the extends are ANY_EXTENDs.
18562     AllAnyExt &= AnyExt;
18563   }
18564 
18565   // In order to have valid types, all of the inputs must be extended from the
18566   // same source type and all of the inputs must be any or zero extend.
18567   // Scalar sizes must be a power of two.
18568   EVT OutScalarTy = VT.getScalarType();
18569   bool ValidTypes = SourceType != MVT::Other &&
18570                  isPowerOf2_32(OutScalarTy.getSizeInBits()) &&
18571                  isPowerOf2_32(SourceType.getSizeInBits());
18572 
18573   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
18574   // turn into a single shuffle instruction.
18575   if (!ValidTypes)
18576     return SDValue();
18577 
18578   // If we already have a splat buildvector, then don't fold it if it means
18579   // introducing zeros.
18580   if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
18581     return SDValue();
18582 
18583   bool isLE = DAG.getDataLayout().isLittleEndian();
18584   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
18585   assert(ElemRatio > 1 && "Invalid element size ratio");
18586   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
18587                                DAG.getConstant(0, DL, SourceType);
18588 
18589   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
18590   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
18591 
18592   // Populate the new build_vector
18593   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
18594     SDValue Cast = N->getOperand(i);
18595     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
18596             Cast.getOpcode() == ISD::ZERO_EXTEND ||
18597             Cast.isUndef()) && "Invalid cast opcode");
18598     SDValue In;
18599     if (Cast.isUndef())
18600       In = DAG.getUNDEF(SourceType);
18601     else
18602       In = Cast->getOperand(0);
18603     unsigned Index = isLE ? (i * ElemRatio) :
18604                             (i * ElemRatio + (ElemRatio - 1));
18605 
18606     assert(Index < Ops.size() && "Invalid index");
18607     Ops[Index] = In;
18608   }
18609 
18610   // The type of the new BUILD_VECTOR node.
18611   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
18612   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
18613          "Invalid vector size");
18614   // Check if the new vector type is legal.
18615   if (!isTypeLegal(VecVT) ||
18616       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
18617        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
18618     return SDValue();
18619 
18620   // Make the new BUILD_VECTOR.
18621   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
18622 
18623   // The new BUILD_VECTOR node has the potential to be further optimized.
18624   AddToWorklist(BV.getNode());
18625   // Bitcast to the desired type.
18626   return DAG.getBitcast(VT, BV);
18627 }
18628 
18629 // Simplify (build_vec (trunc $1)
18630 //                     (trunc (srl $1 half-width))
18631 //                     (trunc (srl $1 (2 * half-width))) …)
18632 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)18633 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
18634   assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
18635 
18636   // Only for little endian
18637   if (!DAG.getDataLayout().isLittleEndian())
18638     return SDValue();
18639 
18640   SDLoc DL(N);
18641   EVT VT = N->getValueType(0);
18642   EVT OutScalarTy = VT.getScalarType();
18643   uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
18644 
18645   // Only for power of two types to be sure that bitcast works well
18646   if (!isPowerOf2_64(ScalarTypeBitsize))
18647     return SDValue();
18648 
18649   unsigned NumInScalars = N->getNumOperands();
18650 
18651   // Look through bitcasts
18652   auto PeekThroughBitcast = [](SDValue Op) {
18653     if (Op.getOpcode() == ISD::BITCAST)
18654       return Op.getOperand(0);
18655     return Op;
18656   };
18657 
18658   // The source value where all the parts are extracted.
18659   SDValue Src;
18660   for (unsigned i = 0; i != NumInScalars; ++i) {
18661     SDValue In = PeekThroughBitcast(N->getOperand(i));
18662     // Ignore undef inputs.
18663     if (In.isUndef()) continue;
18664 
18665     if (In.getOpcode() != ISD::TRUNCATE)
18666       return SDValue();
18667 
18668     In = PeekThroughBitcast(In.getOperand(0));
18669 
18670     if (In.getOpcode() != ISD::SRL) {
18671       // For now only build_vec without shuffling, handle shifts here in the
18672       // future.
18673       if (i != 0)
18674         return SDValue();
18675 
18676       Src = In;
18677     } else {
18678       // In is SRL
18679       SDValue part = PeekThroughBitcast(In.getOperand(0));
18680 
18681       if (!Src) {
18682         Src = part;
18683       } else if (Src != part) {
18684         // Vector parts do not stem from the same variable
18685         return SDValue();
18686       }
18687 
18688       SDValue ShiftAmtVal = In.getOperand(1);
18689       if (!isa<ConstantSDNode>(ShiftAmtVal))
18690         return SDValue();
18691 
18692       uint64_t ShiftAmt = In.getNode()->getConstantOperandVal(1);
18693 
18694       // The extracted value is not extracted at the right position
18695       if (ShiftAmt != i * ScalarTypeBitsize)
18696         return SDValue();
18697     }
18698   }
18699 
18700   // Only cast if the size is the same
18701   if (Src.getValueType().getSizeInBits() != VT.getSizeInBits())
18702     return SDValue();
18703 
18704   return DAG.getBitcast(VT, Src);
18705 }
18706 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)18707 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
18708                                            ArrayRef<int> VectorMask,
18709                                            SDValue VecIn1, SDValue VecIn2,
18710                                            unsigned LeftIdx, bool DidSplitVec) {
18711   SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
18712 
18713   EVT VT = N->getValueType(0);
18714   EVT InVT1 = VecIn1.getValueType();
18715   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
18716 
18717   unsigned NumElems = VT.getVectorNumElements();
18718   unsigned ShuffleNumElems = NumElems;
18719 
18720   // If we artificially split a vector in two already, then the offsets in the
18721   // operands will all be based off of VecIn1, even those in VecIn2.
18722   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
18723 
18724   uint64_t VTSize = VT.getFixedSizeInBits();
18725   uint64_t InVT1Size = InVT1.getFixedSizeInBits();
18726   uint64_t InVT2Size = InVT2.getFixedSizeInBits();
18727 
18728   // We can't generate a shuffle node with mismatched input and output types.
18729   // Try to make the types match the type of the output.
18730   if (InVT1 != VT || InVT2 != VT) {
18731     if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
18732       // If the output vector length is a multiple of both input lengths,
18733       // we can concatenate them and pad the rest with undefs.
18734       unsigned NumConcats = VTSize / InVT1Size;
18735       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
18736       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
18737       ConcatOps[0] = VecIn1;
18738       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
18739       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
18740       VecIn2 = SDValue();
18741     } else if (InVT1Size == VTSize * 2) {
18742       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
18743         return SDValue();
18744 
18745       if (!VecIn2.getNode()) {
18746         // If we only have one input vector, and it's twice the size of the
18747         // output, split it in two.
18748         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
18749                              DAG.getVectorIdxConstant(NumElems, DL));
18750         VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
18751         // Since we now have shorter input vectors, adjust the offset of the
18752         // second vector's start.
18753         Vec2Offset = NumElems;
18754       } else if (InVT2Size <= InVT1Size) {
18755         // VecIn1 is wider than the output, and we have another, possibly
18756         // smaller input. Pad the smaller input with undefs, shuffle at the
18757         // input vector width, and extract the output.
18758         // The shuffle type is different than VT, so check legality again.
18759         if (LegalOperations &&
18760             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
18761           return SDValue();
18762 
18763         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
18764         // lower it back into a BUILD_VECTOR. So if the inserted type is
18765         // illegal, don't even try.
18766         if (InVT1 != InVT2) {
18767           if (!TLI.isTypeLegal(InVT2))
18768             return SDValue();
18769           VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
18770                                DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
18771         }
18772         ShuffleNumElems = NumElems * 2;
18773       } else {
18774         // Both VecIn1 and VecIn2 are wider than the output, and VecIn2 is wider
18775         // than VecIn1. We can't handle this for now - this case will disappear
18776         // when we start sorting the vectors by type.
18777         return SDValue();
18778       }
18779     } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
18780       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
18781       ConcatOps[0] = VecIn2;
18782       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
18783     } else {
18784       // TODO: Support cases where the length mismatch isn't exactly by a
18785       // factor of 2.
18786       // TODO: Move this check upwards, so that if we have bad type
18787       // mismatches, we don't create any DAG nodes.
18788       return SDValue();
18789     }
18790   }
18791 
18792   // Initialize mask to undef.
18793   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
18794 
18795   // Only need to run up to the number of elements actually used, not the
18796   // total number of elements in the shuffle - if we are shuffling a wider
18797   // vector, the high lanes should be set to undef.
18798   for (unsigned i = 0; i != NumElems; ++i) {
18799     if (VectorMask[i] <= 0)
18800       continue;
18801 
18802     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
18803     if (VectorMask[i] == (int)LeftIdx) {
18804       Mask[i] = ExtIndex;
18805     } else if (VectorMask[i] == (int)LeftIdx + 1) {
18806       Mask[i] = Vec2Offset + ExtIndex;
18807     }
18808   }
18809 
18810   // The type the input vectors may have changed above.
18811   InVT1 = VecIn1.getValueType();
18812 
18813   // If we already have a VecIn2, it should have the same type as VecIn1.
18814   // If we don't, get an undef/zero vector of the appropriate type.
18815   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
18816   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
18817 
18818   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
18819   if (ShuffleNumElems > NumElems)
18820     Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
18821 
18822   return Shuffle;
18823 }
18824 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)18825 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
18826   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
18827 
18828   // First, determine where the build vector is not undef.
18829   // TODO: We could extend this to handle zero elements as well as undefs.
18830   int NumBVOps = BV->getNumOperands();
18831   int ZextElt = -1;
18832   for (int i = 0; i != NumBVOps; ++i) {
18833     SDValue Op = BV->getOperand(i);
18834     if (Op.isUndef())
18835       continue;
18836     if (ZextElt == -1)
18837       ZextElt = i;
18838     else
18839       return SDValue();
18840   }
18841   // Bail out if there's no non-undef element.
18842   if (ZextElt == -1)
18843     return SDValue();
18844 
18845   // The build vector contains some number of undef elements and exactly
18846   // one other element. That other element must be a zero-extended scalar
18847   // extracted from a vector at a constant index to turn this into a shuffle.
18848   // Also, require that the build vector does not implicitly truncate/extend
18849   // its elements.
18850   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
18851   EVT VT = BV->getValueType(0);
18852   SDValue Zext = BV->getOperand(ZextElt);
18853   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
18854       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
18855       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
18856       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
18857     return SDValue();
18858 
18859   // The zero-extend must be a multiple of the source size, and we must be
18860   // building a vector of the same size as the source of the extract element.
18861   SDValue Extract = Zext.getOperand(0);
18862   unsigned DestSize = Zext.getValueSizeInBits();
18863   unsigned SrcSize = Extract.getValueSizeInBits();
18864   if (DestSize % SrcSize != 0 ||
18865       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
18866     return SDValue();
18867 
18868   // Create a shuffle mask that will combine the extracted element with zeros
18869   // and undefs.
18870   int ZextRatio = DestSize / SrcSize;
18871   int NumMaskElts = NumBVOps * ZextRatio;
18872   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
18873   for (int i = 0; i != NumMaskElts; ++i) {
18874     if (i / ZextRatio == ZextElt) {
18875       // The low bits of the (potentially translated) extracted element map to
18876       // the source vector. The high bits map to zero. We will use a zero vector
18877       // as the 2nd source operand of the shuffle, so use the 1st element of
18878       // that vector (mask value is number-of-elements) for the high bits.
18879       if (i % ZextRatio == 0)
18880         ShufMask[i] = Extract.getConstantOperandVal(1);
18881       else
18882         ShufMask[i] = NumMaskElts;
18883     }
18884 
18885     // Undef elements of the build vector remain undef because we initialize
18886     // the shuffle mask with -1.
18887   }
18888 
18889   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
18890   // bitcast (shuffle V, ZeroVec, VectorMask)
18891   SDLoc DL(BV);
18892   EVT VecVT = Extract.getOperand(0).getValueType();
18893   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
18894   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18895   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
18896                                              ZeroVec, ShufMask, DAG);
18897   if (!Shuf)
18898     return SDValue();
18899   return DAG.getBitcast(VT, Shuf);
18900 }
18901 
18902 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
18903 // operations. If the types of the vectors we're extracting from allow it,
18904 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)18905 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
18906   SDLoc DL(N);
18907   EVT VT = N->getValueType(0);
18908 
18909   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
18910   if (!isTypeLegal(VT))
18911     return SDValue();
18912 
18913   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
18914     return V;
18915 
18916   // May only combine to shuffle after legalize if shuffle is legal.
18917   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
18918     return SDValue();
18919 
18920   bool UsesZeroVector = false;
18921   unsigned NumElems = N->getNumOperands();
18922 
18923   // Record, for each element of the newly built vector, which input vector
18924   // that element comes from. -1 stands for undef, 0 for the zero vector,
18925   // and positive values for the input vectors.
18926   // VectorMask maps each element to its vector number, and VecIn maps vector
18927   // numbers to their initial SDValues.
18928 
18929   SmallVector<int, 8> VectorMask(NumElems, -1);
18930   SmallVector<SDValue, 8> VecIn;
18931   VecIn.push_back(SDValue());
18932 
18933   for (unsigned i = 0; i != NumElems; ++i) {
18934     SDValue Op = N->getOperand(i);
18935 
18936     if (Op.isUndef())
18937       continue;
18938 
18939     // See if we can use a blend with a zero vector.
18940     // TODO: Should we generalize this to a blend with an arbitrary constant
18941     // vector?
18942     if (isNullConstant(Op) || isNullFPConstant(Op)) {
18943       UsesZeroVector = true;
18944       VectorMask[i] = 0;
18945       continue;
18946     }
18947 
18948     // Not an undef or zero. If the input is something other than an
18949     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
18950     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
18951         !isa<ConstantSDNode>(Op.getOperand(1)))
18952       return SDValue();
18953     SDValue ExtractedFromVec = Op.getOperand(0);
18954 
18955     if (ExtractedFromVec.getValueType().isScalableVector())
18956       return SDValue();
18957 
18958     const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
18959     if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
18960       return SDValue();
18961 
18962     // All inputs must have the same element type as the output.
18963     if (VT.getVectorElementType() !=
18964         ExtractedFromVec.getValueType().getVectorElementType())
18965       return SDValue();
18966 
18967     // Have we seen this input vector before?
18968     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
18969     // a map back from SDValues to numbers isn't worth it.
18970     unsigned Idx = std::distance(VecIn.begin(), find(VecIn, ExtractedFromVec));
18971     if (Idx == VecIn.size())
18972       VecIn.push_back(ExtractedFromVec);
18973 
18974     VectorMask[i] = Idx;
18975   }
18976 
18977   // If we didn't find at least one input vector, bail out.
18978   if (VecIn.size() < 2)
18979     return SDValue();
18980 
18981   // If all the Operands of BUILD_VECTOR extract from same
18982   // vector, then split the vector efficiently based on the maximum
18983   // vector access index and adjust the VectorMask and
18984   // VecIn accordingly.
18985   bool DidSplitVec = false;
18986   if (VecIn.size() == 2) {
18987     unsigned MaxIndex = 0;
18988     unsigned NearestPow2 = 0;
18989     SDValue Vec = VecIn.back();
18990     EVT InVT = Vec.getValueType();
18991     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
18992 
18993     for (unsigned i = 0; i < NumElems; i++) {
18994       if (VectorMask[i] <= 0)
18995         continue;
18996       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
18997       IndexVec[i] = Index;
18998       MaxIndex = std::max(MaxIndex, Index);
18999     }
19000 
19001     NearestPow2 = PowerOf2Ceil(MaxIndex);
19002     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
19003         NumElems * 2 < NearestPow2) {
19004       unsigned SplitSize = NearestPow2 / 2;
19005       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
19006                                      InVT.getVectorElementType(), SplitSize);
19007       if (TLI.isTypeLegal(SplitVT)) {
19008         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
19009                                      DAG.getVectorIdxConstant(SplitSize, DL));
19010         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
19011                                      DAG.getVectorIdxConstant(0, DL));
19012         VecIn.pop_back();
19013         VecIn.push_back(VecIn1);
19014         VecIn.push_back(VecIn2);
19015         DidSplitVec = true;
19016 
19017         for (unsigned i = 0; i < NumElems; i++) {
19018           if (VectorMask[i] <= 0)
19019             continue;
19020           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
19021         }
19022       }
19023     }
19024   }
19025 
19026   // TODO: We want to sort the vectors by descending length, so that adjacent
19027   // pairs have similar length, and the longer vector is always first in the
19028   // pair.
19029 
19030   // TODO: Should this fire if some of the input vectors has illegal type (like
19031   // it does now), or should we let legalization run its course first?
19032 
19033   // Shuffle phase:
19034   // Take pairs of vectors, and shuffle them so that the result has elements
19035   // from these vectors in the correct places.
19036   // For example, given:
19037   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
19038   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
19039   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
19040   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
19041   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
19042   // We will generate:
19043   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
19044   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
19045   SmallVector<SDValue, 4> Shuffles;
19046   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
19047     unsigned LeftIdx = 2 * In + 1;
19048     SDValue VecLeft = VecIn[LeftIdx];
19049     SDValue VecRight =
19050         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
19051 
19052     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
19053                                                 VecRight, LeftIdx, DidSplitVec))
19054       Shuffles.push_back(Shuffle);
19055     else
19056       return SDValue();
19057   }
19058 
19059   // If we need the zero vector as an "ingredient" in the blend tree, add it
19060   // to the list of shuffles.
19061   if (UsesZeroVector)
19062     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
19063                                       : DAG.getConstantFP(0.0, DL, VT));
19064 
19065   // If we only have one shuffle, we're done.
19066   if (Shuffles.size() == 1)
19067     return Shuffles[0];
19068 
19069   // Update the vector mask to point to the post-shuffle vectors.
19070   for (int &Vec : VectorMask)
19071     if (Vec == 0)
19072       Vec = Shuffles.size() - 1;
19073     else
19074       Vec = (Vec - 1) / 2;
19075 
19076   // More than one shuffle. Generate a binary tree of blends, e.g. if from
19077   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
19078   // generate:
19079   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
19080   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
19081   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
19082   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
19083   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
19084   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
19085   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
19086 
19087   // Make sure the initial size of the shuffle list is even.
19088   if (Shuffles.size() % 2)
19089     Shuffles.push_back(DAG.getUNDEF(VT));
19090 
19091   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
19092     if (CurSize % 2) {
19093       Shuffles[CurSize] = DAG.getUNDEF(VT);
19094       CurSize++;
19095     }
19096     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
19097       int Left = 2 * In;
19098       int Right = 2 * In + 1;
19099       SmallVector<int, 8> Mask(NumElems, -1);
19100       for (unsigned i = 0; i != NumElems; ++i) {
19101         if (VectorMask[i] == Left) {
19102           Mask[i] = i;
19103           VectorMask[i] = In;
19104         } else if (VectorMask[i] == Right) {
19105           Mask[i] = i + NumElems;
19106           VectorMask[i] = In;
19107         }
19108       }
19109 
19110       Shuffles[In] =
19111           DAG.getVectorShuffle(VT, DL, Shuffles[Left], Shuffles[Right], Mask);
19112     }
19113   }
19114   return Shuffles[0];
19115 }
19116 
19117 // Try to turn a build vector of zero extends of extract vector elts into a
19118 // a vector zero extend and possibly an extract subvector.
19119 // TODO: Support sign extend?
19120 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)19121 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
19122   if (LegalOperations)
19123     return SDValue();
19124 
19125   EVT VT = N->getValueType(0);
19126 
19127   bool FoundZeroExtend = false;
19128   SDValue Op0 = N->getOperand(0);
19129   auto checkElem = [&](SDValue Op) -> int64_t {
19130     unsigned Opc = Op.getOpcode();
19131     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
19132     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
19133         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19134         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
19135       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
19136         return C->getZExtValue();
19137     return -1;
19138   };
19139 
19140   // Make sure the first element matches
19141   // (zext (extract_vector_elt X, C))
19142   int64_t Offset = checkElem(Op0);
19143   if (Offset < 0)
19144     return SDValue();
19145 
19146   unsigned NumElems = N->getNumOperands();
19147   SDValue In = Op0.getOperand(0).getOperand(0);
19148   EVT InSVT = In.getValueType().getScalarType();
19149   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
19150 
19151   // Don't create an illegal input type after type legalization.
19152   if (LegalTypes && !TLI.isTypeLegal(InVT))
19153     return SDValue();
19154 
19155   // Ensure all the elements come from the same vector and are adjacent.
19156   for (unsigned i = 1; i != NumElems; ++i) {
19157     if ((Offset + i) != checkElem(N->getOperand(i)))
19158       return SDValue();
19159   }
19160 
19161   SDLoc DL(N);
19162   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
19163                    Op0.getOperand(0).getOperand(1));
19164   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
19165                      VT, In);
19166 }
19167 
visitBUILD_VECTOR(SDNode * N)19168 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
19169   EVT VT = N->getValueType(0);
19170 
19171   // A vector built entirely of undefs is undef.
19172   if (ISD::allOperandsUndef(N))
19173     return DAG.getUNDEF(VT);
19174 
19175   // If this is a splat of a bitcast from another vector, change to a
19176   // concat_vector.
19177   // For example:
19178   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
19179   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
19180   //
19181   // If X is a build_vector itself, the concat can become a larger build_vector.
19182   // TODO: Maybe this is useful for non-splat too?
19183   if (!LegalOperations) {
19184     if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
19185       Splat = peekThroughBitcasts(Splat);
19186       EVT SrcVT = Splat.getValueType();
19187       if (SrcVT.isVector()) {
19188         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
19189         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
19190                                      SrcVT.getVectorElementType(), NumElts);
19191         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
19192           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
19193           SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
19194                                        NewVT, Ops);
19195           return DAG.getBitcast(VT, Concat);
19196         }
19197       }
19198     }
19199   }
19200 
19201   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
19202   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
19203     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
19204       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
19205       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
19206     }
19207 
19208   // Check if we can express BUILD VECTOR via subvector extract.
19209   if (!LegalTypes && (N->getNumOperands() > 1)) {
19210     SDValue Op0 = N->getOperand(0);
19211     auto checkElem = [&](SDValue Op) -> uint64_t {
19212       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
19213           (Op0.getOperand(0) == Op.getOperand(0)))
19214         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
19215           return CNode->getZExtValue();
19216       return -1;
19217     };
19218 
19219     int Offset = checkElem(Op0);
19220     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
19221       if (Offset + i != checkElem(N->getOperand(i))) {
19222         Offset = -1;
19223         break;
19224       }
19225     }
19226 
19227     if ((Offset == 0) &&
19228         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
19229       return Op0.getOperand(0);
19230     if ((Offset != -1) &&
19231         ((Offset % N->getValueType(0).getVectorNumElements()) ==
19232          0)) // IDX must be multiple of output size.
19233       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
19234                          Op0.getOperand(0), Op0.getOperand(1));
19235   }
19236 
19237   if (SDValue V = convertBuildVecZextToZext(N))
19238     return V;
19239 
19240   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
19241     return V;
19242 
19243   if (SDValue V = reduceBuildVecTruncToBitCast(N))
19244     return V;
19245 
19246   if (SDValue V = reduceBuildVecToShuffle(N))
19247     return V;
19248 
19249   return SDValue();
19250 }
19251 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)19252 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
19253   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19254   EVT OpVT = N->getOperand(0).getValueType();
19255 
19256   // If the operands are legal vectors, leave them alone.
19257   if (TLI.isTypeLegal(OpVT))
19258     return SDValue();
19259 
19260   SDLoc DL(N);
19261   EVT VT = N->getValueType(0);
19262   SmallVector<SDValue, 8> Ops;
19263 
19264   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
19265   SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
19266 
19267   // Keep track of what we encounter.
19268   bool AnyInteger = false;
19269   bool AnyFP = false;
19270   for (const SDValue &Op : N->ops()) {
19271     if (ISD::BITCAST == Op.getOpcode() &&
19272         !Op.getOperand(0).getValueType().isVector())
19273       Ops.push_back(Op.getOperand(0));
19274     else if (ISD::UNDEF == Op.getOpcode())
19275       Ops.push_back(ScalarUndef);
19276     else
19277       return SDValue();
19278 
19279     // Note whether we encounter an integer or floating point scalar.
19280     // If it's neither, bail out, it could be something weird like x86mmx.
19281     EVT LastOpVT = Ops.back().getValueType();
19282     if (LastOpVT.isFloatingPoint())
19283       AnyFP = true;
19284     else if (LastOpVT.isInteger())
19285       AnyInteger = true;
19286     else
19287       return SDValue();
19288   }
19289 
19290   // If any of the operands is a floating point scalar bitcast to a vector,
19291   // use floating point types throughout, and bitcast everything.
19292   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
19293   if (AnyFP) {
19294     SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
19295     ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
19296     if (AnyInteger) {
19297       for (SDValue &Op : Ops) {
19298         if (Op.getValueType() == SVT)
19299           continue;
19300         if (Op.isUndef())
19301           Op = ScalarUndef;
19302         else
19303           Op = DAG.getBitcast(SVT, Op);
19304       }
19305     }
19306   }
19307 
19308   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
19309                                VT.getSizeInBits() / SVT.getSizeInBits());
19310   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
19311 }
19312 
19313 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
19314 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
19315 // most two distinct vectors the same size as the result, attempt to turn this
19316 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)19317 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
19318   EVT VT = N->getValueType(0);
19319   EVT OpVT = N->getOperand(0).getValueType();
19320 
19321   // We currently can't generate an appropriate shuffle for a scalable vector.
19322   if (VT.isScalableVector())
19323     return SDValue();
19324 
19325   int NumElts = VT.getVectorNumElements();
19326   int NumOpElts = OpVT.getVectorNumElements();
19327 
19328   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
19329   SmallVector<int, 8> Mask;
19330 
19331   for (SDValue Op : N->ops()) {
19332     Op = peekThroughBitcasts(Op);
19333 
19334     // UNDEF nodes convert to UNDEF shuffle mask values.
19335     if (Op.isUndef()) {
19336       Mask.append((unsigned)NumOpElts, -1);
19337       continue;
19338     }
19339 
19340     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
19341       return SDValue();
19342 
19343     // What vector are we extracting the subvector from and at what index?
19344     SDValue ExtVec = Op.getOperand(0);
19345     int ExtIdx = Op.getConstantOperandVal(1);
19346 
19347     // We want the EVT of the original extraction to correctly scale the
19348     // extraction index.
19349     EVT ExtVT = ExtVec.getValueType();
19350     ExtVec = peekThroughBitcasts(ExtVec);
19351 
19352     // UNDEF nodes convert to UNDEF shuffle mask values.
19353     if (ExtVec.isUndef()) {
19354       Mask.append((unsigned)NumOpElts, -1);
19355       continue;
19356     }
19357 
19358     // Ensure that we are extracting a subvector from a vector the same
19359     // size as the result.
19360     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
19361       return SDValue();
19362 
19363     // Scale the subvector index to account for any bitcast.
19364     int NumExtElts = ExtVT.getVectorNumElements();
19365     if (0 == (NumExtElts % NumElts))
19366       ExtIdx /= (NumExtElts / NumElts);
19367     else if (0 == (NumElts % NumExtElts))
19368       ExtIdx *= (NumElts / NumExtElts);
19369     else
19370       return SDValue();
19371 
19372     // At most we can reference 2 inputs in the final shuffle.
19373     if (SV0.isUndef() || SV0 == ExtVec) {
19374       SV0 = ExtVec;
19375       for (int i = 0; i != NumOpElts; ++i)
19376         Mask.push_back(i + ExtIdx);
19377     } else if (SV1.isUndef() || SV1 == ExtVec) {
19378       SV1 = ExtVec;
19379       for (int i = 0; i != NumOpElts; ++i)
19380         Mask.push_back(i + ExtIdx + NumElts);
19381     } else {
19382       return SDValue();
19383     }
19384   }
19385 
19386   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19387   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
19388                                      DAG.getBitcast(VT, SV1), Mask, DAG);
19389 }
19390 
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)19391 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
19392   unsigned CastOpcode = N->getOperand(0).getOpcode();
19393   switch (CastOpcode) {
19394   case ISD::SINT_TO_FP:
19395   case ISD::UINT_TO_FP:
19396   case ISD::FP_TO_SINT:
19397   case ISD::FP_TO_UINT:
19398     // TODO: Allow more opcodes?
19399     //  case ISD::BITCAST:
19400     //  case ISD::TRUNCATE:
19401     //  case ISD::ZERO_EXTEND:
19402     //  case ISD::SIGN_EXTEND:
19403     //  case ISD::FP_EXTEND:
19404     break;
19405   default:
19406     return SDValue();
19407   }
19408 
19409   EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
19410   if (!SrcVT.isVector())
19411     return SDValue();
19412 
19413   // All operands of the concat must be the same kind of cast from the same
19414   // source type.
19415   SmallVector<SDValue, 4> SrcOps;
19416   for (SDValue Op : N->ops()) {
19417     if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
19418         Op.getOperand(0).getValueType() != SrcVT)
19419       return SDValue();
19420     SrcOps.push_back(Op.getOperand(0));
19421   }
19422 
19423   // The wider cast must be supported by the target. This is unusual because
19424   // the operation support type parameter depends on the opcode. In addition,
19425   // check the other type in the cast to make sure this is really legal.
19426   EVT VT = N->getValueType(0);
19427   EVT SrcEltVT = SrcVT.getVectorElementType();
19428   ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
19429   EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
19430   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19431   switch (CastOpcode) {
19432   case ISD::SINT_TO_FP:
19433   case ISD::UINT_TO_FP:
19434     if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
19435         !TLI.isTypeLegal(VT))
19436       return SDValue();
19437     break;
19438   case ISD::FP_TO_SINT:
19439   case ISD::FP_TO_UINT:
19440     if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
19441         !TLI.isTypeLegal(ConcatSrcVT))
19442       return SDValue();
19443     break;
19444   default:
19445     llvm_unreachable("Unexpected cast opcode");
19446   }
19447 
19448   // concat (cast X), (cast Y)... -> cast (concat X, Y...)
19449   SDLoc DL(N);
19450   SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
19451   return DAG.getNode(CastOpcode, DL, VT, NewConcat);
19452 }
19453 
visitCONCAT_VECTORS(SDNode * N)19454 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
19455   // If we only have one input vector, we don't need to do any concatenation.
19456   if (N->getNumOperands() == 1)
19457     return N->getOperand(0);
19458 
19459   // Check if all of the operands are undefs.
19460   EVT VT = N->getValueType(0);
19461   if (ISD::allOperandsUndef(N))
19462     return DAG.getUNDEF(VT);
19463 
19464   // Optimize concat_vectors where all but the first of the vectors are undef.
19465   if (all_of(drop_begin(N->ops()),
19466              [](const SDValue &Op) { return Op.isUndef(); })) {
19467     SDValue In = N->getOperand(0);
19468     assert(In.getValueType().isVector() && "Must concat vectors");
19469 
19470     // If the input is a concat_vectors, just make a larger concat by padding
19471     // with smaller undefs.
19472     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) {
19473       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
19474       SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
19475       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
19476       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
19477     }
19478 
19479     SDValue Scalar = peekThroughOneUseBitcasts(In);
19480 
19481     // concat_vectors(scalar_to_vector(scalar), undef) ->
19482     //     scalar_to_vector(scalar)
19483     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
19484          Scalar.hasOneUse()) {
19485       EVT SVT = Scalar.getValueType().getVectorElementType();
19486       if (SVT == Scalar.getOperand(0).getValueType())
19487         Scalar = Scalar.getOperand(0);
19488     }
19489 
19490     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
19491     if (!Scalar.getValueType().isVector()) {
19492       // If the bitcast type isn't legal, it might be a trunc of a legal type;
19493       // look through the trunc so we can still do the transform:
19494       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
19495       if (Scalar->getOpcode() == ISD::TRUNCATE &&
19496           !TLI.isTypeLegal(Scalar.getValueType()) &&
19497           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
19498         Scalar = Scalar->getOperand(0);
19499 
19500       EVT SclTy = Scalar.getValueType();
19501 
19502       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
19503         return SDValue();
19504 
19505       // Bail out if the vector size is not a multiple of the scalar size.
19506       if (VT.getSizeInBits() % SclTy.getSizeInBits())
19507         return SDValue();
19508 
19509       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
19510       if (VNTNumElms < 2)
19511         return SDValue();
19512 
19513       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
19514       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
19515         return SDValue();
19516 
19517       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
19518       return DAG.getBitcast(VT, Res);
19519     }
19520   }
19521 
19522   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
19523   // We have already tested above for an UNDEF only concatenation.
19524   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
19525   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
19526   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
19527     return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
19528   };
19529   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
19530     SmallVector<SDValue, 8> Opnds;
19531     EVT SVT = VT.getScalarType();
19532 
19533     EVT MinVT = SVT;
19534     if (!SVT.isFloatingPoint()) {
19535       // If BUILD_VECTOR are from built from integer, they may have different
19536       // operand types. Get the smallest type and truncate all operands to it.
19537       bool FoundMinVT = false;
19538       for (const SDValue &Op : N->ops())
19539         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
19540           EVT OpSVT = Op.getOperand(0).getValueType();
19541           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
19542           FoundMinVT = true;
19543         }
19544       assert(FoundMinVT && "Concat vector type mismatch");
19545     }
19546 
19547     for (const SDValue &Op : N->ops()) {
19548       EVT OpVT = Op.getValueType();
19549       unsigned NumElts = OpVT.getVectorNumElements();
19550 
19551       if (ISD::UNDEF == Op.getOpcode())
19552         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
19553 
19554       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
19555         if (SVT.isFloatingPoint()) {
19556           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
19557           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
19558         } else {
19559           for (unsigned i = 0; i != NumElts; ++i)
19560             Opnds.push_back(
19561                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
19562         }
19563       }
19564     }
19565 
19566     assert(VT.getVectorNumElements() == Opnds.size() &&
19567            "Concat vector type mismatch");
19568     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
19569   }
19570 
19571   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
19572   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
19573     return V;
19574 
19575   // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
19576   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT))
19577     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
19578       return V;
19579 
19580   if (SDValue V = combineConcatVectorOfCasts(N, DAG))
19581     return V;
19582 
19583   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
19584   // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
19585   // operands and look for a CONCAT operations that place the incoming vectors
19586   // at the exact same location.
19587   //
19588   // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
19589   SDValue SingleSource = SDValue();
19590   unsigned PartNumElem =
19591       N->getOperand(0).getValueType().getVectorMinNumElements();
19592 
19593   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
19594     SDValue Op = N->getOperand(i);
19595 
19596     if (Op.isUndef())
19597       continue;
19598 
19599     // Check if this is the identity extract:
19600     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
19601       return SDValue();
19602 
19603     // Find the single incoming vector for the extract_subvector.
19604     if (SingleSource.getNode()) {
19605       if (Op.getOperand(0) != SingleSource)
19606         return SDValue();
19607     } else {
19608       SingleSource = Op.getOperand(0);
19609 
19610       // Check the source type is the same as the type of the result.
19611       // If not, this concat may extend the vector, so we can not
19612       // optimize it away.
19613       if (SingleSource.getValueType() != N->getValueType(0))
19614         return SDValue();
19615     }
19616 
19617     // Check that we are reading from the identity index.
19618     unsigned IdentityIndex = i * PartNumElem;
19619     if (Op.getConstantOperandAPInt(1) != IdentityIndex)
19620       return SDValue();
19621   }
19622 
19623   if (SingleSource.getNode())
19624     return SingleSource;
19625 
19626   return SDValue();
19627 }
19628 
19629 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
19630 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)19631 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
19632   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
19633       V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
19634     return V.getOperand(1);
19635   }
19636   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
19637   if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
19638       V.getOperand(0).getValueType() == SubVT &&
19639       (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
19640     uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
19641     return V.getOperand(SubIdx);
19642   }
19643   return SDValue();
19644 }
19645 
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)19646 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
19647                                               SelectionDAG &DAG,
19648                                               bool LegalOperations) {
19649   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19650   SDValue BinOp = Extract->getOperand(0);
19651   unsigned BinOpcode = BinOp.getOpcode();
19652   if (!TLI.isBinOp(BinOpcode) || BinOp.getNode()->getNumValues() != 1)
19653     return SDValue();
19654 
19655   EVT VecVT = BinOp.getValueType();
19656   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
19657   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
19658     return SDValue();
19659 
19660   SDValue Index = Extract->getOperand(1);
19661   EVT SubVT = Extract->getValueType(0);
19662   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
19663     return SDValue();
19664 
19665   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
19666   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
19667 
19668   // TODO: We could handle the case where only 1 operand is being inserted by
19669   //       creating an extract of the other operand, but that requires checking
19670   //       number of uses and/or costs.
19671   if (!Sub0 || !Sub1)
19672     return SDValue();
19673 
19674   // We are inserting both operands of the wide binop only to extract back
19675   // to the narrow vector size. Eliminate all of the insert/extract:
19676   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
19677   return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
19678                      BinOp->getFlags());
19679 }
19680 
19681 /// If we are extracting a subvector produced by a wide binary operator try
19682 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)19683 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
19684                                           bool LegalOperations) {
19685   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
19686   // some of these bailouts with other transforms.
19687 
19688   if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
19689     return V;
19690 
19691   // The extract index must be a constant, so we can map it to a concat operand.
19692   auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
19693   if (!ExtractIndexC)
19694     return SDValue();
19695 
19696   // We are looking for an optionally bitcasted wide vector binary operator
19697   // feeding an extract subvector.
19698   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19699   SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
19700   unsigned BOpcode = BinOp.getOpcode();
19701   if (!TLI.isBinOp(BOpcode) || BinOp.getNode()->getNumValues() != 1)
19702     return SDValue();
19703 
19704   // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
19705   // reduced to the unary fneg when it is visited, and we probably want to deal
19706   // with fneg in a target-specific way.
19707   if (BOpcode == ISD::FSUB) {
19708     auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
19709     if (C && C->getValueAPF().isNegZero())
19710       return SDValue();
19711   }
19712 
19713   // The binop must be a vector type, so we can extract some fraction of it.
19714   EVT WideBVT = BinOp.getValueType();
19715   // The optimisations below currently assume we are dealing with fixed length
19716   // vectors. It is possible to add support for scalable vectors, but at the
19717   // moment we've done no analysis to prove whether they are profitable or not.
19718   if (!WideBVT.isFixedLengthVector())
19719     return SDValue();
19720 
19721   EVT VT = Extract->getValueType(0);
19722   unsigned ExtractIndex = ExtractIndexC->getZExtValue();
19723   assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
19724          "Extract index is not a multiple of the vector length.");
19725 
19726   // Bail out if this is not a proper multiple width extraction.
19727   unsigned WideWidth = WideBVT.getSizeInBits();
19728   unsigned NarrowWidth = VT.getSizeInBits();
19729   if (WideWidth % NarrowWidth != 0)
19730     return SDValue();
19731 
19732   // Bail out if we are extracting a fraction of a single operation. This can
19733   // occur because we potentially looked through a bitcast of the binop.
19734   unsigned NarrowingRatio = WideWidth / NarrowWidth;
19735   unsigned WideNumElts = WideBVT.getVectorNumElements();
19736   if (WideNumElts % NarrowingRatio != 0)
19737     return SDValue();
19738 
19739   // Bail out if the target does not support a narrower version of the binop.
19740   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
19741                                    WideNumElts / NarrowingRatio);
19742   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
19743     return SDValue();
19744 
19745   // If extraction is cheap, we don't need to look at the binop operands
19746   // for concat ops. The narrow binop alone makes this transform profitable.
19747   // We can't just reuse the original extract index operand because we may have
19748   // bitcasted.
19749   unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
19750   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
19751   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
19752       BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
19753     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
19754     SDLoc DL(Extract);
19755     SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
19756     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
19757                             BinOp.getOperand(0), NewExtIndex);
19758     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
19759                             BinOp.getOperand(1), NewExtIndex);
19760     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y,
19761                                       BinOp.getNode()->getFlags());
19762     return DAG.getBitcast(VT, NarrowBinOp);
19763   }
19764 
19765   // Only handle the case where we are doubling and then halving. A larger ratio
19766   // may require more than two narrow binops to replace the wide binop.
19767   if (NarrowingRatio != 2)
19768     return SDValue();
19769 
19770   // TODO: The motivating case for this transform is an x86 AVX1 target. That
19771   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
19772   // flavors, but no other 256-bit integer support. This could be extended to
19773   // handle any binop, but that may require fixing/adding other folds to avoid
19774   // codegen regressions.
19775   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
19776     return SDValue();
19777 
19778   // We need at least one concatenation operation of a binop operand to make
19779   // this transform worthwhile. The concat must double the input vector sizes.
19780   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
19781     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
19782       return V.getOperand(ConcatOpNum);
19783     return SDValue();
19784   };
19785   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
19786   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
19787 
19788   if (SubVecL || SubVecR) {
19789     // If a binop operand was not the result of a concat, we must extract a
19790     // half-sized operand for our new narrow binop:
19791     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
19792     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
19793     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
19794     SDLoc DL(Extract);
19795     SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
19796     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
19797                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
19798                                       BinOp.getOperand(0), IndexC);
19799 
19800     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
19801                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
19802                                       BinOp.getOperand(1), IndexC);
19803 
19804     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
19805     return DAG.getBitcast(VT, NarrowBinOp);
19806   }
19807 
19808   return SDValue();
19809 }
19810 
19811 /// If we are extracting a subvector from a wide vector load, convert to a
19812 /// narrow load to eliminate the extraction:
19813 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)19814 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
19815   // TODO: Add support for big-endian. The offset calculation must be adjusted.
19816   if (DAG.getDataLayout().isBigEndian())
19817     return SDValue();
19818 
19819   auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
19820   auto *ExtIdx = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
19821   if (!Ld || Ld->getExtensionType() || !Ld->isSimple() ||
19822       !ExtIdx)
19823     return SDValue();
19824 
19825   // Allow targets to opt-out.
19826   EVT VT = Extract->getValueType(0);
19827 
19828   // We can only create byte sized loads.
19829   if (!VT.isByteSized())
19830     return SDValue();
19831 
19832   unsigned Index = ExtIdx->getZExtValue();
19833   unsigned NumElts = VT.getVectorMinNumElements();
19834 
19835   // The definition of EXTRACT_SUBVECTOR states that the index must be a
19836   // multiple of the minimum number of elements in the result type.
19837   assert(Index % NumElts == 0 && "The extract subvector index is not a "
19838                                  "multiple of the result's element count");
19839 
19840   // It's fine to use TypeSize here as we know the offset will not be negative.
19841   TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
19842 
19843   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19844   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
19845     return SDValue();
19846 
19847   // The narrow load will be offset from the base address of the old load if
19848   // we are extracting from something besides index 0 (little-endian).
19849   SDLoc DL(Extract);
19850 
19851   // TODO: Use "BaseIndexOffset" to make this more effective.
19852   SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
19853 
19854   uint64_t StoreSize = MemoryLocation::getSizeOrUnknown(VT.getStoreSize());
19855   MachineFunction &MF = DAG.getMachineFunction();
19856   MachineMemOperand *MMO;
19857   if (Offset.isScalable()) {
19858     MachinePointerInfo MPI =
19859         MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
19860     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, StoreSize);
19861   } else
19862     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedSize(),
19863                                   StoreSize);
19864 
19865   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
19866   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
19867   return NewLd;
19868 }
19869 
visitEXTRACT_SUBVECTOR(SDNode * N)19870 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
19871   EVT NVT = N->getValueType(0);
19872   SDValue V = N->getOperand(0);
19873   uint64_t ExtIdx = N->getConstantOperandVal(1);
19874 
19875   // Extract from UNDEF is UNDEF.
19876   if (V.isUndef())
19877     return DAG.getUNDEF(NVT);
19878 
19879   if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
19880     if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
19881       return NarrowLoad;
19882 
19883   // Combine an extract of an extract into a single extract_subvector.
19884   // ext (ext X, C), 0 --> ext X, C
19885   if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
19886     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
19887                                     V.getConstantOperandVal(1)) &&
19888         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
19889       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
19890                          V.getOperand(1));
19891     }
19892   }
19893 
19894   // Try to move vector bitcast after extract_subv by scaling extraction index:
19895   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
19896   if (V.getOpcode() == ISD::BITCAST &&
19897       V.getOperand(0).getValueType().isVector()) {
19898     SDValue SrcOp = V.getOperand(0);
19899     EVT SrcVT = SrcOp.getValueType();
19900     unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
19901     unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
19902     if ((SrcNumElts % DestNumElts) == 0) {
19903       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
19904       ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
19905       EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
19906                                       NewExtEC);
19907       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
19908         SDLoc DL(N);
19909         SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
19910         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
19911                                          V.getOperand(0), NewIndex);
19912         return DAG.getBitcast(NVT, NewExtract);
19913       }
19914     }
19915     if ((DestNumElts % SrcNumElts) == 0) {
19916       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
19917       if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
19918         ElementCount NewExtEC =
19919             NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
19920         EVT ScalarVT = SrcVT.getScalarType();
19921         if ((ExtIdx % DestSrcRatio) == 0) {
19922           SDLoc DL(N);
19923           unsigned IndexValScaled = ExtIdx / DestSrcRatio;
19924           EVT NewExtVT =
19925               EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
19926           if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
19927             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
19928             SDValue NewExtract =
19929                 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
19930                             V.getOperand(0), NewIndex);
19931             return DAG.getBitcast(NVT, NewExtract);
19932           }
19933           if (NewExtEC.isScalar() &&
19934               TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
19935             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
19936             SDValue NewExtract =
19937                 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
19938                             V.getOperand(0), NewIndex);
19939             return DAG.getBitcast(NVT, NewExtract);
19940           }
19941         }
19942       }
19943     }
19944   }
19945 
19946   if (V.getOpcode() == ISD::CONCAT_VECTORS) {
19947     unsigned ExtNumElts = NVT.getVectorMinNumElements();
19948     EVT ConcatSrcVT = V.getOperand(0).getValueType();
19949     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
19950            "Concat and extract subvector do not change element type");
19951     assert((ExtIdx % ExtNumElts) == 0 &&
19952            "Extract index is not a multiple of the input vector length.");
19953 
19954     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
19955     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
19956 
19957     // If the concatenated source types match this extract, it's a direct
19958     // simplification:
19959     // extract_subvec (concat V1, V2, ...), i --> Vi
19960     if (ConcatSrcNumElts == ExtNumElts)
19961       return V.getOperand(ConcatOpIdx);
19962 
19963     // If the concatenated source vectors are a multiple length of this extract,
19964     // then extract a fraction of one of those source vectors directly from a
19965     // concat operand. Example:
19966     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
19967     //   v2i8 extract_subvec v8i8 Y, 6
19968     if (NVT.isFixedLengthVector() && ConcatSrcNumElts % ExtNumElts == 0) {
19969       SDLoc DL(N);
19970       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
19971       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
19972              "Trying to extract from >1 concat operand?");
19973       assert(NewExtIdx % ExtNumElts == 0 &&
19974              "Extract index is not a multiple of the input vector length.");
19975       SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
19976       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
19977                          V.getOperand(ConcatOpIdx), NewIndexC);
19978     }
19979   }
19980 
19981   V = peekThroughBitcasts(V);
19982 
19983   // If the input is a build vector. Try to make a smaller build vector.
19984   if (V.getOpcode() == ISD::BUILD_VECTOR) {
19985     EVT InVT = V.getValueType();
19986     unsigned ExtractSize = NVT.getSizeInBits();
19987     unsigned EltSize = InVT.getScalarSizeInBits();
19988     // Only do this if we won't split any elements.
19989     if (ExtractSize % EltSize == 0) {
19990       unsigned NumElems = ExtractSize / EltSize;
19991       EVT EltVT = InVT.getVectorElementType();
19992       EVT ExtractVT =
19993           NumElems == 1 ? EltVT
19994                         : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
19995       if ((Level < AfterLegalizeDAG ||
19996            (NumElems == 1 ||
19997             TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
19998           (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
19999         unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
20000 
20001         if (NumElems == 1) {
20002           SDValue Src = V->getOperand(IdxVal);
20003           if (EltVT != Src.getValueType())
20004             Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
20005           return DAG.getBitcast(NVT, Src);
20006         }
20007 
20008         // Extract the pieces from the original build_vector.
20009         SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N),
20010                                               V->ops().slice(IdxVal, NumElems));
20011         return DAG.getBitcast(NVT, BuildVec);
20012       }
20013     }
20014   }
20015 
20016   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
20017     // Handle only simple case where vector being inserted and vector
20018     // being extracted are of same size.
20019     EVT SmallVT = V.getOperand(1).getValueType();
20020     if (!NVT.bitsEq(SmallVT))
20021       return SDValue();
20022 
20023     // Combine:
20024     //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
20025     // Into:
20026     //    indices are equal or bit offsets are equal => V1
20027     //    otherwise => (extract_subvec V1, ExtIdx)
20028     uint64_t InsIdx = V.getConstantOperandVal(2);
20029     if (InsIdx * SmallVT.getScalarSizeInBits() ==
20030         ExtIdx * NVT.getScalarSizeInBits())
20031       return DAG.getBitcast(NVT, V.getOperand(1));
20032     return DAG.getNode(
20033         ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
20034         DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
20035         N->getOperand(1));
20036   }
20037 
20038   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
20039     return NarrowBOp;
20040 
20041   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
20042     return SDValue(N, 0);
20043 
20044   return SDValue();
20045 }
20046 
20047 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
20048 /// followed by concatenation. Narrow vector ops may have better performance
20049 /// than wide ops, and this can unlock further narrowing of other vector ops.
20050 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)20051 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
20052                                          SelectionDAG &DAG) {
20053   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
20054   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
20055       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
20056       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
20057     return SDValue();
20058 
20059   // Split the wide shuffle mask into halves. Any mask element that is accessing
20060   // operand 1 is offset down to account for narrowing of the vectors.
20061   ArrayRef<int> Mask = Shuf->getMask();
20062   EVT VT = Shuf->getValueType(0);
20063   unsigned NumElts = VT.getVectorNumElements();
20064   unsigned HalfNumElts = NumElts / 2;
20065   SmallVector<int, 16> Mask0(HalfNumElts, -1);
20066   SmallVector<int, 16> Mask1(HalfNumElts, -1);
20067   for (unsigned i = 0; i != NumElts; ++i) {
20068     if (Mask[i] == -1)
20069       continue;
20070     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
20071     if (i < HalfNumElts)
20072       Mask0[i] = M;
20073     else
20074       Mask1[i - HalfNumElts] = M;
20075   }
20076 
20077   // Ask the target if this is a valid transform.
20078   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20079   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
20080                                 HalfNumElts);
20081   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
20082       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
20083     return SDValue();
20084 
20085   // shuffle (concat X, undef), (concat Y, undef), Mask -->
20086   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
20087   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
20088   SDLoc DL(Shuf);
20089   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
20090   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
20091   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
20092 }
20093 
20094 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
20095 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)20096 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
20097   EVT VT = N->getValueType(0);
20098   unsigned NumElts = VT.getVectorNumElements();
20099 
20100   SDValue N0 = N->getOperand(0);
20101   SDValue N1 = N->getOperand(1);
20102   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
20103   ArrayRef<int> Mask = SVN->getMask();
20104 
20105   SmallVector<SDValue, 4> Ops;
20106   EVT ConcatVT = N0.getOperand(0).getValueType();
20107   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
20108   unsigned NumConcats = NumElts / NumElemsPerConcat;
20109 
20110   auto IsUndefMaskElt = [](int i) { return i == -1; };
20111 
20112   // Special case: shuffle(concat(A,B)) can be more efficiently represented
20113   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
20114   // half vector elements.
20115   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
20116       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
20117                    IsUndefMaskElt)) {
20118     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
20119                               N0.getOperand(1),
20120                               Mask.slice(0, NumElemsPerConcat));
20121     N1 = DAG.getUNDEF(ConcatVT);
20122     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
20123   }
20124 
20125   // Look at every vector that's inserted. We're looking for exact
20126   // subvector-sized copies from a concatenated vector
20127   for (unsigned I = 0; I != NumConcats; ++I) {
20128     unsigned Begin = I * NumElemsPerConcat;
20129     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
20130 
20131     // Make sure we're dealing with a copy.
20132     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
20133       Ops.push_back(DAG.getUNDEF(ConcatVT));
20134       continue;
20135     }
20136 
20137     int OpIdx = -1;
20138     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
20139       if (IsUndefMaskElt(SubMask[i]))
20140         continue;
20141       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
20142         return SDValue();
20143       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
20144       if (0 <= OpIdx && EltOpIdx != OpIdx)
20145         return SDValue();
20146       OpIdx = EltOpIdx;
20147     }
20148     assert(0 <= OpIdx && "Unknown concat_vectors op");
20149 
20150     if (OpIdx < (int)N0.getNumOperands())
20151       Ops.push_back(N0.getOperand(OpIdx));
20152     else
20153       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
20154   }
20155 
20156   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
20157 }
20158 
20159 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
20160 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
20161 //
20162 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
20163 // a simplification in some sense, but it isn't appropriate in general: some
20164 // BUILD_VECTORs are substantially cheaper than others. The general case
20165 // of a BUILD_VECTOR requires inserting each element individually (or
20166 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
20167 // all constants is a single constant pool load.  A BUILD_VECTOR where each
20168 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
20169 // are undef lowers to a small number of element insertions.
20170 //
20171 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
20172 // We don't fold shuffles where one side is a non-zero constant, and we don't
20173 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
20174 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)20175 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
20176                                        SelectionDAG &DAG,
20177                                        const TargetLowering &TLI) {
20178   EVT VT = SVN->getValueType(0);
20179   unsigned NumElts = VT.getVectorNumElements();
20180   SDValue N0 = SVN->getOperand(0);
20181   SDValue N1 = SVN->getOperand(1);
20182 
20183   if (!N0->hasOneUse())
20184     return SDValue();
20185 
20186   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
20187   // discussed above.
20188   if (!N1.isUndef()) {
20189     if (!N1->hasOneUse())
20190       return SDValue();
20191 
20192     bool N0AnyConst = isAnyConstantBuildVector(N0);
20193     bool N1AnyConst = isAnyConstantBuildVector(N1);
20194     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
20195       return SDValue();
20196     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
20197       return SDValue();
20198   }
20199 
20200   // If both inputs are splats of the same value then we can safely merge this
20201   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
20202   bool IsSplat = false;
20203   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
20204   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
20205   if (BV0 && BV1)
20206     if (SDValue Splat0 = BV0->getSplatValue())
20207       IsSplat = (Splat0 == BV1->getSplatValue());
20208 
20209   SmallVector<SDValue, 8> Ops;
20210   SmallSet<SDValue, 16> DuplicateOps;
20211   for (int M : SVN->getMask()) {
20212     SDValue Op = DAG.getUNDEF(VT.getScalarType());
20213     if (M >= 0) {
20214       int Idx = M < (int)NumElts ? M : M - NumElts;
20215       SDValue &S = (M < (int)NumElts ? N0 : N1);
20216       if (S.getOpcode() == ISD::BUILD_VECTOR) {
20217         Op = S.getOperand(Idx);
20218       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
20219         SDValue Op0 = S.getOperand(0);
20220         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
20221       } else {
20222         // Operand can't be combined - bail out.
20223         return SDValue();
20224       }
20225     }
20226 
20227     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
20228     // generating a splat; semantically, this is fine, but it's likely to
20229     // generate low-quality code if the target can't reconstruct an appropriate
20230     // shuffle.
20231     if (!Op.isUndef() && !isa<ConstantSDNode>(Op) && !isa<ConstantFPSDNode>(Op))
20232       if (!IsSplat && !DuplicateOps.insert(Op).second)
20233         return SDValue();
20234 
20235     Ops.push_back(Op);
20236   }
20237 
20238   // BUILD_VECTOR requires all inputs to be of the same type, find the
20239   // maximum type and extend them all.
20240   EVT SVT = VT.getScalarType();
20241   if (SVT.isInteger())
20242     for (SDValue &Op : Ops)
20243       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
20244   if (SVT != VT.getScalarType())
20245     for (SDValue &Op : Ops)
20246       Op = TLI.isZExtFree(Op.getValueType(), SVT)
20247                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
20248                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT);
20249   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
20250 }
20251 
20252 // Match shuffles that can be converted to any_vector_extend_in_reg.
20253 // This is often generated during legalization.
20254 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
20255 // TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case.
combineShuffleToVectorExtend(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)20256 static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN,
20257                                             SelectionDAG &DAG,
20258                                             const TargetLowering &TLI,
20259                                             bool LegalOperations) {
20260   EVT VT = SVN->getValueType(0);
20261   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
20262 
20263   // TODO Add support for big-endian when we have a test case.
20264   if (!VT.isInteger() || IsBigEndian)
20265     return SDValue();
20266 
20267   unsigned NumElts = VT.getVectorNumElements();
20268   unsigned EltSizeInBits = VT.getScalarSizeInBits();
20269   ArrayRef<int> Mask = SVN->getMask();
20270   SDValue N0 = SVN->getOperand(0);
20271 
20272   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
20273   auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) {
20274     for (unsigned i = 0; i != NumElts; ++i) {
20275       if (Mask[i] < 0)
20276         continue;
20277       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
20278         continue;
20279       return false;
20280     }
20281     return true;
20282   };
20283 
20284   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
20285   // power-of-2 extensions as they are the most likely.
20286   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
20287     // Check for non power of 2 vector sizes
20288     if (NumElts % Scale != 0)
20289       continue;
20290     if (!isAnyExtend(Scale))
20291       continue;
20292 
20293     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
20294     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
20295     // Never create an illegal type. Only create unsupported operations if we
20296     // are pre-legalization.
20297     if (TLI.isTypeLegal(OutVT))
20298       if (!LegalOperations ||
20299           TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT))
20300         return DAG.getBitcast(VT,
20301                               DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG,
20302                                           SDLoc(SVN), OutVT, N0));
20303   }
20304 
20305   return SDValue();
20306 }
20307 
20308 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
20309 // each source element of a large type into the lowest elements of a smaller
20310 // destination type. This is often generated during legalization.
20311 // If the source node itself was a '*_extend_vector_inreg' node then we should
20312 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)20313 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
20314                                         SelectionDAG &DAG) {
20315   EVT VT = SVN->getValueType(0);
20316   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
20317 
20318   // TODO Add support for big-endian when we have a test case.
20319   if (!VT.isInteger() || IsBigEndian)
20320     return SDValue();
20321 
20322   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
20323 
20324   unsigned Opcode = N0.getOpcode();
20325   if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
20326       Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
20327       Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
20328     return SDValue();
20329 
20330   SDValue N00 = N0.getOperand(0);
20331   ArrayRef<int> Mask = SVN->getMask();
20332   unsigned NumElts = VT.getVectorNumElements();
20333   unsigned EltSizeInBits = VT.getScalarSizeInBits();
20334   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
20335   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
20336 
20337   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
20338     return SDValue();
20339   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
20340 
20341   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
20342   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
20343   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
20344   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
20345     for (unsigned i = 0; i != NumElts; ++i) {
20346       if (Mask[i] < 0)
20347         continue;
20348       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
20349         continue;
20350       return false;
20351     }
20352     return true;
20353   };
20354 
20355   // At the moment we just handle the case where we've truncated back to the
20356   // same size as before the extension.
20357   // TODO: handle more extension/truncation cases as cases arise.
20358   if (EltSizeInBits != ExtSrcSizeInBits)
20359     return SDValue();
20360 
20361   // We can remove *extend_vector_inreg only if the truncation happens at
20362   // the same scale as the extension.
20363   if (isTruncate(ExtScale))
20364     return DAG.getBitcast(VT, N00);
20365 
20366   return SDValue();
20367 }
20368 
20369 // Combine shuffles of splat-shuffles of the form:
20370 // shuffle (shuffle V, undef, splat-mask), undef, M
20371 // If splat-mask contains undef elements, we need to be careful about
20372 // introducing undef's in the folded mask which are not the result of composing
20373 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)20374 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
20375                                         SelectionDAG &DAG) {
20376   if (!Shuf->getOperand(1).isUndef())
20377     return SDValue();
20378   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
20379   if (!Splat || !Splat->isSplat())
20380     return SDValue();
20381 
20382   ArrayRef<int> ShufMask = Shuf->getMask();
20383   ArrayRef<int> SplatMask = Splat->getMask();
20384   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
20385 
20386   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
20387   // every undef mask element in the splat-shuffle has a corresponding undef
20388   // element in the user-shuffle's mask or if the composition of mask elements
20389   // would result in undef.
20390   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
20391   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
20392   //   In this case it is not legal to simplify to the splat-shuffle because we
20393   //   may be exposing the users of the shuffle an undef element at index 1
20394   //   which was not there before the combine.
20395   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
20396   //   In this case the composition of masks yields SplatMask, so it's ok to
20397   //   simplify to the splat-shuffle.
20398   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
20399   //   In this case the composed mask includes all undef elements of SplatMask
20400   //   and in addition sets element zero to undef. It is safe to simplify to
20401   //   the splat-shuffle.
20402   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
20403                                        ArrayRef<int> SplatMask) {
20404     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
20405       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
20406           SplatMask[UserMask[i]] != -1)
20407         return false;
20408     return true;
20409   };
20410   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
20411     return Shuf->getOperand(0);
20412 
20413   // Create a new shuffle with a mask that is composed of the two shuffles'
20414   // masks.
20415   SmallVector<int, 32> NewMask;
20416   for (int Idx : ShufMask)
20417     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
20418 
20419   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
20420                               Splat->getOperand(0), Splat->getOperand(1),
20421                               NewMask);
20422 }
20423 
20424 /// Combine shuffle of shuffle of the form:
20425 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)20426 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
20427                                      SelectionDAG &DAG) {
20428   if (!OuterShuf->getOperand(1).isUndef())
20429     return SDValue();
20430   auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
20431   if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
20432     return SDValue();
20433 
20434   ArrayRef<int> OuterMask = OuterShuf->getMask();
20435   ArrayRef<int> InnerMask = InnerShuf->getMask();
20436   unsigned NumElts = OuterMask.size();
20437   assert(NumElts == InnerMask.size() && "Mask length mismatch");
20438   SmallVector<int, 32> CombinedMask(NumElts, -1);
20439   int SplatIndex = -1;
20440   for (unsigned i = 0; i != NumElts; ++i) {
20441     // Undef lanes remain undef.
20442     int OuterMaskElt = OuterMask[i];
20443     if (OuterMaskElt == -1)
20444       continue;
20445 
20446     // Peek through the shuffle masks to get the underlying source element.
20447     int InnerMaskElt = InnerMask[OuterMaskElt];
20448     if (InnerMaskElt == -1)
20449       continue;
20450 
20451     // Initialize the splatted element.
20452     if (SplatIndex == -1)
20453       SplatIndex = InnerMaskElt;
20454 
20455     // Non-matching index - this is not a splat.
20456     if (SplatIndex != InnerMaskElt)
20457       return SDValue();
20458 
20459     CombinedMask[i] = InnerMaskElt;
20460   }
20461   assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
20462           getSplatIndex(CombinedMask) != -1) &&
20463          "Expected a splat mask");
20464 
20465   // TODO: The transform may be a win even if the mask is not legal.
20466   EVT VT = OuterShuf->getValueType(0);
20467   assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
20468   if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
20469     return SDValue();
20470 
20471   return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
20472                               InnerShuf->getOperand(1), CombinedMask);
20473 }
20474 
20475 /// If the shuffle mask is taking exactly one element from the first vector
20476 /// operand and passing through all other elements from the second vector
20477 /// operand, return the index of the mask element that is choosing an element
20478 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)20479 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
20480   int MaskSize = Mask.size();
20481   int EltFromOp0 = -1;
20482   // TODO: This does not match if there are undef elements in the shuffle mask.
20483   // Should we ignore undefs in the shuffle mask instead? The trade-off is
20484   // removing an instruction (a shuffle), but losing the knowledge that some
20485   // vector lanes are not needed.
20486   for (int i = 0; i != MaskSize; ++i) {
20487     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
20488       // We're looking for a shuffle of exactly one element from operand 0.
20489       if (EltFromOp0 != -1)
20490         return -1;
20491       EltFromOp0 = i;
20492     } else if (Mask[i] != i + MaskSize) {
20493       // Nothing from operand 1 can change lanes.
20494       return -1;
20495     }
20496   }
20497   return EltFromOp0;
20498 }
20499 
20500 /// If a shuffle inserts exactly one element from a source vector operand into
20501 /// another vector operand and we can access the specified element as a scalar,
20502 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)20503 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
20504                                       SelectionDAG &DAG) {
20505   // First, check if we are taking one element of a vector and shuffling that
20506   // element into another vector.
20507   ArrayRef<int> Mask = Shuf->getMask();
20508   SmallVector<int, 16> CommutedMask(Mask.begin(), Mask.end());
20509   SDValue Op0 = Shuf->getOperand(0);
20510   SDValue Op1 = Shuf->getOperand(1);
20511   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
20512   if (ShufOp0Index == -1) {
20513     // Commute mask and check again.
20514     ShuffleVectorSDNode::commuteMask(CommutedMask);
20515     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
20516     if (ShufOp0Index == -1)
20517       return SDValue();
20518     // Commute operands to match the commuted shuffle mask.
20519     std::swap(Op0, Op1);
20520     Mask = CommutedMask;
20521   }
20522 
20523   // The shuffle inserts exactly one element from operand 0 into operand 1.
20524   // Now see if we can access that element as a scalar via a real insert element
20525   // instruction.
20526   // TODO: We can try harder to locate the element as a scalar. Examples: it
20527   // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
20528   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
20529          "Shuffle mask value must be from operand 0");
20530   if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
20531     return SDValue();
20532 
20533   auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
20534   if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
20535     return SDValue();
20536 
20537   // There's an existing insertelement with constant insertion index, so we
20538   // don't need to check the legality/profitability of a replacement operation
20539   // that differs at most in the constant value. The target should be able to
20540   // lower any of those in a similar way. If not, legalization will expand this
20541   // to a scalar-to-vector plus shuffle.
20542   //
20543   // Note that the shuffle may move the scalar from the position that the insert
20544   // element used. Therefore, our new insert element occurs at the shuffle's
20545   // mask index value, not the insert's index value.
20546   // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
20547   SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
20548   return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
20549                      Op1, Op0.getOperand(1), NewInsIndex);
20550 }
20551 
20552 /// If we have a unary shuffle of a shuffle, see if it can be folded away
20553 /// completely. This has the potential to lose undef knowledge because the first
20554 /// shuffle may not have an undef mask element where the second one does. So
20555 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)20556 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
20557   // shuf (shuf0 X, Y, Mask0), undef, Mask
20558   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
20559   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
20560     return SDValue();
20561 
20562   ArrayRef<int> Mask = Shuf->getMask();
20563   ArrayRef<int> Mask0 = Shuf0->getMask();
20564   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
20565     // Ignore undef elements.
20566     if (Mask[i] == -1)
20567       continue;
20568     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
20569 
20570     // Is the element of the shuffle operand chosen by this shuffle the same as
20571     // the element chosen by the shuffle operand itself?
20572     if (Mask0[Mask[i]] != Mask0[i])
20573       return SDValue();
20574   }
20575   // Every element of this shuffle is identical to the result of the previous
20576   // shuffle, so we can replace this value.
20577   return Shuf->getOperand(0);
20578 }
20579 
visitVECTOR_SHUFFLE(SDNode * N)20580 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
20581   EVT VT = N->getValueType(0);
20582   unsigned NumElts = VT.getVectorNumElements();
20583 
20584   SDValue N0 = N->getOperand(0);
20585   SDValue N1 = N->getOperand(1);
20586 
20587   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
20588 
20589   // Canonicalize shuffle undef, undef -> undef
20590   if (N0.isUndef() && N1.isUndef())
20591     return DAG.getUNDEF(VT);
20592 
20593   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
20594 
20595   // Canonicalize shuffle v, v -> v, undef
20596   if (N0 == N1) {
20597     SmallVector<int, 8> NewMask;
20598     for (unsigned i = 0; i != NumElts; ++i) {
20599       int Idx = SVN->getMaskElt(i);
20600       if (Idx >= (int)NumElts) Idx -= NumElts;
20601       NewMask.push_back(Idx);
20602     }
20603     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT), NewMask);
20604   }
20605 
20606   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
20607   if (N0.isUndef())
20608     return DAG.getCommutedVectorShuffle(*SVN);
20609 
20610   // Remove references to rhs if it is undef
20611   if (N1.isUndef()) {
20612     bool Changed = false;
20613     SmallVector<int, 8> NewMask;
20614     for (unsigned i = 0; i != NumElts; ++i) {
20615       int Idx = SVN->getMaskElt(i);
20616       if (Idx >= (int)NumElts) {
20617         Idx = -1;
20618         Changed = true;
20619       }
20620       NewMask.push_back(Idx);
20621     }
20622     if (Changed)
20623       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
20624   }
20625 
20626   if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
20627     return InsElt;
20628 
20629   // A shuffle of a single vector that is a splatted value can always be folded.
20630   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
20631     return V;
20632 
20633   if (SDValue V = formSplatFromShuffles(SVN, DAG))
20634     return V;
20635 
20636   // If it is a splat, check if the argument vector is another splat or a
20637   // build_vector.
20638   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
20639     int SplatIndex = SVN->getSplatIndex();
20640     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
20641         TLI.isBinOp(N0.getOpcode()) && N0.getNode()->getNumValues() == 1) {
20642       // splat (vector_bo L, R), Index -->
20643       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
20644       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
20645       SDLoc DL(N);
20646       EVT EltVT = VT.getScalarType();
20647       SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
20648       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
20649       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
20650       SDValue NewBO = DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR,
20651                                   N0.getNode()->getFlags());
20652       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
20653       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
20654       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
20655     }
20656 
20657     // If this is a bit convert that changes the element type of the vector but
20658     // not the number of vector elements, look through it.  Be careful not to
20659     // look though conversions that change things like v4f32 to v2f64.
20660     SDNode *V = N0.getNode();
20661     if (V->getOpcode() == ISD::BITCAST) {
20662       SDValue ConvInput = V->getOperand(0);
20663       if (ConvInput.getValueType().isVector() &&
20664           ConvInput.getValueType().getVectorNumElements() == NumElts)
20665         V = ConvInput.getNode();
20666     }
20667 
20668     if (V->getOpcode() == ISD::BUILD_VECTOR) {
20669       assert(V->getNumOperands() == NumElts &&
20670              "BUILD_VECTOR has wrong number of operands");
20671       SDValue Base;
20672       bool AllSame = true;
20673       for (unsigned i = 0; i != NumElts; ++i) {
20674         if (!V->getOperand(i).isUndef()) {
20675           Base = V->getOperand(i);
20676           break;
20677         }
20678       }
20679       // Splat of <u, u, u, u>, return <u, u, u, u>
20680       if (!Base.getNode())
20681         return N0;
20682       for (unsigned i = 0; i != NumElts; ++i) {
20683         if (V->getOperand(i) != Base) {
20684           AllSame = false;
20685           break;
20686         }
20687       }
20688       // Splat of <x, x, x, x>, return <x, x, x, x>
20689       if (AllSame)
20690         return N0;
20691 
20692       // Canonicalize any other splat as a build_vector.
20693       SDValue Splatted = V->getOperand(SplatIndex);
20694       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
20695       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
20696 
20697       // We may have jumped through bitcasts, so the type of the
20698       // BUILD_VECTOR may not match the type of the shuffle.
20699       if (V->getValueType(0) != VT)
20700         NewBV = DAG.getBitcast(VT, NewBV);
20701       return NewBV;
20702     }
20703   }
20704 
20705   // Simplify source operands based on shuffle mask.
20706   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
20707     return SDValue(N, 0);
20708 
20709   // This is intentionally placed after demanded elements simplification because
20710   // it could eliminate knowledge of undef elements created by this shuffle.
20711   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
20712     return ShufOp;
20713 
20714   // Match shuffles that can be converted to any_vector_extend_in_reg.
20715   if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations))
20716     return V;
20717 
20718   // Combine "truncate_vector_in_reg" style shuffles.
20719   if (SDValue V = combineTruncationShuffle(SVN, DAG))
20720     return V;
20721 
20722   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
20723       Level < AfterLegalizeVectorOps &&
20724       (N1.isUndef() ||
20725       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
20726        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
20727     if (SDValue V = partitionShuffleOfConcats(N, DAG))
20728       return V;
20729   }
20730 
20731   // A shuffle of a concat of the same narrow vector can be reduced to use
20732   // only low-half elements of a concat with undef:
20733   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
20734   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
20735       N0.getNumOperands() == 2 &&
20736       N0.getOperand(0) == N0.getOperand(1)) {
20737     int HalfNumElts = (int)NumElts / 2;
20738     SmallVector<int, 8> NewMask;
20739     for (unsigned i = 0; i != NumElts; ++i) {
20740       int Idx = SVN->getMaskElt(i);
20741       if (Idx >= HalfNumElts) {
20742         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
20743         Idx -= HalfNumElts;
20744       }
20745       NewMask.push_back(Idx);
20746     }
20747     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
20748       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
20749       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
20750                                    N0.getOperand(0), UndefVec);
20751       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
20752     }
20753   }
20754 
20755   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
20756   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
20757   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
20758     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
20759       return Res;
20760 
20761   // If this shuffle only has a single input that is a bitcasted shuffle,
20762   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
20763   // back to their original types.
20764   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
20765       N1.isUndef() && Level < AfterLegalizeVectorOps &&
20766       TLI.isTypeLegal(VT)) {
20767 
20768     SDValue BC0 = peekThroughOneUseBitcasts(N0);
20769     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
20770       EVT SVT = VT.getScalarType();
20771       EVT InnerVT = BC0->getValueType(0);
20772       EVT InnerSVT = InnerVT.getScalarType();
20773 
20774       // Determine which shuffle works with the smaller scalar type.
20775       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
20776       EVT ScaleSVT = ScaleVT.getScalarType();
20777 
20778       if (TLI.isTypeLegal(ScaleVT) &&
20779           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
20780           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
20781         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
20782         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
20783 
20784         // Scale the shuffle masks to the smaller scalar type.
20785         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
20786         SmallVector<int, 8> InnerMask;
20787         SmallVector<int, 8> OuterMask;
20788         narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
20789         narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
20790 
20791         // Merge the shuffle masks.
20792         SmallVector<int, 8> NewMask;
20793         for (int M : OuterMask)
20794           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
20795 
20796         // Test for shuffle mask legality over both commutations.
20797         SDValue SV0 = BC0->getOperand(0);
20798         SDValue SV1 = BC0->getOperand(1);
20799         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
20800         if (!LegalMask) {
20801           std::swap(SV0, SV1);
20802           ShuffleVectorSDNode::commuteMask(NewMask);
20803           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
20804         }
20805 
20806         if (LegalMask) {
20807           SV0 = DAG.getBitcast(ScaleVT, SV0);
20808           SV1 = DAG.getBitcast(ScaleVT, SV1);
20809           return DAG.getBitcast(
20810               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
20811         }
20812       }
20813     }
20814   }
20815 
20816   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
20817     // Canonicalize shuffles according to rules:
20818     //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
20819     //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
20820     //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
20821     if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
20822         N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
20823       // The incoming shuffle must be of the same type as the result of the
20824       // current shuffle.
20825       assert(N1->getOperand(0).getValueType() == VT &&
20826              "Shuffle types don't match");
20827 
20828       SDValue SV0 = N1->getOperand(0);
20829       SDValue SV1 = N1->getOperand(1);
20830       bool HasSameOp0 = N0 == SV0;
20831       bool IsSV1Undef = SV1.isUndef();
20832       if (HasSameOp0 || IsSV1Undef || N0 == SV1)
20833         // Commute the operands of this shuffle so merging below will trigger.
20834         return DAG.getCommutedVectorShuffle(*SVN);
20835     }
20836 
20837     // Canonicalize splat shuffles to the RHS to improve merging below.
20838     //  shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
20839     if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
20840         N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
20841         cast<ShuffleVectorSDNode>(N0)->isSplat() &&
20842         !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
20843       return DAG.getCommutedVectorShuffle(*SVN);
20844     }
20845   }
20846 
20847   // Compute the combined shuffle mask for a shuffle with SV0 as the first
20848   // operand, and SV1 as the second operand.
20849   // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask).
20850   auto MergeInnerShuffle = [NumElts](ShuffleVectorSDNode *SVN,
20851                                      ShuffleVectorSDNode *OtherSVN, SDValue N1,
20852                                      SDValue &SV0, SDValue &SV1,
20853                                      SmallVectorImpl<int> &Mask) -> bool {
20854     // Don't try to fold splats; they're likely to simplify somehow, or they
20855     // might be free.
20856     if (OtherSVN->isSplat())
20857       return false;
20858 
20859     SV0 = SV1 = SDValue();
20860     Mask.clear();
20861 
20862     for (unsigned i = 0; i != NumElts; ++i) {
20863       int Idx = SVN->getMaskElt(i);
20864       if (Idx < 0) {
20865         // Propagate Undef.
20866         Mask.push_back(Idx);
20867         continue;
20868       }
20869 
20870       SDValue CurrentVec;
20871       if (Idx < (int)NumElts) {
20872         // This shuffle index refers to the inner shuffle N0. Lookup the inner
20873         // shuffle mask to identify which vector is actually referenced.
20874         Idx = OtherSVN->getMaskElt(Idx);
20875         if (Idx < 0) {
20876           // Propagate Undef.
20877           Mask.push_back(Idx);
20878           continue;
20879         }
20880         CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
20881                                           : OtherSVN->getOperand(1);
20882       } else {
20883         // This shuffle index references an element within N1.
20884         CurrentVec = N1;
20885       }
20886 
20887       // Simple case where 'CurrentVec' is UNDEF.
20888       if (CurrentVec.isUndef()) {
20889         Mask.push_back(-1);
20890         continue;
20891       }
20892 
20893       // Canonicalize the shuffle index. We don't know yet if CurrentVec
20894       // will be the first or second operand of the combined shuffle.
20895       Idx = Idx % NumElts;
20896       if (!SV0.getNode() || SV0 == CurrentVec) {
20897         // Ok. CurrentVec is the left hand side.
20898         // Update the mask accordingly.
20899         SV0 = CurrentVec;
20900         Mask.push_back(Idx);
20901         continue;
20902       }
20903       if (!SV1.getNode() || SV1 == CurrentVec) {
20904         // Ok. CurrentVec is the right hand side.
20905         // Update the mask accordingly.
20906         SV1 = CurrentVec;
20907         Mask.push_back(Idx + NumElts);
20908         continue;
20909       }
20910 
20911       // Last chance - see if the vector is another shuffle and if it
20912       // uses one of the existing candidate shuffle ops.
20913       if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
20914         int InnerIdx = CurrentSVN->getMaskElt(Idx);
20915         if (InnerIdx < 0) {
20916           Mask.push_back(-1);
20917           continue;
20918         }
20919         SDValue InnerVec = (InnerIdx < (int)NumElts)
20920                                ? CurrentSVN->getOperand(0)
20921                                : CurrentSVN->getOperand(1);
20922         if (InnerVec.isUndef()) {
20923           Mask.push_back(-1);
20924           continue;
20925         }
20926         InnerIdx %= NumElts;
20927         if (InnerVec == SV0) {
20928           Mask.push_back(InnerIdx);
20929           continue;
20930         }
20931         if (InnerVec == SV1) {
20932           Mask.push_back(InnerIdx + NumElts);
20933           continue;
20934         }
20935       }
20936 
20937       // Bail out if we cannot convert the shuffle pair into a single shuffle.
20938       return false;
20939     }
20940     return true;
20941   };
20942 
20943   // Try to fold according to rules:
20944   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
20945   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
20946   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
20947   // Don't try to fold shuffles with illegal type.
20948   // Only fold if this shuffle is the only user of the other shuffle.
20949   if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) &&
20950       Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
20951     ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0);
20952 
20953     // The incoming shuffle must be of the same type as the result of the
20954     // current shuffle.
20955     assert(OtherSV->getOperand(0).getValueType() == VT &&
20956            "Shuffle types don't match");
20957 
20958     SDValue SV0, SV1;
20959     SmallVector<int, 4> Mask;
20960     if (MergeInnerShuffle(SVN, OtherSV, N1, SV0, SV1, Mask)) {
20961       // Check if all indices in Mask are Undef. In case, propagate Undef.
20962       if (llvm::all_of(Mask, [](int M) { return M < 0; }))
20963         return DAG.getUNDEF(VT);
20964 
20965       if (!SV0.getNode())
20966         SV0 = DAG.getUNDEF(VT);
20967       if (!SV1.getNode())
20968         SV1 = DAG.getUNDEF(VT);
20969 
20970       // Avoid introducing shuffles with illegal mask.
20971       //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
20972       //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
20973       //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
20974       //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
20975       //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
20976       //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
20977       return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG);
20978     }
20979   }
20980 
20981   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
20982     return V;
20983 
20984   return SDValue();
20985 }
20986 
visitSCALAR_TO_VECTOR(SDNode * N)20987 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
20988   SDValue InVal = N->getOperand(0);
20989   EVT VT = N->getValueType(0);
20990 
20991   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
20992   // with a VECTOR_SHUFFLE and possible truncate.
20993   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20994       VT.isFixedLengthVector() &&
20995       InVal->getOperand(0).getValueType().isFixedLengthVector()) {
20996     SDValue InVec = InVal->getOperand(0);
20997     SDValue EltNo = InVal->getOperand(1);
20998     auto InVecT = InVec.getValueType();
20999     if (ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo)) {
21000       SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1);
21001       int Elt = C0->getZExtValue();
21002       NewMask[0] = Elt;
21003       // If we have an implict truncate do truncate here as long as it's legal.
21004       // if it's not legal, this should
21005       if (VT.getScalarType() != InVal.getValueType() &&
21006           InVal.getValueType().isScalarInteger() &&
21007           isTypeLegal(VT.getScalarType())) {
21008         SDValue Val =
21009             DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal);
21010         return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
21011       }
21012       if (VT.getScalarType() == InVecT.getScalarType() &&
21013           VT.getVectorNumElements() <= InVecT.getVectorNumElements()) {
21014         SDValue LegalShuffle =
21015           TLI.buildLegalVectorShuffle(InVecT, SDLoc(N), InVec,
21016                                       DAG.getUNDEF(InVecT), NewMask, DAG);
21017         if (LegalShuffle) {
21018           // If the initial vector is the correct size this shuffle is a
21019           // valid result.
21020           if (VT == InVecT)
21021             return LegalShuffle;
21022           // If not we must truncate the vector.
21023           if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
21024             SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
21025             EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
21026                                          InVecT.getVectorElementType(),
21027                                          VT.getVectorNumElements());
21028             return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT,
21029                                LegalShuffle, ZeroIdx);
21030           }
21031         }
21032       }
21033     }
21034   }
21035 
21036   return SDValue();
21037 }
21038 
visitINSERT_SUBVECTOR(SDNode * N)21039 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
21040   EVT VT = N->getValueType(0);
21041   SDValue N0 = N->getOperand(0);
21042   SDValue N1 = N->getOperand(1);
21043   SDValue N2 = N->getOperand(2);
21044   uint64_t InsIdx = N->getConstantOperandVal(2);
21045 
21046   // If inserting an UNDEF, just return the original vector.
21047   if (N1.isUndef())
21048     return N0;
21049 
21050   // If this is an insert of an extracted vector into an undef vector, we can
21051   // just use the input to the extract.
21052   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
21053       N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
21054     return N1.getOperand(0);
21055 
21056   // If we are inserting a bitcast value into an undef, with the same
21057   // number of elements, just use the bitcast input of the extract.
21058   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
21059   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
21060   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
21061       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
21062       N1.getOperand(0).getOperand(1) == N2 &&
21063       N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
21064           VT.getVectorElementCount() &&
21065       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
21066           VT.getSizeInBits()) {
21067     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
21068   }
21069 
21070   // If both N1 and N2 are bitcast values on which insert_subvector
21071   // would makes sense, pull the bitcast through.
21072   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
21073   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
21074   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
21075     SDValue CN0 = N0.getOperand(0);
21076     SDValue CN1 = N1.getOperand(0);
21077     EVT CN0VT = CN0.getValueType();
21078     EVT CN1VT = CN1.getValueType();
21079     if (CN0VT.isVector() && CN1VT.isVector() &&
21080         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
21081         CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
21082       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
21083                                       CN0.getValueType(), CN0, CN1, N2);
21084       return DAG.getBitcast(VT, NewINSERT);
21085     }
21086   }
21087 
21088   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
21089   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
21090   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
21091   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
21092       N0.getOperand(1).getValueType() == N1.getValueType() &&
21093       N0.getOperand(2) == N2)
21094     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
21095                        N1, N2);
21096 
21097   // Eliminate an intermediate insert into an undef vector:
21098   // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
21099   // insert_subvector undef, X, N2
21100   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
21101       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
21102     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
21103                        N1.getOperand(1), N2);
21104 
21105   // Push subvector bitcasts to the output, adjusting the index as we go.
21106   // insert_subvector(bitcast(v), bitcast(s), c1)
21107   // -> bitcast(insert_subvector(v, s, c2))
21108   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
21109       N1.getOpcode() == ISD::BITCAST) {
21110     SDValue N0Src = peekThroughBitcasts(N0);
21111     SDValue N1Src = peekThroughBitcasts(N1);
21112     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
21113     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
21114     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
21115         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
21116       EVT NewVT;
21117       SDLoc DL(N);
21118       SDValue NewIdx;
21119       LLVMContext &Ctx = *DAG.getContext();
21120       ElementCount NumElts = VT.getVectorElementCount();
21121       unsigned EltSizeInBits = VT.getScalarSizeInBits();
21122       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
21123         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
21124         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
21125         NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
21126       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
21127         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
21128         if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
21129           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
21130                                    NumElts.divideCoefficientBy(Scale));
21131           NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
21132         }
21133       }
21134       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
21135         SDValue Res = DAG.getBitcast(NewVT, N0Src);
21136         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
21137         return DAG.getBitcast(VT, Res);
21138       }
21139     }
21140   }
21141 
21142   // Canonicalize insert_subvector dag nodes.
21143   // Example:
21144   // (insert_subvector (insert_subvector A, Idx0), Idx1)
21145   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
21146   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
21147       N1.getValueType() == N0.getOperand(1).getValueType()) {
21148     unsigned OtherIdx = N0.getConstantOperandVal(2);
21149     if (InsIdx < OtherIdx) {
21150       // Swap nodes.
21151       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
21152                                   N0.getOperand(0), N1, N2);
21153       AddToWorklist(NewOp.getNode());
21154       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
21155                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
21156     }
21157   }
21158 
21159   // If the input vector is a concatenation, and the insert replaces
21160   // one of the pieces, we can optimize into a single concat_vectors.
21161   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
21162       N0.getOperand(0).getValueType() == N1.getValueType() &&
21163       N0.getOperand(0).getValueType().isScalableVector() ==
21164           N1.getValueType().isScalableVector()) {
21165     unsigned Factor = N1.getValueType().getVectorMinNumElements();
21166     SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
21167     Ops[InsIdx / Factor] = N1;
21168     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
21169   }
21170 
21171   // Simplify source operands based on insertion.
21172   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
21173     return SDValue(N, 0);
21174 
21175   return SDValue();
21176 }
21177 
visitFP_TO_FP16(SDNode * N)21178 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
21179   SDValue N0 = N->getOperand(0);
21180 
21181   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
21182   if (N0->getOpcode() == ISD::FP16_TO_FP)
21183     return N0->getOperand(0);
21184 
21185   return SDValue();
21186 }
21187 
visitFP16_TO_FP(SDNode * N)21188 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
21189   SDValue N0 = N->getOperand(0);
21190 
21191   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
21192   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
21193     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
21194     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
21195       return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
21196                          N0.getOperand(0));
21197     }
21198   }
21199 
21200   return SDValue();
21201 }
21202 
visitVECREDUCE(SDNode * N)21203 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
21204   SDValue N0 = N->getOperand(0);
21205   EVT VT = N0.getValueType();
21206   unsigned Opcode = N->getOpcode();
21207 
21208   // VECREDUCE over 1-element vector is just an extract.
21209   if (VT.getVectorElementCount().isScalar()) {
21210     SDLoc dl(N);
21211     SDValue Res =
21212         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
21213                     DAG.getVectorIdxConstant(0, dl));
21214     if (Res.getValueType() != N->getValueType(0))
21215       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
21216     return Res;
21217   }
21218 
21219   // On an boolean vector an and/or reduction is the same as a umin/umax
21220   // reduction. Convert them if the latter is legal while the former isn't.
21221   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
21222     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
21223         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
21224     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
21225         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
21226         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
21227       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
21228   }
21229 
21230   return SDValue();
21231 }
21232 
21233 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
21234 /// with the destination vector and a zero vector.
21235 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
21236 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)21237 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
21238   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
21239 
21240   EVT VT = N->getValueType(0);
21241   SDValue LHS = N->getOperand(0);
21242   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
21243   SDLoc DL(N);
21244 
21245   // Make sure we're not running after operation legalization where it
21246   // may have custom lowered the vector shuffles.
21247   if (LegalOperations)
21248     return SDValue();
21249 
21250   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
21251     return SDValue();
21252 
21253   EVT RVT = RHS.getValueType();
21254   unsigned NumElts = RHS.getNumOperands();
21255 
21256   // Attempt to create a valid clear mask, splitting the mask into
21257   // sub elements and checking to see if each is
21258   // all zeros or all ones - suitable for shuffle masking.
21259   auto BuildClearMask = [&](int Split) {
21260     int NumSubElts = NumElts * Split;
21261     int NumSubBits = RVT.getScalarSizeInBits() / Split;
21262 
21263     SmallVector<int, 8> Indices;
21264     for (int i = 0; i != NumSubElts; ++i) {
21265       int EltIdx = i / Split;
21266       int SubIdx = i % Split;
21267       SDValue Elt = RHS.getOperand(EltIdx);
21268       // X & undef --> 0 (not undef). So this lane must be converted to choose
21269       // from the zero constant vector (same as if the element had all 0-bits).
21270       if (Elt.isUndef()) {
21271         Indices.push_back(i + NumSubElts);
21272         continue;
21273       }
21274 
21275       APInt Bits;
21276       if (isa<ConstantSDNode>(Elt))
21277         Bits = cast<ConstantSDNode>(Elt)->getAPIntValue();
21278       else if (isa<ConstantFPSDNode>(Elt))
21279         Bits = cast<ConstantFPSDNode>(Elt)->getValueAPF().bitcastToAPInt();
21280       else
21281         return SDValue();
21282 
21283       // Extract the sub element from the constant bit mask.
21284       if (DAG.getDataLayout().isBigEndian())
21285         Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
21286       else
21287         Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
21288 
21289       if (Bits.isAllOnesValue())
21290         Indices.push_back(i);
21291       else if (Bits == 0)
21292         Indices.push_back(i + NumSubElts);
21293       else
21294         return SDValue();
21295     }
21296 
21297     // Let's see if the target supports this vector_shuffle.
21298     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
21299     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
21300     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
21301       return SDValue();
21302 
21303     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
21304     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
21305                                                    DAG.getBitcast(ClearVT, LHS),
21306                                                    Zero, Indices));
21307   };
21308 
21309   // Determine maximum split level (byte level masking).
21310   int MaxSplit = 1;
21311   if (RVT.getScalarSizeInBits() % 8 == 0)
21312     MaxSplit = RVT.getScalarSizeInBits() / 8;
21313 
21314   for (int Split = 1; Split <= MaxSplit; ++Split)
21315     if (RVT.getScalarSizeInBits() % Split == 0)
21316       if (SDValue S = BuildClearMask(Split))
21317         return S;
21318 
21319   return SDValue();
21320 }
21321 
21322 /// If a vector binop is performed on splat values, it may be profitable to
21323 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG)21324 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG) {
21325   SDValue N0 = N->getOperand(0);
21326   SDValue N1 = N->getOperand(1);
21327   unsigned Opcode = N->getOpcode();
21328   EVT VT = N->getValueType(0);
21329   EVT EltVT = VT.getVectorElementType();
21330   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21331 
21332   // TODO: Remove/replace the extract cost check? If the elements are available
21333   //       as scalars, then there may be no extract cost. Should we ask if
21334   //       inserting a scalar back into a vector is cheap instead?
21335   int Index0, Index1;
21336   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
21337   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
21338   if (!Src0 || !Src1 || Index0 != Index1 ||
21339       Src0.getValueType().getVectorElementType() != EltVT ||
21340       Src1.getValueType().getVectorElementType() != EltVT ||
21341       !TLI.isExtractVecEltCheap(VT, Index0) ||
21342       !TLI.isOperationLegalOrCustom(Opcode, EltVT))
21343     return SDValue();
21344 
21345   SDLoc DL(N);
21346   SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
21347   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
21348   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
21349   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
21350 
21351   // If all lanes but 1 are undefined, no need to splat the scalar result.
21352   // TODO: Keep track of undefs and use that info in the general case.
21353   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
21354       count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
21355       count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
21356     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
21357     // build_vec ..undef, (bo X, Y), undef...
21358     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
21359     Ops[Index0] = ScalarBO;
21360     return DAG.getBuildVector(VT, DL, Ops);
21361   }
21362 
21363   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
21364   SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
21365   return DAG.getBuildVector(VT, DL, Ops);
21366 }
21367 
21368 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N)21369 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
21370   assert(N->getValueType(0).isVector() &&
21371          "SimplifyVBinOp only works on vectors!");
21372 
21373   SDValue LHS = N->getOperand(0);
21374   SDValue RHS = N->getOperand(1);
21375   SDValue Ops[] = {LHS, RHS};
21376   EVT VT = N->getValueType(0);
21377   unsigned Opcode = N->getOpcode();
21378   SDNodeFlags Flags = N->getFlags();
21379 
21380   // See if we can constant fold the vector operation.
21381   if (SDValue Fold = DAG.FoldConstantVectorArithmetic(
21382           Opcode, SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags()))
21383     return Fold;
21384 
21385   // Move unary shuffles with identical masks after a vector binop:
21386   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
21387   //   --> shuffle (VBinOp A, B), Undef, Mask
21388   // This does not require type legality checks because we are creating the
21389   // same types of operations that are in the original sequence. We do have to
21390   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
21391   // though. This code is adapted from the identical transform in instcombine.
21392   if (Opcode != ISD::UDIV && Opcode != ISD::SDIV &&
21393       Opcode != ISD::UREM && Opcode != ISD::SREM &&
21394       Opcode != ISD::UDIVREM && Opcode != ISD::SDIVREM) {
21395     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
21396     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
21397     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
21398         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
21399         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
21400       SDLoc DL(N);
21401       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
21402                                      RHS.getOperand(0), Flags);
21403       SDValue UndefV = LHS.getOperand(1);
21404       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
21405     }
21406 
21407     // Try to sink a splat shuffle after a binop with a uniform constant.
21408     // This is limited to cases where neither the shuffle nor the constant have
21409     // undefined elements because that could be poison-unsafe or inhibit
21410     // demanded elements analysis. It is further limited to not change a splat
21411     // of an inserted scalar because that may be optimized better by
21412     // load-folding or other target-specific behaviors.
21413     if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) &&
21414         Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
21415         Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
21416       // binop (splat X), (splat C) --> splat (binop X, C)
21417       SDLoc DL(N);
21418       SDValue X = Shuf0->getOperand(0);
21419       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
21420       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
21421                                   Shuf0->getMask());
21422     }
21423     if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) &&
21424         Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
21425         Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
21426       // binop (splat C), (splat X) --> splat (binop C, X)
21427       SDLoc DL(N);
21428       SDValue X = Shuf1->getOperand(0);
21429       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
21430       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
21431                                   Shuf1->getMask());
21432     }
21433   }
21434 
21435   // The following pattern is likely to emerge with vector reduction ops. Moving
21436   // the binary operation ahead of insertion may allow using a narrower vector
21437   // instruction that has better performance than the wide version of the op:
21438   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
21439   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
21440       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
21441       LHS.getOperand(2) == RHS.getOperand(2) &&
21442       (LHS.hasOneUse() || RHS.hasOneUse())) {
21443     SDValue X = LHS.getOperand(1);
21444     SDValue Y = RHS.getOperand(1);
21445     SDValue Z = LHS.getOperand(2);
21446     EVT NarrowVT = X.getValueType();
21447     if (NarrowVT == Y.getValueType() &&
21448         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
21449                                               LegalOperations)) {
21450       // (binop undef, undef) may not return undef, so compute that result.
21451       SDLoc DL(N);
21452       SDValue VecC =
21453           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
21454       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
21455       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
21456     }
21457   }
21458 
21459   // Make sure all but the first op are undef or constant.
21460   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
21461     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
21462            all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
21463              return Op.isUndef() ||
21464                     ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
21465            });
21466   };
21467 
21468   // The following pattern is likely to emerge with vector reduction ops. Moving
21469   // the binary operation ahead of the concat may allow using a narrower vector
21470   // instruction that has better performance than the wide version of the op:
21471   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
21472   //   concat (VBinOp X, Y), VecC
21473   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
21474       (LHS.hasOneUse() || RHS.hasOneUse())) {
21475     EVT NarrowVT = LHS.getOperand(0).getValueType();
21476     if (NarrowVT == RHS.getOperand(0).getValueType() &&
21477         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
21478       SDLoc DL(N);
21479       unsigned NumOperands = LHS.getNumOperands();
21480       SmallVector<SDValue, 4> ConcatOps;
21481       for (unsigned i = 0; i != NumOperands; ++i) {
21482         // This constant fold for operands 1 and up.
21483         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
21484                                         RHS.getOperand(i)));
21485       }
21486 
21487       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
21488     }
21489   }
21490 
21491   if (SDValue V = scalarizeBinOpOfSplats(N, DAG))
21492     return V;
21493 
21494   return SDValue();
21495 }
21496 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)21497 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
21498                                     SDValue N2) {
21499   assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
21500 
21501   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
21502                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
21503 
21504   // If we got a simplified select_cc node back from SimplifySelectCC, then
21505   // break it down into a new SETCC node, and a new SELECT node, and then return
21506   // the SELECT node, since we were called with a SELECT node.
21507   if (SCC.getNode()) {
21508     // Check to see if we got a select_cc back (to turn into setcc/select).
21509     // Otherwise, just return whatever node we got back, like fabs.
21510     if (SCC.getOpcode() == ISD::SELECT_CC) {
21511       const SDNodeFlags Flags = N0.getNode()->getFlags();
21512       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
21513                                   N0.getValueType(),
21514                                   SCC.getOperand(0), SCC.getOperand(1),
21515                                   SCC.getOperand(4), Flags);
21516       AddToWorklist(SETCC.getNode());
21517       SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
21518                                          SCC.getOperand(2), SCC.getOperand(3));
21519       SelectNode->setFlags(Flags);
21520       return SelectNode;
21521     }
21522 
21523     return SCC;
21524   }
21525   return SDValue();
21526 }
21527 
21528 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
21529 /// being selected between, see if we can simplify the select.  Callers of this
21530 /// should assume that TheSelect is deleted if this returns true.  As such, they
21531 /// should return the appropriate thing (e.g. the node) back to the top-level of
21532 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)21533 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
21534                                     SDValue RHS) {
21535   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
21536   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
21537   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
21538     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
21539       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
21540       SDValue Sqrt = RHS;
21541       ISD::CondCode CC;
21542       SDValue CmpLHS;
21543       const ConstantFPSDNode *Zero = nullptr;
21544 
21545       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
21546         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
21547         CmpLHS = TheSelect->getOperand(0);
21548         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
21549       } else {
21550         // SELECT or VSELECT
21551         SDValue Cmp = TheSelect->getOperand(0);
21552         if (Cmp.getOpcode() == ISD::SETCC) {
21553           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
21554           CmpLHS = Cmp.getOperand(0);
21555           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
21556         }
21557       }
21558       if (Zero && Zero->isZero() &&
21559           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
21560           CC == ISD::SETULT || CC == ISD::SETLT)) {
21561         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
21562         CombineTo(TheSelect, Sqrt);
21563         return true;
21564       }
21565     }
21566   }
21567   // Cannot simplify select with vector condition
21568   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
21569 
21570   // If this is a select from two identical things, try to pull the operation
21571   // through the select.
21572   if (LHS.getOpcode() != RHS.getOpcode() ||
21573       !LHS.hasOneUse() || !RHS.hasOneUse())
21574     return false;
21575 
21576   // If this is a load and the token chain is identical, replace the select
21577   // of two loads with a load through a select of the address to load from.
21578   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
21579   // constants have been dropped into the constant pool.
21580   if (LHS.getOpcode() == ISD::LOAD) {
21581     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
21582     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
21583 
21584     // Token chains must be identical.
21585     if (LHS.getOperand(0) != RHS.getOperand(0) ||
21586         // Do not let this transformation reduce the number of volatile loads.
21587         // Be conservative for atomics for the moment
21588         // TODO: This does appear to be legal for unordered atomics (see D66309)
21589         !LLD->isSimple() || !RLD->isSimple() ||
21590         // FIXME: If either is a pre/post inc/dec load,
21591         // we'd need to split out the address adjustment.
21592         LLD->isIndexed() || RLD->isIndexed() ||
21593         // If this is an EXTLOAD, the VT's must match.
21594         LLD->getMemoryVT() != RLD->getMemoryVT() ||
21595         // If this is an EXTLOAD, the kind of extension must match.
21596         (LLD->getExtensionType() != RLD->getExtensionType() &&
21597          // The only exception is if one of the extensions is anyext.
21598          LLD->getExtensionType() != ISD::EXTLOAD &&
21599          RLD->getExtensionType() != ISD::EXTLOAD) ||
21600         // FIXME: this discards src value information.  This is
21601         // over-conservative. It would be beneficial to be able to remember
21602         // both potential memory locations.  Since we are discarding
21603         // src value info, don't do the transformation if the memory
21604         // locations are not in the default address space.
21605         LLD->getPointerInfo().getAddrSpace() != 0 ||
21606         RLD->getPointerInfo().getAddrSpace() != 0 ||
21607         // We can't produce a CMOV of a TargetFrameIndex since we won't
21608         // generate the address generation required.
21609         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
21610         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
21611         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
21612                                       LLD->getBasePtr().getValueType()))
21613       return false;
21614 
21615     // The loads must not depend on one another.
21616     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
21617       return false;
21618 
21619     // Check that the select condition doesn't reach either load.  If so,
21620     // folding this will induce a cycle into the DAG.  If not, this is safe to
21621     // xform, so create a select of the addresses.
21622 
21623     SmallPtrSet<const SDNode *, 32> Visited;
21624     SmallVector<const SDNode *, 16> Worklist;
21625 
21626     // Always fail if LLD and RLD are not independent. TheSelect is a
21627     // predecessor to all Nodes in question so we need not search past it.
21628 
21629     Visited.insert(TheSelect);
21630     Worklist.push_back(LLD);
21631     Worklist.push_back(RLD);
21632 
21633     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
21634         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
21635       return false;
21636 
21637     SDValue Addr;
21638     if (TheSelect->getOpcode() == ISD::SELECT) {
21639       // We cannot do this optimization if any pair of {RLD, LLD} is a
21640       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
21641       // Loads, we only need to check if CondNode is a successor to one of the
21642       // loads. We can further avoid this if there's no use of their chain
21643       // value.
21644       SDNode *CondNode = TheSelect->getOperand(0).getNode();
21645       Worklist.push_back(CondNode);
21646 
21647       if ((LLD->hasAnyUseOfValue(1) &&
21648            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
21649           (RLD->hasAnyUseOfValue(1) &&
21650            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
21651         return false;
21652 
21653       Addr = DAG.getSelect(SDLoc(TheSelect),
21654                            LLD->getBasePtr().getValueType(),
21655                            TheSelect->getOperand(0), LLD->getBasePtr(),
21656                            RLD->getBasePtr());
21657     } else {  // Otherwise SELECT_CC
21658       // We cannot do this optimization if any pair of {RLD, LLD} is a
21659       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
21660       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
21661       // one of the loads. We can further avoid this if there's no use of their
21662       // chain value.
21663 
21664       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
21665       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
21666       Worklist.push_back(CondLHS);
21667       Worklist.push_back(CondRHS);
21668 
21669       if ((LLD->hasAnyUseOfValue(1) &&
21670            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
21671           (RLD->hasAnyUseOfValue(1) &&
21672            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
21673         return false;
21674 
21675       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
21676                          LLD->getBasePtr().getValueType(),
21677                          TheSelect->getOperand(0),
21678                          TheSelect->getOperand(1),
21679                          LLD->getBasePtr(), RLD->getBasePtr(),
21680                          TheSelect->getOperand(4));
21681     }
21682 
21683     SDValue Load;
21684     // It is safe to replace the two loads if they have different alignments,
21685     // but the new load must be the minimum (most restrictive) alignment of the
21686     // inputs.
21687     Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
21688     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
21689     if (!RLD->isInvariant())
21690       MMOFlags &= ~MachineMemOperand::MOInvariant;
21691     if (!RLD->isDereferenceable())
21692       MMOFlags &= ~MachineMemOperand::MODereferenceable;
21693     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
21694       // FIXME: Discards pointer and AA info.
21695       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
21696                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
21697                          MMOFlags);
21698     } else {
21699       // FIXME: Discards pointer and AA info.
21700       Load = DAG.getExtLoad(
21701           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
21702                                                   : LLD->getExtensionType(),
21703           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
21704           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
21705     }
21706 
21707     // Users of the select now use the result of the load.
21708     CombineTo(TheSelect, Load);
21709 
21710     // Users of the old loads now use the new load's chain.  We know the
21711     // old-load value is dead now.
21712     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
21713     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
21714     return true;
21715   }
21716 
21717   return false;
21718 }
21719 
21720 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
21721 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)21722 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
21723                                             SDValue N1, SDValue N2, SDValue N3,
21724                                             ISD::CondCode CC) {
21725   // If this is a select where the false operand is zero and the compare is a
21726   // check of the sign bit, see if we can perform the "gzip trick":
21727   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
21728   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
21729   EVT XType = N0.getValueType();
21730   EVT AType = N2.getValueType();
21731   if (!isNullConstant(N3) || !XType.bitsGE(AType))
21732     return SDValue();
21733 
21734   // If the comparison is testing for a positive value, we have to invert
21735   // the sign bit mask, so only do that transform if the target has a bitwise
21736   // 'and not' instruction (the invert is free).
21737   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
21738     // (X > -1) ? A : 0
21739     // (X >  0) ? X : 0 <-- This is canonical signed max.
21740     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
21741       return SDValue();
21742   } else if (CC == ISD::SETLT) {
21743     // (X <  0) ? A : 0
21744     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
21745     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
21746       return SDValue();
21747   } else {
21748     return SDValue();
21749   }
21750 
21751   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
21752   // constant.
21753   EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
21754   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
21755   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
21756     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
21757     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
21758       SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
21759       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
21760       AddToWorklist(Shift.getNode());
21761 
21762       if (XType.bitsGT(AType)) {
21763         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
21764         AddToWorklist(Shift.getNode());
21765       }
21766 
21767       if (CC == ISD::SETGT)
21768         Shift = DAG.getNOT(DL, Shift, AType);
21769 
21770       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
21771     }
21772   }
21773 
21774   unsigned ShCt = XType.getSizeInBits() - 1;
21775   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
21776     return SDValue();
21777 
21778   SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
21779   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
21780   AddToWorklist(Shift.getNode());
21781 
21782   if (XType.bitsGT(AType)) {
21783     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
21784     AddToWorklist(Shift.getNode());
21785   }
21786 
21787   if (CC == ISD::SETGT)
21788     Shift = DAG.getNOT(DL, Shift, AType);
21789 
21790   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
21791 }
21792 
21793 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)21794 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
21795   SDValue N0 = N->getOperand(0);
21796   EVT VT = N->getValueType(0);
21797   bool IsFabs = N->getOpcode() == ISD::FABS;
21798   bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
21799 
21800   if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
21801     return SDValue();
21802 
21803   SDValue Int = N0.getOperand(0);
21804   EVT IntVT = Int.getValueType();
21805 
21806   // The operand to cast should be integer.
21807   if (!IntVT.isInteger() || IntVT.isVector())
21808     return SDValue();
21809 
21810   // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
21811   // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
21812   APInt SignMask;
21813   if (N0.getValueType().isVector()) {
21814     // For vector, create a sign mask (0x80...) or its inverse (for fabs,
21815     // 0x7f...) per element and splat it.
21816     SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
21817     if (IsFabs)
21818       SignMask = ~SignMask;
21819     SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
21820   } else {
21821     // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
21822     SignMask = APInt::getSignMask(IntVT.getSizeInBits());
21823     if (IsFabs)
21824       SignMask = ~SignMask;
21825   }
21826   SDLoc DL(N0);
21827   Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
21828                     DAG.getConstant(SignMask, DL, IntVT));
21829   AddToWorklist(Int.getNode());
21830   return DAG.getBitcast(VT, Int);
21831 }
21832 
21833 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
21834 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
21835 /// in it. This may be a win when the constant is not otherwise available
21836 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)21837 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
21838     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
21839     ISD::CondCode CC) {
21840   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
21841     return SDValue();
21842 
21843   // If we are before legalize types, we want the other legalization to happen
21844   // first (for example, to avoid messing with soft float).
21845   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
21846   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
21847   EVT VT = N2.getValueType();
21848   if (!TV || !FV || !TLI.isTypeLegal(VT))
21849     return SDValue();
21850 
21851   // If a constant can be materialized without loads, this does not make sense.
21852   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
21853       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
21854       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
21855     return SDValue();
21856 
21857   // If both constants have multiple uses, then we won't need to do an extra
21858   // load. The values are likely around in registers for other users.
21859   if (!TV->hasOneUse() && !FV->hasOneUse())
21860     return SDValue();
21861 
21862   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
21863                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
21864   Type *FPTy = Elts[0]->getType();
21865   const DataLayout &TD = DAG.getDataLayout();
21866 
21867   // Create a ConstantArray of the two constants.
21868   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
21869   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
21870                                       TD.getPrefTypeAlign(FPTy));
21871   Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
21872 
21873   // Get offsets to the 0 and 1 elements of the array, so we can select between
21874   // them.
21875   SDValue Zero = DAG.getIntPtrConstant(0, DL);
21876   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
21877   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
21878   SDValue Cond =
21879       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
21880   AddToWorklist(Cond.getNode());
21881   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
21882   AddToWorklist(CstOffset.getNode());
21883   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
21884   AddToWorklist(CPIdx.getNode());
21885   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
21886                      MachinePointerInfo::getConstantPool(
21887                          DAG.getMachineFunction()), Alignment);
21888 }
21889 
21890 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
21891 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)21892 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
21893                                       SDValue N2, SDValue N3, ISD::CondCode CC,
21894                                       bool NotExtCompare) {
21895   // (x ? y : y) -> y.
21896   if (N2 == N3) return N2;
21897 
21898   EVT CmpOpVT = N0.getValueType();
21899   EVT CmpResVT = getSetCCResultType(CmpOpVT);
21900   EVT VT = N2.getValueType();
21901   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
21902   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
21903   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
21904 
21905   // Determine if the condition we're dealing with is constant.
21906   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
21907     AddToWorklist(SCC.getNode());
21908     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
21909       // fold select_cc true, x, y -> x
21910       // fold select_cc false, x, y -> y
21911       return !(SCCC->isNullValue()) ? N2 : N3;
21912     }
21913   }
21914 
21915   if (SDValue V =
21916           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
21917     return V;
21918 
21919   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
21920     return V;
21921 
21922   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (shr (shl x)) A)
21923   // where y is has a single bit set.
21924   // A plaintext description would be, we can turn the SELECT_CC into an AND
21925   // when the condition can be materialized as an all-ones register.  Any
21926   // single bit-test can be materialized as an all-ones register with
21927   // shift-left and shift-right-arith.
21928   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
21929       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
21930     SDValue AndLHS = N0->getOperand(0);
21931     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
21932     if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
21933       // Shift the tested bit over the sign bit.
21934       const APInt &AndMask = ConstAndRHS->getAPIntValue();
21935       unsigned ShCt = AndMask.getBitWidth() - 1;
21936       if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
21937         SDValue ShlAmt =
21938           DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
21939                           getShiftAmountTy(AndLHS.getValueType()));
21940         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
21941 
21942         // Now arithmetic right shift it all the way over, so the result is
21943         // either all-ones, or zero.
21944         SDValue ShrAmt =
21945           DAG.getConstant(ShCt, SDLoc(Shl),
21946                           getShiftAmountTy(Shl.getValueType()));
21947         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
21948 
21949         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
21950       }
21951     }
21952   }
21953 
21954   // fold select C, 16, 0 -> shl C, 4
21955   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
21956   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
21957 
21958   if ((Fold || Swap) &&
21959       TLI.getBooleanContents(CmpOpVT) ==
21960           TargetLowering::ZeroOrOneBooleanContent &&
21961       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
21962 
21963     if (Swap) {
21964       CC = ISD::getSetCCInverse(CC, CmpOpVT);
21965       std::swap(N2C, N3C);
21966     }
21967 
21968     // If the caller doesn't want us to simplify this into a zext of a compare,
21969     // don't do it.
21970     if (NotExtCompare && N2C->isOne())
21971       return SDValue();
21972 
21973     SDValue Temp, SCC;
21974     // zext (setcc n0, n1)
21975     if (LegalTypes) {
21976       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
21977       if (VT.bitsLT(SCC.getValueType()))
21978         Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
21979       else
21980         Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
21981     } else {
21982       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
21983       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
21984     }
21985 
21986     AddToWorklist(SCC.getNode());
21987     AddToWorklist(Temp.getNode());
21988 
21989     if (N2C->isOne())
21990       return Temp;
21991 
21992     unsigned ShCt = N2C->getAPIntValue().logBase2();
21993     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
21994       return SDValue();
21995 
21996     // shl setcc result by log2 n2c
21997     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
21998                        DAG.getConstant(ShCt, SDLoc(Temp),
21999                                        getShiftAmountTy(Temp.getValueType())));
22000   }
22001 
22002   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
22003   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
22004   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
22005   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
22006   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
22007   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
22008   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
22009   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
22010   if (N1C && N1C->isNullValue() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
22011     SDValue ValueOnZero = N2;
22012     SDValue Count = N3;
22013     // If the condition is NE instead of E, swap the operands.
22014     if (CC == ISD::SETNE)
22015       std::swap(ValueOnZero, Count);
22016     // Check if the value on zero is a constant equal to the bits in the type.
22017     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
22018       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
22019         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
22020         // legal, combine to just cttz.
22021         if ((Count.getOpcode() == ISD::CTTZ ||
22022              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
22023             N0 == Count.getOperand(0) &&
22024             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
22025           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
22026         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
22027         // legal, combine to just ctlz.
22028         if ((Count.getOpcode() == ISD::CTLZ ||
22029              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
22030             N0 == Count.getOperand(0) &&
22031             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
22032           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
22033       }
22034     }
22035   }
22036 
22037   return SDValue();
22038 }
22039 
22040 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)22041 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
22042                                    ISD::CondCode Cond, const SDLoc &DL,
22043                                    bool foldBooleans) {
22044   TargetLowering::DAGCombinerInfo
22045     DagCombineInfo(DAG, Level, false, this);
22046   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
22047 }
22048 
22049 /// Given an ISD::SDIV node expressing a divide by constant, return
22050 /// a DAG expression to select that will generate the same value by multiplying
22051 /// by a magic number.
22052 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)22053 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
22054   // when optimising for minimum size, we don't want to expand a div to a mul
22055   // and a shift.
22056   if (DAG.getMachineFunction().getFunction().hasMinSize())
22057     return SDValue();
22058 
22059   SmallVector<SDNode *, 8> Built;
22060   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
22061     for (SDNode *N : Built)
22062       AddToWorklist(N);
22063     return S;
22064   }
22065 
22066   return SDValue();
22067 }
22068 
22069 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
22070 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)22071 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
22072   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
22073   if (!C)
22074     return SDValue();
22075 
22076   // Avoid division by zero.
22077   if (C->isNullValue())
22078     return SDValue();
22079 
22080   SmallVector<SDNode *, 8> Built;
22081   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
22082     for (SDNode *N : Built)
22083       AddToWorklist(N);
22084     return S;
22085   }
22086 
22087   return SDValue();
22088 }
22089 
22090 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
22091 /// expression that will generate the same value by multiplying by a magic
22092 /// number.
22093 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)22094 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
22095   // when optimising for minimum size, we don't want to expand a div to a mul
22096   // and a shift.
22097   if (DAG.getMachineFunction().getFunction().hasMinSize())
22098     return SDValue();
22099 
22100   SmallVector<SDNode *, 8> Built;
22101   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
22102     for (SDNode *N : Built)
22103       AddToWorklist(N);
22104     return S;
22105   }
22106 
22107   return SDValue();
22108 }
22109 
22110 /// Determines the LogBase2 value for a non-null input value using the
22111 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL)22112 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
22113   EVT VT = V.getValueType();
22114   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
22115   SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
22116   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
22117   return LogBase2;
22118 }
22119 
22120 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
22121 /// For the reciprocal, we need to find the zero of the function:
22122 ///   F(X) = A X - 1 [which has a zero at X = 1/A]
22123 ///     =>
22124 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
22125 ///     does not require additional intermediate precision]
22126 /// For the last iteration, put numerator N into it to gain more precision:
22127 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)22128 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
22129                                       SDNodeFlags Flags) {
22130   if (LegalDAG)
22131     return SDValue();
22132 
22133   // TODO: Handle half and/or extended types?
22134   EVT VT = Op.getValueType();
22135   if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
22136     return SDValue();
22137 
22138   // If estimates are explicitly disabled for this function, we're done.
22139   MachineFunction &MF = DAG.getMachineFunction();
22140   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
22141   if (Enabled == TLI.ReciprocalEstimate::Disabled)
22142     return SDValue();
22143 
22144   // Estimates may be explicitly enabled for this type with a custom number of
22145   // refinement steps.
22146   int Iterations = TLI.getDivRefinementSteps(VT, MF);
22147   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
22148     AddToWorklist(Est.getNode());
22149 
22150     SDLoc DL(Op);
22151     if (Iterations) {
22152       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
22153 
22154       // Newton iterations: Est = Est + Est (N - Arg * Est)
22155       // If this is the last iteration, also multiply by the numerator.
22156       for (int i = 0; i < Iterations; ++i) {
22157         SDValue MulEst = Est;
22158 
22159         if (i == Iterations - 1) {
22160           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
22161           AddToWorklist(MulEst.getNode());
22162         }
22163 
22164         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
22165         AddToWorklist(NewEst.getNode());
22166 
22167         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
22168                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
22169         AddToWorklist(NewEst.getNode());
22170 
22171         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
22172         AddToWorklist(NewEst.getNode());
22173 
22174         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
22175         AddToWorklist(Est.getNode());
22176       }
22177     } else {
22178       // If no iterations are available, multiply with N.
22179       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
22180       AddToWorklist(Est.getNode());
22181     }
22182 
22183     return Est;
22184   }
22185 
22186   return SDValue();
22187 }
22188 
22189 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
22190 /// For the reciprocal sqrt, we need to find the zero of the function:
22191 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
22192 ///     =>
22193 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
22194 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)22195 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
22196                                          unsigned Iterations,
22197                                          SDNodeFlags Flags, bool Reciprocal) {
22198   EVT VT = Arg.getValueType();
22199   SDLoc DL(Arg);
22200   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
22201 
22202   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
22203   // this entire sequence requires only one FP constant.
22204   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
22205   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
22206 
22207   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
22208   for (unsigned i = 0; i < Iterations; ++i) {
22209     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
22210     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
22211     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
22212     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
22213   }
22214 
22215   // If non-reciprocal square root is requested, multiply the result by Arg.
22216   if (!Reciprocal)
22217     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
22218 
22219   return Est;
22220 }
22221 
22222 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
22223 /// For the reciprocal sqrt, we need to find the zero of the function:
22224 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
22225 ///     =>
22226 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)22227 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
22228                                          unsigned Iterations,
22229                                          SDNodeFlags Flags, bool Reciprocal) {
22230   EVT VT = Arg.getValueType();
22231   SDLoc DL(Arg);
22232   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
22233   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
22234 
22235   // This routine must enter the loop below to work correctly
22236   // when (Reciprocal == false).
22237   assert(Iterations > 0);
22238 
22239   // Newton iterations for reciprocal square root:
22240   // E = (E * -0.5) * ((A * E) * E + -3.0)
22241   for (unsigned i = 0; i < Iterations; ++i) {
22242     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
22243     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
22244     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
22245 
22246     // When calculating a square root at the last iteration build:
22247     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
22248     // (notice a common subexpression)
22249     SDValue LHS;
22250     if (Reciprocal || (i + 1) < Iterations) {
22251       // RSQRT: LHS = (E * -0.5)
22252       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
22253     } else {
22254       // SQRT: LHS = (A * E) * -0.5
22255       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
22256     }
22257 
22258     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
22259   }
22260 
22261   return Est;
22262 }
22263 
22264 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
22265 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
22266 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)22267 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
22268                                            bool Reciprocal) {
22269   if (LegalDAG)
22270     return SDValue();
22271 
22272   // TODO: Handle half and/or extended types?
22273   EVT VT = Op.getValueType();
22274   if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
22275     return SDValue();
22276 
22277   // If estimates are explicitly disabled for this function, we're done.
22278   MachineFunction &MF = DAG.getMachineFunction();
22279   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
22280   if (Enabled == TLI.ReciprocalEstimate::Disabled)
22281     return SDValue();
22282 
22283   // Estimates may be explicitly enabled for this type with a custom number of
22284   // refinement steps.
22285   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
22286 
22287   bool UseOneConstNR = false;
22288   if (SDValue Est =
22289       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
22290                           Reciprocal)) {
22291     AddToWorklist(Est.getNode());
22292 
22293     if (Iterations)
22294       Est = UseOneConstNR
22295             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
22296             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
22297     if (!Reciprocal) {
22298       SDLoc DL(Op);
22299       // Try the target specific test first.
22300       SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
22301 
22302       // The estimate is now completely wrong if the input was exactly 0.0 or
22303       // possibly a denormal. Force the answer to 0.0 or value provided by
22304       // target for those cases.
22305       Est = DAG.getNode(
22306           Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
22307           Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
22308     }
22309     return Est;
22310   }
22311 
22312   return SDValue();
22313 }
22314 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)22315 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
22316   return buildSqrtEstimateImpl(Op, Flags, true);
22317 }
22318 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)22319 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
22320   return buildSqrtEstimateImpl(Op, Flags, false);
22321 }
22322 
22323 /// Return true if there is any possibility that the two addresses overlap.
isAlias(SDNode * Op0,SDNode * Op1) const22324 bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const {
22325 
22326   struct MemUseCharacteristics {
22327     bool IsVolatile;
22328     bool IsAtomic;
22329     SDValue BasePtr;
22330     int64_t Offset;
22331     Optional<int64_t> NumBytes;
22332     MachineMemOperand *MMO;
22333   };
22334 
22335   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
22336     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
22337       int64_t Offset = 0;
22338       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
22339         Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
22340                      ? C->getSExtValue()
22341                      : (LSN->getAddressingMode() == ISD::PRE_DEC)
22342                            ? -1 * C->getSExtValue()
22343                            : 0;
22344       uint64_t Size =
22345           MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize());
22346       return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(),
22347               Offset /*base offset*/,
22348               Optional<int64_t>(Size),
22349               LSN->getMemOperand()};
22350     }
22351     if (const auto *LN = cast<LifetimeSDNode>(N))
22352       return {false /*isVolatile*/, /*isAtomic*/ false, LN->getOperand(1),
22353               (LN->hasOffset()) ? LN->getOffset() : 0,
22354               (LN->hasOffset()) ? Optional<int64_t>(LN->getSize())
22355                                 : Optional<int64_t>(),
22356               (MachineMemOperand *)nullptr};
22357     // Default.
22358     return {false /*isvolatile*/, /*isAtomic*/ false, SDValue(),
22359             (int64_t)0 /*offset*/,
22360             Optional<int64_t>() /*size*/, (MachineMemOperand *)nullptr};
22361   };
22362 
22363   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
22364                         MUC1 = getCharacteristics(Op1);
22365 
22366   // If they are to the same address, then they must be aliases.
22367   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
22368       MUC0.Offset == MUC1.Offset)
22369     return true;
22370 
22371   // If they are both volatile then they cannot be reordered.
22372   if (MUC0.IsVolatile && MUC1.IsVolatile)
22373     return true;
22374 
22375   // Be conservative about atomics for the moment
22376   // TODO: This is way overconservative for unordered atomics (see D66309)
22377   if (MUC0.IsAtomic && MUC1.IsAtomic)
22378     return true;
22379 
22380   if (MUC0.MMO && MUC1.MMO) {
22381     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
22382         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
22383       return false;
22384   }
22385 
22386   // Try to prove that there is aliasing, or that there is no aliasing. Either
22387   // way, we can return now. If nothing can be proved, proceed with more tests.
22388   bool IsAlias;
22389   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
22390                                        DAG, IsAlias))
22391     return IsAlias;
22392 
22393   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
22394   // either are not known.
22395   if (!MUC0.MMO || !MUC1.MMO)
22396     return true;
22397 
22398   // If one operation reads from invariant memory, and the other may store, they
22399   // cannot alias. These should really be checking the equivalent of mayWrite,
22400   // but it only matters for memory nodes other than load /store.
22401   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
22402       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
22403     return false;
22404 
22405   // If we know required SrcValue1 and SrcValue2 have relatively large
22406   // alignment compared to the size and offset of the access, we may be able
22407   // to prove they do not alias. This check is conservative for now to catch
22408   // cases created by splitting vector types, it only works when the offsets are
22409   // multiples of the size of the data.
22410   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
22411   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
22412   Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
22413   Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
22414   auto &Size0 = MUC0.NumBytes;
22415   auto &Size1 = MUC1.NumBytes;
22416   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
22417       Size0.hasValue() && Size1.hasValue() && *Size0 == *Size1 &&
22418       OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 &&
22419       SrcValOffset1 % *Size1 == 0) {
22420     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
22421     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
22422 
22423     // There is no overlap between these relatively aligned accesses of
22424     // similar size. Return no alias.
22425     if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0)
22426       return false;
22427   }
22428 
22429   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
22430                    ? CombinerGlobalAA
22431                    : DAG.getSubtarget().useAA();
22432 #ifndef NDEBUG
22433   if (CombinerAAOnlyFunc.getNumOccurrences() &&
22434       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
22435     UseAA = false;
22436 #endif
22437 
22438   if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
22439       Size0.hasValue() && Size1.hasValue()) {
22440     // Use alias analysis information.
22441     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
22442     int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset;
22443     int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset;
22444     AliasResult AAResult = AA->alias(
22445         MemoryLocation(MUC0.MMO->getValue(), Overlap0,
22446                        UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
22447         MemoryLocation(MUC1.MMO->getValue(), Overlap1,
22448                        UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes()));
22449     if (AAResult == NoAlias)
22450       return false;
22451   }
22452 
22453   // Otherwise we have to assume they alias.
22454   return true;
22455 }
22456 
22457 /// Walk up chain skipping non-aliasing memory nodes,
22458 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)22459 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
22460                                    SmallVectorImpl<SDValue> &Aliases) {
22461   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
22462   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
22463 
22464   // Get alias information for node.
22465   // TODO: relax aliasing for unordered atomics (see D66309)
22466   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
22467 
22468   // Starting off.
22469   Chains.push_back(OriginalChain);
22470   unsigned Depth = 0;
22471 
22472   // Attempt to improve chain by a single step
22473   std::function<bool(SDValue &)> ImproveChain = [&](SDValue &C) -> bool {
22474     switch (C.getOpcode()) {
22475     case ISD::EntryToken:
22476       // No need to mark EntryToken.
22477       C = SDValue();
22478       return true;
22479     case ISD::LOAD:
22480     case ISD::STORE: {
22481       // Get alias information for C.
22482       // TODO: Relax aliasing for unordered atomics (see D66309)
22483       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
22484                       cast<LSBaseSDNode>(C.getNode())->isSimple();
22485       if ((IsLoad && IsOpLoad) || !isAlias(N, C.getNode())) {
22486         // Look further up the chain.
22487         C = C.getOperand(0);
22488         return true;
22489       }
22490       // Alias, so stop here.
22491       return false;
22492     }
22493 
22494     case ISD::CopyFromReg:
22495       // Always forward past past CopyFromReg.
22496       C = C.getOperand(0);
22497       return true;
22498 
22499     case ISD::LIFETIME_START:
22500     case ISD::LIFETIME_END: {
22501       // We can forward past any lifetime start/end that can be proven not to
22502       // alias the memory access.
22503       if (!isAlias(N, C.getNode())) {
22504         // Look further up the chain.
22505         C = C.getOperand(0);
22506         return true;
22507       }
22508       return false;
22509     }
22510     default:
22511       return false;
22512     }
22513   };
22514 
22515   // Look at each chain and determine if it is an alias.  If so, add it to the
22516   // aliases list.  If not, then continue up the chain looking for the next
22517   // candidate.
22518   while (!Chains.empty()) {
22519     SDValue Chain = Chains.pop_back_val();
22520 
22521     // Don't bother if we've seen Chain before.
22522     if (!Visited.insert(Chain.getNode()).second)
22523       continue;
22524 
22525     // For TokenFactor nodes, look at each operand and only continue up the
22526     // chain until we reach the depth limit.
22527     //
22528     // FIXME: The depth check could be made to return the last non-aliasing
22529     // chain we found before we hit a tokenfactor rather than the original
22530     // chain.
22531     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
22532       Aliases.clear();
22533       Aliases.push_back(OriginalChain);
22534       return;
22535     }
22536 
22537     if (Chain.getOpcode() == ISD::TokenFactor) {
22538       // We have to check each of the operands of the token factor for "small"
22539       // token factors, so we queue them up.  Adding the operands to the queue
22540       // (stack) in reverse order maintains the original order and increases the
22541       // likelihood that getNode will find a matching token factor (CSE.)
22542       if (Chain.getNumOperands() > 16) {
22543         Aliases.push_back(Chain);
22544         continue;
22545       }
22546       for (unsigned n = Chain.getNumOperands(); n;)
22547         Chains.push_back(Chain.getOperand(--n));
22548       ++Depth;
22549       continue;
22550     }
22551     // Everything else
22552     if (ImproveChain(Chain)) {
22553       // Updated Chain Found, Consider new chain if one exists.
22554       if (Chain.getNode())
22555         Chains.push_back(Chain);
22556       ++Depth;
22557       continue;
22558     }
22559     // No Improved Chain Possible, treat as Alias.
22560     Aliases.push_back(Chain);
22561   }
22562 }
22563 
22564 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
22565 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)22566 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
22567   if (OptLevel == CodeGenOpt::None)
22568     return OldChain;
22569 
22570   // Ops for replacing token factor.
22571   SmallVector<SDValue, 8> Aliases;
22572 
22573   // Accumulate all the aliases to this node.
22574   GatherAllAliases(N, OldChain, Aliases);
22575 
22576   // If no operands then chain to entry token.
22577   if (Aliases.size() == 0)
22578     return DAG.getEntryNode();
22579 
22580   // If a single operand then chain to it.  We don't need to revisit it.
22581   if (Aliases.size() == 1)
22582     return Aliases[0];
22583 
22584   // Construct a custom tailored token factor.
22585   return DAG.getTokenFactor(SDLoc(N), Aliases);
22586 }
22587 
22588 namespace {
22589 // TODO: Replace with with std::monostate when we move to C++17.
22590 struct UnitT { } Unit;
operator ==(const UnitT &,const UnitT &)22591 bool operator==(const UnitT &, const UnitT &) { return true; }
operator !=(const UnitT &,const UnitT &)22592 bool operator!=(const UnitT &, const UnitT &) { return false; }
22593 } // namespace
22594 
22595 // This function tries to collect a bunch of potentially interesting
22596 // nodes to improve the chains of, all at once. This might seem
22597 // redundant, as this function gets called when visiting every store
22598 // node, so why not let the work be done on each store as it's visited?
22599 //
22600 // I believe this is mainly important because mergeConsecutiveStores
22601 // is unable to deal with merging stores of different sizes, so unless
22602 // we improve the chains of all the potential candidates up-front
22603 // before running mergeConsecutiveStores, it might only see some of
22604 // the nodes that will eventually be candidates, and then not be able
22605 // to go from a partially-merged state to the desired final
22606 // fully-merged state.
22607 
parallelizeChainedStores(StoreSDNode * St)22608 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
22609   SmallVector<StoreSDNode *, 8> ChainedStores;
22610   StoreSDNode *STChain = St;
22611   // Intervals records which offsets from BaseIndex have been covered. In
22612   // the common case, every store writes to the immediately previous address
22613   // space and thus merged with the previous interval at insertion time.
22614 
22615   using IMap =
22616       llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>;
22617   IMap::Allocator A;
22618   IMap Intervals(A);
22619 
22620   // This holds the base pointer, index, and the offset in bytes from the base
22621   // pointer.
22622   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
22623 
22624   // We must have a base and an offset.
22625   if (!BasePtr.getBase().getNode())
22626     return false;
22627 
22628   // Do not handle stores to undef base pointers.
22629   if (BasePtr.getBase().isUndef())
22630     return false;
22631 
22632   // BaseIndexOffset assumes that offsets are fixed-size, which
22633   // is not valid for scalable vectors where the offsets are
22634   // scaled by `vscale`, so bail out early.
22635   if (St->getMemoryVT().isScalableVector())
22636     return false;
22637 
22638   // Add ST's interval.
22639   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit);
22640 
22641   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
22642     // If the chain has more than one use, then we can't reorder the mem ops.
22643     if (!SDValue(Chain, 0)->hasOneUse())
22644       break;
22645     // TODO: Relax for unordered atomics (see D66309)
22646     if (!Chain->isSimple() || Chain->isIndexed())
22647       break;
22648 
22649     // Find the base pointer and offset for this memory node.
22650     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
22651     // Check that the base pointer is the same as the original one.
22652     int64_t Offset;
22653     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
22654       break;
22655     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
22656     // Make sure we don't overlap with other intervals by checking the ones to
22657     // the left or right before inserting.
22658     auto I = Intervals.find(Offset);
22659     // If there's a next interval, we should end before it.
22660     if (I != Intervals.end() && I.start() < (Offset + Length))
22661       break;
22662     // If there's a previous interval, we should start after it.
22663     if (I != Intervals.begin() && (--I).stop() <= Offset)
22664       break;
22665     Intervals.insert(Offset, Offset + Length, Unit);
22666 
22667     ChainedStores.push_back(Chain);
22668     STChain = Chain;
22669   }
22670 
22671   // If we didn't find a chained store, exit.
22672   if (ChainedStores.size() == 0)
22673     return false;
22674 
22675   // Improve all chained stores (St and ChainedStores members) starting from
22676   // where the store chain ended and return single TokenFactor.
22677   SDValue NewChain = STChain->getChain();
22678   SmallVector<SDValue, 8> TFOps;
22679   for (unsigned I = ChainedStores.size(); I;) {
22680     StoreSDNode *S = ChainedStores[--I];
22681     SDValue BetterChain = FindBetterChain(S, NewChain);
22682     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
22683         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
22684     TFOps.push_back(SDValue(S, 0));
22685     ChainedStores[I] = S;
22686   }
22687 
22688   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
22689   SDValue BetterChain = FindBetterChain(St, NewChain);
22690   SDValue NewST;
22691   if (St->isTruncatingStore())
22692     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
22693                               St->getBasePtr(), St->getMemoryVT(),
22694                               St->getMemOperand());
22695   else
22696     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
22697                          St->getBasePtr(), St->getMemOperand());
22698 
22699   TFOps.push_back(NewST);
22700 
22701   // If we improved every element of TFOps, then we've lost the dependence on
22702   // NewChain to successors of St and we need to add it back to TFOps. Do so at
22703   // the beginning to keep relative order consistent with FindBetterChains.
22704   auto hasImprovedChain = [&](SDValue ST) -> bool {
22705     return ST->getOperand(0) != NewChain;
22706   };
22707   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
22708   if (AddNewChain)
22709     TFOps.insert(TFOps.begin(), NewChain);
22710 
22711   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
22712   CombineTo(St, TF);
22713 
22714   // Add TF and its operands to the worklist.
22715   AddToWorklist(TF.getNode());
22716   for (const SDValue &Op : TF->ops())
22717     AddToWorklist(Op.getNode());
22718   AddToWorklist(STChain);
22719   return true;
22720 }
22721 
findBetterNeighborChains(StoreSDNode * St)22722 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
22723   if (OptLevel == CodeGenOpt::None)
22724     return false;
22725 
22726   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
22727 
22728   // We must have a base and an offset.
22729   if (!BasePtr.getBase().getNode())
22730     return false;
22731 
22732   // Do not handle stores to undef base pointers.
22733   if (BasePtr.getBase().isUndef())
22734     return false;
22735 
22736   // Directly improve a chain of disjoint stores starting at St.
22737   if (parallelizeChainedStores(St))
22738     return true;
22739 
22740   // Improve St's Chain..
22741   SDValue BetterChain = FindBetterChain(St, St->getChain());
22742   if (St->getChain() != BetterChain) {
22743     replaceStoreChain(St, BetterChain);
22744     return true;
22745   }
22746   return false;
22747 }
22748 
22749 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOpt::Level OptLevel)22750 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
22751                            CodeGenOpt::Level OptLevel) {
22752   /// This is the main entry point to this class.
22753   DAGCombiner(*this, AA, OptLevel).Run(Level);
22754 }
22755