1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/Statistic.h"
32 #include "llvm/Analysis/AliasAnalysis.h"
33 #include "llvm/Analysis/MemoryLocation.h"
34 #include "llvm/Analysis/TargetLibraryInfo.h"
35 #include "llvm/Analysis/VectorUtils.h"
36 #include "llvm/CodeGen/DAGCombine.h"
37 #include "llvm/CodeGen/ISDOpcodes.h"
38 #include "llvm/CodeGen/MachineFrameInfo.h"
39 #include "llvm/CodeGen/MachineFunction.h"
40 #include "llvm/CodeGen/MachineMemOperand.h"
41 #include "llvm/CodeGen/RuntimeLibcalls.h"
42 #include "llvm/CodeGen/SelectionDAG.h"
43 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
44 #include "llvm/CodeGen/SelectionDAGNodes.h"
45 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
46 #include "llvm/CodeGen/TargetLowering.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/TargetSubtargetInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/Constant.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/LLVMContext.h"
56 #include "llvm/IR/Metadata.h"
57 #include "llvm/Support/Casting.h"
58 #include "llvm/Support/CodeGen.h"
59 #include "llvm/Support/CommandLine.h"
60 #include "llvm/Support/Compiler.h"
61 #include "llvm/Support/Debug.h"
62 #include "llvm/Support/ErrorHandling.h"
63 #include "llvm/Support/KnownBits.h"
64 #include "llvm/Support/MachineValueType.h"
65 #include "llvm/Support/MathExtras.h"
66 #include "llvm/Support/raw_ostream.h"
67 #include "llvm/Target/TargetMachine.h"
68 #include "llvm/Target/TargetOptions.h"
69 #include <algorithm>
70 #include <cassert>
71 #include <cstdint>
72 #include <functional>
73 #include <iterator>
74 #include <string>
75 #include <tuple>
76 #include <utility>
77 
78 using namespace llvm;
79 
80 #define DEBUG_TYPE "dagcombine"
81 
82 STATISTIC(NodesCombined   , "Number of dag nodes combined");
83 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
84 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
85 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
86 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
87 STATISTIC(SlicedLoads, "Number of load sliced");
88 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
89 
90 static cl::opt<bool>
91 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
92                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
93 
94 static cl::opt<bool>
95 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
96         cl::desc("Enable DAG combiner's use of TBAA"));
97 
98 #ifndef NDEBUG
99 static cl::opt<std::string>
100 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
101                    cl::desc("Only use DAG-combiner alias analysis in this"
102                             " function"));
103 #endif
104 
105 /// Hidden option to stress test load slicing, i.e., when this option
106 /// is enabled, load slicing bypasses most of its profitability guards.
107 static cl::opt<bool>
108 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
109                   cl::desc("Bypass the profitability model of load slicing"),
110                   cl::init(false));
111 
112 static cl::opt<bool>
113   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
114                     cl::desc("DAG combiner may split indexing from loads"));
115 
116 static cl::opt<bool>
117     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
118                        cl::desc("DAG combiner enable merging multiple stores "
119                                 "into a wider store"));
120 
121 static cl::opt<unsigned> TokenFactorInlineLimit(
122     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
123     cl::desc("Limit the number of operands to inline for Token Factors"));
124 
125 static cl::opt<unsigned> StoreMergeDependenceLimit(
126     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
127     cl::desc("Limit the number of times for the same StoreNode and RootNode "
128              "to bail out in store merging dependence check"));
129 
130 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
131     "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
132     cl::desc("DAG cominber enable reducing the width of load/op/store "
133              "sequence"));
134 
135 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
136     "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
137     cl::desc("DAG cominber enable load/<replace bytes>/store with "
138              "a narrower store"));
139 
140 namespace {
141 
142   class DAGCombiner {
143     SelectionDAG &DAG;
144     const TargetLowering &TLI;
145     const SelectionDAGTargetInfo *STI;
146     CombineLevel Level;
147     CodeGenOpt::Level OptLevel;
148     bool LegalDAG = false;
149     bool LegalOperations = false;
150     bool LegalTypes = false;
151     bool ForCodeSize;
152     bool DisableGenericCombines;
153 
154     /// Worklist of all of the nodes that need to be simplified.
155     ///
156     /// This must behave as a stack -- new nodes to process are pushed onto the
157     /// back and when processing we pop off of the back.
158     ///
159     /// The worklist will not contain duplicates but may contain null entries
160     /// due to nodes being deleted from the underlying DAG.
161     SmallVector<SDNode *, 64> Worklist;
162 
163     /// Mapping from an SDNode to its position on the worklist.
164     ///
165     /// This is used to find and remove nodes from the worklist (by nulling
166     /// them) when they are deleted from the underlying DAG. It relies on
167     /// stable indices of nodes within the worklist.
168     DenseMap<SDNode *, unsigned> WorklistMap;
169     /// This records all nodes attempted to add to the worklist since we
170     /// considered a new worklist entry. As we keep do not add duplicate nodes
171     /// in the worklist, this is different from the tail of the worklist.
172     SmallSetVector<SDNode *, 32> PruningList;
173 
174     /// Set of nodes which have been combined (at least once).
175     ///
176     /// This is used to allow us to reliably add any operands of a DAG node
177     /// which have not yet been combined to the worklist.
178     SmallPtrSet<SDNode *, 32> CombinedNodes;
179 
180     /// Map from candidate StoreNode to the pair of RootNode and count.
181     /// The count is used to track how many times we have seen the StoreNode
182     /// with the same RootNode bail out in dependence check. If we have seen
183     /// the bail out for the same pair many times over a limit, we won't
184     /// consider the StoreNode with the same RootNode as store merging
185     /// candidate again.
186     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
187 
188     // AA - Used for DAG load/store alias analysis.
189     AliasAnalysis *AA;
190 
191     /// When an instruction is simplified, add all users of the instruction to
192     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)193     void AddUsersToWorklist(SDNode *N) {
194       for (SDNode *Node : N->uses())
195         AddToWorklist(Node);
196     }
197 
198     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)199     void AddToWorklistWithUsers(SDNode *N) {
200       AddUsersToWorklist(N);
201       AddToWorklist(N);
202     }
203 
204     // Prune potentially dangling nodes. This is called after
205     // any visit to a node, but should also be called during a visit after any
206     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()207     void clearAddedDanglingWorklistEntries() {
208       // Check any nodes added to the worklist to see if they are prunable.
209       while (!PruningList.empty()) {
210         auto *N = PruningList.pop_back_val();
211         if (N->use_empty())
212           recursivelyDeleteUnusedNodes(N);
213       }
214     }
215 
getNextWorklistEntry()216     SDNode *getNextWorklistEntry() {
217       // Before we do any work, remove nodes that are not in use.
218       clearAddedDanglingWorklistEntries();
219       SDNode *N = nullptr;
220       // The Worklist holds the SDNodes in order, but it may contain null
221       // entries.
222       while (!N && !Worklist.empty()) {
223         N = Worklist.pop_back_val();
224       }
225 
226       if (N) {
227         bool GoodWorklistEntry = WorklistMap.erase(N);
228         (void)GoodWorklistEntry;
229         assert(GoodWorklistEntry &&
230                "Found a worklist entry without a corresponding map entry!");
231       }
232       return N;
233     }
234 
235     /// Call the node-specific routine that folds each particular type of node.
236     SDValue visit(SDNode *N);
237 
238   public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOpt::Level OL)239     DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
240         : DAG(D), TLI(D.getTargetLoweringInfo()),
241           STI(D.getSubtarget().getSelectionDAGInfo()),
242           Level(BeforeLegalizeTypes), OptLevel(OL), AA(AA) {
243       ForCodeSize = DAG.shouldOptForSize();
244       DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
245 
246       MaximumLegalStoreInBits = 0;
247       // We use the minimum store size here, since that's all we can guarantee
248       // for the scalable vector types.
249       for (MVT VT : MVT::all_valuetypes())
250         if (EVT(VT).isSimple() && VT != MVT::Other &&
251             TLI.isTypeLegal(EVT(VT)) &&
252             VT.getSizeInBits().getKnownMinSize() >= MaximumLegalStoreInBits)
253           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinSize();
254     }
255 
ConsiderForPruning(SDNode * N)256     void ConsiderForPruning(SDNode *N) {
257       // Mark this for potential pruning.
258       PruningList.insert(N);
259     }
260 
261     /// Add to the worklist making sure its instance is at the back (next to be
262     /// processed.)
AddToWorklist(SDNode * N)263     void AddToWorklist(SDNode *N) {
264       assert(N->getOpcode() != ISD::DELETED_NODE &&
265              "Deleted Node added to Worklist");
266 
267       // Skip handle nodes as they can't usefully be combined and confuse the
268       // zero-use deletion strategy.
269       if (N->getOpcode() == ISD::HANDLENODE)
270         return;
271 
272       ConsiderForPruning(N);
273 
274       if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
275         Worklist.push_back(N);
276     }
277 
278     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)279     void removeFromWorklist(SDNode *N) {
280       CombinedNodes.erase(N);
281       PruningList.remove(N);
282       StoreRootCountMap.erase(N);
283 
284       auto It = WorklistMap.find(N);
285       if (It == WorklistMap.end())
286         return; // Not in the worklist.
287 
288       // Null out the entry rather than erasing it to avoid a linear operation.
289       Worklist[It->second] = nullptr;
290       WorklistMap.erase(It);
291     }
292 
293     void deleteAndRecombine(SDNode *N);
294     bool recursivelyDeleteUnusedNodes(SDNode *N);
295 
296     /// Replaces all uses of the results of one DAG node with new values.
297     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
298                       bool AddTo = true);
299 
300     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)301     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
302       return CombineTo(N, &Res, 1, AddTo);
303     }
304 
305     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)306     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
307                       bool AddTo = true) {
308       SDValue To[] = { Res0, Res1 };
309       return CombineTo(N, To, 2, AddTo);
310     }
311 
312     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
313 
314   private:
315     unsigned MaximumLegalStoreInBits;
316 
317     /// Check the specified integer node value to see if it can be simplified or
318     /// if things it uses can be simplified by bit propagation.
319     /// If so, return true.
SimplifyDemandedBits(SDValue Op)320     bool SimplifyDemandedBits(SDValue Op) {
321       unsigned BitWidth = Op.getScalarValueSizeInBits();
322       APInt DemandedBits = APInt::getAllOnesValue(BitWidth);
323       return SimplifyDemandedBits(Op, DemandedBits);
324     }
325 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)326     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
327       TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
328       KnownBits Known;
329       if (!TLI.SimplifyDemandedBits(Op, DemandedBits, Known, TLO, 0, false))
330         return false;
331 
332       // Revisit the node.
333       AddToWorklist(Op.getNode());
334 
335       CommitTargetLoweringOpt(TLO);
336       return true;
337     }
338 
339     /// Check the specified vector node value to see if it can be simplified or
340     /// if things it uses can be simplified as it only uses some of the
341     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)342     bool SimplifyDemandedVectorElts(SDValue Op) {
343       // TODO: For now just pretend it cannot be simplified.
344       if (Op.getValueType().isScalableVector())
345         return false;
346 
347       unsigned NumElts = Op.getValueType().getVectorNumElements();
348       APInt DemandedElts = APInt::getAllOnesValue(NumElts);
349       return SimplifyDemandedVectorElts(Op, DemandedElts);
350     }
351 
352     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
353                               const APInt &DemandedElts,
354                               bool AssumeSingleUse = false);
355     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
356                                     bool AssumeSingleUse = false);
357 
358     bool CombineToPreIndexedLoadStore(SDNode *N);
359     bool CombineToPostIndexedLoadStore(SDNode *N);
360     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
361     bool SliceUpLoad(SDNode *N);
362 
363     // Scalars have size 0 to distinguish from singleton vectors.
364     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
365     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
366     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
367 
368     /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
369     ///   load.
370     ///
371     /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
372     /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
373     /// \param EltNo index of the vector element to load.
374     /// \param OriginalLoad load that EVE came from to be replaced.
375     /// \returns EVE on success SDValue() on failure.
376     SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
377                                          SDValue EltNo,
378                                          LoadSDNode *OriginalLoad);
379     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
380     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
381     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
382     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
383     SDValue PromoteIntBinOp(SDValue Op);
384     SDValue PromoteIntShiftOp(SDValue Op);
385     SDValue PromoteExtend(SDValue Op);
386     bool PromoteLoad(SDValue Op);
387 
388     /// Call the node-specific routine that knows how to fold each
389     /// particular type of node. If that doesn't do anything, try the
390     /// target-specific DAG combines.
391     SDValue combine(SDNode *N);
392 
393     // Visitation implementation - Implement dag node combining for different
394     // node types.  The semantics are as follows:
395     // Return Value:
396     //   SDValue.getNode() == 0 - No change was made
397     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
398     //   otherwise              - N should be replaced by the returned Operand.
399     //
400     SDValue visitTokenFactor(SDNode *N);
401     SDValue visitMERGE_VALUES(SDNode *N);
402     SDValue visitADD(SDNode *N);
403     SDValue visitADDLike(SDNode *N);
404     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
405     SDValue visitSUB(SDNode *N);
406     SDValue visitADDSAT(SDNode *N);
407     SDValue visitSUBSAT(SDNode *N);
408     SDValue visitADDC(SDNode *N);
409     SDValue visitADDO(SDNode *N);
410     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
411     SDValue visitSUBC(SDNode *N);
412     SDValue visitSUBO(SDNode *N);
413     SDValue visitADDE(SDNode *N);
414     SDValue visitADDCARRY(SDNode *N);
415     SDValue visitSADDO_CARRY(SDNode *N);
416     SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
417     SDValue visitSUBE(SDNode *N);
418     SDValue visitSUBCARRY(SDNode *N);
419     SDValue visitSSUBO_CARRY(SDNode *N);
420     SDValue visitMUL(SDNode *N);
421     SDValue visitMULFIX(SDNode *N);
422     SDValue useDivRem(SDNode *N);
423     SDValue visitSDIV(SDNode *N);
424     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
425     SDValue visitUDIV(SDNode *N);
426     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
427     SDValue visitREM(SDNode *N);
428     SDValue visitMULHU(SDNode *N);
429     SDValue visitMULHS(SDNode *N);
430     SDValue visitSMUL_LOHI(SDNode *N);
431     SDValue visitUMUL_LOHI(SDNode *N);
432     SDValue visitMULO(SDNode *N);
433     SDValue visitIMINMAX(SDNode *N);
434     SDValue visitAND(SDNode *N);
435     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
436     SDValue visitOR(SDNode *N);
437     SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
438     SDValue visitXOR(SDNode *N);
439     SDValue SimplifyVBinOp(SDNode *N);
440     SDValue visitSHL(SDNode *N);
441     SDValue visitSRA(SDNode *N);
442     SDValue visitSRL(SDNode *N);
443     SDValue visitFunnelShift(SDNode *N);
444     SDValue visitRotate(SDNode *N);
445     SDValue visitABS(SDNode *N);
446     SDValue visitBSWAP(SDNode *N);
447     SDValue visitBITREVERSE(SDNode *N);
448     SDValue visitCTLZ(SDNode *N);
449     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
450     SDValue visitCTTZ(SDNode *N);
451     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
452     SDValue visitCTPOP(SDNode *N);
453     SDValue visitSELECT(SDNode *N);
454     SDValue visitVSELECT(SDNode *N);
455     SDValue visitSELECT_CC(SDNode *N);
456     SDValue visitSETCC(SDNode *N);
457     SDValue visitSETCCCARRY(SDNode *N);
458     SDValue visitSIGN_EXTEND(SDNode *N);
459     SDValue visitZERO_EXTEND(SDNode *N);
460     SDValue visitANY_EXTEND(SDNode *N);
461     SDValue visitAssertExt(SDNode *N);
462     SDValue visitAssertAlign(SDNode *N);
463     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
464     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
465     SDValue visitTRUNCATE(SDNode *N);
466     SDValue visitBITCAST(SDNode *N);
467     SDValue visitFREEZE(SDNode *N);
468     SDValue visitBUILD_PAIR(SDNode *N);
469     SDValue visitFADD(SDNode *N);
470     SDValue visitSTRICT_FADD(SDNode *N);
471     SDValue visitFSUB(SDNode *N);
472     SDValue visitFMUL(SDNode *N);
473     SDValue visitFMA(SDNode *N);
474     SDValue visitFDIV(SDNode *N);
475     SDValue visitFREM(SDNode *N);
476     SDValue visitFSQRT(SDNode *N);
477     SDValue visitFCOPYSIGN(SDNode *N);
478     SDValue visitFPOW(SDNode *N);
479     SDValue visitSINT_TO_FP(SDNode *N);
480     SDValue visitUINT_TO_FP(SDNode *N);
481     SDValue visitFP_TO_SINT(SDNode *N);
482     SDValue visitFP_TO_UINT(SDNode *N);
483     SDValue visitFP_ROUND(SDNode *N);
484     SDValue visitFP_EXTEND(SDNode *N);
485     SDValue visitFNEG(SDNode *N);
486     SDValue visitFABS(SDNode *N);
487     SDValue visitFCEIL(SDNode *N);
488     SDValue visitFTRUNC(SDNode *N);
489     SDValue visitFFLOOR(SDNode *N);
490     SDValue visitFMINNUM(SDNode *N);
491     SDValue visitFMAXNUM(SDNode *N);
492     SDValue visitFMINIMUM(SDNode *N);
493     SDValue visitFMAXIMUM(SDNode *N);
494     SDValue visitBRCOND(SDNode *N);
495     SDValue visitBR_CC(SDNode *N);
496     SDValue visitLOAD(SDNode *N);
497 
498     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
499     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
500 
501     SDValue visitSTORE(SDNode *N);
502     SDValue visitLIFETIME_END(SDNode *N);
503     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
504     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
505     SDValue visitBUILD_VECTOR(SDNode *N);
506     SDValue visitCONCAT_VECTORS(SDNode *N);
507     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
508     SDValue visitVECTOR_SHUFFLE(SDNode *N);
509     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
510     SDValue visitINSERT_SUBVECTOR(SDNode *N);
511     SDValue visitMLOAD(SDNode *N);
512     SDValue visitMSTORE(SDNode *N);
513     SDValue visitMGATHER(SDNode *N);
514     SDValue visitMSCATTER(SDNode *N);
515     SDValue visitFP_TO_FP16(SDNode *N);
516     SDValue visitFP16_TO_FP(SDNode *N);
517     SDValue visitVECREDUCE(SDNode *N);
518 
519     SDValue visitFADDForFMACombine(SDNode *N);
520     SDValue visitFSUBForFMACombine(SDNode *N);
521     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
522 
523     SDValue XformToShuffleWithZero(SDNode *N);
524     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
525                                                     const SDLoc &DL, SDValue N0,
526                                                     SDValue N1);
527     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
528                                       SDValue N1);
529     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
530                            SDValue N1, SDNodeFlags Flags);
531 
532     SDValue visitShiftByConstant(SDNode *N);
533 
534     SDValue foldSelectOfConstants(SDNode *N);
535     SDValue foldVSelectOfConstants(SDNode *N);
536     SDValue foldBinOpIntoSelect(SDNode *BO);
537     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
538     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
539     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
540     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
541                              SDValue N2, SDValue N3, ISD::CondCode CC,
542                              bool NotExtCompare = false);
543     SDValue convertSelectOfFPConstantsToLoadOffset(
544         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
545         ISD::CondCode CC);
546     SDValue foldSignChangeInBitcast(SDNode *N);
547     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
548                                    SDValue N2, SDValue N3, ISD::CondCode CC);
549     SDValue foldSelectOfBinops(SDNode *N);
550     SDValue foldSextSetcc(SDNode *N);
551     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
552                               const SDLoc &DL);
553     SDValue foldSubToUSubSat(EVT DstVT, SDNode *N);
554     SDValue unfoldMaskedMerge(SDNode *N);
555     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
556     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
557                           const SDLoc &DL, bool foldBooleans);
558     SDValue rebuildSetCC(SDValue N);
559 
560     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
561                            SDValue &CC, bool MatchStrict = false) const;
562     bool isOneUseSetCC(SDValue N) const;
563 
564     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
565                                          unsigned HiOp);
566     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
567     SDValue CombineExtLoad(SDNode *N);
568     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
569     SDValue combineRepeatedFPDivisors(SDNode *N);
570     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
571     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
572     SDValue BuildSDIV(SDNode *N);
573     SDValue BuildSDIVPow2(SDNode *N);
574     SDValue BuildUDIV(SDNode *N);
575     SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
576     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
577     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
578     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
579     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
580     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
581                                 SDNodeFlags Flags, bool Reciprocal);
582     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
583                                 SDNodeFlags Flags, bool Reciprocal);
584     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
585                                bool DemandHighBits = true);
586     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
587     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
588                               SDValue InnerPos, SDValue InnerNeg,
589                               unsigned PosOpcode, unsigned NegOpcode,
590                               const SDLoc &DL);
591     SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
592                               SDValue InnerPos, SDValue InnerNeg,
593                               unsigned PosOpcode, unsigned NegOpcode,
594                               const SDLoc &DL);
595     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
596     SDValue MatchLoadCombine(SDNode *N);
597     SDValue mergeTruncStores(StoreSDNode *N);
598     SDValue ReduceLoadWidth(SDNode *N);
599     SDValue ReduceLoadOpStoreWidth(SDNode *N);
600     SDValue splitMergedValStore(StoreSDNode *ST);
601     SDValue TransformFPLoadStorePair(SDNode *N);
602     SDValue convertBuildVecZextToZext(SDNode *N);
603     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
604     SDValue reduceBuildVecTruncToBitCast(SDNode *N);
605     SDValue reduceBuildVecToShuffle(SDNode *N);
606     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
607                                   ArrayRef<int> VectorMask, SDValue VecIn1,
608                                   SDValue VecIn2, unsigned LeftIdx,
609                                   bool DidSplitVec);
610     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
611 
612     /// Walk up chain skipping non-aliasing memory nodes,
613     /// looking for aliasing nodes and adding them to the Aliases vector.
614     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
615                           SmallVectorImpl<SDValue> &Aliases);
616 
617     /// Return true if there is any possibility that the two addresses overlap.
618     bool isAlias(SDNode *Op0, SDNode *Op1) const;
619 
620     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
621     /// chain (aliasing node.)
622     SDValue FindBetterChain(SDNode *N, SDValue Chain);
623 
624     /// Try to replace a store and any possibly adjacent stores on
625     /// consecutive chains with better chains. Return true only if St is
626     /// replaced.
627     ///
628     /// Notice that other chains may still be replaced even if the function
629     /// returns false.
630     bool findBetterNeighborChains(StoreSDNode *St);
631 
632     // Helper for findBetterNeighborChains. Walk up store chain add additional
633     // chained stores that do not overlap and can be parallelized.
634     bool parallelizeChainedStores(StoreSDNode *St);
635 
636     /// Holds a pointer to an LSBaseSDNode as well as information on where it
637     /// is located in a sequence of memory operations connected by a chain.
638     struct MemOpLink {
639       // Ptr to the mem node.
640       LSBaseSDNode *MemNode;
641 
642       // Offset from the base ptr.
643       int64_t OffsetFromBase;
644 
MemOpLink__anon035eee1e0111::DAGCombiner::MemOpLink645       MemOpLink(LSBaseSDNode *N, int64_t Offset)
646           : MemNode(N), OffsetFromBase(Offset) {}
647     };
648 
649     // Classify the origin of a stored value.
650     enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)651     StoreSource getStoreSource(SDValue StoreVal) {
652       switch (StoreVal.getOpcode()) {
653       case ISD::Constant:
654       case ISD::ConstantFP:
655         return StoreSource::Constant;
656       case ISD::EXTRACT_VECTOR_ELT:
657       case ISD::EXTRACT_SUBVECTOR:
658         return StoreSource::Extract;
659       case ISD::LOAD:
660         return StoreSource::Load;
661       default:
662         return StoreSource::Unknown;
663       }
664     }
665 
666     /// This is a helper function for visitMUL to check the profitability
667     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
668     /// MulNode is the original multiply, AddNode is (add x, c1),
669     /// and ConstNode is c2.
670     bool isMulAddWithConstProfitable(SDNode *MulNode,
671                                      SDValue &AddNode,
672                                      SDValue &ConstNode);
673 
674     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
675     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
676     /// the type of the loaded value to be extended.
677     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
678                           EVT LoadResultTy, EVT &ExtVT);
679 
680     /// Helper function to calculate whether the given Load/Store can have its
681     /// width reduced to ExtVT.
682     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
683                            EVT &MemVT, unsigned ShAmt = 0);
684 
685     /// Used by BackwardsPropagateMask to find suitable loads.
686     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
687                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
688                            ConstantSDNode *Mask, SDNode *&NodeToMask);
689     /// Attempt to propagate a given AND node back to load leaves so that they
690     /// can be combined into narrow loads.
691     bool BackwardsPropagateMask(SDNode *N);
692 
693     /// Helper function for mergeConsecutiveStores which merges the component
694     /// store chains.
695     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
696                                 unsigned NumStores);
697 
698     /// This is a helper function for mergeConsecutiveStores. When the source
699     /// elements of the consecutive stores are all constants or all extracted
700     /// vector elements, try to merge them into one larger store introducing
701     /// bitcasts if necessary.  \return True if a merged store was created.
702     bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
703                                          EVT MemVT, unsigned NumStores,
704                                          bool IsConstantSrc, bool UseVector,
705                                          bool UseTrunc);
706 
707     /// This is a helper function for mergeConsecutiveStores. Stores that
708     /// potentially may be merged with St are placed in StoreNodes. RootNode is
709     /// a chain predecessor to all store candidates.
710     void getStoreMergeCandidates(StoreSDNode *St,
711                                  SmallVectorImpl<MemOpLink> &StoreNodes,
712                                  SDNode *&Root);
713 
714     /// Helper function for mergeConsecutiveStores. Checks if candidate stores
715     /// have indirect dependency through their operands. RootNode is the
716     /// predecessor to all stores calculated by getStoreMergeCandidates and is
717     /// used to prune the dependency check. \return True if safe to merge.
718     bool checkMergeStoreCandidatesForDependencies(
719         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
720         SDNode *RootNode);
721 
722     /// This is a helper function for mergeConsecutiveStores. Given a list of
723     /// store candidates, find the first N that are consecutive in memory.
724     /// Returns 0 if there are not at least 2 consecutive stores to try merging.
725     unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
726                                   int64_t ElementSizeBytes) const;
727 
728     /// This is a helper function for mergeConsecutiveStores. It is used for
729     /// store chains that are composed entirely of constant values.
730     bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
731                                   unsigned NumConsecutiveStores,
732                                   EVT MemVT, SDNode *Root, bool AllowVectors);
733 
734     /// This is a helper function for mergeConsecutiveStores. It is used for
735     /// store chains that are composed entirely of extracted vector elements.
736     /// When extracting multiple vector elements, try to store them in one
737     /// vector store rather than a sequence of scalar stores.
738     bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
739                                  unsigned NumConsecutiveStores, EVT MemVT,
740                                  SDNode *Root);
741 
742     /// This is a helper function for mergeConsecutiveStores. It is used for
743     /// store chains that are composed entirely of loaded values.
744     bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
745                               unsigned NumConsecutiveStores, EVT MemVT,
746                               SDNode *Root, bool AllowVectors,
747                               bool IsNonTemporalStore, bool IsNonTemporalLoad);
748 
749     /// Merge consecutive store operations into a wide store.
750     /// This optimization uses wide integers or vectors when possible.
751     /// \return true if stores were merged.
752     bool mergeConsecutiveStores(StoreSDNode *St);
753 
754     /// Try to transform a truncation where C is a constant:
755     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
756     ///
757     /// \p N needs to be a truncation and its first operand an AND. Other
758     /// requirements are checked by the function (e.g. that trunc is
759     /// single-use) and if missed an empty SDValue is returned.
760     SDValue distributeTruncateThroughAnd(SDNode *N);
761 
762     /// Helper function to determine whether the target supports operation
763     /// given by \p Opcode for type \p VT, that is, whether the operation
764     /// is legal or custom before legalizing operations, and whether is
765     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)766     bool hasOperation(unsigned Opcode, EVT VT) {
767       return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
768     }
769 
770   public:
771     /// Runs the dag combiner on all nodes in the work list
772     void Run(CombineLevel AtLevel);
773 
getDAG() const774     SelectionDAG &getDAG() const { return DAG; }
775 
776     /// Returns a type large enough to hold any valid shift amount - before type
777     /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)778     EVT getShiftAmountTy(EVT LHSTy) {
779       assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
780       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
781     }
782 
783     /// This method returns true if we are running before type legalization or
784     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)785     bool isTypeLegal(const EVT &VT) {
786       if (!LegalTypes) return true;
787       return TLI.isTypeLegal(VT);
788     }
789 
790     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const791     EVT getSetCCResultType(EVT VT) const {
792       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
793     }
794 
795     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
796                          SDValue OrigLoad, SDValue ExtLoad,
797                          ISD::NodeType ExtType);
798   };
799 
800 /// This class is a DAGUpdateListener that removes any deleted
801 /// nodes from the worklist.
802 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
803   DAGCombiner &DC;
804 
805 public:
WorklistRemover(DAGCombiner & dc)806   explicit WorklistRemover(DAGCombiner &dc)
807     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
808 
NodeDeleted(SDNode * N,SDNode * E)809   void NodeDeleted(SDNode *N, SDNode *E) override {
810     DC.removeFromWorklist(N);
811   }
812 };
813 
814 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
815   DAGCombiner &DC;
816 
817 public:
WorklistInserter(DAGCombiner & dc)818   explicit WorklistInserter(DAGCombiner &dc)
819       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
820 
821   // FIXME: Ideally we could add N to the worklist, but this causes exponential
822   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)823   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
824 };
825 
826 } // end anonymous namespace
827 
828 //===----------------------------------------------------------------------===//
829 //  TargetLowering::DAGCombinerInfo implementation
830 //===----------------------------------------------------------------------===//
831 
AddToWorklist(SDNode * N)832 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
833   ((DAGCombiner*)DC)->AddToWorklist(N);
834 }
835 
836 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)837 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
838   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
839 }
840 
841 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)842 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
843   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
844 }
845 
846 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)847 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
848   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
849 }
850 
851 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)852 recursivelyDeleteUnusedNodes(SDNode *N) {
853   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
854 }
855 
856 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)857 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
858   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
859 }
860 
861 //===----------------------------------------------------------------------===//
862 // Helper Functions
863 //===----------------------------------------------------------------------===//
864 
deleteAndRecombine(SDNode * N)865 void DAGCombiner::deleteAndRecombine(SDNode *N) {
866   removeFromWorklist(N);
867 
868   // If the operands of this node are only used by the node, they will now be
869   // dead. Make sure to re-visit them and recursively delete dead nodes.
870   for (const SDValue &Op : N->ops())
871     // For an operand generating multiple values, one of the values may
872     // become dead allowing further simplification (e.g. split index
873     // arithmetic from an indexed load).
874     if (Op->hasOneUse() || Op->getNumValues() > 1)
875       AddToWorklist(Op.getNode());
876 
877   DAG.DeleteNode(N);
878 }
879 
880 // APInts must be the same size for most operations, this helper
881 // function zero extends the shorter of the pair so that they match.
882 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)883 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
884   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
885   LHS = LHS.zextOrSelf(Bits);
886   RHS = RHS.zextOrSelf(Bits);
887 }
888 
889 // Return true if this node is a setcc, or is a select_cc
890 // that selects between the target values used for true and false, making it
891 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
892 // the appropriate nodes based on the type of node we are checking. This
893 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const894 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
895                                     SDValue &CC, bool MatchStrict) const {
896   if (N.getOpcode() == ISD::SETCC) {
897     LHS = N.getOperand(0);
898     RHS = N.getOperand(1);
899     CC  = N.getOperand(2);
900     return true;
901   }
902 
903   if (MatchStrict &&
904       (N.getOpcode() == ISD::STRICT_FSETCC ||
905        N.getOpcode() == ISD::STRICT_FSETCCS)) {
906     LHS = N.getOperand(1);
907     RHS = N.getOperand(2);
908     CC  = N.getOperand(3);
909     return true;
910   }
911 
912   if (N.getOpcode() != ISD::SELECT_CC ||
913       !TLI.isConstTrueVal(N.getOperand(2).getNode()) ||
914       !TLI.isConstFalseVal(N.getOperand(3).getNode()))
915     return false;
916 
917   if (TLI.getBooleanContents(N.getValueType()) ==
918       TargetLowering::UndefinedBooleanContent)
919     return false;
920 
921   LHS = N.getOperand(0);
922   RHS = N.getOperand(1);
923   CC  = N.getOperand(4);
924   return true;
925 }
926 
927 /// Return true if this is a SetCC-equivalent operation with only one use.
928 /// If this is true, it allows the users to invert the operation for free when
929 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const930 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
931   SDValue N0, N1, N2;
932   if (isSetCCEquivalent(N, N0, N1, N2) && N.getNode()->hasOneUse())
933     return true;
934   return false;
935 }
936 
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)937 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
938   if (!ScalarTy.isSimple())
939     return false;
940 
941   uint64_t MaskForTy = 0ULL;
942   switch (ScalarTy.getSimpleVT().SimpleTy) {
943   case MVT::i8:
944     MaskForTy = 0xFFULL;
945     break;
946   case MVT::i16:
947     MaskForTy = 0xFFFFULL;
948     break;
949   case MVT::i32:
950     MaskForTy = 0xFFFFFFFFULL;
951     break;
952   default:
953     return false;
954     break;
955   }
956 
957   APInt Val;
958   if (ISD::isConstantSplatVector(N, Val))
959     return Val.getLimitedValue() == MaskForTy;
960 
961   return false;
962 }
963 
964 // Determines if it is a constant integer or a splat/build vector of constant
965 // integers (and undefs).
966 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)967 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
968   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
969     return !(Const->isOpaque() && NoOpaques);
970   if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
971     return false;
972   unsigned BitWidth = N.getScalarValueSizeInBits();
973   for (const SDValue &Op : N->op_values()) {
974     if (Op.isUndef())
975       continue;
976     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
977     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
978         (Const->isOpaque() && NoOpaques))
979       return false;
980   }
981   return true;
982 }
983 
984 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
985 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)986 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
987   if (V.getOpcode() != ISD::BUILD_VECTOR)
988     return false;
989   return isConstantOrConstantVector(V, NoOpaques) ||
990          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
991 }
992 
993 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)994 static bool canSplitIdx(LoadSDNode *LD) {
995   return MaySplitLoadIndex &&
996          (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
997           !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
998 }
999 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)1000 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1001                                                              const SDLoc &DL,
1002                                                              SDValue N0,
1003                                                              SDValue N1) {
1004   // Currently this only tries to ensure we don't undo the GEP splits done by
1005   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1006   // we check if the following transformation would be problematic:
1007   // (load/store (add, (add, x, offset1), offset2)) ->
1008   // (load/store (add, x, offset1+offset2)).
1009 
1010   if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1011     return false;
1012 
1013   if (N0.hasOneUse())
1014     return false;
1015 
1016   auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
1017   auto *C2 = dyn_cast<ConstantSDNode>(N1);
1018   if (!C1 || !C2)
1019     return false;
1020 
1021   const APInt &C1APIntVal = C1->getAPIntValue();
1022   const APInt &C2APIntVal = C2->getAPIntValue();
1023   if (C1APIntVal.getBitWidth() > 64 || C2APIntVal.getBitWidth() > 64)
1024     return false;
1025 
1026   const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1027   if (CombinedValueIntVal.getBitWidth() > 64)
1028     return false;
1029   const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1030 
1031   for (SDNode *Node : N0->uses()) {
1032     auto LoadStore = dyn_cast<MemSDNode>(Node);
1033     if (LoadStore) {
1034       // Is x[offset2] already not a legal addressing mode? If so then
1035       // reassociating the constants breaks nothing (we test offset2 because
1036       // that's the one we hope to fold into the load or store).
1037       TargetLoweringBase::AddrMode AM;
1038       AM.HasBaseReg = true;
1039       AM.BaseOffs = C2APIntVal.getSExtValue();
1040       EVT VT = LoadStore->getMemoryVT();
1041       unsigned AS = LoadStore->getAddressSpace();
1042       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1043       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1044         continue;
1045 
1046       // Would x[offset1+offset2] still be a legal addressing mode?
1047       AM.BaseOffs = CombinedValue;
1048       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1049         return true;
1050     }
1051   }
1052 
1053   return false;
1054 }
1055 
1056 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
1057 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)1058 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1059                                                SDValue N0, SDValue N1) {
1060   EVT VT = N0.getValueType();
1061 
1062   if (N0.getOpcode() != Opc)
1063     return SDValue();
1064 
1065   if (DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) {
1066     if (DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
1067       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1068       if (SDValue OpNode =
1069               DAG.FoldConstantArithmetic(Opc, DL, VT, {N0.getOperand(1), N1}))
1070         return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode);
1071       return SDValue();
1072     }
1073     if (N0.hasOneUse()) {
1074       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1075       //              iff (op x, c1) has one use
1076       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N0.getOperand(0), N1);
1077       if (!OpNode.getNode())
1078         return SDValue();
1079       return DAG.getNode(Opc, DL, VT, OpNode, N0.getOperand(1));
1080     }
1081   }
1082   return SDValue();
1083 }
1084 
1085 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1086 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1087                                     SDValue N1, SDNodeFlags Flags) {
1088   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1089 
1090   // Floating-point reassociation is not allowed without loose FP math.
1091   if (N0.getValueType().isFloatingPoint() ||
1092       N1.getValueType().isFloatingPoint())
1093     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1094       return SDValue();
1095 
1096   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1))
1097     return Combined;
1098   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0))
1099     return Combined;
1100   return SDValue();
1101 }
1102 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1103 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1104                                bool AddTo) {
1105   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1106   ++NodesCombined;
1107   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1108              To[0].getNode()->dump(&DAG);
1109              dbgs() << " and " << NumTo - 1 << " other values\n");
1110   for (unsigned i = 0, e = NumTo; i != e; ++i)
1111     assert((!To[i].getNode() ||
1112             N->getValueType(i) == To[i].getValueType()) &&
1113            "Cannot combine value to value of different type!");
1114 
1115   WorklistRemover DeadNodes(*this);
1116   DAG.ReplaceAllUsesWith(N, To);
1117   if (AddTo) {
1118     // Push the new nodes and any users onto the worklist
1119     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1120       if (To[i].getNode()) {
1121         AddToWorklist(To[i].getNode());
1122         AddUsersToWorklist(To[i].getNode());
1123       }
1124     }
1125   }
1126 
1127   // Finally, if the node is now dead, remove it from the graph.  The node
1128   // may not be dead if the replacement process recursively simplified to
1129   // something else needing this node.
1130   if (N->use_empty())
1131     deleteAndRecombine(N);
1132   return SDValue(N, 0);
1133 }
1134 
1135 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1136 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1137   // Replace the old value with the new one.
1138   ++NodesCombined;
1139   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG);
1140              dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG);
1141              dbgs() << '\n');
1142 
1143   // Replace all uses.  If any nodes become isomorphic to other nodes and
1144   // are deleted, make sure to remove them from our worklist.
1145   WorklistRemover DeadNodes(*this);
1146   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1147 
1148   // Push the new node and any (possibly new) users onto the worklist.
1149   AddToWorklistWithUsers(TLO.New.getNode());
1150 
1151   // Finally, if the node is now dead, remove it from the graph.  The node
1152   // may not be dead if the replacement process recursively simplified to
1153   // something else needing this node.
1154   if (TLO.Old.getNode()->use_empty())
1155     deleteAndRecombine(TLO.Old.getNode());
1156 }
1157 
1158 /// Check the specified integer node value to see if it can be simplified or if
1159 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1160 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1161                                        const APInt &DemandedElts,
1162                                        bool AssumeSingleUse) {
1163   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1164   KnownBits Known;
1165   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1166                                 AssumeSingleUse))
1167     return false;
1168 
1169   // Revisit the node.
1170   AddToWorklist(Op.getNode());
1171 
1172   CommitTargetLoweringOpt(TLO);
1173   return true;
1174 }
1175 
1176 /// Check the specified vector node value to see if it can be simplified or
1177 /// if things it uses can be simplified as it only uses some of the elements.
1178 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1179 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1180                                              const APInt &DemandedElts,
1181                                              bool AssumeSingleUse) {
1182   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1183   APInt KnownUndef, KnownZero;
1184   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1185                                       TLO, 0, AssumeSingleUse))
1186     return false;
1187 
1188   // Revisit the node.
1189   AddToWorklist(Op.getNode());
1190 
1191   CommitTargetLoweringOpt(TLO);
1192   return true;
1193 }
1194 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1195 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1196   SDLoc DL(Load);
1197   EVT VT = Load->getValueType(0);
1198   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1199 
1200   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1201              Trunc.getNode()->dump(&DAG); dbgs() << '\n');
1202   WorklistRemover DeadNodes(*this);
1203   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1204   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1205   deleteAndRecombine(Load);
1206   AddToWorklist(Trunc.getNode());
1207 }
1208 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1209 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1210   Replace = false;
1211   SDLoc DL(Op);
1212   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1213     LoadSDNode *LD = cast<LoadSDNode>(Op);
1214     EVT MemVT = LD->getMemoryVT();
1215     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1216                                                       : LD->getExtensionType();
1217     Replace = true;
1218     return DAG.getExtLoad(ExtType, DL, PVT,
1219                           LD->getChain(), LD->getBasePtr(),
1220                           MemVT, LD->getMemOperand());
1221   }
1222 
1223   unsigned Opc = Op.getOpcode();
1224   switch (Opc) {
1225   default: break;
1226   case ISD::AssertSext:
1227     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1228       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1229     break;
1230   case ISD::AssertZext:
1231     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1232       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1233     break;
1234   case ISD::Constant: {
1235     unsigned ExtOpc =
1236       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1237     return DAG.getNode(ExtOpc, DL, PVT, Op);
1238   }
1239   }
1240 
1241   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1242     return SDValue();
1243   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1244 }
1245 
SExtPromoteOperand(SDValue Op,EVT PVT)1246 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1247   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1248     return SDValue();
1249   EVT OldVT = Op.getValueType();
1250   SDLoc DL(Op);
1251   bool Replace = false;
1252   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1253   if (!NewOp.getNode())
1254     return SDValue();
1255   AddToWorklist(NewOp.getNode());
1256 
1257   if (Replace)
1258     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1259   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1260                      DAG.getValueType(OldVT));
1261 }
1262 
ZExtPromoteOperand(SDValue Op,EVT PVT)1263 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1264   EVT OldVT = Op.getValueType();
1265   SDLoc DL(Op);
1266   bool Replace = false;
1267   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1268   if (!NewOp.getNode())
1269     return SDValue();
1270   AddToWorklist(NewOp.getNode());
1271 
1272   if (Replace)
1273     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1274   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1275 }
1276 
1277 /// Promote the specified integer binary operation if the target indicates it is
1278 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1279 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1280 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1281   if (!LegalOperations)
1282     return SDValue();
1283 
1284   EVT VT = Op.getValueType();
1285   if (VT.isVector() || !VT.isInteger())
1286     return SDValue();
1287 
1288   // If operation type is 'undesirable', e.g. i16 on x86, consider
1289   // promoting it.
1290   unsigned Opc = Op.getOpcode();
1291   if (TLI.isTypeDesirableForOp(Opc, VT))
1292     return SDValue();
1293 
1294   EVT PVT = VT;
1295   // Consult target whether it is a good idea to promote this operation and
1296   // what's the right type to promote it to.
1297   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1298     assert(PVT != VT && "Don't know what type to promote to!");
1299 
1300     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1301 
1302     bool Replace0 = false;
1303     SDValue N0 = Op.getOperand(0);
1304     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1305 
1306     bool Replace1 = false;
1307     SDValue N1 = Op.getOperand(1);
1308     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1309     SDLoc DL(Op);
1310 
1311     SDValue RV =
1312         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1313 
1314     // We are always replacing N0/N1's use in N and only need additional
1315     // replacements if there are additional uses.
1316     // Note: We are checking uses of the *nodes* (SDNode) rather than values
1317     //       (SDValue) here because the node may reference multiple values
1318     //       (for example, the chain value of a load node).
1319     Replace0 &= !N0->hasOneUse();
1320     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1321 
1322     // Combine Op here so it is preserved past replacements.
1323     CombineTo(Op.getNode(), RV);
1324 
1325     // If operands have a use ordering, make sure we deal with
1326     // predecessor first.
1327     if (Replace0 && Replace1 && N0.getNode()->isPredecessorOf(N1.getNode())) {
1328       std::swap(N0, N1);
1329       std::swap(NN0, NN1);
1330     }
1331 
1332     if (Replace0) {
1333       AddToWorklist(NN0.getNode());
1334       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1335     }
1336     if (Replace1) {
1337       AddToWorklist(NN1.getNode());
1338       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1339     }
1340     return Op;
1341   }
1342   return SDValue();
1343 }
1344 
1345 /// Promote the specified integer shift operation if the target indicates it is
1346 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1347 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1348 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1349   if (!LegalOperations)
1350     return SDValue();
1351 
1352   EVT VT = Op.getValueType();
1353   if (VT.isVector() || !VT.isInteger())
1354     return SDValue();
1355 
1356   // If operation type is 'undesirable', e.g. i16 on x86, consider
1357   // promoting it.
1358   unsigned Opc = Op.getOpcode();
1359   if (TLI.isTypeDesirableForOp(Opc, VT))
1360     return SDValue();
1361 
1362   EVT PVT = VT;
1363   // Consult target whether it is a good idea to promote this operation and
1364   // what's the right type to promote it to.
1365   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1366     assert(PVT != VT && "Don't know what type to promote to!");
1367 
1368     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1369 
1370     bool Replace = false;
1371     SDValue N0 = Op.getOperand(0);
1372     SDValue N1 = Op.getOperand(1);
1373     if (Opc == ISD::SRA)
1374       N0 = SExtPromoteOperand(N0, PVT);
1375     else if (Opc == ISD::SRL)
1376       N0 = ZExtPromoteOperand(N0, PVT);
1377     else
1378       N0 = PromoteOperand(N0, PVT, Replace);
1379 
1380     if (!N0.getNode())
1381       return SDValue();
1382 
1383     SDLoc DL(Op);
1384     SDValue RV =
1385         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1386 
1387     if (Replace)
1388       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1389 
1390     // Deal with Op being deleted.
1391     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1392       return RV;
1393   }
1394   return SDValue();
1395 }
1396 
PromoteExtend(SDValue Op)1397 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1398   if (!LegalOperations)
1399     return SDValue();
1400 
1401   EVT VT = Op.getValueType();
1402   if (VT.isVector() || !VT.isInteger())
1403     return SDValue();
1404 
1405   // If operation type is 'undesirable', e.g. i16 on x86, consider
1406   // promoting it.
1407   unsigned Opc = Op.getOpcode();
1408   if (TLI.isTypeDesirableForOp(Opc, VT))
1409     return SDValue();
1410 
1411   EVT PVT = VT;
1412   // Consult target whether it is a good idea to promote this operation and
1413   // what's the right type to promote it to.
1414   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1415     assert(PVT != VT && "Don't know what type to promote to!");
1416     // fold (aext (aext x)) -> (aext x)
1417     // fold (aext (zext x)) -> (zext x)
1418     // fold (aext (sext x)) -> (sext x)
1419     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1420     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1421   }
1422   return SDValue();
1423 }
1424 
PromoteLoad(SDValue Op)1425 bool DAGCombiner::PromoteLoad(SDValue Op) {
1426   if (!LegalOperations)
1427     return false;
1428 
1429   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1430     return false;
1431 
1432   EVT VT = Op.getValueType();
1433   if (VT.isVector() || !VT.isInteger())
1434     return false;
1435 
1436   // If operation type is 'undesirable', e.g. i16 on x86, consider
1437   // promoting it.
1438   unsigned Opc = Op.getOpcode();
1439   if (TLI.isTypeDesirableForOp(Opc, VT))
1440     return false;
1441 
1442   EVT PVT = VT;
1443   // Consult target whether it is a good idea to promote this operation and
1444   // what's the right type to promote it to.
1445   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1446     assert(PVT != VT && "Don't know what type to promote to!");
1447 
1448     SDLoc DL(Op);
1449     SDNode *N = Op.getNode();
1450     LoadSDNode *LD = cast<LoadSDNode>(N);
1451     EVT MemVT = LD->getMemoryVT();
1452     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1453                                                       : LD->getExtensionType();
1454     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1455                                    LD->getChain(), LD->getBasePtr(),
1456                                    MemVT, LD->getMemOperand());
1457     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1458 
1459     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1460                Result.getNode()->dump(&DAG); dbgs() << '\n');
1461     WorklistRemover DeadNodes(*this);
1462     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1463     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1464     deleteAndRecombine(N);
1465     AddToWorklist(Result.getNode());
1466     return true;
1467   }
1468   return false;
1469 }
1470 
1471 /// Recursively delete a node which has no uses and any operands for
1472 /// which it is the only use.
1473 ///
1474 /// Note that this both deletes the nodes and removes them from the worklist.
1475 /// It also adds any nodes who have had a user deleted to the worklist as they
1476 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1477 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1478   if (!N->use_empty())
1479     return false;
1480 
1481   SmallSetVector<SDNode *, 16> Nodes;
1482   Nodes.insert(N);
1483   do {
1484     N = Nodes.pop_back_val();
1485     if (!N)
1486       continue;
1487 
1488     if (N->use_empty()) {
1489       for (const SDValue &ChildN : N->op_values())
1490         Nodes.insert(ChildN.getNode());
1491 
1492       removeFromWorklist(N);
1493       DAG.DeleteNode(N);
1494     } else {
1495       AddToWorklist(N);
1496     }
1497   } while (!Nodes.empty());
1498   return true;
1499 }
1500 
1501 //===----------------------------------------------------------------------===//
1502 //  Main DAG Combiner implementation
1503 //===----------------------------------------------------------------------===//
1504 
Run(CombineLevel AtLevel)1505 void DAGCombiner::Run(CombineLevel AtLevel) {
1506   // set the instance variables, so that the various visit routines may use it.
1507   Level = AtLevel;
1508   LegalDAG = Level >= AfterLegalizeDAG;
1509   LegalOperations = Level >= AfterLegalizeVectorOps;
1510   LegalTypes = Level >= AfterLegalizeTypes;
1511 
1512   WorklistInserter AddNodes(*this);
1513 
1514   // Add all the dag nodes to the worklist.
1515   for (SDNode &Node : DAG.allnodes())
1516     AddToWorklist(&Node);
1517 
1518   // Create a dummy node (which is not added to allnodes), that adds a reference
1519   // to the root node, preventing it from being deleted, and tracking any
1520   // changes of the root.
1521   HandleSDNode Dummy(DAG.getRoot());
1522 
1523   // While we have a valid worklist entry node, try to combine it.
1524   while (SDNode *N = getNextWorklistEntry()) {
1525     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1526     // N is deleted from the DAG, since they too may now be dead or may have a
1527     // reduced number of uses, allowing other xforms.
1528     if (recursivelyDeleteUnusedNodes(N))
1529       continue;
1530 
1531     WorklistRemover DeadNodes(*this);
1532 
1533     // If this combine is running after legalizing the DAG, re-legalize any
1534     // nodes pulled off the worklist.
1535     if (LegalDAG) {
1536       SmallSetVector<SDNode *, 16> UpdatedNodes;
1537       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1538 
1539       for (SDNode *LN : UpdatedNodes)
1540         AddToWorklistWithUsers(LN);
1541 
1542       if (!NIsValid)
1543         continue;
1544     }
1545 
1546     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1547 
1548     // Add any operands of the new node which have not yet been combined to the
1549     // worklist as well. Because the worklist uniques things already, this
1550     // won't repeatedly process the same operand.
1551     CombinedNodes.insert(N);
1552     for (const SDValue &ChildN : N->op_values())
1553       if (!CombinedNodes.count(ChildN.getNode()))
1554         AddToWorklist(ChildN.getNode());
1555 
1556     SDValue RV = combine(N);
1557 
1558     if (!RV.getNode())
1559       continue;
1560 
1561     ++NodesCombined;
1562 
1563     // If we get back the same node we passed in, rather than a new node or
1564     // zero, we know that the node must have defined multiple values and
1565     // CombineTo was used.  Since CombineTo takes care of the worklist
1566     // mechanics for us, we have no work to do in this case.
1567     if (RV.getNode() == N)
1568       continue;
1569 
1570     assert(N->getOpcode() != ISD::DELETED_NODE &&
1571            RV.getOpcode() != ISD::DELETED_NODE &&
1572            "Node was deleted but visit returned new node!");
1573 
1574     LLVM_DEBUG(dbgs() << " ... into: "; RV.getNode()->dump(&DAG));
1575 
1576     if (N->getNumValues() == RV.getNode()->getNumValues())
1577       DAG.ReplaceAllUsesWith(N, RV.getNode());
1578     else {
1579       assert(N->getValueType(0) == RV.getValueType() &&
1580              N->getNumValues() == 1 && "Type mismatch");
1581       DAG.ReplaceAllUsesWith(N, &RV);
1582     }
1583 
1584     // Push the new node and any users onto the worklist.  Omit this if the
1585     // new node is the EntryToken (e.g. if a store managed to get optimized
1586     // out), because re-visiting the EntryToken and its users will not uncover
1587     // any additional opportunities, but there may be a large number of such
1588     // users, potentially causing compile time explosion.
1589     if (RV.getOpcode() != ISD::EntryToken) {
1590       AddToWorklist(RV.getNode());
1591       AddUsersToWorklist(RV.getNode());
1592     }
1593 
1594     // Finally, if the node is now dead, remove it from the graph.  The node
1595     // may not be dead if the replacement process recursively simplified to
1596     // something else needing this node. This will also take care of adding any
1597     // operands which have lost a user to the worklist.
1598     recursivelyDeleteUnusedNodes(N);
1599   }
1600 
1601   // If the root changed (e.g. it was a dead load, update the root).
1602   DAG.setRoot(Dummy.getValue());
1603   DAG.RemoveDeadNodes();
1604 }
1605 
visit(SDNode * N)1606 SDValue DAGCombiner::visit(SDNode *N) {
1607   switch (N->getOpcode()) {
1608   default: break;
1609   case ISD::TokenFactor:        return visitTokenFactor(N);
1610   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1611   case ISD::ADD:                return visitADD(N);
1612   case ISD::SUB:                return visitSUB(N);
1613   case ISD::SADDSAT:
1614   case ISD::UADDSAT:            return visitADDSAT(N);
1615   case ISD::SSUBSAT:
1616   case ISD::USUBSAT:            return visitSUBSAT(N);
1617   case ISD::ADDC:               return visitADDC(N);
1618   case ISD::SADDO:
1619   case ISD::UADDO:              return visitADDO(N);
1620   case ISD::SUBC:               return visitSUBC(N);
1621   case ISD::SSUBO:
1622   case ISD::USUBO:              return visitSUBO(N);
1623   case ISD::ADDE:               return visitADDE(N);
1624   case ISD::ADDCARRY:           return visitADDCARRY(N);
1625   case ISD::SADDO_CARRY:        return visitSADDO_CARRY(N);
1626   case ISD::SUBE:               return visitSUBE(N);
1627   case ISD::SUBCARRY:           return visitSUBCARRY(N);
1628   case ISD::SSUBO_CARRY:        return visitSSUBO_CARRY(N);
1629   case ISD::SMULFIX:
1630   case ISD::SMULFIXSAT:
1631   case ISD::UMULFIX:
1632   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1633   case ISD::MUL:                return visitMUL(N);
1634   case ISD::SDIV:               return visitSDIV(N);
1635   case ISD::UDIV:               return visitUDIV(N);
1636   case ISD::SREM:
1637   case ISD::UREM:               return visitREM(N);
1638   case ISD::MULHU:              return visitMULHU(N);
1639   case ISD::MULHS:              return visitMULHS(N);
1640   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1641   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1642   case ISD::SMULO:
1643   case ISD::UMULO:              return visitMULO(N);
1644   case ISD::SMIN:
1645   case ISD::SMAX:
1646   case ISD::UMIN:
1647   case ISD::UMAX:               return visitIMINMAX(N);
1648   case ISD::AND:                return visitAND(N);
1649   case ISD::OR:                 return visitOR(N);
1650   case ISD::XOR:                return visitXOR(N);
1651   case ISD::SHL:                return visitSHL(N);
1652   case ISD::SRA:                return visitSRA(N);
1653   case ISD::SRL:                return visitSRL(N);
1654   case ISD::ROTR:
1655   case ISD::ROTL:               return visitRotate(N);
1656   case ISD::FSHL:
1657   case ISD::FSHR:               return visitFunnelShift(N);
1658   case ISD::ABS:                return visitABS(N);
1659   case ISD::BSWAP:              return visitBSWAP(N);
1660   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1661   case ISD::CTLZ:               return visitCTLZ(N);
1662   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1663   case ISD::CTTZ:               return visitCTTZ(N);
1664   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1665   case ISD::CTPOP:              return visitCTPOP(N);
1666   case ISD::SELECT:             return visitSELECT(N);
1667   case ISD::VSELECT:            return visitVSELECT(N);
1668   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1669   case ISD::SETCC:              return visitSETCC(N);
1670   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1671   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1672   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1673   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1674   case ISD::AssertSext:
1675   case ISD::AssertZext:         return visitAssertExt(N);
1676   case ISD::AssertAlign:        return visitAssertAlign(N);
1677   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
1678   case ISD::SIGN_EXTEND_VECTOR_INREG:
1679   case ISD::ZERO_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1680   case ISD::TRUNCATE:           return visitTRUNCATE(N);
1681   case ISD::BITCAST:            return visitBITCAST(N);
1682   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
1683   case ISD::FADD:               return visitFADD(N);
1684   case ISD::STRICT_FADD:        return visitSTRICT_FADD(N);
1685   case ISD::FSUB:               return visitFSUB(N);
1686   case ISD::FMUL:               return visitFMUL(N);
1687   case ISD::FMA:                return visitFMA(N);
1688   case ISD::FDIV:               return visitFDIV(N);
1689   case ISD::FREM:               return visitFREM(N);
1690   case ISD::FSQRT:              return visitFSQRT(N);
1691   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
1692   case ISD::FPOW:               return visitFPOW(N);
1693   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
1694   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
1695   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
1696   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
1697   case ISD::FP_ROUND:           return visitFP_ROUND(N);
1698   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
1699   case ISD::FNEG:               return visitFNEG(N);
1700   case ISD::FABS:               return visitFABS(N);
1701   case ISD::FFLOOR:             return visitFFLOOR(N);
1702   case ISD::FMINNUM:            return visitFMINNUM(N);
1703   case ISD::FMAXNUM:            return visitFMAXNUM(N);
1704   case ISD::FMINIMUM:           return visitFMINIMUM(N);
1705   case ISD::FMAXIMUM:           return visitFMAXIMUM(N);
1706   case ISD::FCEIL:              return visitFCEIL(N);
1707   case ISD::FTRUNC:             return visitFTRUNC(N);
1708   case ISD::BRCOND:             return visitBRCOND(N);
1709   case ISD::BR_CC:              return visitBR_CC(N);
1710   case ISD::LOAD:               return visitLOAD(N);
1711   case ISD::STORE:              return visitSTORE(N);
1712   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
1713   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1714   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
1715   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
1716   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
1717   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
1718   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
1719   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
1720   case ISD::MGATHER:            return visitMGATHER(N);
1721   case ISD::MLOAD:              return visitMLOAD(N);
1722   case ISD::MSCATTER:           return visitMSCATTER(N);
1723   case ISD::MSTORE:             return visitMSTORE(N);
1724   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
1725   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
1726   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
1727   case ISD::FREEZE:             return visitFREEZE(N);
1728   case ISD::VECREDUCE_FADD:
1729   case ISD::VECREDUCE_FMUL:
1730   case ISD::VECREDUCE_ADD:
1731   case ISD::VECREDUCE_MUL:
1732   case ISD::VECREDUCE_AND:
1733   case ISD::VECREDUCE_OR:
1734   case ISD::VECREDUCE_XOR:
1735   case ISD::VECREDUCE_SMAX:
1736   case ISD::VECREDUCE_SMIN:
1737   case ISD::VECREDUCE_UMAX:
1738   case ISD::VECREDUCE_UMIN:
1739   case ISD::VECREDUCE_FMAX:
1740   case ISD::VECREDUCE_FMIN:     return visitVECREDUCE(N);
1741   }
1742   return SDValue();
1743 }
1744 
combine(SDNode * N)1745 SDValue DAGCombiner::combine(SDNode *N) {
1746   SDValue RV;
1747   if (!DisableGenericCombines)
1748     RV = visit(N);
1749 
1750   // If nothing happened, try a target-specific DAG combine.
1751   if (!RV.getNode()) {
1752     assert(N->getOpcode() != ISD::DELETED_NODE &&
1753            "Node was deleted but visit returned NULL!");
1754 
1755     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1756         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
1757 
1758       // Expose the DAG combiner to the target combiner impls.
1759       TargetLowering::DAGCombinerInfo
1760         DagCombineInfo(DAG, Level, false, this);
1761 
1762       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
1763     }
1764   }
1765 
1766   // If nothing happened still, try promoting the operation.
1767   if (!RV.getNode()) {
1768     switch (N->getOpcode()) {
1769     default: break;
1770     case ISD::ADD:
1771     case ISD::SUB:
1772     case ISD::MUL:
1773     case ISD::AND:
1774     case ISD::OR:
1775     case ISD::XOR:
1776       RV = PromoteIntBinOp(SDValue(N, 0));
1777       break;
1778     case ISD::SHL:
1779     case ISD::SRA:
1780     case ISD::SRL:
1781       RV = PromoteIntShiftOp(SDValue(N, 0));
1782       break;
1783     case ISD::SIGN_EXTEND:
1784     case ISD::ZERO_EXTEND:
1785     case ISD::ANY_EXTEND:
1786       RV = PromoteExtend(SDValue(N, 0));
1787       break;
1788     case ISD::LOAD:
1789       if (PromoteLoad(SDValue(N, 0)))
1790         RV = SDValue(N, 0);
1791       break;
1792     }
1793   }
1794 
1795   // If N is a commutative binary node, try to eliminate it if the commuted
1796   // version is already present in the DAG.
1797   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) &&
1798       N->getNumValues() == 1) {
1799     SDValue N0 = N->getOperand(0);
1800     SDValue N1 = N->getOperand(1);
1801 
1802     // Constant operands are canonicalized to RHS.
1803     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
1804       SDValue Ops[] = {N1, N0};
1805       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
1806                                             N->getFlags());
1807       if (CSENode)
1808         return SDValue(CSENode, 0);
1809     }
1810   }
1811 
1812   return RV;
1813 }
1814 
1815 /// Given a node, return its input chain if it has one, otherwise return a null
1816 /// sd operand.
getInputChainForNode(SDNode * N)1817 static SDValue getInputChainForNode(SDNode *N) {
1818   if (unsigned NumOps = N->getNumOperands()) {
1819     if (N->getOperand(0).getValueType() == MVT::Other)
1820       return N->getOperand(0);
1821     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
1822       return N->getOperand(NumOps-1);
1823     for (unsigned i = 1; i < NumOps-1; ++i)
1824       if (N->getOperand(i).getValueType() == MVT::Other)
1825         return N->getOperand(i);
1826   }
1827   return SDValue();
1828 }
1829 
visitTokenFactor(SDNode * N)1830 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
1831   // If N has two operands, where one has an input chain equal to the other,
1832   // the 'other' chain is redundant.
1833   if (N->getNumOperands() == 2) {
1834     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
1835       return N->getOperand(0);
1836     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
1837       return N->getOperand(1);
1838   }
1839 
1840   // Don't simplify token factors if optnone.
1841   if (OptLevel == CodeGenOpt::None)
1842     return SDValue();
1843 
1844   // Don't simplify the token factor if the node itself has too many operands.
1845   if (N->getNumOperands() > TokenFactorInlineLimit)
1846     return SDValue();
1847 
1848   // If the sole user is a token factor, we should make sure we have a
1849   // chance to merge them together. This prevents TF chains from inhibiting
1850   // optimizations.
1851   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
1852     AddToWorklist(*(N->use_begin()));
1853 
1854   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
1855   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
1856   SmallPtrSet<SDNode*, 16> SeenOps;
1857   bool Changed = false;             // If we should replace this token factor.
1858 
1859   // Start out with this token factor.
1860   TFs.push_back(N);
1861 
1862   // Iterate through token factors.  The TFs grows when new token factors are
1863   // encountered.
1864   for (unsigned i = 0; i < TFs.size(); ++i) {
1865     // Limit number of nodes to inline, to avoid quadratic compile times.
1866     // We have to add the outstanding Token Factors to Ops, otherwise we might
1867     // drop Ops from the resulting Token Factors.
1868     if (Ops.size() > TokenFactorInlineLimit) {
1869       for (unsigned j = i; j < TFs.size(); j++)
1870         Ops.emplace_back(TFs[j], 0);
1871       // Drop unprocessed Token Factors from TFs, so we do not add them to the
1872       // combiner worklist later.
1873       TFs.resize(i);
1874       break;
1875     }
1876 
1877     SDNode *TF = TFs[i];
1878     // Check each of the operands.
1879     for (const SDValue &Op : TF->op_values()) {
1880       switch (Op.getOpcode()) {
1881       case ISD::EntryToken:
1882         // Entry tokens don't need to be added to the list. They are
1883         // redundant.
1884         Changed = true;
1885         break;
1886 
1887       case ISD::TokenFactor:
1888         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
1889           // Queue up for processing.
1890           TFs.push_back(Op.getNode());
1891           Changed = true;
1892           break;
1893         }
1894         LLVM_FALLTHROUGH;
1895 
1896       default:
1897         // Only add if it isn't already in the list.
1898         if (SeenOps.insert(Op.getNode()).second)
1899           Ops.push_back(Op);
1900         else
1901           Changed = true;
1902         break;
1903       }
1904     }
1905   }
1906 
1907   // Re-visit inlined Token Factors, to clean them up in case they have been
1908   // removed. Skip the first Token Factor, as this is the current node.
1909   for (unsigned i = 1, e = TFs.size(); i < e; i++)
1910     AddToWorklist(TFs[i]);
1911 
1912   // Remove Nodes that are chained to another node in the list. Do so
1913   // by walking up chains breath-first stopping when we've seen
1914   // another operand. In general we must climb to the EntryNode, but we can exit
1915   // early if we find all remaining work is associated with just one operand as
1916   // no further pruning is possible.
1917 
1918   // List of nodes to search through and original Ops from which they originate.
1919   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
1920   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
1921   SmallPtrSet<SDNode *, 16> SeenChains;
1922   bool DidPruneOps = false;
1923 
1924   unsigned NumLeftToConsider = 0;
1925   for (const SDValue &Op : Ops) {
1926     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
1927     OpWorkCount.push_back(1);
1928   }
1929 
1930   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
1931     // If this is an Op, we can remove the op from the list. Remark any
1932     // search associated with it as from the current OpNumber.
1933     if (SeenOps.contains(Op)) {
1934       Changed = true;
1935       DidPruneOps = true;
1936       unsigned OrigOpNumber = 0;
1937       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
1938         OrigOpNumber++;
1939       assert((OrigOpNumber != Ops.size()) &&
1940              "expected to find TokenFactor Operand");
1941       // Re-mark worklist from OrigOpNumber to OpNumber
1942       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
1943         if (Worklist[i].second == OrigOpNumber) {
1944           Worklist[i].second = OpNumber;
1945         }
1946       }
1947       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
1948       OpWorkCount[OrigOpNumber] = 0;
1949       NumLeftToConsider--;
1950     }
1951     // Add if it's a new chain
1952     if (SeenChains.insert(Op).second) {
1953       OpWorkCount[OpNumber]++;
1954       Worklist.push_back(std::make_pair(Op, OpNumber));
1955     }
1956   };
1957 
1958   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
1959     // We need at least be consider at least 2 Ops to prune.
1960     if (NumLeftToConsider <= 1)
1961       break;
1962     auto CurNode = Worklist[i].first;
1963     auto CurOpNumber = Worklist[i].second;
1964     assert((OpWorkCount[CurOpNumber] > 0) &&
1965            "Node should not appear in worklist");
1966     switch (CurNode->getOpcode()) {
1967     case ISD::EntryToken:
1968       // Hitting EntryToken is the only way for the search to terminate without
1969       // hitting
1970       // another operand's search. Prevent us from marking this operand
1971       // considered.
1972       NumLeftToConsider++;
1973       break;
1974     case ISD::TokenFactor:
1975       for (const SDValue &Op : CurNode->op_values())
1976         AddToWorklist(i, Op.getNode(), CurOpNumber);
1977       break;
1978     case ISD::LIFETIME_START:
1979     case ISD::LIFETIME_END:
1980     case ISD::CopyFromReg:
1981     case ISD::CopyToReg:
1982       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
1983       break;
1984     default:
1985       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
1986         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
1987       break;
1988     }
1989     OpWorkCount[CurOpNumber]--;
1990     if (OpWorkCount[CurOpNumber] == 0)
1991       NumLeftToConsider--;
1992   }
1993 
1994   // If we've changed things around then replace token factor.
1995   if (Changed) {
1996     SDValue Result;
1997     if (Ops.empty()) {
1998       // The entry token is the only possible outcome.
1999       Result = DAG.getEntryNode();
2000     } else {
2001       if (DidPruneOps) {
2002         SmallVector<SDValue, 8> PrunedOps;
2003         //
2004         for (const SDValue &Op : Ops) {
2005           if (SeenChains.count(Op.getNode()) == 0)
2006             PrunedOps.push_back(Op);
2007         }
2008         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2009       } else {
2010         Result = DAG.getTokenFactor(SDLoc(N), Ops);
2011       }
2012     }
2013     return Result;
2014   }
2015   return SDValue();
2016 }
2017 
2018 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2019 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2020   WorklistRemover DeadNodes(*this);
2021   // Replacing results may cause a different MERGE_VALUES to suddenly
2022   // be CSE'd with N, and carry its uses with it. Iterate until no
2023   // uses remain, to ensure that the node can be safely deleted.
2024   // First add the users of this node to the work list so that they
2025   // can be tried again once they have new operands.
2026   AddUsersToWorklist(N);
2027   do {
2028     // Do as a single replacement to avoid rewalking use lists.
2029     SmallVector<SDValue, 8> Ops;
2030     for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2031       Ops.push_back(N->getOperand(i));
2032     DAG.ReplaceAllUsesWith(N, Ops.data());
2033   } while (!N->use_empty());
2034   deleteAndRecombine(N);
2035   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
2036 }
2037 
2038 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2039 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2040 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2041   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2042   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2043 }
2044 
2045 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2046 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2047 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2048                                     const TargetLowering &TLI) {
2049   EVT VT;
2050   unsigned AS;
2051 
2052   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2053     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2054       return false;
2055     VT = LD->getMemoryVT();
2056     AS = LD->getAddressSpace();
2057   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2058     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2059       return false;
2060     VT = ST->getMemoryVT();
2061     AS = ST->getAddressSpace();
2062   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2063     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2064       return false;
2065     VT = LD->getMemoryVT();
2066     AS = LD->getAddressSpace();
2067   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2068     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2069       return false;
2070     VT = ST->getMemoryVT();
2071     AS = ST->getAddressSpace();
2072   } else
2073     return false;
2074 
2075   TargetLowering::AddrMode AM;
2076   if (N->getOpcode() == ISD::ADD) {
2077     AM.HasBaseReg = true;
2078     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2079     if (Offset)
2080       // [reg +/- imm]
2081       AM.BaseOffs = Offset->getSExtValue();
2082     else
2083       // [reg +/- reg]
2084       AM.Scale = 1;
2085   } else if (N->getOpcode() == ISD::SUB) {
2086     AM.HasBaseReg = true;
2087     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2088     if (Offset)
2089       // [reg +/- imm]
2090       AM.BaseOffs = -Offset->getSExtValue();
2091     else
2092       // [reg +/- reg]
2093       AM.Scale = 1;
2094   } else
2095     return false;
2096 
2097   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2098                                    VT.getTypeForEVT(*DAG.getContext()), AS);
2099 }
2100 
foldBinOpIntoSelect(SDNode * BO)2101 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2102   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2103          "Unexpected binary operator");
2104 
2105   // Don't do this unless the old select is going away. We want to eliminate the
2106   // binary operator, not replace a binop with a select.
2107   // TODO: Handle ISD::SELECT_CC.
2108   unsigned SelOpNo = 0;
2109   SDValue Sel = BO->getOperand(0);
2110   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2111     SelOpNo = 1;
2112     Sel = BO->getOperand(1);
2113   }
2114 
2115   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2116     return SDValue();
2117 
2118   SDValue CT = Sel.getOperand(1);
2119   if (!isConstantOrConstantVector(CT, true) &&
2120       !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2121     return SDValue();
2122 
2123   SDValue CF = Sel.getOperand(2);
2124   if (!isConstantOrConstantVector(CF, true) &&
2125       !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2126     return SDValue();
2127 
2128   // Bail out if any constants are opaque because we can't constant fold those.
2129   // The exception is "and" and "or" with either 0 or -1 in which case we can
2130   // propagate non constant operands into select. I.e.:
2131   // and (select Cond, 0, -1), X --> select Cond, 0, X
2132   // or X, (select Cond, -1, 0) --> select Cond, -1, X
2133   auto BinOpcode = BO->getOpcode();
2134   bool CanFoldNonConst =
2135       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2136       (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) &&
2137       (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF));
2138 
2139   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2140   if (!CanFoldNonConst &&
2141       !isConstantOrConstantVector(CBO, true) &&
2142       !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2143     return SDValue();
2144 
2145   EVT VT = BO->getValueType(0);
2146 
2147   // We have a select-of-constants followed by a binary operator with a
2148   // constant. Eliminate the binop by pulling the constant math into the select.
2149   // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
2150   SDLoc DL(Sel);
2151   SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
2152                           : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
2153   if (!CanFoldNonConst && !NewCT.isUndef() &&
2154       !isConstantOrConstantVector(NewCT, true) &&
2155       !DAG.isConstantFPBuildVectorOrConstantFP(NewCT))
2156     return SDValue();
2157 
2158   SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
2159                           : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
2160   if (!CanFoldNonConst && !NewCF.isUndef() &&
2161       !isConstantOrConstantVector(NewCF, true) &&
2162       !DAG.isConstantFPBuildVectorOrConstantFP(NewCF))
2163     return SDValue();
2164 
2165   SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2166   SelectOp->setFlags(BO->getFlags());
2167   return SelectOp;
2168 }
2169 
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2170 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2171   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2172          "Expecting add or sub");
2173 
2174   // Match a constant operand and a zext operand for the math instruction:
2175   // add Z, C
2176   // sub C, Z
2177   bool IsAdd = N->getOpcode() == ISD::ADD;
2178   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2179   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2180   auto *CN = dyn_cast<ConstantSDNode>(C);
2181   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2182     return SDValue();
2183 
2184   // Match the zext operand as a setcc of a boolean.
2185   if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2186       Z.getOperand(0).getValueType() != MVT::i1)
2187     return SDValue();
2188 
2189   // Match the compare as: setcc (X & 1), 0, eq.
2190   SDValue SetCC = Z.getOperand(0);
2191   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2192   if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2193       SetCC.getOperand(0).getOpcode() != ISD::AND ||
2194       !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2195     return SDValue();
2196 
2197   // We are adding/subtracting a constant and an inverted low bit. Turn that
2198   // into a subtract/add of the low bit with incremented/decremented constant:
2199   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2200   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2201   EVT VT = C.getValueType();
2202   SDLoc DL(N);
2203   SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2204   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2205                        DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2206   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2207 }
2208 
2209 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2210 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2211 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2212   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2213          "Expecting add or sub");
2214 
2215   // We need a constant operand for the add/sub, and the other operand is a
2216   // logical shift right: add (srl), C or sub C, (srl).
2217   bool IsAdd = N->getOpcode() == ISD::ADD;
2218   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2219   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2220   if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2221       ShiftOp.getOpcode() != ISD::SRL)
2222     return SDValue();
2223 
2224   // The shift must be of a 'not' value.
2225   SDValue Not = ShiftOp.getOperand(0);
2226   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2227     return SDValue();
2228 
2229   // The shift must be moving the sign bit to the least-significant-bit.
2230   EVT VT = ShiftOp.getValueType();
2231   SDValue ShAmt = ShiftOp.getOperand(1);
2232   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2233   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2234     return SDValue();
2235 
2236   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2237   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2238   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2239   SDLoc DL(N);
2240   auto ShOpcode = IsAdd ? ISD::SRA : ISD::SRL;
2241   SDValue NewShift = DAG.getNode(ShOpcode, DL, VT, Not.getOperand(0), ShAmt);
2242   if (SDValue NewC =
2243           DAG.FoldConstantArithmetic(IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2244                                      {ConstantOp, DAG.getConstant(1, DL, VT)}))
2245     return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2246   return SDValue();
2247 }
2248 
2249 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2250 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2251 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2252 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2253   SDValue N0 = N->getOperand(0);
2254   SDValue N1 = N->getOperand(1);
2255   EVT VT = N0.getValueType();
2256   SDLoc DL(N);
2257 
2258   // fold vector ops
2259   if (VT.isVector()) {
2260     if (SDValue FoldedVOp = SimplifyVBinOp(N))
2261       return FoldedVOp;
2262 
2263     // fold (add x, 0) -> x, vector edition
2264     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2265       return N0;
2266     if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
2267       return N1;
2268   }
2269 
2270   // fold (add x, undef) -> undef
2271   if (N0.isUndef())
2272     return N0;
2273 
2274   if (N1.isUndef())
2275     return N1;
2276 
2277   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
2278     // canonicalize constant to RHS
2279     if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
2280       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2281     // fold (add c1, c2) -> c1+c2
2282     return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1});
2283   }
2284 
2285   // fold (add x, 0) -> x
2286   if (isNullConstant(N1))
2287     return N0;
2288 
2289   if (isConstantOrConstantVector(N1, /* NoOpaque */ true)) {
2290     // fold ((A-c1)+c2) -> (A+(c2-c1))
2291     if (N0.getOpcode() == ISD::SUB &&
2292         isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2293       SDValue Sub =
2294           DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N0.getOperand(1)});
2295       assert(Sub && "Constant folding failed");
2296       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2297     }
2298 
2299     // fold ((c1-A)+c2) -> (c1+c2)-A
2300     if (N0.getOpcode() == ISD::SUB &&
2301         isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) {
2302       SDValue Add =
2303           DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N0.getOperand(0)});
2304       assert(Add && "Constant folding failed");
2305       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2306     }
2307 
2308     // add (sext i1 X), 1 -> zext (not i1 X)
2309     // We don't transform this pattern:
2310     //   add (zext i1 X), -1 -> sext (not i1 X)
2311     // because most (?) targets generate better code for the zext form.
2312     if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2313         isOneOrOneSplat(N1)) {
2314       SDValue X = N0.getOperand(0);
2315       if ((!LegalOperations ||
2316            (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2317             TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2318           X.getScalarValueSizeInBits() == 1) {
2319         SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2320         return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2321       }
2322     }
2323 
2324     // Fold (add (or x, c0), c1) -> (add x, (c0 + c1)) if (or x, c0) is
2325     // equivalent to (add x, c0).
2326     if (N0.getOpcode() == ISD::OR &&
2327         isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true) &&
2328         DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) {
2329       if (SDValue Add0 = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT,
2330                                                     {N1, N0.getOperand(1)}))
2331         return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0);
2332     }
2333   }
2334 
2335   if (SDValue NewSel = foldBinOpIntoSelect(N))
2336     return NewSel;
2337 
2338   // reassociate add
2339   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N0, N1)) {
2340     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2341       return RADD;
2342 
2343     // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2344     // equivalent to (add x, c).
2345     auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2346       if (N0.getOpcode() == ISD::OR && N0.hasOneUse() &&
2347           isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true) &&
2348           DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) {
2349         return DAG.getNode(ISD::ADD, DL, VT,
2350                            DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2351                            N0.getOperand(1));
2352       }
2353       return SDValue();
2354     };
2355     if (SDValue Add = ReassociateAddOr(N0, N1))
2356       return Add;
2357     if (SDValue Add = ReassociateAddOr(N1, N0))
2358       return Add;
2359   }
2360   // fold ((0-A) + B) -> B-A
2361   if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2362     return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2363 
2364   // fold (A + (0-B)) -> A-B
2365   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2366     return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2367 
2368   // fold (A+(B-A)) -> B
2369   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2370     return N1.getOperand(0);
2371 
2372   // fold ((B-A)+A) -> B
2373   if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2374     return N0.getOperand(0);
2375 
2376   // fold ((A-B)+(C-A)) -> (C-B)
2377   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2378       N0.getOperand(0) == N1.getOperand(1))
2379     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2380                        N0.getOperand(1));
2381 
2382   // fold ((A-B)+(B-C)) -> (A-C)
2383   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2384       N0.getOperand(1) == N1.getOperand(0))
2385     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2386                        N1.getOperand(1));
2387 
2388   // fold (A+(B-(A+C))) to (B-C)
2389   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2390       N0 == N1.getOperand(1).getOperand(0))
2391     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2392                        N1.getOperand(1).getOperand(1));
2393 
2394   // fold (A+(B-(C+A))) to (B-C)
2395   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2396       N0 == N1.getOperand(1).getOperand(1))
2397     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2398                        N1.getOperand(1).getOperand(0));
2399 
2400   // fold (A+((B-A)+or-C)) to (B+or-C)
2401   if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2402       N1.getOperand(0).getOpcode() == ISD::SUB &&
2403       N0 == N1.getOperand(0).getOperand(1))
2404     return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2405                        N1.getOperand(1));
2406 
2407   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2408   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB) {
2409     SDValue N00 = N0.getOperand(0);
2410     SDValue N01 = N0.getOperand(1);
2411     SDValue N10 = N1.getOperand(0);
2412     SDValue N11 = N1.getOperand(1);
2413 
2414     if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2415       return DAG.getNode(ISD::SUB, DL, VT,
2416                          DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2417                          DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2418   }
2419 
2420   // fold (add (umax X, C), -C) --> (usubsat X, C)
2421   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2422     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2423       return (!Max && !Op) ||
2424              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2425     };
2426     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2427                                   /*AllowUndefs*/ true))
2428       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2429                          N0.getOperand(1));
2430   }
2431 
2432   if (SimplifyDemandedBits(SDValue(N, 0)))
2433     return SDValue(N, 0);
2434 
2435   if (isOneOrOneSplat(N1)) {
2436     // fold (add (xor a, -1), 1) -> (sub 0, a)
2437     if (isBitwiseNot(N0))
2438       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2439                          N0.getOperand(0));
2440 
2441     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2442     if (N0.getOpcode() == ISD::ADD) {
2443       SDValue A, Xor;
2444 
2445       if (isBitwiseNot(N0.getOperand(0))) {
2446         A = N0.getOperand(1);
2447         Xor = N0.getOperand(0);
2448       } else if (isBitwiseNot(N0.getOperand(1))) {
2449         A = N0.getOperand(0);
2450         Xor = N0.getOperand(1);
2451       }
2452 
2453       if (Xor)
2454         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2455     }
2456 
2457     // Look for:
2458     //   add (add x, y), 1
2459     // And if the target does not like this form then turn into:
2460     //   sub y, (xor x, -1)
2461     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() &&
2462         N0.getOpcode() == ISD::ADD) {
2463       SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2464                                 DAG.getAllOnesConstant(DL, VT));
2465       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2466     }
2467   }
2468 
2469   // (x - y) + -1  ->  add (xor y, -1), x
2470   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2471       isAllOnesOrAllOnesSplat(N1)) {
2472     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2473     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2474   }
2475 
2476   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2477     return Combined;
2478 
2479   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2480     return Combined;
2481 
2482   return SDValue();
2483 }
2484 
visitADD(SDNode * N)2485 SDValue DAGCombiner::visitADD(SDNode *N) {
2486   SDValue N0 = N->getOperand(0);
2487   SDValue N1 = N->getOperand(1);
2488   EVT VT = N0.getValueType();
2489   SDLoc DL(N);
2490 
2491   if (SDValue Combined = visitADDLike(N))
2492     return Combined;
2493 
2494   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2495     return V;
2496 
2497   if (SDValue V = foldAddSubOfSignBit(N, DAG))
2498     return V;
2499 
2500   // fold (a+b) -> (a|b) iff a and b share no bits.
2501   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2502       DAG.haveNoCommonBitsSet(N0, N1))
2503     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2504 
2505   // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2506   if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2507     const APInt &C0 = N0->getConstantOperandAPInt(0);
2508     const APInt &C1 = N1->getConstantOperandAPInt(0);
2509     return DAG.getVScale(DL, VT, C0 + C1);
2510   }
2511 
2512   // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2513   if ((N0.getOpcode() == ISD::ADD) &&
2514       (N0.getOperand(1).getOpcode() == ISD::VSCALE) &&
2515       (N1.getOpcode() == ISD::VSCALE)) {
2516     const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2517     const APInt &VS1 = N1->getConstantOperandAPInt(0);
2518     SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
2519     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
2520   }
2521 
2522   // Fold (add step_vector(c1), step_vector(c2)  to step_vector(c1+c2))
2523   if (N0.getOpcode() == ISD::STEP_VECTOR &&
2524       N1.getOpcode() == ISD::STEP_VECTOR) {
2525     const APInt &C0 = N0->getConstantOperandAPInt(0);
2526     const APInt &C1 = N1->getConstantOperandAPInt(0);
2527     APInt NewStep = C0 + C1;
2528     return DAG.getStepVector(DL, VT, NewStep);
2529   }
2530 
2531   // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
2532   if ((N0.getOpcode() == ISD::ADD) &&
2533       (N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR) &&
2534       (N1.getOpcode() == ISD::STEP_VECTOR)) {
2535     const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2536     const APInt &SV1 = N1->getConstantOperandAPInt(0);
2537     APInt NewStep = SV0 + SV1;
2538     SDValue SV = DAG.getStepVector(DL, VT, NewStep);
2539     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
2540   }
2541 
2542   return SDValue();
2543 }
2544 
visitADDSAT(SDNode * N)2545 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2546   unsigned Opcode = N->getOpcode();
2547   SDValue N0 = N->getOperand(0);
2548   SDValue N1 = N->getOperand(1);
2549   EVT VT = N0.getValueType();
2550   SDLoc DL(N);
2551 
2552   // fold vector ops
2553   if (VT.isVector()) {
2554     // TODO SimplifyVBinOp
2555 
2556     // fold (add_sat x, 0) -> x, vector edition
2557     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2558       return N0;
2559     if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
2560       return N1;
2561   }
2562 
2563   // fold (add_sat x, undef) -> -1
2564   if (N0.isUndef() || N1.isUndef())
2565     return DAG.getAllOnesConstant(DL, VT);
2566 
2567   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
2568     // canonicalize constant to RHS
2569     if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
2570       return DAG.getNode(Opcode, DL, VT, N1, N0);
2571     // fold (add_sat c1, c2) -> c3
2572     return DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1});
2573   }
2574 
2575   // fold (add_sat x, 0) -> x
2576   if (isNullConstant(N1))
2577     return N0;
2578 
2579   // If it cannot overflow, transform into an add.
2580   if (Opcode == ISD::UADDSAT)
2581     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2582       return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
2583 
2584   return SDValue();
2585 }
2586 
getAsCarry(const TargetLowering & TLI,SDValue V)2587 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
2588   bool Masked = false;
2589 
2590   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2591   while (true) {
2592     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2593       V = V.getOperand(0);
2594       continue;
2595     }
2596 
2597     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
2598       Masked = true;
2599       V = V.getOperand(0);
2600       continue;
2601     }
2602 
2603     break;
2604   }
2605 
2606   // If this is not a carry, return.
2607   if (V.getResNo() != 1)
2608     return SDValue();
2609 
2610   if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
2611       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
2612     return SDValue();
2613 
2614   EVT VT = V.getNode()->getValueType(0);
2615   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
2616     return SDValue();
2617 
2618   // If the result is masked, then no matter what kind of bool it is we can
2619   // return. If it isn't, then we need to make sure the bool type is either 0 or
2620   // 1 and not other values.
2621   if (Masked ||
2622       TLI.getBooleanContents(V.getValueType()) ==
2623           TargetLoweringBase::ZeroOrOneBooleanContent)
2624     return V;
2625 
2626   return SDValue();
2627 }
2628 
2629 /// Given the operands of an add/sub operation, see if the 2nd operand is a
2630 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
2631 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)2632 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
2633                                  SelectionDAG &DAG, const SDLoc &DL) {
2634   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
2635     return SDValue();
2636 
2637   EVT VT = N0.getValueType();
2638   if (DAG.ComputeNumSignBits(N1.getOperand(0)) != VT.getScalarSizeInBits())
2639     return SDValue();
2640 
2641   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
2642   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
2643   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N1.getOperand(0));
2644 }
2645 
2646 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)2647 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
2648                                           SDNode *LocReference) {
2649   EVT VT = N0.getValueType();
2650   SDLoc DL(LocReference);
2651 
2652   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
2653   if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
2654       isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
2655     return DAG.getNode(ISD::SUB, DL, VT, N0,
2656                        DAG.getNode(ISD::SHL, DL, VT,
2657                                    N1.getOperand(0).getOperand(1),
2658                                    N1.getOperand(1)));
2659 
2660   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
2661     return V;
2662 
2663   // Look for:
2664   //   add (add x, 1), y
2665   // And if the target does not like this form then turn into:
2666   //   sub y, (xor x, -1)
2667   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() &&
2668       N0.getOpcode() == ISD::ADD && isOneOrOneSplat(N0.getOperand(1))) {
2669     SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2670                               DAG.getAllOnesConstant(DL, VT));
2671     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
2672   }
2673 
2674   // Hoist one-use subtraction by non-opaque constant:
2675   //   (x - C) + y  ->  (x + y) - C
2676   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
2677   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2678       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
2679     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
2680     return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2681   }
2682   // Hoist one-use subtraction from non-opaque constant:
2683   //   (C - x) + y  ->  (y - x) + C
2684   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2685       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
2686     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2687     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
2688   }
2689 
2690   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
2691   // rather than 'add 0/-1' (the zext should get folded).
2692   // add (sext i1 Y), X --> sub X, (zext i1 Y)
2693   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
2694       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
2695       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
2696     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
2697     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
2698   }
2699 
2700   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
2701   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2702     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
2703     if (TN->getVT() == MVT::i1) {
2704       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
2705                                  DAG.getConstant(1, DL, VT));
2706       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
2707     }
2708   }
2709 
2710   // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2711   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
2712       N1.getResNo() == 0)
2713     return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
2714                        N0, N1.getOperand(0), N1.getOperand(2));
2715 
2716   // (add X, Carry) -> (addcarry X, 0, Carry)
2717   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2718     if (SDValue Carry = getAsCarry(TLI, N1))
2719       return DAG.getNode(ISD::ADDCARRY, DL,
2720                          DAG.getVTList(VT, Carry.getValueType()), N0,
2721                          DAG.getConstant(0, DL, VT), Carry);
2722 
2723   return SDValue();
2724 }
2725 
visitADDC(SDNode * N)2726 SDValue DAGCombiner::visitADDC(SDNode *N) {
2727   SDValue N0 = N->getOperand(0);
2728   SDValue N1 = N->getOperand(1);
2729   EVT VT = N0.getValueType();
2730   SDLoc DL(N);
2731 
2732   // If the flag result is dead, turn this into an ADD.
2733   if (!N->hasAnyUseOfValue(1))
2734     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2735                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2736 
2737   // canonicalize constant to RHS.
2738   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2739   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2740   if (N0C && !N1C)
2741     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
2742 
2743   // fold (addc x, 0) -> x + no carry out
2744   if (isNullConstant(N1))
2745     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
2746                                         DL, MVT::Glue));
2747 
2748   // If it cannot overflow, transform into an add.
2749   if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2750     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2751                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2752 
2753   return SDValue();
2754 }
2755 
2756 /**
2757  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
2758  * then the flip also occurs if computing the inverse is the same cost.
2759  * This function returns an empty SDValue in case it cannot flip the boolean
2760  * without increasing the cost of the computation. If you want to flip a boolean
2761  * no matter what, use DAG.getLogicalNOT.
2762  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)2763 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
2764                                   const TargetLowering &TLI,
2765                                   bool Force) {
2766   if (Force && isa<ConstantSDNode>(V))
2767     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2768 
2769   if (V.getOpcode() != ISD::XOR)
2770     return SDValue();
2771 
2772   ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
2773   if (!Const)
2774     return SDValue();
2775 
2776   EVT VT = V.getValueType();
2777 
2778   bool IsFlip = false;
2779   switch(TLI.getBooleanContents(VT)) {
2780     case TargetLowering::ZeroOrOneBooleanContent:
2781       IsFlip = Const->isOne();
2782       break;
2783     case TargetLowering::ZeroOrNegativeOneBooleanContent:
2784       IsFlip = Const->isAllOnesValue();
2785       break;
2786     case TargetLowering::UndefinedBooleanContent:
2787       IsFlip = (Const->getAPIntValue() & 0x01) == 1;
2788       break;
2789   }
2790 
2791   if (IsFlip)
2792     return V.getOperand(0);
2793   if (Force)
2794     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2795   return SDValue();
2796 }
2797 
visitADDO(SDNode * N)2798 SDValue DAGCombiner::visitADDO(SDNode *N) {
2799   SDValue N0 = N->getOperand(0);
2800   SDValue N1 = N->getOperand(1);
2801   EVT VT = N0.getValueType();
2802   bool IsSigned = (ISD::SADDO == N->getOpcode());
2803 
2804   EVT CarryVT = N->getValueType(1);
2805   SDLoc DL(N);
2806 
2807   // If the flag result is dead, turn this into an ADD.
2808   if (!N->hasAnyUseOfValue(1))
2809     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2810                      DAG.getUNDEF(CarryVT));
2811 
2812   // canonicalize constant to RHS.
2813   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2814       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2815     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
2816 
2817   // fold (addo x, 0) -> x + no carry out
2818   if (isNullOrNullSplat(N1))
2819     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
2820 
2821   if (!IsSigned) {
2822     // If it cannot overflow, transform into an add.
2823     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2824       return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2825                        DAG.getConstant(0, DL, CarryVT));
2826 
2827     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
2828     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
2829       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
2830                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
2831       return CombineTo(
2832           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
2833     }
2834 
2835     if (SDValue Combined = visitUADDOLike(N0, N1, N))
2836       return Combined;
2837 
2838     if (SDValue Combined = visitUADDOLike(N1, N0, N))
2839       return Combined;
2840   }
2841 
2842   return SDValue();
2843 }
2844 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)2845 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
2846   EVT VT = N0.getValueType();
2847   if (VT.isVector())
2848     return SDValue();
2849 
2850   // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2851   // If Y + 1 cannot overflow.
2852   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
2853     SDValue Y = N1.getOperand(0);
2854     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
2855     if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
2856       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
2857                          N1.getOperand(2));
2858   }
2859 
2860   // (uaddo X, Carry) -> (addcarry X, 0, Carry)
2861   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2862     if (SDValue Carry = getAsCarry(TLI, N1))
2863       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
2864                          DAG.getConstant(0, SDLoc(N), VT), Carry);
2865 
2866   return SDValue();
2867 }
2868 
visitADDE(SDNode * N)2869 SDValue DAGCombiner::visitADDE(SDNode *N) {
2870   SDValue N0 = N->getOperand(0);
2871   SDValue N1 = N->getOperand(1);
2872   SDValue CarryIn = N->getOperand(2);
2873 
2874   // canonicalize constant to RHS
2875   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2876   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2877   if (N0C && !N1C)
2878     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
2879                        N1, N0, CarryIn);
2880 
2881   // fold (adde x, y, false) -> (addc x, y)
2882   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
2883     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
2884 
2885   return SDValue();
2886 }
2887 
visitADDCARRY(SDNode * N)2888 SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
2889   SDValue N0 = N->getOperand(0);
2890   SDValue N1 = N->getOperand(1);
2891   SDValue CarryIn = N->getOperand(2);
2892   SDLoc DL(N);
2893 
2894   // canonicalize constant to RHS
2895   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2896   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2897   if (N0C && !N1C)
2898     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
2899 
2900   // fold (addcarry x, y, false) -> (uaddo x, y)
2901   if (isNullConstant(CarryIn)) {
2902     if (!LegalOperations ||
2903         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
2904       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
2905   }
2906 
2907   // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
2908   if (isNullConstant(N0) && isNullConstant(N1)) {
2909     EVT VT = N0.getValueType();
2910     EVT CarryVT = CarryIn.getValueType();
2911     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
2912     AddToWorklist(CarryExt.getNode());
2913     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
2914                                     DAG.getConstant(1, DL, VT)),
2915                      DAG.getConstant(0, DL, CarryVT));
2916   }
2917 
2918   if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
2919     return Combined;
2920 
2921   if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
2922     return Combined;
2923 
2924   return SDValue();
2925 }
2926 
visitSADDO_CARRY(SDNode * N)2927 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
2928   SDValue N0 = N->getOperand(0);
2929   SDValue N1 = N->getOperand(1);
2930   SDValue CarryIn = N->getOperand(2);
2931   SDLoc DL(N);
2932 
2933   // canonicalize constant to RHS
2934   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2935   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2936   if (N0C && !N1C)
2937     return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
2938 
2939   // fold (saddo_carry x, y, false) -> (saddo x, y)
2940   if (isNullConstant(CarryIn)) {
2941     if (!LegalOperations ||
2942         TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
2943       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
2944   }
2945 
2946   return SDValue();
2947 }
2948 
2949 /**
2950  * If we are facing some sort of diamond carry propapagtion pattern try to
2951  * break it up to generate something like:
2952  *   (addcarry X, 0, (addcarry A, B, Z):Carry)
2953  *
2954  * The end result is usually an increase in operation required, but because the
2955  * carry is now linearized, other tranforms can kick in and optimize the DAG.
2956  *
2957  * Patterns typically look something like
2958  *            (uaddo A, B)
2959  *             /       \
2960  *          Carry      Sum
2961  *            |          \
2962  *            | (addcarry *, 0, Z)
2963  *            |       /
2964  *             \   Carry
2965  *              |   /
2966  * (addcarry X, *, *)
2967  *
2968  * But numerous variation exist. Our goal is to identify A, B, X and Z and
2969  * produce a combine with a single path for carry propagation.
2970  */
combineADDCARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)2971 static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
2972                                       SDValue X, SDValue Carry0, SDValue Carry1,
2973                                       SDNode *N) {
2974   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
2975     return SDValue();
2976   if (Carry1.getOpcode() != ISD::UADDO)
2977     return SDValue();
2978 
2979   SDValue Z;
2980 
2981   /**
2982    * First look for a suitable Z. It will present itself in the form of
2983    * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
2984    */
2985   if (Carry0.getOpcode() == ISD::ADDCARRY &&
2986       isNullConstant(Carry0.getOperand(1))) {
2987     Z = Carry0.getOperand(2);
2988   } else if (Carry0.getOpcode() == ISD::UADDO &&
2989              isOneConstant(Carry0.getOperand(1))) {
2990     EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
2991     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
2992   } else {
2993     // We couldn't find a suitable Z.
2994     return SDValue();
2995   }
2996 
2997 
2998   auto cancelDiamond = [&](SDValue A,SDValue B) {
2999     SDLoc DL(N);
3000     SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z);
3001     Combiner.AddToWorklist(NewY.getNode());
3002     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X,
3003                        DAG.getConstant(0, DL, X.getValueType()),
3004                        NewY.getValue(1));
3005   };
3006 
3007   /**
3008    *      (uaddo A, B)
3009    *           |
3010    *          Sum
3011    *           |
3012    * (addcarry *, 0, Z)
3013    */
3014   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3015     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3016   }
3017 
3018   /**
3019    * (addcarry A, 0, Z)
3020    *         |
3021    *        Sum
3022    *         |
3023    *  (uaddo *, B)
3024    */
3025   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3026     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3027   }
3028 
3029   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3030     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3031   }
3032 
3033   return SDValue();
3034 }
3035 
3036 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3037 // match patterns like:
3038 //
3039 //          (uaddo A, B)            CarryIn
3040 //            |  \                     |
3041 //            |   \                    |
3042 //    PartialSum   PartialCarryOutX   /
3043 //            |        |             /
3044 //            |    ____|____________/
3045 //            |   /    |
3046 //     (uaddo *, *)    \________
3047 //       |  \                   \
3048 //       |   \                   |
3049 //       |    PartialCarryOutY   |
3050 //       |        \              |
3051 //       |         \            /
3052 //   AddCarrySum    |    ______/
3053 //                  |   /
3054 //   CarryOut = (or *, *)
3055 //
3056 // And generate ADDCARRY (or SUBCARRY) with two result values:
3057 //
3058 //    {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn)
3059 //
3060 // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with
3061 // a single path for carry/borrow out propagation:
combineCarryDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,const TargetLowering & TLI,SDValue Carry0,SDValue Carry1,SDNode * N)3062 static SDValue combineCarryDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
3063                                    const TargetLowering &TLI, SDValue Carry0,
3064                                    SDValue Carry1, SDNode *N) {
3065   if (Carry0.getResNo() != 1 || Carry1.getResNo() != 1)
3066     return SDValue();
3067   unsigned Opcode = Carry0.getOpcode();
3068   if (Opcode != Carry1.getOpcode())
3069     return SDValue();
3070   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3071     return SDValue();
3072 
3073   // Canonicalize the add/sub of A and B as Carry0 and the add/sub of the
3074   // carry/borrow in as Carry1. (The top and middle uaddo nodes respectively in
3075   // the above ASCII art.)
3076   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3077       Carry1.getOperand(1) != Carry0.getValue(0))
3078     std::swap(Carry0, Carry1);
3079   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3080       Carry1.getOperand(1) != Carry0.getValue(0))
3081     return SDValue();
3082 
3083   // The carry in value must be on the righthand side for subtraction.
3084   unsigned CarryInOperandNum =
3085       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3086   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3087     return SDValue();
3088   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3089 
3090   unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY;
3091   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3092     return SDValue();
3093 
3094   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3095   // TODO: make getAsCarry() aware of how partial carries are merged.
3096   if (CarryIn.getOpcode() != ISD::ZERO_EXTEND)
3097     return SDValue();
3098   CarryIn = CarryIn.getOperand(0);
3099   if (CarryIn.getValueType() != MVT::i1)
3100     return SDValue();
3101 
3102   SDLoc DL(N);
3103   SDValue Merged =
3104       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3105                   Carry0.getOperand(1), CarryIn);
3106 
3107   // Please note that because we have proven that the result of the UADDO/USUBO
3108   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3109   // therefore prove that if the first UADDO/USUBO overflows, the second
3110   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3111   // maximum value.
3112   //
3113   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3114   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3115   //
3116   // This is important because it means that OR and XOR can be used to merge
3117   // carry flags; and that AND can return a constant zero.
3118   //
3119   // TODO: match other operations that can merge flags (ADD, etc)
3120   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3121   if (N->getOpcode() == ISD::AND)
3122     return DAG.getConstant(0, DL, MVT::i1);
3123   return Merged.getValue(1);
3124 }
3125 
visitADDCARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3126 SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
3127                                        SDNode *N) {
3128   // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry.
3129   if (isBitwiseNot(N0))
3130     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3131       SDLoc DL(N);
3132       SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1,
3133                                 N0.getOperand(0), NotC);
3134       return CombineTo(
3135           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3136     }
3137 
3138   // Iff the flag result is dead:
3139   // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
3140   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3141   // or the dependency between the instructions.
3142   if ((N0.getOpcode() == ISD::ADD ||
3143        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3144         N0.getValue(1) != CarryIn)) &&
3145       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3146     return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
3147                        N0.getOperand(0), N0.getOperand(1), CarryIn);
3148 
3149   /**
3150    * When one of the addcarry argument is itself a carry, we may be facing
3151    * a diamond carry propagation. In which case we try to transform the DAG
3152    * to ensure linear carry propagation if that is possible.
3153    */
3154   if (auto Y = getAsCarry(TLI, N1)) {
3155     // Because both are carries, Y and Z can be swapped.
3156     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3157       return R;
3158     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3159       return R;
3160   }
3161 
3162   return SDValue();
3163 }
3164 
3165 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3166 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3167 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3168                                    SDValue RHS, SelectionDAG &DAG,
3169                                    const SDLoc &DL) {
3170   assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3171          "Illegal truncation");
3172 
3173   if (DstVT == SrcVT)
3174     return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3175 
3176   // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3177   // clamping RHS.
3178   APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3179                                           DstVT.getScalarSizeInBits());
3180   if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3181     return SDValue();
3182 
3183   SDValue SatLimit =
3184       DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3185                                            DstVT.getScalarSizeInBits()),
3186                       DL, SrcVT);
3187   RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3188   RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3189   LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3190   return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3191 }
3192 
3193 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3194 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N)3195 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
3196   if (N->getOpcode() != ISD::SUB ||
3197       !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3198     return SDValue();
3199 
3200   EVT SubVT = N->getValueType(0);
3201   SDValue Op0 = N->getOperand(0);
3202   SDValue Op1 = N->getOperand(1);
3203 
3204   // Try to find umax(a,b) - b or a - umin(a,b) patterns
3205   // they may be converted to usubsat(a,b).
3206   if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3207     SDValue MaxLHS = Op0.getOperand(0);
3208     SDValue MaxRHS = Op0.getOperand(1);
3209     if (MaxLHS == Op1)
3210       return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, SDLoc(N));
3211     if (MaxRHS == Op1)
3212       return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, SDLoc(N));
3213   }
3214 
3215   if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3216     SDValue MinLHS = Op1.getOperand(0);
3217     SDValue MinRHS = Op1.getOperand(1);
3218     if (MinLHS == Op0)
3219       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, SDLoc(N));
3220     if (MinRHS == Op0)
3221       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, SDLoc(N));
3222   }
3223 
3224   // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3225   if (Op1.getOpcode() == ISD::TRUNCATE &&
3226       Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3227       Op1.getOperand(0).hasOneUse()) {
3228     SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3229     SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3230     if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3231       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3232                                  DAG, SDLoc(N));
3233     if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3234       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3235                                  DAG, SDLoc(N));
3236   }
3237 
3238   return SDValue();
3239 }
3240 
3241 // Since it may not be valid to emit a fold to zero for vector initializers
3242 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)3243 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3244                              SelectionDAG &DAG, bool LegalOperations) {
3245   if (!VT.isVector())
3246     return DAG.getConstant(0, DL, VT);
3247   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
3248     return DAG.getConstant(0, DL, VT);
3249   return SDValue();
3250 }
3251 
visitSUB(SDNode * N)3252 SDValue DAGCombiner::visitSUB(SDNode *N) {
3253   SDValue N0 = N->getOperand(0);
3254   SDValue N1 = N->getOperand(1);
3255   EVT VT = N0.getValueType();
3256   SDLoc DL(N);
3257 
3258   // fold vector ops
3259   if (VT.isVector()) {
3260     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3261       return FoldedVOp;
3262 
3263     // fold (sub x, 0) -> x, vector edition
3264     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3265       return N0;
3266   }
3267 
3268   // fold (sub x, x) -> 0
3269   // FIXME: Refactor this and xor and other similar operations together.
3270   if (N0 == N1)
3271     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3272 
3273   // fold (sub c1, c2) -> c3
3274   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
3275     return C;
3276 
3277   if (SDValue NewSel = foldBinOpIntoSelect(N))
3278     return NewSel;
3279 
3280   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3281 
3282   // fold (sub x, c) -> (add x, -c)
3283   if (N1C) {
3284     return DAG.getNode(ISD::ADD, DL, VT, N0,
3285                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3286   }
3287 
3288   if (isNullOrNullSplat(N0)) {
3289     unsigned BitWidth = VT.getScalarSizeInBits();
3290     // Right-shifting everything out but the sign bit followed by negation is
3291     // the same as flipping arithmetic/logical shift type without the negation:
3292     // -(X >>u 31) -> (X >>s 31)
3293     // -(X >>s 31) -> (X >>u 31)
3294     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3295       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3296       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3297         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3298         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3299           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3300       }
3301     }
3302 
3303     // 0 - X --> 0 if the sub is NUW.
3304     if (N->getFlags().hasNoUnsignedWrap())
3305       return N0;
3306 
3307     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3308       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3309       // N1 must be 0 because negating the minimum signed value is undefined.
3310       if (N->getFlags().hasNoSignedWrap())
3311         return N0;
3312 
3313       // 0 - X --> X if X is 0 or the minimum signed value.
3314       return N1;
3315     }
3316 
3317     // Convert 0 - abs(x).
3318     SDValue Result;
3319     if (N1->getOpcode() == ISD::ABS &&
3320         !TLI.isOperationLegalOrCustom(ISD::ABS, VT) &&
3321         TLI.expandABS(N1.getNode(), Result, DAG, true))
3322       return Result;
3323 
3324     // Fold neg(splat(neg(x)) -> splat(x)
3325     if (VT.isVector()) {
3326       SDValue N1S = DAG.getSplatValue(N1, true);
3327       if (N1S && N1S.getOpcode() == ISD::SUB &&
3328           isNullConstant(N1S.getOperand(0))) {
3329         if (VT.isScalableVector())
3330           return DAG.getSplatVector(VT, DL, N1S.getOperand(1));
3331         return DAG.getSplatBuildVector(VT, DL, N1S.getOperand(1));
3332       }
3333     }
3334   }
3335 
3336   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3337   if (isAllOnesOrAllOnesSplat(N0))
3338     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3339 
3340   // fold (A - (0-B)) -> A+B
3341   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3342     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3343 
3344   // fold A-(A-B) -> B
3345   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3346     return N1.getOperand(1);
3347 
3348   // fold (A+B)-A -> B
3349   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3350     return N0.getOperand(1);
3351 
3352   // fold (A+B)-B -> A
3353   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3354     return N0.getOperand(0);
3355 
3356   // fold (A+C1)-C2 -> A+(C1-C2)
3357   if (N0.getOpcode() == ISD::ADD &&
3358       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3359       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3360     SDValue NewC =
3361         DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0.getOperand(1), N1});
3362     assert(NewC && "Constant folding failed");
3363     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3364   }
3365 
3366   // fold C2-(A+C1) -> (C2-C1)-A
3367   if (N1.getOpcode() == ISD::ADD) {
3368     SDValue N11 = N1.getOperand(1);
3369     if (isConstantOrConstantVector(N0, /* NoOpaques */ true) &&
3370         isConstantOrConstantVector(N11, /* NoOpaques */ true)) {
3371       SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11});
3372       assert(NewC && "Constant folding failed");
3373       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3374     }
3375   }
3376 
3377   // fold (A-C1)-C2 -> A-(C1+C2)
3378   if (N0.getOpcode() == ISD::SUB &&
3379       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3380       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3381     SDValue NewC =
3382         DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0.getOperand(1), N1});
3383     assert(NewC && "Constant folding failed");
3384     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3385   }
3386 
3387   // fold (c1-A)-c2 -> (c1-c2)-A
3388   if (N0.getOpcode() == ISD::SUB &&
3389       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3390       isConstantOrConstantVector(N0.getOperand(0), /* NoOpaques */ true)) {
3391     SDValue NewC =
3392         DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0.getOperand(0), N1});
3393     assert(NewC && "Constant folding failed");
3394     return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3395   }
3396 
3397   // fold ((A+(B+or-C))-B) -> A+or-C
3398   if (N0.getOpcode() == ISD::ADD &&
3399       (N0.getOperand(1).getOpcode() == ISD::SUB ||
3400        N0.getOperand(1).getOpcode() == ISD::ADD) &&
3401       N0.getOperand(1).getOperand(0) == N1)
3402     return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3403                        N0.getOperand(1).getOperand(1));
3404 
3405   // fold ((A+(C+B))-B) -> A+C
3406   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3407       N0.getOperand(1).getOperand(1) == N1)
3408     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3409                        N0.getOperand(1).getOperand(0));
3410 
3411   // fold ((A-(B-C))-C) -> A-B
3412   if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3413       N0.getOperand(1).getOperand(1) == N1)
3414     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3415                        N0.getOperand(1).getOperand(0));
3416 
3417   // fold (A-(B-C)) -> A+(C-B)
3418   if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3419     return DAG.getNode(ISD::ADD, DL, VT, N0,
3420                        DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3421                                    N1.getOperand(0)));
3422 
3423   // A - (A & B)  ->  A & (~B)
3424   if (N1.getOpcode() == ISD::AND) {
3425     SDValue A = N1.getOperand(0);
3426     SDValue B = N1.getOperand(1);
3427     if (A != N0)
3428       std::swap(A, B);
3429     if (A == N0 &&
3430         (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3431       SDValue InvB =
3432           DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3433       return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3434     }
3435   }
3436 
3437   // fold (X - (-Y * Z)) -> (X + (Y * Z))
3438   if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3439     if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3440         isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3441       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3442                                 N1.getOperand(0).getOperand(1),
3443                                 N1.getOperand(1));
3444       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3445     }
3446     if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3447         isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3448       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3449                                 N1.getOperand(0),
3450                                 N1.getOperand(1).getOperand(1));
3451       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3452     }
3453   }
3454 
3455   // If either operand of a sub is undef, the result is undef
3456   if (N0.isUndef())
3457     return N0;
3458   if (N1.isUndef())
3459     return N1;
3460 
3461   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3462     return V;
3463 
3464   if (SDValue V = foldAddSubOfSignBit(N, DAG))
3465     return V;
3466 
3467   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3468     return V;
3469 
3470   if (SDValue V = foldSubToUSubSat(VT, N))
3471     return V;
3472 
3473   // (x - y) - 1  ->  add (xor y, -1), x
3474   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && isOneOrOneSplat(N1)) {
3475     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
3476                               DAG.getAllOnesConstant(DL, VT));
3477     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
3478   }
3479 
3480   // Look for:
3481   //   sub y, (xor x, -1)
3482   // And if the target does not like this form then turn into:
3483   //   add (add x, y), 1
3484   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3485     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3486     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3487   }
3488 
3489   // Hoist one-use addition by non-opaque constant:
3490   //   (x + C) - y  ->  (x - y) + C
3491   if (N0.hasOneUse() && N0.getOpcode() == ISD::ADD &&
3492       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3493     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3494     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3495   }
3496   // y - (x + C)  ->  (y - x) - C
3497   if (N1.hasOneUse() && N1.getOpcode() == ISD::ADD &&
3498       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
3499     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
3500     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
3501   }
3502   // (x - C) - y  ->  (x - y) - C
3503   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3504   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3505       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3506     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3507     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
3508   }
3509   // (C - x) - y  ->  C - (x + y)
3510   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3511       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3512     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
3513     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
3514   }
3515 
3516   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3517   // rather than 'sub 0/1' (the sext should get folded).
3518   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3519   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3520       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
3521       TLI.getBooleanContents(VT) ==
3522           TargetLowering::ZeroOrNegativeOneBooleanContent) {
3523     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
3524     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
3525   }
3526 
3527   // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3528   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
3529     if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3530       SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
3531       SDValue S0 = N1.getOperand(0);
3532       if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
3533         if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
3534           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
3535             return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
3536     }
3537   }
3538 
3539   // If the relocation model supports it, consider symbol offsets.
3540   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
3541     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3542       // fold (sub Sym, c) -> Sym-c
3543       if (N1C && GA->getOpcode() == ISD::GlobalAddress)
3544         return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
3545                                     GA->getOffset() -
3546                                         (uint64_t)N1C->getSExtValue());
3547       // fold (sub Sym+c1, Sym+c2) -> c1-c2
3548       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
3549         if (GA->getGlobal() == GB->getGlobal())
3550           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
3551                                  DL, VT);
3552     }
3553 
3554   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3555   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3556     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3557     if (TN->getVT() == MVT::i1) {
3558       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3559                                  DAG.getConstant(1, DL, VT));
3560       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
3561     }
3562   }
3563 
3564   // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
3565   if (N1.getOpcode() == ISD::VSCALE) {
3566     const APInt &IntVal = N1.getConstantOperandAPInt(0);
3567     return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
3568   }
3569 
3570   // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
3571   if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
3572     APInt NewStep = -N1.getConstantOperandAPInt(0);
3573     return DAG.getNode(ISD::ADD, DL, VT, N0,
3574                        DAG.getStepVector(DL, VT, NewStep));
3575   }
3576 
3577   // Prefer an add for more folding potential and possibly better codegen:
3578   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3579   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3580     SDValue ShAmt = N1.getOperand(1);
3581     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
3582     if (ShAmtC &&
3583         ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
3584       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
3585       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
3586     }
3587   }
3588 
3589   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) {
3590     // (sub Carry, X)  ->  (addcarry (sub 0, X), 0, Carry)
3591     if (SDValue Carry = getAsCarry(TLI, N0)) {
3592       SDValue X = N1;
3593       SDValue Zero = DAG.getConstant(0, DL, VT);
3594       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
3595       return DAG.getNode(ISD::ADDCARRY, DL,
3596                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
3597                          Carry);
3598     }
3599   }
3600 
3601   return SDValue();
3602 }
3603 
visitSUBSAT(SDNode * N)3604 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
3605   SDValue N0 = N->getOperand(0);
3606   SDValue N1 = N->getOperand(1);
3607   EVT VT = N0.getValueType();
3608   SDLoc DL(N);
3609 
3610   // fold vector ops
3611   if (VT.isVector()) {
3612     // TODO SimplifyVBinOp
3613 
3614     // fold (sub_sat x, 0) -> x, vector edition
3615     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3616       return N0;
3617   }
3618 
3619   // fold (sub_sat x, undef) -> 0
3620   if (N0.isUndef() || N1.isUndef())
3621     return DAG.getConstant(0, DL, VT);
3622 
3623   // fold (sub_sat x, x) -> 0
3624   if (N0 == N1)
3625     return DAG.getConstant(0, DL, VT);
3626 
3627   // fold (sub_sat c1, c2) -> c3
3628   if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
3629     return C;
3630 
3631   // fold (sub_sat x, 0) -> x
3632   if (isNullConstant(N1))
3633     return N0;
3634 
3635   return SDValue();
3636 }
3637 
visitSUBC(SDNode * N)3638 SDValue DAGCombiner::visitSUBC(SDNode *N) {
3639   SDValue N0 = N->getOperand(0);
3640   SDValue N1 = N->getOperand(1);
3641   EVT VT = N0.getValueType();
3642   SDLoc DL(N);
3643 
3644   // If the flag result is dead, turn this into an SUB.
3645   if (!N->hasAnyUseOfValue(1))
3646     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3647                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3648 
3649   // fold (subc x, x) -> 0 + no borrow
3650   if (N0 == N1)
3651     return CombineTo(N, DAG.getConstant(0, DL, VT),
3652                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3653 
3654   // fold (subc x, 0) -> x + no borrow
3655   if (isNullConstant(N1))
3656     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3657 
3658   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3659   if (isAllOnesConstant(N0))
3660     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3661                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3662 
3663   return SDValue();
3664 }
3665 
visitSUBO(SDNode * N)3666 SDValue DAGCombiner::visitSUBO(SDNode *N) {
3667   SDValue N0 = N->getOperand(0);
3668   SDValue N1 = N->getOperand(1);
3669   EVT VT = N0.getValueType();
3670   bool IsSigned = (ISD::SSUBO == N->getOpcode());
3671 
3672   EVT CarryVT = N->getValueType(1);
3673   SDLoc DL(N);
3674 
3675   // If the flag result is dead, turn this into an SUB.
3676   if (!N->hasAnyUseOfValue(1))
3677     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3678                      DAG.getUNDEF(CarryVT));
3679 
3680   // fold (subo x, x) -> 0 + no borrow
3681   if (N0 == N1)
3682     return CombineTo(N, DAG.getConstant(0, DL, VT),
3683                      DAG.getConstant(0, DL, CarryVT));
3684 
3685   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3686 
3687   // fold (subox, c) -> (addo x, -c)
3688   if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) {
3689     return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
3690                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3691   }
3692 
3693   // fold (subo x, 0) -> x + no borrow
3694   if (isNullOrNullSplat(N1))
3695     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3696 
3697   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3698   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
3699     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3700                      DAG.getConstant(0, DL, CarryVT));
3701 
3702   return SDValue();
3703 }
3704 
visitSUBE(SDNode * N)3705 SDValue DAGCombiner::visitSUBE(SDNode *N) {
3706   SDValue N0 = N->getOperand(0);
3707   SDValue N1 = N->getOperand(1);
3708   SDValue CarryIn = N->getOperand(2);
3709 
3710   // fold (sube x, y, false) -> (subc x, y)
3711   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3712     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
3713 
3714   return SDValue();
3715 }
3716 
visitSUBCARRY(SDNode * N)3717 SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
3718   SDValue N0 = N->getOperand(0);
3719   SDValue N1 = N->getOperand(1);
3720   SDValue CarryIn = N->getOperand(2);
3721 
3722   // fold (subcarry x, y, false) -> (usubo x, y)
3723   if (isNullConstant(CarryIn)) {
3724     if (!LegalOperations ||
3725         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
3726       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
3727   }
3728 
3729   return SDValue();
3730 }
3731 
visitSSUBO_CARRY(SDNode * N)3732 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
3733   SDValue N0 = N->getOperand(0);
3734   SDValue N1 = N->getOperand(1);
3735   SDValue CarryIn = N->getOperand(2);
3736 
3737   // fold (ssubo_carry x, y, false) -> (ssubo x, y)
3738   if (isNullConstant(CarryIn)) {
3739     if (!LegalOperations ||
3740         TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
3741       return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
3742   }
3743 
3744   return SDValue();
3745 }
3746 
3747 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
3748 // UMULFIXSAT here.
visitMULFIX(SDNode * N)3749 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
3750   SDValue N0 = N->getOperand(0);
3751   SDValue N1 = N->getOperand(1);
3752   SDValue Scale = N->getOperand(2);
3753   EVT VT = N0.getValueType();
3754 
3755   // fold (mulfix x, undef, scale) -> 0
3756   if (N0.isUndef() || N1.isUndef())
3757     return DAG.getConstant(0, SDLoc(N), VT);
3758 
3759   // Canonicalize constant to RHS (vector doesn't have to splat)
3760   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3761      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3762     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
3763 
3764   // fold (mulfix x, 0, scale) -> 0
3765   if (isNullConstant(N1))
3766     return DAG.getConstant(0, SDLoc(N), VT);
3767 
3768   return SDValue();
3769 }
3770 
visitMUL(SDNode * N)3771 SDValue DAGCombiner::visitMUL(SDNode *N) {
3772   SDValue N0 = N->getOperand(0);
3773   SDValue N1 = N->getOperand(1);
3774   EVT VT = N0.getValueType();
3775 
3776   // fold (mul x, undef) -> 0
3777   if (N0.isUndef() || N1.isUndef())
3778     return DAG.getConstant(0, SDLoc(N), VT);
3779 
3780   bool N1IsConst = false;
3781   bool N1IsOpaqueConst = false;
3782   APInt ConstValue1;
3783 
3784   // fold vector ops
3785   if (VT.isVector()) {
3786     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3787       return FoldedVOp;
3788 
3789     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
3790     assert((!N1IsConst ||
3791             ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
3792            "Splat APInt should be element width");
3793   } else {
3794     N1IsConst = isa<ConstantSDNode>(N1);
3795     if (N1IsConst) {
3796       ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue();
3797       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
3798     }
3799   }
3800 
3801   // fold (mul c1, c2) -> c1*c2
3802   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT, {N0, N1}))
3803     return C;
3804 
3805   // canonicalize constant to RHS (vector doesn't have to splat)
3806   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3807      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3808     return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
3809 
3810   // fold (mul x, 0) -> 0
3811   if (N1IsConst && ConstValue1.isNullValue())
3812     return N1;
3813 
3814   // fold (mul x, 1) -> x
3815   if (N1IsConst && ConstValue1.isOneValue())
3816     return N0;
3817 
3818   if (SDValue NewSel = foldBinOpIntoSelect(N))
3819     return NewSel;
3820 
3821   // fold (mul x, -1) -> 0-x
3822   if (N1IsConst && ConstValue1.isAllOnesValue()) {
3823     SDLoc DL(N);
3824     return DAG.getNode(ISD::SUB, DL, VT,
3825                        DAG.getConstant(0, DL, VT), N0);
3826   }
3827 
3828   // fold (mul x, (1 << c)) -> x << c
3829   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
3830       DAG.isKnownToBeAPowerOfTwo(N1) &&
3831       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
3832     SDLoc DL(N);
3833     SDValue LogBase2 = BuildLogBase2(N1, DL);
3834     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
3835     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
3836     return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
3837   }
3838 
3839   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
3840   if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) {
3841     unsigned Log2Val = (-ConstValue1).logBase2();
3842     SDLoc DL(N);
3843     // FIXME: If the input is something that is easily negated (e.g. a
3844     // single-use add), we should put the negate there.
3845     return DAG.getNode(ISD::SUB, DL, VT,
3846                        DAG.getConstant(0, DL, VT),
3847                        DAG.getNode(ISD::SHL, DL, VT, N0,
3848                             DAG.getConstant(Log2Val, DL,
3849                                       getShiftAmountTy(N0.getValueType()))));
3850   }
3851 
3852   // Try to transform:
3853   // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
3854   // mul x, (2^N + 1) --> add (shl x, N), x
3855   // mul x, (2^N - 1) --> sub (shl x, N), x
3856   // Examples: x * 33 --> (x << 5) + x
3857   //           x * 15 --> (x << 4) - x
3858   //           x * -33 --> -((x << 5) + x)
3859   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
3860   // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
3861   // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
3862   // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
3863   // Examples: x * 0x8800 --> (x << 15) + (x << 11)
3864   //           x * 0xf800 --> (x << 16) - (x << 11)
3865   //           x * -0x8800 --> -((x << 15) + (x << 11))
3866   //           x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
3867   if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
3868     // TODO: We could handle more general decomposition of any constant by
3869     //       having the target set a limit on number of ops and making a
3870     //       callback to determine that sequence (similar to sqrt expansion).
3871     unsigned MathOp = ISD::DELETED_NODE;
3872     APInt MulC = ConstValue1.abs();
3873     // The constant `2` should be treated as (2^0 + 1).
3874     unsigned TZeros = MulC == 2 ? 0 : MulC.countTrailingZeros();
3875     MulC.lshrInPlace(TZeros);
3876     if ((MulC - 1).isPowerOf2())
3877       MathOp = ISD::ADD;
3878     else if ((MulC + 1).isPowerOf2())
3879       MathOp = ISD::SUB;
3880 
3881     if (MathOp != ISD::DELETED_NODE) {
3882       unsigned ShAmt =
3883           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
3884       ShAmt += TZeros;
3885       assert(ShAmt < VT.getScalarSizeInBits() &&
3886              "multiply-by-constant generated out of bounds shift");
3887       SDLoc DL(N);
3888       SDValue Shl =
3889           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
3890       SDValue R =
3891           TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
3892                                DAG.getNode(ISD::SHL, DL, VT, N0,
3893                                            DAG.getConstant(TZeros, DL, VT)))
3894                  : DAG.getNode(MathOp, DL, VT, Shl, N0);
3895       if (ConstValue1.isNegative())
3896         R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R);
3897       return R;
3898     }
3899   }
3900 
3901   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
3902   if (N0.getOpcode() == ISD::SHL &&
3903       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3904       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3905     SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1));
3906     if (isConstantOrConstantVector(C3))
3907       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3);
3908   }
3909 
3910   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
3911   // use.
3912   {
3913     SDValue Sh(nullptr, 0), Y(nullptr, 0);
3914 
3915     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
3916     if (N0.getOpcode() == ISD::SHL &&
3917         isConstantOrConstantVector(N0.getOperand(1)) &&
3918         N0.getNode()->hasOneUse()) {
3919       Sh = N0; Y = N1;
3920     } else if (N1.getOpcode() == ISD::SHL &&
3921                isConstantOrConstantVector(N1.getOperand(1)) &&
3922                N1.getNode()->hasOneUse()) {
3923       Sh = N1; Y = N0;
3924     }
3925 
3926     if (Sh.getNode()) {
3927       SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, Sh.getOperand(0), Y);
3928       return DAG.getNode(ISD::SHL, SDLoc(N), VT, Mul, Sh.getOperand(1));
3929     }
3930   }
3931 
3932   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
3933   if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
3934       N0.getOpcode() == ISD::ADD &&
3935       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
3936       isMulAddWithConstProfitable(N, N0, N1))
3937       return DAG.getNode(ISD::ADD, SDLoc(N), VT,
3938                          DAG.getNode(ISD::MUL, SDLoc(N0), VT,
3939                                      N0.getOperand(0), N1),
3940                          DAG.getNode(ISD::MUL, SDLoc(N1), VT,
3941                                      N0.getOperand(1), N1));
3942 
3943   // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
3944   if (N0.getOpcode() == ISD::VSCALE)
3945     if (ConstantSDNode *NC1 = isConstOrConstSplat(N1)) {
3946       const APInt &C0 = N0.getConstantOperandAPInt(0);
3947       const APInt &C1 = NC1->getAPIntValue();
3948       return DAG.getVScale(SDLoc(N), VT, C0 * C1);
3949     }
3950 
3951   // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
3952   APInt MulVal;
3953   if (N0.getOpcode() == ISD::STEP_VECTOR)
3954     if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
3955       const APInt &C0 = N0.getConstantOperandAPInt(0);
3956       APInt NewStep = C0 * MulVal;
3957       return DAG.getStepVector(SDLoc(N), VT, NewStep);
3958     }
3959 
3960   // Fold ((mul x, 0/undef) -> 0,
3961   //       (mul x, 1) -> x) -> x)
3962   // -> and(x, mask)
3963   // We can replace vectors with '0' and '1' factors with a clearing mask.
3964   if (VT.isFixedLengthVector()) {
3965     unsigned NumElts = VT.getVectorNumElements();
3966     SmallBitVector ClearMask;
3967     ClearMask.reserve(NumElts);
3968     auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
3969       if (!V || V->isNullValue()) {
3970         ClearMask.push_back(true);
3971         return true;
3972       }
3973       ClearMask.push_back(false);
3974       return V->isOne();
3975     };
3976     if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
3977         ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
3978       assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
3979       SDLoc DL(N);
3980       EVT LegalSVT = N1.getOperand(0).getValueType();
3981       SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
3982       SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
3983       SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
3984       for (unsigned I = 0; I != NumElts; ++I)
3985         if (ClearMask[I])
3986           Mask[I] = Zero;
3987       return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
3988     }
3989   }
3990 
3991   // reassociate mul
3992   if (SDValue RMUL = reassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags()))
3993     return RMUL;
3994 
3995   return SDValue();
3996 }
3997 
3998 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)3999 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4000                                      const TargetLowering &TLI) {
4001   RTLIB::Libcall LC;
4002   EVT NodeType = Node->getValueType(0);
4003   if (!NodeType.isSimple())
4004     return false;
4005   switch (NodeType.getSimpleVT().SimpleTy) {
4006   default: return false; // No libcall for vector types.
4007   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
4008   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4009   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4010   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4011   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4012   }
4013 
4014   return TLI.getLibcallName(LC) != nullptr;
4015 }
4016 
4017 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4018 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4019   if (Node->use_empty())
4020     return SDValue(); // This is a dead node, leave it alone.
4021 
4022   unsigned Opcode = Node->getOpcode();
4023   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4024   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4025 
4026   // DivMod lib calls can still work on non-legal types if using lib-calls.
4027   EVT VT = Node->getValueType(0);
4028   if (VT.isVector() || !VT.isInteger())
4029     return SDValue();
4030 
4031   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4032     return SDValue();
4033 
4034   // If DIVREM is going to get expanded into a libcall,
4035   // but there is no libcall available, then don't combine.
4036   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4037       !isDivRemLibcallAvailable(Node, isSigned, TLI))
4038     return SDValue();
4039 
4040   // If div is legal, it's better to do the normal expansion
4041   unsigned OtherOpcode = 0;
4042   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4043     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4044     if (TLI.isOperationLegalOrCustom(Opcode, VT))
4045       return SDValue();
4046   } else {
4047     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4048     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4049       return SDValue();
4050   }
4051 
4052   SDValue Op0 = Node->getOperand(0);
4053   SDValue Op1 = Node->getOperand(1);
4054   SDValue combined;
4055   for (SDNode::use_iterator UI = Op0.getNode()->use_begin(),
4056          UE = Op0.getNode()->use_end(); UI != UE; ++UI) {
4057     SDNode *User = *UI;
4058     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4059         User->use_empty())
4060       continue;
4061     // Convert the other matching node(s), too;
4062     // otherwise, the DIVREM may get target-legalized into something
4063     // target-specific that we won't be able to recognize.
4064     unsigned UserOpc = User->getOpcode();
4065     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4066         User->getOperand(0) == Op0 &&
4067         User->getOperand(1) == Op1) {
4068       if (!combined) {
4069         if (UserOpc == OtherOpcode) {
4070           SDVTList VTs = DAG.getVTList(VT, VT);
4071           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4072         } else if (UserOpc == DivRemOpc) {
4073           combined = SDValue(User, 0);
4074         } else {
4075           assert(UserOpc == Opcode);
4076           continue;
4077         }
4078       }
4079       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4080         CombineTo(User, combined);
4081       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4082         CombineTo(User, combined.getValue(1));
4083     }
4084   }
4085   return combined;
4086 }
4087 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4088 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4089   SDValue N0 = N->getOperand(0);
4090   SDValue N1 = N->getOperand(1);
4091   EVT VT = N->getValueType(0);
4092   SDLoc DL(N);
4093 
4094   unsigned Opc = N->getOpcode();
4095   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4096   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4097 
4098   // X / undef -> undef
4099   // X % undef -> undef
4100   // X / 0 -> undef
4101   // X % 0 -> undef
4102   // NOTE: This includes vectors where any divisor element is zero/undef.
4103   if (DAG.isUndef(Opc, {N0, N1}))
4104     return DAG.getUNDEF(VT);
4105 
4106   // undef / X -> 0
4107   // undef % X -> 0
4108   if (N0.isUndef())
4109     return DAG.getConstant(0, DL, VT);
4110 
4111   // 0 / X -> 0
4112   // 0 % X -> 0
4113   ConstantSDNode *N0C = isConstOrConstSplat(N0);
4114   if (N0C && N0C->isNullValue())
4115     return N0;
4116 
4117   // X / X -> 1
4118   // X % X -> 0
4119   if (N0 == N1)
4120     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
4121 
4122   // X / 1 -> X
4123   // X % 1 -> 0
4124   // If this is a boolean op (single-bit element type), we can't have
4125   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4126   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4127   // it's a 1.
4128   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4129     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
4130 
4131   return SDValue();
4132 }
4133 
visitSDIV(SDNode * N)4134 SDValue DAGCombiner::visitSDIV(SDNode *N) {
4135   SDValue N0 = N->getOperand(0);
4136   SDValue N1 = N->getOperand(1);
4137   EVT VT = N->getValueType(0);
4138   EVT CCVT = getSetCCResultType(VT);
4139 
4140   // fold vector ops
4141   if (VT.isVector())
4142     if (SDValue FoldedVOp = SimplifyVBinOp(N))
4143       return FoldedVOp;
4144 
4145   SDLoc DL(N);
4146 
4147   // fold (sdiv c1, c2) -> c1/c2
4148   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4149   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
4150     return C;
4151 
4152   // fold (sdiv X, -1) -> 0-X
4153   if (N1C && N1C->isAllOnesValue())
4154     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
4155 
4156   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4157   if (N1C && N1C->getAPIntValue().isMinSignedValue())
4158     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4159                          DAG.getConstant(1, DL, VT),
4160                          DAG.getConstant(0, DL, VT));
4161 
4162   if (SDValue V = simplifyDivRem(N, DAG))
4163     return V;
4164 
4165   if (SDValue NewSel = foldBinOpIntoSelect(N))
4166     return NewSel;
4167 
4168   // If we know the sign bits of both operands are zero, strength reduce to a
4169   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
4170   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4171     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
4172 
4173   if (SDValue V = visitSDIVLike(N0, N1, N)) {
4174     // If the corresponding remainder node exists, update its users with
4175     // (Dividend - (Quotient * Divisor).
4176     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
4177                                               { N0, N1 })) {
4178       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4179       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4180       AddToWorklist(Mul.getNode());
4181       AddToWorklist(Sub.getNode());
4182       CombineTo(RemNode, Sub);
4183     }
4184     return V;
4185   }
4186 
4187   // sdiv, srem -> sdivrem
4188   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4189   // true.  Otherwise, we break the simplification logic in visitREM().
4190   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4191   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4192     if (SDValue DivRem = useDivRem(N))
4193         return DivRem;
4194 
4195   return SDValue();
4196 }
4197 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)4198 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4199   SDLoc DL(N);
4200   EVT VT = N->getValueType(0);
4201   EVT CCVT = getSetCCResultType(VT);
4202   unsigned BitWidth = VT.getScalarSizeInBits();
4203 
4204   // Helper for determining whether a value is a power-2 constant scalar or a
4205   // vector of such elements.
4206   auto IsPowerOfTwo = [](ConstantSDNode *C) {
4207     if (C->isNullValue() || C->isOpaque())
4208       return false;
4209     if (C->getAPIntValue().isPowerOf2())
4210       return true;
4211     if ((-C->getAPIntValue()).isPowerOf2())
4212       return true;
4213     return false;
4214   };
4215 
4216   // fold (sdiv X, pow2) -> simple ops after legalize
4217   // FIXME: We check for the exact bit here because the generic lowering gives
4218   // better results in that case. The target-specific lowering should learn how
4219   // to handle exact sdivs efficiently.
4220   if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) {
4221     // Target-specific implementation of sdiv x, pow2.
4222     if (SDValue Res = BuildSDIVPow2(N))
4223       return Res;
4224 
4225     // Create constants that are functions of the shift amount value.
4226     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
4227     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
4228     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
4229     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
4230     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
4231     if (!isConstantOrConstantVector(Inexact))
4232       return SDValue();
4233 
4234     // Splat the sign bit into the register
4235     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
4236                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
4237     AddToWorklist(Sign.getNode());
4238 
4239     // Add (N0 < 0) ? abs2 - 1 : 0;
4240     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
4241     AddToWorklist(Srl.getNode());
4242     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
4243     AddToWorklist(Add.getNode());
4244     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
4245     AddToWorklist(Sra.getNode());
4246 
4247     // Special case: (sdiv X, 1) -> X
4248     // Special Case: (sdiv X, -1) -> 0-X
4249     SDValue One = DAG.getConstant(1, DL, VT);
4250     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4251     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
4252     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
4253     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
4254     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
4255 
4256     // If dividing by a positive value, we're done. Otherwise, the result must
4257     // be negated.
4258     SDValue Zero = DAG.getConstant(0, DL, VT);
4259     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
4260 
4261     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4262     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
4263     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
4264     return Res;
4265   }
4266 
4267   // If integer divide is expensive and we satisfy the requirements, emit an
4268   // alternate sequence.  Targets may check function attributes for size/speed
4269   // trade-offs.
4270   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4271   if (isConstantOrConstantVector(N1) &&
4272       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4273     if (SDValue Op = BuildSDIV(N))
4274       return Op;
4275 
4276   return SDValue();
4277 }
4278 
visitUDIV(SDNode * N)4279 SDValue DAGCombiner::visitUDIV(SDNode *N) {
4280   SDValue N0 = N->getOperand(0);
4281   SDValue N1 = N->getOperand(1);
4282   EVT VT = N->getValueType(0);
4283   EVT CCVT = getSetCCResultType(VT);
4284 
4285   // fold vector ops
4286   if (VT.isVector())
4287     if (SDValue FoldedVOp = SimplifyVBinOp(N))
4288       return FoldedVOp;
4289 
4290   SDLoc DL(N);
4291 
4292   // fold (udiv c1, c2) -> c1/c2
4293   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4294   if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
4295     return C;
4296 
4297   // fold (udiv X, -1) -> select(X == -1, 1, 0)
4298   if (N1C && N1C->getAPIntValue().isAllOnesValue())
4299     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4300                          DAG.getConstant(1, DL, VT),
4301                          DAG.getConstant(0, DL, VT));
4302 
4303   if (SDValue V = simplifyDivRem(N, DAG))
4304     return V;
4305 
4306   if (SDValue NewSel = foldBinOpIntoSelect(N))
4307     return NewSel;
4308 
4309   if (SDValue V = visitUDIVLike(N0, N1, N)) {
4310     // If the corresponding remainder node exists, update its users with
4311     // (Dividend - (Quotient * Divisor).
4312     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
4313                                               { N0, N1 })) {
4314       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4315       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4316       AddToWorklist(Mul.getNode());
4317       AddToWorklist(Sub.getNode());
4318       CombineTo(RemNode, Sub);
4319     }
4320     return V;
4321   }
4322 
4323   // sdiv, srem -> sdivrem
4324   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4325   // true.  Otherwise, we break the simplification logic in visitREM().
4326   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4327   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4328     if (SDValue DivRem = useDivRem(N))
4329         return DivRem;
4330 
4331   return SDValue();
4332 }
4333 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)4334 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4335   SDLoc DL(N);
4336   EVT VT = N->getValueType(0);
4337 
4338   // fold (udiv x, (1 << c)) -> x >>u c
4339   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4340       DAG.isKnownToBeAPowerOfTwo(N1)) {
4341     SDValue LogBase2 = BuildLogBase2(N1, DL);
4342     AddToWorklist(LogBase2.getNode());
4343 
4344     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4345     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4346     AddToWorklist(Trunc.getNode());
4347     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4348   }
4349 
4350   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4351   if (N1.getOpcode() == ISD::SHL) {
4352     SDValue N10 = N1.getOperand(0);
4353     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
4354         DAG.isKnownToBeAPowerOfTwo(N10)) {
4355       SDValue LogBase2 = BuildLogBase2(N10, DL);
4356       AddToWorklist(LogBase2.getNode());
4357 
4358       EVT ADDVT = N1.getOperand(1).getValueType();
4359       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
4360       AddToWorklist(Trunc.getNode());
4361       SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
4362       AddToWorklist(Add.getNode());
4363       return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
4364     }
4365   }
4366 
4367   // fold (udiv x, c) -> alternate
4368   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4369   if (isConstantOrConstantVector(N1) &&
4370       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4371     if (SDValue Op = BuildUDIV(N))
4372       return Op;
4373 
4374   return SDValue();
4375 }
4376 
4377 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)4378 SDValue DAGCombiner::visitREM(SDNode *N) {
4379   unsigned Opcode = N->getOpcode();
4380   SDValue N0 = N->getOperand(0);
4381   SDValue N1 = N->getOperand(1);
4382   EVT VT = N->getValueType(0);
4383   EVT CCVT = getSetCCResultType(VT);
4384 
4385   bool isSigned = (Opcode == ISD::SREM);
4386   SDLoc DL(N);
4387 
4388   // fold (rem c1, c2) -> c1%c2
4389   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4390   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4391     return C;
4392 
4393   // fold (urem X, -1) -> select(X == -1, 0, x)
4394   if (!isSigned && N1C && N1C->getAPIntValue().isAllOnesValue())
4395     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4396                          DAG.getConstant(0, DL, VT), N0);
4397 
4398   if (SDValue V = simplifyDivRem(N, DAG))
4399     return V;
4400 
4401   if (SDValue NewSel = foldBinOpIntoSelect(N))
4402     return NewSel;
4403 
4404   if (isSigned) {
4405     // If we know the sign bits of both operands are zero, strength reduce to a
4406     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4407     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4408       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4409   } else {
4410     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4411       // fold (urem x, pow2) -> (and x, pow2-1)
4412       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4413       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4414       AddToWorklist(Add.getNode());
4415       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4416     }
4417     if (N1.getOpcode() == ISD::SHL &&
4418         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
4419       // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4420       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4421       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4422       AddToWorklist(Add.getNode());
4423       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4424     }
4425   }
4426 
4427   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4428 
4429   // If X/C can be simplified by the division-by-constant logic, lower
4430   // X%C to the equivalent of X-X/C*C.
4431   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4432   // speculative DIV must not cause a DIVREM conversion.  We guard against this
4433   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
4434   // combine will not return a DIVREM.  Regardless, checking cheapness here
4435   // makes sense since the simplification results in fatter code.
4436   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
4437     SDValue OptimizedDiv =
4438         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4439     if (OptimizedDiv.getNode()) {
4440       // If the equivalent Div node also exists, update its users.
4441       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4442       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
4443                                                 { N0, N1 }))
4444         CombineTo(DivNode, OptimizedDiv);
4445       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
4446       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4447       AddToWorklist(OptimizedDiv.getNode());
4448       AddToWorklist(Mul.getNode());
4449       return Sub;
4450     }
4451   }
4452 
4453   // sdiv, srem -> sdivrem
4454   if (SDValue DivRem = useDivRem(N))
4455     return DivRem.getValue(1);
4456 
4457   return SDValue();
4458 }
4459 
visitMULHS(SDNode * N)4460 SDValue DAGCombiner::visitMULHS(SDNode *N) {
4461   SDValue N0 = N->getOperand(0);
4462   SDValue N1 = N->getOperand(1);
4463   EVT VT = N->getValueType(0);
4464   SDLoc DL(N);
4465 
4466   if (VT.isVector()) {
4467     // fold (mulhs x, 0) -> 0
4468     // do not return N0/N1, because undef node may exist.
4469     if (ISD::isConstantSplatVectorAllZeros(N0.getNode()) ||
4470         ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4471       return DAG.getConstant(0, DL, VT);
4472   }
4473 
4474   // fold (mulhs c1, c2)
4475   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
4476     return C;
4477 
4478   // fold (mulhs x, 0) -> 0
4479   if (isNullConstant(N1))
4480     return N1;
4481   // fold (mulhs x, 1) -> (sra x, size(x)-1)
4482   if (isOneConstant(N1))
4483     return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
4484                        DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
4485                                        getShiftAmountTy(N0.getValueType())));
4486 
4487   // fold (mulhs x, undef) -> 0
4488   if (N0.isUndef() || N1.isUndef())
4489     return DAG.getConstant(0, DL, VT);
4490 
4491   // If the type twice as wide is legal, transform the mulhs to a wider multiply
4492   // plus a shift.
4493   if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
4494       !VT.isVector()) {
4495     MVT Simple = VT.getSimpleVT();
4496     unsigned SimpleSize = Simple.getSizeInBits();
4497     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4498     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4499       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4500       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4501       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4502       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4503             DAG.getConstant(SimpleSize, DL,
4504                             getShiftAmountTy(N1.getValueType())));
4505       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4506     }
4507   }
4508 
4509   return SDValue();
4510 }
4511 
visitMULHU(SDNode * N)4512 SDValue DAGCombiner::visitMULHU(SDNode *N) {
4513   SDValue N0 = N->getOperand(0);
4514   SDValue N1 = N->getOperand(1);
4515   EVT VT = N->getValueType(0);
4516   SDLoc DL(N);
4517 
4518   if (VT.isVector()) {
4519     // fold (mulhu x, 0) -> 0
4520     // do not return N0/N1, because undef node may exist.
4521     if (ISD::isConstantSplatVectorAllZeros(N0.getNode()) ||
4522         ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4523       return DAG.getConstant(0, DL, VT);
4524   }
4525 
4526   // fold (mulhu c1, c2)
4527   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
4528     return C;
4529 
4530   // fold (mulhu x, 0) -> 0
4531   if (isNullConstant(N1))
4532     return N1;
4533   // fold (mulhu x, 1) -> 0
4534   if (isOneConstant(N1))
4535     return DAG.getConstant(0, DL, N0.getValueType());
4536   // fold (mulhu x, undef) -> 0
4537   if (N0.isUndef() || N1.isUndef())
4538     return DAG.getConstant(0, DL, VT);
4539 
4540   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
4541   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4542       DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
4543     unsigned NumEltBits = VT.getScalarSizeInBits();
4544     SDValue LogBase2 = BuildLogBase2(N1, DL);
4545     SDValue SRLAmt = DAG.getNode(
4546         ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
4547     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4548     SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
4549     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4550   }
4551 
4552   // If the type twice as wide is legal, transform the mulhu to a wider multiply
4553   // plus a shift.
4554   if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
4555       !VT.isVector()) {
4556     MVT Simple = VT.getSimpleVT();
4557     unsigned SimpleSize = Simple.getSizeInBits();
4558     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4559     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4560       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4561       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4562       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4563       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4564             DAG.getConstant(SimpleSize, DL,
4565                             getShiftAmountTy(N1.getValueType())));
4566       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4567     }
4568   }
4569 
4570   return SDValue();
4571 }
4572 
4573 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
4574 /// give the opcodes for the two computations that are being performed. Return
4575 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)4576 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
4577                                                 unsigned HiOp) {
4578   // If the high half is not needed, just compute the low half.
4579   bool HiExists = N->hasAnyUseOfValue(1);
4580   if (!HiExists && (!LegalOperations ||
4581                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
4582     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4583     return CombineTo(N, Res, Res);
4584   }
4585 
4586   // If the low half is not needed, just compute the high half.
4587   bool LoExists = N->hasAnyUseOfValue(0);
4588   if (!LoExists && (!LegalOperations ||
4589                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
4590     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4591     return CombineTo(N, Res, Res);
4592   }
4593 
4594   // If both halves are used, return as it is.
4595   if (LoExists && HiExists)
4596     return SDValue();
4597 
4598   // If the two computed results can be simplified separately, separate them.
4599   if (LoExists) {
4600     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4601     AddToWorklist(Lo.getNode());
4602     SDValue LoOpt = combine(Lo.getNode());
4603     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
4604         (!LegalOperations ||
4605          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
4606       return CombineTo(N, LoOpt, LoOpt);
4607   }
4608 
4609   if (HiExists) {
4610     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4611     AddToWorklist(Hi.getNode());
4612     SDValue HiOpt = combine(Hi.getNode());
4613     if (HiOpt.getNode() && HiOpt != Hi &&
4614         (!LegalOperations ||
4615          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
4616       return CombineTo(N, HiOpt, HiOpt);
4617   }
4618 
4619   return SDValue();
4620 }
4621 
visitSMUL_LOHI(SDNode * N)4622 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
4623   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
4624     return Res;
4625 
4626   EVT VT = N->getValueType(0);
4627   SDLoc DL(N);
4628 
4629   // If the type is twice as wide is legal, transform the mulhu to a wider
4630   // multiply plus a shift.
4631   if (VT.isSimple() && !VT.isVector()) {
4632     MVT Simple = VT.getSimpleVT();
4633     unsigned SimpleSize = Simple.getSizeInBits();
4634     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4635     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4636       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N->getOperand(0));
4637       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N->getOperand(1));
4638       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4639       // Compute the high part as N1.
4640       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4641             DAG.getConstant(SimpleSize, DL,
4642                             getShiftAmountTy(Lo.getValueType())));
4643       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4644       // Compute the low part as N0.
4645       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4646       return CombineTo(N, Lo, Hi);
4647     }
4648   }
4649 
4650   return SDValue();
4651 }
4652 
visitUMUL_LOHI(SDNode * N)4653 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
4654   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
4655     return Res;
4656 
4657   EVT VT = N->getValueType(0);
4658   SDLoc DL(N);
4659 
4660   // (umul_lohi N0, 0) -> (0, 0)
4661   if (isNullConstant(N->getOperand(1))) {
4662     SDValue Zero = DAG.getConstant(0, DL, VT);
4663     return CombineTo(N, Zero, Zero);
4664   }
4665 
4666   // (umul_lohi N0, 1) -> (N0, 0)
4667   if (isOneConstant(N->getOperand(1))) {
4668     SDValue Zero = DAG.getConstant(0, DL, VT);
4669     return CombineTo(N, N->getOperand(0), Zero);
4670   }
4671 
4672   // If the type is twice as wide is legal, transform the mulhu to a wider
4673   // multiply plus a shift.
4674   if (VT.isSimple() && !VT.isVector()) {
4675     MVT Simple = VT.getSimpleVT();
4676     unsigned SimpleSize = Simple.getSizeInBits();
4677     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4678     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4679       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N->getOperand(0));
4680       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N->getOperand(1));
4681       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4682       // Compute the high part as N1.
4683       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4684             DAG.getConstant(SimpleSize, DL,
4685                             getShiftAmountTy(Lo.getValueType())));
4686       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4687       // Compute the low part as N0.
4688       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4689       return CombineTo(N, Lo, Hi);
4690     }
4691   }
4692 
4693   return SDValue();
4694 }
4695 
visitMULO(SDNode * N)4696 SDValue DAGCombiner::visitMULO(SDNode *N) {
4697   SDValue N0 = N->getOperand(0);
4698   SDValue N1 = N->getOperand(1);
4699   EVT VT = N0.getValueType();
4700   bool IsSigned = (ISD::SMULO == N->getOpcode());
4701 
4702   EVT CarryVT = N->getValueType(1);
4703   SDLoc DL(N);
4704 
4705   ConstantSDNode *N0C = isConstOrConstSplat(N0);
4706   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4707 
4708   // fold operation with constant operands.
4709   // TODO: Move this to FoldConstantArithmetic when it supports nodes with
4710   // multiple results.
4711   if (N0C && N1C) {
4712     bool Overflow;
4713     APInt Result =
4714         IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
4715                  : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
4716     return CombineTo(N, DAG.getConstant(Result, DL, VT),
4717                      DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
4718   }
4719 
4720   // canonicalize constant to RHS.
4721   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4722       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4723     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
4724 
4725   // fold (mulo x, 0) -> 0 + no carry out
4726   if (isNullOrNullSplat(N1))
4727     return CombineTo(N, DAG.getConstant(0, DL, VT),
4728                      DAG.getConstant(0, DL, CarryVT));
4729 
4730   // (mulo x, 2) -> (addo x, x)
4731   if (N1C && N1C->getAPIntValue() == 2)
4732     return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
4733                        N->getVTList(), N0, N0);
4734 
4735   if (IsSigned) {
4736     // A 1 bit SMULO overflows if both inputs are 1.
4737     if (VT.getScalarSizeInBits() == 1) {
4738       SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
4739       return CombineTo(N, And,
4740                        DAG.getSetCC(DL, CarryVT, And,
4741                                     DAG.getConstant(0, DL, VT), ISD::SETNE));
4742     }
4743 
4744     // Multiplying n * m significant bits yields a result of n + m significant
4745     // bits. If the total number of significant bits does not exceed the
4746     // result bit width (minus 1), there is no overflow.
4747     unsigned SignBits = DAG.ComputeNumSignBits(N0);
4748     if (SignBits > 1)
4749       SignBits += DAG.ComputeNumSignBits(N1);
4750     if (SignBits > VT.getScalarSizeInBits() + 1)
4751       return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
4752                        DAG.getConstant(0, DL, CarryVT));
4753   } else {
4754     KnownBits N1Known = DAG.computeKnownBits(N1);
4755     KnownBits N0Known = DAG.computeKnownBits(N0);
4756     bool Overflow;
4757     (void)N0Known.getMaxValue().umul_ov(N1Known.getMaxValue(), Overflow);
4758     if (!Overflow)
4759       return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
4760                        DAG.getConstant(0, DL, CarryVT));
4761   }
4762 
4763   return SDValue();
4764 }
4765 
visitIMINMAX(SDNode * N)4766 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
4767   SDValue N0 = N->getOperand(0);
4768   SDValue N1 = N->getOperand(1);
4769   EVT VT = N0.getValueType();
4770   unsigned Opcode = N->getOpcode();
4771 
4772   // fold vector ops
4773   if (VT.isVector())
4774     if (SDValue FoldedVOp = SimplifyVBinOp(N))
4775       return FoldedVOp;
4776 
4777   // fold operation with constant operands.
4778   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, SDLoc(N), VT, {N0, N1}))
4779     return C;
4780 
4781   // canonicalize constant to RHS
4782   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4783       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4784     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
4785 
4786   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
4787   // Only do this if the current op isn't legal and the flipped is.
4788   if (!TLI.isOperationLegal(Opcode, VT) &&
4789       (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
4790       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
4791     unsigned AltOpcode;
4792     switch (Opcode) {
4793     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
4794     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
4795     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
4796     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
4797     default: llvm_unreachable("Unknown MINMAX opcode");
4798     }
4799     if (TLI.isOperationLegal(AltOpcode, VT))
4800       return DAG.getNode(AltOpcode, SDLoc(N), VT, N0, N1);
4801   }
4802 
4803   // Simplify the operands using demanded-bits information.
4804   if (SimplifyDemandedBits(SDValue(N, 0)))
4805     return SDValue(N, 0);
4806 
4807   return SDValue();
4808 }
4809 
4810 /// If this is a bitwise logic instruction and both operands have the same
4811 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)4812 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
4813   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
4814   EVT VT = N0.getValueType();
4815   unsigned LogicOpcode = N->getOpcode();
4816   unsigned HandOpcode = N0.getOpcode();
4817   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
4818           LogicOpcode == ISD::XOR) && "Expected logic opcode");
4819   assert(HandOpcode == N1.getOpcode() && "Bad input!");
4820 
4821   // Bail early if none of these transforms apply.
4822   if (N0.getNumOperands() == 0)
4823     return SDValue();
4824 
4825   // FIXME: We should check number of uses of the operands to not increase
4826   //        the instruction count for all transforms.
4827 
4828   // Handle size-changing casts.
4829   SDValue X = N0.getOperand(0);
4830   SDValue Y = N1.getOperand(0);
4831   EVT XVT = X.getValueType();
4832   SDLoc DL(N);
4833   if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
4834       HandOpcode == ISD::SIGN_EXTEND) {
4835     // If both operands have other uses, this transform would create extra
4836     // instructions without eliminating anything.
4837     if (!N0.hasOneUse() && !N1.hasOneUse())
4838       return SDValue();
4839     // We need matching integer source types.
4840     if (XVT != Y.getValueType())
4841       return SDValue();
4842     // Don't create an illegal op during or after legalization. Don't ever
4843     // create an unsupported vector op.
4844     if ((VT.isVector() || LegalOperations) &&
4845         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
4846       return SDValue();
4847     // Avoid infinite looping with PromoteIntBinOp.
4848     // TODO: Should we apply desirable/legal constraints to all opcodes?
4849     if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
4850         !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
4851       return SDValue();
4852     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
4853     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4854     return DAG.getNode(HandOpcode, DL, VT, Logic);
4855   }
4856 
4857   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
4858   if (HandOpcode == ISD::TRUNCATE) {
4859     // If both operands have other uses, this transform would create extra
4860     // instructions without eliminating anything.
4861     if (!N0.hasOneUse() && !N1.hasOneUse())
4862       return SDValue();
4863     // We need matching source types.
4864     if (XVT != Y.getValueType())
4865       return SDValue();
4866     // Don't create an illegal op during or after legalization.
4867     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
4868       return SDValue();
4869     // Be extra careful sinking truncate. If it's free, there's no benefit in
4870     // widening a binop. Also, don't create a logic op on an illegal type.
4871     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
4872       return SDValue();
4873     if (!TLI.isTypeLegal(XVT))
4874       return SDValue();
4875     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4876     return DAG.getNode(HandOpcode, DL, VT, Logic);
4877   }
4878 
4879   // For binops SHL/SRL/SRA/AND:
4880   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
4881   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
4882        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
4883       N0.getOperand(1) == N1.getOperand(1)) {
4884     // If either operand has other uses, this transform is not an improvement.
4885     if (!N0.hasOneUse() || !N1.hasOneUse())
4886       return SDValue();
4887     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4888     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
4889   }
4890 
4891   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
4892   if (HandOpcode == ISD::BSWAP) {
4893     // If either operand has other uses, this transform is not an improvement.
4894     if (!N0.hasOneUse() || !N1.hasOneUse())
4895       return SDValue();
4896     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4897     return DAG.getNode(HandOpcode, DL, VT, Logic);
4898   }
4899 
4900   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
4901   // Only perform this optimization up until type legalization, before
4902   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
4903   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
4904   // we don't want to undo this promotion.
4905   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
4906   // on scalars.
4907   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
4908        Level <= AfterLegalizeTypes) {
4909     // Input types must be integer and the same.
4910     if (XVT.isInteger() && XVT == Y.getValueType() &&
4911         !(VT.isVector() && TLI.isTypeLegal(VT) &&
4912           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
4913       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4914       return DAG.getNode(HandOpcode, DL, VT, Logic);
4915     }
4916   }
4917 
4918   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
4919   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
4920   // If both shuffles use the same mask, and both shuffle within a single
4921   // vector, then it is worthwhile to move the swizzle after the operation.
4922   // The type-legalizer generates this pattern when loading illegal
4923   // vector types from memory. In many cases this allows additional shuffle
4924   // optimizations.
4925   // There are other cases where moving the shuffle after the xor/and/or
4926   // is profitable even if shuffles don't perform a swizzle.
4927   // If both shuffles use the same mask, and both shuffles have the same first
4928   // or second operand, then it might still be profitable to move the shuffle
4929   // after the xor/and/or operation.
4930   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
4931     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
4932     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
4933     assert(X.getValueType() == Y.getValueType() &&
4934            "Inputs to shuffles are not the same type");
4935 
4936     // Check that both shuffles use the same mask. The masks are known to be of
4937     // the same length because the result vector type is the same.
4938     // Check also that shuffles have only one use to avoid introducing extra
4939     // instructions.
4940     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
4941         !SVN0->getMask().equals(SVN1->getMask()))
4942       return SDValue();
4943 
4944     // Don't try to fold this node if it requires introducing a
4945     // build vector of all zeros that might be illegal at this stage.
4946     SDValue ShOp = N0.getOperand(1);
4947     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
4948       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4949 
4950     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
4951     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
4952       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
4953                                   N0.getOperand(0), N1.getOperand(0));
4954       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
4955     }
4956 
4957     // Don't try to fold this node if it requires introducing a
4958     // build vector of all zeros that might be illegal at this stage.
4959     ShOp = N0.getOperand(0);
4960     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
4961       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4962 
4963     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
4964     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
4965       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
4966                                   N1.getOperand(1));
4967       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
4968     }
4969   }
4970 
4971   return SDValue();
4972 }
4973 
4974 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)4975 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
4976                                        const SDLoc &DL) {
4977   SDValue LL, LR, RL, RR, N0CC, N1CC;
4978   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
4979       !isSetCCEquivalent(N1, RL, RR, N1CC))
4980     return SDValue();
4981 
4982   assert(N0.getValueType() == N1.getValueType() &&
4983          "Unexpected operand types for bitwise logic op");
4984   assert(LL.getValueType() == LR.getValueType() &&
4985          RL.getValueType() == RR.getValueType() &&
4986          "Unexpected operand types for setcc");
4987 
4988   // If we're here post-legalization or the logic op type is not i1, the logic
4989   // op type must match a setcc result type. Also, all folds require new
4990   // operations on the left and right operands, so those types must match.
4991   EVT VT = N0.getValueType();
4992   EVT OpVT = LL.getValueType();
4993   if (LegalOperations || VT.getScalarType() != MVT::i1)
4994     if (VT != getSetCCResultType(OpVT))
4995       return SDValue();
4996   if (OpVT != RL.getValueType())
4997     return SDValue();
4998 
4999   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
5000   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
5001   bool IsInteger = OpVT.isInteger();
5002   if (LR == RR && CC0 == CC1 && IsInteger) {
5003     bool IsZero = isNullOrNullSplat(LR);
5004     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
5005 
5006     // All bits clear?
5007     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5008     // All sign bits clear?
5009     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5010     // Any bits set?
5011     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5012     // Any sign bits set?
5013     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5014 
5015     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
5016     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5017     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
5018     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
5019     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5020       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
5021       AddToWorklist(Or.getNode());
5022       return DAG.getSetCC(DL, VT, Or, LR, CC1);
5023     }
5024 
5025     // All bits set?
5026     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5027     // All sign bits set?
5028     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5029     // Any bits clear?
5030     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5031     // Any sign bits clear?
5032     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5033 
5034     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
5035     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
5036     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
5037     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
5038     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
5039       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
5040       AddToWorklist(And.getNode());
5041       return DAG.getSetCC(DL, VT, And, LR, CC1);
5042     }
5043   }
5044 
5045   // TODO: What is the 'or' equivalent of this fold?
5046   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
5047   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
5048       IsInteger && CC0 == ISD::SETNE &&
5049       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
5050        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
5051     SDValue One = DAG.getConstant(1, DL, OpVT);
5052     SDValue Two = DAG.getConstant(2, DL, OpVT);
5053     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
5054     AddToWorklist(Add.getNode());
5055     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
5056   }
5057 
5058   // Try more general transforms if the predicates match and the only user of
5059   // the compares is the 'and' or 'or'.
5060   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
5061       N0.hasOneUse() && N1.hasOneUse()) {
5062     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
5063     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
5064     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
5065       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
5066       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
5067       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
5068       SDValue Zero = DAG.getConstant(0, DL, OpVT);
5069       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
5070     }
5071 
5072     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
5073     // TODO - support non-uniform vector amounts.
5074     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
5075       // Match a shared variable operand and 2 non-opaque constant operands.
5076       ConstantSDNode *C0 = isConstOrConstSplat(LR);
5077       ConstantSDNode *C1 = isConstOrConstSplat(RR);
5078       if (LL == RL && C0 && C1 && !C0->isOpaque() && !C1->isOpaque()) {
5079         const APInt &CMax =
5080             APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
5081         const APInt &CMin =
5082             APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
5083         // The difference of the constants must be a single bit.
5084         if ((CMax - CMin).isPowerOf2()) {
5085           // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
5086           // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
5087           SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
5088           SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
5089           SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
5090           SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
5091           SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
5092           SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
5093           SDValue Zero = DAG.getConstant(0, DL, OpVT);
5094           return DAG.getSetCC(DL, VT, And, Zero, CC0);
5095         }
5096       }
5097     }
5098   }
5099 
5100   // Canonicalize equivalent operands to LL == RL.
5101   if (LL == RR && LR == RL) {
5102     CC1 = ISD::getSetCCSwappedOperands(CC1);
5103     std::swap(RL, RR);
5104   }
5105 
5106   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5107   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5108   if (LL == RL && LR == RR) {
5109     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
5110                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
5111     if (NewCC != ISD::SETCC_INVALID &&
5112         (!LegalOperations ||
5113          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
5114           TLI.isOperationLegal(ISD::SETCC, OpVT))))
5115       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
5116   }
5117 
5118   return SDValue();
5119 }
5120 
5121 /// This contains all DAGCombine rules which reduce two values combined by
5122 /// an And operation to a single value. This makes them reusable in the context
5123 /// of visitSELECT(). Rules involving constants are not included as
5124 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)5125 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
5126   EVT VT = N1.getValueType();
5127   SDLoc DL(N);
5128 
5129   // fold (and x, undef) -> 0
5130   if (N0.isUndef() || N1.isUndef())
5131     return DAG.getConstant(0, DL, VT);
5132 
5133   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
5134     return V;
5135 
5136   // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
5137   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
5138       VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
5139     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5140       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
5141         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
5142         // immediate for an add, but it is legal if its top c2 bits are set,
5143         // transform the ADD so the immediate doesn't need to be materialized
5144         // in a register.
5145         APInt ADDC = ADDI->getAPIntValue();
5146         APInt SRLC = SRLI->getAPIntValue();
5147         if (ADDC.getMinSignedBits() <= 64 &&
5148             SRLC.ult(VT.getSizeInBits()) &&
5149             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5150           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
5151                                              SRLC.getZExtValue());
5152           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
5153             ADDC |= Mask;
5154             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5155               SDLoc DL0(N0);
5156               SDValue NewAdd =
5157                 DAG.getNode(ISD::ADD, DL0, VT,
5158                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
5159               CombineTo(N0.getNode(), NewAdd);
5160               // Return N so it doesn't get rechecked!
5161               return SDValue(N, 0);
5162             }
5163           }
5164         }
5165       }
5166     }
5167   }
5168 
5169   // Reduce bit extract of low half of an integer to the narrower type.
5170   // (and (srl i64:x, K), KMask) ->
5171   //   (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
5172   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
5173     if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
5174       if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5175         unsigned Size = VT.getSizeInBits();
5176         const APInt &AndMask = CAnd->getAPIntValue();
5177         unsigned ShiftBits = CShift->getZExtValue();
5178 
5179         // Bail out, this node will probably disappear anyway.
5180         if (ShiftBits == 0)
5181           return SDValue();
5182 
5183         unsigned MaskBits = AndMask.countTrailingOnes();
5184         EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
5185 
5186         if (AndMask.isMask() &&
5187             // Required bits must not span the two halves of the integer and
5188             // must fit in the half size type.
5189             (ShiftBits + MaskBits <= Size / 2) &&
5190             TLI.isNarrowingProfitable(VT, HalfVT) &&
5191             TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
5192             TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
5193             TLI.isTruncateFree(VT, HalfVT) &&
5194             TLI.isZExtFree(HalfVT, VT)) {
5195           // The isNarrowingProfitable is to avoid regressions on PPC and
5196           // AArch64 which match a few 64-bit bit insert / bit extract patterns
5197           // on downstream users of this. Those patterns could probably be
5198           // extended to handle extensions mixed in.
5199 
5200           SDValue SL(N0);
5201           assert(MaskBits <= Size);
5202 
5203           // Extracting the highest bit of the low half.
5204           EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
5205           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
5206                                       N0.getOperand(0));
5207 
5208           SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
5209           SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
5210           SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
5211           SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
5212           return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
5213         }
5214       }
5215     }
5216   }
5217 
5218   return SDValue();
5219 }
5220 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)5221 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
5222                                    EVT LoadResultTy, EVT &ExtVT) {
5223   if (!AndC->getAPIntValue().isMask())
5224     return false;
5225 
5226   unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
5227 
5228   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5229   EVT LoadedVT = LoadN->getMemoryVT();
5230 
5231   if (ExtVT == LoadedVT &&
5232       (!LegalOperations ||
5233        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
5234     // ZEXTLOAD will match without needing to change the size of the value being
5235     // loaded.
5236     return true;
5237   }
5238 
5239   // Do not change the width of a volatile or atomic loads.
5240   if (!LoadN->isSimple())
5241     return false;
5242 
5243   // Do not generate loads of non-round integer types since these can
5244   // be expensive (and would be wrong if the type is not byte sized).
5245   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
5246     return false;
5247 
5248   if (LegalOperations &&
5249       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
5250     return false;
5251 
5252   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
5253     return false;
5254 
5255   return true;
5256 }
5257 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)5258 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
5259                                     ISD::LoadExtType ExtType, EVT &MemVT,
5260                                     unsigned ShAmt) {
5261   if (!LDST)
5262     return false;
5263   // Only allow byte offsets.
5264   if (ShAmt % 8)
5265     return false;
5266 
5267   // Do not generate loads of non-round integer types since these can
5268   // be expensive (and would be wrong if the type is not byte sized).
5269   if (!MemVT.isRound())
5270     return false;
5271 
5272   // Don't change the width of a volatile or atomic loads.
5273   if (!LDST->isSimple())
5274     return false;
5275 
5276   EVT LdStMemVT = LDST->getMemoryVT();
5277 
5278   // Bail out when changing the scalable property, since we can't be sure that
5279   // we're actually narrowing here.
5280   if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
5281     return false;
5282 
5283   // Verify that we are actually reducing a load width here.
5284   if (LdStMemVT.bitsLT(MemVT))
5285     return false;
5286 
5287   // Ensure that this isn't going to produce an unsupported memory access.
5288   if (ShAmt) {
5289     assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
5290     const unsigned ByteShAmt = ShAmt / 8;
5291     const Align LDSTAlign = LDST->getAlign();
5292     const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
5293     if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
5294                                 LDST->getAddressSpace(), NarrowAlign,
5295                                 LDST->getMemOperand()->getFlags()))
5296       return false;
5297   }
5298 
5299   // It's not possible to generate a constant of extended or untyped type.
5300   EVT PtrType = LDST->getBasePtr().getValueType();
5301   if (PtrType == MVT::Untyped || PtrType.isExtended())
5302     return false;
5303 
5304   if (isa<LoadSDNode>(LDST)) {
5305     LoadSDNode *Load = cast<LoadSDNode>(LDST);
5306     // Don't transform one with multiple uses, this would require adding a new
5307     // load.
5308     if (!SDValue(Load, 0).hasOneUse())
5309       return false;
5310 
5311     if (LegalOperations &&
5312         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
5313       return false;
5314 
5315     // For the transform to be legal, the load must produce only two values
5316     // (the value loaded and the chain).  Don't transform a pre-increment
5317     // load, for example, which produces an extra value.  Otherwise the
5318     // transformation is not equivalent, and the downstream logic to replace
5319     // uses gets things wrong.
5320     if (Load->getNumValues() > 2)
5321       return false;
5322 
5323     // If the load that we're shrinking is an extload and we're not just
5324     // discarding the extension we can't simply shrink the load. Bail.
5325     // TODO: It would be possible to merge the extensions in some cases.
5326     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
5327         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5328       return false;
5329 
5330     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
5331       return false;
5332   } else {
5333     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
5334     StoreSDNode *Store = cast<StoreSDNode>(LDST);
5335     // Can't write outside the original store
5336     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5337       return false;
5338 
5339     if (LegalOperations &&
5340         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
5341       return false;
5342   }
5343   return true;
5344 }
5345 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)5346 bool DAGCombiner::SearchForAndLoads(SDNode *N,
5347                                     SmallVectorImpl<LoadSDNode*> &Loads,
5348                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
5349                                     ConstantSDNode *Mask,
5350                                     SDNode *&NodeToMask) {
5351   // Recursively search for the operands, looking for loads which can be
5352   // narrowed.
5353   for (SDValue Op : N->op_values()) {
5354     if (Op.getValueType().isVector())
5355       return false;
5356 
5357     // Some constants may need fixing up later if they are too large.
5358     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
5359       if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
5360           (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
5361         NodesWithConsts.insert(N);
5362       continue;
5363     }
5364 
5365     if (!Op.hasOneUse())
5366       return false;
5367 
5368     switch(Op.getOpcode()) {
5369     case ISD::LOAD: {
5370       auto *Load = cast<LoadSDNode>(Op);
5371       EVT ExtVT;
5372       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
5373           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
5374 
5375         // ZEXTLOAD is already small enough.
5376         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
5377             ExtVT.bitsGE(Load->getMemoryVT()))
5378           continue;
5379 
5380         // Use LE to convert equal sized loads to zext.
5381         if (ExtVT.bitsLE(Load->getMemoryVT()))
5382           Loads.push_back(Load);
5383 
5384         continue;
5385       }
5386       return false;
5387     }
5388     case ISD::ZERO_EXTEND:
5389     case ISD::AssertZext: {
5390       unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
5391       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5392       EVT VT = Op.getOpcode() == ISD::AssertZext ?
5393         cast<VTSDNode>(Op.getOperand(1))->getVT() :
5394         Op.getOperand(0).getValueType();
5395 
5396       // We can accept extending nodes if the mask is wider or an equal
5397       // width to the original type.
5398       if (ExtVT.bitsGE(VT))
5399         continue;
5400       break;
5401     }
5402     case ISD::OR:
5403     case ISD::XOR:
5404     case ISD::AND:
5405       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
5406                              NodeToMask))
5407         return false;
5408       continue;
5409     }
5410 
5411     // Allow one node which will masked along with any loads found.
5412     if (NodeToMask)
5413       return false;
5414 
5415     // Also ensure that the node to be masked only produces one data result.
5416     NodeToMask = Op.getNode();
5417     if (NodeToMask->getNumValues() > 1) {
5418       bool HasValue = false;
5419       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
5420         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
5421         if (VT != MVT::Glue && VT != MVT::Other) {
5422           if (HasValue) {
5423             NodeToMask = nullptr;
5424             return false;
5425           }
5426           HasValue = true;
5427         }
5428       }
5429       assert(HasValue && "Node to be masked has no data result?");
5430     }
5431   }
5432   return true;
5433 }
5434 
BackwardsPropagateMask(SDNode * N)5435 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
5436   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
5437   if (!Mask)
5438     return false;
5439 
5440   if (!Mask->getAPIntValue().isMask())
5441     return false;
5442 
5443   // No need to do anything if the and directly uses a load.
5444   if (isa<LoadSDNode>(N->getOperand(0)))
5445     return false;
5446 
5447   SmallVector<LoadSDNode*, 8> Loads;
5448   SmallPtrSet<SDNode*, 2> NodesWithConsts;
5449   SDNode *FixupNode = nullptr;
5450   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
5451     if (Loads.size() == 0)
5452       return false;
5453 
5454     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
5455     SDValue MaskOp = N->getOperand(1);
5456 
5457     // If it exists, fixup the single node we allow in the tree that needs
5458     // masking.
5459     if (FixupNode) {
5460       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
5461       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
5462                                 FixupNode->getValueType(0),
5463                                 SDValue(FixupNode, 0), MaskOp);
5464       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
5465       if (And.getOpcode() == ISD ::AND)
5466         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
5467     }
5468 
5469     // Narrow any constants that need it.
5470     for (auto *LogicN : NodesWithConsts) {
5471       SDValue Op0 = LogicN->getOperand(0);
5472       SDValue Op1 = LogicN->getOperand(1);
5473 
5474       if (isa<ConstantSDNode>(Op0))
5475           std::swap(Op0, Op1);
5476 
5477       SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
5478                                 Op1, MaskOp);
5479 
5480       DAG.UpdateNodeOperands(LogicN, Op0, And);
5481     }
5482 
5483     // Create narrow loads.
5484     for (auto *Load : Loads) {
5485       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
5486       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
5487                                 SDValue(Load, 0), MaskOp);
5488       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
5489       if (And.getOpcode() == ISD ::AND)
5490         And = SDValue(
5491             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
5492       SDValue NewLoad = ReduceLoadWidth(And.getNode());
5493       assert(NewLoad &&
5494              "Shouldn't be masking the load if it can't be narrowed");
5495       CombineTo(Load, NewLoad, NewLoad.getValue(1));
5496     }
5497     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
5498     return true;
5499   }
5500   return false;
5501 }
5502 
5503 // Unfold
5504 //    x &  (-1 'logical shift' y)
5505 // To
5506 //    (x 'opposite logical shift' y) 'logical shift' y
5507 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)5508 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
5509   assert(N->getOpcode() == ISD::AND);
5510 
5511   SDValue N0 = N->getOperand(0);
5512   SDValue N1 = N->getOperand(1);
5513 
5514   // Do we actually prefer shifts over mask?
5515   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
5516     return SDValue();
5517 
5518   // Try to match  (-1 '[outer] logical shift' y)
5519   unsigned OuterShift;
5520   unsigned InnerShift; // The opposite direction to the OuterShift.
5521   SDValue Y;           // Shift amount.
5522   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
5523     if (!M.hasOneUse())
5524       return false;
5525     OuterShift = M->getOpcode();
5526     if (OuterShift == ISD::SHL)
5527       InnerShift = ISD::SRL;
5528     else if (OuterShift == ISD::SRL)
5529       InnerShift = ISD::SHL;
5530     else
5531       return false;
5532     if (!isAllOnesConstant(M->getOperand(0)))
5533       return false;
5534     Y = M->getOperand(1);
5535     return true;
5536   };
5537 
5538   SDValue X;
5539   if (matchMask(N1))
5540     X = N0;
5541   else if (matchMask(N0))
5542     X = N1;
5543   else
5544     return SDValue();
5545 
5546   SDLoc DL(N);
5547   EVT VT = N->getValueType(0);
5548 
5549   //     tmp = x   'opposite logical shift' y
5550   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
5551   //     ret = tmp 'logical shift' y
5552   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
5553 
5554   return T1;
5555 }
5556 
5557 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
5558 /// For a target with a bit test, this is expected to become test + set and save
5559 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)5560 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
5561   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
5562 
5563   // This is probably not worthwhile without a supported type.
5564   EVT VT = And->getValueType(0);
5565   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5566   if (!TLI.isTypeLegal(VT))
5567     return SDValue();
5568 
5569   // Look through an optional extension and find a 'not'.
5570   // TODO: Should we favor test+set even without the 'not' op?
5571   SDValue Not = And->getOperand(0), And1 = And->getOperand(1);
5572   if (Not.getOpcode() == ISD::ANY_EXTEND)
5573     Not = Not.getOperand(0);
5574   if (!isBitwiseNot(Not) || !Not.hasOneUse() || !isOneConstant(And1))
5575     return SDValue();
5576 
5577   // Look though an optional truncation. The source operand may not be the same
5578   // type as the original 'and', but that is ok because we are masking off
5579   // everything but the low bit.
5580   SDValue Srl = Not.getOperand(0);
5581   if (Srl.getOpcode() == ISD::TRUNCATE)
5582     Srl = Srl.getOperand(0);
5583 
5584   // Match a shift-right by constant.
5585   if (Srl.getOpcode() != ISD::SRL || !Srl.hasOneUse() ||
5586       !isa<ConstantSDNode>(Srl.getOperand(1)))
5587     return SDValue();
5588 
5589   // We might have looked through casts that make this transform invalid.
5590   // TODO: If the source type is wider than the result type, do the mask and
5591   //       compare in the source type.
5592   const APInt &ShiftAmt = Srl.getConstantOperandAPInt(1);
5593   unsigned VTBitWidth = VT.getSizeInBits();
5594   if (ShiftAmt.uge(VTBitWidth))
5595     return SDValue();
5596 
5597   // Turn this into a bit-test pattern using mask op + setcc:
5598   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
5599   SDLoc DL(And);
5600   SDValue X = DAG.getZExtOrTrunc(Srl.getOperand(0), DL, VT);
5601   EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
5602   SDValue Mask = DAG.getConstant(
5603       APInt::getOneBitSet(VTBitWidth, ShiftAmt.getZExtValue()), DL, VT);
5604   SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
5605   SDValue Zero = DAG.getConstant(0, DL, VT);
5606   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
5607   return DAG.getZExtOrTrunc(Setcc, DL, VT);
5608 }
5609 
visitAND(SDNode * N)5610 SDValue DAGCombiner::visitAND(SDNode *N) {
5611   SDValue N0 = N->getOperand(0);
5612   SDValue N1 = N->getOperand(1);
5613   EVT VT = N1.getValueType();
5614 
5615   // x & x --> x
5616   if (N0 == N1)
5617     return N0;
5618 
5619   // fold vector ops
5620   if (VT.isVector()) {
5621     if (SDValue FoldedVOp = SimplifyVBinOp(N))
5622       return FoldedVOp;
5623 
5624     // fold (and x, 0) -> 0, vector edition
5625     if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
5626       // do not return N0, because undef node may exist in N0
5627       return DAG.getConstant(APInt::getNullValue(N0.getScalarValueSizeInBits()),
5628                              SDLoc(N), N0.getValueType());
5629     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5630       // do not return N1, because undef node may exist in N1
5631       return DAG.getConstant(APInt::getNullValue(N1.getScalarValueSizeInBits()),
5632                              SDLoc(N), N1.getValueType());
5633 
5634     // fold (and x, -1) -> x, vector edition
5635     if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
5636       return N1;
5637     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
5638       return N0;
5639 
5640     // fold (and (masked_load) (build_vec (x, ...))) to zext_masked_load
5641     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
5642     auto *BVec = dyn_cast<BuildVectorSDNode>(N1);
5643     if (MLoad && BVec && MLoad->getExtensionType() == ISD::EXTLOAD &&
5644         N0.hasOneUse() && N1.hasOneUse()) {
5645       EVT LoadVT = MLoad->getMemoryVT();
5646       EVT ExtVT = VT;
5647       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
5648         // For this AND to be a zero extension of the masked load the elements
5649         // of the BuildVec must mask the bottom bits of the extended element
5650         // type
5651         if (ConstantSDNode *Splat = BVec->getConstantSplatNode()) {
5652           uint64_t ElementSize =
5653               LoadVT.getVectorElementType().getScalarSizeInBits();
5654           if (Splat->getAPIntValue().isMask(ElementSize)) {
5655             return DAG.getMaskedLoad(
5656                 ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
5657                 MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
5658                 LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
5659                 ISD::ZEXTLOAD, MLoad->isExpandingLoad());
5660           }
5661         }
5662       }
5663     }
5664   }
5665 
5666   // fold (and c1, c2) -> c1&c2
5667   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5668   if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
5669     return C;
5670 
5671   // canonicalize constant to RHS
5672   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5673       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5674     return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
5675 
5676   // fold (and x, -1) -> x
5677   if (isAllOnesConstant(N1))
5678     return N0;
5679 
5680   // if (and x, c) is known to be zero, return 0
5681   unsigned BitWidth = VT.getScalarSizeInBits();
5682   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
5683                                    APInt::getAllOnesValue(BitWidth)))
5684     return DAG.getConstant(0, SDLoc(N), VT);
5685 
5686   if (SDValue NewSel = foldBinOpIntoSelect(N))
5687     return NewSel;
5688 
5689   // reassociate and
5690   if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
5691     return RAND;
5692 
5693   // Try to convert a constant mask AND into a shuffle clear mask.
5694   if (VT.isVector())
5695     if (SDValue Shuffle = XformToShuffleWithZero(N))
5696       return Shuffle;
5697 
5698   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
5699     return Combined;
5700 
5701   // fold (and (or x, C), D) -> D if (C & D) == D
5702   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
5703     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
5704   };
5705   if (N0.getOpcode() == ISD::OR &&
5706       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
5707     return N1;
5708   // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
5709   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
5710     SDValue N0Op0 = N0.getOperand(0);
5711     APInt Mask = ~N1C->getAPIntValue();
5712     Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
5713     if (DAG.MaskedValueIsZero(N0Op0, Mask)) {
5714       SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
5715                                  N0.getValueType(), N0Op0);
5716 
5717       // Replace uses of the AND with uses of the Zero extend node.
5718       CombineTo(N, Zext);
5719 
5720       // We actually want to replace all uses of the any_extend with the
5721       // zero_extend, to avoid duplicating things.  This will later cause this
5722       // AND to be folded.
5723       CombineTo(N0.getNode(), Zext);
5724       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
5725     }
5726   }
5727 
5728   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
5729   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
5730   // already be zero by virtue of the width of the base type of the load.
5731   //
5732   // the 'X' node here can either be nothing or an extract_vector_elt to catch
5733   // more cases.
5734   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
5735        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
5736        N0.getOperand(0).getOpcode() == ISD::LOAD &&
5737        N0.getOperand(0).getResNo() == 0) ||
5738       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
5739     LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
5740                                          N0 : N0.getOperand(0) );
5741 
5742     // Get the constant (if applicable) the zero'th operand is being ANDed with.
5743     // This can be a pure constant or a vector splat, in which case we treat the
5744     // vector as a scalar and use the splat value.
5745     APInt Constant = APInt::getNullValue(1);
5746     if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
5747       Constant = C->getAPIntValue();
5748     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
5749       APInt SplatValue, SplatUndef;
5750       unsigned SplatBitSize;
5751       bool HasAnyUndefs;
5752       bool IsSplat = Vector->isConstantSplat(SplatValue, SplatUndef,
5753                                              SplatBitSize, HasAnyUndefs);
5754       if (IsSplat) {
5755         // Undef bits can contribute to a possible optimisation if set, so
5756         // set them.
5757         SplatValue |= SplatUndef;
5758 
5759         // The splat value may be something like "0x00FFFFFF", which means 0 for
5760         // the first vector value and FF for the rest, repeating. We need a mask
5761         // that will apply equally to all members of the vector, so AND all the
5762         // lanes of the constant together.
5763         unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
5764 
5765         // If the splat value has been compressed to a bitlength lower
5766         // than the size of the vector lane, we need to re-expand it to
5767         // the lane size.
5768         if (EltBitWidth > SplatBitSize)
5769           for (SplatValue = SplatValue.zextOrTrunc(EltBitWidth);
5770                SplatBitSize < EltBitWidth; SplatBitSize = SplatBitSize * 2)
5771             SplatValue |= SplatValue.shl(SplatBitSize);
5772 
5773         // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
5774         // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
5775         if ((SplatBitSize % EltBitWidth) == 0) {
5776           Constant = APInt::getAllOnesValue(EltBitWidth);
5777           for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
5778             Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
5779         }
5780       }
5781     }
5782 
5783     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
5784     // actually legal and isn't going to get expanded, else this is a false
5785     // optimisation.
5786     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
5787                                                     Load->getValueType(0),
5788                                                     Load->getMemoryVT());
5789 
5790     // Resize the constant to the same size as the original memory access before
5791     // extension. If it is still the AllOnesValue then this AND is completely
5792     // unneeded.
5793     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
5794 
5795     bool B;
5796     switch (Load->getExtensionType()) {
5797     default: B = false; break;
5798     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
5799     case ISD::ZEXTLOAD:
5800     case ISD::NON_EXTLOAD: B = true; break;
5801     }
5802 
5803     if (B && Constant.isAllOnesValue()) {
5804       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
5805       // preserve semantics once we get rid of the AND.
5806       SDValue NewLoad(Load, 0);
5807 
5808       // Fold the AND away. NewLoad may get replaced immediately.
5809       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
5810 
5811       if (Load->getExtensionType() == ISD::EXTLOAD) {
5812         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
5813                               Load->getValueType(0), SDLoc(Load),
5814                               Load->getChain(), Load->getBasePtr(),
5815                               Load->getOffset(), Load->getMemoryVT(),
5816                               Load->getMemOperand());
5817         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
5818         if (Load->getNumValues() == 3) {
5819           // PRE/POST_INC loads have 3 values.
5820           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
5821                            NewLoad.getValue(2) };
5822           CombineTo(Load, To, 3, true);
5823         } else {
5824           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
5825         }
5826       }
5827 
5828       return SDValue(N, 0); // Return N so it doesn't get rechecked!
5829     }
5830   }
5831 
5832   // fold (and (masked_gather x)) -> (zext_masked_gather x)
5833   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
5834     EVT MemVT = GN0->getMemoryVT();
5835     EVT ScalarVT = MemVT.getScalarType();
5836 
5837     if (SDValue(GN0, 0).hasOneUse() &&
5838         isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
5839         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
5840       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
5841                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
5842 
5843       SDValue ZExtLoad = DAG.getMaskedGather(
5844           DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops,
5845           GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD);
5846 
5847       CombineTo(N, ZExtLoad);
5848       AddToWorklist(ZExtLoad.getNode());
5849       // Avoid recheck of N.
5850       return SDValue(N, 0);
5851     }
5852   }
5853 
5854   // fold (and (load x), 255) -> (zextload x, i8)
5855   // fold (and (extload x, i16), 255) -> (zextload x, i8)
5856   // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8)
5857   if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
5858                                 (N0.getOpcode() == ISD::ANY_EXTEND &&
5859                                  N0.getOperand(0).getOpcode() == ISD::LOAD))) {
5860     if (SDValue Res = ReduceLoadWidth(N)) {
5861       LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
5862         ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
5863       AddToWorklist(N);
5864       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res);
5865       return SDValue(N, 0);
5866     }
5867   }
5868 
5869   if (LegalTypes) {
5870     // Attempt to propagate the AND back up to the leaves which, if they're
5871     // loads, can be combined to narrow loads and the AND node can be removed.
5872     // Perform after legalization so that extend nodes will already be
5873     // combined into the loads.
5874     if (BackwardsPropagateMask(N))
5875       return SDValue(N, 0);
5876   }
5877 
5878   if (SDValue Combined = visitANDLike(N0, N1, N))
5879     return Combined;
5880 
5881   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
5882   if (N0.getOpcode() == N1.getOpcode())
5883     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
5884       return V;
5885 
5886   // Masking the negated extension of a boolean is just the zero-extended
5887   // boolean:
5888   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
5889   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
5890   //
5891   // Note: the SimplifyDemandedBits fold below can make an information-losing
5892   // transform, and then we have no way to find this better fold.
5893   if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
5894     if (isNullOrNullSplat(N0.getOperand(0))) {
5895       SDValue SubRHS = N0.getOperand(1);
5896       if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
5897           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
5898         return SubRHS;
5899       if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
5900           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
5901         return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
5902     }
5903   }
5904 
5905   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
5906   // fold (and (sra)) -> (and (srl)) when possible.
5907   if (SimplifyDemandedBits(SDValue(N, 0)))
5908     return SDValue(N, 0);
5909 
5910   // fold (zext_inreg (extload x)) -> (zextload x)
5911   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
5912   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
5913       (ISD::isEXTLoad(N0.getNode()) ||
5914        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
5915     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
5916     EVT MemVT = LN0->getMemoryVT();
5917     // If we zero all the possible extended bits, then we can turn this into
5918     // a zextload if we are running before legalize or the operation is legal.
5919     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
5920     unsigned MemBitSize = MemVT.getScalarSizeInBits();
5921     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
5922     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
5923         ((!LegalOperations && LN0->isSimple()) ||
5924          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
5925       SDValue ExtLoad =
5926           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
5927                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
5928       AddToWorklist(N);
5929       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
5930       return SDValue(N, 0); // Return N so it doesn't get rechecked!
5931     }
5932   }
5933 
5934   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
5935   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
5936     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
5937                                            N0.getOperand(1), false))
5938       return BSwap;
5939   }
5940 
5941   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
5942     return Shifts;
5943 
5944   if (TLI.hasBitTest(N0, N1))
5945     if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
5946       return V;
5947 
5948   // Recognize the following pattern:
5949   //
5950   // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
5951   //
5952   // where bitmask is a mask that clears the upper bits of AndVT. The
5953   // number of bits in bitmask must be a power of two.
5954   auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
5955     if (LHS->getOpcode() != ISD::SIGN_EXTEND)
5956       return false;
5957 
5958     auto *C = dyn_cast<ConstantSDNode>(RHS);
5959     if (!C)
5960       return false;
5961 
5962     if (!C->getAPIntValue().isMask(
5963             LHS.getOperand(0).getValueType().getFixedSizeInBits()))
5964       return false;
5965 
5966     return true;
5967   };
5968 
5969   // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
5970   if (IsAndZeroExtMask(N0, N1))
5971     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0.getOperand(0));
5972 
5973   return SDValue();
5974 }
5975 
5976 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)5977 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
5978                                         bool DemandHighBits) {
5979   if (!LegalOperations)
5980     return SDValue();
5981 
5982   EVT VT = N->getValueType(0);
5983   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
5984     return SDValue();
5985   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
5986     return SDValue();
5987 
5988   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
5989   bool LookPassAnd0 = false;
5990   bool LookPassAnd1 = false;
5991   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
5992       std::swap(N0, N1);
5993   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
5994       std::swap(N0, N1);
5995   if (N0.getOpcode() == ISD::AND) {
5996     if (!N0.getNode()->hasOneUse())
5997       return SDValue();
5998     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5999     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
6000     // This is needed for X86.
6001     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
6002                   N01C->getZExtValue() != 0xFFFF))
6003       return SDValue();
6004     N0 = N0.getOperand(0);
6005     LookPassAnd0 = true;
6006   }
6007 
6008   if (N1.getOpcode() == ISD::AND) {
6009     if (!N1.getNode()->hasOneUse())
6010       return SDValue();
6011     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6012     if (!N11C || N11C->getZExtValue() != 0xFF)
6013       return SDValue();
6014     N1 = N1.getOperand(0);
6015     LookPassAnd1 = true;
6016   }
6017 
6018   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
6019     std::swap(N0, N1);
6020   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
6021     return SDValue();
6022   if (!N0.getNode()->hasOneUse() || !N1.getNode()->hasOneUse())
6023     return SDValue();
6024 
6025   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6026   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6027   if (!N01C || !N11C)
6028     return SDValue();
6029   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
6030     return SDValue();
6031 
6032   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
6033   SDValue N00 = N0->getOperand(0);
6034   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
6035     if (!N00.getNode()->hasOneUse())
6036       return SDValue();
6037     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
6038     if (!N001C || N001C->getZExtValue() != 0xFF)
6039       return SDValue();
6040     N00 = N00.getOperand(0);
6041     LookPassAnd0 = true;
6042   }
6043 
6044   SDValue N10 = N1->getOperand(0);
6045   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
6046     if (!N10.getNode()->hasOneUse())
6047       return SDValue();
6048     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
6049     // Also allow 0xFFFF since the bits will be shifted out. This is needed
6050     // for X86.
6051     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
6052                    N101C->getZExtValue() != 0xFFFF))
6053       return SDValue();
6054     N10 = N10.getOperand(0);
6055     LookPassAnd1 = true;
6056   }
6057 
6058   if (N00 != N10)
6059     return SDValue();
6060 
6061   // Make sure everything beyond the low halfword gets set to zero since the SRL
6062   // 16 will clear the top bits.
6063   unsigned OpSizeInBits = VT.getSizeInBits();
6064   if (DemandHighBits && OpSizeInBits > 16) {
6065     // If the left-shift isn't masked out then the only way this is a bswap is
6066     // if all bits beyond the low 8 are 0. In that case the entire pattern
6067     // reduces to a left shift anyway: leave it for other parts of the combiner.
6068     if (!LookPassAnd0)
6069       return SDValue();
6070 
6071     // However, if the right shift isn't masked out then it might be because
6072     // it's not needed. See if we can spot that too.
6073     if (!LookPassAnd1 &&
6074         !DAG.MaskedValueIsZero(
6075             N10, APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - 16)))
6076       return SDValue();
6077   }
6078 
6079   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
6080   if (OpSizeInBits > 16) {
6081     SDLoc DL(N);
6082     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
6083                       DAG.getConstant(OpSizeInBits - 16, DL,
6084                                       getShiftAmountTy(VT)));
6085   }
6086   return Res;
6087 }
6088 
6089 /// Return true if the specified node is an element that makes up a 32-bit
6090 /// packed halfword byteswap.
6091 /// ((x & 0x000000ff) << 8) |
6092 /// ((x & 0x0000ff00) >> 8) |
6093 /// ((x & 0x00ff0000) << 8) |
6094 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)6095 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
6096   if (!N.getNode()->hasOneUse())
6097     return false;
6098 
6099   unsigned Opc = N.getOpcode();
6100   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
6101     return false;
6102 
6103   SDValue N0 = N.getOperand(0);
6104   unsigned Opc0 = N0.getOpcode();
6105   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
6106     return false;
6107 
6108   ConstantSDNode *N1C = nullptr;
6109   // SHL or SRL: look upstream for AND mask operand
6110   if (Opc == ISD::AND)
6111     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6112   else if (Opc0 == ISD::AND)
6113     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6114   if (!N1C)
6115     return false;
6116 
6117   unsigned MaskByteOffset;
6118   switch (N1C->getZExtValue()) {
6119   default:
6120     return false;
6121   case 0xFF:       MaskByteOffset = 0; break;
6122   case 0xFF00:     MaskByteOffset = 1; break;
6123   case 0xFFFF:
6124     // In case demanded bits didn't clear the bits that will be shifted out.
6125     // This is needed for X86.
6126     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
6127       MaskByteOffset = 1;
6128       break;
6129     }
6130     return false;
6131   case 0xFF0000:   MaskByteOffset = 2; break;
6132   case 0xFF000000: MaskByteOffset = 3; break;
6133   }
6134 
6135   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
6136   if (Opc == ISD::AND) {
6137     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
6138       // (x >> 8) & 0xff
6139       // (x >> 8) & 0xff0000
6140       if (Opc0 != ISD::SRL)
6141         return false;
6142       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6143       if (!C || C->getZExtValue() != 8)
6144         return false;
6145     } else {
6146       // (x << 8) & 0xff00
6147       // (x << 8) & 0xff000000
6148       if (Opc0 != ISD::SHL)
6149         return false;
6150       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6151       if (!C || C->getZExtValue() != 8)
6152         return false;
6153     }
6154   } else if (Opc == ISD::SHL) {
6155     // (x & 0xff) << 8
6156     // (x & 0xff0000) << 8
6157     if (MaskByteOffset != 0 && MaskByteOffset != 2)
6158       return false;
6159     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6160     if (!C || C->getZExtValue() != 8)
6161       return false;
6162   } else { // Opc == ISD::SRL
6163     // (x & 0xff00) >> 8
6164     // (x & 0xff000000) >> 8
6165     if (MaskByteOffset != 1 && MaskByteOffset != 3)
6166       return false;
6167     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6168     if (!C || C->getZExtValue() != 8)
6169       return false;
6170   }
6171 
6172   if (Parts[MaskByteOffset])
6173     return false;
6174 
6175   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
6176   return true;
6177 }
6178 
6179 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)6180 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
6181   if (N.getOpcode() == ISD::OR)
6182     return isBSwapHWordElement(N.getOperand(0), Parts) &&
6183            isBSwapHWordElement(N.getOperand(1), Parts);
6184 
6185   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
6186     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
6187     if (!C || C->getAPIntValue() != 16)
6188       return false;
6189     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
6190     return true;
6191   }
6192 
6193   return false;
6194 }
6195 
6196 // Match this pattern:
6197 //   (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
6198 // And rewrite this to:
6199 //   (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT,EVT ShiftAmountTy)6200 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
6201                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
6202                                        SDValue N1, EVT VT, EVT ShiftAmountTy) {
6203   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
6204          "MatchBSwapHWordOrAndAnd: expecting i32");
6205   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6206     return SDValue();
6207   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
6208     return SDValue();
6209   // TODO: this is too restrictive; lifting this restriction requires more tests
6210   if (!N0->hasOneUse() || !N1->hasOneUse())
6211     return SDValue();
6212   ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
6213   ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
6214   if (!Mask0 || !Mask1)
6215     return SDValue();
6216   if (Mask0->getAPIntValue() != 0xff00ff00 ||
6217       Mask1->getAPIntValue() != 0x00ff00ff)
6218     return SDValue();
6219   SDValue Shift0 = N0.getOperand(0);
6220   SDValue Shift1 = N1.getOperand(0);
6221   if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
6222     return SDValue();
6223   ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
6224   ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
6225   if (!ShiftAmt0 || !ShiftAmt1)
6226     return SDValue();
6227   if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
6228     return SDValue();
6229   if (Shift0.getOperand(0) != Shift1.getOperand(0))
6230     return SDValue();
6231 
6232   SDLoc DL(N);
6233   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
6234   SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
6235   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6236 }
6237 
6238 /// Match a 32-bit packed halfword bswap. That is
6239 /// ((x & 0x000000ff) << 8) |
6240 /// ((x & 0x0000ff00) >> 8) |
6241 /// ((x & 0x00ff0000) << 8) |
6242 /// ((x & 0xff000000) >> 8)
6243 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)6244 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
6245   if (!LegalOperations)
6246     return SDValue();
6247 
6248   EVT VT = N->getValueType(0);
6249   if (VT != MVT::i32)
6250     return SDValue();
6251   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6252     return SDValue();
6253 
6254   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
6255                                               getShiftAmountTy(VT)))
6256   return BSwap;
6257 
6258   // Try again with commuted operands.
6259   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
6260                                               getShiftAmountTy(VT)))
6261   return BSwap;
6262 
6263 
6264   // Look for either
6265   // (or (bswaphpair), (bswaphpair))
6266   // (or (or (bswaphpair), (and)), (and))
6267   // (or (or (and), (bswaphpair)), (and))
6268   SDNode *Parts[4] = {};
6269 
6270   if (isBSwapHWordPair(N0, Parts)) {
6271     // (or (or (and), (and)), (or (and), (and)))
6272     if (!isBSwapHWordPair(N1, Parts))
6273       return SDValue();
6274   } else if (N0.getOpcode() == ISD::OR) {
6275     // (or (or (or (and), (and)), (and)), (and))
6276     if (!isBSwapHWordElement(N1, Parts))
6277       return SDValue();
6278     SDValue N00 = N0.getOperand(0);
6279     SDValue N01 = N0.getOperand(1);
6280     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
6281         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
6282       return SDValue();
6283   } else
6284     return SDValue();
6285 
6286   // Make sure the parts are all coming from the same node.
6287   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
6288     return SDValue();
6289 
6290   SDLoc DL(N);
6291   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
6292                               SDValue(Parts[0], 0));
6293 
6294   // Result of the bswap should be rotated by 16. If it's not legal, then
6295   // do  (x << 16) | (x >> 16).
6296   SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
6297   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
6298     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
6299   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6300     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6301   return DAG.getNode(ISD::OR, DL, VT,
6302                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
6303                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
6304 }
6305 
6306 /// This contains all DAGCombine rules which reduce two values combined by
6307 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)6308 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
6309   EVT VT = N1.getValueType();
6310   SDLoc DL(N);
6311 
6312   // fold (or x, undef) -> -1
6313   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
6314     return DAG.getAllOnesConstant(DL, VT);
6315 
6316   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
6317     return V;
6318 
6319   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
6320   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
6321       // Don't increase # computations.
6322       (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
6323     // We can only do this xform if we know that bits from X that are set in C2
6324     // but not in C1 are already zero.  Likewise for Y.
6325     if (const ConstantSDNode *N0O1C =
6326         getAsNonOpaqueConstant(N0.getOperand(1))) {
6327       if (const ConstantSDNode *N1O1C =
6328           getAsNonOpaqueConstant(N1.getOperand(1))) {
6329         // We can only do this xform if we know that bits from X that are set in
6330         // C2 but not in C1 are already zero.  Likewise for Y.
6331         const APInt &LHSMask = N0O1C->getAPIntValue();
6332         const APInt &RHSMask = N1O1C->getAPIntValue();
6333 
6334         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
6335             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
6336           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
6337                                   N0.getOperand(0), N1.getOperand(0));
6338           return DAG.getNode(ISD::AND, DL, VT, X,
6339                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
6340         }
6341       }
6342     }
6343   }
6344 
6345   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
6346   if (N0.getOpcode() == ISD::AND &&
6347       N1.getOpcode() == ISD::AND &&
6348       N0.getOperand(0) == N1.getOperand(0) &&
6349       // Don't increase # computations.
6350       (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
6351     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
6352                             N0.getOperand(1), N1.getOperand(1));
6353     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
6354   }
6355 
6356   return SDValue();
6357 }
6358 
6359 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)6360 static SDValue visitORCommutative(
6361     SelectionDAG &DAG, SDValue N0, SDValue N1, SDNode *N) {
6362   EVT VT = N0.getValueType();
6363   if (N0.getOpcode() == ISD::AND) {
6364     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
6365     if (isBitwiseNot(N0.getOperand(1)) && N0.getOperand(1).getOperand(0) == N1)
6366       return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(0), N1);
6367 
6368     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
6369     if (isBitwiseNot(N0.getOperand(0)) && N0.getOperand(0).getOperand(0) == N1)
6370       return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(1), N1);
6371   }
6372 
6373   return SDValue();
6374 }
6375 
visitOR(SDNode * N)6376 SDValue DAGCombiner::visitOR(SDNode *N) {
6377   SDValue N0 = N->getOperand(0);
6378   SDValue N1 = N->getOperand(1);
6379   EVT VT = N1.getValueType();
6380 
6381   // x | x --> x
6382   if (N0 == N1)
6383     return N0;
6384 
6385   // fold vector ops
6386   if (VT.isVector()) {
6387     if (SDValue FoldedVOp = SimplifyVBinOp(N))
6388       return FoldedVOp;
6389 
6390     // fold (or x, 0) -> x, vector edition
6391     if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
6392       return N1;
6393     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6394       return N0;
6395 
6396     // fold (or x, -1) -> -1, vector edition
6397     if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
6398       // do not return N0, because undef node may exist in N0
6399       return DAG.getAllOnesConstant(SDLoc(N), N0.getValueType());
6400     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6401       // do not return N1, because undef node may exist in N1
6402       return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
6403 
6404     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
6405     // Do this only if the resulting shuffle is legal.
6406     if (isa<ShuffleVectorSDNode>(N0) &&
6407         isa<ShuffleVectorSDNode>(N1) &&
6408         // Avoid folding a node with illegal type.
6409         TLI.isTypeLegal(VT)) {
6410       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
6411       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
6412       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
6413       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
6414       // Ensure both shuffles have a zero input.
6415       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
6416         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
6417         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
6418         const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0);
6419         const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1);
6420         bool CanFold = true;
6421         int NumElts = VT.getVectorNumElements();
6422         SmallVector<int, 4> Mask(NumElts);
6423 
6424         for (int i = 0; i != NumElts; ++i) {
6425           int M0 = SV0->getMaskElt(i);
6426           int M1 = SV1->getMaskElt(i);
6427 
6428           // Determine if either index is pointing to a zero vector.
6429           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
6430           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
6431 
6432           // If one element is zero and the otherside is undef, keep undef.
6433           // This also handles the case that both are undef.
6434           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0)) {
6435             Mask[i] = -1;
6436             continue;
6437           }
6438 
6439           // Make sure only one of the elements is zero.
6440           if (M0Zero == M1Zero) {
6441             CanFold = false;
6442             break;
6443           }
6444 
6445           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
6446 
6447           // We have a zero and non-zero element. If the non-zero came from
6448           // SV0 make the index a LHS index. If it came from SV1, make it
6449           // a RHS index. We need to mod by NumElts because we don't care
6450           // which operand it came from in the original shuffles.
6451           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
6452         }
6453 
6454         if (CanFold) {
6455           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
6456           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
6457 
6458           SDValue LegalShuffle =
6459               TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
6460                                           Mask, DAG);
6461           if (LegalShuffle)
6462             return LegalShuffle;
6463         }
6464       }
6465     }
6466   }
6467 
6468   // fold (or c1, c2) -> c1|c2
6469   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
6470   if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
6471     return C;
6472 
6473   // canonicalize constant to RHS
6474   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6475      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6476     return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
6477 
6478   // fold (or x, 0) -> x
6479   if (isNullConstant(N1))
6480     return N0;
6481 
6482   // fold (or x, -1) -> -1
6483   if (isAllOnesConstant(N1))
6484     return N1;
6485 
6486   if (SDValue NewSel = foldBinOpIntoSelect(N))
6487     return NewSel;
6488 
6489   // fold (or x, c) -> c iff (x & ~c) == 0
6490   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
6491     return N1;
6492 
6493   if (SDValue Combined = visitORLike(N0, N1, N))
6494     return Combined;
6495 
6496   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
6497     return Combined;
6498 
6499   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
6500   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
6501     return BSwap;
6502   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
6503     return BSwap;
6504 
6505   // reassociate or
6506   if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
6507     return ROR;
6508 
6509   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
6510   // iff (c1 & c2) != 0 or c1/c2 are undef.
6511   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
6512     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
6513   };
6514   if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
6515       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
6516     if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
6517                                                  {N1, N0.getOperand(1)})) {
6518       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
6519       AddToWorklist(IOR.getNode());
6520       return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
6521     }
6522   }
6523 
6524   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
6525     return Combined;
6526   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
6527     return Combined;
6528 
6529   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
6530   if (N0.getOpcode() == N1.getOpcode())
6531     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
6532       return V;
6533 
6534   // See if this is some rotate idiom.
6535   if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
6536     return Rot;
6537 
6538   if (SDValue Load = MatchLoadCombine(N))
6539     return Load;
6540 
6541   // Simplify the operands using demanded-bits information.
6542   if (SimplifyDemandedBits(SDValue(N, 0)))
6543     return SDValue(N, 0);
6544 
6545   // If OR can be rewritten into ADD, try combines based on ADD.
6546   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
6547       DAG.haveNoCommonBitsSet(N0, N1))
6548     if (SDValue Combined = visitADDLike(N))
6549       return Combined;
6550 
6551   return SDValue();
6552 }
6553 
stripConstantMask(SelectionDAG & DAG,SDValue Op,SDValue & Mask)6554 static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) {
6555   if (Op.getOpcode() == ISD::AND &&
6556       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
6557     Mask = Op.getOperand(1);
6558     return Op.getOperand(0);
6559   }
6560   return Op;
6561 }
6562 
6563 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)6564 static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift,
6565                             SDValue &Mask) {
6566   Op = stripConstantMask(DAG, Op, Mask);
6567   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
6568     Shift = Op;
6569     return true;
6570   }
6571   return false;
6572 }
6573 
6574 /// Helper function for visitOR to extract the needed side of a rotate idiom
6575 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
6576 /// InstCombine merged some outside op with one of the shifts from
6577 /// the rotate pattern.
6578 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
6579 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
6580 /// patterns:
6581 ///
6582 ///   (or (add v v) (shrl v bitwidth-1)):
6583 ///     expands (add v v) -> (shl v 1)
6584 ///
6585 ///   (or (mul v c0) (shrl (mul v c1) c2)):
6586 ///     expands (mul v c0) -> (shl (mul v c1) c3)
6587 ///
6588 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
6589 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
6590 ///
6591 ///   (or (shl v c0) (shrl (shl v c1) c2)):
6592 ///     expands (shl v c0) -> (shl (shl v c1) c3)
6593 ///
6594 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
6595 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
6596 ///
6597 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)6598 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
6599                                      SDValue ExtractFrom, SDValue &Mask,
6600                                      const SDLoc &DL) {
6601   assert(OppShift && ExtractFrom && "Empty SDValue");
6602   assert(
6603       (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) &&
6604       "Existing shift must be valid as a rotate half");
6605 
6606   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
6607 
6608   // Value and Type of the shift.
6609   SDValue OppShiftLHS = OppShift.getOperand(0);
6610   EVT ShiftedVT = OppShiftLHS.getValueType();
6611 
6612   // Amount of the existing shift.
6613   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
6614 
6615   // (add v v) -> (shl v 1)
6616   // TODO: Should this be a general DAG canonicalization?
6617   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
6618       ExtractFrom.getOpcode() == ISD::ADD &&
6619       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
6620       ExtractFrom.getOperand(0) == OppShiftLHS &&
6621       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
6622     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
6623                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
6624 
6625   // Preconditions:
6626   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
6627   //
6628   // Find opcode of the needed shift to be extracted from (op0 v c0).
6629   unsigned Opcode = ISD::DELETED_NODE;
6630   bool IsMulOrDiv = false;
6631   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
6632   // opcode or its arithmetic (mul or udiv) variant.
6633   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
6634     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
6635     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
6636       return false;
6637     Opcode = NeededShift;
6638     return true;
6639   };
6640   // op0 must be either the needed shift opcode or the mul/udiv equivalent
6641   // that the needed shift can be extracted from.
6642   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
6643       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
6644     return SDValue();
6645 
6646   // op0 must be the same opcode on both sides, have the same LHS argument,
6647   // and produce the same value type.
6648   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
6649       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
6650       ShiftedVT != ExtractFrom.getValueType())
6651     return SDValue();
6652 
6653   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
6654   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
6655   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
6656   ConstantSDNode *ExtractFromCst =
6657       isConstOrConstSplat(ExtractFrom.getOperand(1));
6658   // TODO: We should be able to handle non-uniform constant vectors for these values
6659   // Check that we have constant values.
6660   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
6661       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
6662       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
6663     return SDValue();
6664 
6665   // Compute the shift amount we need to extract to complete the rotate.
6666   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
6667   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
6668     return SDValue();
6669   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
6670   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
6671   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
6672   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
6673   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
6674 
6675   // Now try extract the needed shift from the ExtractFrom op and see if the
6676   // result matches up with the existing shift's LHS op.
6677   if (IsMulOrDiv) {
6678     // Op to extract from is a mul or udiv by a constant.
6679     // Check:
6680     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
6681     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
6682     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
6683                                                  NeededShiftAmt.getZExtValue());
6684     APInt ResultAmt;
6685     APInt Rem;
6686     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
6687     if (Rem != 0 || ResultAmt != OppLHSAmt)
6688       return SDValue();
6689   } else {
6690     // Op to extract from is a shift by a constant.
6691     // Check:
6692     //      c2 - (bitwidth(op0 v c0) - c1) == c0
6693     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
6694                                           ExtractFromAmt.getBitWidth()))
6695       return SDValue();
6696   }
6697 
6698   // Return the expanded shift op that should allow a rotate to be formed.
6699   EVT ShiftVT = OppShift.getOperand(1).getValueType();
6700   EVT ResVT = ExtractFrom.getValueType();
6701   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
6702   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
6703 }
6704 
6705 // Return true if we can prove that, whenever Neg and Pos are both in the
6706 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
6707 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
6708 //
6709 //     (or (shift1 X, Neg), (shift2 X, Pos))
6710 //
6711 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
6712 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
6713 // to consider shift amounts with defined behavior.
6714 //
6715 // The IsRotate flag should be set when the LHS of both shifts is the same.
6716 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate)6717 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
6718                            SelectionDAG &DAG, bool IsRotate) {
6719   // If EltSize is a power of 2 then:
6720   //
6721   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
6722   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
6723   //
6724   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
6725   // for the stronger condition:
6726   //
6727   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
6728   //
6729   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
6730   // we can just replace Neg with Neg' for the rest of the function.
6731   //
6732   // In other cases we check for the even stronger condition:
6733   //
6734   //     Neg == EltSize - Pos                                    [B]
6735   //
6736   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
6737   // behavior if Pos == 0 (and consequently Neg == EltSize).
6738   //
6739   // We could actually use [A] whenever EltSize is a power of 2, but the
6740   // only extra cases that it would match are those uninteresting ones
6741   // where Neg and Pos are never in range at the same time.  E.g. for
6742   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
6743   // as well as (sub 32, Pos), but:
6744   //
6745   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
6746   //
6747   // always invokes undefined behavior for 32-bit X.
6748   //
6749   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
6750   //
6751   // NOTE: We can only do this when matching an AND and not a general
6752   // funnel shift.
6753   unsigned MaskLoBits = 0;
6754   if (IsRotate && Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) {
6755     if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) {
6756       KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0));
6757       unsigned Bits = Log2_64(EltSize);
6758       if (NegC->getAPIntValue().getActiveBits() <= Bits &&
6759           ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) {
6760         Neg = Neg.getOperand(0);
6761         MaskLoBits = Bits;
6762       }
6763     }
6764   }
6765 
6766   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
6767   if (Neg.getOpcode() != ISD::SUB)
6768     return false;
6769   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
6770   if (!NegC)
6771     return false;
6772   SDValue NegOp1 = Neg.getOperand(1);
6773 
6774   // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with
6775   // Pos'.  The truncation is redundant for the purpose of the equality.
6776   if (MaskLoBits && Pos.getOpcode() == ISD::AND) {
6777     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) {
6778       KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0));
6779       if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits &&
6780           ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >=
6781            MaskLoBits))
6782         Pos = Pos.getOperand(0);
6783     }
6784   }
6785 
6786   // The condition we need is now:
6787   //
6788   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
6789   //
6790   // If NegOp1 == Pos then we need:
6791   //
6792   //              EltSize & Mask == NegC & Mask
6793   //
6794   // (because "x & Mask" is a truncation and distributes through subtraction).
6795   //
6796   // We also need to account for a potential truncation of NegOp1 if the amount
6797   // has already been legalized to a shift amount type.
6798   APInt Width;
6799   if ((Pos == NegOp1) ||
6800       (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
6801     Width = NegC->getAPIntValue();
6802 
6803   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
6804   // Then the condition we want to prove becomes:
6805   //
6806   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
6807   //
6808   // which, again because "x & Mask" is a truncation, becomes:
6809   //
6810   //                NegC & Mask == (EltSize - PosC) & Mask
6811   //             EltSize & Mask == (NegC + PosC) & Mask
6812   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
6813     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
6814       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
6815     else
6816       return false;
6817   } else
6818     return false;
6819 
6820   // Now we just need to check that EltSize & Mask == Width & Mask.
6821   if (MaskLoBits)
6822     // EltSize & Mask is 0 since Mask is EltSize - 1.
6823     return Width.getLoBits(MaskLoBits) == 0;
6824   return Width == EltSize;
6825 }
6826 
6827 // A subroutine of MatchRotate used once we have found an OR of two opposite
6828 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
6829 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
6830 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
6831 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)6832 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
6833                                        SDValue Neg, SDValue InnerPos,
6834                                        SDValue InnerNeg, unsigned PosOpcode,
6835                                        unsigned NegOpcode, const SDLoc &DL) {
6836   // fold (or (shl x, (*ext y)),
6837   //          (srl x, (*ext (sub 32, y)))) ->
6838   //   (rotl x, y) or (rotr x, (sub 32, y))
6839   //
6840   // fold (or (shl x, (*ext (sub 32, y))),
6841   //          (srl x, (*ext y))) ->
6842   //   (rotr x, y) or (rotl x, (sub 32, y))
6843   EVT VT = Shifted.getValueType();
6844   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
6845                      /*IsRotate*/ true)) {
6846     bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
6847     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
6848                        HasPos ? Pos : Neg);
6849   }
6850 
6851   return SDValue();
6852 }
6853 
6854 // A subroutine of MatchRotate used once we have found an OR of two opposite
6855 // shifts of N0 + N1.  If Neg == <operand size> - Pos then the OR reduces
6856 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
6857 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
6858 // Neg with outer conversions stripped away.
6859 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)6860 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
6861                                        SDValue Neg, SDValue InnerPos,
6862                                        SDValue InnerNeg, unsigned PosOpcode,
6863                                        unsigned NegOpcode, const SDLoc &DL) {
6864   EVT VT = N0.getValueType();
6865   unsigned EltBits = VT.getScalarSizeInBits();
6866 
6867   // fold (or (shl x0, (*ext y)),
6868   //          (srl x1, (*ext (sub 32, y)))) ->
6869   //   (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
6870   //
6871   // fold (or (shl x0, (*ext (sub 32, y))),
6872   //          (srl x1, (*ext y))) ->
6873   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
6874   if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
6875     bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
6876     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
6877                        HasPos ? Pos : Neg);
6878   }
6879 
6880   // Matching the shift+xor cases, we can't easily use the xor'd shift amount
6881   // so for now just use the PosOpcode case if its legal.
6882   // TODO: When can we use the NegOpcode case?
6883   if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
6884     auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
6885       if (Op.getOpcode() != BinOpc)
6886         return false;
6887       ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1));
6888       return Cst && (Cst->getAPIntValue() == Imm);
6889     };
6890 
6891     // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
6892     //   -> (fshl x0, x1, y)
6893     if (IsBinOpImm(N1, ISD::SRL, 1) &&
6894         IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
6895         InnerPos == InnerNeg.getOperand(0) &&
6896         TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
6897       return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos);
6898     }
6899 
6900     // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
6901     //   -> (fshr x0, x1, y)
6902     if (IsBinOpImm(N0, ISD::SHL, 1) &&
6903         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
6904         InnerNeg == InnerPos.getOperand(0) &&
6905         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
6906       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
6907     }
6908 
6909     // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
6910     //   -> (fshr x0, x1, y)
6911     // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
6912     if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) &&
6913         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
6914         InnerNeg == InnerPos.getOperand(0) &&
6915         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
6916       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
6917     }
6918   }
6919 
6920   return SDValue();
6921 }
6922 
6923 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
6924 // idioms for rotate, and if the target supports rotation instructions, generate
6925 // a rot[lr]. This also matches funnel shift patterns, similar to rotation but
6926 // with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)6927 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
6928   // Must be a legal type.  Expanded 'n promoted things won't work with rotates.
6929   EVT VT = LHS.getValueType();
6930   if (!TLI.isTypeLegal(VT))
6931     return SDValue();
6932 
6933   // The target must have at least one rotate/funnel flavor.
6934   bool HasROTL = hasOperation(ISD::ROTL, VT);
6935   bool HasROTR = hasOperation(ISD::ROTR, VT);
6936   bool HasFSHL = hasOperation(ISD::FSHL, VT);
6937   bool HasFSHR = hasOperation(ISD::FSHR, VT);
6938   if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
6939     return SDValue();
6940 
6941   // Check for truncated rotate.
6942   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
6943       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
6944     assert(LHS.getValueType() == RHS.getValueType());
6945     if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
6946       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
6947     }
6948   }
6949 
6950   // Match "(X shl/srl V1) & V2" where V2 may not be present.
6951   SDValue LHSShift;   // The shift.
6952   SDValue LHSMask;    // AND value if any.
6953   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
6954 
6955   SDValue RHSShift;   // The shift.
6956   SDValue RHSMask;    // AND value if any.
6957   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
6958 
6959   // If neither side matched a rotate half, bail
6960   if (!LHSShift && !RHSShift)
6961     return SDValue();
6962 
6963   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
6964   // side of the rotate, so try to handle that here. In all cases we need to
6965   // pass the matched shift from the opposite side to compute the opcode and
6966   // needed shift amount to extract.  We still want to do this if both sides
6967   // matched a rotate half because one half may be a potential overshift that
6968   // can be broken down (ie if InstCombine merged two shl or srl ops into a
6969   // single one).
6970 
6971   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
6972   if (LHSShift)
6973     if (SDValue NewRHSShift =
6974             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
6975       RHSShift = NewRHSShift;
6976   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
6977   if (RHSShift)
6978     if (SDValue NewLHSShift =
6979             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
6980       LHSShift = NewLHSShift;
6981 
6982   // If a side is still missing, nothing else we can do.
6983   if (!RHSShift || !LHSShift)
6984     return SDValue();
6985 
6986   // At this point we've matched or extracted a shift op on each side.
6987 
6988   if (LHSShift.getOpcode() == RHSShift.getOpcode())
6989     return SDValue(); // Shifts must disagree.
6990 
6991   bool IsRotate = LHSShift.getOperand(0) == RHSShift.getOperand(0);
6992   if (!IsRotate && !(HasFSHL || HasFSHR))
6993     return SDValue(); // Requires funnel shift support.
6994 
6995   // Canonicalize shl to left side in a shl/srl pair.
6996   if (RHSShift.getOpcode() == ISD::SHL) {
6997     std::swap(LHS, RHS);
6998     std::swap(LHSShift, RHSShift);
6999     std::swap(LHSMask, RHSMask);
7000   }
7001 
7002   unsigned EltSizeInBits = VT.getScalarSizeInBits();
7003   SDValue LHSShiftArg = LHSShift.getOperand(0);
7004   SDValue LHSShiftAmt = LHSShift.getOperand(1);
7005   SDValue RHSShiftArg = RHSShift.getOperand(0);
7006   SDValue RHSShiftAmt = RHSShift.getOperand(1);
7007 
7008   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
7009   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
7010   // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
7011   // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
7012   // iff C1+C2 == EltSizeInBits
7013   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
7014                                         ConstantSDNode *RHS) {
7015     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
7016   };
7017   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7018     SDValue Res;
7019     if (IsRotate && (HasROTL || HasROTR))
7020       Res = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
7021                         HasROTL ? LHSShiftAmt : RHSShiftAmt);
7022     else
7023       Res = DAG.getNode(HasFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
7024                         RHSShiftArg, HasFSHL ? LHSShiftAmt : RHSShiftAmt);
7025 
7026     // If there is an AND of either shifted operand, apply it to the result.
7027     if (LHSMask.getNode() || RHSMask.getNode()) {
7028       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
7029       SDValue Mask = AllOnes;
7030 
7031       if (LHSMask.getNode()) {
7032         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
7033         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7034                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
7035       }
7036       if (RHSMask.getNode()) {
7037         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
7038         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7039                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
7040       }
7041 
7042       Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
7043     }
7044 
7045     return Res;
7046   }
7047 
7048   // If there is a mask here, and we have a variable shift, we can't be sure
7049   // that we're masking out the right stuff.
7050   if (LHSMask.getNode() || RHSMask.getNode())
7051     return SDValue();
7052 
7053   // If the shift amount is sign/zext/any-extended just peel it off.
7054   SDValue LExtOp0 = LHSShiftAmt;
7055   SDValue RExtOp0 = RHSShiftAmt;
7056   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7057        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7058        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7059        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
7060       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7061        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7062        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7063        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
7064     LExtOp0 = LHSShiftAmt.getOperand(0);
7065     RExtOp0 = RHSShiftAmt.getOperand(0);
7066   }
7067 
7068   if (IsRotate && (HasROTL || HasROTR)) {
7069     SDValue TryL =
7070         MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
7071                           RExtOp0, ISD::ROTL, ISD::ROTR, DL);
7072     if (TryL)
7073       return TryL;
7074 
7075     SDValue TryR =
7076         MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
7077                           LExtOp0, ISD::ROTR, ISD::ROTL, DL);
7078     if (TryR)
7079       return TryR;
7080   }
7081 
7082   SDValue TryL =
7083       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
7084                         LExtOp0, RExtOp0, ISD::FSHL, ISD::FSHR, DL);
7085   if (TryL)
7086     return TryL;
7087 
7088   SDValue TryR =
7089       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
7090                         RExtOp0, LExtOp0, ISD::FSHR, ISD::FSHL, DL);
7091   if (TryR)
7092     return TryR;
7093 
7094   return SDValue();
7095 }
7096 
7097 namespace {
7098 
7099 /// Represents known origin of an individual byte in load combine pattern. The
7100 /// value of the byte is either constant zero or comes from memory.
7101 struct ByteProvider {
7102   // For constant zero providers Load is set to nullptr. For memory providers
7103   // Load represents the node which loads the byte from memory.
7104   // ByteOffset is the offset of the byte in the value produced by the load.
7105   LoadSDNode *Load = nullptr;
7106   unsigned ByteOffset = 0;
7107 
7108   ByteProvider() = default;
7109 
getMemory__anon035eee1e0f11::ByteProvider7110   static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
7111     return ByteProvider(Load, ByteOffset);
7112   }
7113 
getConstantZero__anon035eee1e0f11::ByteProvider7114   static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
7115 
isConstantZero__anon035eee1e0f11::ByteProvider7116   bool isConstantZero() const { return !Load; }
isMemory__anon035eee1e0f11::ByteProvider7117   bool isMemory() const { return Load; }
7118 
operator ==__anon035eee1e0f11::ByteProvider7119   bool operator==(const ByteProvider &Other) const {
7120     return Other.Load == Load && Other.ByteOffset == ByteOffset;
7121   }
7122 
7123 private:
ByteProvider__anon035eee1e0f11::ByteProvider7124   ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
7125       : Load(Load), ByteOffset(ByteOffset) {}
7126 };
7127 
7128 } // end anonymous namespace
7129 
7130 /// Recursively traverses the expression calculating the origin of the requested
7131 /// byte of the given value. Returns None if the provider can't be calculated.
7132 ///
7133 /// For all the values except the root of the expression verifies that the value
7134 /// has exactly one use and if it's not true return None. This way if the origin
7135 /// of the byte is returned it's guaranteed that the values which contribute to
7136 /// the byte are not used outside of this expression.
7137 ///
7138 /// Because the parts of the expression are not allowed to have more than one
7139 /// use this function iterates over trees, not DAGs. So it never visits the same
7140 /// node more than once.
7141 static const Optional<ByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,bool Root=false)7142 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
7143                       bool Root = false) {
7144   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
7145   if (Depth == 10)
7146     return None;
7147 
7148   if (!Root && !Op.hasOneUse())
7149     return None;
7150 
7151   assert(Op.getValueType().isScalarInteger() && "can't handle other types");
7152   unsigned BitWidth = Op.getValueSizeInBits();
7153   if (BitWidth % 8 != 0)
7154     return None;
7155   unsigned ByteWidth = BitWidth / 8;
7156   assert(Index < ByteWidth && "invalid index requested");
7157   (void) ByteWidth;
7158 
7159   switch (Op.getOpcode()) {
7160   case ISD::OR: {
7161     auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
7162     if (!LHS)
7163       return None;
7164     auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
7165     if (!RHS)
7166       return None;
7167 
7168     if (LHS->isConstantZero())
7169       return RHS;
7170     if (RHS->isConstantZero())
7171       return LHS;
7172     return None;
7173   }
7174   case ISD::SHL: {
7175     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
7176     if (!ShiftOp)
7177       return None;
7178 
7179     uint64_t BitShift = ShiftOp->getZExtValue();
7180     if (BitShift % 8 != 0)
7181       return None;
7182     uint64_t ByteShift = BitShift / 8;
7183 
7184     return Index < ByteShift
7185                ? ByteProvider::getConstantZero()
7186                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
7187                                        Depth + 1);
7188   }
7189   case ISD::ANY_EXTEND:
7190   case ISD::SIGN_EXTEND:
7191   case ISD::ZERO_EXTEND: {
7192     SDValue NarrowOp = Op->getOperand(0);
7193     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
7194     if (NarrowBitWidth % 8 != 0)
7195       return None;
7196     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
7197 
7198     if (Index >= NarrowByteWidth)
7199       return Op.getOpcode() == ISD::ZERO_EXTEND
7200                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
7201                  : None;
7202     return calculateByteProvider(NarrowOp, Index, Depth + 1);
7203   }
7204   case ISD::BSWAP:
7205     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
7206                                  Depth + 1);
7207   case ISD::LOAD: {
7208     auto L = cast<LoadSDNode>(Op.getNode());
7209     if (!L->isSimple() || L->isIndexed())
7210       return None;
7211 
7212     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
7213     if (NarrowBitWidth % 8 != 0)
7214       return None;
7215     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
7216 
7217     if (Index >= NarrowByteWidth)
7218       return L->getExtensionType() == ISD::ZEXTLOAD
7219                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
7220                  : None;
7221     return ByteProvider::getMemory(L, Index);
7222   }
7223   }
7224 
7225   return None;
7226 }
7227 
littleEndianByteAt(unsigned BW,unsigned i)7228 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
7229   return i;
7230 }
7231 
bigEndianByteAt(unsigned BW,unsigned i)7232 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
7233   return BW - i - 1;
7234 }
7235 
7236 // Check if the bytes offsets we are looking at match with either big or
7237 // little endian value loaded. Return true for big endian, false for little
7238 // endian, and None if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)7239 static Optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
7240                                   int64_t FirstOffset) {
7241   // The endian can be decided only when it is 2 bytes at least.
7242   unsigned Width = ByteOffsets.size();
7243   if (Width < 2)
7244     return None;
7245 
7246   bool BigEndian = true, LittleEndian = true;
7247   for (unsigned i = 0; i < Width; i++) {
7248     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
7249     LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
7250     BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
7251     if (!BigEndian && !LittleEndian)
7252       return None;
7253   }
7254 
7255   assert((BigEndian != LittleEndian) && "It should be either big endian or"
7256                                         "little endian");
7257   return BigEndian;
7258 }
7259 
stripTruncAndExt(SDValue Value)7260 static SDValue stripTruncAndExt(SDValue Value) {
7261   switch (Value.getOpcode()) {
7262   case ISD::TRUNCATE:
7263   case ISD::ZERO_EXTEND:
7264   case ISD::SIGN_EXTEND:
7265   case ISD::ANY_EXTEND:
7266     return stripTruncAndExt(Value.getOperand(0));
7267   }
7268   return Value;
7269 }
7270 
7271 /// Match a pattern where a wide type scalar value is stored by several narrow
7272 /// stores. Fold it into a single store or a BSWAP and a store if the targets
7273 /// supports it.
7274 ///
7275 /// Assuming little endian target:
7276 ///  i8 *p = ...
7277 ///  i32 val = ...
7278 ///  p[0] = (val >> 0) & 0xFF;
7279 ///  p[1] = (val >> 8) & 0xFF;
7280 ///  p[2] = (val >> 16) & 0xFF;
7281 ///  p[3] = (val >> 24) & 0xFF;
7282 /// =>
7283 ///  *((i32)p) = val;
7284 ///
7285 ///  i8 *p = ...
7286 ///  i32 val = ...
7287 ///  p[0] = (val >> 24) & 0xFF;
7288 ///  p[1] = (val >> 16) & 0xFF;
7289 ///  p[2] = (val >> 8) & 0xFF;
7290 ///  p[3] = (val >> 0) & 0xFF;
7291 /// =>
7292 ///  *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)7293 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
7294   // The matching looks for "store (trunc x)" patterns that appear early but are
7295   // likely to be replaced by truncating store nodes during combining.
7296   // TODO: If there is evidence that running this later would help, this
7297   //       limitation could be removed. Legality checks may need to be added
7298   //       for the created store and optional bswap/rotate.
7299   if (LegalOperations)
7300     return SDValue();
7301 
7302   // We only handle merging simple stores of 1-4 bytes.
7303   // TODO: Allow unordered atomics when wider type is legal (see D66309)
7304   EVT MemVT = N->getMemoryVT();
7305   if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
7306       !N->isSimple() || N->isIndexed())
7307     return SDValue();
7308 
7309   // Collect all of the stores in the chain.
7310   SDValue Chain = N->getChain();
7311   SmallVector<StoreSDNode *, 8> Stores = {N};
7312   while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
7313     // All stores must be the same size to ensure that we are writing all of the
7314     // bytes in the wide value.
7315     // TODO: We could allow multiple sizes by tracking each stored byte.
7316     if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
7317         Store->isIndexed())
7318       return SDValue();
7319     Stores.push_back(Store);
7320     Chain = Store->getChain();
7321   }
7322   // There is no reason to continue if we do not have at least a pair of stores.
7323   if (Stores.size() < 2)
7324     return SDValue();
7325 
7326   // Handle simple types only.
7327   LLVMContext &Context = *DAG.getContext();
7328   unsigned NumStores = Stores.size();
7329   unsigned NarrowNumBits = N->getMemoryVT().getScalarSizeInBits();
7330   unsigned WideNumBits = NumStores * NarrowNumBits;
7331   EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
7332   if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
7333     return SDValue();
7334 
7335   // Check if all bytes of the source value that we are looking at are stored
7336   // to the same base address. Collect offsets from Base address into OffsetMap.
7337   SDValue SourceValue;
7338   SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
7339   int64_t FirstOffset = INT64_MAX;
7340   StoreSDNode *FirstStore = nullptr;
7341   Optional<BaseIndexOffset> Base;
7342   for (auto Store : Stores) {
7343     // All the stores store different parts of the CombinedValue. A truncate is
7344     // required to get the partial value.
7345     SDValue Trunc = Store->getValue();
7346     if (Trunc.getOpcode() != ISD::TRUNCATE)
7347       return SDValue();
7348     // Other than the first/last part, a shift operation is required to get the
7349     // offset.
7350     int64_t Offset = 0;
7351     SDValue WideVal = Trunc.getOperand(0);
7352     if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
7353         isa<ConstantSDNode>(WideVal.getOperand(1))) {
7354       // The shift amount must be a constant multiple of the narrow type.
7355       // It is translated to the offset address in the wide source value "y".
7356       //
7357       // x = srl y, ShiftAmtC
7358       // i8 z = trunc x
7359       // store z, ...
7360       uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
7361       if (ShiftAmtC % NarrowNumBits != 0)
7362         return SDValue();
7363 
7364       Offset = ShiftAmtC / NarrowNumBits;
7365       WideVal = WideVal.getOperand(0);
7366     }
7367 
7368     // Stores must share the same source value with different offsets.
7369     // Truncate and extends should be stripped to get the single source value.
7370     if (!SourceValue)
7371       SourceValue = WideVal;
7372     else if (stripTruncAndExt(SourceValue) != stripTruncAndExt(WideVal))
7373       return SDValue();
7374     else if (SourceValue.getValueType() != WideVT) {
7375       if (WideVal.getValueType() == WideVT ||
7376           WideVal.getScalarValueSizeInBits() >
7377               SourceValue.getScalarValueSizeInBits())
7378         SourceValue = WideVal;
7379       // Give up if the source value type is smaller than the store size.
7380       if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
7381         return SDValue();
7382     }
7383 
7384     // Stores must share the same base address.
7385     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
7386     int64_t ByteOffsetFromBase = 0;
7387     if (!Base)
7388       Base = Ptr;
7389     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
7390       return SDValue();
7391 
7392     // Remember the first store.
7393     if (ByteOffsetFromBase < FirstOffset) {
7394       FirstStore = Store;
7395       FirstOffset = ByteOffsetFromBase;
7396     }
7397     // Map the offset in the store and the offset in the combined value, and
7398     // early return if it has been set before.
7399     if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
7400       return SDValue();
7401     OffsetMap[Offset] = ByteOffsetFromBase;
7402   }
7403 
7404   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
7405   assert(FirstStore && "First store must be set");
7406 
7407   // Check that a store of the wide type is both allowed and fast on the target
7408   const DataLayout &Layout = DAG.getDataLayout();
7409   bool Fast = false;
7410   bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
7411                                         *FirstStore->getMemOperand(), &Fast);
7412   if (!Allowed || !Fast)
7413     return SDValue();
7414 
7415   // Check if the pieces of the value are going to the expected places in memory
7416   // to merge the stores.
7417   auto checkOffsets = [&](bool MatchLittleEndian) {
7418     if (MatchLittleEndian) {
7419       for (unsigned i = 0; i != NumStores; ++i)
7420         if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
7421           return false;
7422     } else { // MatchBigEndian by reversing loop counter.
7423       for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
7424         if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
7425           return false;
7426     }
7427     return true;
7428   };
7429 
7430   // Check if the offsets line up for the native data layout of this target.
7431   bool NeedBswap = false;
7432   bool NeedRotate = false;
7433   if (!checkOffsets(Layout.isLittleEndian())) {
7434     // Special-case: check if byte offsets line up for the opposite endian.
7435     if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
7436       NeedBswap = true;
7437     else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
7438       NeedRotate = true;
7439     else
7440       return SDValue();
7441   }
7442 
7443   SDLoc DL(N);
7444   if (WideVT != SourceValue.getValueType()) {
7445     assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
7446            "Unexpected store value to merge");
7447     SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
7448   }
7449 
7450   // Before legalize we can introduce illegal bswaps/rotates which will be later
7451   // converted to an explicit bswap sequence. This way we end up with a single
7452   // store and byte shuffling instead of several stores and byte shuffling.
7453   if (NeedBswap) {
7454     SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
7455   } else if (NeedRotate) {
7456     assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
7457     SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
7458     SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
7459   }
7460 
7461   SDValue NewStore =
7462       DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
7463                    FirstStore->getPointerInfo(), FirstStore->getAlign());
7464 
7465   // Rely on other DAG combine rules to remove the other individual stores.
7466   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
7467   return NewStore;
7468 }
7469 
7470 /// Match a pattern where a wide type scalar value is loaded by several narrow
7471 /// loads and combined by shifts and ors. Fold it into a single load or a load
7472 /// and a BSWAP if the targets supports it.
7473 ///
7474 /// Assuming little endian target:
7475 ///  i8 *a = ...
7476 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
7477 /// =>
7478 ///  i32 val = *((i32)a)
7479 ///
7480 ///  i8 *a = ...
7481 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
7482 /// =>
7483 ///  i32 val = BSWAP(*((i32)a))
7484 ///
7485 /// TODO: This rule matches complex patterns with OR node roots and doesn't
7486 /// interact well with the worklist mechanism. When a part of the pattern is
7487 /// updated (e.g. one of the loads) its direct users are put into the worklist,
7488 /// but the root node of the pattern which triggers the load combine is not
7489 /// necessarily a direct user of the changed node. For example, once the address
7490 /// of t28 load is reassociated load combine won't be triggered:
7491 ///             t25: i32 = add t4, Constant:i32<2>
7492 ///           t26: i64 = sign_extend t25
7493 ///        t27: i64 = add t2, t26
7494 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
7495 ///     t29: i32 = zero_extend t28
7496 ///   t32: i32 = shl t29, Constant:i8<8>
7497 /// t33: i32 = or t23, t32
7498 /// As a possible fix visitLoad can check if the load can be a part of a load
7499 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)7500 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
7501   assert(N->getOpcode() == ISD::OR &&
7502          "Can only match load combining against OR nodes");
7503 
7504   // Handles simple types only
7505   EVT VT = N->getValueType(0);
7506   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
7507     return SDValue();
7508   unsigned ByteWidth = VT.getSizeInBits() / 8;
7509 
7510   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
7511   auto MemoryByteOffset = [&] (ByteProvider P) {
7512     assert(P.isMemory() && "Must be a memory byte provider");
7513     unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits();
7514     assert(LoadBitWidth % 8 == 0 &&
7515            "can only analyze providers for individual bytes not bit");
7516     unsigned LoadByteWidth = LoadBitWidth / 8;
7517     return IsBigEndianTarget
7518             ? bigEndianByteAt(LoadByteWidth, P.ByteOffset)
7519             : littleEndianByteAt(LoadByteWidth, P.ByteOffset);
7520   };
7521 
7522   Optional<BaseIndexOffset> Base;
7523   SDValue Chain;
7524 
7525   SmallPtrSet<LoadSDNode *, 8> Loads;
7526   Optional<ByteProvider> FirstByteProvider;
7527   int64_t FirstOffset = INT64_MAX;
7528 
7529   // Check if all the bytes of the OR we are looking at are loaded from the same
7530   // base address. Collect bytes offsets from Base address in ByteOffsets.
7531   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
7532   unsigned ZeroExtendedBytes = 0;
7533   for (int i = ByteWidth - 1; i >= 0; --i) {
7534     auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
7535     if (!P)
7536       return SDValue();
7537 
7538     if (P->isConstantZero()) {
7539       // It's OK for the N most significant bytes to be 0, we can just
7540       // zero-extend the load.
7541       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
7542         return SDValue();
7543       continue;
7544     }
7545     assert(P->isMemory() && "provenance should either be memory or zero");
7546 
7547     LoadSDNode *L = P->Load;
7548     assert(L->hasNUsesOfValue(1, 0) && L->isSimple() &&
7549            !L->isIndexed() &&
7550            "Must be enforced by calculateByteProvider");
7551     assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
7552 
7553     // All loads must share the same chain
7554     SDValue LChain = L->getChain();
7555     if (!Chain)
7556       Chain = LChain;
7557     else if (Chain != LChain)
7558       return SDValue();
7559 
7560     // Loads must share the same base address
7561     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
7562     int64_t ByteOffsetFromBase = 0;
7563     if (!Base)
7564       Base = Ptr;
7565     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
7566       return SDValue();
7567 
7568     // Calculate the offset of the current byte from the base address
7569     ByteOffsetFromBase += MemoryByteOffset(*P);
7570     ByteOffsets[i] = ByteOffsetFromBase;
7571 
7572     // Remember the first byte load
7573     if (ByteOffsetFromBase < FirstOffset) {
7574       FirstByteProvider = P;
7575       FirstOffset = ByteOffsetFromBase;
7576     }
7577 
7578     Loads.insert(L);
7579   }
7580   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
7581          "memory, so there must be at least one load which produces the value");
7582   assert(Base && "Base address of the accessed memory location must be set");
7583   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
7584 
7585   bool NeedsZext = ZeroExtendedBytes > 0;
7586 
7587   EVT MemVT =
7588       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
7589 
7590   if (!MemVT.isSimple())
7591     return SDValue();
7592 
7593   // Before legalize we can introduce too wide illegal loads which will be later
7594   // split into legal sized loads. This enables us to combine i64 load by i8
7595   // patterns to a couple of i32 loads on 32 bit targets.
7596   if (LegalOperations &&
7597       !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
7598                             MemVT))
7599     return SDValue();
7600 
7601   // Check if the bytes of the OR we are looking at match with either big or
7602   // little endian value load
7603   Optional<bool> IsBigEndian = isBigEndian(
7604       makeArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
7605   if (!IsBigEndian.hasValue())
7606     return SDValue();
7607 
7608   assert(FirstByteProvider && "must be set");
7609 
7610   // Ensure that the first byte is loaded from zero offset of the first load.
7611   // So the combined value can be loaded from the first load address.
7612   if (MemoryByteOffset(*FirstByteProvider) != 0)
7613     return SDValue();
7614   LoadSDNode *FirstLoad = FirstByteProvider->Load;
7615 
7616   // The node we are looking at matches with the pattern, check if we can
7617   // replace it with a single (possibly zero-extended) load and bswap + shift if
7618   // needed.
7619 
7620   // If the load needs byte swap check if the target supports it
7621   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
7622 
7623   // Before legalize we can introduce illegal bswaps which will be later
7624   // converted to an explicit bswap sequence. This way we end up with a single
7625   // load and byte shuffling instead of several loads and byte shuffling.
7626   // We do not introduce illegal bswaps when zero-extending as this tends to
7627   // introduce too many arithmetic instructions.
7628   if (NeedsBswap && (LegalOperations || NeedsZext) &&
7629       !TLI.isOperationLegal(ISD::BSWAP, VT))
7630     return SDValue();
7631 
7632   // If we need to bswap and zero extend, we have to insert a shift. Check that
7633   // it is legal.
7634   if (NeedsBswap && NeedsZext && LegalOperations &&
7635       !TLI.isOperationLegal(ISD::SHL, VT))
7636     return SDValue();
7637 
7638   // Check that a load of the wide type is both allowed and fast on the target
7639   bool Fast = false;
7640   bool Allowed =
7641       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
7642                              *FirstLoad->getMemOperand(), &Fast);
7643   if (!Allowed || !Fast)
7644     return SDValue();
7645 
7646   SDValue NewLoad =
7647       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
7648                      Chain, FirstLoad->getBasePtr(),
7649                      FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
7650 
7651   // Transfer chain users from old loads to the new load.
7652   for (LoadSDNode *L : Loads)
7653     DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
7654 
7655   if (!NeedsBswap)
7656     return NewLoad;
7657 
7658   SDValue ShiftedLoad =
7659       NeedsZext
7660           ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
7661                         DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
7662                                                    SDLoc(N), LegalOperations))
7663           : NewLoad;
7664   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
7665 }
7666 
7667 // If the target has andn, bsl, or a similar bit-select instruction,
7668 // we want to unfold masked merge, with canonical pattern of:
7669 //   |        A  |  |B|
7670 //   ((x ^ y) & m) ^ y
7671 //    |  D  |
7672 // Into:
7673 //   (x & m) | (y & ~m)
7674 // If y is a constant, and the 'andn' does not work with immediates,
7675 // we unfold into a different pattern:
7676 //   ~(~x & m) & (m | y)
7677 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
7678 //       the very least that breaks andnpd / andnps patterns, and because those
7679 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)7680 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
7681   assert(N->getOpcode() == ISD::XOR);
7682 
7683   // Don't touch 'not' (i.e. where y = -1).
7684   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
7685     return SDValue();
7686 
7687   EVT VT = N->getValueType(0);
7688 
7689   // There are 3 commutable operators in the pattern,
7690   // so we have to deal with 8 possible variants of the basic pattern.
7691   SDValue X, Y, M;
7692   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
7693     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
7694       return false;
7695     SDValue Xor = And.getOperand(XorIdx);
7696     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
7697       return false;
7698     SDValue Xor0 = Xor.getOperand(0);
7699     SDValue Xor1 = Xor.getOperand(1);
7700     // Don't touch 'not' (i.e. where y = -1).
7701     if (isAllOnesOrAllOnesSplat(Xor1))
7702       return false;
7703     if (Other == Xor0)
7704       std::swap(Xor0, Xor1);
7705     if (Other != Xor1)
7706       return false;
7707     X = Xor0;
7708     Y = Xor1;
7709     M = And.getOperand(XorIdx ? 0 : 1);
7710     return true;
7711   };
7712 
7713   SDValue N0 = N->getOperand(0);
7714   SDValue N1 = N->getOperand(1);
7715   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
7716       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
7717     return SDValue();
7718 
7719   // Don't do anything if the mask is constant. This should not be reachable.
7720   // InstCombine should have already unfolded this pattern, and DAGCombiner
7721   // probably shouldn't produce it, too.
7722   if (isa<ConstantSDNode>(M.getNode()))
7723     return SDValue();
7724 
7725   // We can transform if the target has AndNot
7726   if (!TLI.hasAndNot(M))
7727     return SDValue();
7728 
7729   SDLoc DL(N);
7730 
7731   // If Y is a constant, check that 'andn' works with immediates.
7732   if (!TLI.hasAndNot(Y)) {
7733     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
7734     // If not, we need to do a bit more work to make sure andn is still used.
7735     SDValue NotX = DAG.getNOT(DL, X, VT);
7736     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
7737     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
7738     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
7739     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
7740   }
7741 
7742   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
7743   SDValue NotM = DAG.getNOT(DL, M, VT);
7744   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
7745 
7746   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
7747 }
7748 
visitXOR(SDNode * N)7749 SDValue DAGCombiner::visitXOR(SDNode *N) {
7750   SDValue N0 = N->getOperand(0);
7751   SDValue N1 = N->getOperand(1);
7752   EVT VT = N0.getValueType();
7753 
7754   // fold vector ops
7755   if (VT.isVector()) {
7756     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7757       return FoldedVOp;
7758 
7759     // fold (xor x, 0) -> x, vector edition
7760     if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
7761       return N1;
7762     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
7763       return N0;
7764   }
7765 
7766   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
7767   SDLoc DL(N);
7768   if (N0.isUndef() && N1.isUndef())
7769     return DAG.getConstant(0, DL, VT);
7770 
7771   // fold (xor x, undef) -> undef
7772   if (N0.isUndef())
7773     return N0;
7774   if (N1.isUndef())
7775     return N1;
7776 
7777   // fold (xor c1, c2) -> c1^c2
7778   if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
7779     return C;
7780 
7781   // canonicalize constant to RHS
7782   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7783      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7784     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
7785 
7786   // fold (xor x, 0) -> x
7787   if (isNullConstant(N1))
7788     return N0;
7789 
7790   if (SDValue NewSel = foldBinOpIntoSelect(N))
7791     return NewSel;
7792 
7793   // reassociate xor
7794   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
7795     return RXOR;
7796 
7797   // fold !(x cc y) -> (x !cc y)
7798   unsigned N0Opcode = N0.getOpcode();
7799   SDValue LHS, RHS, CC;
7800   if (TLI.isConstTrueVal(N1.getNode()) &&
7801       isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/true)) {
7802     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
7803                                                LHS.getValueType());
7804     if (!LegalOperations ||
7805         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
7806       switch (N0Opcode) {
7807       default:
7808         llvm_unreachable("Unhandled SetCC Equivalent!");
7809       case ISD::SETCC:
7810         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
7811       case ISD::SELECT_CC:
7812         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
7813                                N0.getOperand(3), NotCC);
7814       case ISD::STRICT_FSETCC:
7815       case ISD::STRICT_FSETCCS: {
7816         if (N0.hasOneUse()) {
7817           // FIXME Can we handle multiple uses? Could we token factor the chain
7818           // results from the new/old setcc?
7819           SDValue SetCC =
7820               DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
7821                            N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
7822           CombineTo(N, SetCC);
7823           DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
7824           recursivelyDeleteUnusedNodes(N0.getNode());
7825           return SDValue(N, 0); // Return N so it doesn't get rechecked!
7826         }
7827         break;
7828       }
7829       }
7830     }
7831   }
7832 
7833   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
7834   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
7835       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
7836     SDValue V = N0.getOperand(0);
7837     SDLoc DL0(N0);
7838     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
7839                     DAG.getConstant(1, DL0, V.getValueType()));
7840     AddToWorklist(V.getNode());
7841     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
7842   }
7843 
7844   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
7845   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
7846       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
7847     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
7848     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
7849       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
7850       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
7851       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
7852       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
7853       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
7854     }
7855   }
7856   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
7857   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
7858       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
7859     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
7860     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
7861       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
7862       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
7863       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
7864       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
7865       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
7866     }
7867   }
7868 
7869   // fold (not (neg x)) -> (add X, -1)
7870   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
7871   // Y is a constant or the subtract has a single use.
7872   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
7873       isNullConstant(N0.getOperand(0))) {
7874     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
7875                        DAG.getAllOnesConstant(DL, VT));
7876   }
7877 
7878   // fold (not (add X, -1)) -> (neg X)
7879   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
7880       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
7881     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
7882                        N0.getOperand(0));
7883   }
7884 
7885   // fold (xor (and x, y), y) -> (and (not x), y)
7886   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
7887     SDValue X = N0.getOperand(0);
7888     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
7889     AddToWorklist(NotX.getNode());
7890     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
7891   }
7892 
7893   if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) {
7894     ConstantSDNode *XorC = isConstOrConstSplat(N1);
7895     ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1));
7896     unsigned BitWidth = VT.getScalarSizeInBits();
7897     if (XorC && ShiftC) {
7898       // Don't crash on an oversized shift. We can not guarantee that a bogus
7899       // shift has been simplified to undef.
7900       uint64_t ShiftAmt = ShiftC->getLimitedValue();
7901       if (ShiftAmt < BitWidth) {
7902         APInt Ones = APInt::getAllOnesValue(BitWidth);
7903         Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt);
7904         if (XorC->getAPIntValue() == Ones) {
7905           // If the xor constant is a shifted -1, do a 'not' before the shift:
7906           // xor (X << ShiftC), XorC --> (not X) << ShiftC
7907           // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
7908           SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
7909           return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1));
7910         }
7911       }
7912     }
7913   }
7914 
7915   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
7916   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
7917     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
7918     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
7919     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
7920       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
7921       SDValue S0 = S.getOperand(0);
7922       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
7923         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
7924           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
7925             return DAG.getNode(ISD::ABS, DL, VT, S0);
7926     }
7927   }
7928 
7929   // fold (xor x, x) -> 0
7930   if (N0 == N1)
7931     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
7932 
7933   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
7934   // Here is a concrete example of this equivalence:
7935   // i16   x ==  14
7936   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
7937   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
7938   //
7939   // =>
7940   //
7941   // i16     ~1      == 0b1111111111111110
7942   // i16 rol(~1, 14) == 0b1011111111111111
7943   //
7944   // Some additional tips to help conceptualize this transform:
7945   // - Try to see the operation as placing a single zero in a value of all ones.
7946   // - There exists no value for x which would allow the result to contain zero.
7947   // - Values of x larger than the bitwidth are undefined and do not require a
7948   //   consistent result.
7949   // - Pushing the zero left requires shifting one bits in from the right.
7950   // A rotate left of ~1 is a nice way of achieving the desired result.
7951   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
7952       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
7953     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
7954                        N0.getOperand(1));
7955   }
7956 
7957   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
7958   if (N0Opcode == N1.getOpcode())
7959     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7960       return V;
7961 
7962   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
7963   if (SDValue MM = unfoldMaskedMerge(N))
7964     return MM;
7965 
7966   // Simplify the expression using non-local knowledge.
7967   if (SimplifyDemandedBits(SDValue(N, 0)))
7968     return SDValue(N, 0);
7969 
7970   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
7971     return Combined;
7972 
7973   return SDValue();
7974 }
7975 
7976 /// If we have a shift-by-constant of a bitwise logic op that itself has a
7977 /// shift-by-constant operand with identical opcode, we may be able to convert
7978 /// that into 2 independent shifts followed by the logic op. This is a
7979 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)7980 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
7981   // Match a one-use bitwise logic op.
7982   SDValue LogicOp = Shift->getOperand(0);
7983   if (!LogicOp.hasOneUse())
7984     return SDValue();
7985 
7986   unsigned LogicOpcode = LogicOp.getOpcode();
7987   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
7988       LogicOpcode != ISD::XOR)
7989     return SDValue();
7990 
7991   // Find a matching one-use shift by constant.
7992   unsigned ShiftOpcode = Shift->getOpcode();
7993   SDValue C1 = Shift->getOperand(1);
7994   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
7995   assert(C1Node && "Expected a shift with constant operand");
7996   const APInt &C1Val = C1Node->getAPIntValue();
7997   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
7998                              const APInt *&ShiftAmtVal) {
7999     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
8000       return false;
8001 
8002     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
8003     if (!ShiftCNode)
8004       return false;
8005 
8006     // Capture the shifted operand and shift amount value.
8007     ShiftOp = V.getOperand(0);
8008     ShiftAmtVal = &ShiftCNode->getAPIntValue();
8009 
8010     // Shift amount types do not have to match their operand type, so check that
8011     // the constants are the same width.
8012     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
8013       return false;
8014 
8015     // The fold is not valid if the sum of the shift values exceeds bitwidth.
8016     if ((*ShiftAmtVal + C1Val).uge(V.getScalarValueSizeInBits()))
8017       return false;
8018 
8019     return true;
8020   };
8021 
8022   // Logic ops are commutative, so check each operand for a match.
8023   SDValue X, Y;
8024   const APInt *C0Val;
8025   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
8026     Y = LogicOp.getOperand(1);
8027   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
8028     Y = LogicOp.getOperand(0);
8029   else
8030     return SDValue();
8031 
8032   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
8033   SDLoc DL(Shift);
8034   EVT VT = Shift->getValueType(0);
8035   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
8036   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
8037   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
8038   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
8039   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
8040 }
8041 
8042 /// Handle transforms common to the three shifts, when the shift amount is a
8043 /// constant.
8044 /// We are looking for: (shift being one of shl/sra/srl)
8045 ///   shift (binop X, C0), C1
8046 /// And want to transform into:
8047 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)8048 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
8049   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
8050 
8051   // Do not turn a 'not' into a regular xor.
8052   if (isBitwiseNot(N->getOperand(0)))
8053     return SDValue();
8054 
8055   // The inner binop must be one-use, since we want to replace it.
8056   SDValue LHS = N->getOperand(0);
8057   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
8058     return SDValue();
8059 
8060   // TODO: This is limited to early combining because it may reveal regressions
8061   //       otherwise. But since we just checked a target hook to see if this is
8062   //       desirable, that should have filtered out cases where this interferes
8063   //       with some other pattern matching.
8064   if (!LegalTypes)
8065     if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
8066       return R;
8067 
8068   // We want to pull some binops through shifts, so that we have (and (shift))
8069   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
8070   // thing happens with address calculations, so it's important to canonicalize
8071   // it.
8072   switch (LHS.getOpcode()) {
8073   default:
8074     return SDValue();
8075   case ISD::OR:
8076   case ISD::XOR:
8077   case ISD::AND:
8078     break;
8079   case ISD::ADD:
8080     if (N->getOpcode() != ISD::SHL)
8081       return SDValue(); // only shl(add) not sr[al](add).
8082     break;
8083   }
8084 
8085   // We require the RHS of the binop to be a constant and not opaque as well.
8086   ConstantSDNode *BinOpCst = getAsNonOpaqueConstant(LHS.getOperand(1));
8087   if (!BinOpCst)
8088     return SDValue();
8089 
8090   // FIXME: disable this unless the input to the binop is a shift by a constant
8091   // or is copy/select. Enable this in other cases when figure out it's exactly
8092   // profitable.
8093   SDValue BinOpLHSVal = LHS.getOperand(0);
8094   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
8095                             BinOpLHSVal.getOpcode() == ISD::SRA ||
8096                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
8097                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
8098   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
8099                         BinOpLHSVal.getOpcode() == ISD::SELECT;
8100 
8101   if (!IsShiftByConstant && !IsCopyOrSelect)
8102     return SDValue();
8103 
8104   if (IsCopyOrSelect && N->hasOneUse())
8105     return SDValue();
8106 
8107   // Fold the constants, shifting the binop RHS by the shift amount.
8108   SDLoc DL(N);
8109   EVT VT = N->getValueType(0);
8110   SDValue NewRHS = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(1),
8111                                N->getOperand(1));
8112   assert(isa<ConstantSDNode>(NewRHS) && "Folding was not successful!");
8113 
8114   SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
8115                                  N->getOperand(1));
8116   return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
8117 }
8118 
distributeTruncateThroughAnd(SDNode * N)8119 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
8120   assert(N->getOpcode() == ISD::TRUNCATE);
8121   assert(N->getOperand(0).getOpcode() == ISD::AND);
8122 
8123   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
8124   EVT TruncVT = N->getValueType(0);
8125   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
8126       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
8127     SDValue N01 = N->getOperand(0).getOperand(1);
8128     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
8129       SDLoc DL(N);
8130       SDValue N00 = N->getOperand(0).getOperand(0);
8131       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
8132       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
8133       AddToWorklist(Trunc00.getNode());
8134       AddToWorklist(Trunc01.getNode());
8135       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
8136     }
8137   }
8138 
8139   return SDValue();
8140 }
8141 
visitRotate(SDNode * N)8142 SDValue DAGCombiner::visitRotate(SDNode *N) {
8143   SDLoc dl(N);
8144   SDValue N0 = N->getOperand(0);
8145   SDValue N1 = N->getOperand(1);
8146   EVT VT = N->getValueType(0);
8147   unsigned Bitsize = VT.getScalarSizeInBits();
8148 
8149   // fold (rot x, 0) -> x
8150   if (isNullOrNullSplat(N1))
8151     return N0;
8152 
8153   // fold (rot x, c) -> x iff (c % BitSize) == 0
8154   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
8155     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
8156     if (DAG.MaskedValueIsZero(N1, ModuloMask))
8157       return N0;
8158   }
8159 
8160   // fold (rot x, c) -> (rot x, c % BitSize)
8161   bool OutOfRange = false;
8162   auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
8163     OutOfRange |= C->getAPIntValue().uge(Bitsize);
8164     return true;
8165   };
8166   if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
8167     EVT AmtVT = N1.getValueType();
8168     SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
8169     if (SDValue Amt =
8170             DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
8171       return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
8172   }
8173 
8174   // rot i16 X, 8 --> bswap X
8175   auto *RotAmtC = isConstOrConstSplat(N1);
8176   if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
8177       VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
8178     return DAG.getNode(ISD::BSWAP, dl, VT, N0);
8179 
8180   // Simplify the operands using demanded-bits information.
8181   if (SimplifyDemandedBits(SDValue(N, 0)))
8182     return SDValue(N, 0);
8183 
8184   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
8185   if (N1.getOpcode() == ISD::TRUNCATE &&
8186       N1.getOperand(0).getOpcode() == ISD::AND) {
8187     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8188       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
8189   }
8190 
8191   unsigned NextOp = N0.getOpcode();
8192   // fold (rot* (rot* x, c2), c1) -> (rot* x, c1 +- c2 % bitsize)
8193   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
8194     SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
8195     SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
8196     if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
8197       EVT ShiftVT = C1->getValueType(0);
8198       bool SameSide = (N->getOpcode() == NextOp);
8199       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
8200       if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
8201               CombineOp, dl, ShiftVT, {N1, N0.getOperand(1)})) {
8202         SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
8203         SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
8204             ISD::SREM, dl, ShiftVT, {CombinedShift, BitsizeC});
8205         return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
8206                            CombinedShiftNorm);
8207       }
8208     }
8209   }
8210   return SDValue();
8211 }
8212 
visitSHL(SDNode * N)8213 SDValue DAGCombiner::visitSHL(SDNode *N) {
8214   SDValue N0 = N->getOperand(0);
8215   SDValue N1 = N->getOperand(1);
8216   if (SDValue V = DAG.simplifyShift(N0, N1))
8217     return V;
8218 
8219   EVT VT = N0.getValueType();
8220   EVT ShiftVT = N1.getValueType();
8221   unsigned OpSizeInBits = VT.getScalarSizeInBits();
8222 
8223   // fold vector ops
8224   if (VT.isVector()) {
8225     if (SDValue FoldedVOp = SimplifyVBinOp(N))
8226       return FoldedVOp;
8227 
8228     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
8229     // If setcc produces all-one true value then:
8230     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
8231     if (N1CV && N1CV->isConstant()) {
8232       if (N0.getOpcode() == ISD::AND) {
8233         SDValue N00 = N0->getOperand(0);
8234         SDValue N01 = N0->getOperand(1);
8235         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
8236 
8237         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
8238             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
8239                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
8240           if (SDValue C =
8241                   DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N01, N1}))
8242             return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
8243         }
8244       }
8245     }
8246   }
8247 
8248   ConstantSDNode *N1C = isConstOrConstSplat(N1);
8249 
8250   // fold (shl c1, c2) -> c1<<c2
8251   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
8252     return C;
8253 
8254   if (SDValue NewSel = foldBinOpIntoSelect(N))
8255     return NewSel;
8256 
8257   // if (shl x, c) is known to be zero, return 0
8258   if (DAG.MaskedValueIsZero(SDValue(N, 0),
8259                             APInt::getAllOnesValue(OpSizeInBits)))
8260     return DAG.getConstant(0, SDLoc(N), VT);
8261 
8262   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
8263   if (N1.getOpcode() == ISD::TRUNCATE &&
8264       N1.getOperand(0).getOpcode() == ISD::AND) {
8265     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8266       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
8267   }
8268 
8269   if (SimplifyDemandedBits(SDValue(N, 0)))
8270     return SDValue(N, 0);
8271 
8272   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
8273   if (N0.getOpcode() == ISD::SHL) {
8274     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
8275                                           ConstantSDNode *RHS) {
8276       APInt c1 = LHS->getAPIntValue();
8277       APInt c2 = RHS->getAPIntValue();
8278       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8279       return (c1 + c2).uge(OpSizeInBits);
8280     };
8281     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
8282       return DAG.getConstant(0, SDLoc(N), VT);
8283 
8284     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
8285                                        ConstantSDNode *RHS) {
8286       APInt c1 = LHS->getAPIntValue();
8287       APInt c2 = RHS->getAPIntValue();
8288       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8289       return (c1 + c2).ult(OpSizeInBits);
8290     };
8291     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
8292       SDLoc DL(N);
8293       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
8294       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
8295     }
8296   }
8297 
8298   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
8299   // For this to be valid, the second form must not preserve any of the bits
8300   // that are shifted out by the inner shift in the first form.  This means
8301   // the outer shift size must be >= the number of bits added by the ext.
8302   // As a corollary, we don't care what kind of ext it is.
8303   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
8304        N0.getOpcode() == ISD::ANY_EXTEND ||
8305        N0.getOpcode() == ISD::SIGN_EXTEND) &&
8306       N0.getOperand(0).getOpcode() == ISD::SHL) {
8307     SDValue N0Op0 = N0.getOperand(0);
8308     SDValue InnerShiftAmt = N0Op0.getOperand(1);
8309     EVT InnerVT = N0Op0.getValueType();
8310     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
8311 
8312     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
8313                                                          ConstantSDNode *RHS) {
8314       APInt c1 = LHS->getAPIntValue();
8315       APInt c2 = RHS->getAPIntValue();
8316       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8317       return c2.uge(OpSizeInBits - InnerBitwidth) &&
8318              (c1 + c2).uge(OpSizeInBits);
8319     };
8320     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
8321                                   /*AllowUndefs*/ false,
8322                                   /*AllowTypeMismatch*/ true))
8323       return DAG.getConstant(0, SDLoc(N), VT);
8324 
8325     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
8326                                                       ConstantSDNode *RHS) {
8327       APInt c1 = LHS->getAPIntValue();
8328       APInt c2 = RHS->getAPIntValue();
8329       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8330       return c2.uge(OpSizeInBits - InnerBitwidth) &&
8331              (c1 + c2).ult(OpSizeInBits);
8332     };
8333     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
8334                                   /*AllowUndefs*/ false,
8335                                   /*AllowTypeMismatch*/ true)) {
8336       SDLoc DL(N);
8337       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
8338       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
8339       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
8340       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
8341     }
8342   }
8343 
8344   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
8345   // Only fold this if the inner zext has no other uses to avoid increasing
8346   // the total number of instructions.
8347   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
8348       N0.getOperand(0).getOpcode() == ISD::SRL) {
8349     SDValue N0Op0 = N0.getOperand(0);
8350     SDValue InnerShiftAmt = N0Op0.getOperand(1);
8351 
8352     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
8353       APInt c1 = LHS->getAPIntValue();
8354       APInt c2 = RHS->getAPIntValue();
8355       zeroExtendToMatch(c1, c2);
8356       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
8357     };
8358     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
8359                                   /*AllowUndefs*/ false,
8360                                   /*AllowTypeMismatch*/ true)) {
8361       SDLoc DL(N);
8362       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
8363       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
8364       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
8365       AddToWorklist(NewSHL.getNode());
8366       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
8367     }
8368   }
8369 
8370   // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
8371   // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1  > C2
8372   // TODO - support non-uniform vector shift amounts.
8373   if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
8374       N0->getFlags().hasExact()) {
8375     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
8376       uint64_t C1 = N0C1->getZExtValue();
8377       uint64_t C2 = N1C->getZExtValue();
8378       SDLoc DL(N);
8379       if (C1 <= C2)
8380         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
8381                            DAG.getConstant(C2 - C1, DL, ShiftVT));
8382       return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0),
8383                          DAG.getConstant(C1 - C2, DL, ShiftVT));
8384     }
8385   }
8386 
8387   // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
8388   //                               (and (srl x, (sub c1, c2), MASK)
8389   // Only fold this if the inner shift has no other uses -- if it does, folding
8390   // this will increase the total number of instructions.
8391   // TODO - drop hasOneUse requirement if c1 == c2?
8392   // TODO - support non-uniform vector shift amounts.
8393   if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() &&
8394       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
8395     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
8396       if (N0C1->getAPIntValue().ult(OpSizeInBits)) {
8397         uint64_t c1 = N0C1->getZExtValue();
8398         uint64_t c2 = N1C->getZExtValue();
8399         APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1);
8400         SDValue Shift;
8401         if (c2 > c1) {
8402           Mask <<= c2 - c1;
8403           SDLoc DL(N);
8404           Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
8405                               DAG.getConstant(c2 - c1, DL, ShiftVT));
8406         } else {
8407           Mask.lshrInPlace(c1 - c2);
8408           SDLoc DL(N);
8409           Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
8410                               DAG.getConstant(c1 - c2, DL, ShiftVT));
8411         }
8412         SDLoc DL(N0);
8413         return DAG.getNode(ISD::AND, DL, VT, Shift,
8414                            DAG.getConstant(Mask, DL, VT));
8415       }
8416     }
8417   }
8418 
8419   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
8420   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
8421       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
8422     SDLoc DL(N);
8423     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
8424     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
8425     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
8426   }
8427 
8428   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
8429   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
8430   // Variant of version done on multiply, except mul by a power of 2 is turned
8431   // into a shift.
8432   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
8433       N0.getNode()->hasOneUse() &&
8434       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
8435       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
8436       TLI.isDesirableToCommuteWithShift(N, Level)) {
8437     SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
8438     SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
8439     AddToWorklist(Shl0.getNode());
8440     AddToWorklist(Shl1.getNode());
8441     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
8442   }
8443 
8444   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
8445   if (N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse() &&
8446       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
8447       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) {
8448     SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
8449     if (isConstantOrConstantVector(Shl))
8450       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
8451   }
8452 
8453   if (N1C && !N1C->isOpaque())
8454     if (SDValue NewSHL = visitShiftByConstant(N))
8455       return NewSHL;
8456 
8457   // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
8458   if (N0.getOpcode() == ISD::VSCALE)
8459     if (ConstantSDNode *NC1 = isConstOrConstSplat(N->getOperand(1))) {
8460       const APInt &C0 = N0.getConstantOperandAPInt(0);
8461       const APInt &C1 = NC1->getAPIntValue();
8462       return DAG.getVScale(SDLoc(N), VT, C0 << C1);
8463     }
8464 
8465   // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
8466   APInt ShlVal;
8467   if (N0.getOpcode() == ISD::STEP_VECTOR)
8468     if (ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
8469       const APInt &C0 = N0.getConstantOperandAPInt(0);
8470       if (ShlVal.ult(C0.getBitWidth())) {
8471         APInt NewStep = C0 << ShlVal;
8472         return DAG.getStepVector(SDLoc(N), VT, NewStep);
8473       }
8474     }
8475 
8476   return SDValue();
8477 }
8478 
8479 // Transform a right shift of a multiply into a multiply-high.
8480 // Examples:
8481 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
8482 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)8483 static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
8484                                   const TargetLowering &TLI) {
8485   assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
8486          "SRL or SRA node is required here!");
8487 
8488   // Check the shift amount. Proceed with the transformation if the shift
8489   // amount is constant.
8490   ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
8491   if (!ShiftAmtSrc)
8492     return SDValue();
8493 
8494   SDLoc DL(N);
8495 
8496   // The operation feeding into the shift must be a multiply.
8497   SDValue ShiftOperand = N->getOperand(0);
8498   if (ShiftOperand.getOpcode() != ISD::MUL)
8499     return SDValue();
8500 
8501   // Both operands must be equivalent extend nodes.
8502   SDValue LeftOp = ShiftOperand.getOperand(0);
8503   SDValue RightOp = ShiftOperand.getOperand(1);
8504   bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
8505   bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
8506 
8507   if ((!(IsSignExt || IsZeroExt)) || LeftOp.getOpcode() != RightOp.getOpcode())
8508     return SDValue();
8509 
8510   EVT WideVT1 = LeftOp.getValueType();
8511   EVT WideVT2 = RightOp.getValueType();
8512   (void)WideVT2;
8513   // Proceed with the transformation if the wide types match.
8514   assert((WideVT1 == WideVT2) &&
8515          "Cannot have a multiply node with two different operand types.");
8516 
8517   EVT NarrowVT = LeftOp.getOperand(0).getValueType();
8518   // Check that the two extend nodes are the same type.
8519   if (NarrowVT !=  RightOp.getOperand(0).getValueType())
8520     return SDValue();
8521 
8522   // Proceed with the transformation if the wide type is twice as large
8523   // as the narrow type.
8524   unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
8525   if (WideVT1.getScalarSizeInBits() != 2 * NarrowVTSize)
8526     return SDValue();
8527 
8528   // Check the shift amount with the narrow type size.
8529   // Proceed with the transformation if the shift amount is the width
8530   // of the narrow type.
8531   unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
8532   if (ShiftAmt != NarrowVTSize)
8533     return SDValue();
8534 
8535   // If the operation feeding into the MUL is a sign extend (sext),
8536   // we use mulhs. Othewise, zero extends (zext) use mulhu.
8537   unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
8538 
8539   // Combine to mulh if mulh is legal/custom for the narrow type on the target.
8540   if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
8541     return SDValue();
8542 
8543   SDValue Result = DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0),
8544                                RightOp.getOperand(0));
8545   return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT1)
8546                                      : DAG.getZExtOrTrunc(Result, DL, WideVT1));
8547 }
8548 
visitSRA(SDNode * N)8549 SDValue DAGCombiner::visitSRA(SDNode *N) {
8550   SDValue N0 = N->getOperand(0);
8551   SDValue N1 = N->getOperand(1);
8552   if (SDValue V = DAG.simplifyShift(N0, N1))
8553     return V;
8554 
8555   EVT VT = N0.getValueType();
8556   unsigned OpSizeInBits = VT.getScalarSizeInBits();
8557 
8558   // Arithmetic shifting an all-sign-bit value is a no-op.
8559   // fold (sra 0, x) -> 0
8560   // fold (sra -1, x) -> -1
8561   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
8562     return N0;
8563 
8564   // fold vector ops
8565   if (VT.isVector())
8566     if (SDValue FoldedVOp = SimplifyVBinOp(N))
8567       return FoldedVOp;
8568 
8569   ConstantSDNode *N1C = isConstOrConstSplat(N1);
8570 
8571   // fold (sra c1, c2) -> (sra c1, c2)
8572   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
8573     return C;
8574 
8575   if (SDValue NewSel = foldBinOpIntoSelect(N))
8576     return NewSel;
8577 
8578   // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
8579   // sext_inreg.
8580   if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
8581     unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
8582     EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
8583     if (VT.isVector())
8584       ExtVT = EVT::getVectorVT(*DAG.getContext(), ExtVT,
8585                                VT.getVectorElementCount());
8586     if (!LegalOperations ||
8587         TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
8588         TargetLowering::Legal)
8589       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
8590                          N0.getOperand(0), DAG.getValueType(ExtVT));
8591     // Even if we can't convert to sext_inreg, we might be able to remove
8592     // this shift pair if the input is already sign extended.
8593     if (DAG.ComputeNumSignBits(N0.getOperand(0)) > N1C->getZExtValue())
8594       return N0.getOperand(0);
8595   }
8596 
8597   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
8598   // clamp (add c1, c2) to max shift.
8599   if (N0.getOpcode() == ISD::SRA) {
8600     SDLoc DL(N);
8601     EVT ShiftVT = N1.getValueType();
8602     EVT ShiftSVT = ShiftVT.getScalarType();
8603     SmallVector<SDValue, 16> ShiftValues;
8604 
8605     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
8606       APInt c1 = LHS->getAPIntValue();
8607       APInt c2 = RHS->getAPIntValue();
8608       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8609       APInt Sum = c1 + c2;
8610       unsigned ShiftSum =
8611           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
8612       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
8613       return true;
8614     };
8615     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
8616       SDValue ShiftValue;
8617       if (N1.getOpcode() == ISD::BUILD_VECTOR)
8618         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
8619       else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
8620         assert(ShiftValues.size() == 1 &&
8621                "Expected matchBinaryPredicate to return one element for "
8622                "SPLAT_VECTORs");
8623         ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
8624       } else
8625         ShiftValue = ShiftValues[0];
8626       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
8627     }
8628   }
8629 
8630   // fold (sra (shl X, m), (sub result_size, n))
8631   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
8632   // result_size - n != m.
8633   // If truncate is free for the target sext(shl) is likely to result in better
8634   // code.
8635   if (N0.getOpcode() == ISD::SHL && N1C) {
8636     // Get the two constanst of the shifts, CN0 = m, CN = n.
8637     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
8638     if (N01C) {
8639       LLVMContext &Ctx = *DAG.getContext();
8640       // Determine what the truncate's result bitsize and type would be.
8641       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
8642 
8643       if (VT.isVector())
8644         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
8645 
8646       // Determine the residual right-shift amount.
8647       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
8648 
8649       // If the shift is not a no-op (in which case this should be just a sign
8650       // extend already), the truncated to type is legal, sign_extend is legal
8651       // on that type, and the truncate to that type is both legal and free,
8652       // perform the transform.
8653       if ((ShiftAmt > 0) &&
8654           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
8655           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
8656           TLI.isTruncateFree(VT, TruncVT)) {
8657         SDLoc DL(N);
8658         SDValue Amt = DAG.getConstant(ShiftAmt, DL,
8659             getShiftAmountTy(N0.getOperand(0).getValueType()));
8660         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
8661                                     N0.getOperand(0), Amt);
8662         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
8663                                     Shift);
8664         return DAG.getNode(ISD::SIGN_EXTEND, DL,
8665                            N->getValueType(0), Trunc);
8666       }
8667     }
8668   }
8669 
8670   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
8671   //   sra (add (shl X, N1C), AddC), N1C -->
8672   //   sext (add (trunc X to (width - N1C)), AddC')
8673   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() && N1C &&
8674       N0.getOperand(0).getOpcode() == ISD::SHL &&
8675       N0.getOperand(0).getOperand(1) == N1 && N0.getOperand(0).hasOneUse()) {
8676     if (ConstantSDNode *AddC = isConstOrConstSplat(N0.getOperand(1))) {
8677       SDValue Shl = N0.getOperand(0);
8678       // Determine what the truncate's type would be and ask the target if that
8679       // is a free operation.
8680       LLVMContext &Ctx = *DAG.getContext();
8681       unsigned ShiftAmt = N1C->getZExtValue();
8682       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
8683       if (VT.isVector())
8684         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
8685 
8686       // TODO: The simple type check probably belongs in the default hook
8687       //       implementation and/or target-specific overrides (because
8688       //       non-simple types likely require masking when legalized), but that
8689       //       restriction may conflict with other transforms.
8690       if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
8691           TLI.isTruncateFree(VT, TruncVT)) {
8692         SDLoc DL(N);
8693         SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
8694         SDValue ShiftC = DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).
8695                              trunc(TruncVT.getScalarSizeInBits()), DL, TruncVT);
8696         SDValue Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
8697         return DAG.getSExtOrTrunc(Add, DL, VT);
8698       }
8699     }
8700   }
8701 
8702   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
8703   if (N1.getOpcode() == ISD::TRUNCATE &&
8704       N1.getOperand(0).getOpcode() == ISD::AND) {
8705     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8706       return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
8707   }
8708 
8709   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
8710   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
8711   //      if c1 is equal to the number of bits the trunc removes
8712   // TODO - support non-uniform vector shift amounts.
8713   if (N0.getOpcode() == ISD::TRUNCATE &&
8714       (N0.getOperand(0).getOpcode() == ISD::SRL ||
8715        N0.getOperand(0).getOpcode() == ISD::SRA) &&
8716       N0.getOperand(0).hasOneUse() &&
8717       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
8718     SDValue N0Op0 = N0.getOperand(0);
8719     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
8720       EVT LargeVT = N0Op0.getValueType();
8721       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
8722       if (LargeShift->getAPIntValue() == TruncBits) {
8723         SDLoc DL(N);
8724         SDValue Amt = DAG.getConstant(N1C->getZExtValue() + TruncBits, DL,
8725                                       getShiftAmountTy(LargeVT));
8726         SDValue SRA =
8727             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
8728         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
8729       }
8730     }
8731   }
8732 
8733   // Simplify, based on bits shifted out of the LHS.
8734   if (SimplifyDemandedBits(SDValue(N, 0)))
8735     return SDValue(N, 0);
8736 
8737   // If the sign bit is known to be zero, switch this to a SRL.
8738   if (DAG.SignBitIsZero(N0))
8739     return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
8740 
8741   if (N1C && !N1C->isOpaque())
8742     if (SDValue NewSRA = visitShiftByConstant(N))
8743       return NewSRA;
8744 
8745   // Try to transform this shift into a multiply-high if
8746   // it matches the appropriate pattern detected in combineShiftToMULH.
8747   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
8748     return MULH;
8749 
8750   return SDValue();
8751 }
8752 
visitSRL(SDNode * N)8753 SDValue DAGCombiner::visitSRL(SDNode *N) {
8754   SDValue N0 = N->getOperand(0);
8755   SDValue N1 = N->getOperand(1);
8756   if (SDValue V = DAG.simplifyShift(N0, N1))
8757     return V;
8758 
8759   EVT VT = N0.getValueType();
8760   unsigned OpSizeInBits = VT.getScalarSizeInBits();
8761 
8762   // fold vector ops
8763   if (VT.isVector())
8764     if (SDValue FoldedVOp = SimplifyVBinOp(N))
8765       return FoldedVOp;
8766 
8767   ConstantSDNode *N1C = isConstOrConstSplat(N1);
8768 
8769   // fold (srl c1, c2) -> c1 >>u c2
8770   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
8771     return C;
8772 
8773   if (SDValue NewSel = foldBinOpIntoSelect(N))
8774     return NewSel;
8775 
8776   // if (srl x, c) is known to be zero, return 0
8777   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
8778                                    APInt::getAllOnesValue(OpSizeInBits)))
8779     return DAG.getConstant(0, SDLoc(N), VT);
8780 
8781   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
8782   if (N0.getOpcode() == ISD::SRL) {
8783     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
8784                                           ConstantSDNode *RHS) {
8785       APInt c1 = LHS->getAPIntValue();
8786       APInt c2 = RHS->getAPIntValue();
8787       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8788       return (c1 + c2).uge(OpSizeInBits);
8789     };
8790     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
8791       return DAG.getConstant(0, SDLoc(N), VT);
8792 
8793     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
8794                                        ConstantSDNode *RHS) {
8795       APInt c1 = LHS->getAPIntValue();
8796       APInt c2 = RHS->getAPIntValue();
8797       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8798       return (c1 + c2).ult(OpSizeInBits);
8799     };
8800     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
8801       SDLoc DL(N);
8802       EVT ShiftVT = N1.getValueType();
8803       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
8804       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
8805     }
8806   }
8807 
8808   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
8809       N0.getOperand(0).getOpcode() == ISD::SRL) {
8810     SDValue InnerShift = N0.getOperand(0);
8811     // TODO - support non-uniform vector shift amounts.
8812     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
8813       uint64_t c1 = N001C->getZExtValue();
8814       uint64_t c2 = N1C->getZExtValue();
8815       EVT InnerShiftVT = InnerShift.getValueType();
8816       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
8817       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
8818       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
8819       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
8820       if (c1 + OpSizeInBits == InnerShiftSize) {
8821         SDLoc DL(N);
8822         if (c1 + c2 >= InnerShiftSize)
8823           return DAG.getConstant(0, DL, VT);
8824         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
8825         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
8826                                        InnerShift.getOperand(0), NewShiftAmt);
8827         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
8828       }
8829       // In the more general case, we can clear the high bits after the shift:
8830       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
8831       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
8832           c1 + c2 < InnerShiftSize) {
8833         SDLoc DL(N);
8834         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
8835         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
8836                                        InnerShift.getOperand(0), NewShiftAmt);
8837         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
8838                                                             OpSizeInBits - c2),
8839                                        DL, InnerShiftVT);
8840         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
8841         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
8842       }
8843     }
8844   }
8845 
8846   // fold (srl (shl x, c), c) -> (and x, cst2)
8847   // TODO - (srl (shl x, c1), c2).
8848   if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 &&
8849       isConstantOrConstantVector(N1, /* NoOpaques */ true)) {
8850     SDLoc DL(N);
8851     SDValue Mask =
8852         DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1);
8853     AddToWorklist(Mask.getNode());
8854     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask);
8855   }
8856 
8857   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
8858   // TODO - support non-uniform vector shift amounts.
8859   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
8860     // Shifting in all undef bits?
8861     EVT SmallVT = N0.getOperand(0).getValueType();
8862     unsigned BitSize = SmallVT.getScalarSizeInBits();
8863     if (N1C->getAPIntValue().uge(BitSize))
8864       return DAG.getUNDEF(VT);
8865 
8866     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
8867       uint64_t ShiftAmt = N1C->getZExtValue();
8868       SDLoc DL0(N0);
8869       SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
8870                                        N0.getOperand(0),
8871                           DAG.getConstant(ShiftAmt, DL0,
8872                                           getShiftAmountTy(SmallVT)));
8873       AddToWorklist(SmallShift.getNode());
8874       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
8875       SDLoc DL(N);
8876       return DAG.getNode(ISD::AND, DL, VT,
8877                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
8878                          DAG.getConstant(Mask, DL, VT));
8879     }
8880   }
8881 
8882   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
8883   // bit, which is unmodified by sra.
8884   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
8885     if (N0.getOpcode() == ISD::SRA)
8886       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
8887   }
8888 
8889   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit).
8890   if (N1C && N0.getOpcode() == ISD::CTLZ &&
8891       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
8892     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
8893 
8894     // If any of the input bits are KnownOne, then the input couldn't be all
8895     // zeros, thus the result of the srl will always be zero.
8896     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
8897 
8898     // If all of the bits input the to ctlz node are known to be zero, then
8899     // the result of the ctlz is "32" and the result of the shift is one.
8900     APInt UnknownBits = ~Known.Zero;
8901     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
8902 
8903     // Otherwise, check to see if there is exactly one bit input to the ctlz.
8904     if (UnknownBits.isPowerOf2()) {
8905       // Okay, we know that only that the single bit specified by UnknownBits
8906       // could be set on input to the CTLZ node. If this bit is set, the SRL
8907       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
8908       // to an SRL/XOR pair, which is likely to simplify more.
8909       unsigned ShAmt = UnknownBits.countTrailingZeros();
8910       SDValue Op = N0.getOperand(0);
8911 
8912       if (ShAmt) {
8913         SDLoc DL(N0);
8914         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
8915                   DAG.getConstant(ShAmt, DL,
8916                                   getShiftAmountTy(Op.getValueType())));
8917         AddToWorklist(Op.getNode());
8918       }
8919 
8920       SDLoc DL(N);
8921       return DAG.getNode(ISD::XOR, DL, VT,
8922                          Op, DAG.getConstant(1, DL, VT));
8923     }
8924   }
8925 
8926   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
8927   if (N1.getOpcode() == ISD::TRUNCATE &&
8928       N1.getOperand(0).getOpcode() == ISD::AND) {
8929     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8930       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
8931   }
8932 
8933   // fold operands of srl based on knowledge that the low bits are not
8934   // demanded.
8935   if (SimplifyDemandedBits(SDValue(N, 0)))
8936     return SDValue(N, 0);
8937 
8938   if (N1C && !N1C->isOpaque())
8939     if (SDValue NewSRL = visitShiftByConstant(N))
8940       return NewSRL;
8941 
8942   // Attempt to convert a srl of a load into a narrower zero-extending load.
8943   if (SDValue NarrowLoad = ReduceLoadWidth(N))
8944     return NarrowLoad;
8945 
8946   // Here is a common situation. We want to optimize:
8947   //
8948   //   %a = ...
8949   //   %b = and i32 %a, 2
8950   //   %c = srl i32 %b, 1
8951   //   brcond i32 %c ...
8952   //
8953   // into
8954   //
8955   //   %a = ...
8956   //   %b = and %a, 2
8957   //   %c = setcc eq %b, 0
8958   //   brcond %c ...
8959   //
8960   // However when after the source operand of SRL is optimized into AND, the SRL
8961   // itself may not be optimized further. Look for it and add the BRCOND into
8962   // the worklist.
8963   if (N->hasOneUse()) {
8964     SDNode *Use = *N->use_begin();
8965     if (Use->getOpcode() == ISD::BRCOND)
8966       AddToWorklist(Use);
8967     else if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse()) {
8968       // Also look pass the truncate.
8969       Use = *Use->use_begin();
8970       if (Use->getOpcode() == ISD::BRCOND)
8971         AddToWorklist(Use);
8972     }
8973   }
8974 
8975   // Try to transform this shift into a multiply-high if
8976   // it matches the appropriate pattern detected in combineShiftToMULH.
8977   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
8978     return MULH;
8979 
8980   return SDValue();
8981 }
8982 
visitFunnelShift(SDNode * N)8983 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
8984   EVT VT = N->getValueType(0);
8985   SDValue N0 = N->getOperand(0);
8986   SDValue N1 = N->getOperand(1);
8987   SDValue N2 = N->getOperand(2);
8988   bool IsFSHL = N->getOpcode() == ISD::FSHL;
8989   unsigned BitWidth = VT.getScalarSizeInBits();
8990 
8991   // fold (fshl N0, N1, 0) -> N0
8992   // fold (fshr N0, N1, 0) -> N1
8993   if (isPowerOf2_32(BitWidth))
8994     if (DAG.MaskedValueIsZero(
8995             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
8996       return IsFSHL ? N0 : N1;
8997 
8998   auto IsUndefOrZero = [](SDValue V) {
8999     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
9000   };
9001 
9002   // TODO - support non-uniform vector shift amounts.
9003   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
9004     EVT ShAmtTy = N2.getValueType();
9005 
9006     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
9007     if (Cst->getAPIntValue().uge(BitWidth)) {
9008       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
9009       return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
9010                          DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
9011     }
9012 
9013     unsigned ShAmt = Cst->getZExtValue();
9014     if (ShAmt == 0)
9015       return IsFSHL ? N0 : N1;
9016 
9017     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
9018     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
9019     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
9020     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
9021     if (IsUndefOrZero(N0))
9022       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
9023                          DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
9024                                          SDLoc(N), ShAmtTy));
9025     if (IsUndefOrZero(N1))
9026       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
9027                          DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
9028                                          SDLoc(N), ShAmtTy));
9029 
9030     // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
9031     // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
9032     // TODO - bigendian support once we have test coverage.
9033     // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
9034     // TODO - permit LHS EXTLOAD if extensions are shifted out.
9035     if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
9036         !DAG.getDataLayout().isBigEndian()) {
9037       auto *LHS = dyn_cast<LoadSDNode>(N0);
9038       auto *RHS = dyn_cast<LoadSDNode>(N1);
9039       if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
9040           LHS->getAddressSpace() == RHS->getAddressSpace() &&
9041           (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) &&
9042           ISD::isNON_EXTLoad(LHS)) {
9043         if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
9044           SDLoc DL(RHS);
9045           uint64_t PtrOff =
9046               IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
9047           Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
9048           bool Fast = false;
9049           if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
9050                                      RHS->getAddressSpace(), NewAlign,
9051                                      RHS->getMemOperand()->getFlags(), &Fast) &&
9052               Fast) {
9053             SDValue NewPtr = DAG.getMemBasePlusOffset(
9054                 RHS->getBasePtr(), TypeSize::Fixed(PtrOff), DL);
9055             AddToWorklist(NewPtr.getNode());
9056             SDValue Load = DAG.getLoad(
9057                 VT, DL, RHS->getChain(), NewPtr,
9058                 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
9059                 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
9060             // Replace the old load's chain with the new load's chain.
9061             WorklistRemover DeadNodes(*this);
9062             DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1));
9063             return Load;
9064           }
9065         }
9066       }
9067     }
9068   }
9069 
9070   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
9071   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
9072   // iff We know the shift amount is in range.
9073   // TODO: when is it worth doing SUB(BW, N2) as well?
9074   if (isPowerOf2_32(BitWidth)) {
9075     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
9076     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
9077       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
9078     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
9079       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
9080   }
9081 
9082   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
9083   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
9084   // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
9085   // is legal as well we might be better off avoiding non-constant (BW - N2).
9086   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
9087   if (N0 == N1 && hasOperation(RotOpc, VT))
9088     return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
9089 
9090   // Simplify, based on bits shifted out of N0/N1.
9091   if (SimplifyDemandedBits(SDValue(N, 0)))
9092     return SDValue(N, 0);
9093 
9094   return SDValue();
9095 }
9096 
9097 // Given a ABS node, detect the following pattern:
9098 // (ABS (SUB (EXTEND a), (EXTEND b))).
9099 // Generates UABD/SABD instruction.
combineABSToABD(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)9100 static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
9101                                const TargetLowering &TLI) {
9102   SDValue AbsOp1 = N->getOperand(0);
9103   SDValue Op0, Op1;
9104 
9105   if (AbsOp1.getOpcode() != ISD::SUB)
9106     return SDValue();
9107 
9108   Op0 = AbsOp1.getOperand(0);
9109   Op1 = AbsOp1.getOperand(1);
9110 
9111   unsigned Opc0 = Op0.getOpcode();
9112   // Check if the operands of the sub are (zero|sign)-extended.
9113   if (Opc0 != Op1.getOpcode() ||
9114       (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
9115     return SDValue();
9116 
9117   EVT VT1 = Op0.getOperand(0).getValueType();
9118   EVT VT2 = Op1.getOperand(0).getValueType();
9119   // Check if the operands are of same type and valid size.
9120   unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
9121   if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1))
9122     return SDValue();
9123 
9124   Op0 = Op0.getOperand(0);
9125   Op1 = Op1.getOperand(0);
9126   SDValue ABD =
9127       DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
9128   return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
9129 }
9130 
visitABS(SDNode * N)9131 SDValue DAGCombiner::visitABS(SDNode *N) {
9132   SDValue N0 = N->getOperand(0);
9133   EVT VT = N->getValueType(0);
9134 
9135   // fold (abs c1) -> c2
9136   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9137     return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
9138   // fold (abs (abs x)) -> (abs x)
9139   if (N0.getOpcode() == ISD::ABS)
9140     return N0;
9141   // fold (abs x) -> x iff not-negative
9142   if (DAG.SignBitIsZero(N0))
9143     return N0;
9144 
9145   if (SDValue ABD = combineABSToABD(N, DAG, TLI))
9146     return ABD;
9147 
9148   return SDValue();
9149 }
9150 
visitBSWAP(SDNode * N)9151 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
9152   SDValue N0 = N->getOperand(0);
9153   EVT VT = N->getValueType(0);
9154 
9155   // fold (bswap c1) -> c2
9156   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9157     return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N0);
9158   // fold (bswap (bswap x)) -> x
9159   if (N0.getOpcode() == ISD::BSWAP)
9160     return N0->getOperand(0);
9161   return SDValue();
9162 }
9163 
visitBITREVERSE(SDNode * N)9164 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
9165   SDValue N0 = N->getOperand(0);
9166   EVT VT = N->getValueType(0);
9167 
9168   // fold (bitreverse c1) -> c2
9169   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9170     return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
9171   // fold (bitreverse (bitreverse x)) -> x
9172   if (N0.getOpcode() == ISD::BITREVERSE)
9173     return N0.getOperand(0);
9174   return SDValue();
9175 }
9176 
visitCTLZ(SDNode * N)9177 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
9178   SDValue N0 = N->getOperand(0);
9179   EVT VT = N->getValueType(0);
9180 
9181   // fold (ctlz c1) -> c2
9182   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9183     return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
9184 
9185   // If the value is known never to be zero, switch to the undef version.
9186   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
9187     if (DAG.isKnownNeverZero(N0))
9188       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
9189   }
9190 
9191   return SDValue();
9192 }
9193 
visitCTLZ_ZERO_UNDEF(SDNode * N)9194 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
9195   SDValue N0 = N->getOperand(0);
9196   EVT VT = N->getValueType(0);
9197 
9198   // fold (ctlz_zero_undef c1) -> c2
9199   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9200     return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
9201   return SDValue();
9202 }
9203 
visitCTTZ(SDNode * N)9204 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
9205   SDValue N0 = N->getOperand(0);
9206   EVT VT = N->getValueType(0);
9207 
9208   // fold (cttz c1) -> c2
9209   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9210     return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
9211 
9212   // If the value is known never to be zero, switch to the undef version.
9213   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
9214     if (DAG.isKnownNeverZero(N0))
9215       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
9216   }
9217 
9218   return SDValue();
9219 }
9220 
visitCTTZ_ZERO_UNDEF(SDNode * N)9221 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
9222   SDValue N0 = N->getOperand(0);
9223   EVT VT = N->getValueType(0);
9224 
9225   // fold (cttz_zero_undef c1) -> c2
9226   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9227     return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
9228   return SDValue();
9229 }
9230 
visitCTPOP(SDNode * N)9231 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
9232   SDValue N0 = N->getOperand(0);
9233   EVT VT = N->getValueType(0);
9234 
9235   // fold (ctpop c1) -> c2
9236   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9237     return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
9238   return SDValue();
9239 }
9240 
9241 // FIXME: This should be checking for no signed zeros on individual operands, as
9242 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)9243 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
9244                                          SDValue RHS,
9245                                          const TargetLowering &TLI) {
9246   const TargetOptions &Options = DAG.getTarget().Options;
9247   EVT VT = LHS.getValueType();
9248 
9249   return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
9250          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
9251          DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
9252 }
9253 
9254 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)9255 static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
9256                                    SDValue RHS, SDValue True, SDValue False,
9257                                    ISD::CondCode CC, const TargetLowering &TLI,
9258                                    SelectionDAG &DAG) {
9259   if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True))
9260     return SDValue();
9261 
9262   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
9263   switch (CC) {
9264   case ISD::SETOLT:
9265   case ISD::SETOLE:
9266   case ISD::SETLT:
9267   case ISD::SETLE:
9268   case ISD::SETULT:
9269   case ISD::SETULE: {
9270     // Since it's known never nan to get here already, either fminnum or
9271     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
9272     // expanded in terms of it.
9273     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
9274     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
9275       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
9276 
9277     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
9278     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
9279       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
9280     return SDValue();
9281   }
9282   case ISD::SETOGT:
9283   case ISD::SETOGE:
9284   case ISD::SETGT:
9285   case ISD::SETGE:
9286   case ISD::SETUGT:
9287   case ISD::SETUGE: {
9288     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
9289     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
9290       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
9291 
9292     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
9293     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
9294       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
9295     return SDValue();
9296   }
9297   default:
9298     return SDValue();
9299   }
9300 }
9301 
9302 /// If a (v)select has a condition value that is a sign-bit test, try to smear
9303 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)9304 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
9305   SDValue Cond = N->getOperand(0);
9306   SDValue C1 = N->getOperand(1);
9307   SDValue C2 = N->getOperand(2);
9308   if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
9309     return SDValue();
9310 
9311   EVT VT = N->getValueType(0);
9312   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
9313       VT != Cond.getOperand(0).getValueType())
9314     return SDValue();
9315 
9316   // The inverted-condition + commuted-select variants of these patterns are
9317   // canonicalized to these forms in IR.
9318   SDValue X = Cond.getOperand(0);
9319   SDValue CondC = Cond.getOperand(1);
9320   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
9321   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
9322       isAllOnesOrAllOnesSplat(C2)) {
9323     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
9324     SDLoc DL(N);
9325     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
9326     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
9327     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
9328   }
9329   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
9330     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
9331     SDLoc DL(N);
9332     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
9333     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
9334     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
9335   }
9336   return SDValue();
9337 }
9338 
foldSelectOfConstants(SDNode * N)9339 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
9340   SDValue Cond = N->getOperand(0);
9341   SDValue N1 = N->getOperand(1);
9342   SDValue N2 = N->getOperand(2);
9343   EVT VT = N->getValueType(0);
9344   EVT CondVT = Cond.getValueType();
9345   SDLoc DL(N);
9346 
9347   if (!VT.isInteger())
9348     return SDValue();
9349 
9350   auto *C1 = dyn_cast<ConstantSDNode>(N1);
9351   auto *C2 = dyn_cast<ConstantSDNode>(N2);
9352   if (!C1 || !C2)
9353     return SDValue();
9354 
9355   // Only do this before legalization to avoid conflicting with target-specific
9356   // transforms in the other direction (create a select from a zext/sext). There
9357   // is also a target-independent combine here in DAGCombiner in the other
9358   // direction for (select Cond, -1, 0) when the condition is not i1.
9359   if (CondVT == MVT::i1 && !LegalOperations) {
9360     if (C1->isNullValue() && C2->isOne()) {
9361       // select Cond, 0, 1 --> zext (!Cond)
9362       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
9363       if (VT != MVT::i1)
9364         NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond);
9365       return NotCond;
9366     }
9367     if (C1->isNullValue() && C2->isAllOnesValue()) {
9368       // select Cond, 0, -1 --> sext (!Cond)
9369       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
9370       if (VT != MVT::i1)
9371         NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond);
9372       return NotCond;
9373     }
9374     if (C1->isOne() && C2->isNullValue()) {
9375       // select Cond, 1, 0 --> zext (Cond)
9376       if (VT != MVT::i1)
9377         Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
9378       return Cond;
9379     }
9380     if (C1->isAllOnesValue() && C2->isNullValue()) {
9381       // select Cond, -1, 0 --> sext (Cond)
9382       if (VT != MVT::i1)
9383         Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
9384       return Cond;
9385     }
9386 
9387     // Use a target hook because some targets may prefer to transform in the
9388     // other direction.
9389     if (TLI.convertSelectOfConstantsToMath(VT)) {
9390       // For any constants that differ by 1, we can transform the select into an
9391       // extend and add.
9392       const APInt &C1Val = C1->getAPIntValue();
9393       const APInt &C2Val = C2->getAPIntValue();
9394       if (C1Val - 1 == C2Val) {
9395         // select Cond, C1, C1-1 --> add (zext Cond), C1-1
9396         if (VT != MVT::i1)
9397           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
9398         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
9399       }
9400       if (C1Val + 1 == C2Val) {
9401         // select Cond, C1, C1+1 --> add (sext Cond), C1+1
9402         if (VT != MVT::i1)
9403           Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
9404         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
9405       }
9406 
9407       // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
9408       if (C1Val.isPowerOf2() && C2Val.isNullValue()) {
9409         if (VT != MVT::i1)
9410           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
9411         SDValue ShAmtC = DAG.getConstant(C1Val.exactLogBase2(), DL, VT);
9412         return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
9413       }
9414 
9415       if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
9416         return V;
9417     }
9418 
9419     return SDValue();
9420   }
9421 
9422   // fold (select Cond, 0, 1) -> (xor Cond, 1)
9423   // We can't do this reliably if integer based booleans have different contents
9424   // to floating point based booleans. This is because we can't tell whether we
9425   // have an integer-based boolean or a floating-point-based boolean unless we
9426   // can find the SETCC that produced it and inspect its operands. This is
9427   // fairly easy if C is the SETCC node, but it can potentially be
9428   // undiscoverable (or not reasonably discoverable). For example, it could be
9429   // in another basic block or it could require searching a complicated
9430   // expression.
9431   if (CondVT.isInteger() &&
9432       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
9433           TargetLowering::ZeroOrOneBooleanContent &&
9434       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
9435           TargetLowering::ZeroOrOneBooleanContent &&
9436       C1->isNullValue() && C2->isOne()) {
9437     SDValue NotCond =
9438         DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
9439     if (VT.bitsEq(CondVT))
9440       return NotCond;
9441     return DAG.getZExtOrTrunc(NotCond, DL, VT);
9442   }
9443 
9444   return SDValue();
9445 }
9446 
foldBoolSelectToLogic(SDNode * N,SelectionDAG & DAG)9447 static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
9448   assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
9449          "Expected a (v)select");
9450   SDValue Cond = N->getOperand(0);
9451   SDValue T = N->getOperand(1), F = N->getOperand(2);
9452   EVT VT = N->getValueType(0);
9453   if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
9454     return SDValue();
9455 
9456   // select Cond, Cond, F --> or Cond, F
9457   // select Cond, 1, F    --> or Cond, F
9458   if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
9459     return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
9460 
9461   // select Cond, T, Cond --> and Cond, T
9462   // select Cond, T, 0    --> and Cond, T
9463   if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
9464     return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
9465 
9466   // select Cond, T, 1 --> or (not Cond), T
9467   if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
9468     SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
9469     return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
9470   }
9471 
9472   // select Cond, 0, F --> and (not Cond), F
9473   if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
9474     SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
9475     return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
9476   }
9477 
9478   return SDValue();
9479 }
9480 
visitSELECT(SDNode * N)9481 SDValue DAGCombiner::visitSELECT(SDNode *N) {
9482   SDValue N0 = N->getOperand(0);
9483   SDValue N1 = N->getOperand(1);
9484   SDValue N2 = N->getOperand(2);
9485   EVT VT = N->getValueType(0);
9486   EVT VT0 = N0.getValueType();
9487   SDLoc DL(N);
9488   SDNodeFlags Flags = N->getFlags();
9489 
9490   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
9491     return V;
9492 
9493   if (SDValue V = foldSelectOfConstants(N))
9494     return V;
9495 
9496   if (SDValue V = foldBoolSelectToLogic(N, DAG))
9497     return V;
9498 
9499   // If we can fold this based on the true/false value, do so.
9500   if (SimplifySelectOps(N, N1, N2))
9501     return SDValue(N, 0); // Don't revisit N.
9502 
9503   if (VT0 == MVT::i1) {
9504     // The code in this block deals with the following 2 equivalences:
9505     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
9506     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
9507     // The target can specify its preferred form with the
9508     // shouldNormalizeToSelectSequence() callback. However we always transform
9509     // to the right anyway if we find the inner select exists in the DAG anyway
9510     // and we always transform to the left side if we know that we can further
9511     // optimize the combination of the conditions.
9512     bool normalizeToSequence =
9513         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
9514     // select (and Cond0, Cond1), X, Y
9515     //   -> select Cond0, (select Cond1, X, Y), Y
9516     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
9517       SDValue Cond0 = N0->getOperand(0);
9518       SDValue Cond1 = N0->getOperand(1);
9519       SDValue InnerSelect =
9520           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
9521       if (normalizeToSequence || !InnerSelect.use_empty())
9522         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
9523                            InnerSelect, N2, Flags);
9524       // Cleanup on failure.
9525       if (InnerSelect.use_empty())
9526         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
9527     }
9528     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
9529     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
9530       SDValue Cond0 = N0->getOperand(0);
9531       SDValue Cond1 = N0->getOperand(1);
9532       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
9533                                         Cond1, N1, N2, Flags);
9534       if (normalizeToSequence || !InnerSelect.use_empty())
9535         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
9536                            InnerSelect, Flags);
9537       // Cleanup on failure.
9538       if (InnerSelect.use_empty())
9539         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
9540     }
9541 
9542     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
9543     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
9544       SDValue N1_0 = N1->getOperand(0);
9545       SDValue N1_1 = N1->getOperand(1);
9546       SDValue N1_2 = N1->getOperand(2);
9547       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
9548         // Create the actual and node if we can generate good code for it.
9549         if (!normalizeToSequence) {
9550           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
9551           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
9552                              N2, Flags);
9553         }
9554         // Otherwise see if we can optimize the "and" to a better pattern.
9555         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
9556           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
9557                              N2, Flags);
9558         }
9559       }
9560     }
9561     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
9562     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
9563       SDValue N2_0 = N2->getOperand(0);
9564       SDValue N2_1 = N2->getOperand(1);
9565       SDValue N2_2 = N2->getOperand(2);
9566       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
9567         // Create the actual or node if we can generate good code for it.
9568         if (!normalizeToSequence) {
9569           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
9570           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
9571                              N2_2, Flags);
9572         }
9573         // Otherwise see if we can optimize to a better pattern.
9574         if (SDValue Combined = visitORLike(N0, N2_0, N))
9575           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
9576                              N2_2, Flags);
9577       }
9578     }
9579   }
9580 
9581   // select (not Cond), N1, N2 -> select Cond, N2, N1
9582   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
9583     SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
9584     SelectOp->setFlags(Flags);
9585     return SelectOp;
9586   }
9587 
9588   // Fold selects based on a setcc into other things, such as min/max/abs.
9589   if (N0.getOpcode() == ISD::SETCC) {
9590     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
9591     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
9592 
9593     // select (fcmp lt x, y), x, y -> fminnum x, y
9594     // select (fcmp gt x, y), x, y -> fmaxnum x, y
9595     //
9596     // This is OK if we don't care what happens if either operand is a NaN.
9597     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
9598       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2,
9599                                                 CC, TLI, DAG))
9600         return FMinMax;
9601 
9602     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
9603     // This is conservatively limited to pre-legal-operations to give targets
9604     // a chance to reverse the transform if they want to do that. Also, it is
9605     // unlikely that the pattern would be formed late, so it's probably not
9606     // worth going through the other checks.
9607     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
9608         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
9609         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
9610       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
9611       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
9612       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
9613         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
9614         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
9615         //
9616         // The IR equivalent of this transform would have this form:
9617         //   %a = add %x, C
9618         //   %c = icmp ugt %x, ~C
9619         //   %r = select %c, -1, %a
9620         //   =>
9621         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
9622         //   %u0 = extractvalue %u, 0
9623         //   %u1 = extractvalue %u, 1
9624         //   %r = select %u1, -1, %u0
9625         SDVTList VTs = DAG.getVTList(VT, VT0);
9626         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
9627         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
9628       }
9629     }
9630 
9631     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
9632         (!LegalOperations &&
9633          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
9634       // Any flags available in a select/setcc fold will be on the setcc as they
9635       // migrated from fcmp
9636       Flags = N0.getNode()->getFlags();
9637       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
9638                                        N2, N0.getOperand(2));
9639       SelectNode->setFlags(Flags);
9640       return SelectNode;
9641     }
9642 
9643     if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
9644       return NewSel;
9645   }
9646 
9647   if (!VT.isVector())
9648     if (SDValue BinOp = foldSelectOfBinops(N))
9649       return BinOp;
9650 
9651   return SDValue();
9652 }
9653 
9654 // This function assumes all the vselect's arguments are CONCAT_VECTOR
9655 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)9656 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
9657   SDLoc DL(N);
9658   SDValue Cond = N->getOperand(0);
9659   SDValue LHS = N->getOperand(1);
9660   SDValue RHS = N->getOperand(2);
9661   EVT VT = N->getValueType(0);
9662   int NumElems = VT.getVectorNumElements();
9663   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
9664          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
9665          Cond.getOpcode() == ISD::BUILD_VECTOR);
9666 
9667   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
9668   // binary ones here.
9669   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
9670     return SDValue();
9671 
9672   // We're sure we have an even number of elements due to the
9673   // concat_vectors we have as arguments to vselect.
9674   // Skip BV elements until we find one that's not an UNDEF
9675   // After we find an UNDEF element, keep looping until we get to half the
9676   // length of the BV and see if all the non-undef nodes are the same.
9677   ConstantSDNode *BottomHalf = nullptr;
9678   for (int i = 0; i < NumElems / 2; ++i) {
9679     if (Cond->getOperand(i)->isUndef())
9680       continue;
9681 
9682     if (BottomHalf == nullptr)
9683       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
9684     else if (Cond->getOperand(i).getNode() != BottomHalf)
9685       return SDValue();
9686   }
9687 
9688   // Do the same for the second half of the BuildVector
9689   ConstantSDNode *TopHalf = nullptr;
9690   for (int i = NumElems / 2; i < NumElems; ++i) {
9691     if (Cond->getOperand(i)->isUndef())
9692       continue;
9693 
9694     if (TopHalf == nullptr)
9695       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
9696     else if (Cond->getOperand(i).getNode() != TopHalf)
9697       return SDValue();
9698   }
9699 
9700   assert(TopHalf && BottomHalf &&
9701          "One half of the selector was all UNDEFs and the other was all the "
9702          "same value. This should have been addressed before this function.");
9703   return DAG.getNode(
9704       ISD::CONCAT_VECTORS, DL, VT,
9705       BottomHalf->isNullValue() ? RHS->getOperand(0) : LHS->getOperand(0),
9706       TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1));
9707 }
9708 
refineUniformBase(SDValue & BasePtr,SDValue & Index,SelectionDAG & DAG)9709 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) {
9710   if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
9711     return false;
9712 
9713   // For now we check only the LHS of the add.
9714   SDValue LHS = Index.getOperand(0);
9715   SDValue SplatVal = DAG.getSplatValue(LHS);
9716   if (!SplatVal)
9717     return false;
9718 
9719   BasePtr = SplatVal;
9720   Index = Index.getOperand(1);
9721   return true;
9722 }
9723 
9724 // Fold sext/zext of index into index type.
refineIndexType(MaskedGatherScatterSDNode * MGS,SDValue & Index,bool Scaled,SelectionDAG & DAG)9725 bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index,
9726                      bool Scaled, SelectionDAG &DAG) {
9727   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9728 
9729   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
9730     SDValue Op = Index.getOperand(0);
9731     MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED);
9732     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
9733       Index = Op;
9734       return true;
9735     }
9736   }
9737 
9738   if (Index.getOpcode() == ISD::SIGN_EXTEND) {
9739     SDValue Op = Index.getOperand(0);
9740     MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED);
9741     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
9742       Index = Op;
9743       return true;
9744     }
9745   }
9746 
9747   return false;
9748 }
9749 
visitMSCATTER(SDNode * N)9750 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
9751   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
9752   SDValue Mask = MSC->getMask();
9753   SDValue Chain = MSC->getChain();
9754   SDValue Index = MSC->getIndex();
9755   SDValue Scale = MSC->getScale();
9756   SDValue StoreVal = MSC->getValue();
9757   SDValue BasePtr = MSC->getBasePtr();
9758   SDLoc DL(N);
9759 
9760   // Zap scatters with a zero mask.
9761   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
9762     return Chain;
9763 
9764   if (refineUniformBase(BasePtr, Index, DAG)) {
9765     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
9766     return DAG.getMaskedScatter(
9767         DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
9768         MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
9769   }
9770 
9771   if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) {
9772     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
9773     return DAG.getMaskedScatter(
9774         DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
9775         MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
9776   }
9777 
9778   return SDValue();
9779 }
9780 
visitMSTORE(SDNode * N)9781 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
9782   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
9783   SDValue Mask = MST->getMask();
9784   SDValue Chain = MST->getChain();
9785   SDLoc DL(N);
9786 
9787   // Zap masked stores with a zero mask.
9788   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
9789     return Chain;
9790 
9791   // If this is a masked load with an all ones mask, we can use a unmasked load.
9792   // FIXME: Can we do this for indexed, compressing, or truncating stores?
9793   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) &&
9794       MST->isUnindexed() && !MST->isCompressingStore() &&
9795       !MST->isTruncatingStore())
9796     return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
9797                         MST->getBasePtr(), MST->getMemOperand());
9798 
9799   // Try transforming N to an indexed store.
9800   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
9801     return SDValue(N, 0);
9802 
9803   return SDValue();
9804 }
9805 
visitMGATHER(SDNode * N)9806 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
9807   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
9808   SDValue Mask = MGT->getMask();
9809   SDValue Chain = MGT->getChain();
9810   SDValue Index = MGT->getIndex();
9811   SDValue Scale = MGT->getScale();
9812   SDValue PassThru = MGT->getPassThru();
9813   SDValue BasePtr = MGT->getBasePtr();
9814   SDLoc DL(N);
9815 
9816   // Zap gathers with a zero mask.
9817   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
9818     return CombineTo(N, PassThru, MGT->getChain());
9819 
9820   if (refineUniformBase(BasePtr, Index, DAG)) {
9821     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
9822     return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
9823                                MGT->getMemoryVT(), DL, Ops,
9824                                MGT->getMemOperand(), MGT->getIndexType(),
9825                                MGT->getExtensionType());
9826   }
9827 
9828   if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) {
9829     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
9830     return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
9831                                MGT->getMemoryVT(), DL, Ops,
9832                                MGT->getMemOperand(), MGT->getIndexType(),
9833                                MGT->getExtensionType());
9834   }
9835 
9836   return SDValue();
9837 }
9838 
visitMLOAD(SDNode * N)9839 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
9840   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
9841   SDValue Mask = MLD->getMask();
9842   SDLoc DL(N);
9843 
9844   // Zap masked loads with a zero mask.
9845   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
9846     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
9847 
9848   // If this is a masked load with an all ones mask, we can use a unmasked load.
9849   // FIXME: Can we do this for indexed, expanding, or extending loads?
9850   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) &&
9851       MLD->isUnindexed() && !MLD->isExpandingLoad() &&
9852       MLD->getExtensionType() == ISD::NON_EXTLOAD) {
9853     SDValue NewLd = DAG.getLoad(N->getValueType(0), SDLoc(N), MLD->getChain(),
9854                                 MLD->getBasePtr(), MLD->getMemOperand());
9855     return CombineTo(N, NewLd, NewLd.getValue(1));
9856   }
9857 
9858   // Try transforming N to an indexed load.
9859   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
9860     return SDValue(N, 0);
9861 
9862   return SDValue();
9863 }
9864 
9865 /// A vector select of 2 constant vectors can be simplified to math/logic to
9866 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)9867 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
9868   SDValue Cond = N->getOperand(0);
9869   SDValue N1 = N->getOperand(1);
9870   SDValue N2 = N->getOperand(2);
9871   EVT VT = N->getValueType(0);
9872   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
9873       !TLI.convertSelectOfConstantsToMath(VT) ||
9874       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
9875       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
9876     return SDValue();
9877 
9878   // Check if we can use the condition value to increment/decrement a single
9879   // constant value. This simplifies a select to an add and removes a constant
9880   // load/materialization from the general case.
9881   bool AllAddOne = true;
9882   bool AllSubOne = true;
9883   unsigned Elts = VT.getVectorNumElements();
9884   for (unsigned i = 0; i != Elts; ++i) {
9885     SDValue N1Elt = N1.getOperand(i);
9886     SDValue N2Elt = N2.getOperand(i);
9887     if (N1Elt.isUndef() || N2Elt.isUndef())
9888       continue;
9889     if (N1Elt.getValueType() != N2Elt.getValueType())
9890       continue;
9891 
9892     const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
9893     const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
9894     if (C1 != C2 + 1)
9895       AllAddOne = false;
9896     if (C1 != C2 - 1)
9897       AllSubOne = false;
9898   }
9899 
9900   // Further simplifications for the extra-special cases where the constants are
9901   // all 0 or all -1 should be implemented as folds of these patterns.
9902   SDLoc DL(N);
9903   if (AllAddOne || AllSubOne) {
9904     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
9905     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
9906     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
9907     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
9908     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
9909   }
9910 
9911   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
9912   APInt Pow2C;
9913   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
9914       isNullOrNullSplat(N2)) {
9915     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
9916     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
9917     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
9918   }
9919 
9920   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
9921     return V;
9922 
9923   // The general case for select-of-constants:
9924   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
9925   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
9926   // leave that to a machine-specific pass.
9927   return SDValue();
9928 }
9929 
visitVSELECT(SDNode * N)9930 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
9931   SDValue N0 = N->getOperand(0);
9932   SDValue N1 = N->getOperand(1);
9933   SDValue N2 = N->getOperand(2);
9934   EVT VT = N->getValueType(0);
9935   SDLoc DL(N);
9936 
9937   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
9938     return V;
9939 
9940   if (SDValue V = foldBoolSelectToLogic(N, DAG))
9941     return V;
9942 
9943   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
9944   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
9945     return DAG.getSelect(DL, VT, F, N2, N1);
9946 
9947   // Canonicalize integer abs.
9948   // vselect (setg[te] X,  0),  X, -X ->
9949   // vselect (setgt    X, -1),  X, -X ->
9950   // vselect (setl[te] X,  0), -X,  X ->
9951   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
9952   if (N0.getOpcode() == ISD::SETCC) {
9953     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
9954     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
9955     bool isAbs = false;
9956     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
9957 
9958     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
9959          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
9960         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
9961       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
9962     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
9963              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
9964       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
9965 
9966     if (isAbs) {
9967       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
9968         return DAG.getNode(ISD::ABS, DL, VT, LHS);
9969 
9970       SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
9971                                   DAG.getConstant(VT.getScalarSizeInBits() - 1,
9972                                                   DL, getShiftAmountTy(VT)));
9973       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
9974       AddToWorklist(Shift.getNode());
9975       AddToWorklist(Add.getNode());
9976       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
9977     }
9978 
9979     // vselect x, y (fcmp lt x, y) -> fminnum x, y
9980     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
9981     //
9982     // This is OK if we don't care about what happens if either operand is a
9983     // NaN.
9984     //
9985     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
9986       if (SDValue FMinMax =
9987               combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC, TLI, DAG))
9988         return FMinMax;
9989     }
9990 
9991     // If this select has a condition (setcc) with narrower operands than the
9992     // select, try to widen the compare to match the select width.
9993     // TODO: This should be extended to handle any constant.
9994     // TODO: This could be extended to handle non-loading patterns, but that
9995     //       requires thorough testing to avoid regressions.
9996     if (isNullOrNullSplat(RHS)) {
9997       EVT NarrowVT = LHS.getValueType();
9998       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
9999       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
10000       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
10001       unsigned WideWidth = WideVT.getScalarSizeInBits();
10002       bool IsSigned = isSignedIntSetCC(CC);
10003       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
10004       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
10005           SetCCWidth != 1 && SetCCWidth < WideWidth &&
10006           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
10007           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
10008         // Both compare operands can be widened for free. The LHS can use an
10009         // extended load, and the RHS is a constant:
10010         //   vselect (ext (setcc load(X), C)), N1, N2 -->
10011         //   vselect (setcc extload(X), C'), N1, N2
10012         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
10013         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
10014         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
10015         EVT WideSetCCVT = getSetCCResultType(WideVT);
10016         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
10017         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
10018       }
10019     }
10020 
10021     // Match VSELECTs into add with unsigned saturation.
10022     if (hasOperation(ISD::UADDSAT, VT)) {
10023       // Check if one of the arms of the VSELECT is vector with all bits set.
10024       // If it's on the left side invert the predicate to simplify logic below.
10025       SDValue Other;
10026       ISD::CondCode SatCC = CC;
10027       if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
10028         Other = N2;
10029         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
10030       } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
10031         Other = N1;
10032       }
10033 
10034       if (Other && Other.getOpcode() == ISD::ADD) {
10035         SDValue CondLHS = LHS, CondRHS = RHS;
10036         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
10037 
10038         // Canonicalize condition operands.
10039         if (SatCC == ISD::SETUGE) {
10040           std::swap(CondLHS, CondRHS);
10041           SatCC = ISD::SETULE;
10042         }
10043 
10044         // We can test against either of the addition operands.
10045         // x <= x+y ? x+y : ~0 --> uaddsat x, y
10046         // x+y >= x ? x+y : ~0 --> uaddsat x, y
10047         if (SatCC == ISD::SETULE && Other == CondRHS &&
10048             (OpLHS == CondLHS || OpRHS == CondLHS))
10049           return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
10050 
10051         if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
10052             (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
10053              OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
10054             CondLHS == OpLHS) {
10055           // If the RHS is a constant we have to reverse the const
10056           // canonicalization.
10057           // x >= ~C ? x+C : ~0 --> uaddsat x, C
10058           auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
10059             return Cond->getAPIntValue() == ~Op->getAPIntValue();
10060           };
10061           if (SatCC == ISD::SETULE &&
10062               ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
10063             return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
10064         }
10065       }
10066     }
10067 
10068     // Match VSELECTs into sub with unsigned saturation.
10069     if (hasOperation(ISD::USUBSAT, VT)) {
10070       // Check if one of the arms of the VSELECT is a zero vector. If it's on
10071       // the left side invert the predicate to simplify logic below.
10072       SDValue Other;
10073       ISD::CondCode SatCC = CC;
10074       if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
10075         Other = N2;
10076         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
10077       } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
10078         Other = N1;
10079       }
10080 
10081       if (Other && Other.getNumOperands() == 2) {
10082         SDValue CondRHS = RHS;
10083         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
10084 
10085         if (Other.getOpcode() == ISD::SUB &&
10086             LHS.getOpcode() == ISD::ZERO_EXTEND && LHS.getOperand(0) == OpLHS &&
10087             OpRHS.getOpcode() == ISD::TRUNCATE && OpRHS.getOperand(0) == RHS) {
10088           // Look for a general sub with unsigned saturation first.
10089           // zext(x) >= y ? x - trunc(y) : 0
10090           // --> usubsat(x,trunc(umin(y,SatLimit)))
10091           // zext(x) >  y ? x - trunc(y) : 0
10092           // --> usubsat(x,trunc(umin(y,SatLimit)))
10093           if (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)
10094             return getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS, DAG,
10095                                        DL);
10096         }
10097 
10098         if (OpLHS == LHS) {
10099           // Look for a general sub with unsigned saturation first.
10100           // x >= y ? x-y : 0 --> usubsat x, y
10101           // x >  y ? x-y : 0 --> usubsat x, y
10102           if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
10103               Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
10104             return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
10105 
10106           if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
10107               OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
10108             if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
10109                 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
10110               // If the RHS is a constant we have to reverse the const
10111               // canonicalization.
10112               // x > C-1 ? x+-C : 0 --> usubsat x, C
10113               auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
10114                 return (!Op && !Cond) ||
10115                        (Op && Cond &&
10116                         Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
10117               };
10118               if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
10119                   ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
10120                                             /*AllowUndefs*/ true)) {
10121                 OpRHS = DAG.getNode(ISD::SUB, DL, VT,
10122                                     DAG.getConstant(0, DL, VT), OpRHS);
10123                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
10124               }
10125 
10126               // Another special case: If C was a sign bit, the sub has been
10127               // canonicalized into a xor.
10128               // FIXME: Would it be better to use computeKnownBits to determine
10129               //        whether it's safe to decanonicalize the xor?
10130               // x s< 0 ? x^C : 0 --> usubsat x, C
10131               APInt SplatValue;
10132               if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
10133                   ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
10134                   ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
10135                   SplatValue.isSignMask()) {
10136                 // Note that we have to rebuild the RHS constant here to
10137                 // ensure we don't rely on particular values of undef lanes.
10138                 OpRHS = DAG.getConstant(SplatValue, DL, VT);
10139                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
10140               }
10141             }
10142           }
10143         }
10144       }
10145     }
10146   }
10147 
10148   if (SimplifySelectOps(N, N1, N2))
10149     return SDValue(N, 0);  // Don't revisit N.
10150 
10151   // Fold (vselect all_ones, N1, N2) -> N1
10152   if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
10153     return N1;
10154   // Fold (vselect all_zeros, N1, N2) -> N2
10155   if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
10156     return N2;
10157 
10158   // The ConvertSelectToConcatVector function is assuming both the above
10159   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
10160   // and addressed.
10161   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
10162       N2.getOpcode() == ISD::CONCAT_VECTORS &&
10163       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
10164     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
10165       return CV;
10166   }
10167 
10168   if (SDValue V = foldVSelectOfConstants(N))
10169     return V;
10170 
10171   return SDValue();
10172 }
10173 
visitSELECT_CC(SDNode * N)10174 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
10175   SDValue N0 = N->getOperand(0);
10176   SDValue N1 = N->getOperand(1);
10177   SDValue N2 = N->getOperand(2);
10178   SDValue N3 = N->getOperand(3);
10179   SDValue N4 = N->getOperand(4);
10180   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
10181 
10182   // fold select_cc lhs, rhs, x, x, cc -> x
10183   if (N2 == N3)
10184     return N2;
10185 
10186   // Determine if the condition we're dealing with is constant
10187   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
10188                                   CC, SDLoc(N), false)) {
10189     AddToWorklist(SCC.getNode());
10190 
10191     if (ConstantSDNode *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode())) {
10192       if (!SCCC->isNullValue())
10193         return N2;    // cond always true -> true val
10194       else
10195         return N3;    // cond always false -> false val
10196     } else if (SCC->isUndef()) {
10197       // When the condition is UNDEF, just return the first operand. This is
10198       // coherent the DAG creation, no setcc node is created in this case
10199       return N2;
10200     } else if (SCC.getOpcode() == ISD::SETCC) {
10201       // Fold to a simpler select_cc
10202       SDValue SelectOp = DAG.getNode(
10203           ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
10204           SCC.getOperand(1), N2, N3, SCC.getOperand(2));
10205       SelectOp->setFlags(SCC->getFlags());
10206       return SelectOp;
10207     }
10208   }
10209 
10210   // If we can fold this based on the true/false value, do so.
10211   if (SimplifySelectOps(N, N2, N3))
10212     return SDValue(N, 0);  // Don't revisit N.
10213 
10214   // fold select_cc into other things, such as min/max/abs
10215   return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
10216 }
10217 
visitSETCC(SDNode * N)10218 SDValue DAGCombiner::visitSETCC(SDNode *N) {
10219   // setcc is very commonly used as an argument to brcond. This pattern
10220   // also lend itself to numerous combines and, as a result, it is desired
10221   // we keep the argument to a brcond as a setcc as much as possible.
10222   bool PreferSetCC =
10223       N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
10224 
10225   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
10226   EVT VT = N->getValueType(0);
10227 
10228   //   SETCC(FREEZE(X), CONST, Cond)
10229   // =>
10230   //   FREEZE(SETCC(X, CONST, Cond))
10231   // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
10232   // isn't equivalent to true or false.
10233   // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
10234   // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
10235   //
10236   // This transformation is beneficial because visitBRCOND can fold
10237   // BRCOND(FREEZE(X)) to BRCOND(X).
10238 
10239   // Conservatively optimize integer comparisons only.
10240   if (PreferSetCC) {
10241     // Do this only when SETCC is going to be used by BRCOND.
10242 
10243     SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
10244     ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
10245     ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
10246     bool Updated = false;
10247 
10248     // Is 'X Cond C' always true or false?
10249     auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
10250       bool False = (Cond == ISD::SETULT && C->isNullValue()) ||
10251                    (Cond == ISD::SETLT  && C->isMinSignedValue()) ||
10252                    (Cond == ISD::SETUGT && C->isAllOnesValue()) ||
10253                    (Cond == ISD::SETGT  && C->isMaxSignedValue());
10254       bool True =  (Cond == ISD::SETULE && C->isAllOnesValue()) ||
10255                    (Cond == ISD::SETLE  && C->isMaxSignedValue()) ||
10256                    (Cond == ISD::SETUGE && C->isNullValue()) ||
10257                    (Cond == ISD::SETGE  && C->isMinSignedValue());
10258       return True || False;
10259     };
10260 
10261     if (N0->getOpcode() == ISD::FREEZE && N0.hasOneUse() && N1C) {
10262       if (!IsAlwaysTrueOrFalse(Cond, N1C)) {
10263         N0 = N0->getOperand(0);
10264         Updated = true;
10265       }
10266     }
10267     if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse() && N0C) {
10268       if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond),
10269                                N0C)) {
10270         N1 = N1->getOperand(0);
10271         Updated = true;
10272       }
10273     }
10274 
10275     if (Updated)
10276       return DAG.getFreeze(DAG.getSetCC(SDLoc(N), VT, N0, N1, Cond));
10277   }
10278 
10279   SDValue Combined = SimplifySetCC(VT, N->getOperand(0), N->getOperand(1), Cond,
10280                                    SDLoc(N), !PreferSetCC);
10281 
10282   if (!Combined)
10283     return SDValue();
10284 
10285   // If we prefer to have a setcc, and we don't, we'll try our best to
10286   // recreate one using rebuildSetCC.
10287   if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
10288     SDValue NewSetCC = rebuildSetCC(Combined);
10289 
10290     // We don't have anything interesting to combine to.
10291     if (NewSetCC.getNode() == N)
10292       return SDValue();
10293 
10294     if (NewSetCC)
10295       return NewSetCC;
10296   }
10297 
10298   return Combined;
10299 }
10300 
visitSETCCCARRY(SDNode * N)10301 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
10302   SDValue LHS = N->getOperand(0);
10303   SDValue RHS = N->getOperand(1);
10304   SDValue Carry = N->getOperand(2);
10305   SDValue Cond = N->getOperand(3);
10306 
10307   // If Carry is false, fold to a regular SETCC.
10308   if (isNullConstant(Carry))
10309     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
10310 
10311   return SDValue();
10312 }
10313 
10314 /// Check if N satisfies:
10315 ///   N is used once.
10316 ///   N is a Load.
10317 ///   The load is compatible with ExtOpcode. It means
10318 ///     If load has explicit zero/sign extension, ExpOpcode must have the same
10319 ///     extension.
10320 ///     Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)10321 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
10322   if (!N.hasOneUse())
10323     return false;
10324 
10325   if (!isa<LoadSDNode>(N))
10326     return false;
10327 
10328   LoadSDNode *Load = cast<LoadSDNode>(N);
10329   ISD::LoadExtType LoadExt = Load->getExtensionType();
10330   if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
10331     return true;
10332 
10333   // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
10334   // extension.
10335   if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
10336       (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
10337     return false;
10338 
10339   return true;
10340 }
10341 
10342 /// Fold
10343 ///   (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
10344 ///   (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
10345 ///   (aext (select c, load x, load y)) -> (select c, extload x, extload y)
10346 /// This function is called by the DAGCombiner when visiting sext/zext/aext
10347 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG)10348 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
10349                                          SelectionDAG &DAG) {
10350   unsigned Opcode = N->getOpcode();
10351   SDValue N0 = N->getOperand(0);
10352   EVT VT = N->getValueType(0);
10353   SDLoc DL(N);
10354 
10355   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
10356           Opcode == ISD::ANY_EXTEND) &&
10357          "Expected EXTEND dag node in input!");
10358 
10359   if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
10360       !N0.hasOneUse())
10361     return SDValue();
10362 
10363   SDValue Op1 = N0->getOperand(1);
10364   SDValue Op2 = N0->getOperand(2);
10365   if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
10366     return SDValue();
10367 
10368   auto ExtLoadOpcode = ISD::EXTLOAD;
10369   if (Opcode == ISD::SIGN_EXTEND)
10370     ExtLoadOpcode = ISD::SEXTLOAD;
10371   else if (Opcode == ISD::ZERO_EXTEND)
10372     ExtLoadOpcode = ISD::ZEXTLOAD;
10373 
10374   LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
10375   LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
10376   if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
10377       !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()))
10378     return SDValue();
10379 
10380   SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
10381   SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
10382   return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
10383 }
10384 
10385 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
10386 /// a build_vector of constants.
10387 /// This function is called by the DAGCombiner when visiting sext/zext/aext
10388 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
10389 /// Vector extends are not folded if operations are legal; this is to
10390 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)10391 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
10392                                          SelectionDAG &DAG, bool LegalTypes) {
10393   unsigned Opcode = N->getOpcode();
10394   SDValue N0 = N->getOperand(0);
10395   EVT VT = N->getValueType(0);
10396   SDLoc DL(N);
10397 
10398   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
10399          Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
10400          Opcode == ISD::ZERO_EXTEND_VECTOR_INREG)
10401          && "Expected EXTEND dag node in input!");
10402 
10403   // fold (sext c1) -> c1
10404   // fold (zext c1) -> c1
10405   // fold (aext c1) -> c1
10406   if (isa<ConstantSDNode>(N0))
10407     return DAG.getNode(Opcode, DL, VT, N0);
10408 
10409   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
10410   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
10411   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
10412   if (N0->getOpcode() == ISD::SELECT) {
10413     SDValue Op1 = N0->getOperand(1);
10414     SDValue Op2 = N0->getOperand(2);
10415     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
10416         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
10417       // For any_extend, choose sign extension of the constants to allow a
10418       // possible further transform to sign_extend_inreg.i.e.
10419       //
10420       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
10421       // t2: i64 = any_extend t1
10422       // -->
10423       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
10424       // -->
10425       // t4: i64 = sign_extend_inreg t3
10426       unsigned FoldOpc = Opcode;
10427       if (FoldOpc == ISD::ANY_EXTEND)
10428         FoldOpc = ISD::SIGN_EXTEND;
10429       return DAG.getSelect(DL, VT, N0->getOperand(0),
10430                            DAG.getNode(FoldOpc, DL, VT, Op1),
10431                            DAG.getNode(FoldOpc, DL, VT, Op2));
10432     }
10433   }
10434 
10435   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
10436   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
10437   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
10438   EVT SVT = VT.getScalarType();
10439   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
10440       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
10441     return SDValue();
10442 
10443   // We can fold this node into a build_vector.
10444   unsigned VTBits = SVT.getSizeInBits();
10445   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
10446   SmallVector<SDValue, 8> Elts;
10447   unsigned NumElts = VT.getVectorNumElements();
10448 
10449   // For zero-extensions, UNDEF elements still guarantee to have the upper
10450   // bits set to zero.
10451   bool IsZext =
10452       Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG;
10453 
10454   for (unsigned i = 0; i != NumElts; ++i) {
10455     SDValue Op = N0.getOperand(i);
10456     if (Op.isUndef()) {
10457       Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT));
10458       continue;
10459     }
10460 
10461     SDLoc DL(Op);
10462     // Get the constant value and if needed trunc it to the size of the type.
10463     // Nodes like build_vector might have constants wider than the scalar type.
10464     APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().zextOrTrunc(EVTBits);
10465     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
10466       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
10467     else
10468       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
10469   }
10470 
10471   return DAG.getBuildVector(VT, DL, Elts);
10472 }
10473 
10474 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
10475 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
10476 // transformation. Returns true if extension are possible and the above
10477 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)10478 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
10479                                     unsigned ExtOpc,
10480                                     SmallVectorImpl<SDNode *> &ExtendNodes,
10481                                     const TargetLowering &TLI) {
10482   bool HasCopyToRegUses = false;
10483   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
10484   for (SDNode::use_iterator UI = N0.getNode()->use_begin(),
10485                             UE = N0.getNode()->use_end();
10486        UI != UE; ++UI) {
10487     SDNode *User = *UI;
10488     if (User == N)
10489       continue;
10490     if (UI.getUse().getResNo() != N0.getResNo())
10491       continue;
10492     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
10493     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
10494       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
10495       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
10496         // Sign bits will be lost after a zext.
10497         return false;
10498       bool Add = false;
10499       for (unsigned i = 0; i != 2; ++i) {
10500         SDValue UseOp = User->getOperand(i);
10501         if (UseOp == N0)
10502           continue;
10503         if (!isa<ConstantSDNode>(UseOp))
10504           return false;
10505         Add = true;
10506       }
10507       if (Add)
10508         ExtendNodes.push_back(User);
10509       continue;
10510     }
10511     // If truncates aren't free and there are users we can't
10512     // extend, it isn't worthwhile.
10513     if (!isTruncFree)
10514       return false;
10515     // Remember if this value is live-out.
10516     if (User->getOpcode() == ISD::CopyToReg)
10517       HasCopyToRegUses = true;
10518   }
10519 
10520   if (HasCopyToRegUses) {
10521     bool BothLiveOut = false;
10522     for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
10523          UI != UE; ++UI) {
10524       SDUse &Use = UI.getUse();
10525       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
10526         BothLiveOut = true;
10527         break;
10528       }
10529     }
10530     if (BothLiveOut)
10531       // Both unextended and extended values are live out. There had better be
10532       // a good reason for the transformation.
10533       return ExtendNodes.size();
10534   }
10535   return true;
10536 }
10537 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)10538 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
10539                                   SDValue OrigLoad, SDValue ExtLoad,
10540                                   ISD::NodeType ExtType) {
10541   // Extend SetCC uses if necessary.
10542   SDLoc DL(ExtLoad);
10543   for (SDNode *SetCC : SetCCs) {
10544     SmallVector<SDValue, 4> Ops;
10545 
10546     for (unsigned j = 0; j != 2; ++j) {
10547       SDValue SOp = SetCC->getOperand(j);
10548       if (SOp == OrigLoad)
10549         Ops.push_back(ExtLoad);
10550       else
10551         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
10552     }
10553 
10554     Ops.push_back(SetCC->getOperand(2));
10555     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
10556   }
10557 }
10558 
10559 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)10560 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
10561   SDValue N0 = N->getOperand(0);
10562   EVT DstVT = N->getValueType(0);
10563   EVT SrcVT = N0.getValueType();
10564 
10565   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
10566           N->getOpcode() == ISD::ZERO_EXTEND) &&
10567          "Unexpected node type (not an extend)!");
10568 
10569   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
10570   // For example, on a target with legal v4i32, but illegal v8i32, turn:
10571   //   (v8i32 (sext (v8i16 (load x))))
10572   // into:
10573   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
10574   //                          (v4i32 (sextload (x + 16)))))
10575   // Where uses of the original load, i.e.:
10576   //   (v8i16 (load x))
10577   // are replaced with:
10578   //   (v8i16 (truncate
10579   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
10580   //                            (v4i32 (sextload (x + 16)))))))
10581   //
10582   // This combine is only applicable to illegal, but splittable, vectors.
10583   // All legal types, and illegal non-vector types, are handled elsewhere.
10584   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
10585   //
10586   if (N0->getOpcode() != ISD::LOAD)
10587     return SDValue();
10588 
10589   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10590 
10591   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
10592       !N0.hasOneUse() || !LN0->isSimple() ||
10593       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
10594       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
10595     return SDValue();
10596 
10597   SmallVector<SDNode *, 4> SetCCs;
10598   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
10599     return SDValue();
10600 
10601   ISD::LoadExtType ExtType =
10602       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
10603 
10604   // Try to split the vector types to get down to legal types.
10605   EVT SplitSrcVT = SrcVT;
10606   EVT SplitDstVT = DstVT;
10607   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
10608          SplitSrcVT.getVectorNumElements() > 1) {
10609     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
10610     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
10611   }
10612 
10613   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
10614     return SDValue();
10615 
10616   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
10617 
10618   SDLoc DL(N);
10619   const unsigned NumSplits =
10620       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
10621   const unsigned Stride = SplitSrcVT.getStoreSize();
10622   SmallVector<SDValue, 4> Loads;
10623   SmallVector<SDValue, 4> Chains;
10624 
10625   SDValue BasePtr = LN0->getBasePtr();
10626   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
10627     const unsigned Offset = Idx * Stride;
10628     const Align Align = commonAlignment(LN0->getAlign(), Offset);
10629 
10630     SDValue SplitLoad = DAG.getExtLoad(
10631         ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
10632         LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
10633         LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
10634 
10635     BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(Stride), DL);
10636 
10637     Loads.push_back(SplitLoad.getValue(0));
10638     Chains.push_back(SplitLoad.getValue(1));
10639   }
10640 
10641   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
10642   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
10643 
10644   // Simplify TF.
10645   AddToWorklist(NewChain.getNode());
10646 
10647   CombineTo(N, NewValue);
10648 
10649   // Replace uses of the original load (before extension)
10650   // with a truncate of the concatenated sextloaded vectors.
10651   SDValue Trunc =
10652       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
10653   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
10654   CombineTo(N0.getNode(), Trunc, NewChain);
10655   return SDValue(N, 0); // Return N so it doesn't get rechecked!
10656 }
10657 
10658 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
10659 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)10660 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
10661   assert(N->getOpcode() == ISD::ZERO_EXTEND);
10662   EVT VT = N->getValueType(0);
10663   EVT OrigVT = N->getOperand(0).getValueType();
10664   if (TLI.isZExtFree(OrigVT, VT))
10665     return SDValue();
10666 
10667   // and/or/xor
10668   SDValue N0 = N->getOperand(0);
10669   if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
10670         N0.getOpcode() == ISD::XOR) ||
10671       N0.getOperand(1).getOpcode() != ISD::Constant ||
10672       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
10673     return SDValue();
10674 
10675   // shl/shr
10676   SDValue N1 = N0->getOperand(0);
10677   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
10678       N1.getOperand(1).getOpcode() != ISD::Constant ||
10679       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
10680     return SDValue();
10681 
10682   // load
10683   if (!isa<LoadSDNode>(N1.getOperand(0)))
10684     return SDValue();
10685   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
10686   EVT MemVT = Load->getMemoryVT();
10687   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
10688       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
10689     return SDValue();
10690 
10691 
10692   // If the shift op is SHL, the logic op must be AND, otherwise the result
10693   // will be wrong.
10694   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
10695     return SDValue();
10696 
10697   if (!N0.hasOneUse() || !N1.hasOneUse())
10698     return SDValue();
10699 
10700   SmallVector<SDNode*, 4> SetCCs;
10701   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
10702                                ISD::ZERO_EXTEND, SetCCs, TLI))
10703     return SDValue();
10704 
10705   // Actually do the transformation.
10706   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
10707                                    Load->getChain(), Load->getBasePtr(),
10708                                    Load->getMemoryVT(), Load->getMemOperand());
10709 
10710   SDLoc DL1(N1);
10711   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
10712                               N1.getOperand(1));
10713 
10714   APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
10715   SDLoc DL0(N0);
10716   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
10717                             DAG.getConstant(Mask, DL0, VT));
10718 
10719   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
10720   CombineTo(N, And);
10721   if (SDValue(Load, 0).hasOneUse()) {
10722     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
10723   } else {
10724     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
10725                                 Load->getValueType(0), ExtLoad);
10726     CombineTo(Load, Trunc, ExtLoad.getValue(1));
10727   }
10728 
10729   // N0 is dead at this point.
10730   recursivelyDeleteUnusedNodes(N0.getNode());
10731 
10732   return SDValue(N,0); // Return N so it doesn't get rechecked!
10733 }
10734 
10735 /// If we're narrowing or widening the result of a vector select and the final
10736 /// size is the same size as a setcc (compare) feeding the select, then try to
10737 /// apply the cast operation to the select's operands because matching vector
10738 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)10739 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
10740   unsigned CastOpcode = Cast->getOpcode();
10741   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
10742           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
10743           CastOpcode == ISD::FP_ROUND) &&
10744          "Unexpected opcode for vector select narrowing/widening");
10745 
10746   // We only do this transform before legal ops because the pattern may be
10747   // obfuscated by target-specific operations after legalization. Do not create
10748   // an illegal select op, however, because that may be difficult to lower.
10749   EVT VT = Cast->getValueType(0);
10750   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
10751     return SDValue();
10752 
10753   SDValue VSel = Cast->getOperand(0);
10754   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
10755       VSel.getOperand(0).getOpcode() != ISD::SETCC)
10756     return SDValue();
10757 
10758   // Does the setcc have the same vector size as the casted select?
10759   SDValue SetCC = VSel.getOperand(0);
10760   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
10761   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
10762     return SDValue();
10763 
10764   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
10765   SDValue A = VSel.getOperand(1);
10766   SDValue B = VSel.getOperand(2);
10767   SDValue CastA, CastB;
10768   SDLoc DL(Cast);
10769   if (CastOpcode == ISD::FP_ROUND) {
10770     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
10771     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
10772     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
10773   } else {
10774     CastA = DAG.getNode(CastOpcode, DL, VT, A);
10775     CastB = DAG.getNode(CastOpcode, DL, VT, B);
10776   }
10777   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
10778 }
10779 
10780 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
10781 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)10782 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
10783                                      const TargetLowering &TLI, EVT VT,
10784                                      bool LegalOperations, SDNode *N,
10785                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
10786   SDNode *N0Node = N0.getNode();
10787   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
10788                                                    : ISD::isZEXTLoad(N0Node);
10789   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
10790       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
10791     return SDValue();
10792 
10793   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10794   EVT MemVT = LN0->getMemoryVT();
10795   if ((LegalOperations || !LN0->isSimple() ||
10796        VT.isVector()) &&
10797       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
10798     return SDValue();
10799 
10800   SDValue ExtLoad =
10801       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
10802                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
10803   Combiner.CombineTo(N, ExtLoad);
10804   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
10805   if (LN0->use_empty())
10806     Combiner.recursivelyDeleteUnusedNodes(LN0);
10807   return SDValue(N, 0); // Return N so it doesn't get rechecked!
10808 }
10809 
10810 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
10811 // Only generate vector extloads when 1) they're legal, and 2) they are
10812 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)10813 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
10814                                   const TargetLowering &TLI, EVT VT,
10815                                   bool LegalOperations, SDNode *N, SDValue N0,
10816                                   ISD::LoadExtType ExtLoadType,
10817                                   ISD::NodeType ExtOpc) {
10818   if (!ISD::isNON_EXTLoad(N0.getNode()) ||
10819       !ISD::isUNINDEXEDLoad(N0.getNode()) ||
10820       ((LegalOperations || VT.isVector() ||
10821         !cast<LoadSDNode>(N0)->isSimple()) &&
10822        !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
10823     return {};
10824 
10825   bool DoXform = true;
10826   SmallVector<SDNode *, 4> SetCCs;
10827   if (!N0.hasOneUse())
10828     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
10829   if (VT.isVector())
10830     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
10831   if (!DoXform)
10832     return {};
10833 
10834   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10835   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
10836                                    LN0->getBasePtr(), N0.getValueType(),
10837                                    LN0->getMemOperand());
10838   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
10839   // If the load value is used only by N, replace it via CombineTo N.
10840   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
10841   Combiner.CombineTo(N, ExtLoad);
10842   if (NoReplaceTrunc) {
10843     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
10844     Combiner.recursivelyDeleteUnusedNodes(LN0);
10845   } else {
10846     SDValue Trunc =
10847         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
10848     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
10849   }
10850   return SDValue(N, 0); // Return N so it doesn't get rechecked!
10851 }
10852 
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)10853 static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG,
10854                                         const TargetLowering &TLI, EVT VT,
10855                                         SDNode *N, SDValue N0,
10856                                         ISD::LoadExtType ExtLoadType,
10857                                         ISD::NodeType ExtOpc) {
10858   if (!N0.hasOneUse())
10859     return SDValue();
10860 
10861   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
10862   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
10863     return SDValue();
10864 
10865   if (!TLI.isLoadExtLegal(ExtLoadType, VT, Ld->getValueType(0)))
10866     return SDValue();
10867 
10868   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
10869     return SDValue();
10870 
10871   SDLoc dl(Ld);
10872   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
10873   SDValue NewLoad = DAG.getMaskedLoad(
10874       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
10875       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
10876       ExtLoadType, Ld->isExpandingLoad());
10877   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
10878   return NewLoad;
10879 }
10880 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)10881 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
10882                                        bool LegalOperations) {
10883   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
10884           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
10885 
10886   SDValue SetCC = N->getOperand(0);
10887   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
10888       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
10889     return SDValue();
10890 
10891   SDValue X = SetCC.getOperand(0);
10892   SDValue Ones = SetCC.getOperand(1);
10893   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
10894   EVT VT = N->getValueType(0);
10895   EVT XVT = X.getValueType();
10896   // setge X, C is canonicalized to setgt, so we do not need to match that
10897   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
10898   // not require the 'not' op.
10899   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
10900     // Invert and smear/shift the sign bit:
10901     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
10902     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
10903     SDLoc DL(N);
10904     unsigned ShCt = VT.getSizeInBits() - 1;
10905     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10906     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
10907       SDValue NotX = DAG.getNOT(DL, X, VT);
10908       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
10909       auto ShiftOpcode =
10910         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
10911       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
10912     }
10913   }
10914   return SDValue();
10915 }
10916 
foldSextSetcc(SDNode * N)10917 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
10918   SDValue N0 = N->getOperand(0);
10919   if (N0.getOpcode() != ISD::SETCC)
10920     return SDValue();
10921 
10922   SDValue N00 = N0.getOperand(0);
10923   SDValue N01 = N0.getOperand(1);
10924   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10925   EVT VT = N->getValueType(0);
10926   EVT N00VT = N00.getValueType();
10927   SDLoc DL(N);
10928 
10929   // On some architectures (such as SSE/NEON/etc) the SETCC result type is
10930   // the same size as the compared operands. Try to optimize sext(setcc())
10931   // if this is the case.
10932   if (VT.isVector() && !LegalOperations &&
10933       TLI.getBooleanContents(N00VT) ==
10934           TargetLowering::ZeroOrNegativeOneBooleanContent) {
10935     EVT SVT = getSetCCResultType(N00VT);
10936 
10937     // If we already have the desired type, don't change it.
10938     if (SVT != N0.getValueType()) {
10939       // We know that the # elements of the results is the same as the
10940       // # elements of the compare (and the # elements of the compare result
10941       // for that matter).  Check to see that they are the same size.  If so,
10942       // we know that the element size of the sext'd result matches the
10943       // element size of the compare operands.
10944       if (VT.getSizeInBits() == SVT.getSizeInBits())
10945         return DAG.getSetCC(DL, VT, N00, N01, CC);
10946 
10947       // If the desired elements are smaller or larger than the source
10948       // elements, we can use a matching integer vector type and then
10949       // truncate/sign extend.
10950       EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
10951       if (SVT == MatchingVecType) {
10952         SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
10953         return DAG.getSExtOrTrunc(VsetCC, DL, VT);
10954       }
10955     }
10956 
10957     // Try to eliminate the sext of a setcc by zexting the compare operands.
10958     if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
10959         !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
10960       bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
10961       unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
10962       unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
10963 
10964       // We have an unsupported narrow vector compare op that would be legal
10965       // if extended to the destination type. See if the compare operands
10966       // can be freely extended to the destination type.
10967       auto IsFreeToExtend = [&](SDValue V) {
10968         if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
10969           return true;
10970         // Match a simple, non-extended load that can be converted to a
10971         // legal {z/s}ext-load.
10972         // TODO: Allow widening of an existing {z/s}ext-load?
10973         if (!(ISD::isNON_EXTLoad(V.getNode()) &&
10974               ISD::isUNINDEXEDLoad(V.getNode()) &&
10975               cast<LoadSDNode>(V)->isSimple() &&
10976               TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
10977           return false;
10978 
10979         // Non-chain users of this value must either be the setcc in this
10980         // sequence or extends that can be folded into the new {z/s}ext-load.
10981         for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
10982              UI != UE; ++UI) {
10983           // Skip uses of the chain and the setcc.
10984           SDNode *User = *UI;
10985           if (UI.getUse().getResNo() != 0 || User == N0.getNode())
10986             continue;
10987           // Extra users must have exactly the same cast we are about to create.
10988           // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
10989           //       is enhanced similarly.
10990           if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
10991             return false;
10992         }
10993         return true;
10994       };
10995 
10996       if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
10997         SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
10998         SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
10999         return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
11000       }
11001     }
11002   }
11003 
11004   // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
11005   // Here, T can be 1 or -1, depending on the type of the setcc and
11006   // getBooleanContents().
11007   unsigned SetCCWidth = N0.getScalarValueSizeInBits();
11008 
11009   // To determine the "true" side of the select, we need to know the high bit
11010   // of the value returned by the setcc if it evaluates to true.
11011   // If the type of the setcc is i1, then the true case of the select is just
11012   // sext(i1 1), that is, -1.
11013   // If the type of the setcc is larger (say, i8) then the value of the high
11014   // bit depends on getBooleanContents(), so ask TLI for a real "true" value
11015   // of the appropriate width.
11016   SDValue ExtTrueVal = (SetCCWidth == 1)
11017                            ? DAG.getAllOnesConstant(DL, VT)
11018                            : DAG.getBoolConstant(true, DL, VT, N00VT);
11019   SDValue Zero = DAG.getConstant(0, DL, VT);
11020   if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
11021     return SCC;
11022 
11023   if (!VT.isVector() && !TLI.convertSelectOfConstantsToMath(VT)) {
11024     EVT SetCCVT = getSetCCResultType(N00VT);
11025     // Don't do this transform for i1 because there's a select transform
11026     // that would reverse it.
11027     // TODO: We should not do this transform at all without a target hook
11028     // because a sext is likely cheaper than a select?
11029     if (SetCCVT.getScalarSizeInBits() != 1 &&
11030         (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
11031       SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
11032       return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
11033     }
11034   }
11035 
11036   return SDValue();
11037 }
11038 
visitSIGN_EXTEND(SDNode * N)11039 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
11040   SDValue N0 = N->getOperand(0);
11041   EVT VT = N->getValueType(0);
11042   SDLoc DL(N);
11043 
11044   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
11045     return Res;
11046 
11047   // fold (sext (sext x)) -> (sext x)
11048   // fold (sext (aext x)) -> (sext x)
11049   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
11050     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
11051 
11052   if (N0.getOpcode() == ISD::TRUNCATE) {
11053     // fold (sext (truncate (load x))) -> (sext (smaller load x))
11054     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
11055     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
11056       SDNode *oye = N0.getOperand(0).getNode();
11057       if (NarrowLoad.getNode() != N0.getNode()) {
11058         CombineTo(N0.getNode(), NarrowLoad);
11059         // CombineTo deleted the truncate, if needed, but not what's under it.
11060         AddToWorklist(oye);
11061       }
11062       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
11063     }
11064 
11065     // See if the value being truncated is already sign extended.  If so, just
11066     // eliminate the trunc/sext pair.
11067     SDValue Op = N0.getOperand(0);
11068     unsigned OpBits   = Op.getScalarValueSizeInBits();
11069     unsigned MidBits  = N0.getScalarValueSizeInBits();
11070     unsigned DestBits = VT.getScalarSizeInBits();
11071     unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
11072 
11073     if (OpBits == DestBits) {
11074       // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
11075       // bits, it is already ready.
11076       if (NumSignBits > DestBits-MidBits)
11077         return Op;
11078     } else if (OpBits < DestBits) {
11079       // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
11080       // bits, just sext from i32.
11081       if (NumSignBits > OpBits-MidBits)
11082         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
11083     } else {
11084       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
11085       // bits, just truncate to i32.
11086       if (NumSignBits > OpBits-MidBits)
11087         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
11088     }
11089 
11090     // fold (sext (truncate x)) -> (sextinreg x).
11091     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
11092                                                  N0.getValueType())) {
11093       if (OpBits < DestBits)
11094         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
11095       else if (OpBits > DestBits)
11096         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
11097       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
11098                          DAG.getValueType(N0.getValueType()));
11099     }
11100   }
11101 
11102   // Try to simplify (sext (load x)).
11103   if (SDValue foldedExt =
11104           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
11105                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
11106     return foldedExt;
11107 
11108   if (SDValue foldedExt =
11109       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD,
11110                                ISD::SIGN_EXTEND))
11111     return foldedExt;
11112 
11113   // fold (sext (load x)) to multiple smaller sextloads.
11114   // Only on illegal but splittable vectors.
11115   if (SDValue ExtLoad = CombineExtLoad(N))
11116     return ExtLoad;
11117 
11118   // Try to simplify (sext (sextload x)).
11119   if (SDValue foldedExt = tryToFoldExtOfExtload(
11120           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
11121     return foldedExt;
11122 
11123   // fold (sext (and/or/xor (load x), cst)) ->
11124   //      (and/or/xor (sextload x), (sext cst))
11125   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
11126        N0.getOpcode() == ISD::XOR) &&
11127       isa<LoadSDNode>(N0.getOperand(0)) &&
11128       N0.getOperand(1).getOpcode() == ISD::Constant &&
11129       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
11130     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
11131     EVT MemVT = LN00->getMemoryVT();
11132     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
11133       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
11134       SmallVector<SDNode*, 4> SetCCs;
11135       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
11136                                              ISD::SIGN_EXTEND, SetCCs, TLI);
11137       if (DoXform) {
11138         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
11139                                          LN00->getChain(), LN00->getBasePtr(),
11140                                          LN00->getMemoryVT(),
11141                                          LN00->getMemOperand());
11142         APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
11143         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
11144                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
11145         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
11146         bool NoReplaceTruncAnd = !N0.hasOneUse();
11147         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
11148         CombineTo(N, And);
11149         // If N0 has multiple uses, change other uses as well.
11150         if (NoReplaceTruncAnd) {
11151           SDValue TruncAnd =
11152               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
11153           CombineTo(N0.getNode(), TruncAnd);
11154         }
11155         if (NoReplaceTrunc) {
11156           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
11157         } else {
11158           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
11159                                       LN00->getValueType(0), ExtLoad);
11160           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
11161         }
11162         return SDValue(N,0); // Return N so it doesn't get rechecked!
11163       }
11164     }
11165   }
11166 
11167   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
11168     return V;
11169 
11170   if (SDValue V = foldSextSetcc(N))
11171     return V;
11172 
11173   // fold (sext x) -> (zext x) if the sign bit is known zero.
11174   if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
11175       DAG.SignBitIsZero(N0))
11176     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
11177 
11178   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
11179     return NewVSel;
11180 
11181   // Eliminate this sign extend by doing a negation in the destination type:
11182   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
11183   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
11184       isNullOrNullSplat(N0.getOperand(0)) &&
11185       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
11186       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
11187     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
11188     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Zext);
11189   }
11190   // Eliminate this sign extend by doing a decrement in the destination type:
11191   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
11192   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
11193       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
11194       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
11195       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
11196     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
11197     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
11198   }
11199 
11200   // fold sext (not i1 X) -> add (zext i1 X), -1
11201   // TODO: This could be extended to handle bool vectors.
11202   if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
11203       (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
11204                             TLI.isOperationLegal(ISD::ADD, VT)))) {
11205     // If we can eliminate the 'not', the sext form should be better
11206     if (SDValue NewXor = visitXOR(N0.getNode())) {
11207       // Returning N0 is a form of in-visit replacement that may have
11208       // invalidated N0.
11209       if (NewXor.getNode() == N0.getNode()) {
11210         // Return SDValue here as the xor should have already been replaced in
11211         // this sext.
11212         return SDValue();
11213       } else {
11214         // Return a new sext with the new xor.
11215         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
11216       }
11217     }
11218 
11219     SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
11220     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
11221   }
11222 
11223   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
11224     return Res;
11225 
11226   return SDValue();
11227 }
11228 
11229 // isTruncateOf - If N is a truncate of some other value, return true, record
11230 // the value being truncated in Op and which of Op's bits are zero/one in Known.
11231 // This function computes KnownBits to avoid a duplicated call to
11232 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)11233 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
11234                          KnownBits &Known) {
11235   if (N->getOpcode() == ISD::TRUNCATE) {
11236     Op = N->getOperand(0);
11237     Known = DAG.computeKnownBits(Op);
11238     return true;
11239   }
11240 
11241   if (N.getOpcode() != ISD::SETCC ||
11242       N.getValueType().getScalarType() != MVT::i1 ||
11243       cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
11244     return false;
11245 
11246   SDValue Op0 = N->getOperand(0);
11247   SDValue Op1 = N->getOperand(1);
11248   assert(Op0.getValueType() == Op1.getValueType());
11249 
11250   if (isNullOrNullSplat(Op0))
11251     Op = Op1;
11252   else if (isNullOrNullSplat(Op1))
11253     Op = Op0;
11254   else
11255     return false;
11256 
11257   Known = DAG.computeKnownBits(Op);
11258 
11259   return (Known.Zero | 1).isAllOnesValue();
11260 }
11261 
11262 /// Given an extending node with a pop-count operand, if the target does not
11263 /// support a pop-count in the narrow source type but does support it in the
11264 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)11265 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
11266   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
11267           Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
11268 
11269   SDValue CtPop = Extend->getOperand(0);
11270   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
11271     return SDValue();
11272 
11273   EVT VT = Extend->getValueType(0);
11274   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11275   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
11276       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
11277     return SDValue();
11278 
11279   // zext (ctpop X) --> ctpop (zext X)
11280   SDLoc DL(Extend);
11281   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
11282   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
11283 }
11284 
visitZERO_EXTEND(SDNode * N)11285 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
11286   SDValue N0 = N->getOperand(0);
11287   EVT VT = N->getValueType(0);
11288 
11289   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
11290     return Res;
11291 
11292   // fold (zext (zext x)) -> (zext x)
11293   // fold (zext (aext x)) -> (zext x)
11294   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
11295     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT,
11296                        N0.getOperand(0));
11297 
11298   // fold (zext (truncate x)) -> (zext x) or
11299   //      (zext (truncate x)) -> (truncate x)
11300   // This is valid when the truncated bits of x are already zero.
11301   SDValue Op;
11302   KnownBits Known;
11303   if (isTruncateOf(DAG, N0, Op, Known)) {
11304     APInt TruncatedBits =
11305       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
11306       APInt(Op.getScalarValueSizeInBits(), 0) :
11307       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
11308                         N0.getScalarValueSizeInBits(),
11309                         std::min(Op.getScalarValueSizeInBits(),
11310                                  VT.getScalarSizeInBits()));
11311     if (TruncatedBits.isSubsetOf(Known.Zero))
11312       return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
11313   }
11314 
11315   // fold (zext (truncate x)) -> (and x, mask)
11316   if (N0.getOpcode() == ISD::TRUNCATE) {
11317     // fold (zext (truncate (load x))) -> (zext (smaller load x))
11318     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
11319     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
11320       SDNode *oye = N0.getOperand(0).getNode();
11321       if (NarrowLoad.getNode() != N0.getNode()) {
11322         CombineTo(N0.getNode(), NarrowLoad);
11323         // CombineTo deleted the truncate, if needed, but not what's under it.
11324         AddToWorklist(oye);
11325       }
11326       return SDValue(N, 0); // Return N so it doesn't get rechecked!
11327     }
11328 
11329     EVT SrcVT = N0.getOperand(0).getValueType();
11330     EVT MinVT = N0.getValueType();
11331 
11332     // Try to mask before the extension to avoid having to generate a larger mask,
11333     // possibly over several sub-vectors.
11334     if (SrcVT.bitsLT(VT) && VT.isVector()) {
11335       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
11336                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
11337         SDValue Op = N0.getOperand(0);
11338         Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
11339         AddToWorklist(Op.getNode());
11340         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
11341         // Transfer the debug info; the new node is equivalent to N0.
11342         DAG.transferDbgValues(N0, ZExtOrTrunc);
11343         return ZExtOrTrunc;
11344       }
11345     }
11346 
11347     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
11348       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
11349       AddToWorklist(Op.getNode());
11350       SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
11351       // We may safely transfer the debug info describing the truncate node over
11352       // to the equivalent and operation.
11353       DAG.transferDbgValues(N0, And);
11354       return And;
11355     }
11356   }
11357 
11358   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
11359   // if either of the casts is not free.
11360   if (N0.getOpcode() == ISD::AND &&
11361       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
11362       N0.getOperand(1).getOpcode() == ISD::Constant &&
11363       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
11364                            N0.getValueType()) ||
11365        !TLI.isZExtFree(N0.getValueType(), VT))) {
11366     SDValue X = N0.getOperand(0).getOperand(0);
11367     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
11368     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
11369     SDLoc DL(N);
11370     return DAG.getNode(ISD::AND, DL, VT,
11371                        X, DAG.getConstant(Mask, DL, VT));
11372   }
11373 
11374   // Try to simplify (zext (load x)).
11375   if (SDValue foldedExt =
11376           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
11377                              ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
11378     return foldedExt;
11379 
11380   if (SDValue foldedExt =
11381       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD,
11382                                ISD::ZERO_EXTEND))
11383     return foldedExt;
11384 
11385   // fold (zext (load x)) to multiple smaller zextloads.
11386   // Only on illegal but splittable vectors.
11387   if (SDValue ExtLoad = CombineExtLoad(N))
11388     return ExtLoad;
11389 
11390   // fold (zext (and/or/xor (load x), cst)) ->
11391   //      (and/or/xor (zextload x), (zext cst))
11392   // Unless (and (load x) cst) will match as a zextload already and has
11393   // additional users.
11394   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
11395        N0.getOpcode() == ISD::XOR) &&
11396       isa<LoadSDNode>(N0.getOperand(0)) &&
11397       N0.getOperand(1).getOpcode() == ISD::Constant &&
11398       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
11399     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
11400     EVT MemVT = LN00->getMemoryVT();
11401     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
11402         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
11403       bool DoXform = true;
11404       SmallVector<SDNode*, 4> SetCCs;
11405       if (!N0.hasOneUse()) {
11406         if (N0.getOpcode() == ISD::AND) {
11407           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
11408           EVT LoadResultTy = AndC->getValueType(0);
11409           EVT ExtVT;
11410           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
11411             DoXform = false;
11412         }
11413       }
11414       if (DoXform)
11415         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
11416                                           ISD::ZERO_EXTEND, SetCCs, TLI);
11417       if (DoXform) {
11418         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
11419                                          LN00->getChain(), LN00->getBasePtr(),
11420                                          LN00->getMemoryVT(),
11421                                          LN00->getMemOperand());
11422         APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
11423         SDLoc DL(N);
11424         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
11425                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
11426         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
11427         bool NoReplaceTruncAnd = !N0.hasOneUse();
11428         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
11429         CombineTo(N, And);
11430         // If N0 has multiple uses, change other uses as well.
11431         if (NoReplaceTruncAnd) {
11432           SDValue TruncAnd =
11433               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
11434           CombineTo(N0.getNode(), TruncAnd);
11435         }
11436         if (NoReplaceTrunc) {
11437           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
11438         } else {
11439           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
11440                                       LN00->getValueType(0), ExtLoad);
11441           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
11442         }
11443         return SDValue(N,0); // Return N so it doesn't get rechecked!
11444       }
11445     }
11446   }
11447 
11448   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
11449   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
11450   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
11451     return ZExtLoad;
11452 
11453   // Try to simplify (zext (zextload x)).
11454   if (SDValue foldedExt = tryToFoldExtOfExtload(
11455           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
11456     return foldedExt;
11457 
11458   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
11459     return V;
11460 
11461   if (N0.getOpcode() == ISD::SETCC) {
11462     // Only do this before legalize for now.
11463     if (!LegalOperations && VT.isVector() &&
11464         N0.getValueType().getVectorElementType() == MVT::i1) {
11465       EVT N00VT = N0.getOperand(0).getValueType();
11466       if (getSetCCResultType(N00VT) == N0.getValueType())
11467         return SDValue();
11468 
11469       // We know that the # elements of the results is the same as the #
11470       // elements of the compare (and the # elements of the compare result for
11471       // that matter). Check to see that they are the same size. If so, we know
11472       // that the element size of the sext'd result matches the element size of
11473       // the compare operands.
11474       SDLoc DL(N);
11475       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
11476         // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
11477         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
11478                                      N0.getOperand(1), N0.getOperand(2));
11479         return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
11480       }
11481 
11482       // If the desired elements are smaller or larger than the source
11483       // elements we can use a matching integer vector type and then
11484       // truncate/any extend followed by zext_in_reg.
11485       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
11486       SDValue VsetCC =
11487           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
11488                       N0.getOperand(1), N0.getOperand(2));
11489       return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
11490                                     N0.getValueType());
11491     }
11492 
11493     // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
11494     SDLoc DL(N);
11495     EVT N0VT = N0.getValueType();
11496     EVT N00VT = N0.getOperand(0).getValueType();
11497     if (SDValue SCC = SimplifySelectCC(
11498             DL, N0.getOperand(0), N0.getOperand(1),
11499             DAG.getBoolConstant(true, DL, N0VT, N00VT),
11500             DAG.getBoolConstant(false, DL, N0VT, N00VT),
11501             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
11502       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
11503   }
11504 
11505   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
11506   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11507       isa<ConstantSDNode>(N0.getOperand(1)) &&
11508       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
11509       N0.hasOneUse()) {
11510     SDValue ShAmt = N0.getOperand(1);
11511     if (N0.getOpcode() == ISD::SHL) {
11512       SDValue InnerZExt = N0.getOperand(0);
11513       // If the original shl may be shifting out bits, do not perform this
11514       // transformation.
11515       unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
11516         InnerZExt.getOperand(0).getValueSizeInBits();
11517       if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits))
11518         return SDValue();
11519     }
11520 
11521     SDLoc DL(N);
11522 
11523     // Ensure that the shift amount is wide enough for the shifted value.
11524     if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
11525       ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
11526 
11527     return DAG.getNode(N0.getOpcode(), DL, VT,
11528                        DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
11529                        ShAmt);
11530   }
11531 
11532   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
11533     return NewVSel;
11534 
11535   if (SDValue NewCtPop = widenCtPop(N, DAG))
11536     return NewCtPop;
11537 
11538   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
11539     return Res;
11540 
11541   return SDValue();
11542 }
11543 
visitANY_EXTEND(SDNode * N)11544 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
11545   SDValue N0 = N->getOperand(0);
11546   EVT VT = N->getValueType(0);
11547 
11548   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
11549     return Res;
11550 
11551   // fold (aext (aext x)) -> (aext x)
11552   // fold (aext (zext x)) -> (zext x)
11553   // fold (aext (sext x)) -> (sext x)
11554   if (N0.getOpcode() == ISD::ANY_EXTEND  ||
11555       N0.getOpcode() == ISD::ZERO_EXTEND ||
11556       N0.getOpcode() == ISD::SIGN_EXTEND)
11557     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
11558 
11559   // fold (aext (truncate (load x))) -> (aext (smaller load x))
11560   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
11561   if (N0.getOpcode() == ISD::TRUNCATE) {
11562     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
11563       SDNode *oye = N0.getOperand(0).getNode();
11564       if (NarrowLoad.getNode() != N0.getNode()) {
11565         CombineTo(N0.getNode(), NarrowLoad);
11566         // CombineTo deleted the truncate, if needed, but not what's under it.
11567         AddToWorklist(oye);
11568       }
11569       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
11570     }
11571   }
11572 
11573   // fold (aext (truncate x))
11574   if (N0.getOpcode() == ISD::TRUNCATE)
11575     return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
11576 
11577   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
11578   // if the trunc is not free.
11579   if (N0.getOpcode() == ISD::AND &&
11580       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
11581       N0.getOperand(1).getOpcode() == ISD::Constant &&
11582       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
11583                           N0.getValueType())) {
11584     SDLoc DL(N);
11585     SDValue X = N0.getOperand(0).getOperand(0);
11586     X = DAG.getAnyExtOrTrunc(X, DL, VT);
11587     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
11588     return DAG.getNode(ISD::AND, DL, VT,
11589                        X, DAG.getConstant(Mask, DL, VT));
11590   }
11591 
11592   // fold (aext (load x)) -> (aext (truncate (extload x)))
11593   // None of the supported targets knows how to perform load and any_ext
11594   // on vectors in one instruction, so attempt to fold to zext instead.
11595   if (VT.isVector()) {
11596     // Try to simplify (zext (load x)).
11597     if (SDValue foldedExt =
11598             tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
11599                                ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
11600       return foldedExt;
11601   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
11602              ISD::isUNINDEXEDLoad(N0.getNode()) &&
11603              TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
11604     bool DoXform = true;
11605     SmallVector<SDNode *, 4> SetCCs;
11606     if (!N0.hasOneUse())
11607       DoXform =
11608           ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
11609     if (DoXform) {
11610       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11611       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
11612                                        LN0->getChain(), LN0->getBasePtr(),
11613                                        N0.getValueType(), LN0->getMemOperand());
11614       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
11615       // If the load value is used only by N, replace it via CombineTo N.
11616       bool NoReplaceTrunc = N0.hasOneUse();
11617       CombineTo(N, ExtLoad);
11618       if (NoReplaceTrunc) {
11619         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
11620         recursivelyDeleteUnusedNodes(LN0);
11621       } else {
11622         SDValue Trunc =
11623             DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
11624         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
11625       }
11626       return SDValue(N, 0); // Return N so it doesn't get rechecked!
11627     }
11628   }
11629 
11630   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
11631   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
11632   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
11633   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
11634       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
11635     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11636     ISD::LoadExtType ExtType = LN0->getExtensionType();
11637     EVT MemVT = LN0->getMemoryVT();
11638     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
11639       SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
11640                                        VT, LN0->getChain(), LN0->getBasePtr(),
11641                                        MemVT, LN0->getMemOperand());
11642       CombineTo(N, ExtLoad);
11643       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
11644       recursivelyDeleteUnusedNodes(LN0);
11645       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
11646     }
11647   }
11648 
11649   if (N0.getOpcode() == ISD::SETCC) {
11650     // For vectors:
11651     // aext(setcc) -> vsetcc
11652     // aext(setcc) -> truncate(vsetcc)
11653     // aext(setcc) -> aext(vsetcc)
11654     // Only do this before legalize for now.
11655     if (VT.isVector() && !LegalOperations) {
11656       EVT N00VT = N0.getOperand(0).getValueType();
11657       if (getSetCCResultType(N00VT) == N0.getValueType())
11658         return SDValue();
11659 
11660       // We know that the # elements of the results is the same as the
11661       // # elements of the compare (and the # elements of the compare result
11662       // for that matter).  Check to see that they are the same size.  If so,
11663       // we know that the element size of the sext'd result matches the
11664       // element size of the compare operands.
11665       if (VT.getSizeInBits() == N00VT.getSizeInBits())
11666         return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
11667                              N0.getOperand(1),
11668                              cast<CondCodeSDNode>(N0.getOperand(2))->get());
11669 
11670       // If the desired elements are smaller or larger than the source
11671       // elements we can use a matching integer vector type and then
11672       // truncate/any extend
11673       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
11674       SDValue VsetCC =
11675         DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
11676                       N0.getOperand(1),
11677                       cast<CondCodeSDNode>(N0.getOperand(2))->get());
11678       return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
11679     }
11680 
11681     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
11682     SDLoc DL(N);
11683     if (SDValue SCC = SimplifySelectCC(
11684             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
11685             DAG.getConstant(0, DL, VT),
11686             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
11687       return SCC;
11688   }
11689 
11690   if (SDValue NewCtPop = widenCtPop(N, DAG))
11691     return NewCtPop;
11692 
11693   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
11694     return Res;
11695 
11696   return SDValue();
11697 }
11698 
visitAssertExt(SDNode * N)11699 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
11700   unsigned Opcode = N->getOpcode();
11701   SDValue N0 = N->getOperand(0);
11702   SDValue N1 = N->getOperand(1);
11703   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
11704 
11705   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
11706   if (N0.getOpcode() == Opcode &&
11707       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
11708     return N0;
11709 
11710   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
11711       N0.getOperand(0).getOpcode() == Opcode) {
11712     // We have an assert, truncate, assert sandwich. Make one stronger assert
11713     // by asserting on the smallest asserted type to the larger source type.
11714     // This eliminates the later assert:
11715     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
11716     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
11717     SDValue BigA = N0.getOperand(0);
11718     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
11719     assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
11720            "Asserting zero/sign-extended bits to a type larger than the "
11721            "truncated destination does not provide information");
11722 
11723     SDLoc DL(N);
11724     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
11725     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
11726     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
11727                                     BigA.getOperand(0), MinAssertVTVal);
11728     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
11729   }
11730 
11731   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
11732   // than X. Just move the AssertZext in front of the truncate and drop the
11733   // AssertSExt.
11734   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
11735       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
11736       Opcode == ISD::AssertZext) {
11737     SDValue BigA = N0.getOperand(0);
11738     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
11739     assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
11740            "Asserting zero/sign-extended bits to a type larger than the "
11741            "truncated destination does not provide information");
11742 
11743     if (AssertVT.bitsLT(BigA_AssertVT)) {
11744       SDLoc DL(N);
11745       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
11746                                       BigA.getOperand(0), N1);
11747       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
11748     }
11749   }
11750 
11751   return SDValue();
11752 }
11753 
visitAssertAlign(SDNode * N)11754 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
11755   SDLoc DL(N);
11756 
11757   Align AL = cast<AssertAlignSDNode>(N)->getAlign();
11758   SDValue N0 = N->getOperand(0);
11759 
11760   // Fold (assertalign (assertalign x, AL0), AL1) ->
11761   // (assertalign x, max(AL0, AL1))
11762   if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
11763     return DAG.getAssertAlign(DL, N0.getOperand(0),
11764                               std::max(AL, AAN->getAlign()));
11765 
11766   // In rare cases, there are trivial arithmetic ops in source operands. Sink
11767   // this assert down to source operands so that those arithmetic ops could be
11768   // exposed to the DAG combining.
11769   switch (N0.getOpcode()) {
11770   default:
11771     break;
11772   case ISD::ADD:
11773   case ISD::SUB: {
11774     unsigned AlignShift = Log2(AL);
11775     SDValue LHS = N0.getOperand(0);
11776     SDValue RHS = N0.getOperand(1);
11777     unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
11778     unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
11779     if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
11780       if (LHSAlignShift < AlignShift)
11781         LHS = DAG.getAssertAlign(DL, LHS, AL);
11782       if (RHSAlignShift < AlignShift)
11783         RHS = DAG.getAssertAlign(DL, RHS, AL);
11784       return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
11785     }
11786     break;
11787   }
11788   }
11789 
11790   return SDValue();
11791 }
11792 
11793 /// If the result of a wider load is shifted to right of N  bits and then
11794 /// truncated to a narrower type and where N is a multiple of number of bits of
11795 /// the narrower type, transform it to a narrower load from address + N / num of
11796 /// bits of new type. Also narrow the load if the result is masked with an AND
11797 /// to effectively produce a smaller type. If the result is to be extended, also
11798 /// fold the extension to form a extending load.
ReduceLoadWidth(SDNode * N)11799 SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
11800   unsigned Opc = N->getOpcode();
11801 
11802   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
11803   SDValue N0 = N->getOperand(0);
11804   EVT VT = N->getValueType(0);
11805   EVT ExtVT = VT;
11806 
11807   // This transformation isn't valid for vector loads.
11808   if (VT.isVector())
11809     return SDValue();
11810 
11811   unsigned ShAmt = 0;
11812   bool HasShiftedOffset = false;
11813   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
11814   // extended to VT.
11815   if (Opc == ISD::SIGN_EXTEND_INREG) {
11816     ExtType = ISD::SEXTLOAD;
11817     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
11818   } else if (Opc == ISD::SRL) {
11819     // Another special-case: SRL is basically zero-extending a narrower value,
11820     // or it maybe shifting a higher subword, half or byte into the lowest
11821     // bits.
11822     ExtType = ISD::ZEXTLOAD;
11823     N0 = SDValue(N, 0);
11824 
11825     auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
11826     auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11827     if (!N01 || !LN0)
11828       return SDValue();
11829 
11830     uint64_t ShiftAmt = N01->getZExtValue();
11831     uint64_t MemoryWidth = LN0->getMemoryVT().getScalarSizeInBits();
11832     if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
11833       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
11834     else
11835       ExtVT = EVT::getIntegerVT(*DAG.getContext(),
11836                                 VT.getScalarSizeInBits() - ShiftAmt);
11837   } else if (Opc == ISD::AND) {
11838     // An AND with a constant mask is the same as a truncate + zero-extend.
11839     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
11840     if (!AndC)
11841       return SDValue();
11842 
11843     const APInt &Mask = AndC->getAPIntValue();
11844     unsigned ActiveBits = 0;
11845     if (Mask.isMask()) {
11846       ActiveBits = Mask.countTrailingOnes();
11847     } else if (Mask.isShiftedMask()) {
11848       ShAmt = Mask.countTrailingZeros();
11849       APInt ShiftedMask = Mask.lshr(ShAmt);
11850       ActiveBits = ShiftedMask.countTrailingOnes();
11851       HasShiftedOffset = true;
11852     } else
11853       return SDValue();
11854 
11855     ExtType = ISD::ZEXTLOAD;
11856     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
11857   }
11858 
11859   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
11860     SDValue SRL = N0;
11861     if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
11862       ShAmt = ConstShift->getZExtValue();
11863       unsigned EVTBits = ExtVT.getScalarSizeInBits();
11864       // Is the shift amount a multiple of size of VT?
11865       if ((ShAmt & (EVTBits-1)) == 0) {
11866         N0 = N0.getOperand(0);
11867         // Is the load width a multiple of size of VT?
11868         if ((N0.getScalarValueSizeInBits() & (EVTBits - 1)) != 0)
11869           return SDValue();
11870       }
11871 
11872       // At this point, we must have a load or else we can't do the transform.
11873       auto *LN0 = dyn_cast<LoadSDNode>(N0);
11874       if (!LN0) return SDValue();
11875 
11876       // Because a SRL must be assumed to *need* to zero-extend the high bits
11877       // (as opposed to anyext the high bits), we can't combine the zextload
11878       // lowering of SRL and an sextload.
11879       if (LN0->getExtensionType() == ISD::SEXTLOAD)
11880         return SDValue();
11881 
11882       // If the shift amount is larger than the input type then we're not
11883       // accessing any of the loaded bytes.  If the load was a zextload/extload
11884       // then the result of the shift+trunc is zero/undef (handled elsewhere).
11885       if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
11886         return SDValue();
11887 
11888       // If the SRL is only used by a masking AND, we may be able to adjust
11889       // the ExtVT to make the AND redundant.
11890       SDNode *Mask = *(SRL->use_begin());
11891       if (Mask->getOpcode() == ISD::AND &&
11892           isa<ConstantSDNode>(Mask->getOperand(1))) {
11893         const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
11894         if (ShiftMask.isMask()) {
11895           EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
11896                                            ShiftMask.countTrailingOnes());
11897           // If the mask is smaller, recompute the type.
11898           if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
11899               TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
11900             ExtVT = MaskedVT;
11901         }
11902       }
11903     }
11904   }
11905 
11906   // If the load is shifted left (and the result isn't shifted back right),
11907   // we can fold the truncate through the shift.
11908   unsigned ShLeftAmt = 0;
11909   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
11910       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
11911     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
11912       ShLeftAmt = N01->getZExtValue();
11913       N0 = N0.getOperand(0);
11914     }
11915   }
11916 
11917   // If we haven't found a load, we can't narrow it.
11918   if (!isa<LoadSDNode>(N0))
11919     return SDValue();
11920 
11921   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11922   // Reducing the width of a volatile load is illegal.  For atomics, we may be
11923   // able to reduce the width provided we never widen again. (see D66309)
11924   if (!LN0->isSimple() ||
11925       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
11926     return SDValue();
11927 
11928   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
11929     unsigned LVTStoreBits =
11930         LN0->getMemoryVT().getStoreSizeInBits().getFixedSize();
11931     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedSize();
11932     return LVTStoreBits - EVTStoreBits - ShAmt;
11933   };
11934 
11935   // For big endian targets, we need to adjust the offset to the pointer to
11936   // load the correct bytes.
11937   if (DAG.getDataLayout().isBigEndian())
11938     ShAmt = AdjustBigEndianShift(ShAmt);
11939 
11940   uint64_t PtrOff = ShAmt / 8;
11941   Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
11942   SDLoc DL(LN0);
11943   // The original load itself didn't wrap, so an offset within it doesn't.
11944   SDNodeFlags Flags;
11945   Flags.setNoUnsignedWrap(true);
11946   SDValue NewPtr = DAG.getMemBasePlusOffset(LN0->getBasePtr(),
11947                                             TypeSize::Fixed(PtrOff), DL, Flags);
11948   AddToWorklist(NewPtr.getNode());
11949 
11950   SDValue Load;
11951   if (ExtType == ISD::NON_EXTLOAD)
11952     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
11953                        LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
11954                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
11955   else
11956     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
11957                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
11958                           NewAlign, LN0->getMemOperand()->getFlags(),
11959                           LN0->getAAInfo());
11960 
11961   // Replace the old load's chain with the new load's chain.
11962   WorklistRemover DeadNodes(*this);
11963   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
11964 
11965   // Shift the result left, if we've swallowed a left shift.
11966   SDValue Result = Load;
11967   if (ShLeftAmt != 0) {
11968     EVT ShImmTy = getShiftAmountTy(Result.getValueType());
11969     if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
11970       ShImmTy = VT;
11971     // If the shift amount is as large as the result size (but, presumably,
11972     // no larger than the source) then the useful bits of the result are
11973     // zero; we can't simply return the shortened shift, because the result
11974     // of that operation is undefined.
11975     if (ShLeftAmt >= VT.getScalarSizeInBits())
11976       Result = DAG.getConstant(0, DL, VT);
11977     else
11978       Result = DAG.getNode(ISD::SHL, DL, VT,
11979                           Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
11980   }
11981 
11982   if (HasShiftedOffset) {
11983     // Recalculate the shift amount after it has been altered to calculate
11984     // the offset.
11985     if (DAG.getDataLayout().isBigEndian())
11986       ShAmt = AdjustBigEndianShift(ShAmt);
11987 
11988     // We're using a shifted mask, so the load now has an offset. This means
11989     // that data has been loaded into the lower bytes than it would have been
11990     // before, so we need to shl the loaded data into the correct position in the
11991     // register.
11992     SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
11993     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
11994     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
11995   }
11996 
11997   // Return the new loaded value.
11998   return Result;
11999 }
12000 
visitSIGN_EXTEND_INREG(SDNode * N)12001 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
12002   SDValue N0 = N->getOperand(0);
12003   SDValue N1 = N->getOperand(1);
12004   EVT VT = N->getValueType(0);
12005   EVT ExtVT = cast<VTSDNode>(N1)->getVT();
12006   unsigned VTBits = VT.getScalarSizeInBits();
12007   unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
12008 
12009   // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
12010   if (N0.isUndef())
12011     return DAG.getConstant(0, SDLoc(N), VT);
12012 
12013   // fold (sext_in_reg c1) -> c1
12014   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
12015     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
12016 
12017   // If the input is already sign extended, just drop the extension.
12018   if (DAG.ComputeNumSignBits(N0) >= (VTBits - ExtVTBits + 1))
12019     return N0;
12020 
12021   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
12022   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
12023       ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
12024     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0),
12025                        N1);
12026 
12027   // fold (sext_in_reg (sext x)) -> (sext x)
12028   // fold (sext_in_reg (aext x)) -> (sext x)
12029   // if x is small enough or if we know that x has more than 1 sign bit and the
12030   // sign_extend_inreg is extending from one of them.
12031   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
12032     SDValue N00 = N0.getOperand(0);
12033     unsigned N00Bits = N00.getScalarValueSizeInBits();
12034     if ((N00Bits <= ExtVTBits ||
12035          (N00Bits - DAG.ComputeNumSignBits(N00)) < ExtVTBits) &&
12036         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
12037       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
12038   }
12039 
12040   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
12041   // if x is small enough or if we know that x has more than 1 sign bit and the
12042   // sign_extend_inreg is extending from one of them.
12043   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
12044       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
12045       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
12046     SDValue N00 = N0.getOperand(0);
12047     unsigned N00Bits = N00.getScalarValueSizeInBits();
12048     unsigned DstElts = N0.getValueType().getVectorMinNumElements();
12049     unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
12050     bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
12051     APInt DemandedSrcElts = APInt::getLowBitsSet(SrcElts, DstElts);
12052     if ((N00Bits == ExtVTBits ||
12053          (!IsZext && (N00Bits < ExtVTBits ||
12054                       (N00Bits - DAG.ComputeNumSignBits(N00, DemandedSrcElts)) <
12055                           ExtVTBits))) &&
12056         (!LegalOperations ||
12057          TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
12058       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, N00);
12059   }
12060 
12061   // fold (sext_in_reg (zext x)) -> (sext x)
12062   // iff we are extending the source sign bit.
12063   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
12064     SDValue N00 = N0.getOperand(0);
12065     if (N00.getScalarValueSizeInBits() == ExtVTBits &&
12066         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
12067       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
12068   }
12069 
12070   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
12071   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
12072     return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT);
12073 
12074   // fold operands of sext_in_reg based on knowledge that the top bits are not
12075   // demanded.
12076   if (SimplifyDemandedBits(SDValue(N, 0)))
12077     return SDValue(N, 0);
12078 
12079   // fold (sext_in_reg (load x)) -> (smaller sextload x)
12080   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
12081   if (SDValue NarrowLoad = ReduceLoadWidth(N))
12082     return NarrowLoad;
12083 
12084   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
12085   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
12086   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
12087   if (N0.getOpcode() == ISD::SRL) {
12088     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
12089       if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
12090         // We can turn this into an SRA iff the input to the SRL is already sign
12091         // extended enough.
12092         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
12093         if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
12094           return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
12095                              N0.getOperand(1));
12096       }
12097   }
12098 
12099   // fold (sext_inreg (extload x)) -> (sextload x)
12100   // If sextload is not supported by target, we can only do the combine when
12101   // load has one use. Doing otherwise can block folding the extload with other
12102   // extends that the target does support.
12103   if (ISD::isEXTLoad(N0.getNode()) &&
12104       ISD::isUNINDEXEDLoad(N0.getNode()) &&
12105       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
12106       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
12107         N0.hasOneUse()) ||
12108        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
12109     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12110     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
12111                                      LN0->getChain(),
12112                                      LN0->getBasePtr(), ExtVT,
12113                                      LN0->getMemOperand());
12114     CombineTo(N, ExtLoad);
12115     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
12116     AddToWorklist(ExtLoad.getNode());
12117     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
12118   }
12119 
12120   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
12121   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
12122       N0.hasOneUse() &&
12123       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
12124       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
12125        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
12126     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12127     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
12128                                      LN0->getChain(),
12129                                      LN0->getBasePtr(), ExtVT,
12130                                      LN0->getMemOperand());
12131     CombineTo(N, ExtLoad);
12132     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
12133     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
12134   }
12135 
12136   // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
12137   // ignore it if the masked load is already sign extended
12138   if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
12139     if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
12140         Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
12141         TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
12142       SDValue ExtMaskedLoad = DAG.getMaskedLoad(
12143           VT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
12144           Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
12145           Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
12146       CombineTo(N, ExtMaskedLoad);
12147       CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
12148       return SDValue(N, 0); // Return N so it doesn't get rechecked!
12149     }
12150   }
12151 
12152   // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
12153   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
12154     if (SDValue(GN0, 0).hasOneUse() &&
12155         ExtVT == GN0->getMemoryVT() &&
12156         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
12157       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
12158                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
12159 
12160       SDValue ExtLoad = DAG.getMaskedGather(
12161           DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
12162           GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
12163 
12164       CombineTo(N, ExtLoad);
12165       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
12166       AddToWorklist(ExtLoad.getNode());
12167       return SDValue(N, 0); // Return N so it doesn't get rechecked!
12168     }
12169   }
12170 
12171   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
12172   if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
12173     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
12174                                            N0.getOperand(1), false))
12175       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1);
12176   }
12177 
12178   return SDValue();
12179 }
12180 
visitEXTEND_VECTOR_INREG(SDNode * N)12181 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
12182   SDValue N0 = N->getOperand(0);
12183   EVT VT = N->getValueType(0);
12184 
12185   // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
12186   if (N0.isUndef())
12187     return DAG.getConstant(0, SDLoc(N), VT);
12188 
12189   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12190     return Res;
12191 
12192   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
12193     return SDValue(N, 0);
12194 
12195   return SDValue();
12196 }
12197 
visitTRUNCATE(SDNode * N)12198 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
12199   SDValue N0 = N->getOperand(0);
12200   EVT VT = N->getValueType(0);
12201   EVT SrcVT = N0.getValueType();
12202   bool isLE = DAG.getDataLayout().isLittleEndian();
12203 
12204   // noop truncate
12205   if (SrcVT == VT)
12206     return N0;
12207 
12208   // fold (truncate (truncate x)) -> (truncate x)
12209   if (N0.getOpcode() == ISD::TRUNCATE)
12210     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
12211 
12212   // fold (truncate c1) -> c1
12213   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
12214     SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
12215     if (C.getNode() != N)
12216       return C;
12217   }
12218 
12219   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
12220   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
12221       N0.getOpcode() == ISD::SIGN_EXTEND ||
12222       N0.getOpcode() == ISD::ANY_EXTEND) {
12223     // if the source is smaller than the dest, we still need an extend.
12224     if (N0.getOperand(0).getValueType().bitsLT(VT))
12225       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
12226     // if the source is larger than the dest, than we just need the truncate.
12227     if (N0.getOperand(0).getValueType().bitsGT(VT))
12228       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
12229     // if the source and dest are the same type, we can drop both the extend
12230     // and the truncate.
12231     return N0.getOperand(0);
12232   }
12233 
12234   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
12235   if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
12236     return SDValue();
12237 
12238   // Fold extract-and-trunc into a narrow extract. For example:
12239   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
12240   //   i32 y = TRUNCATE(i64 x)
12241   //        -- becomes --
12242   //   v16i8 b = BITCAST (v2i64 val)
12243   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
12244   //
12245   // Note: We only run this optimization after type legalization (which often
12246   // creates this pattern) and before operation legalization after which
12247   // we need to be more careful about the vector instructions that we generate.
12248   if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
12249       LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
12250     EVT VecTy = N0.getOperand(0).getValueType();
12251     EVT ExTy = N0.getValueType();
12252     EVT TrTy = N->getValueType(0);
12253 
12254     auto EltCnt = VecTy.getVectorElementCount();
12255     unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
12256     auto NewEltCnt = EltCnt * SizeRatio;
12257 
12258     EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
12259     assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
12260 
12261     SDValue EltNo = N0->getOperand(1);
12262     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
12263       int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
12264       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
12265 
12266       SDLoc DL(N);
12267       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
12268                          DAG.getBitcast(NVT, N0.getOperand(0)),
12269                          DAG.getVectorIdxConstant(Index, DL));
12270     }
12271   }
12272 
12273   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
12274   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
12275     if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
12276         TLI.isTruncateFree(SrcVT, VT)) {
12277       SDLoc SL(N0);
12278       SDValue Cond = N0.getOperand(0);
12279       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
12280       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
12281       return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
12282     }
12283   }
12284 
12285   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
12286   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
12287       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
12288       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
12289     SDValue Amt = N0.getOperand(1);
12290     KnownBits Known = DAG.computeKnownBits(Amt);
12291     unsigned Size = VT.getScalarSizeInBits();
12292     if (Known.getBitWidth() - Known.countMinLeadingZeros() <= Log2_32(Size)) {
12293       SDLoc SL(N);
12294       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
12295 
12296       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
12297       if (AmtVT != Amt.getValueType()) {
12298         Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
12299         AddToWorklist(Amt.getNode());
12300       }
12301       return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
12302     }
12303   }
12304 
12305   if (SDValue V = foldSubToUSubSat(VT, N0.getNode()))
12306     return V;
12307 
12308   // Attempt to pre-truncate BUILD_VECTOR sources.
12309   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
12310       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
12311       // Avoid creating illegal types if running after type legalizer.
12312       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
12313     SDLoc DL(N);
12314     EVT SVT = VT.getScalarType();
12315     SmallVector<SDValue, 8> TruncOps;
12316     for (const SDValue &Op : N0->op_values()) {
12317       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
12318       TruncOps.push_back(TruncOp);
12319     }
12320     return DAG.getBuildVector(VT, DL, TruncOps);
12321   }
12322 
12323   // Fold a series of buildvector, bitcast, and truncate if possible.
12324   // For example fold
12325   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
12326   //   (2xi32 (buildvector x, y)).
12327   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
12328       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
12329       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
12330       N0.getOperand(0).hasOneUse()) {
12331     SDValue BuildVect = N0.getOperand(0);
12332     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
12333     EVT TruncVecEltTy = VT.getVectorElementType();
12334 
12335     // Check that the element types match.
12336     if (BuildVectEltTy == TruncVecEltTy) {
12337       // Now we only need to compute the offset of the truncated elements.
12338       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
12339       unsigned TruncVecNumElts = VT.getVectorNumElements();
12340       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
12341 
12342       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
12343              "Invalid number of elements");
12344 
12345       SmallVector<SDValue, 8> Opnds;
12346       for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
12347         Opnds.push_back(BuildVect.getOperand(i));
12348 
12349       return DAG.getBuildVector(VT, SDLoc(N), Opnds);
12350     }
12351   }
12352 
12353   // See if we can simplify the input to this truncate through knowledge that
12354   // only the low bits are being used.
12355   // For example "trunc (or (shl x, 8), y)" // -> trunc y
12356   // Currently we only perform this optimization on scalars because vectors
12357   // may have different active low bits.
12358   if (!VT.isVector()) {
12359     APInt Mask =
12360         APInt::getLowBitsSet(N0.getValueSizeInBits(), VT.getSizeInBits());
12361     if (SDValue Shorter = DAG.GetDemandedBits(N0, Mask))
12362       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter);
12363   }
12364 
12365   // fold (truncate (load x)) -> (smaller load x)
12366   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
12367   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
12368     if (SDValue Reduced = ReduceLoadWidth(N))
12369       return Reduced;
12370 
12371     // Handle the case where the load remains an extending load even
12372     // after truncation.
12373     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
12374       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12375       if (LN0->isSimple() && LN0->getMemoryVT().bitsLT(VT)) {
12376         SDValue NewLoad = DAG.getExtLoad(LN0->getExtensionType(), SDLoc(LN0),
12377                                          VT, LN0->getChain(), LN0->getBasePtr(),
12378                                          LN0->getMemoryVT(),
12379                                          LN0->getMemOperand());
12380         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
12381         return NewLoad;
12382       }
12383     }
12384   }
12385 
12386   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
12387   // where ... are all 'undef'.
12388   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
12389     SmallVector<EVT, 8> VTs;
12390     SDValue V;
12391     unsigned Idx = 0;
12392     unsigned NumDefs = 0;
12393 
12394     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
12395       SDValue X = N0.getOperand(i);
12396       if (!X.isUndef()) {
12397         V = X;
12398         Idx = i;
12399         NumDefs++;
12400       }
12401       // Stop if more than one members are non-undef.
12402       if (NumDefs > 1)
12403         break;
12404 
12405       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
12406                                      VT.getVectorElementType(),
12407                                      X.getValueType().getVectorElementCount()));
12408     }
12409 
12410     if (NumDefs == 0)
12411       return DAG.getUNDEF(VT);
12412 
12413     if (NumDefs == 1) {
12414       assert(V.getNode() && "The single defined operand is empty!");
12415       SmallVector<SDValue, 8> Opnds;
12416       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
12417         if (i != Idx) {
12418           Opnds.push_back(DAG.getUNDEF(VTs[i]));
12419           continue;
12420         }
12421         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
12422         AddToWorklist(NV.getNode());
12423         Opnds.push_back(NV);
12424       }
12425       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
12426     }
12427   }
12428 
12429   // Fold truncate of a bitcast of a vector to an extract of the low vector
12430   // element.
12431   //
12432   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
12433   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
12434     SDValue VecSrc = N0.getOperand(0);
12435     EVT VecSrcVT = VecSrc.getValueType();
12436     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
12437         (!LegalOperations ||
12438          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
12439       SDLoc SL(N);
12440 
12441       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
12442       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
12443                          DAG.getVectorIdxConstant(Idx, SL));
12444     }
12445   }
12446 
12447   // Simplify the operands using demanded-bits information.
12448   if (SimplifyDemandedBits(SDValue(N, 0)))
12449     return SDValue(N, 0);
12450 
12451   // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
12452   // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
12453   // When the adde's carry is not used.
12454   if ((N0.getOpcode() == ISD::ADDE || N0.getOpcode() == ISD::ADDCARRY) &&
12455       N0.hasOneUse() && !N0.getNode()->hasAnyUseOfValue(1) &&
12456       // We only do for addcarry before legalize operation
12457       ((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
12458        TLI.isOperationLegal(N0.getOpcode(), VT))) {
12459     SDLoc SL(N);
12460     auto X = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
12461     auto Y = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
12462     auto VTs = DAG.getVTList(VT, N0->getValueType(1));
12463     return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2));
12464   }
12465 
12466   // fold (truncate (extract_subvector(ext x))) ->
12467   //      (extract_subvector x)
12468   // TODO: This can be generalized to cover cases where the truncate and extract
12469   // do not fully cancel each other out.
12470   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
12471     SDValue N00 = N0.getOperand(0);
12472     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
12473         N00.getOpcode() == ISD::ZERO_EXTEND ||
12474         N00.getOpcode() == ISD::ANY_EXTEND) {
12475       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
12476           VT.getVectorElementType())
12477         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
12478                            N00.getOperand(0), N0.getOperand(1));
12479     }
12480   }
12481 
12482   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12483     return NewVSel;
12484 
12485   // Narrow a suitable binary operation with a non-opaque constant operand by
12486   // moving it ahead of the truncate. This is limited to pre-legalization
12487   // because targets may prefer a wider type during later combines and invert
12488   // this transform.
12489   switch (N0.getOpcode()) {
12490   case ISD::ADD:
12491   case ISD::SUB:
12492   case ISD::MUL:
12493   case ISD::AND:
12494   case ISD::OR:
12495   case ISD::XOR:
12496     if (!LegalOperations && N0.hasOneUse() &&
12497         (isConstantOrConstantVector(N0.getOperand(0), true) ||
12498          isConstantOrConstantVector(N0.getOperand(1), true))) {
12499       // TODO: We already restricted this to pre-legalization, but for vectors
12500       // we are extra cautious to not create an unsupported operation.
12501       // Target-specific changes are likely needed to avoid regressions here.
12502       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
12503         SDLoc DL(N);
12504         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
12505         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
12506         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
12507       }
12508     }
12509     break;
12510   case ISD::USUBSAT:
12511     // Truncate the USUBSAT only if LHS is a known zero-extension, its not
12512     // enough to know that the upper bits are zero we must ensure that we don't
12513     // introduce an extra truncate.
12514     if (!LegalOperations && N0.hasOneUse() &&
12515         N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12516         N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
12517             VT.getScalarSizeInBits() &&
12518         hasOperation(N0.getOpcode(), VT)) {
12519       return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
12520                                  DAG, SDLoc(N));
12521     }
12522     break;
12523   }
12524 
12525   return SDValue();
12526 }
12527 
getBuildPairElt(SDNode * N,unsigned i)12528 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
12529   SDValue Elt = N->getOperand(i);
12530   if (Elt.getOpcode() != ISD::MERGE_VALUES)
12531     return Elt.getNode();
12532   return Elt.getOperand(Elt.getResNo()).getNode();
12533 }
12534 
12535 /// build_pair (load, load) -> load
12536 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)12537 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
12538   assert(N->getOpcode() == ISD::BUILD_PAIR);
12539 
12540   LoadSDNode *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
12541   LoadSDNode *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
12542 
12543   // A BUILD_PAIR is always having the least significant part in elt 0 and the
12544   // most significant part in elt 1. So when combining into one large load, we
12545   // need to consider the endianness.
12546   if (DAG.getDataLayout().isBigEndian())
12547     std::swap(LD1, LD2);
12548 
12549   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !LD1->hasOneUse() ||
12550       LD1->getAddressSpace() != LD2->getAddressSpace())
12551     return SDValue();
12552   EVT LD1VT = LD1->getValueType(0);
12553   unsigned LD1Bytes = LD1VT.getStoreSize();
12554   if (ISD::isNON_EXTLoad(LD2) && LD2->hasOneUse() &&
12555       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1)) {
12556     Align Alignment = LD1->getAlign();
12557     Align NewAlign = DAG.getDataLayout().getABITypeAlign(
12558         VT.getTypeForEVT(*DAG.getContext()));
12559 
12560     if (NewAlign <= Alignment &&
12561         (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)))
12562       return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
12563                          LD1->getPointerInfo(), Alignment);
12564   }
12565 
12566   return SDValue();
12567 }
12568 
getPPCf128HiElementSelector(const SelectionDAG & DAG)12569 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
12570   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
12571   // and Lo parts; on big-endian machines it doesn't.
12572   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
12573 }
12574 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)12575 static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
12576                                     const TargetLowering &TLI) {
12577   // If this is not a bitcast to an FP type or if the target doesn't have
12578   // IEEE754-compliant FP logic, we're done.
12579   EVT VT = N->getValueType(0);
12580   if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
12581     return SDValue();
12582 
12583   // TODO: Handle cases where the integer constant is a different scalar
12584   // bitwidth to the FP.
12585   SDValue N0 = N->getOperand(0);
12586   EVT SourceVT = N0.getValueType();
12587   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
12588     return SDValue();
12589 
12590   unsigned FPOpcode;
12591   APInt SignMask;
12592   switch (N0.getOpcode()) {
12593   case ISD::AND:
12594     FPOpcode = ISD::FABS;
12595     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
12596     break;
12597   case ISD::XOR:
12598     FPOpcode = ISD::FNEG;
12599     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
12600     break;
12601   case ISD::OR:
12602     FPOpcode = ISD::FABS;
12603     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
12604     break;
12605   default:
12606     return SDValue();
12607   }
12608 
12609   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
12610   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
12611   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
12612   //   fneg (fabs X)
12613   SDValue LogicOp0 = N0.getOperand(0);
12614   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
12615   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
12616       LogicOp0.getOpcode() == ISD::BITCAST &&
12617       LogicOp0.getOperand(0).getValueType() == VT) {
12618     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
12619     NumFPLogicOpsConv++;
12620     if (N0.getOpcode() == ISD::OR)
12621       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
12622     return FPOp;
12623   }
12624 
12625   return SDValue();
12626 }
12627 
visitBITCAST(SDNode * N)12628 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
12629   SDValue N0 = N->getOperand(0);
12630   EVT VT = N->getValueType(0);
12631 
12632   if (N0.isUndef())
12633     return DAG.getUNDEF(VT);
12634 
12635   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
12636   // Only do this before legalize types, unless both types are integer and the
12637   // scalar type is legal. Only do this before legalize ops, since the target
12638   // maybe depending on the bitcast.
12639   // First check to see if this is all constant.
12640   // TODO: Support FP bitcasts after legalize types.
12641   if (VT.isVector() &&
12642       (!LegalTypes ||
12643        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
12644         TLI.isTypeLegal(VT.getVectorElementType()))) &&
12645       N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() &&
12646       cast<BuildVectorSDNode>(N0)->isConstant())
12647     return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
12648                                              VT.getVectorElementType());
12649 
12650   // If the input is a constant, let getNode fold it.
12651   if (isIntOrFPConstant(N0)) {
12652     // If we can't allow illegal operations, we need to check that this is just
12653     // a fp -> int or int -> conversion and that the resulting operation will
12654     // be legal.
12655     if (!LegalOperations ||
12656         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
12657          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
12658         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
12659          TLI.isOperationLegal(ISD::Constant, VT))) {
12660       SDValue C = DAG.getBitcast(VT, N0);
12661       if (C.getNode() != N)
12662         return C;
12663     }
12664   }
12665 
12666   // (conv (conv x, t1), t2) -> (conv x, t2)
12667   if (N0.getOpcode() == ISD::BITCAST)
12668     return DAG.getBitcast(VT, N0.getOperand(0));
12669 
12670   // fold (conv (load x)) -> (load (conv*)x)
12671   // If the resultant load doesn't need a higher alignment than the original!
12672   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
12673       // Do not remove the cast if the types differ in endian layout.
12674       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
12675           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
12676       // If the load is volatile, we only want to change the load type if the
12677       // resulting load is legal. Otherwise we might increase the number of
12678       // memory accesses. We don't care if the original type was legal or not
12679       // as we assume software couldn't rely on the number of accesses of an
12680       // illegal type.
12681       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
12682        TLI.isOperationLegal(ISD::LOAD, VT))) {
12683     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12684 
12685     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
12686                                     *LN0->getMemOperand())) {
12687       SDValue Load =
12688           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
12689                       LN0->getPointerInfo(), LN0->getAlign(),
12690                       LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
12691       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
12692       return Load;
12693     }
12694   }
12695 
12696   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
12697     return V;
12698 
12699   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
12700   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
12701   //
12702   // For ppc_fp128:
12703   // fold (bitcast (fneg x)) ->
12704   //     flipbit = signbit
12705   //     (xor (bitcast x) (build_pair flipbit, flipbit))
12706   //
12707   // fold (bitcast (fabs x)) ->
12708   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
12709   //     (xor (bitcast x) (build_pair flipbit, flipbit))
12710   // This often reduces constant pool loads.
12711   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
12712        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
12713       N0.getNode()->hasOneUse() && VT.isInteger() &&
12714       !VT.isVector() && !N0.getValueType().isVector()) {
12715     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
12716     AddToWorklist(NewConv.getNode());
12717 
12718     SDLoc DL(N);
12719     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
12720       assert(VT.getSizeInBits() == 128);
12721       SDValue SignBit = DAG.getConstant(
12722           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
12723       SDValue FlipBit;
12724       if (N0.getOpcode() == ISD::FNEG) {
12725         FlipBit = SignBit;
12726         AddToWorklist(FlipBit.getNode());
12727       } else {
12728         assert(N0.getOpcode() == ISD::FABS);
12729         SDValue Hi =
12730             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
12731                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
12732                                               SDLoc(NewConv)));
12733         AddToWorklist(Hi.getNode());
12734         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
12735         AddToWorklist(FlipBit.getNode());
12736       }
12737       SDValue FlipBits =
12738           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
12739       AddToWorklist(FlipBits.getNode());
12740       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
12741     }
12742     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
12743     if (N0.getOpcode() == ISD::FNEG)
12744       return DAG.getNode(ISD::XOR, DL, VT,
12745                          NewConv, DAG.getConstant(SignBit, DL, VT));
12746     assert(N0.getOpcode() == ISD::FABS);
12747     return DAG.getNode(ISD::AND, DL, VT,
12748                        NewConv, DAG.getConstant(~SignBit, DL, VT));
12749   }
12750 
12751   // fold (bitconvert (fcopysign cst, x)) ->
12752   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
12753   // Note that we don't handle (copysign x, cst) because this can always be
12754   // folded to an fneg or fabs.
12755   //
12756   // For ppc_fp128:
12757   // fold (bitcast (fcopysign cst, x)) ->
12758   //     flipbit = (and (extract_element
12759   //                     (xor (bitcast cst), (bitcast x)), 0),
12760   //                    signbit)
12761   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
12762   if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse() &&
12763       isa<ConstantFPSDNode>(N0.getOperand(0)) &&
12764       VT.isInteger() && !VT.isVector()) {
12765     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
12766     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
12767     if (isTypeLegal(IntXVT)) {
12768       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
12769       AddToWorklist(X.getNode());
12770 
12771       // If X has a different width than the result/lhs, sext it or truncate it.
12772       unsigned VTWidth = VT.getSizeInBits();
12773       if (OrigXWidth < VTWidth) {
12774         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
12775         AddToWorklist(X.getNode());
12776       } else if (OrigXWidth > VTWidth) {
12777         // To get the sign bit in the right place, we have to shift it right
12778         // before truncating.
12779         SDLoc DL(X);
12780         X = DAG.getNode(ISD::SRL, DL,
12781                         X.getValueType(), X,
12782                         DAG.getConstant(OrigXWidth-VTWidth, DL,
12783                                         X.getValueType()));
12784         AddToWorklist(X.getNode());
12785         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
12786         AddToWorklist(X.getNode());
12787       }
12788 
12789       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
12790         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
12791         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
12792         AddToWorklist(Cst.getNode());
12793         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
12794         AddToWorklist(X.getNode());
12795         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
12796         AddToWorklist(XorResult.getNode());
12797         SDValue XorResult64 = DAG.getNode(
12798             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
12799             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
12800                                   SDLoc(XorResult)));
12801         AddToWorklist(XorResult64.getNode());
12802         SDValue FlipBit =
12803             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
12804                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
12805         AddToWorklist(FlipBit.getNode());
12806         SDValue FlipBits =
12807             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
12808         AddToWorklist(FlipBits.getNode());
12809         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
12810       }
12811       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
12812       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
12813                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
12814       AddToWorklist(X.getNode());
12815 
12816       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
12817       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
12818                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
12819       AddToWorklist(Cst.getNode());
12820 
12821       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
12822     }
12823   }
12824 
12825   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
12826   if (N0.getOpcode() == ISD::BUILD_PAIR)
12827     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
12828       return CombineLD;
12829 
12830   // Remove double bitcasts from shuffles - this is often a legacy of
12831   // XformToShuffleWithZero being used to combine bitmaskings (of
12832   // float vectors bitcast to integer vectors) into shuffles.
12833   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
12834   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
12835       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
12836       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
12837       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
12838     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
12839 
12840     // If operands are a bitcast, peek through if it casts the original VT.
12841     // If operands are a constant, just bitcast back to original VT.
12842     auto PeekThroughBitcast = [&](SDValue Op) {
12843       if (Op.getOpcode() == ISD::BITCAST &&
12844           Op.getOperand(0).getValueType() == VT)
12845         return SDValue(Op.getOperand(0));
12846       if (Op.isUndef() || ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
12847           ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
12848         return DAG.getBitcast(VT, Op);
12849       return SDValue();
12850     };
12851 
12852     // FIXME: If either input vector is bitcast, try to convert the shuffle to
12853     // the result type of this bitcast. This would eliminate at least one
12854     // bitcast. See the transform in InstCombine.
12855     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
12856     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
12857     if (!(SV0 && SV1))
12858       return SDValue();
12859 
12860     int MaskScale =
12861         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
12862     SmallVector<int, 8> NewMask;
12863     for (int M : SVN->getMask())
12864       for (int i = 0; i != MaskScale; ++i)
12865         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
12866 
12867     SDValue LegalShuffle =
12868         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
12869     if (LegalShuffle)
12870       return LegalShuffle;
12871   }
12872 
12873   return SDValue();
12874 }
12875 
visitBUILD_PAIR(SDNode * N)12876 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
12877   EVT VT = N->getValueType(0);
12878   return CombineConsecutiveLoads(N, VT);
12879 }
12880 
visitFREEZE(SDNode * N)12881 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
12882   SDValue N0 = N->getOperand(0);
12883 
12884   if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
12885     return N0;
12886 
12887   return SDValue();
12888 }
12889 
12890 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
12891 /// operands. DstEltVT indicates the destination element value type.
12892 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)12893 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
12894   EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
12895 
12896   // If this is already the right type, we're done.
12897   if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
12898 
12899   unsigned SrcBitSize = SrcEltVT.getSizeInBits();
12900   unsigned DstBitSize = DstEltVT.getSizeInBits();
12901 
12902   // If this is a conversion of N elements of one type to N elements of another
12903   // type, convert each element.  This handles FP<->INT cases.
12904   if (SrcBitSize == DstBitSize) {
12905     SmallVector<SDValue, 8> Ops;
12906     for (SDValue Op : BV->op_values()) {
12907       // If the vector element type is not legal, the BUILD_VECTOR operands
12908       // are promoted and implicitly truncated.  Make that explicit here.
12909       if (Op.getValueType() != SrcEltVT)
12910         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
12911       Ops.push_back(DAG.getBitcast(DstEltVT, Op));
12912       AddToWorklist(Ops.back().getNode());
12913     }
12914     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
12915                               BV->getValueType(0).getVectorNumElements());
12916     return DAG.getBuildVector(VT, SDLoc(BV), Ops);
12917   }
12918 
12919   // Otherwise, we're growing or shrinking the elements.  To avoid having to
12920   // handle annoying details of growing/shrinking FP values, we convert them to
12921   // int first.
12922   if (SrcEltVT.isFloatingPoint()) {
12923     // Convert the input float vector to a int vector where the elements are the
12924     // same sizes.
12925     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
12926     BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
12927     SrcEltVT = IntVT;
12928   }
12929 
12930   // Now we know the input is an integer vector.  If the output is a FP type,
12931   // convert to integer first, then to FP of the right size.
12932   if (DstEltVT.isFloatingPoint()) {
12933     EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
12934     SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
12935 
12936     // Next, convert to FP elements of the same size.
12937     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
12938   }
12939 
12940   SDLoc DL(BV);
12941 
12942   // Okay, we know the src/dst types are both integers of differing types.
12943   // Handling growing first.
12944   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
12945   if (SrcBitSize < DstBitSize) {
12946     unsigned NumInputsPerOutput = DstBitSize/SrcBitSize;
12947 
12948     SmallVector<SDValue, 8> Ops;
12949     for (unsigned i = 0, e = BV->getNumOperands(); i != e;
12950          i += NumInputsPerOutput) {
12951       bool isLE = DAG.getDataLayout().isLittleEndian();
12952       APInt NewBits = APInt(DstBitSize, 0);
12953       bool EltIsUndef = true;
12954       for (unsigned j = 0; j != NumInputsPerOutput; ++j) {
12955         // Shift the previously computed bits over.
12956         NewBits <<= SrcBitSize;
12957         SDValue Op = BV->getOperand(i+ (isLE ? (NumInputsPerOutput-j-1) : j));
12958         if (Op.isUndef()) continue;
12959         EltIsUndef = false;
12960 
12961         NewBits |= cast<ConstantSDNode>(Op)->getAPIntValue().
12962                    zextOrTrunc(SrcBitSize).zext(DstBitSize);
12963       }
12964 
12965       if (EltIsUndef)
12966         Ops.push_back(DAG.getUNDEF(DstEltVT));
12967       else
12968         Ops.push_back(DAG.getConstant(NewBits, DL, DstEltVT));
12969     }
12970 
12971     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
12972     return DAG.getBuildVector(VT, DL, Ops);
12973   }
12974 
12975   // Finally, this must be the case where we are shrinking elements: each input
12976   // turns into multiple outputs.
12977   unsigned NumOutputsPerInput = SrcBitSize/DstBitSize;
12978   EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
12979                             NumOutputsPerInput*BV->getNumOperands());
12980   SmallVector<SDValue, 8> Ops;
12981 
12982   for (const SDValue &Op : BV->op_values()) {
12983     if (Op.isUndef()) {
12984       Ops.append(NumOutputsPerInput, DAG.getUNDEF(DstEltVT));
12985       continue;
12986     }
12987 
12988     APInt OpVal = cast<ConstantSDNode>(Op)->
12989                   getAPIntValue().zextOrTrunc(SrcBitSize);
12990 
12991     for (unsigned j = 0; j != NumOutputsPerInput; ++j) {
12992       APInt ThisVal = OpVal.trunc(DstBitSize);
12993       Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT));
12994       OpVal.lshrInPlace(DstBitSize);
12995     }
12996 
12997     // For big endian targets, swap the order of the pieces of each element.
12998     if (DAG.getDataLayout().isBigEndian())
12999       std::reverse(Ops.end()-NumOutputsPerInput, Ops.end());
13000   }
13001 
13002   return DAG.getBuildVector(VT, DL, Ops);
13003 }
13004 
13005 /// Try to perform FMA combining on a given FADD node.
visitFADDForFMACombine(SDNode * N)13006 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
13007   SDValue N0 = N->getOperand(0);
13008   SDValue N1 = N->getOperand(1);
13009   EVT VT = N->getValueType(0);
13010   SDLoc SL(N);
13011 
13012   const TargetOptions &Options = DAG.getTarget().Options;
13013 
13014   // Floating-point multiply-add with intermediate rounding.
13015   bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
13016 
13017   // Floating-point multiply-add without intermediate rounding.
13018   bool HasFMA =
13019       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
13020       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
13021 
13022   // No valid opcode, do not combine.
13023   if (!HasFMAD && !HasFMA)
13024     return SDValue();
13025 
13026   bool CanReassociate =
13027       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
13028   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
13029                               Options.UnsafeFPMath || HasFMAD);
13030   // If the addition is not contractable, do not combine.
13031   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
13032     return SDValue();
13033 
13034   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
13035     return SDValue();
13036 
13037   // Always prefer FMAD to FMA for precision.
13038   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
13039   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
13040 
13041   // Is the node an FMUL and contractable either due to global flags or
13042   // SDNodeFlags.
13043   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
13044     if (N.getOpcode() != ISD::FMUL)
13045       return false;
13046     return AllowFusionGlobally || N->getFlags().hasAllowContract();
13047   };
13048   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
13049   // prefer to fold the multiply with fewer uses.
13050   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
13051     if (N0.getNode()->use_size() > N1.getNode()->use_size())
13052       std::swap(N0, N1);
13053   }
13054 
13055   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
13056   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
13057     return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
13058                        N0.getOperand(1), N1);
13059   }
13060 
13061   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
13062   // Note: Commutes FADD operands.
13063   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
13064     return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
13065                        N1.getOperand(1), N0);
13066   }
13067 
13068   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
13069   // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
13070   // This requires reassociation because it changes the order of operations.
13071   SDValue FMA, E;
13072   if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode &&
13073       N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
13074       N0.getOperand(2).hasOneUse()) {
13075     FMA = N0;
13076     E = N1;
13077   } else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode &&
13078              N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
13079              N1.getOperand(2).hasOneUse()) {
13080     FMA = N1;
13081     E = N0;
13082   }
13083   if (FMA && E) {
13084     SDValue A = FMA.getOperand(0);
13085     SDValue B = FMA.getOperand(1);
13086     SDValue C = FMA.getOperand(2).getOperand(0);
13087     SDValue D = FMA.getOperand(2).getOperand(1);
13088     SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
13089     return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE);
13090   }
13091 
13092   // Look through FP_EXTEND nodes to do more combining.
13093 
13094   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
13095   if (N0.getOpcode() == ISD::FP_EXTEND) {
13096     SDValue N00 = N0.getOperand(0);
13097     if (isContractableFMUL(N00) &&
13098         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13099                             N00.getValueType())) {
13100       return DAG.getNode(PreferredFusedOpcode, SL, VT,
13101                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
13102                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
13103                          N1);
13104     }
13105   }
13106 
13107   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
13108   // Note: Commutes FADD operands.
13109   if (N1.getOpcode() == ISD::FP_EXTEND) {
13110     SDValue N10 = N1.getOperand(0);
13111     if (isContractableFMUL(N10) &&
13112         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13113                             N10.getValueType())) {
13114       return DAG.getNode(PreferredFusedOpcode, SL, VT,
13115                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
13116                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)),
13117                          N0);
13118     }
13119   }
13120 
13121   // More folding opportunities when target permits.
13122   if (Aggressive) {
13123     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
13124     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
13125     auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
13126                                     SDValue Z) {
13127       return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
13128                          DAG.getNode(PreferredFusedOpcode, SL, VT,
13129                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
13130                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
13131                                      Z));
13132     };
13133     if (N0.getOpcode() == PreferredFusedOpcode) {
13134       SDValue N02 = N0.getOperand(2);
13135       if (N02.getOpcode() == ISD::FP_EXTEND) {
13136         SDValue N020 = N02.getOperand(0);
13137         if (isContractableFMUL(N020) &&
13138             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13139                                 N020.getValueType())) {
13140           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
13141                                       N020.getOperand(0), N020.getOperand(1),
13142                                       N1);
13143         }
13144       }
13145     }
13146 
13147     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
13148     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
13149     // FIXME: This turns two single-precision and one double-precision
13150     // operation into two double-precision operations, which might not be
13151     // interesting for all targets, especially GPUs.
13152     auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
13153                                     SDValue Z) {
13154       return DAG.getNode(
13155           PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
13156           DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
13157           DAG.getNode(PreferredFusedOpcode, SL, VT,
13158                       DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
13159                       DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
13160     };
13161     if (N0.getOpcode() == ISD::FP_EXTEND) {
13162       SDValue N00 = N0.getOperand(0);
13163       if (N00.getOpcode() == PreferredFusedOpcode) {
13164         SDValue N002 = N00.getOperand(2);
13165         if (isContractableFMUL(N002) &&
13166             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13167                                 N00.getValueType())) {
13168           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
13169                                       N002.getOperand(0), N002.getOperand(1),
13170                                       N1);
13171         }
13172       }
13173     }
13174 
13175     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
13176     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
13177     if (N1.getOpcode() == PreferredFusedOpcode) {
13178       SDValue N12 = N1.getOperand(2);
13179       if (N12.getOpcode() == ISD::FP_EXTEND) {
13180         SDValue N120 = N12.getOperand(0);
13181         if (isContractableFMUL(N120) &&
13182             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13183                                 N120.getValueType())) {
13184           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
13185                                       N120.getOperand(0), N120.getOperand(1),
13186                                       N0);
13187         }
13188       }
13189     }
13190 
13191     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
13192     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
13193     // FIXME: This turns two single-precision and one double-precision
13194     // operation into two double-precision operations, which might not be
13195     // interesting for all targets, especially GPUs.
13196     if (N1.getOpcode() == ISD::FP_EXTEND) {
13197       SDValue N10 = N1.getOperand(0);
13198       if (N10.getOpcode() == PreferredFusedOpcode) {
13199         SDValue N102 = N10.getOperand(2);
13200         if (isContractableFMUL(N102) &&
13201             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13202                                 N10.getValueType())) {
13203           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
13204                                       N102.getOperand(0), N102.getOperand(1),
13205                                       N0);
13206         }
13207       }
13208     }
13209   }
13210 
13211   return SDValue();
13212 }
13213 
13214 /// Try to perform FMA combining on a given FSUB node.
visitFSUBForFMACombine(SDNode * N)13215 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
13216   SDValue N0 = N->getOperand(0);
13217   SDValue N1 = N->getOperand(1);
13218   EVT VT = N->getValueType(0);
13219   SDLoc SL(N);
13220 
13221   const TargetOptions &Options = DAG.getTarget().Options;
13222   // Floating-point multiply-add with intermediate rounding.
13223   bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
13224 
13225   // Floating-point multiply-add without intermediate rounding.
13226   bool HasFMA =
13227       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
13228       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
13229 
13230   // No valid opcode, do not combine.
13231   if (!HasFMAD && !HasFMA)
13232     return SDValue();
13233 
13234   const SDNodeFlags Flags = N->getFlags();
13235   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
13236                               Options.UnsafeFPMath || HasFMAD);
13237 
13238   // If the subtraction is not contractable, do not combine.
13239   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
13240     return SDValue();
13241 
13242   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
13243     return SDValue();
13244 
13245   // Always prefer FMAD to FMA for precision.
13246   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
13247   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
13248   bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
13249 
13250   // Is the node an FMUL and contractable either due to global flags or
13251   // SDNodeFlags.
13252   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
13253     if (N.getOpcode() != ISD::FMUL)
13254       return false;
13255     return AllowFusionGlobally || N->getFlags().hasAllowContract();
13256   };
13257 
13258   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
13259   auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
13260     if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
13261       return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
13262                          XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z));
13263     }
13264     return SDValue();
13265   };
13266 
13267   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
13268   // Note: Commutes FSUB operands.
13269   auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
13270     if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
13271       return DAG.getNode(PreferredFusedOpcode, SL, VT,
13272                          DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
13273                          YZ.getOperand(1), X);
13274     }
13275     return SDValue();
13276   };
13277 
13278   // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
13279   // prefer to fold the multiply with fewer uses.
13280   if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
13281       (N0.getNode()->use_size() > N1.getNode()->use_size())) {
13282     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
13283     if (SDValue V = tryToFoldXSubYZ(N0, N1))
13284       return V;
13285     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
13286     if (SDValue V = tryToFoldXYSubZ(N0, N1))
13287       return V;
13288   } else {
13289     // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
13290     if (SDValue V = tryToFoldXYSubZ(N0, N1))
13291       return V;
13292     // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
13293     if (SDValue V = tryToFoldXSubYZ(N0, N1))
13294       return V;
13295   }
13296 
13297   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
13298   if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
13299       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
13300     SDValue N00 = N0.getOperand(0).getOperand(0);
13301     SDValue N01 = N0.getOperand(0).getOperand(1);
13302     return DAG.getNode(PreferredFusedOpcode, SL, VT,
13303                        DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
13304                        DAG.getNode(ISD::FNEG, SL, VT, N1));
13305   }
13306 
13307   // Look through FP_EXTEND nodes to do more combining.
13308 
13309   // fold (fsub (fpext (fmul x, y)), z)
13310   //   -> (fma (fpext x), (fpext y), (fneg z))
13311   if (N0.getOpcode() == ISD::FP_EXTEND) {
13312     SDValue N00 = N0.getOperand(0);
13313     if (isContractableFMUL(N00) &&
13314         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13315                             N00.getValueType())) {
13316       return DAG.getNode(PreferredFusedOpcode, SL, VT,
13317                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
13318                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
13319                          DAG.getNode(ISD::FNEG, SL, VT, N1));
13320     }
13321   }
13322 
13323   // fold (fsub x, (fpext (fmul y, z)))
13324   //   -> (fma (fneg (fpext y)), (fpext z), x)
13325   // Note: Commutes FSUB operands.
13326   if (N1.getOpcode() == ISD::FP_EXTEND) {
13327     SDValue N10 = N1.getOperand(0);
13328     if (isContractableFMUL(N10) &&
13329         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13330                             N10.getValueType())) {
13331       return DAG.getNode(
13332           PreferredFusedOpcode, SL, VT,
13333           DAG.getNode(ISD::FNEG, SL, VT,
13334                       DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
13335           DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
13336     }
13337   }
13338 
13339   // fold (fsub (fpext (fneg (fmul, x, y))), z)
13340   //   -> (fneg (fma (fpext x), (fpext y), z))
13341   // Note: This could be removed with appropriate canonicalization of the
13342   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
13343   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
13344   // from implementing the canonicalization in visitFSUB.
13345   if (N0.getOpcode() == ISD::FP_EXTEND) {
13346     SDValue N00 = N0.getOperand(0);
13347     if (N00.getOpcode() == ISD::FNEG) {
13348       SDValue N000 = N00.getOperand(0);
13349       if (isContractableFMUL(N000) &&
13350           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13351                               N00.getValueType())) {
13352         return DAG.getNode(
13353             ISD::FNEG, SL, VT,
13354             DAG.getNode(PreferredFusedOpcode, SL, VT,
13355                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
13356                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
13357                         N1));
13358       }
13359     }
13360   }
13361 
13362   // fold (fsub (fneg (fpext (fmul, x, y))), z)
13363   //   -> (fneg (fma (fpext x)), (fpext y), z)
13364   // Note: This could be removed with appropriate canonicalization of the
13365   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
13366   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
13367   // from implementing the canonicalization in visitFSUB.
13368   if (N0.getOpcode() == ISD::FNEG) {
13369     SDValue N00 = N0.getOperand(0);
13370     if (N00.getOpcode() == ISD::FP_EXTEND) {
13371       SDValue N000 = N00.getOperand(0);
13372       if (isContractableFMUL(N000) &&
13373           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13374                               N000.getValueType())) {
13375         return DAG.getNode(
13376             ISD::FNEG, SL, VT,
13377             DAG.getNode(PreferredFusedOpcode, SL, VT,
13378                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
13379                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
13380                         N1));
13381       }
13382     }
13383   }
13384 
13385   auto isReassociable = [Options](SDNode *N) {
13386     return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
13387   };
13388 
13389   auto isContractableAndReassociableFMUL = [isContractableFMUL,
13390                                             isReassociable](SDValue N) {
13391     return isContractableFMUL(N) && isReassociable(N.getNode());
13392   };
13393 
13394   // More folding opportunities when target permits.
13395   if (Aggressive && isReassociable(N)) {
13396     bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
13397     // fold (fsub (fma x, y, (fmul u, v)), z)
13398     //   -> (fma x, y (fma u, v, (fneg z)))
13399     if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
13400         isContractableAndReassociableFMUL(N0.getOperand(2)) &&
13401         N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
13402       return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
13403                          N0.getOperand(1),
13404                          DAG.getNode(PreferredFusedOpcode, SL, VT,
13405                                      N0.getOperand(2).getOperand(0),
13406                                      N0.getOperand(2).getOperand(1),
13407                                      DAG.getNode(ISD::FNEG, SL, VT, N1)));
13408     }
13409 
13410     // fold (fsub x, (fma y, z, (fmul u, v)))
13411     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
13412     if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
13413         isContractableAndReassociableFMUL(N1.getOperand(2)) &&
13414         N1->hasOneUse() && NoSignedZero) {
13415       SDValue N20 = N1.getOperand(2).getOperand(0);
13416       SDValue N21 = N1.getOperand(2).getOperand(1);
13417       return DAG.getNode(
13418           PreferredFusedOpcode, SL, VT,
13419           DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
13420           DAG.getNode(PreferredFusedOpcode, SL, VT,
13421                       DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
13422     }
13423 
13424     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
13425     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
13426     if (N0.getOpcode() == PreferredFusedOpcode &&
13427         N0->hasOneUse()) {
13428       SDValue N02 = N0.getOperand(2);
13429       if (N02.getOpcode() == ISD::FP_EXTEND) {
13430         SDValue N020 = N02.getOperand(0);
13431         if (isContractableAndReassociableFMUL(N020) &&
13432             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13433                                 N020.getValueType())) {
13434           return DAG.getNode(
13435               PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
13436               DAG.getNode(
13437                   PreferredFusedOpcode, SL, VT,
13438                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
13439                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
13440                   DAG.getNode(ISD::FNEG, SL, VT, N1)));
13441         }
13442       }
13443     }
13444 
13445     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
13446     //   -> (fma (fpext x), (fpext y),
13447     //           (fma (fpext u), (fpext v), (fneg z)))
13448     // FIXME: This turns two single-precision and one double-precision
13449     // operation into two double-precision operations, which might not be
13450     // interesting for all targets, especially GPUs.
13451     if (N0.getOpcode() == ISD::FP_EXTEND) {
13452       SDValue N00 = N0.getOperand(0);
13453       if (N00.getOpcode() == PreferredFusedOpcode) {
13454         SDValue N002 = N00.getOperand(2);
13455         if (isContractableAndReassociableFMUL(N002) &&
13456             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13457                                 N00.getValueType())) {
13458           return DAG.getNode(
13459               PreferredFusedOpcode, SL, VT,
13460               DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
13461               DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
13462               DAG.getNode(
13463                   PreferredFusedOpcode, SL, VT,
13464                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
13465                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
13466                   DAG.getNode(ISD::FNEG, SL, VT, N1)));
13467         }
13468       }
13469     }
13470 
13471     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
13472     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
13473     if (N1.getOpcode() == PreferredFusedOpcode &&
13474         N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
13475         N1->hasOneUse()) {
13476       SDValue N120 = N1.getOperand(2).getOperand(0);
13477       if (isContractableAndReassociableFMUL(N120) &&
13478           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13479                               N120.getValueType())) {
13480         SDValue N1200 = N120.getOperand(0);
13481         SDValue N1201 = N120.getOperand(1);
13482         return DAG.getNode(
13483             PreferredFusedOpcode, SL, VT,
13484             DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
13485             DAG.getNode(PreferredFusedOpcode, SL, VT,
13486                         DAG.getNode(ISD::FNEG, SL, VT,
13487                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
13488                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
13489       }
13490     }
13491 
13492     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
13493     //   -> (fma (fneg (fpext y)), (fpext z),
13494     //           (fma (fneg (fpext u)), (fpext v), x))
13495     // FIXME: This turns two single-precision and one double-precision
13496     // operation into two double-precision operations, which might not be
13497     // interesting for all targets, especially GPUs.
13498     if (N1.getOpcode() == ISD::FP_EXTEND &&
13499         N1.getOperand(0).getOpcode() == PreferredFusedOpcode) {
13500       SDValue CvtSrc = N1.getOperand(0);
13501       SDValue N100 = CvtSrc.getOperand(0);
13502       SDValue N101 = CvtSrc.getOperand(1);
13503       SDValue N102 = CvtSrc.getOperand(2);
13504       if (isContractableAndReassociableFMUL(N102) &&
13505           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
13506                               CvtSrc.getValueType())) {
13507         SDValue N1020 = N102.getOperand(0);
13508         SDValue N1021 = N102.getOperand(1);
13509         return DAG.getNode(
13510             PreferredFusedOpcode, SL, VT,
13511             DAG.getNode(ISD::FNEG, SL, VT,
13512                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)),
13513             DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
13514             DAG.getNode(PreferredFusedOpcode, SL, VT,
13515                         DAG.getNode(ISD::FNEG, SL, VT,
13516                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
13517                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
13518       }
13519     }
13520   }
13521 
13522   return SDValue();
13523 }
13524 
13525 /// Try to perform FMA combining on a given FMUL node based on the distributive
13526 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
13527 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)13528 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
13529   SDValue N0 = N->getOperand(0);
13530   SDValue N1 = N->getOperand(1);
13531   EVT VT = N->getValueType(0);
13532   SDLoc SL(N);
13533 
13534   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
13535 
13536   const TargetOptions &Options = DAG.getTarget().Options;
13537 
13538   // The transforms below are incorrect when x == 0 and y == inf, because the
13539   // intermediate multiplication produces a nan.
13540   if (!Options.NoInfsFPMath)
13541     return SDValue();
13542 
13543   // Floating-point multiply-add without intermediate rounding.
13544   bool HasFMA =
13545       (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
13546       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
13547       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
13548 
13549   // Floating-point multiply-add with intermediate rounding. This can result
13550   // in a less precise result due to the changed rounding order.
13551   bool HasFMAD = Options.UnsafeFPMath &&
13552                  (LegalOperations && TLI.isFMADLegal(DAG, N));
13553 
13554   // No valid opcode, do not combine.
13555   if (!HasFMAD && !HasFMA)
13556     return SDValue();
13557 
13558   // Always prefer FMAD to FMA for precision.
13559   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
13560   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
13561 
13562   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
13563   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
13564   auto FuseFADD = [&](SDValue X, SDValue Y) {
13565     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
13566       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
13567         if (C->isExactlyValue(+1.0))
13568           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
13569                              Y);
13570         if (C->isExactlyValue(-1.0))
13571           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
13572                              DAG.getNode(ISD::FNEG, SL, VT, Y));
13573       }
13574     }
13575     return SDValue();
13576   };
13577 
13578   if (SDValue FMA = FuseFADD(N0, N1))
13579     return FMA;
13580   if (SDValue FMA = FuseFADD(N1, N0))
13581     return FMA;
13582 
13583   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
13584   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
13585   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
13586   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
13587   auto FuseFSUB = [&](SDValue X, SDValue Y) {
13588     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
13589       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
13590         if (C0->isExactlyValue(+1.0))
13591           return DAG.getNode(PreferredFusedOpcode, SL, VT,
13592                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
13593                              Y);
13594         if (C0->isExactlyValue(-1.0))
13595           return DAG.getNode(PreferredFusedOpcode, SL, VT,
13596                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
13597                              DAG.getNode(ISD::FNEG, SL, VT, Y));
13598       }
13599       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
13600         if (C1->isExactlyValue(+1.0))
13601           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
13602                              DAG.getNode(ISD::FNEG, SL, VT, Y));
13603         if (C1->isExactlyValue(-1.0))
13604           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
13605                              Y);
13606       }
13607     }
13608     return SDValue();
13609   };
13610 
13611   if (SDValue FMA = FuseFSUB(N0, N1))
13612     return FMA;
13613   if (SDValue FMA = FuseFSUB(N1, N0))
13614     return FMA;
13615 
13616   return SDValue();
13617 }
13618 
visitFADD(SDNode * N)13619 SDValue DAGCombiner::visitFADD(SDNode *N) {
13620   SDValue N0 = N->getOperand(0);
13621   SDValue N1 = N->getOperand(1);
13622   bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
13623   bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
13624   EVT VT = N->getValueType(0);
13625   SDLoc DL(N);
13626   const TargetOptions &Options = DAG.getTarget().Options;
13627   SDNodeFlags Flags = N->getFlags();
13628   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13629 
13630   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
13631     return R;
13632 
13633   // fold vector ops
13634   if (VT.isVector())
13635     if (SDValue FoldedVOp = SimplifyVBinOp(N))
13636       return FoldedVOp;
13637 
13638   // fold (fadd c1, c2) -> c1 + c2
13639   if (N0CFP && N1CFP)
13640     return DAG.getNode(ISD::FADD, DL, VT, N0, N1);
13641 
13642   // canonicalize constant to RHS
13643   if (N0CFP && !N1CFP)
13644     return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
13645 
13646   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
13647   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
13648   if (N1C && N1C->isZero())
13649     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
13650       return N0;
13651 
13652   if (SDValue NewSel = foldBinOpIntoSelect(N))
13653     return NewSel;
13654 
13655   // fold (fadd A, (fneg B)) -> (fsub A, B)
13656   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
13657     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
13658             N1, DAG, LegalOperations, ForCodeSize))
13659       return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
13660 
13661   // fold (fadd (fneg A), B) -> (fsub B, A)
13662   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
13663     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
13664             N0, DAG, LegalOperations, ForCodeSize))
13665       return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
13666 
13667   auto isFMulNegTwo = [](SDValue FMul) {
13668     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
13669       return false;
13670     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
13671     return C && C->isExactlyValue(-2.0);
13672   };
13673 
13674   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
13675   if (isFMulNegTwo(N0)) {
13676     SDValue B = N0.getOperand(0);
13677     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
13678     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
13679   }
13680   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
13681   if (isFMulNegTwo(N1)) {
13682     SDValue B = N1.getOperand(0);
13683     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
13684     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
13685   }
13686 
13687   // No FP constant should be created after legalization as Instruction
13688   // Selection pass has a hard time dealing with FP constants.
13689   bool AllowNewConst = (Level < AfterLegalizeDAG);
13690 
13691   // If nnan is enabled, fold lots of things.
13692   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
13693     // If allowed, fold (fadd (fneg x), x) -> 0.0
13694     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
13695       return DAG.getConstantFP(0.0, DL, VT);
13696 
13697     // If allowed, fold (fadd x, (fneg x)) -> 0.0
13698     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
13699       return DAG.getConstantFP(0.0, DL, VT);
13700   }
13701 
13702   // If 'unsafe math' or reassoc and nsz, fold lots of things.
13703   // TODO: break out portions of the transformations below for which Unsafe is
13704   //       considered and which do not require both nsz and reassoc
13705   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
13706        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
13707       AllowNewConst) {
13708     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
13709     if (N1CFP && N0.getOpcode() == ISD::FADD &&
13710         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
13711       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
13712       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
13713     }
13714 
13715     // We can fold chains of FADD's of the same value into multiplications.
13716     // This transform is not safe in general because we are reducing the number
13717     // of rounding steps.
13718     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
13719       if (N0.getOpcode() == ISD::FMUL) {
13720         bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
13721         bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
13722 
13723         // (fadd (fmul x, c), x) -> (fmul x, c+1)
13724         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
13725           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
13726                                        DAG.getConstantFP(1.0, DL, VT));
13727           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
13728         }
13729 
13730         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
13731         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
13732             N1.getOperand(0) == N1.getOperand(1) &&
13733             N0.getOperand(0) == N1.getOperand(0)) {
13734           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
13735                                        DAG.getConstantFP(2.0, DL, VT));
13736           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
13737         }
13738       }
13739 
13740       if (N1.getOpcode() == ISD::FMUL) {
13741         bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
13742         bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
13743 
13744         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
13745         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
13746           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
13747                                        DAG.getConstantFP(1.0, DL, VT));
13748           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
13749         }
13750 
13751         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
13752         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
13753             N0.getOperand(0) == N0.getOperand(1) &&
13754             N1.getOperand(0) == N0.getOperand(0)) {
13755           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
13756                                        DAG.getConstantFP(2.0, DL, VT));
13757           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
13758         }
13759       }
13760 
13761       if (N0.getOpcode() == ISD::FADD) {
13762         bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
13763         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
13764         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
13765             (N0.getOperand(0) == N1)) {
13766           return DAG.getNode(ISD::FMUL, DL, VT, N1,
13767                              DAG.getConstantFP(3.0, DL, VT));
13768         }
13769       }
13770 
13771       if (N1.getOpcode() == ISD::FADD) {
13772         bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
13773         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
13774         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
13775             N1.getOperand(0) == N0) {
13776           return DAG.getNode(ISD::FMUL, DL, VT, N0,
13777                              DAG.getConstantFP(3.0, DL, VT));
13778         }
13779       }
13780 
13781       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
13782       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
13783           N0.getOperand(0) == N0.getOperand(1) &&
13784           N1.getOperand(0) == N1.getOperand(1) &&
13785           N0.getOperand(0) == N1.getOperand(0)) {
13786         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
13787                            DAG.getConstantFP(4.0, DL, VT));
13788       }
13789     }
13790   } // enable-unsafe-fp-math
13791 
13792   // FADD -> FMA combines:
13793   if (SDValue Fused = visitFADDForFMACombine(N)) {
13794     AddToWorklist(Fused.getNode());
13795     return Fused;
13796   }
13797   return SDValue();
13798 }
13799 
visitSTRICT_FADD(SDNode * N)13800 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
13801   SDValue Chain = N->getOperand(0);
13802   SDValue N0 = N->getOperand(1);
13803   SDValue N1 = N->getOperand(2);
13804   EVT VT = N->getValueType(0);
13805   EVT ChainVT = N->getValueType(1);
13806   SDLoc DL(N);
13807   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13808 
13809   // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
13810   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
13811     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
13812             N1, DAG, LegalOperations, ForCodeSize)) {
13813       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
13814                          {Chain, N0, NegN1});
13815     }
13816 
13817   // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
13818   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
13819     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
13820             N0, DAG, LegalOperations, ForCodeSize)) {
13821       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
13822                          {Chain, N1, NegN0});
13823     }
13824   return SDValue();
13825 }
13826 
visitFSUB(SDNode * N)13827 SDValue DAGCombiner::visitFSUB(SDNode *N) {
13828   SDValue N0 = N->getOperand(0);
13829   SDValue N1 = N->getOperand(1);
13830   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
13831   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
13832   EVT VT = N->getValueType(0);
13833   SDLoc DL(N);
13834   const TargetOptions &Options = DAG.getTarget().Options;
13835   const SDNodeFlags Flags = N->getFlags();
13836   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13837 
13838   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
13839     return R;
13840 
13841   // fold vector ops
13842   if (VT.isVector())
13843     if (SDValue FoldedVOp = SimplifyVBinOp(N))
13844       return FoldedVOp;
13845 
13846   // fold (fsub c1, c2) -> c1-c2
13847   if (N0CFP && N1CFP)
13848     return DAG.getNode(ISD::FSUB, DL, VT, N0, N1);
13849 
13850   if (SDValue NewSel = foldBinOpIntoSelect(N))
13851     return NewSel;
13852 
13853   // (fsub A, 0) -> A
13854   if (N1CFP && N1CFP->isZero()) {
13855     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
13856         Flags.hasNoSignedZeros()) {
13857       return N0;
13858     }
13859   }
13860 
13861   if (N0 == N1) {
13862     // (fsub x, x) -> 0.0
13863     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
13864       return DAG.getConstantFP(0.0f, DL, VT);
13865   }
13866 
13867   // (fsub -0.0, N1) -> -N1
13868   if (N0CFP && N0CFP->isZero()) {
13869     if (N0CFP->isNegative() ||
13870         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
13871       // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
13872       // flushed to zero, unless all users treat denorms as zero (DAZ).
13873       // FIXME: This transform will change the sign of a NaN and the behavior
13874       // of a signaling NaN. It is only valid when a NoNaN flag is present.
13875       DenormalMode DenormMode = DAG.getDenormalMode(VT);
13876       if (DenormMode == DenormalMode::getIEEE()) {
13877         if (SDValue NegN1 =
13878                 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
13879           return NegN1;
13880         if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
13881           return DAG.getNode(ISD::FNEG, DL, VT, N1);
13882       }
13883     }
13884   }
13885 
13886   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
13887        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
13888       N1.getOpcode() == ISD::FADD) {
13889     // X - (X + Y) -> -Y
13890     if (N0 == N1->getOperand(0))
13891       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
13892     // X - (Y + X) -> -Y
13893     if (N0 == N1->getOperand(1))
13894       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
13895   }
13896 
13897   // fold (fsub A, (fneg B)) -> (fadd A, B)
13898   if (SDValue NegN1 =
13899           TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
13900     return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
13901 
13902   // FSUB -> FMA combines:
13903   if (SDValue Fused = visitFSUBForFMACombine(N)) {
13904     AddToWorklist(Fused.getNode());
13905     return Fused;
13906   }
13907 
13908   return SDValue();
13909 }
13910 
visitFMUL(SDNode * N)13911 SDValue DAGCombiner::visitFMUL(SDNode *N) {
13912   SDValue N0 = N->getOperand(0);
13913   SDValue N1 = N->getOperand(1);
13914   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
13915   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
13916   EVT VT = N->getValueType(0);
13917   SDLoc DL(N);
13918   const TargetOptions &Options = DAG.getTarget().Options;
13919   const SDNodeFlags Flags = N->getFlags();
13920   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
13921 
13922   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
13923     return R;
13924 
13925   // fold vector ops
13926   if (VT.isVector()) {
13927     // This just handles C1 * C2 for vectors. Other vector folds are below.
13928     if (SDValue FoldedVOp = SimplifyVBinOp(N))
13929       return FoldedVOp;
13930   }
13931 
13932   // fold (fmul c1, c2) -> c1*c2
13933   if (N0CFP && N1CFP)
13934     return DAG.getNode(ISD::FMUL, DL, VT, N0, N1);
13935 
13936   // canonicalize constant to RHS
13937   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
13938      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
13939     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
13940 
13941   if (SDValue NewSel = foldBinOpIntoSelect(N))
13942     return NewSel;
13943 
13944   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
13945     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
13946     if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
13947         N0.getOpcode() == ISD::FMUL) {
13948       SDValue N00 = N0.getOperand(0);
13949       SDValue N01 = N0.getOperand(1);
13950       // Avoid an infinite loop by making sure that N00 is not a constant
13951       // (the inner multiply has not been constant folded yet).
13952       if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
13953           !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
13954         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
13955         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
13956       }
13957     }
13958 
13959     // Match a special-case: we convert X * 2.0 into fadd.
13960     // fmul (fadd X, X), C -> fmul X, 2.0 * C
13961     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
13962         N0.getOperand(0) == N0.getOperand(1)) {
13963       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
13964       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
13965       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
13966     }
13967   }
13968 
13969   // fold (fmul X, 2.0) -> (fadd X, X)
13970   if (N1CFP && N1CFP->isExactlyValue(+2.0))
13971     return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
13972 
13973   // fold (fmul X, -1.0) -> (fneg X)
13974   if (N1CFP && N1CFP->isExactlyValue(-1.0))
13975     if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
13976       return DAG.getNode(ISD::FNEG, DL, VT, N0);
13977 
13978   // -N0 * -N1 --> N0 * N1
13979   TargetLowering::NegatibleCost CostN0 =
13980       TargetLowering::NegatibleCost::Expensive;
13981   TargetLowering::NegatibleCost CostN1 =
13982       TargetLowering::NegatibleCost::Expensive;
13983   SDValue NegN0 =
13984       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
13985   SDValue NegN1 =
13986       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
13987   if (NegN0 && NegN1 &&
13988       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
13989        CostN1 == TargetLowering::NegatibleCost::Cheaper))
13990     return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
13991 
13992   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
13993   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
13994   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
13995       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
13996       TLI.isOperationLegal(ISD::FABS, VT)) {
13997     SDValue Select = N0, X = N1;
13998     if (Select.getOpcode() != ISD::SELECT)
13999       std::swap(Select, X);
14000 
14001     SDValue Cond = Select.getOperand(0);
14002     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
14003     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
14004 
14005     if (TrueOpnd && FalseOpnd &&
14006         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
14007         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
14008         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
14009       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
14010       switch (CC) {
14011       default: break;
14012       case ISD::SETOLT:
14013       case ISD::SETULT:
14014       case ISD::SETOLE:
14015       case ISD::SETULE:
14016       case ISD::SETLT:
14017       case ISD::SETLE:
14018         std::swap(TrueOpnd, FalseOpnd);
14019         LLVM_FALLTHROUGH;
14020       case ISD::SETOGT:
14021       case ISD::SETUGT:
14022       case ISD::SETOGE:
14023       case ISD::SETUGE:
14024       case ISD::SETGT:
14025       case ISD::SETGE:
14026         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
14027             TLI.isOperationLegal(ISD::FNEG, VT))
14028           return DAG.getNode(ISD::FNEG, DL, VT,
14029                    DAG.getNode(ISD::FABS, DL, VT, X));
14030         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
14031           return DAG.getNode(ISD::FABS, DL, VT, X);
14032 
14033         break;
14034       }
14035     }
14036   }
14037 
14038   // FMUL -> FMA combines:
14039   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
14040     AddToWorklist(Fused.getNode());
14041     return Fused;
14042   }
14043 
14044   return SDValue();
14045 }
14046 
visitFMA(SDNode * N)14047 SDValue DAGCombiner::visitFMA(SDNode *N) {
14048   SDValue N0 = N->getOperand(0);
14049   SDValue N1 = N->getOperand(1);
14050   SDValue N2 = N->getOperand(2);
14051   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
14052   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
14053   EVT VT = N->getValueType(0);
14054   SDLoc DL(N);
14055   const TargetOptions &Options = DAG.getTarget().Options;
14056   // FMA nodes have flags that propagate to the created nodes.
14057   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14058 
14059   bool UnsafeFPMath =
14060       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
14061 
14062   // Constant fold FMA.
14063   if (isa<ConstantFPSDNode>(N0) &&
14064       isa<ConstantFPSDNode>(N1) &&
14065       isa<ConstantFPSDNode>(N2)) {
14066     return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
14067   }
14068 
14069   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
14070   TargetLowering::NegatibleCost CostN0 =
14071       TargetLowering::NegatibleCost::Expensive;
14072   TargetLowering::NegatibleCost CostN1 =
14073       TargetLowering::NegatibleCost::Expensive;
14074   SDValue NegN0 =
14075       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
14076   SDValue NegN1 =
14077       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
14078   if (NegN0 && NegN1 &&
14079       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
14080        CostN1 == TargetLowering::NegatibleCost::Cheaper))
14081     return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
14082 
14083   if (UnsafeFPMath) {
14084     if (N0CFP && N0CFP->isZero())
14085       return N2;
14086     if (N1CFP && N1CFP->isZero())
14087       return N2;
14088   }
14089 
14090   if (N0CFP && N0CFP->isExactlyValue(1.0))
14091     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
14092   if (N1CFP && N1CFP->isExactlyValue(1.0))
14093     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
14094 
14095   // Canonicalize (fma c, x, y) -> (fma x, c, y)
14096   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
14097      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
14098     return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
14099 
14100   if (UnsafeFPMath) {
14101     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
14102     if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
14103         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
14104         DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
14105       return DAG.getNode(ISD::FMUL, DL, VT, N0,
14106                          DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
14107     }
14108 
14109     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
14110     if (N0.getOpcode() == ISD::FMUL &&
14111         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
14112         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
14113       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
14114                          DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)),
14115                          N2);
14116     }
14117   }
14118 
14119   // (fma x, -1, y) -> (fadd (fneg x), y)
14120   if (N1CFP) {
14121     if (N1CFP->isExactlyValue(1.0))
14122       return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
14123 
14124     if (N1CFP->isExactlyValue(-1.0) &&
14125         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
14126       SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
14127       AddToWorklist(RHSNeg.getNode());
14128       return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
14129     }
14130 
14131     // fma (fneg x), K, y -> fma x -K, y
14132     if (N0.getOpcode() == ISD::FNEG &&
14133         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
14134          (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
14135                                               ForCodeSize)))) {
14136       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
14137                          DAG.getNode(ISD::FNEG, DL, VT, N1), N2);
14138     }
14139   }
14140 
14141   if (UnsafeFPMath) {
14142     // (fma x, c, x) -> (fmul x, (c+1))
14143     if (N1CFP && N0 == N2) {
14144       return DAG.getNode(
14145           ISD::FMUL, DL, VT, N0,
14146           DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT)));
14147     }
14148 
14149     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
14150     if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
14151       return DAG.getNode(
14152           ISD::FMUL, DL, VT, N0,
14153           DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT)));
14154     }
14155   }
14156 
14157   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
14158   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
14159   if (!TLI.isFNegFree(VT))
14160     if (SDValue Neg = TLI.getCheaperNegatedExpression(
14161             SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
14162       return DAG.getNode(ISD::FNEG, DL, VT, Neg);
14163   return SDValue();
14164 }
14165 
14166 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
14167 // reciprocal.
14168 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
14169 // Notice that this is not always beneficial. One reason is different targets
14170 // may have different costs for FDIV and FMUL, so sometimes the cost of two
14171 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
14172 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)14173 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
14174   // TODO: Limit this transform based on optsize/minsize - it always creates at
14175   //       least 1 extra instruction. But the perf win may be substantial enough
14176   //       that only minsize should restrict this.
14177   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
14178   const SDNodeFlags Flags = N->getFlags();
14179   if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
14180     return SDValue();
14181 
14182   // Skip if current node is a reciprocal/fneg-reciprocal.
14183   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
14184   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
14185   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
14186     return SDValue();
14187 
14188   // Exit early if the target does not want this transform or if there can't
14189   // possibly be enough uses of the divisor to make the transform worthwhile.
14190   unsigned MinUses = TLI.combineRepeatedFPDivisors();
14191 
14192   // For splat vectors, scale the number of uses by the splat factor. If we can
14193   // convert the division into a scalar op, that will likely be much faster.
14194   unsigned NumElts = 1;
14195   EVT VT = N->getValueType(0);
14196   if (VT.isVector() && DAG.isSplatValue(N1))
14197     NumElts = VT.getVectorNumElements();
14198 
14199   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
14200     return SDValue();
14201 
14202   // Find all FDIV users of the same divisor.
14203   // Use a set because duplicates may be present in the user list.
14204   SetVector<SDNode *> Users;
14205   for (auto *U : N1->uses()) {
14206     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
14207       // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
14208       if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
14209           U->getOperand(0) == U->getOperand(1).getOperand(0) &&
14210           U->getFlags().hasAllowReassociation() &&
14211           U->getFlags().hasNoSignedZeros())
14212         continue;
14213 
14214       // This division is eligible for optimization only if global unsafe math
14215       // is enabled or if this division allows reciprocal formation.
14216       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
14217         Users.insert(U);
14218     }
14219   }
14220 
14221   // Now that we have the actual number of divisor uses, make sure it meets
14222   // the minimum threshold specified by the target.
14223   if ((Users.size() * NumElts) < MinUses)
14224     return SDValue();
14225 
14226   SDLoc DL(N);
14227   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
14228   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
14229 
14230   // Dividend / Divisor -> Dividend * Reciprocal
14231   for (auto *U : Users) {
14232     SDValue Dividend = U->getOperand(0);
14233     if (Dividend != FPOne) {
14234       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
14235                                     Reciprocal, Flags);
14236       CombineTo(U, NewNode);
14237     } else if (U != Reciprocal.getNode()) {
14238       // In the absence of fast-math-flags, this user node is always the
14239       // same node as Reciprocal, but with FMF they may be different nodes.
14240       CombineTo(U, Reciprocal);
14241     }
14242   }
14243   return SDValue(N, 0);  // N was replaced.
14244 }
14245 
visitFDIV(SDNode * N)14246 SDValue DAGCombiner::visitFDIV(SDNode *N) {
14247   SDValue N0 = N->getOperand(0);
14248   SDValue N1 = N->getOperand(1);
14249   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
14250   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
14251   EVT VT = N->getValueType(0);
14252   SDLoc DL(N);
14253   const TargetOptions &Options = DAG.getTarget().Options;
14254   SDNodeFlags Flags = N->getFlags();
14255   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14256 
14257   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
14258     return R;
14259 
14260   // fold vector ops
14261   if (VT.isVector())
14262     if (SDValue FoldedVOp = SimplifyVBinOp(N))
14263       return FoldedVOp;
14264 
14265   // fold (fdiv c1, c2) -> c1/c2
14266   if (N0CFP && N1CFP)
14267     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1);
14268 
14269   if (SDValue NewSel = foldBinOpIntoSelect(N))
14270     return NewSel;
14271 
14272   if (SDValue V = combineRepeatedFPDivisors(N))
14273     return V;
14274 
14275   if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
14276     // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
14277     if (N1CFP) {
14278       // Compute the reciprocal 1.0 / c2.
14279       const APFloat &N1APF = N1CFP->getValueAPF();
14280       APFloat Recip(N1APF.getSemantics(), 1); // 1.0
14281       APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
14282       // Only do the transform if the reciprocal is a legal fp immediate that
14283       // isn't too nasty (eg NaN, denormal, ...).
14284       if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
14285           (!LegalOperations ||
14286            // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
14287            // backend)... we should handle this gracefully after Legalize.
14288            // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
14289            TLI.isOperationLegal(ISD::ConstantFP, VT) ||
14290            TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
14291         return DAG.getNode(ISD::FMUL, DL, VT, N0,
14292                            DAG.getConstantFP(Recip, DL, VT));
14293     }
14294 
14295     // If this FDIV is part of a reciprocal square root, it may be folded
14296     // into a target-specific square root estimate instruction.
14297     if (N1.getOpcode() == ISD::FSQRT) {
14298       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
14299         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
14300     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
14301                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
14302       if (SDValue RV =
14303               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
14304         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
14305         AddToWorklist(RV.getNode());
14306         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
14307       }
14308     } else if (N1.getOpcode() == ISD::FP_ROUND &&
14309                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
14310       if (SDValue RV =
14311               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
14312         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
14313         AddToWorklist(RV.getNode());
14314         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
14315       }
14316     } else if (N1.getOpcode() == ISD::FMUL) {
14317       // Look through an FMUL. Even though this won't remove the FDIV directly,
14318       // it's still worthwhile to get rid of the FSQRT if possible.
14319       SDValue Sqrt, Y;
14320       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
14321         Sqrt = N1.getOperand(0);
14322         Y = N1.getOperand(1);
14323       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
14324         Sqrt = N1.getOperand(1);
14325         Y = N1.getOperand(0);
14326       }
14327       if (Sqrt.getNode()) {
14328         // If the other multiply operand is known positive, pull it into the
14329         // sqrt. That will eliminate the division if we convert to an estimate.
14330         if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
14331             N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
14332           SDValue A;
14333           if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
14334             A = Y.getOperand(0);
14335           else if (Y == Sqrt.getOperand(0))
14336             A = Y;
14337           if (A) {
14338             // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
14339             // X / (A * sqrt(A))       --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
14340             SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
14341             SDValue AAZ =
14342                 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
14343             if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
14344               return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
14345 
14346             // Estimate creation failed. Clean up speculatively created nodes.
14347             recursivelyDeleteUnusedNodes(AAZ.getNode());
14348           }
14349         }
14350 
14351         // We found a FSQRT, so try to make this fold:
14352         // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
14353         if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
14354           SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
14355           AddToWorklist(Div.getNode());
14356           return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
14357         }
14358       }
14359     }
14360 
14361     // Fold into a reciprocal estimate and multiply instead of a real divide.
14362     if (Options.NoInfsFPMath || Flags.hasNoInfs())
14363       if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
14364         return RV;
14365   }
14366 
14367   // Fold X/Sqrt(X) -> Sqrt(X)
14368   if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
14369       (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
14370     if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
14371       return N1;
14372 
14373   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
14374   TargetLowering::NegatibleCost CostN0 =
14375       TargetLowering::NegatibleCost::Expensive;
14376   TargetLowering::NegatibleCost CostN1 =
14377       TargetLowering::NegatibleCost::Expensive;
14378   SDValue NegN0 =
14379       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
14380   SDValue NegN1 =
14381       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
14382   if (NegN0 && NegN1 &&
14383       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
14384        CostN1 == TargetLowering::NegatibleCost::Cheaper))
14385     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
14386 
14387   return SDValue();
14388 }
14389 
visitFREM(SDNode * N)14390 SDValue DAGCombiner::visitFREM(SDNode *N) {
14391   SDValue N0 = N->getOperand(0);
14392   SDValue N1 = N->getOperand(1);
14393   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
14394   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
14395   EVT VT = N->getValueType(0);
14396   SDNodeFlags Flags = N->getFlags();
14397   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14398 
14399   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
14400     return R;
14401 
14402   // fold (frem c1, c2) -> fmod(c1,c2)
14403   if (N0CFP && N1CFP)
14404     return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1);
14405 
14406   if (SDValue NewSel = foldBinOpIntoSelect(N))
14407     return NewSel;
14408 
14409   return SDValue();
14410 }
14411 
visitFSQRT(SDNode * N)14412 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
14413   SDNodeFlags Flags = N->getFlags();
14414   const TargetOptions &Options = DAG.getTarget().Options;
14415 
14416   // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
14417   // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
14418   if (!Flags.hasApproximateFuncs() ||
14419       (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
14420     return SDValue();
14421 
14422   SDValue N0 = N->getOperand(0);
14423   if (TLI.isFsqrtCheap(N0, DAG))
14424     return SDValue();
14425 
14426   // FSQRT nodes have flags that propagate to the created nodes.
14427   // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
14428   //       transform the fdiv, we may produce a sub-optimal estimate sequence
14429   //       because the reciprocal calculation may not have to filter out a
14430   //       0.0 input.
14431   return buildSqrtEstimate(N0, Flags);
14432 }
14433 
14434 /// copysign(x, fp_extend(y)) -> copysign(x, y)
14435 /// copysign(x, fp_round(y)) -> copysign(x, y)
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)14436 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
14437   SDValue N1 = N->getOperand(1);
14438   if ((N1.getOpcode() == ISD::FP_EXTEND ||
14439        N1.getOpcode() == ISD::FP_ROUND)) {
14440     EVT N1VT = N1->getValueType(0);
14441     EVT N1Op0VT = N1->getOperand(0).getValueType();
14442 
14443     // Always fold no-op FP casts.
14444     if (N1VT == N1Op0VT)
14445       return true;
14446 
14447     // Do not optimize out type conversion of f128 type yet.
14448     // For some targets like x86_64, configuration is changed to keep one f128
14449     // value in one SSE register, but instruction selection cannot handle
14450     // FCOPYSIGN on SSE registers yet.
14451     if (N1Op0VT == MVT::f128)
14452       return false;
14453 
14454     // Avoid mismatched vector operand types, for better instruction selection.
14455     if (N1Op0VT.isVector())
14456       return false;
14457 
14458     return true;
14459   }
14460   return false;
14461 }
14462 
visitFCOPYSIGN(SDNode * N)14463 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
14464   SDValue N0 = N->getOperand(0);
14465   SDValue N1 = N->getOperand(1);
14466   bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
14467   bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
14468   EVT VT = N->getValueType(0);
14469 
14470   if (N0CFP && N1CFP) // Constant fold
14471     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1);
14472 
14473   if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
14474     const APFloat &V = N1C->getValueAPF();
14475     // copysign(x, c1) -> fabs(x)       iff ispos(c1)
14476     // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
14477     if (!V.isNegative()) {
14478       if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
14479         return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
14480     } else {
14481       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
14482         return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
14483                            DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
14484     }
14485   }
14486 
14487   // copysign(fabs(x), y) -> copysign(x, y)
14488   // copysign(fneg(x), y) -> copysign(x, y)
14489   // copysign(copysign(x,z), y) -> copysign(x, y)
14490   if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
14491       N0.getOpcode() == ISD::FCOPYSIGN)
14492     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
14493 
14494   // copysign(x, abs(y)) -> abs(x)
14495   if (N1.getOpcode() == ISD::FABS)
14496     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
14497 
14498   // copysign(x, copysign(y,z)) -> copysign(x, z)
14499   if (N1.getOpcode() == ISD::FCOPYSIGN)
14500     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
14501 
14502   // copysign(x, fp_extend(y)) -> copysign(x, y)
14503   // copysign(x, fp_round(y)) -> copysign(x, y)
14504   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
14505     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
14506 
14507   return SDValue();
14508 }
14509 
visitFPOW(SDNode * N)14510 SDValue DAGCombiner::visitFPOW(SDNode *N) {
14511   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
14512   if (!ExponentC)
14513     return SDValue();
14514   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14515 
14516   // Try to convert x ** (1/3) into cube root.
14517   // TODO: Handle the various flavors of long double.
14518   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
14519   //       Some range near 1/3 should be fine.
14520   EVT VT = N->getValueType(0);
14521   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
14522       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
14523     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
14524     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
14525     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
14526     // For regular numbers, rounding may cause the results to differ.
14527     // Therefore, we require { nsz ninf nnan afn } for this transform.
14528     // TODO: We could select out the special cases if we don't have nsz/ninf.
14529     SDNodeFlags Flags = N->getFlags();
14530     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
14531         !Flags.hasApproximateFuncs())
14532       return SDValue();
14533 
14534     // Do not create a cbrt() libcall if the target does not have it, and do not
14535     // turn a pow that has lowering support into a cbrt() libcall.
14536     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
14537         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
14538          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
14539       return SDValue();
14540 
14541     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
14542   }
14543 
14544   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
14545   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
14546   // TODO: This could be extended (using a target hook) to handle smaller
14547   // power-of-2 fractional exponents.
14548   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
14549   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
14550   if (ExponentIs025 || ExponentIs075) {
14551     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
14552     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
14553     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
14554     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
14555     // For regular numbers, rounding may cause the results to differ.
14556     // Therefore, we require { nsz ninf afn } for this transform.
14557     // TODO: We could select out the special cases if we don't have nsz/ninf.
14558     SDNodeFlags Flags = N->getFlags();
14559 
14560     // We only need no signed zeros for the 0.25 case.
14561     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
14562         !Flags.hasApproximateFuncs())
14563       return SDValue();
14564 
14565     // Don't double the number of libcalls. We are trying to inline fast code.
14566     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
14567       return SDValue();
14568 
14569     // Assume that libcalls are the smallest code.
14570     // TODO: This restriction should probably be lifted for vectors.
14571     if (ForCodeSize)
14572       return SDValue();
14573 
14574     // pow(X, 0.25) --> sqrt(sqrt(X))
14575     SDLoc DL(N);
14576     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
14577     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
14578     if (ExponentIs025)
14579       return SqrtSqrt;
14580     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
14581     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
14582   }
14583 
14584   return SDValue();
14585 }
14586 
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)14587 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
14588                                const TargetLowering &TLI) {
14589   // This optimization is guarded by a function attribute because it may produce
14590   // unexpected results. Ie, programs may be relying on the platform-specific
14591   // undefined behavior when the float-to-int conversion overflows.
14592   const Function &F = DAG.getMachineFunction().getFunction();
14593   Attribute StrictOverflow = F.getFnAttribute("strict-float-cast-overflow");
14594   if (StrictOverflow.getValueAsString().equals("false"))
14595     return SDValue();
14596 
14597   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
14598   // replacing casts with a libcall. We also must be allowed to ignore -0.0
14599   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
14600   // conversions would return +0.0.
14601   // FIXME: We should be able to use node-level FMF here.
14602   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
14603   EVT VT = N->getValueType(0);
14604   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
14605       !DAG.getTarget().Options.NoSignedZerosFPMath)
14606     return SDValue();
14607 
14608   // fptosi/fptoui round towards zero, so converting from FP to integer and
14609   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
14610   SDValue N0 = N->getOperand(0);
14611   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
14612       N0.getOperand(0).getValueType() == VT)
14613     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
14614 
14615   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
14616       N0.getOperand(0).getValueType() == VT)
14617     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
14618 
14619   return SDValue();
14620 }
14621 
visitSINT_TO_FP(SDNode * N)14622 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
14623   SDValue N0 = N->getOperand(0);
14624   EVT VT = N->getValueType(0);
14625   EVT OpVT = N0.getValueType();
14626 
14627   // [us]itofp(undef) = 0, because the result value is bounded.
14628   if (N0.isUndef())
14629     return DAG.getConstantFP(0.0, SDLoc(N), VT);
14630 
14631   // fold (sint_to_fp c1) -> c1fp
14632   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
14633       // ...but only if the target supports immediate floating-point values
14634       (!LegalOperations ||
14635        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
14636     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
14637 
14638   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
14639   // but UINT_TO_FP is legal on this target, try to convert.
14640   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
14641       hasOperation(ISD::UINT_TO_FP, OpVT)) {
14642     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
14643     if (DAG.SignBitIsZero(N0))
14644       return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
14645   }
14646 
14647   // The next optimizations are desirable only if SELECT_CC can be lowered.
14648   // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
14649   if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
14650       !VT.isVector() &&
14651       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
14652     SDLoc DL(N);
14653     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
14654                          DAG.getConstantFP(0.0, DL, VT));
14655   }
14656 
14657   // fold (sint_to_fp (zext (setcc x, y, cc))) ->
14658   //      (select (setcc x, y, cc), 1.0, 0.0)
14659   if (N0.getOpcode() == ISD::ZERO_EXTEND &&
14660       N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
14661       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
14662     SDLoc DL(N);
14663     return DAG.getSelect(DL, VT, N0.getOperand(0),
14664                          DAG.getConstantFP(1.0, DL, VT),
14665                          DAG.getConstantFP(0.0, DL, VT));
14666   }
14667 
14668   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
14669     return FTrunc;
14670 
14671   return SDValue();
14672 }
14673 
visitUINT_TO_FP(SDNode * N)14674 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
14675   SDValue N0 = N->getOperand(0);
14676   EVT VT = N->getValueType(0);
14677   EVT OpVT = N0.getValueType();
14678 
14679   // [us]itofp(undef) = 0, because the result value is bounded.
14680   if (N0.isUndef())
14681     return DAG.getConstantFP(0.0, SDLoc(N), VT);
14682 
14683   // fold (uint_to_fp c1) -> c1fp
14684   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
14685       // ...but only if the target supports immediate floating-point values
14686       (!LegalOperations ||
14687        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
14688     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
14689 
14690   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
14691   // but SINT_TO_FP is legal on this target, try to convert.
14692   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
14693       hasOperation(ISD::SINT_TO_FP, OpVT)) {
14694     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
14695     if (DAG.SignBitIsZero(N0))
14696       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
14697   }
14698 
14699   // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
14700   if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
14701       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
14702     SDLoc DL(N);
14703     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
14704                          DAG.getConstantFP(0.0, DL, VT));
14705   }
14706 
14707   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
14708     return FTrunc;
14709 
14710   return SDValue();
14711 }
14712 
14713 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)14714 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
14715   SDValue N0 = N->getOperand(0);
14716   EVT VT = N->getValueType(0);
14717 
14718   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
14719     return SDValue();
14720 
14721   SDValue Src = N0.getOperand(0);
14722   EVT SrcVT = Src.getValueType();
14723   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
14724   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
14725 
14726   // We can safely assume the conversion won't overflow the output range,
14727   // because (for example) (uint8_t)18293.f is undefined behavior.
14728 
14729   // Since we can assume the conversion won't overflow, our decision as to
14730   // whether the input will fit in the float should depend on the minimum
14731   // of the input range and output range.
14732 
14733   // This means this is also safe for a signed input and unsigned output, since
14734   // a negative input would lead to undefined behavior.
14735   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
14736   unsigned OutputSize = (int)VT.getScalarSizeInBits() - IsOutputSigned;
14737   unsigned ActualSize = std::min(InputSize, OutputSize);
14738   const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
14739 
14740   // We can only fold away the float conversion if the input range can be
14741   // represented exactly in the float range.
14742   if (APFloat::semanticsPrecision(sem) >= ActualSize) {
14743     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
14744       unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
14745                                                        : ISD::ZERO_EXTEND;
14746       return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
14747     }
14748     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
14749       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
14750     return DAG.getBitcast(VT, Src);
14751   }
14752   return SDValue();
14753 }
14754 
visitFP_TO_SINT(SDNode * N)14755 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
14756   SDValue N0 = N->getOperand(0);
14757   EVT VT = N->getValueType(0);
14758 
14759   // fold (fp_to_sint undef) -> undef
14760   if (N0.isUndef())
14761     return DAG.getUNDEF(VT);
14762 
14763   // fold (fp_to_sint c1fp) -> c1
14764   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14765     return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
14766 
14767   return FoldIntToFPToInt(N, DAG);
14768 }
14769 
visitFP_TO_UINT(SDNode * N)14770 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
14771   SDValue N0 = N->getOperand(0);
14772   EVT VT = N->getValueType(0);
14773 
14774   // fold (fp_to_uint undef) -> undef
14775   if (N0.isUndef())
14776     return DAG.getUNDEF(VT);
14777 
14778   // fold (fp_to_uint c1fp) -> c1
14779   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14780     return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
14781 
14782   return FoldIntToFPToInt(N, DAG);
14783 }
14784 
visitFP_ROUND(SDNode * N)14785 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
14786   SDValue N0 = N->getOperand(0);
14787   SDValue N1 = N->getOperand(1);
14788   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
14789   EVT VT = N->getValueType(0);
14790 
14791   // fold (fp_round c1fp) -> c1fp
14792   if (N0CFP)
14793     return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT, N0, N1);
14794 
14795   // fold (fp_round (fp_extend x)) -> x
14796   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
14797     return N0.getOperand(0);
14798 
14799   // fold (fp_round (fp_round x)) -> (fp_round x)
14800   if (N0.getOpcode() == ISD::FP_ROUND) {
14801     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
14802     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
14803 
14804     // Skip this folding if it results in an fp_round from f80 to f16.
14805     //
14806     // f80 to f16 always generates an expensive (and as yet, unimplemented)
14807     // libcall to __truncxfhf2 instead of selecting native f16 conversion
14808     // instructions from f32 or f64.  Moreover, the first (value-preserving)
14809     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
14810     // x86.
14811     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
14812       return SDValue();
14813 
14814     // If the first fp_round isn't a value preserving truncation, it might
14815     // introduce a tie in the second fp_round, that wouldn't occur in the
14816     // single-step fp_round we want to fold to.
14817     // In other words, double rounding isn't the same as rounding.
14818     // Also, this is a value preserving truncation iff both fp_round's are.
14819     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
14820       SDLoc DL(N);
14821       return DAG.getNode(ISD::FP_ROUND, DL, VT, N0.getOperand(0),
14822                          DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL));
14823     }
14824   }
14825 
14826   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
14827   if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse()) {
14828     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
14829                               N0.getOperand(0), N1);
14830     AddToWorklist(Tmp.getNode());
14831     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
14832                        Tmp, N0.getOperand(1));
14833   }
14834 
14835   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14836     return NewVSel;
14837 
14838   return SDValue();
14839 }
14840 
visitFP_EXTEND(SDNode * N)14841 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
14842   SDValue N0 = N->getOperand(0);
14843   EVT VT = N->getValueType(0);
14844 
14845   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
14846   if (N->hasOneUse() &&
14847       N->use_begin()->getOpcode() == ISD::FP_ROUND)
14848     return SDValue();
14849 
14850   // fold (fp_extend c1fp) -> c1fp
14851   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14852     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
14853 
14854   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
14855   if (N0.getOpcode() == ISD::FP16_TO_FP &&
14856       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
14857     return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
14858 
14859   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
14860   // value of X.
14861   if (N0.getOpcode() == ISD::FP_ROUND
14862       && N0.getConstantOperandVal(1) == 1) {
14863     SDValue In = N0.getOperand(0);
14864     if (In.getValueType() == VT) return In;
14865     if (VT.bitsLT(In.getValueType()))
14866       return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
14867                          In, N0.getOperand(1));
14868     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
14869   }
14870 
14871   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
14872   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
14873        TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
14874     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14875     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
14876                                      LN0->getChain(),
14877                                      LN0->getBasePtr(), N0.getValueType(),
14878                                      LN0->getMemOperand());
14879     CombineTo(N, ExtLoad);
14880     CombineTo(N0.getNode(),
14881               DAG.getNode(ISD::FP_ROUND, SDLoc(N0),
14882                           N0.getValueType(), ExtLoad,
14883                           DAG.getIntPtrConstant(1, SDLoc(N0))),
14884               ExtLoad.getValue(1));
14885     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14886   }
14887 
14888   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14889     return NewVSel;
14890 
14891   return SDValue();
14892 }
14893 
visitFCEIL(SDNode * N)14894 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
14895   SDValue N0 = N->getOperand(0);
14896   EVT VT = N->getValueType(0);
14897 
14898   // fold (fceil c1) -> fceil(c1)
14899   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14900     return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
14901 
14902   return SDValue();
14903 }
14904 
visitFTRUNC(SDNode * N)14905 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
14906   SDValue N0 = N->getOperand(0);
14907   EVT VT = N->getValueType(0);
14908 
14909   // fold (ftrunc c1) -> ftrunc(c1)
14910   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14911     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
14912 
14913   // fold ftrunc (known rounded int x) -> x
14914   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
14915   // likely to be generated to extract integer from a rounded floating value.
14916   switch (N0.getOpcode()) {
14917   default: break;
14918   case ISD::FRINT:
14919   case ISD::FTRUNC:
14920   case ISD::FNEARBYINT:
14921   case ISD::FFLOOR:
14922   case ISD::FCEIL:
14923     return N0;
14924   }
14925 
14926   return SDValue();
14927 }
14928 
visitFFLOOR(SDNode * N)14929 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
14930   SDValue N0 = N->getOperand(0);
14931   EVT VT = N->getValueType(0);
14932 
14933   // fold (ffloor c1) -> ffloor(c1)
14934   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14935     return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
14936 
14937   return SDValue();
14938 }
14939 
visitFNEG(SDNode * N)14940 SDValue DAGCombiner::visitFNEG(SDNode *N) {
14941   SDValue N0 = N->getOperand(0);
14942   EVT VT = N->getValueType(0);
14943   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14944 
14945   // Constant fold FNEG.
14946   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
14947     return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
14948 
14949   if (SDValue NegN0 =
14950           TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
14951     return NegN0;
14952 
14953   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
14954   // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
14955   // know it was called from a context with a nsz flag if the input fsub does
14956   // not.
14957   if (N0.getOpcode() == ISD::FSUB &&
14958       (DAG.getTarget().Options.NoSignedZerosFPMath ||
14959        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
14960     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
14961                        N0.getOperand(0));
14962   }
14963 
14964   if (SDValue Cast = foldSignChangeInBitcast(N))
14965     return Cast;
14966 
14967   return SDValue();
14968 }
14969 
visitFMinMax(SelectionDAG & DAG,SDNode * N,APFloat (* Op)(const APFloat &,const APFloat &))14970 static SDValue visitFMinMax(SelectionDAG &DAG, SDNode *N,
14971                             APFloat (*Op)(const APFloat &, const APFloat &)) {
14972   SDValue N0 = N->getOperand(0);
14973   SDValue N1 = N->getOperand(1);
14974   EVT VT = N->getValueType(0);
14975   const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0);
14976   const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1);
14977   const SDNodeFlags Flags = N->getFlags();
14978   unsigned Opc = N->getOpcode();
14979   bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
14980   bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
14981   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14982 
14983   if (N0CFP && N1CFP) {
14984     const APFloat &C0 = N0CFP->getValueAPF();
14985     const APFloat &C1 = N1CFP->getValueAPF();
14986     return DAG.getConstantFP(Op(C0, C1), SDLoc(N), VT);
14987   }
14988 
14989   // Canonicalize to constant on RHS.
14990   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
14991       !DAG.isConstantFPBuildVectorOrConstantFP(N1))
14992     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
14993 
14994   if (N1CFP) {
14995     const APFloat &AF = N1CFP->getValueAPF();
14996 
14997     // minnum(X, nan) -> X
14998     // maxnum(X, nan) -> X
14999     // minimum(X, nan) -> nan
15000     // maximum(X, nan) -> nan
15001     if (AF.isNaN())
15002       return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
15003 
15004     // In the following folds, inf can be replaced with the largest finite
15005     // float, if the ninf flag is set.
15006     if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
15007       // minnum(X, -inf) -> -inf
15008       // maxnum(X, +inf) -> +inf
15009       // minimum(X, -inf) -> -inf if nnan
15010       // maximum(X, +inf) -> +inf if nnan
15011       if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
15012         return N->getOperand(1);
15013 
15014       // minnum(X, +inf) -> X if nnan
15015       // maxnum(X, -inf) -> X if nnan
15016       // minimum(X, +inf) -> X
15017       // maximum(X, -inf) -> X
15018       if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
15019         return N->getOperand(0);
15020     }
15021   }
15022 
15023   return SDValue();
15024 }
15025 
visitFMINNUM(SDNode * N)15026 SDValue DAGCombiner::visitFMINNUM(SDNode *N) {
15027   return visitFMinMax(DAG, N, minnum);
15028 }
15029 
visitFMAXNUM(SDNode * N)15030 SDValue DAGCombiner::visitFMAXNUM(SDNode *N) {
15031   return visitFMinMax(DAG, N, maxnum);
15032 }
15033 
visitFMINIMUM(SDNode * N)15034 SDValue DAGCombiner::visitFMINIMUM(SDNode *N) {
15035   return visitFMinMax(DAG, N, minimum);
15036 }
15037 
visitFMAXIMUM(SDNode * N)15038 SDValue DAGCombiner::visitFMAXIMUM(SDNode *N) {
15039   return visitFMinMax(DAG, N, maximum);
15040 }
15041 
visitFABS(SDNode * N)15042 SDValue DAGCombiner::visitFABS(SDNode *N) {
15043   SDValue N0 = N->getOperand(0);
15044   EVT VT = N->getValueType(0);
15045 
15046   // fold (fabs c1) -> fabs(c1)
15047   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15048     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
15049 
15050   // fold (fabs (fabs x)) -> (fabs x)
15051   if (N0.getOpcode() == ISD::FABS)
15052     return N->getOperand(0);
15053 
15054   // fold (fabs (fneg x)) -> (fabs x)
15055   // fold (fabs (fcopysign x, y)) -> (fabs x)
15056   if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
15057     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
15058 
15059   if (SDValue Cast = foldSignChangeInBitcast(N))
15060     return Cast;
15061 
15062   return SDValue();
15063 }
15064 
visitBRCOND(SDNode * N)15065 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
15066   SDValue Chain = N->getOperand(0);
15067   SDValue N1 = N->getOperand(1);
15068   SDValue N2 = N->getOperand(2);
15069 
15070   // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
15071   // nondeterministic jumps).
15072   if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
15073     return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
15074                        N1->getOperand(0), N2);
15075   }
15076 
15077   // If N is a constant we could fold this into a fallthrough or unconditional
15078   // branch. However that doesn't happen very often in normal code, because
15079   // Instcombine/SimplifyCFG should have handled the available opportunities.
15080   // If we did this folding here, it would be necessary to update the
15081   // MachineBasicBlock CFG, which is awkward.
15082 
15083   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
15084   // on the target.
15085   if (N1.getOpcode() == ISD::SETCC &&
15086       TLI.isOperationLegalOrCustom(ISD::BR_CC,
15087                                    N1.getOperand(0).getValueType())) {
15088     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
15089                        Chain, N1.getOperand(2),
15090                        N1.getOperand(0), N1.getOperand(1), N2);
15091   }
15092 
15093   if (N1.hasOneUse()) {
15094     // rebuildSetCC calls visitXor which may change the Chain when there is a
15095     // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
15096     HandleSDNode ChainHandle(Chain);
15097     if (SDValue NewN1 = rebuildSetCC(N1))
15098       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
15099                          ChainHandle.getValue(), NewN1, N2);
15100   }
15101 
15102   return SDValue();
15103 }
15104 
rebuildSetCC(SDValue N)15105 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
15106   if (N.getOpcode() == ISD::SRL ||
15107       (N.getOpcode() == ISD::TRUNCATE &&
15108        (N.getOperand(0).hasOneUse() &&
15109         N.getOperand(0).getOpcode() == ISD::SRL))) {
15110     // Look pass the truncate.
15111     if (N.getOpcode() == ISD::TRUNCATE)
15112       N = N.getOperand(0);
15113 
15114     // Match this pattern so that we can generate simpler code:
15115     //
15116     //   %a = ...
15117     //   %b = and i32 %a, 2
15118     //   %c = srl i32 %b, 1
15119     //   brcond i32 %c ...
15120     //
15121     // into
15122     //
15123     //   %a = ...
15124     //   %b = and i32 %a, 2
15125     //   %c = setcc eq %b, 0
15126     //   brcond %c ...
15127     //
15128     // This applies only when the AND constant value has one bit set and the
15129     // SRL constant is equal to the log2 of the AND constant. The back-end is
15130     // smart enough to convert the result into a TEST/JMP sequence.
15131     SDValue Op0 = N.getOperand(0);
15132     SDValue Op1 = N.getOperand(1);
15133 
15134     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
15135       SDValue AndOp1 = Op0.getOperand(1);
15136 
15137       if (AndOp1.getOpcode() == ISD::Constant) {
15138         const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
15139 
15140         if (AndConst.isPowerOf2() &&
15141             cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
15142           SDLoc DL(N);
15143           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
15144                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
15145                               ISD::SETNE);
15146         }
15147       }
15148     }
15149   }
15150 
15151   // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
15152   // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
15153   if (N.getOpcode() == ISD::XOR) {
15154     // Because we may call this on a speculatively constructed
15155     // SimplifiedSetCC Node, we need to simplify this node first.
15156     // Ideally this should be folded into SimplifySetCC and not
15157     // here. For now, grab a handle to N so we don't lose it from
15158     // replacements interal to the visit.
15159     HandleSDNode XORHandle(N);
15160     while (N.getOpcode() == ISD::XOR) {
15161       SDValue Tmp = visitXOR(N.getNode());
15162       // No simplification done.
15163       if (!Tmp.getNode())
15164         break;
15165       // Returning N is form in-visit replacement that may invalidated
15166       // N. Grab value from Handle.
15167       if (Tmp.getNode() == N.getNode())
15168         N = XORHandle.getValue();
15169       else // Node simplified. Try simplifying again.
15170         N = Tmp;
15171     }
15172 
15173     if (N.getOpcode() != ISD::XOR)
15174       return N;
15175 
15176     SDValue Op0 = N->getOperand(0);
15177     SDValue Op1 = N->getOperand(1);
15178 
15179     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
15180       bool Equal = false;
15181       // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
15182       if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
15183           Op0.getValueType() == MVT::i1) {
15184         N = Op0;
15185         Op0 = N->getOperand(0);
15186         Op1 = N->getOperand(1);
15187         Equal = true;
15188       }
15189 
15190       EVT SetCCVT = N.getValueType();
15191       if (LegalTypes)
15192         SetCCVT = getSetCCResultType(SetCCVT);
15193       // Replace the uses of XOR with SETCC
15194       return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
15195                           Equal ? ISD::SETEQ : ISD::SETNE);
15196     }
15197   }
15198 
15199   return SDValue();
15200 }
15201 
15202 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
15203 //
visitBR_CC(SDNode * N)15204 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
15205   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
15206   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
15207 
15208   // If N is a constant we could fold this into a fallthrough or unconditional
15209   // branch. However that doesn't happen very often in normal code, because
15210   // Instcombine/SimplifyCFG should have handled the available opportunities.
15211   // If we did this folding here, it would be necessary to update the
15212   // MachineBasicBlock CFG, which is awkward.
15213 
15214   // Use SimplifySetCC to simplify SETCC's.
15215   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
15216                                CondLHS, CondRHS, CC->get(), SDLoc(N),
15217                                false);
15218   if (Simp.getNode()) AddToWorklist(Simp.getNode());
15219 
15220   // fold to a simpler setcc
15221   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
15222     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
15223                        N->getOperand(0), Simp.getOperand(2),
15224                        Simp.getOperand(0), Simp.getOperand(1),
15225                        N->getOperand(4));
15226 
15227   return SDValue();
15228 }
15229 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)15230 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
15231                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
15232                                      const TargetLowering &TLI) {
15233   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
15234     if (LD->isIndexed())
15235       return false;
15236     EVT VT = LD->getMemoryVT();
15237     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
15238       return false;
15239     Ptr = LD->getBasePtr();
15240   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
15241     if (ST->isIndexed())
15242       return false;
15243     EVT VT = ST->getMemoryVT();
15244     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
15245       return false;
15246     Ptr = ST->getBasePtr();
15247     IsLoad = false;
15248   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
15249     if (LD->isIndexed())
15250       return false;
15251     EVT VT = LD->getMemoryVT();
15252     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
15253         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
15254       return false;
15255     Ptr = LD->getBasePtr();
15256     IsMasked = true;
15257   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
15258     if (ST->isIndexed())
15259       return false;
15260     EVT VT = ST->getMemoryVT();
15261     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
15262         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
15263       return false;
15264     Ptr = ST->getBasePtr();
15265     IsLoad = false;
15266     IsMasked = true;
15267   } else {
15268     return false;
15269   }
15270   return true;
15271 }
15272 
15273 /// Try turning a load/store into a pre-indexed load/store when the base
15274 /// pointer is an add or subtract and it has other uses besides the load/store.
15275 /// After the transformation, the new indexed load/store has effectively folded
15276 /// the add/subtract in and all of its other uses are redirected to the
15277 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)15278 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
15279   if (Level < AfterLegalizeDAG)
15280     return false;
15281 
15282   bool IsLoad = true;
15283   bool IsMasked = false;
15284   SDValue Ptr;
15285   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
15286                                 Ptr, TLI))
15287     return false;
15288 
15289   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
15290   // out.  There is no reason to make this a preinc/predec.
15291   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
15292       Ptr.getNode()->hasOneUse())
15293     return false;
15294 
15295   // Ask the target to do addressing mode selection.
15296   SDValue BasePtr;
15297   SDValue Offset;
15298   ISD::MemIndexedMode AM = ISD::UNINDEXED;
15299   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
15300     return false;
15301 
15302   // Backends without true r+i pre-indexed forms may need to pass a
15303   // constant base with a variable offset so that constant coercion
15304   // will work with the patterns in canonical form.
15305   bool Swapped = false;
15306   if (isa<ConstantSDNode>(BasePtr)) {
15307     std::swap(BasePtr, Offset);
15308     Swapped = true;
15309   }
15310 
15311   // Don't create a indexed load / store with zero offset.
15312   if (isNullConstant(Offset))
15313     return false;
15314 
15315   // Try turning it into a pre-indexed load / store except when:
15316   // 1) The new base ptr is a frame index.
15317   // 2) If N is a store and the new base ptr is either the same as or is a
15318   //    predecessor of the value being stored.
15319   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
15320   //    that would create a cycle.
15321   // 4) All uses are load / store ops that use it as old base ptr.
15322 
15323   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
15324   // (plus the implicit offset) to a register to preinc anyway.
15325   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
15326     return false;
15327 
15328   // Check #2.
15329   if (!IsLoad) {
15330     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
15331                            : cast<StoreSDNode>(N)->getValue();
15332 
15333     // Would require a copy.
15334     if (Val == BasePtr)
15335       return false;
15336 
15337     // Would create a cycle.
15338     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
15339       return false;
15340   }
15341 
15342   // Caches for hasPredecessorHelper.
15343   SmallPtrSet<const SDNode *, 32> Visited;
15344   SmallVector<const SDNode *, 16> Worklist;
15345   Worklist.push_back(N);
15346 
15347   // If the offset is a constant, there may be other adds of constants that
15348   // can be folded with this one. We should do this to avoid having to keep
15349   // a copy of the original base pointer.
15350   SmallVector<SDNode *, 16> OtherUses;
15351   if (isa<ConstantSDNode>(Offset))
15352     for (SDNode::use_iterator UI = BasePtr.getNode()->use_begin(),
15353                               UE = BasePtr.getNode()->use_end();
15354          UI != UE; ++UI) {
15355       SDUse &Use = UI.getUse();
15356       // Skip the use that is Ptr and uses of other results from BasePtr's
15357       // node (important for nodes that return multiple results).
15358       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
15359         continue;
15360 
15361       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
15362         continue;
15363 
15364       if (Use.getUser()->getOpcode() != ISD::ADD &&
15365           Use.getUser()->getOpcode() != ISD::SUB) {
15366         OtherUses.clear();
15367         break;
15368       }
15369 
15370       SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
15371       if (!isa<ConstantSDNode>(Op1)) {
15372         OtherUses.clear();
15373         break;
15374       }
15375 
15376       // FIXME: In some cases, we can be smarter about this.
15377       if (Op1.getValueType() != Offset.getValueType()) {
15378         OtherUses.clear();
15379         break;
15380       }
15381 
15382       OtherUses.push_back(Use.getUser());
15383     }
15384 
15385   if (Swapped)
15386     std::swap(BasePtr, Offset);
15387 
15388   // Now check for #3 and #4.
15389   bool RealUse = false;
15390 
15391   for (SDNode *Use : Ptr.getNode()->uses()) {
15392     if (Use == N)
15393       continue;
15394     if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
15395       return false;
15396 
15397     // If Ptr may be folded in addressing mode of other use, then it's
15398     // not profitable to do this transformation.
15399     if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
15400       RealUse = true;
15401   }
15402 
15403   if (!RealUse)
15404     return false;
15405 
15406   SDValue Result;
15407   if (!IsMasked) {
15408     if (IsLoad)
15409       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
15410     else
15411       Result =
15412           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
15413   } else {
15414     if (IsLoad)
15415       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
15416                                         Offset, AM);
15417     else
15418       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
15419                                          Offset, AM);
15420   }
15421   ++PreIndexedNodes;
15422   ++NodesCombined;
15423   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
15424              Result.getNode()->dump(&DAG); dbgs() << '\n');
15425   WorklistRemover DeadNodes(*this);
15426   if (IsLoad) {
15427     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
15428     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
15429   } else {
15430     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
15431   }
15432 
15433   // Finally, since the node is now dead, remove it from the graph.
15434   deleteAndRecombine(N);
15435 
15436   if (Swapped)
15437     std::swap(BasePtr, Offset);
15438 
15439   // Replace other uses of BasePtr that can be updated to use Ptr
15440   for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
15441     unsigned OffsetIdx = 1;
15442     if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
15443       OffsetIdx = 0;
15444     assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
15445            BasePtr.getNode() && "Expected BasePtr operand");
15446 
15447     // We need to replace ptr0 in the following expression:
15448     //   x0 * offset0 + y0 * ptr0 = t0
15449     // knowing that
15450     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
15451     //
15452     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
15453     // indexed load/store and the expression that needs to be re-written.
15454     //
15455     // Therefore, we have:
15456     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
15457 
15458     auto *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
15459     const APInt &Offset0 = CN->getAPIntValue();
15460     const APInt &Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
15461     int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
15462     int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
15463     int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
15464     int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
15465 
15466     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
15467 
15468     APInt CNV = Offset0;
15469     if (X0 < 0) CNV = -CNV;
15470     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
15471     else CNV = CNV - Offset1;
15472 
15473     SDLoc DL(OtherUses[i]);
15474 
15475     // We can now generate the new expression.
15476     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
15477     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
15478 
15479     SDValue NewUse = DAG.getNode(Opcode,
15480                                  DL,
15481                                  OtherUses[i]->getValueType(0), NewOp1, NewOp2);
15482     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
15483     deleteAndRecombine(OtherUses[i]);
15484   }
15485 
15486   // Replace the uses of Ptr with uses of the updated base value.
15487   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
15488   deleteAndRecombine(Ptr.getNode());
15489   AddToWorklist(Result.getNode());
15490 
15491   return true;
15492 }
15493 
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)15494 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
15495                                    SDValue &BasePtr, SDValue &Offset,
15496                                    ISD::MemIndexedMode &AM,
15497                                    SelectionDAG &DAG,
15498                                    const TargetLowering &TLI) {
15499   if (PtrUse == N ||
15500       (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
15501     return false;
15502 
15503   if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
15504     return false;
15505 
15506   // Don't create a indexed load / store with zero offset.
15507   if (isNullConstant(Offset))
15508     return false;
15509 
15510   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
15511     return false;
15512 
15513   SmallPtrSet<const SDNode *, 32> Visited;
15514   for (SDNode *Use : BasePtr.getNode()->uses()) {
15515     if (Use == Ptr.getNode())
15516       continue;
15517 
15518     // No if there's a later user which could perform the index instead.
15519     if (isa<MemSDNode>(Use)) {
15520       bool IsLoad = true;
15521       bool IsMasked = false;
15522       SDValue OtherPtr;
15523       if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad,
15524                                    IsMasked, OtherPtr, TLI)) {
15525         SmallVector<const SDNode *, 2> Worklist;
15526         Worklist.push_back(Use);
15527         if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
15528           return false;
15529       }
15530     }
15531 
15532     // If all the uses are load / store addresses, then don't do the
15533     // transformation.
15534     if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
15535       for (SDNode *UseUse : Use->uses())
15536         if (canFoldInAddressingMode(Use, UseUse, DAG, TLI))
15537           return false;
15538     }
15539   }
15540   return true;
15541 }
15542 
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)15543 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
15544                                          bool &IsMasked, SDValue &Ptr,
15545                                          SDValue &BasePtr, SDValue &Offset,
15546                                          ISD::MemIndexedMode &AM,
15547                                          SelectionDAG &DAG,
15548                                          const TargetLowering &TLI) {
15549   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
15550                                 IsMasked, Ptr, TLI) ||
15551       Ptr.getNode()->hasOneUse())
15552     return nullptr;
15553 
15554   // Try turning it into a post-indexed load / store except when
15555   // 1) All uses are load / store ops that use it as base ptr (and
15556   //    it may be folded as addressing mmode).
15557   // 2) Op must be independent of N, i.e. Op is neither a predecessor
15558   //    nor a successor of N. Otherwise, if Op is folded that would
15559   //    create a cycle.
15560   for (SDNode *Op : Ptr->uses()) {
15561     // Check for #1.
15562     if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
15563       continue;
15564 
15565     // Check for #2.
15566     SmallPtrSet<const SDNode *, 32> Visited;
15567     SmallVector<const SDNode *, 8> Worklist;
15568     // Ptr is predecessor to both N and Op.
15569     Visited.insert(Ptr.getNode());
15570     Worklist.push_back(N);
15571     Worklist.push_back(Op);
15572     if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
15573         !SDNode::hasPredecessorHelper(Op, Visited, Worklist))
15574       return Op;
15575   }
15576   return nullptr;
15577 }
15578 
15579 /// Try to combine a load/store with a add/sub of the base pointer node into a
15580 /// post-indexed load/store. The transformation folded the add/subtract into the
15581 /// new indexed load/store effectively and all of its uses are redirected to the
15582 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)15583 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
15584   if (Level < AfterLegalizeDAG)
15585     return false;
15586 
15587   bool IsLoad = true;
15588   bool IsMasked = false;
15589   SDValue Ptr;
15590   SDValue BasePtr;
15591   SDValue Offset;
15592   ISD::MemIndexedMode AM = ISD::UNINDEXED;
15593   SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
15594                                          Offset, AM, DAG, TLI);
15595   if (!Op)
15596     return false;
15597 
15598   SDValue Result;
15599   if (!IsMasked)
15600     Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
15601                                          Offset, AM)
15602                     : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
15603                                           BasePtr, Offset, AM);
15604   else
15605     Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
15606                                                BasePtr, Offset, AM)
15607                     : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
15608                                                 BasePtr, Offset, AM);
15609   ++PostIndexedNodes;
15610   ++NodesCombined;
15611   LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG);
15612              dbgs() << "\nWith: "; Result.getNode()->dump(&DAG);
15613              dbgs() << '\n');
15614   WorklistRemover DeadNodes(*this);
15615   if (IsLoad) {
15616     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
15617     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
15618   } else {
15619     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
15620   }
15621 
15622   // Finally, since the node is now dead, remove it from the graph.
15623   deleteAndRecombine(N);
15624 
15625   // Replace the uses of Use with uses of the updated base value.
15626   DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
15627                                 Result.getValue(IsLoad ? 1 : 0));
15628   deleteAndRecombine(Op);
15629   return true;
15630 }
15631 
15632 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)15633 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
15634   ISD::MemIndexedMode AM = LD->getAddressingMode();
15635   assert(AM != ISD::UNINDEXED);
15636   SDValue BP = LD->getOperand(1);
15637   SDValue Inc = LD->getOperand(2);
15638 
15639   // Some backends use TargetConstants for load offsets, but don't expect
15640   // TargetConstants in general ADD nodes. We can convert these constants into
15641   // regular Constants (if the constant is not opaque).
15642   assert((Inc.getOpcode() != ISD::TargetConstant ||
15643           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
15644          "Cannot split out indexing using opaque target constants");
15645   if (Inc.getOpcode() == ISD::TargetConstant) {
15646     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
15647     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
15648                           ConstInc->getValueType(0));
15649   }
15650 
15651   unsigned Opc =
15652       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
15653   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
15654 }
15655 
numVectorEltsOrZero(EVT T)15656 static inline ElementCount numVectorEltsOrZero(EVT T) {
15657   return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
15658 }
15659 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)15660 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
15661   Val = ST->getValue();
15662   EVT STType = Val.getValueType();
15663   EVT STMemType = ST->getMemoryVT();
15664   if (STType == STMemType)
15665     return true;
15666   if (isTypeLegal(STMemType))
15667     return false; // fail.
15668   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
15669       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
15670     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
15671     return true;
15672   }
15673   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
15674       STType.isInteger() && STMemType.isInteger()) {
15675     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
15676     return true;
15677   }
15678   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
15679     Val = DAG.getBitcast(STMemType, Val);
15680     return true;
15681   }
15682   return false; // fail.
15683 }
15684 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)15685 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
15686   EVT LDMemType = LD->getMemoryVT();
15687   EVT LDType = LD->getValueType(0);
15688   assert(Val.getValueType() == LDMemType &&
15689          "Attempting to extend value of non-matching type");
15690   if (LDType == LDMemType)
15691     return true;
15692   if (LDMemType.isInteger() && LDType.isInteger()) {
15693     switch (LD->getExtensionType()) {
15694     case ISD::NON_EXTLOAD:
15695       Val = DAG.getBitcast(LDType, Val);
15696       return true;
15697     case ISD::EXTLOAD:
15698       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
15699       return true;
15700     case ISD::SEXTLOAD:
15701       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
15702       return true;
15703     case ISD::ZEXTLOAD:
15704       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
15705       return true;
15706     }
15707   }
15708   return false;
15709 }
15710 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)15711 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
15712   if (OptLevel == CodeGenOpt::None || !LD->isSimple())
15713     return SDValue();
15714   SDValue Chain = LD->getOperand(0);
15715   StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
15716   // TODO: Relax this restriction for unordered atomics (see D66309)
15717   if (!ST || !ST->isSimple())
15718     return SDValue();
15719 
15720   EVT LDType = LD->getValueType(0);
15721   EVT LDMemType = LD->getMemoryVT();
15722   EVT STMemType = ST->getMemoryVT();
15723   EVT STType = ST->getValue().getValueType();
15724 
15725   // There are two cases to consider here:
15726   //  1. The store is fixed width and the load is scalable. In this case we
15727   //     don't know at compile time if the store completely envelops the load
15728   //     so we abandon the optimisation.
15729   //  2. The store is scalable and the load is fixed width. We could
15730   //     potentially support a limited number of cases here, but there has been
15731   //     no cost-benefit analysis to prove it's worth it.
15732   bool LdStScalable = LDMemType.isScalableVector();
15733   if (LdStScalable != STMemType.isScalableVector())
15734     return SDValue();
15735 
15736   // If we are dealing with scalable vectors on a big endian platform the
15737   // calculation of offsets below becomes trickier, since we do not know at
15738   // compile time the absolute size of the vector. Until we've done more
15739   // analysis on big-endian platforms it seems better to bail out for now.
15740   if (LdStScalable && DAG.getDataLayout().isBigEndian())
15741     return SDValue();
15742 
15743   BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
15744   BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
15745   int64_t Offset;
15746   if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
15747     return SDValue();
15748 
15749   // Normalize for Endianness. After this Offset=0 will denote that the least
15750   // significant bit in the loaded value maps to the least significant bit in
15751   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
15752   // n:th least significant byte of the stored value.
15753   if (DAG.getDataLayout().isBigEndian())
15754     Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedSize() -
15755               (int64_t)LDMemType.getStoreSizeInBits().getFixedSize()) /
15756                  8 -
15757              Offset;
15758 
15759   // Check that the stored value cover all bits that are loaded.
15760   bool STCoversLD;
15761 
15762   TypeSize LdMemSize = LDMemType.getSizeInBits();
15763   TypeSize StMemSize = STMemType.getSizeInBits();
15764   if (LdStScalable)
15765     STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
15766   else
15767     STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedSize() <=
15768                                    StMemSize.getFixedSize());
15769 
15770   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
15771     if (LD->isIndexed()) {
15772       // Cannot handle opaque target constants and we must respect the user's
15773       // request not to split indexes from loads.
15774       if (!canSplitIdx(LD))
15775         return SDValue();
15776       SDValue Idx = SplitIndexingFromLoad(LD);
15777       SDValue Ops[] = {Val, Idx, Chain};
15778       return CombineTo(LD, Ops, 3);
15779     }
15780     return CombineTo(LD, Val, Chain);
15781   };
15782 
15783   if (!STCoversLD)
15784     return SDValue();
15785 
15786   // Memory as copy space (potentially masked).
15787   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
15788     // Simple case: Direct non-truncating forwarding
15789     if (LDType.getSizeInBits() == LdMemSize)
15790       return ReplaceLd(LD, ST->getValue(), Chain);
15791     // Can we model the truncate and extension with an and mask?
15792     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
15793         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
15794       // Mask to size of LDMemType
15795       auto Mask =
15796           DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
15797                                                StMemSize.getFixedSize()),
15798                           SDLoc(ST), STType);
15799       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
15800       return ReplaceLd(LD, Val, Chain);
15801     }
15802   }
15803 
15804   // TODO: Deal with nonzero offset.
15805   if (LD->getBasePtr().isUndef() || Offset != 0)
15806     return SDValue();
15807   // Model necessary truncations / extenstions.
15808   SDValue Val;
15809   // Truncate Value To Stored Memory Size.
15810   do {
15811     if (!getTruncatedStoreValue(ST, Val))
15812       continue;
15813     if (!isTypeLegal(LDMemType))
15814       continue;
15815     if (STMemType != LDMemType) {
15816       // TODO: Support vectors? This requires extract_subvector/bitcast.
15817       if (!STMemType.isVector() && !LDMemType.isVector() &&
15818           STMemType.isInteger() && LDMemType.isInteger())
15819         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
15820       else
15821         continue;
15822     }
15823     if (!extendLoadedValueToExtension(LD, Val))
15824       continue;
15825     return ReplaceLd(LD, Val, Chain);
15826   } while (false);
15827 
15828   // On failure, cleanup dead nodes we may have created.
15829   if (Val->use_empty())
15830     deleteAndRecombine(Val.getNode());
15831   return SDValue();
15832 }
15833 
visitLOAD(SDNode * N)15834 SDValue DAGCombiner::visitLOAD(SDNode *N) {
15835   LoadSDNode *LD  = cast<LoadSDNode>(N);
15836   SDValue Chain = LD->getChain();
15837   SDValue Ptr   = LD->getBasePtr();
15838 
15839   // If load is not volatile and there are no uses of the loaded value (and
15840   // the updated indexed value in case of indexed loads), change uses of the
15841   // chain value into uses of the chain input (i.e. delete the dead load).
15842   // TODO: Allow this for unordered atomics (see D66309)
15843   if (LD->isSimple()) {
15844     if (N->getValueType(1) == MVT::Other) {
15845       // Unindexed loads.
15846       if (!N->hasAnyUseOfValue(0)) {
15847         // It's not safe to use the two value CombineTo variant here. e.g.
15848         // v1, chain2 = load chain1, loc
15849         // v2, chain3 = load chain2, loc
15850         // v3         = add v2, c
15851         // Now we replace use of chain2 with chain1.  This makes the second load
15852         // isomorphic to the one we are deleting, and thus makes this load live.
15853         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
15854                    dbgs() << "\nWith chain: "; Chain.getNode()->dump(&DAG);
15855                    dbgs() << "\n");
15856         WorklistRemover DeadNodes(*this);
15857         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
15858         AddUsersToWorklist(Chain.getNode());
15859         if (N->use_empty())
15860           deleteAndRecombine(N);
15861 
15862         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
15863       }
15864     } else {
15865       // Indexed loads.
15866       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
15867 
15868       // If this load has an opaque TargetConstant offset, then we cannot split
15869       // the indexing into an add/sub directly (that TargetConstant may not be
15870       // valid for a different type of node, and we cannot convert an opaque
15871       // target constant into a regular constant).
15872       bool CanSplitIdx = canSplitIdx(LD);
15873 
15874       if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
15875         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
15876         SDValue Index;
15877         if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
15878           Index = SplitIndexingFromLoad(LD);
15879           // Try to fold the base pointer arithmetic into subsequent loads and
15880           // stores.
15881           AddUsersToWorklist(N);
15882         } else
15883           Index = DAG.getUNDEF(N->getValueType(1));
15884         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
15885                    dbgs() << "\nWith: "; Undef.getNode()->dump(&DAG);
15886                    dbgs() << " and 2 other values\n");
15887         WorklistRemover DeadNodes(*this);
15888         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
15889         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
15890         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
15891         deleteAndRecombine(N);
15892         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
15893       }
15894     }
15895   }
15896 
15897   // If this load is directly stored, replace the load value with the stored
15898   // value.
15899   if (auto V = ForwardStoreValueToDirectLoad(LD))
15900     return V;
15901 
15902   // Try to infer better alignment information than the load already has.
15903   if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) {
15904     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
15905       if (*Alignment > LD->getAlign() &&
15906           isAligned(*Alignment, LD->getSrcValueOffset())) {
15907         SDValue NewLoad = DAG.getExtLoad(
15908             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
15909             LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
15910             LD->getMemOperand()->getFlags(), LD->getAAInfo());
15911         // NewLoad will always be N as we are only refining the alignment
15912         assert(NewLoad.getNode() == N);
15913         (void)NewLoad;
15914       }
15915     }
15916   }
15917 
15918   if (LD->isUnindexed()) {
15919     // Walk up chain skipping non-aliasing memory nodes.
15920     SDValue BetterChain = FindBetterChain(LD, Chain);
15921 
15922     // If there is a better chain.
15923     if (Chain != BetterChain) {
15924       SDValue ReplLoad;
15925 
15926       // Replace the chain to void dependency.
15927       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
15928         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
15929                                BetterChain, Ptr, LD->getMemOperand());
15930       } else {
15931         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
15932                                   LD->getValueType(0),
15933                                   BetterChain, Ptr, LD->getMemoryVT(),
15934                                   LD->getMemOperand());
15935       }
15936 
15937       // Create token factor to keep old chain connected.
15938       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
15939                                   MVT::Other, Chain, ReplLoad.getValue(1));
15940 
15941       // Replace uses with load result and token factor
15942       return CombineTo(N, ReplLoad.getValue(0), Token);
15943     }
15944   }
15945 
15946   // Try transforming N to an indexed load.
15947   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
15948     return SDValue(N, 0);
15949 
15950   // Try to slice up N to more direct loads if the slices are mapped to
15951   // different register banks or pairing can take place.
15952   if (SliceUpLoad(N))
15953     return SDValue(N, 0);
15954 
15955   return SDValue();
15956 }
15957 
15958 namespace {
15959 
15960 /// Helper structure used to slice a load in smaller loads.
15961 /// Basically a slice is obtained from the following sequence:
15962 /// Origin = load Ty1, Base
15963 /// Shift = srl Ty1 Origin, CstTy Amount
15964 /// Inst = trunc Shift to Ty2
15965 ///
15966 /// Then, it will be rewritten into:
15967 /// Slice = load SliceTy, Base + SliceOffset
15968 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
15969 ///
15970 /// SliceTy is deduced from the number of bits that are actually used to
15971 /// build Inst.
15972 struct LoadedSlice {
15973   /// Helper structure used to compute the cost of a slice.
15974   struct Cost {
15975     /// Are we optimizing for code size.
15976     bool ForCodeSize = false;
15977 
15978     /// Various cost.
15979     unsigned Loads = 0;
15980     unsigned Truncates = 0;
15981     unsigned CrossRegisterBanksCopies = 0;
15982     unsigned ZExts = 0;
15983     unsigned Shift = 0;
15984 
Cost__anon035eee1e3011::LoadedSlice::Cost15985     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
15986 
15987     /// Get the cost of one isolated slice.
Cost__anon035eee1e3011::LoadedSlice::Cost15988     Cost(const LoadedSlice &LS, bool ForCodeSize)
15989         : ForCodeSize(ForCodeSize), Loads(1) {
15990       EVT TruncType = LS.Inst->getValueType(0);
15991       EVT LoadedType = LS.getLoadedType();
15992       if (TruncType != LoadedType &&
15993           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
15994         ZExts = 1;
15995     }
15996 
15997     /// Account for slicing gain in the current cost.
15998     /// Slicing provide a few gains like removing a shift or a
15999     /// truncate. This method allows to grow the cost of the original
16000     /// load with the gain from this slice.
addSliceGain__anon035eee1e3011::LoadedSlice::Cost16001     void addSliceGain(const LoadedSlice &LS) {
16002       // Each slice saves a truncate.
16003       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
16004       if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(),
16005                               LS.Inst->getValueType(0)))
16006         ++Truncates;
16007       // If there is a shift amount, this slice gets rid of it.
16008       if (LS.Shift)
16009         ++Shift;
16010       // If this slice can merge a cross register bank copy, account for it.
16011       if (LS.canMergeExpensiveCrossRegisterBankCopy())
16012         ++CrossRegisterBanksCopies;
16013     }
16014 
operator +=__anon035eee1e3011::LoadedSlice::Cost16015     Cost &operator+=(const Cost &RHS) {
16016       Loads += RHS.Loads;
16017       Truncates += RHS.Truncates;
16018       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
16019       ZExts += RHS.ZExts;
16020       Shift += RHS.Shift;
16021       return *this;
16022     }
16023 
operator ==__anon035eee1e3011::LoadedSlice::Cost16024     bool operator==(const Cost &RHS) const {
16025       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
16026              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
16027              ZExts == RHS.ZExts && Shift == RHS.Shift;
16028     }
16029 
operator !=__anon035eee1e3011::LoadedSlice::Cost16030     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
16031 
operator <__anon035eee1e3011::LoadedSlice::Cost16032     bool operator<(const Cost &RHS) const {
16033       // Assume cross register banks copies are as expensive as loads.
16034       // FIXME: Do we want some more target hooks?
16035       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
16036       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
16037       // Unless we are optimizing for code size, consider the
16038       // expensive operation first.
16039       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
16040         return ExpensiveOpsLHS < ExpensiveOpsRHS;
16041       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
16042              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
16043     }
16044 
operator >__anon035eee1e3011::LoadedSlice::Cost16045     bool operator>(const Cost &RHS) const { return RHS < *this; }
16046 
operator <=__anon035eee1e3011::LoadedSlice::Cost16047     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
16048 
operator >=__anon035eee1e3011::LoadedSlice::Cost16049     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
16050   };
16051 
16052   // The last instruction that represent the slice. This should be a
16053   // truncate instruction.
16054   SDNode *Inst;
16055 
16056   // The original load instruction.
16057   LoadSDNode *Origin;
16058 
16059   // The right shift amount in bits from the original load.
16060   unsigned Shift;
16061 
16062   // The DAG from which Origin came from.
16063   // This is used to get some contextual information about legal types, etc.
16064   SelectionDAG *DAG;
16065 
LoadedSlice__anon035eee1e3011::LoadedSlice16066   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
16067               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
16068       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
16069 
16070   /// Get the bits used in a chunk of bits \p BitWidth large.
16071   /// \return Result is \p BitWidth and has used bits set to 1 and
16072   ///         not used bits set to 0.
getUsedBits__anon035eee1e3011::LoadedSlice16073   APInt getUsedBits() const {
16074     // Reproduce the trunc(lshr) sequence:
16075     // - Start from the truncated value.
16076     // - Zero extend to the desired bit width.
16077     // - Shift left.
16078     assert(Origin && "No original load to compare against.");
16079     unsigned BitWidth = Origin->getValueSizeInBits(0);
16080     assert(Inst && "This slice is not bound to an instruction");
16081     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
16082            "Extracted slice is bigger than the whole type!");
16083     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
16084     UsedBits.setAllBits();
16085     UsedBits = UsedBits.zext(BitWidth);
16086     UsedBits <<= Shift;
16087     return UsedBits;
16088   }
16089 
16090   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon035eee1e3011::LoadedSlice16091   unsigned getLoadedSize() const {
16092     unsigned SliceSize = getUsedBits().countPopulation();
16093     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
16094     return SliceSize / 8;
16095   }
16096 
16097   /// Get the type that will be loaded for this slice.
16098   /// Note: This may not be the final type for the slice.
getLoadedType__anon035eee1e3011::LoadedSlice16099   EVT getLoadedType() const {
16100     assert(DAG && "Missing context");
16101     LLVMContext &Ctxt = *DAG->getContext();
16102     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
16103   }
16104 
16105   /// Get the alignment of the load used for this slice.
getAlign__anon035eee1e3011::LoadedSlice16106   Align getAlign() const {
16107     Align Alignment = Origin->getAlign();
16108     uint64_t Offset = getOffsetFromBase();
16109     if (Offset != 0)
16110       Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
16111     return Alignment;
16112   }
16113 
16114   /// Check if this slice can be rewritten with legal operations.
isLegal__anon035eee1e3011::LoadedSlice16115   bool isLegal() const {
16116     // An invalid slice is not legal.
16117     if (!Origin || !Inst || !DAG)
16118       return false;
16119 
16120     // Offsets are for indexed load only, we do not handle that.
16121     if (!Origin->getOffset().isUndef())
16122       return false;
16123 
16124     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
16125 
16126     // Check that the type is legal.
16127     EVT SliceType = getLoadedType();
16128     if (!TLI.isTypeLegal(SliceType))
16129       return false;
16130 
16131     // Check that the load is legal for this type.
16132     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
16133       return false;
16134 
16135     // Check that the offset can be computed.
16136     // 1. Check its type.
16137     EVT PtrType = Origin->getBasePtr().getValueType();
16138     if (PtrType == MVT::Untyped || PtrType.isExtended())
16139       return false;
16140 
16141     // 2. Check that it fits in the immediate.
16142     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
16143       return false;
16144 
16145     // 3. Check that the computation is legal.
16146     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
16147       return false;
16148 
16149     // Check that the zext is legal if it needs one.
16150     EVT TruncateType = Inst->getValueType(0);
16151     if (TruncateType != SliceType &&
16152         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
16153       return false;
16154 
16155     return true;
16156   }
16157 
16158   /// Get the offset in bytes of this slice in the original chunk of
16159   /// bits.
16160   /// \pre DAG != nullptr.
getOffsetFromBase__anon035eee1e3011::LoadedSlice16161   uint64_t getOffsetFromBase() const {
16162     assert(DAG && "Missing context.");
16163     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
16164     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
16165     uint64_t Offset = Shift / 8;
16166     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
16167     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
16168            "The size of the original loaded type is not a multiple of a"
16169            " byte.");
16170     // If Offset is bigger than TySizeInBytes, it means we are loading all
16171     // zeros. This should have been optimized before in the process.
16172     assert(TySizeInBytes > Offset &&
16173            "Invalid shift amount for given loaded size");
16174     if (IsBigEndian)
16175       Offset = TySizeInBytes - Offset - getLoadedSize();
16176     return Offset;
16177   }
16178 
16179   /// Generate the sequence of instructions to load the slice
16180   /// represented by this object and redirect the uses of this slice to
16181   /// this new sequence of instructions.
16182   /// \pre this->Inst && this->Origin are valid Instructions and this
16183   /// object passed the legal check: LoadedSlice::isLegal returned true.
16184   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon035eee1e3011::LoadedSlice16185   SDValue loadSlice() const {
16186     assert(Inst && Origin && "Unable to replace a non-existing slice.");
16187     const SDValue &OldBaseAddr = Origin->getBasePtr();
16188     SDValue BaseAddr = OldBaseAddr;
16189     // Get the offset in that chunk of bytes w.r.t. the endianness.
16190     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
16191     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
16192     if (Offset) {
16193       // BaseAddr = BaseAddr + Offset.
16194       EVT ArithType = BaseAddr.getValueType();
16195       SDLoc DL(Origin);
16196       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
16197                               DAG->getConstant(Offset, DL, ArithType));
16198     }
16199 
16200     // Create the type of the loaded slice according to its size.
16201     EVT SliceType = getLoadedType();
16202 
16203     // Create the load for the slice.
16204     SDValue LastInst =
16205         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
16206                      Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
16207                      Origin->getMemOperand()->getFlags());
16208     // If the final type is not the same as the loaded type, this means that
16209     // we have to pad with zero. Create a zero extend for that.
16210     EVT FinalType = Inst->getValueType(0);
16211     if (SliceType != FinalType)
16212       LastInst =
16213           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
16214     return LastInst;
16215   }
16216 
16217   /// Check if this slice can be merged with an expensive cross register
16218   /// bank copy. E.g.,
16219   /// i = load i32
16220   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon035eee1e3011::LoadedSlice16221   bool canMergeExpensiveCrossRegisterBankCopy() const {
16222     if (!Inst || !Inst->hasOneUse())
16223       return false;
16224     SDNode *Use = *Inst->use_begin();
16225     if (Use->getOpcode() != ISD::BITCAST)
16226       return false;
16227     assert(DAG && "Missing context");
16228     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
16229     EVT ResVT = Use->getValueType(0);
16230     const TargetRegisterClass *ResRC =
16231         TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
16232     const TargetRegisterClass *ArgRC =
16233         TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
16234                            Use->getOperand(0)->isDivergent());
16235     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
16236       return false;
16237 
16238     // At this point, we know that we perform a cross-register-bank copy.
16239     // Check if it is expensive.
16240     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
16241     // Assume bitcasts are cheap, unless both register classes do not
16242     // explicitly share a common sub class.
16243     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
16244       return false;
16245 
16246     // Check if it will be merged with the load.
16247     // 1. Check the alignment constraint.
16248     Align RequiredAlignment = DAG->getDataLayout().getABITypeAlign(
16249         ResVT.getTypeForEVT(*DAG->getContext()));
16250 
16251     if (RequiredAlignment > getAlign())
16252       return false;
16253 
16254     // 2. Check that the load is a legal operation for that type.
16255     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
16256       return false;
16257 
16258     // 3. Check that we do not have a zext in the way.
16259     if (Inst->getValueType(0) != getLoadedType())
16260       return false;
16261 
16262     return true;
16263   }
16264 };
16265 
16266 } // end anonymous namespace
16267 
16268 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
16269 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)16270 static bool areUsedBitsDense(const APInt &UsedBits) {
16271   // If all the bits are one, this is dense!
16272   if (UsedBits.isAllOnesValue())
16273     return true;
16274 
16275   // Get rid of the unused bits on the right.
16276   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros());
16277   // Get rid of the unused bits on the left.
16278   if (NarrowedUsedBits.countLeadingZeros())
16279     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
16280   // Check that the chunk of bits is completely used.
16281   return NarrowedUsedBits.isAllOnesValue();
16282 }
16283 
16284 /// Check whether or not \p First and \p Second are next to each other
16285 /// in memory. This means that there is no hole between the bits loaded
16286 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)16287 static bool areSlicesNextToEachOther(const LoadedSlice &First,
16288                                      const LoadedSlice &Second) {
16289   assert(First.Origin == Second.Origin && First.Origin &&
16290          "Unable to match different memory origins.");
16291   APInt UsedBits = First.getUsedBits();
16292   assert((UsedBits & Second.getUsedBits()) == 0 &&
16293          "Slices are not supposed to overlap.");
16294   UsedBits |= Second.getUsedBits();
16295   return areUsedBitsDense(UsedBits);
16296 }
16297 
16298 /// Adjust the \p GlobalLSCost according to the target
16299 /// paring capabilities and the layout of the slices.
16300 /// \pre \p GlobalLSCost should account for at least as many loads as
16301 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)16302 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
16303                                  LoadedSlice::Cost &GlobalLSCost) {
16304   unsigned NumberOfSlices = LoadedSlices.size();
16305   // If there is less than 2 elements, no pairing is possible.
16306   if (NumberOfSlices < 2)
16307     return;
16308 
16309   // Sort the slices so that elements that are likely to be next to each
16310   // other in memory are next to each other in the list.
16311   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
16312     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
16313     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
16314   });
16315   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
16316   // First (resp. Second) is the first (resp. Second) potentially candidate
16317   // to be placed in a paired load.
16318   const LoadedSlice *First = nullptr;
16319   const LoadedSlice *Second = nullptr;
16320   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
16321                 // Set the beginning of the pair.
16322                                                            First = Second) {
16323     Second = &LoadedSlices[CurrSlice];
16324 
16325     // If First is NULL, it means we start a new pair.
16326     // Get to the next slice.
16327     if (!First)
16328       continue;
16329 
16330     EVT LoadedType = First->getLoadedType();
16331 
16332     // If the types of the slices are different, we cannot pair them.
16333     if (LoadedType != Second->getLoadedType())
16334       continue;
16335 
16336     // Check if the target supplies paired loads for this type.
16337     Align RequiredAlignment;
16338     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
16339       // move to the next pair, this type is hopeless.
16340       Second = nullptr;
16341       continue;
16342     }
16343     // Check if we meet the alignment requirement.
16344     if (First->getAlign() < RequiredAlignment)
16345       continue;
16346 
16347     // Check that both loads are next to each other in memory.
16348     if (!areSlicesNextToEachOther(*First, *Second))
16349       continue;
16350 
16351     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
16352     --GlobalLSCost.Loads;
16353     // Move to the next pair.
16354     Second = nullptr;
16355   }
16356 }
16357 
16358 /// Check the profitability of all involved LoadedSlice.
16359 /// Currently, it is considered profitable if there is exactly two
16360 /// involved slices (1) which are (2) next to each other in memory, and
16361 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
16362 ///
16363 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
16364 /// the elements themselves.
16365 ///
16366 /// FIXME: When the cost model will be mature enough, we can relax
16367 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)16368 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
16369                                 const APInt &UsedBits, bool ForCodeSize) {
16370   unsigned NumberOfSlices = LoadedSlices.size();
16371   if (StressLoadSlicing)
16372     return NumberOfSlices > 1;
16373 
16374   // Check (1).
16375   if (NumberOfSlices != 2)
16376     return false;
16377 
16378   // Check (2).
16379   if (!areUsedBitsDense(UsedBits))
16380     return false;
16381 
16382   // Check (3).
16383   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
16384   // The original code has one big load.
16385   OrigCost.Loads = 1;
16386   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
16387     const LoadedSlice &LS = LoadedSlices[CurrSlice];
16388     // Accumulate the cost of all the slices.
16389     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
16390     GlobalSlicingCost += SliceCost;
16391 
16392     // Account as cost in the original configuration the gain obtained
16393     // with the current slices.
16394     OrigCost.addSliceGain(LS);
16395   }
16396 
16397   // If the target supports paired load, adjust the cost accordingly.
16398   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
16399   return OrigCost > GlobalSlicingCost;
16400 }
16401 
16402 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
16403 /// operations, split it in the various pieces being extracted.
16404 ///
16405 /// This sort of thing is introduced by SROA.
16406 /// This slicing takes care not to insert overlapping loads.
16407 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)16408 bool DAGCombiner::SliceUpLoad(SDNode *N) {
16409   if (Level < AfterLegalizeDAG)
16410     return false;
16411 
16412   LoadSDNode *LD = cast<LoadSDNode>(N);
16413   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
16414       !LD->getValueType(0).isInteger())
16415     return false;
16416 
16417   // The algorithm to split up a load of a scalable vector into individual
16418   // elements currently requires knowing the length of the loaded type,
16419   // so will need adjusting to work on scalable vectors.
16420   if (LD->getValueType(0).isScalableVector())
16421     return false;
16422 
16423   // Keep track of already used bits to detect overlapping values.
16424   // In that case, we will just abort the transformation.
16425   APInt UsedBits(LD->getValueSizeInBits(0), 0);
16426 
16427   SmallVector<LoadedSlice, 4> LoadedSlices;
16428 
16429   // Check if this load is used as several smaller chunks of bits.
16430   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
16431   // of computation for each trunc.
16432   for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
16433        UI != UIEnd; ++UI) {
16434     // Skip the uses of the chain.
16435     if (UI.getUse().getResNo() != 0)
16436       continue;
16437 
16438     SDNode *User = *UI;
16439     unsigned Shift = 0;
16440 
16441     // Check if this is a trunc(lshr).
16442     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
16443         isa<ConstantSDNode>(User->getOperand(1))) {
16444       Shift = User->getConstantOperandVal(1);
16445       User = *User->use_begin();
16446     }
16447 
16448     // At this point, User is a Truncate, iff we encountered, trunc or
16449     // trunc(lshr).
16450     if (User->getOpcode() != ISD::TRUNCATE)
16451       return false;
16452 
16453     // The width of the type must be a power of 2 and greater than 8-bits.
16454     // Otherwise the load cannot be represented in LLVM IR.
16455     // Moreover, if we shifted with a non-8-bits multiple, the slice
16456     // will be across several bytes. We do not support that.
16457     unsigned Width = User->getValueSizeInBits(0);
16458     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
16459       return false;
16460 
16461     // Build the slice for this chain of computations.
16462     LoadedSlice LS(User, LD, Shift, &DAG);
16463     APInt CurrentUsedBits = LS.getUsedBits();
16464 
16465     // Check if this slice overlaps with another.
16466     if ((CurrentUsedBits & UsedBits) != 0)
16467       return false;
16468     // Update the bits used globally.
16469     UsedBits |= CurrentUsedBits;
16470 
16471     // Check if the new slice would be legal.
16472     if (!LS.isLegal())
16473       return false;
16474 
16475     // Record the slice.
16476     LoadedSlices.push_back(LS);
16477   }
16478 
16479   // Abort slicing if it does not seem to be profitable.
16480   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
16481     return false;
16482 
16483   ++SlicedLoads;
16484 
16485   // Rewrite each chain to use an independent load.
16486   // By construction, each chain can be represented by a unique load.
16487 
16488   // Prepare the argument for the new token factor for all the slices.
16489   SmallVector<SDValue, 8> ArgChains;
16490   for (const LoadedSlice &LS : LoadedSlices) {
16491     SDValue SliceInst = LS.loadSlice();
16492     CombineTo(LS.Inst, SliceInst, true);
16493     if (SliceInst.getOpcode() != ISD::LOAD)
16494       SliceInst = SliceInst.getOperand(0);
16495     assert(SliceInst->getOpcode() == ISD::LOAD &&
16496            "It takes more than a zext to get to the loaded slice!!");
16497     ArgChains.push_back(SliceInst.getValue(1));
16498   }
16499 
16500   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
16501                               ArgChains);
16502   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
16503   AddToWorklist(Chain.getNode());
16504   return true;
16505 }
16506 
16507 /// Check to see if V is (and load (ptr), imm), where the load is having
16508 /// specific bytes cleared out.  If so, return the byte size being masked out
16509 /// and the shift amount.
16510 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)16511 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
16512   std::pair<unsigned, unsigned> Result(0, 0);
16513 
16514   // Check for the structure we're looking for.
16515   if (V->getOpcode() != ISD::AND ||
16516       !isa<ConstantSDNode>(V->getOperand(1)) ||
16517       !ISD::isNormalLoad(V->getOperand(0).getNode()))
16518     return Result;
16519 
16520   // Check the chain and pointer.
16521   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
16522   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
16523 
16524   // This only handles simple types.
16525   if (V.getValueType() != MVT::i16 &&
16526       V.getValueType() != MVT::i32 &&
16527       V.getValueType() != MVT::i64)
16528     return Result;
16529 
16530   // Check the constant mask.  Invert it so that the bits being masked out are
16531   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
16532   // follow the sign bit for uniformity.
16533   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
16534   unsigned NotMaskLZ = countLeadingZeros(NotMask);
16535   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
16536   unsigned NotMaskTZ = countTrailingZeros(NotMask);
16537   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
16538   if (NotMaskLZ == 64) return Result;  // All zero mask.
16539 
16540   // See if we have a continuous run of bits.  If so, we have 0*1+0*
16541   if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
16542     return Result;
16543 
16544   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
16545   if (V.getValueType() != MVT::i64 && NotMaskLZ)
16546     NotMaskLZ -= 64-V.getValueSizeInBits();
16547 
16548   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
16549   switch (MaskedBytes) {
16550   case 1:
16551   case 2:
16552   case 4: break;
16553   default: return Result; // All one mask, or 5-byte mask.
16554   }
16555 
16556   // Verify that the first bit starts at a multiple of mask so that the access
16557   // is aligned the same as the access width.
16558   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
16559 
16560   // For narrowing to be valid, it must be the case that the load the
16561   // immediately preceding memory operation before the store.
16562   if (LD == Chain.getNode())
16563     ; // ok.
16564   else if (Chain->getOpcode() == ISD::TokenFactor &&
16565            SDValue(LD, 1).hasOneUse()) {
16566     // LD has only 1 chain use so they are no indirect dependencies.
16567     if (!LD->isOperandOf(Chain.getNode()))
16568       return Result;
16569   } else
16570     return Result; // Fail.
16571 
16572   Result.first = MaskedBytes;
16573   Result.second = NotMaskTZ/8;
16574   return Result;
16575 }
16576 
16577 /// Check to see if IVal is something that provides a value as specified by
16578 /// MaskInfo. If so, replace the specified store with a narrower store of
16579 /// truncated IVal.
16580 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)16581 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
16582                                 SDValue IVal, StoreSDNode *St,
16583                                 DAGCombiner *DC) {
16584   unsigned NumBytes = MaskInfo.first;
16585   unsigned ByteShift = MaskInfo.second;
16586   SelectionDAG &DAG = DC->getDAG();
16587 
16588   // Check to see if IVal is all zeros in the part being masked in by the 'or'
16589   // that uses this.  If not, this is not a replacement.
16590   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
16591                                   ByteShift*8, (ByteShift+NumBytes)*8);
16592   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
16593 
16594   // Check that it is legal on the target to do this.  It is legal if the new
16595   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
16596   // legalization (and the target doesn't explicitly think this is a bad idea).
16597   MVT VT = MVT::getIntegerVT(NumBytes * 8);
16598   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
16599   if (!DC->isTypeLegal(VT))
16600     return SDValue();
16601   if (St->getMemOperand() &&
16602       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
16603                               *St->getMemOperand()))
16604     return SDValue();
16605 
16606   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
16607   // shifted by ByteShift and truncated down to NumBytes.
16608   if (ByteShift) {
16609     SDLoc DL(IVal);
16610     IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
16611                        DAG.getConstant(ByteShift*8, DL,
16612                                     DC->getShiftAmountTy(IVal.getValueType())));
16613   }
16614 
16615   // Figure out the offset for the store and the alignment of the access.
16616   unsigned StOffset;
16617   if (DAG.getDataLayout().isLittleEndian())
16618     StOffset = ByteShift;
16619   else
16620     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
16621 
16622   SDValue Ptr = St->getBasePtr();
16623   if (StOffset) {
16624     SDLoc DL(IVal);
16625     Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(StOffset), DL);
16626   }
16627 
16628   // Truncate down to the new size.
16629   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
16630 
16631   ++OpsNarrowed;
16632   return DAG
16633       .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
16634                 St->getPointerInfo().getWithOffset(StOffset),
16635                 St->getOriginalAlign());
16636 }
16637 
16638 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
16639 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
16640 /// narrowing the load and store if it would end up being a win for performance
16641 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)16642 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
16643   StoreSDNode *ST  = cast<StoreSDNode>(N);
16644   if (!ST->isSimple())
16645     return SDValue();
16646 
16647   SDValue Chain = ST->getChain();
16648   SDValue Value = ST->getValue();
16649   SDValue Ptr   = ST->getBasePtr();
16650   EVT VT = Value.getValueType();
16651 
16652   if (ST->isTruncatingStore() || VT.isVector() || !Value.hasOneUse())
16653     return SDValue();
16654 
16655   unsigned Opc = Value.getOpcode();
16656 
16657   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
16658   // is a byte mask indicating a consecutive number of bytes, check to see if
16659   // Y is known to provide just those bytes.  If so, we try to replace the
16660   // load + replace + store sequence with a single (narrower) store, which makes
16661   // the load dead.
16662   if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
16663     std::pair<unsigned, unsigned> MaskedLoad;
16664     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
16665     if (MaskedLoad.first)
16666       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
16667                                                   Value.getOperand(1), ST,this))
16668         return NewST;
16669 
16670     // Or is commutative, so try swapping X and Y.
16671     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
16672     if (MaskedLoad.first)
16673       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
16674                                                   Value.getOperand(0), ST,this))
16675         return NewST;
16676   }
16677 
16678   if (!EnableReduceLoadOpStoreWidth)
16679     return SDValue();
16680 
16681   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
16682       Value.getOperand(1).getOpcode() != ISD::Constant)
16683     return SDValue();
16684 
16685   SDValue N0 = Value.getOperand(0);
16686   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
16687       Chain == SDValue(N0.getNode(), 1)) {
16688     LoadSDNode *LD = cast<LoadSDNode>(N0);
16689     if (LD->getBasePtr() != Ptr ||
16690         LD->getPointerInfo().getAddrSpace() !=
16691         ST->getPointerInfo().getAddrSpace())
16692       return SDValue();
16693 
16694     // Find the type to narrow it the load / op / store to.
16695     SDValue N1 = Value.getOperand(1);
16696     unsigned BitWidth = N1.getValueSizeInBits();
16697     APInt Imm = cast<ConstantSDNode>(N1)->getAPIntValue();
16698     if (Opc == ISD::AND)
16699       Imm ^= APInt::getAllOnesValue(BitWidth);
16700     if (Imm == 0 || Imm.isAllOnesValue())
16701       return SDValue();
16702     unsigned ShAmt = Imm.countTrailingZeros();
16703     unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1;
16704     unsigned NewBW = NextPowerOf2(MSB - ShAmt);
16705     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
16706     // The narrowing should be profitable, the load/store operation should be
16707     // legal (or custom) and the store size should be equal to the NewVT width.
16708     while (NewBW < BitWidth &&
16709            (NewVT.getStoreSizeInBits() != NewBW ||
16710             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
16711             !TLI.isNarrowingProfitable(VT, NewVT))) {
16712       NewBW = NextPowerOf2(NewBW);
16713       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
16714     }
16715     if (NewBW >= BitWidth)
16716       return SDValue();
16717 
16718     // If the lsb changed does not start at the type bitwidth boundary,
16719     // start at the previous one.
16720     if (ShAmt % NewBW)
16721       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
16722     APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
16723                                    std::min(BitWidth, ShAmt + NewBW));
16724     if ((Imm & Mask) == Imm) {
16725       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
16726       if (Opc == ISD::AND)
16727         NewImm ^= APInt::getAllOnesValue(NewBW);
16728       uint64_t PtrOff = ShAmt / 8;
16729       // For big endian targets, we need to adjust the offset to the pointer to
16730       // load the correct bytes.
16731       if (DAG.getDataLayout().isBigEndian())
16732         PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
16733 
16734       Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
16735       Type *NewVTTy = NewVT.getTypeForEVT(*DAG.getContext());
16736       if (NewAlign < DAG.getDataLayout().getABITypeAlign(NewVTTy))
16737         return SDValue();
16738 
16739       SDValue NewPtr =
16740           DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOff), SDLoc(LD));
16741       SDValue NewLD =
16742           DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
16743                       LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
16744                       LD->getMemOperand()->getFlags(), LD->getAAInfo());
16745       SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
16746                                    DAG.getConstant(NewImm, SDLoc(Value),
16747                                                    NewVT));
16748       SDValue NewST =
16749           DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
16750                        ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
16751 
16752       AddToWorklist(NewPtr.getNode());
16753       AddToWorklist(NewLD.getNode());
16754       AddToWorklist(NewVal.getNode());
16755       WorklistRemover DeadNodes(*this);
16756       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
16757       ++OpsNarrowed;
16758       return NewST;
16759     }
16760   }
16761 
16762   return SDValue();
16763 }
16764 
16765 /// For a given floating point load / store pair, if the load value isn't used
16766 /// by any other operations, then consider transforming the pair to integer
16767 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)16768 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
16769   StoreSDNode *ST  = cast<StoreSDNode>(N);
16770   SDValue Value = ST->getValue();
16771   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
16772       Value.hasOneUse()) {
16773     LoadSDNode *LD = cast<LoadSDNode>(Value);
16774     EVT VT = LD->getMemoryVT();
16775     if (!VT.isFloatingPoint() ||
16776         VT != ST->getMemoryVT() ||
16777         LD->isNonTemporal() ||
16778         ST->isNonTemporal() ||
16779         LD->getPointerInfo().getAddrSpace() != 0 ||
16780         ST->getPointerInfo().getAddrSpace() != 0)
16781       return SDValue();
16782 
16783     TypeSize VTSize = VT.getSizeInBits();
16784 
16785     // We don't know the size of scalable types at compile time so we cannot
16786     // create an integer of the equivalent size.
16787     if (VTSize.isScalable())
16788       return SDValue();
16789 
16790     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedSize());
16791     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
16792         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
16793         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
16794         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT))
16795       return SDValue();
16796 
16797     Align LDAlign = LD->getAlign();
16798     Align STAlign = ST->getAlign();
16799     Type *IntVTTy = IntVT.getTypeForEVT(*DAG.getContext());
16800     Align ABIAlign = DAG.getDataLayout().getABITypeAlign(IntVTTy);
16801     if (LDAlign < ABIAlign || STAlign < ABIAlign)
16802       return SDValue();
16803 
16804     SDValue NewLD =
16805         DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
16806                     LD->getPointerInfo(), LDAlign);
16807 
16808     SDValue NewST =
16809         DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
16810                      ST->getPointerInfo(), STAlign);
16811 
16812     AddToWorklist(NewLD.getNode());
16813     AddToWorklist(NewST.getNode());
16814     WorklistRemover DeadNodes(*this);
16815     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
16816     ++LdStFP2Int;
16817     return NewST;
16818   }
16819 
16820   return SDValue();
16821 }
16822 
16823 // This is a helper function for visitMUL to check the profitability
16824 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
16825 // MulNode is the original multiply, AddNode is (add x, c1),
16826 // and ConstNode is c2.
16827 //
16828 // If the (add x, c1) has multiple uses, we could increase
16829 // the number of adds if we make this transformation.
16830 // It would only be worth doing this if we can remove a
16831 // multiply in the process. Check for that here.
16832 // To illustrate:
16833 //     (A + c1) * c3
16834 //     (A + c2) * c3
16835 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue & AddNode,SDValue & ConstNode)16836 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
16837                                               SDValue &AddNode,
16838                                               SDValue &ConstNode) {
16839   APInt Val;
16840 
16841   // If the add only has one use, this would be OK to do.
16842   if (AddNode.getNode()->hasOneUse())
16843     return true;
16844 
16845   // Walk all the users of the constant with which we're multiplying.
16846   for (SDNode *Use : ConstNode->uses()) {
16847     if (Use == MulNode) // This use is the one we're on right now. Skip it.
16848       continue;
16849 
16850     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
16851       SDNode *OtherOp;
16852       SDNode *MulVar = AddNode.getOperand(0).getNode();
16853 
16854       // OtherOp is what we're multiplying against the constant.
16855       if (Use->getOperand(0) == ConstNode)
16856         OtherOp = Use->getOperand(1).getNode();
16857       else
16858         OtherOp = Use->getOperand(0).getNode();
16859 
16860       // Check to see if multiply is with the same operand of our "add".
16861       //
16862       //     ConstNode  = CONST
16863       //     Use = ConstNode * A  <-- visiting Use. OtherOp is A.
16864       //     ...
16865       //     AddNode  = (A + c1)  <-- MulVar is A.
16866       //         = AddNode * ConstNode   <-- current visiting instruction.
16867       //
16868       // If we make this transformation, we will have a common
16869       // multiply (ConstNode * A) that we can save.
16870       if (OtherOp == MulVar)
16871         return true;
16872 
16873       // Now check to see if a future expansion will give us a common
16874       // multiply.
16875       //
16876       //     ConstNode  = CONST
16877       //     AddNode    = (A + c1)
16878       //     ...   = AddNode * ConstNode <-- current visiting instruction.
16879       //     ...
16880       //     OtherOp = (A + c2)
16881       //     Use     = OtherOp * ConstNode <-- visiting Use.
16882       //
16883       // If we make this transformation, we will have a common
16884       // multiply (CONST * A) after we also do the same transformation
16885       // to the "t2" instruction.
16886       if (OtherOp->getOpcode() == ISD::ADD &&
16887           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
16888           OtherOp->getOperand(0).getNode() == MulVar)
16889         return true;
16890     }
16891   }
16892 
16893   // Didn't find a case where this would be profitable.
16894   return false;
16895 }
16896 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)16897 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
16898                                          unsigned NumStores) {
16899   SmallVector<SDValue, 8> Chains;
16900   SmallPtrSet<const SDNode *, 8> Visited;
16901   SDLoc StoreDL(StoreNodes[0].MemNode);
16902 
16903   for (unsigned i = 0; i < NumStores; ++i) {
16904     Visited.insert(StoreNodes[i].MemNode);
16905   }
16906 
16907   // don't include nodes that are children or repeated nodes.
16908   for (unsigned i = 0; i < NumStores; ++i) {
16909     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
16910       Chains.push_back(StoreNodes[i].MemNode->getChain());
16911   }
16912 
16913   assert(Chains.size() > 0 && "Chain should have generated a chain");
16914   return DAG.getTokenFactor(StoreDL, Chains);
16915 }
16916 
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)16917 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
16918     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
16919     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
16920   // Make sure we have something to merge.
16921   if (NumStores < 2)
16922     return false;
16923 
16924   assert((!UseTrunc || !UseVector) &&
16925          "This optimization cannot emit a vector truncating store");
16926 
16927   // The latest Node in the DAG.
16928   SDLoc DL(StoreNodes[0].MemNode);
16929 
16930   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
16931   unsigned SizeInBits = NumStores * ElementSizeBits;
16932   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
16933 
16934   EVT StoreTy;
16935   if (UseVector) {
16936     unsigned Elts = NumStores * NumMemElts;
16937     // Get the type for the merged vector store.
16938     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
16939   } else
16940     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
16941 
16942   SDValue StoredVal;
16943   if (UseVector) {
16944     if (IsConstantSrc) {
16945       SmallVector<SDValue, 8> BuildVector;
16946       for (unsigned I = 0; I != NumStores; ++I) {
16947         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
16948         SDValue Val = St->getValue();
16949         // If constant is of the wrong type, convert it now.
16950         if (MemVT != Val.getValueType()) {
16951           Val = peekThroughBitcasts(Val);
16952           // Deal with constants of wrong size.
16953           if (ElementSizeBits != Val.getValueSizeInBits()) {
16954             EVT IntMemVT =
16955                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
16956             if (isa<ConstantFPSDNode>(Val)) {
16957               // Not clear how to truncate FP values.
16958               return false;
16959             } else if (auto *C = dyn_cast<ConstantSDNode>(Val))
16960               Val = DAG.getConstant(C->getAPIntValue()
16961                                         .zextOrTrunc(Val.getValueSizeInBits())
16962                                         .zextOrTrunc(ElementSizeBits),
16963                                     SDLoc(C), IntMemVT);
16964           }
16965           // Make sure correctly size type is the correct type.
16966           Val = DAG.getBitcast(MemVT, Val);
16967         }
16968         BuildVector.push_back(Val);
16969       }
16970       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
16971                                                : ISD::BUILD_VECTOR,
16972                               DL, StoreTy, BuildVector);
16973     } else {
16974       SmallVector<SDValue, 8> Ops;
16975       for (unsigned i = 0; i < NumStores; ++i) {
16976         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
16977         SDValue Val = peekThroughBitcasts(St->getValue());
16978         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
16979         // type MemVT. If the underlying value is not the correct
16980         // type, but it is an extraction of an appropriate vector we
16981         // can recast Val to be of the correct type. This may require
16982         // converting between EXTRACT_VECTOR_ELT and
16983         // EXTRACT_SUBVECTOR.
16984         if ((MemVT != Val.getValueType()) &&
16985             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
16986              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
16987           EVT MemVTScalarTy = MemVT.getScalarType();
16988           // We may need to add a bitcast here to get types to line up.
16989           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
16990             Val = DAG.getBitcast(MemVT, Val);
16991           } else {
16992             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
16993                                             : ISD::EXTRACT_VECTOR_ELT;
16994             SDValue Vec = Val.getOperand(0);
16995             SDValue Idx = Val.getOperand(1);
16996             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
16997           }
16998         }
16999         Ops.push_back(Val);
17000       }
17001 
17002       // Build the extracted vector elements back into a vector.
17003       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
17004                                                : ISD::BUILD_VECTOR,
17005                               DL, StoreTy, Ops);
17006     }
17007   } else {
17008     // We should always use a vector store when merging extracted vector
17009     // elements, so this path implies a store of constants.
17010     assert(IsConstantSrc && "Merged vector elements should use vector store");
17011 
17012     APInt StoreInt(SizeInBits, 0);
17013 
17014     // Construct a single integer constant which is made of the smaller
17015     // constant inputs.
17016     bool IsLE = DAG.getDataLayout().isLittleEndian();
17017     for (unsigned i = 0; i < NumStores; ++i) {
17018       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
17019       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
17020 
17021       SDValue Val = St->getValue();
17022       Val = peekThroughBitcasts(Val);
17023       StoreInt <<= ElementSizeBits;
17024       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
17025         StoreInt |= C->getAPIntValue()
17026                         .zextOrTrunc(ElementSizeBits)
17027                         .zextOrTrunc(SizeInBits);
17028       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
17029         StoreInt |= C->getValueAPF()
17030                         .bitcastToAPInt()
17031                         .zextOrTrunc(ElementSizeBits)
17032                         .zextOrTrunc(SizeInBits);
17033         // If fp truncation is necessary give up for now.
17034         if (MemVT.getSizeInBits() != ElementSizeBits)
17035           return false;
17036       } else {
17037         llvm_unreachable("Invalid constant element type");
17038       }
17039     }
17040 
17041     // Create the new Load and Store operations.
17042     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
17043   }
17044 
17045   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
17046   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
17047 
17048   // make sure we use trunc store if it's necessary to be legal.
17049   SDValue NewStore;
17050   if (!UseTrunc) {
17051     NewStore =
17052         DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
17053                      FirstInChain->getPointerInfo(), FirstInChain->getAlign());
17054   } else { // Must be realized as a trunc store
17055     EVT LegalizedStoredValTy =
17056         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
17057     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
17058     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
17059     SDValue ExtendedStoreVal =
17060         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
17061                         LegalizedStoredValTy);
17062     NewStore = DAG.getTruncStore(
17063         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
17064         FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
17065         FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags());
17066   }
17067 
17068   // Replace all merged stores with the new store.
17069   for (unsigned i = 0; i < NumStores; ++i)
17070     CombineTo(StoreNodes[i].MemNode, NewStore);
17071 
17072   AddToWorklist(NewChain.getNode());
17073   return true;
17074 }
17075 
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)17076 void DAGCombiner::getStoreMergeCandidates(
17077     StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
17078     SDNode *&RootNode) {
17079   // This holds the base pointer, index, and the offset in bytes from the base
17080   // pointer. We must have a base and an offset. Do not handle stores to undef
17081   // base pointers.
17082   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
17083   if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
17084     return;
17085 
17086   SDValue Val = peekThroughBitcasts(St->getValue());
17087   StoreSource StoreSrc = getStoreSource(Val);
17088   assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
17089 
17090   // Match on loadbaseptr if relevant.
17091   EVT MemVT = St->getMemoryVT();
17092   BaseIndexOffset LBasePtr;
17093   EVT LoadVT;
17094   if (StoreSrc == StoreSource::Load) {
17095     auto *Ld = cast<LoadSDNode>(Val);
17096     LBasePtr = BaseIndexOffset::match(Ld, DAG);
17097     LoadVT = Ld->getMemoryVT();
17098     // Load and store should be the same type.
17099     if (MemVT != LoadVT)
17100       return;
17101     // Loads must only have one use.
17102     if (!Ld->hasNUsesOfValue(1, 0))
17103       return;
17104     // The memory operands must not be volatile/indexed/atomic.
17105     // TODO: May be able to relax for unordered atomics (see D66309)
17106     if (!Ld->isSimple() || Ld->isIndexed())
17107       return;
17108   }
17109   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
17110                             int64_t &Offset) -> bool {
17111     // The memory operands must not be volatile/indexed/atomic.
17112     // TODO: May be able to relax for unordered atomics (see D66309)
17113     if (!Other->isSimple() || Other->isIndexed())
17114       return false;
17115     // Don't mix temporal stores with non-temporal stores.
17116     if (St->isNonTemporal() != Other->isNonTemporal())
17117       return false;
17118     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
17119     // Allow merging constants of different types as integers.
17120     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
17121                                            : Other->getMemoryVT() != MemVT;
17122     switch (StoreSrc) {
17123     case StoreSource::Load: {
17124       if (NoTypeMatch)
17125         return false;
17126       // The Load's Base Ptr must also match.
17127       auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
17128       if (!OtherLd)
17129         return false;
17130       BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
17131       if (LoadVT != OtherLd->getMemoryVT())
17132         return false;
17133       // Loads must only have one use.
17134       if (!OtherLd->hasNUsesOfValue(1, 0))
17135         return false;
17136       // The memory operands must not be volatile/indexed/atomic.
17137       // TODO: May be able to relax for unordered atomics (see D66309)
17138       if (!OtherLd->isSimple() || OtherLd->isIndexed())
17139         return false;
17140       // Don't mix temporal loads with non-temporal loads.
17141       if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
17142         return false;
17143       if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
17144         return false;
17145       break;
17146     }
17147     case StoreSource::Constant:
17148       if (NoTypeMatch)
17149         return false;
17150       if (!isIntOrFPConstant(OtherBC))
17151         return false;
17152       break;
17153     case StoreSource::Extract:
17154       // Do not merge truncated stores here.
17155       if (Other->isTruncatingStore())
17156         return false;
17157       if (!MemVT.bitsEq(OtherBC.getValueType()))
17158         return false;
17159       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
17160           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
17161         return false;
17162       break;
17163     default:
17164       llvm_unreachable("Unhandled store source for merging");
17165     }
17166     Ptr = BaseIndexOffset::match(Other, DAG);
17167     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
17168   };
17169 
17170   // Check if the pair of StoreNode and the RootNode already bail out many
17171   // times which is over the limit in dependence check.
17172   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
17173                                         SDNode *RootNode) -> bool {
17174     auto RootCount = StoreRootCountMap.find(StoreNode);
17175     return RootCount != StoreRootCountMap.end() &&
17176            RootCount->second.first == RootNode &&
17177            RootCount->second.second > StoreMergeDependenceLimit;
17178   };
17179 
17180   auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
17181     // This must be a chain use.
17182     if (UseIter.getOperandNo() != 0)
17183       return;
17184     if (auto *OtherStore = dyn_cast<StoreSDNode>(*UseIter)) {
17185       BaseIndexOffset Ptr;
17186       int64_t PtrDiff;
17187       if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
17188           !OverLimitInDependenceCheck(OtherStore, RootNode))
17189         StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
17190     }
17191   };
17192 
17193   // We looking for a root node which is an ancestor to all mergable
17194   // stores. We search up through a load, to our root and then down
17195   // through all children. For instance we will find Store{1,2,3} if
17196   // St is Store1, Store2. or Store3 where the root is not a load
17197   // which always true for nonvolatile ops. TODO: Expand
17198   // the search to find all valid candidates through multiple layers of loads.
17199   //
17200   // Root
17201   // |-------|-------|
17202   // Load    Load    Store3
17203   // |       |
17204   // Store1   Store2
17205   //
17206   // FIXME: We should be able to climb and
17207   // descend TokenFactors to find candidates as well.
17208 
17209   RootNode = St->getChain().getNode();
17210 
17211   unsigned NumNodesExplored = 0;
17212   const unsigned MaxSearchNodes = 1024;
17213   if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
17214     RootNode = Ldn->getChain().getNode();
17215     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
17216          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
17217       if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) { // walk down chain
17218         for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
17219           TryToAddCandidate(I2);
17220       }
17221     }
17222   } else {
17223     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
17224          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
17225       TryToAddCandidate(I);
17226   }
17227 }
17228 
17229 // We need to check that merging these stores does not cause a loop in
17230 // the DAG. Any store candidate may depend on another candidate
17231 // indirectly through its operand (we already consider dependencies
17232 // through the chain). Check in parallel by searching up from
17233 // non-chain operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)17234 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
17235     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
17236     SDNode *RootNode) {
17237   // FIXME: We should be able to truncate a full search of
17238   // predecessors by doing a BFS and keeping tabs the originating
17239   // stores from which worklist nodes come from in a similar way to
17240   // TokenFactor simplfication.
17241 
17242   SmallPtrSet<const SDNode *, 32> Visited;
17243   SmallVector<const SDNode *, 8> Worklist;
17244 
17245   // RootNode is a predecessor to all candidates so we need not search
17246   // past it. Add RootNode (peeking through TokenFactors). Do not count
17247   // these towards size check.
17248 
17249   Worklist.push_back(RootNode);
17250   while (!Worklist.empty()) {
17251     auto N = Worklist.pop_back_val();
17252     if (!Visited.insert(N).second)
17253       continue; // Already present in Visited.
17254     if (N->getOpcode() == ISD::TokenFactor) {
17255       for (SDValue Op : N->ops())
17256         Worklist.push_back(Op.getNode());
17257     }
17258   }
17259 
17260   // Don't count pruning nodes towards max.
17261   unsigned int Max = 1024 + Visited.size();
17262   // Search Ops of store candidates.
17263   for (unsigned i = 0; i < NumStores; ++i) {
17264     SDNode *N = StoreNodes[i].MemNode;
17265     // Of the 4 Store Operands:
17266     //   * Chain (Op 0) -> We have already considered these
17267     //                    in candidate selection and can be
17268     //                    safely ignored
17269     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
17270     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
17271     //                       but aren't necessarily fromt the same base node, so
17272     //                       cycles possible (e.g. via indexed store).
17273     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
17274     //               non-indexed stores). Not constant on all targets (e.g. ARM)
17275     //               and so can participate in a cycle.
17276     for (unsigned j = 1; j < N->getNumOperands(); ++j)
17277       Worklist.push_back(N->getOperand(j).getNode());
17278   }
17279   // Search through DAG. We can stop early if we find a store node.
17280   for (unsigned i = 0; i < NumStores; ++i)
17281     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
17282                                      Max)) {
17283       // If the searching bail out, record the StoreNode and RootNode in the
17284       // StoreRootCountMap. If we have seen the pair many times over a limit,
17285       // we won't add the StoreNode into StoreNodes set again.
17286       if (Visited.size() >= Max) {
17287         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
17288         if (RootCount.first == RootNode)
17289           RootCount.second++;
17290         else
17291           RootCount = {RootNode, 1};
17292       }
17293       return false;
17294     }
17295   return true;
17296 }
17297 
17298 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const17299 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
17300                                   int64_t ElementSizeBytes) const {
17301   while (true) {
17302     // Find a store past the width of the first store.
17303     size_t StartIdx = 0;
17304     while ((StartIdx + 1 < StoreNodes.size()) &&
17305            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
17306               StoreNodes[StartIdx + 1].OffsetFromBase)
17307       ++StartIdx;
17308 
17309     // Bail if we don't have enough candidates to merge.
17310     if (StartIdx + 1 >= StoreNodes.size())
17311       return 0;
17312 
17313     // Trim stores that overlapped with the first store.
17314     if (StartIdx)
17315       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
17316 
17317     // Scan the memory operations on the chain and find the first
17318     // non-consecutive store memory address.
17319     unsigned NumConsecutiveStores = 1;
17320     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
17321     // Check that the addresses are consecutive starting from the second
17322     // element in the list of stores.
17323     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
17324       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
17325       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
17326         break;
17327       NumConsecutiveStores = i + 1;
17328     }
17329     if (NumConsecutiveStores > 1)
17330       return NumConsecutiveStores;
17331 
17332     // There are no consecutive stores at the start of the list.
17333     // Remove the first store and try again.
17334     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
17335   }
17336 }
17337 
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)17338 bool DAGCombiner::tryStoreMergeOfConstants(
17339     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
17340     EVT MemVT, SDNode *RootNode, bool AllowVectors) {
17341   LLVMContext &Context = *DAG.getContext();
17342   const DataLayout &DL = DAG.getDataLayout();
17343   int64_t ElementSizeBytes = MemVT.getStoreSize();
17344   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
17345   bool MadeChange = false;
17346 
17347   // Store the constants into memory as one consecutive store.
17348   while (NumConsecutiveStores >= 2) {
17349     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
17350     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
17351     unsigned FirstStoreAlign = FirstInChain->getAlignment();
17352     unsigned LastLegalType = 1;
17353     unsigned LastLegalVectorType = 1;
17354     bool LastIntegerTrunc = false;
17355     bool NonZero = false;
17356     unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
17357     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
17358       StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
17359       SDValue StoredVal = ST->getValue();
17360       bool IsElementZero = false;
17361       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
17362         IsElementZero = C->isNullValue();
17363       else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
17364         IsElementZero = C->getConstantFPValue()->isNullValue();
17365       if (IsElementZero) {
17366         if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
17367           FirstZeroAfterNonZero = i;
17368       }
17369       NonZero |= !IsElementZero;
17370 
17371       // Find a legal type for the constant store.
17372       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
17373       EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
17374       bool IsFast = false;
17375 
17376       // Break early when size is too large to be legal.
17377       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
17378         break;
17379 
17380       if (TLI.isTypeLegal(StoreTy) &&
17381           TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
17382           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17383                                  *FirstInChain->getMemOperand(), &IsFast) &&
17384           IsFast) {
17385         LastIntegerTrunc = false;
17386         LastLegalType = i + 1;
17387         // Or check whether a truncstore is legal.
17388       } else if (TLI.getTypeAction(Context, StoreTy) ==
17389                  TargetLowering::TypePromoteInteger) {
17390         EVT LegalizedStoredValTy =
17391             TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
17392         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
17393             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
17394             TLI.allowsMemoryAccess(Context, DL, StoreTy,
17395                                    *FirstInChain->getMemOperand(), &IsFast) &&
17396             IsFast) {
17397           LastIntegerTrunc = true;
17398           LastLegalType = i + 1;
17399         }
17400       }
17401 
17402       // We only use vectors if the constant is known to be zero or the
17403       // target allows it and the function is not marked with the
17404       // noimplicitfloat attribute.
17405       if ((!NonZero ||
17406            TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
17407           AllowVectors) {
17408         // Find a legal type for the vector store.
17409         unsigned Elts = (i + 1) * NumMemElts;
17410         EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
17411         if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
17412             TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
17413             TLI.allowsMemoryAccess(Context, DL, Ty,
17414                                    *FirstInChain->getMemOperand(), &IsFast) &&
17415             IsFast)
17416           LastLegalVectorType = i + 1;
17417       }
17418     }
17419 
17420     bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
17421     unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
17422     bool UseTrunc = LastIntegerTrunc && !UseVector;
17423 
17424     // Check if we found a legal integer type that creates a meaningful
17425     // merge.
17426     if (NumElem < 2) {
17427       // We know that candidate stores are in order and of correct
17428       // shape. While there is no mergeable sequence from the
17429       // beginning one may start later in the sequence. The only
17430       // reason a merge of size N could have failed where another of
17431       // the same size would not have, is if the alignment has
17432       // improved or we've dropped a non-zero value. Drop as many
17433       // candidates as we can here.
17434       unsigned NumSkip = 1;
17435       while ((NumSkip < NumConsecutiveStores) &&
17436              (NumSkip < FirstZeroAfterNonZero) &&
17437              (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
17438         NumSkip++;
17439 
17440       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
17441       NumConsecutiveStores -= NumSkip;
17442       continue;
17443     }
17444 
17445     // Check that we can merge these candidates without causing a cycle.
17446     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
17447                                                   RootNode)) {
17448       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
17449       NumConsecutiveStores -= NumElem;
17450       continue;
17451     }
17452 
17453     MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
17454                                                   /*IsConstantSrc*/ true,
17455                                                   UseVector, UseTrunc);
17456 
17457     // Remove merged stores for next iteration.
17458     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
17459     NumConsecutiveStores -= NumElem;
17460   }
17461   return MadeChange;
17462 }
17463 
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)17464 bool DAGCombiner::tryStoreMergeOfExtracts(
17465     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
17466     EVT MemVT, SDNode *RootNode) {
17467   LLVMContext &Context = *DAG.getContext();
17468   const DataLayout &DL = DAG.getDataLayout();
17469   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
17470   bool MadeChange = false;
17471 
17472   // Loop on Consecutive Stores on success.
17473   while (NumConsecutiveStores >= 2) {
17474     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
17475     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
17476     unsigned FirstStoreAlign = FirstInChain->getAlignment();
17477     unsigned NumStoresToMerge = 1;
17478     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
17479       // Find a legal type for the vector store.
17480       unsigned Elts = (i + 1) * NumMemElts;
17481       EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
17482       bool IsFast = false;
17483 
17484       // Break early when size is too large to be legal.
17485       if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
17486         break;
17487 
17488       if (TLI.isTypeLegal(Ty) && TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
17489           TLI.allowsMemoryAccess(Context, DL, Ty,
17490                                  *FirstInChain->getMemOperand(), &IsFast) &&
17491           IsFast)
17492         NumStoresToMerge = i + 1;
17493     }
17494 
17495     // Check if we found a legal integer type creating a meaningful
17496     // merge.
17497     if (NumStoresToMerge < 2) {
17498       // We know that candidate stores are in order and of correct
17499       // shape. While there is no mergeable sequence from the
17500       // beginning one may start later in the sequence. The only
17501       // reason a merge of size N could have failed where another of
17502       // the same size would not have, is if the alignment has
17503       // improved. Drop as many candidates as we can here.
17504       unsigned NumSkip = 1;
17505       while ((NumSkip < NumConsecutiveStores) &&
17506              (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
17507         NumSkip++;
17508 
17509       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
17510       NumConsecutiveStores -= NumSkip;
17511       continue;
17512     }
17513 
17514     // Check that we can merge these candidates without causing a cycle.
17515     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
17516                                                   RootNode)) {
17517       StoreNodes.erase(StoreNodes.begin(),
17518                        StoreNodes.begin() + NumStoresToMerge);
17519       NumConsecutiveStores -= NumStoresToMerge;
17520       continue;
17521     }
17522 
17523     MadeChange |= mergeStoresOfConstantsOrVecElts(
17524         StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
17525         /*UseVector*/ true, /*UseTrunc*/ false);
17526 
17527     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
17528     NumConsecutiveStores -= NumStoresToMerge;
17529   }
17530   return MadeChange;
17531 }
17532 
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)17533 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
17534                                        unsigned NumConsecutiveStores, EVT MemVT,
17535                                        SDNode *RootNode, bool AllowVectors,
17536                                        bool IsNonTemporalStore,
17537                                        bool IsNonTemporalLoad) {
17538   LLVMContext &Context = *DAG.getContext();
17539   const DataLayout &DL = DAG.getDataLayout();
17540   int64_t ElementSizeBytes = MemVT.getStoreSize();
17541   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
17542   bool MadeChange = false;
17543 
17544   // Look for load nodes which are used by the stored values.
17545   SmallVector<MemOpLink, 8> LoadNodes;
17546 
17547   // Find acceptable loads. Loads need to have the same chain (token factor),
17548   // must not be zext, volatile, indexed, and they must be consecutive.
17549   BaseIndexOffset LdBasePtr;
17550 
17551   for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
17552     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
17553     SDValue Val = peekThroughBitcasts(St->getValue());
17554     LoadSDNode *Ld = cast<LoadSDNode>(Val);
17555 
17556     BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
17557     // If this is not the first ptr that we check.
17558     int64_t LdOffset = 0;
17559     if (LdBasePtr.getBase().getNode()) {
17560       // The base ptr must be the same.
17561       if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
17562         break;
17563     } else {
17564       // Check that all other base pointers are the same as this one.
17565       LdBasePtr = LdPtr;
17566     }
17567 
17568     // We found a potential memory operand to merge.
17569     LoadNodes.push_back(MemOpLink(Ld, LdOffset));
17570   }
17571 
17572   while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
17573     Align RequiredAlignment;
17574     bool NeedRotate = false;
17575     if (LoadNodes.size() == 2) {
17576       // If we have load/store pair instructions and we only have two values,
17577       // don't bother merging.
17578       if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
17579           StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
17580         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
17581         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
17582         break;
17583       }
17584       // If the loads are reversed, see if we can rotate the halves into place.
17585       int64_t Offset0 = LoadNodes[0].OffsetFromBase;
17586       int64_t Offset1 = LoadNodes[1].OffsetFromBase;
17587       EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
17588       if (Offset0 - Offset1 == ElementSizeBytes &&
17589           (hasOperation(ISD::ROTL, PairVT) ||
17590            hasOperation(ISD::ROTR, PairVT))) {
17591         std::swap(LoadNodes[0], LoadNodes[1]);
17592         NeedRotate = true;
17593       }
17594     }
17595     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
17596     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
17597     Align FirstStoreAlign = FirstInChain->getAlign();
17598     LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
17599 
17600     // Scan the memory operations on the chain and find the first
17601     // non-consecutive load memory address. These variables hold the index in
17602     // the store node array.
17603 
17604     unsigned LastConsecutiveLoad = 1;
17605 
17606     // This variable refers to the size and not index in the array.
17607     unsigned LastLegalVectorType = 1;
17608     unsigned LastLegalIntegerType = 1;
17609     bool isDereferenceable = true;
17610     bool DoIntegerTruncate = false;
17611     int64_t StartAddress = LoadNodes[0].OffsetFromBase;
17612     SDValue LoadChain = FirstLoad->getChain();
17613     for (unsigned i = 1; i < LoadNodes.size(); ++i) {
17614       // All loads must share the same chain.
17615       if (LoadNodes[i].MemNode->getChain() != LoadChain)
17616         break;
17617 
17618       int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
17619       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
17620         break;
17621       LastConsecutiveLoad = i;
17622 
17623       if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
17624         isDereferenceable = false;
17625 
17626       // Find a legal type for the vector store.
17627       unsigned Elts = (i + 1) * NumMemElts;
17628       EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
17629 
17630       // Break early when size is too large to be legal.
17631       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
17632         break;
17633 
17634       bool IsFastSt = false;
17635       bool IsFastLd = false;
17636       if (TLI.isTypeLegal(StoreTy) &&
17637           TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
17638           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17639                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
17640           IsFastSt &&
17641           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17642                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
17643           IsFastLd) {
17644         LastLegalVectorType = i + 1;
17645       }
17646 
17647       // Find a legal type for the integer store.
17648       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
17649       StoreTy = EVT::getIntegerVT(Context, SizeInBits);
17650       if (TLI.isTypeLegal(StoreTy) &&
17651           TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
17652           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17653                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
17654           IsFastSt &&
17655           TLI.allowsMemoryAccess(Context, DL, StoreTy,
17656                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
17657           IsFastLd) {
17658         LastLegalIntegerType = i + 1;
17659         DoIntegerTruncate = false;
17660         // Or check whether a truncstore and extload is legal.
17661       } else if (TLI.getTypeAction(Context, StoreTy) ==
17662                  TargetLowering::TypePromoteInteger) {
17663         EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
17664         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
17665             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
17666             TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
17667             TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
17668             TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
17669             TLI.allowsMemoryAccess(Context, DL, StoreTy,
17670                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
17671             IsFastSt &&
17672             TLI.allowsMemoryAccess(Context, DL, StoreTy,
17673                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
17674             IsFastLd) {
17675           LastLegalIntegerType = i + 1;
17676           DoIntegerTruncate = true;
17677         }
17678       }
17679     }
17680 
17681     // Only use vector types if the vector type is larger than the integer
17682     // type. If they are the same, use integers.
17683     bool UseVectorTy =
17684         LastLegalVectorType > LastLegalIntegerType && AllowVectors;
17685     unsigned LastLegalType =
17686         std::max(LastLegalVectorType, LastLegalIntegerType);
17687 
17688     // We add +1 here because the LastXXX variables refer to location while
17689     // the NumElem refers to array/index size.
17690     unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
17691     NumElem = std::min(LastLegalType, NumElem);
17692     Align FirstLoadAlign = FirstLoad->getAlign();
17693 
17694     if (NumElem < 2) {
17695       // We know that candidate stores are in order and of correct
17696       // shape. While there is no mergeable sequence from the
17697       // beginning one may start later in the sequence. The only
17698       // reason a merge of size N could have failed where another of
17699       // the same size would not have is if the alignment or either
17700       // the load or store has improved. Drop as many candidates as we
17701       // can here.
17702       unsigned NumSkip = 1;
17703       while ((NumSkip < LoadNodes.size()) &&
17704              (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
17705              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
17706         NumSkip++;
17707       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
17708       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
17709       NumConsecutiveStores -= NumSkip;
17710       continue;
17711     }
17712 
17713     // Check that we can merge these candidates without causing a cycle.
17714     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
17715                                                   RootNode)) {
17716       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
17717       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
17718       NumConsecutiveStores -= NumElem;
17719       continue;
17720     }
17721 
17722     // Find if it is better to use vectors or integers to load and store
17723     // to memory.
17724     EVT JointMemOpVT;
17725     if (UseVectorTy) {
17726       // Find a legal type for the vector store.
17727       unsigned Elts = NumElem * NumMemElts;
17728       JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
17729     } else {
17730       unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
17731       JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
17732     }
17733 
17734     SDLoc LoadDL(LoadNodes[0].MemNode);
17735     SDLoc StoreDL(StoreNodes[0].MemNode);
17736 
17737     // The merged loads are required to have the same incoming chain, so
17738     // using the first's chain is acceptable.
17739 
17740     SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
17741     AddToWorklist(NewStoreChain.getNode());
17742 
17743     MachineMemOperand::Flags LdMMOFlags =
17744         isDereferenceable ? MachineMemOperand::MODereferenceable
17745                           : MachineMemOperand::MONone;
17746     if (IsNonTemporalLoad)
17747       LdMMOFlags |= MachineMemOperand::MONonTemporal;
17748 
17749     MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
17750                                               ? MachineMemOperand::MONonTemporal
17751                                               : MachineMemOperand::MONone;
17752 
17753     SDValue NewLoad, NewStore;
17754     if (UseVectorTy || !DoIntegerTruncate) {
17755       NewLoad = DAG.getLoad(
17756           JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
17757           FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
17758       SDValue StoreOp = NewLoad;
17759       if (NeedRotate) {
17760         unsigned LoadWidth = ElementSizeBytes * 8 * 2;
17761         assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
17762                "Unexpected type for rotate-able load pair");
17763         SDValue RotAmt =
17764             DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
17765         // Target can convert to the identical ROTR if it does not have ROTL.
17766         StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
17767       }
17768       NewStore = DAG.getStore(
17769           NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
17770           FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
17771     } else { // This must be the truncstore/extload case
17772       EVT ExtendedTy =
17773           TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
17774       NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
17775                                FirstLoad->getChain(), FirstLoad->getBasePtr(),
17776                                FirstLoad->getPointerInfo(), JointMemOpVT,
17777                                FirstLoadAlign, LdMMOFlags);
17778       NewStore = DAG.getTruncStore(
17779           NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
17780           FirstInChain->getPointerInfo(), JointMemOpVT,
17781           FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags());
17782     }
17783 
17784     // Transfer chain users from old loads to the new load.
17785     for (unsigned i = 0; i < NumElem; ++i) {
17786       LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
17787       DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
17788                                     SDValue(NewLoad.getNode(), 1));
17789     }
17790 
17791     // Replace all stores with the new store. Recursively remove corresponding
17792     // values if they are no longer used.
17793     for (unsigned i = 0; i < NumElem; ++i) {
17794       SDValue Val = StoreNodes[i].MemNode->getOperand(1);
17795       CombineTo(StoreNodes[i].MemNode, NewStore);
17796       if (Val.getNode()->use_empty())
17797         recursivelyDeleteUnusedNodes(Val.getNode());
17798     }
17799 
17800     MadeChange = true;
17801     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
17802     LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
17803     NumConsecutiveStores -= NumElem;
17804   }
17805   return MadeChange;
17806 }
17807 
mergeConsecutiveStores(StoreSDNode * St)17808 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
17809   if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
17810     return false;
17811 
17812   // TODO: Extend this function to merge stores of scalable vectors.
17813   // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
17814   // store since we know <vscale x 16 x i8> is exactly twice as large as
17815   // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
17816   EVT MemVT = St->getMemoryVT();
17817   if (MemVT.isScalableVector())
17818     return false;
17819   if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
17820     return false;
17821 
17822   // This function cannot currently deal with non-byte-sized memory sizes.
17823   int64_t ElementSizeBytes = MemVT.getStoreSize();
17824   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
17825     return false;
17826 
17827   // Do not bother looking at stored values that are not constants, loads, or
17828   // extracted vector elements.
17829   SDValue StoredVal = peekThroughBitcasts(St->getValue());
17830   const StoreSource StoreSrc = getStoreSource(StoredVal);
17831   if (StoreSrc == StoreSource::Unknown)
17832     return false;
17833 
17834   SmallVector<MemOpLink, 8> StoreNodes;
17835   SDNode *RootNode;
17836   // Find potential store merge candidates by searching through chain sub-DAG
17837   getStoreMergeCandidates(St, StoreNodes, RootNode);
17838 
17839   // Check if there is anything to merge.
17840   if (StoreNodes.size() < 2)
17841     return false;
17842 
17843   // Sort the memory operands according to their distance from the
17844   // base pointer.
17845   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
17846     return LHS.OffsetFromBase < RHS.OffsetFromBase;
17847   });
17848 
17849   bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
17850       Attribute::NoImplicitFloat);
17851   bool IsNonTemporalStore = St->isNonTemporal();
17852   bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
17853                            cast<LoadSDNode>(StoredVal)->isNonTemporal();
17854 
17855   // Store Merge attempts to merge the lowest stores. This generally
17856   // works out as if successful, as the remaining stores are checked
17857   // after the first collection of stores is merged. However, in the
17858   // case that a non-mergeable store is found first, e.g., {p[-2],
17859   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
17860   // mergeable cases. To prevent this, we prune such stores from the
17861   // front of StoreNodes here.
17862   bool MadeChange = false;
17863   while (StoreNodes.size() > 1) {
17864     unsigned NumConsecutiveStores =
17865         getConsecutiveStores(StoreNodes, ElementSizeBytes);
17866     // There are no more stores in the list to examine.
17867     if (NumConsecutiveStores == 0)
17868       return MadeChange;
17869 
17870     // We have at least 2 consecutive stores. Try to merge them.
17871     assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
17872     switch (StoreSrc) {
17873     case StoreSource::Constant:
17874       MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
17875                                              MemVT, RootNode, AllowVectors);
17876       break;
17877 
17878     case StoreSource::Extract:
17879       MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
17880                                             MemVT, RootNode);
17881       break;
17882 
17883     case StoreSource::Load:
17884       MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
17885                                          MemVT, RootNode, AllowVectors,
17886                                          IsNonTemporalStore, IsNonTemporalLoad);
17887       break;
17888 
17889     default:
17890       llvm_unreachable("Unhandled store source type");
17891     }
17892   }
17893   return MadeChange;
17894 }
17895 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)17896 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
17897   SDLoc SL(ST);
17898   SDValue ReplStore;
17899 
17900   // Replace the chain to avoid dependency.
17901   if (ST->isTruncatingStore()) {
17902     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
17903                                   ST->getBasePtr(), ST->getMemoryVT(),
17904                                   ST->getMemOperand());
17905   } else {
17906     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
17907                              ST->getMemOperand());
17908   }
17909 
17910   // Create token to keep both nodes around.
17911   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
17912                               MVT::Other, ST->getChain(), ReplStore);
17913 
17914   // Make sure the new and old chains are cleaned up.
17915   AddToWorklist(Token.getNode());
17916 
17917   // Don't add users to work list.
17918   return CombineTo(ST, Token, false);
17919 }
17920 
replaceStoreOfFPConstant(StoreSDNode * ST)17921 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
17922   SDValue Value = ST->getValue();
17923   if (Value.getOpcode() == ISD::TargetConstantFP)
17924     return SDValue();
17925 
17926   if (!ISD::isNormalStore(ST))
17927     return SDValue();
17928 
17929   SDLoc DL(ST);
17930 
17931   SDValue Chain = ST->getChain();
17932   SDValue Ptr = ST->getBasePtr();
17933 
17934   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
17935 
17936   // NOTE: If the original store is volatile, this transform must not increase
17937   // the number of stores.  For example, on x86-32 an f64 can be stored in one
17938   // processor operation but an i64 (which is not legal) requires two.  So the
17939   // transform should not be done in this case.
17940 
17941   SDValue Tmp;
17942   switch (CFP->getSimpleValueType(0).SimpleTy) {
17943   default:
17944     llvm_unreachable("Unknown FP type");
17945   case MVT::f16:    // We don't do this for these yet.
17946   case MVT::f80:
17947   case MVT::f128:
17948   case MVT::ppcf128:
17949     return SDValue();
17950   case MVT::f32:
17951     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
17952         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
17953       ;
17954       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
17955                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
17956                             MVT::i32);
17957       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
17958     }
17959 
17960     return SDValue();
17961   case MVT::f64:
17962     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
17963          ST->isSimple()) ||
17964         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
17965       ;
17966       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
17967                             getZExtValue(), SDLoc(CFP), MVT::i64);
17968       return DAG.getStore(Chain, DL, Tmp,
17969                           Ptr, ST->getMemOperand());
17970     }
17971 
17972     if (ST->isSimple() &&
17973         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
17974       // Many FP stores are not made apparent until after legalize, e.g. for
17975       // argument passing.  Since this is so common, custom legalize the
17976       // 64-bit integer store into two 32-bit stores.
17977       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
17978       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
17979       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
17980       if (DAG.getDataLayout().isBigEndian())
17981         std::swap(Lo, Hi);
17982 
17983       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
17984       AAMDNodes AAInfo = ST->getAAInfo();
17985 
17986       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
17987                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
17988       Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(4), DL);
17989       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
17990                                  ST->getPointerInfo().getWithOffset(4),
17991                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
17992       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
17993                          St0, St1);
17994     }
17995 
17996     return SDValue();
17997   }
17998 }
17999 
visitSTORE(SDNode * N)18000 SDValue DAGCombiner::visitSTORE(SDNode *N) {
18001   StoreSDNode *ST  = cast<StoreSDNode>(N);
18002   SDValue Chain = ST->getChain();
18003   SDValue Value = ST->getValue();
18004   SDValue Ptr   = ST->getBasePtr();
18005 
18006   // If this is a store of a bit convert, store the input value if the
18007   // resultant store does not need a higher alignment than the original.
18008   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
18009       ST->isUnindexed()) {
18010     EVT SVT = Value.getOperand(0).getValueType();
18011     // If the store is volatile, we only want to change the store type if the
18012     // resulting store is legal. Otherwise we might increase the number of
18013     // memory accesses. We don't care if the original type was legal or not
18014     // as we assume software couldn't rely on the number of accesses of an
18015     // illegal type.
18016     // TODO: May be able to relax for unordered atomics (see D66309)
18017     if (((!LegalOperations && ST->isSimple()) ||
18018          TLI.isOperationLegal(ISD::STORE, SVT)) &&
18019         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
18020                                      DAG, *ST->getMemOperand())) {
18021       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
18022                           ST->getMemOperand());
18023     }
18024   }
18025 
18026   // Turn 'store undef, Ptr' -> nothing.
18027   if (Value.isUndef() && ST->isUnindexed())
18028     return Chain;
18029 
18030   // Try to infer better alignment information than the store already has.
18031   if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) {
18032     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
18033       if (*Alignment > ST->getAlign() &&
18034           isAligned(*Alignment, ST->getSrcValueOffset())) {
18035         SDValue NewStore =
18036             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
18037                               ST->getMemoryVT(), *Alignment,
18038                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
18039         // NewStore will always be N as we are only refining the alignment
18040         assert(NewStore.getNode() == N);
18041         (void)NewStore;
18042       }
18043     }
18044   }
18045 
18046   // Try transforming a pair floating point load / store ops to integer
18047   // load / store ops.
18048   if (SDValue NewST = TransformFPLoadStorePair(N))
18049     return NewST;
18050 
18051   // Try transforming several stores into STORE (BSWAP).
18052   if (SDValue Store = mergeTruncStores(ST))
18053     return Store;
18054 
18055   if (ST->isUnindexed()) {
18056     // Walk up chain skipping non-aliasing memory nodes, on this store and any
18057     // adjacent stores.
18058     if (findBetterNeighborChains(ST)) {
18059       // replaceStoreChain uses CombineTo, which handled all of the worklist
18060       // manipulation. Return the original node to not do anything else.
18061       return SDValue(ST, 0);
18062     }
18063     Chain = ST->getChain();
18064   }
18065 
18066   // FIXME: is there such a thing as a truncating indexed store?
18067   if (ST->isTruncatingStore() && ST->isUnindexed() &&
18068       Value.getValueType().isInteger() &&
18069       (!isa<ConstantSDNode>(Value) ||
18070        !cast<ConstantSDNode>(Value)->isOpaque())) {
18071     APInt TruncDemandedBits =
18072         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
18073                              ST->getMemoryVT().getScalarSizeInBits());
18074 
18075     // See if we can simplify the input to this truncstore with knowledge that
18076     // only the low bits are being used.  For example:
18077     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
18078     AddToWorklist(Value.getNode());
18079     if (SDValue Shorter = DAG.GetDemandedBits(Value, TruncDemandedBits))
18080       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
18081                                ST->getMemOperand());
18082 
18083     // Otherwise, see if we can simplify the operation with
18084     // SimplifyDemandedBits, which only works if the value has a single use.
18085     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
18086       // Re-visit the store if anything changed and the store hasn't been merged
18087       // with another node (N is deleted) SimplifyDemandedBits will add Value's
18088       // node back to the worklist if necessary, but we also need to re-visit
18089       // the Store node itself.
18090       if (N->getOpcode() != ISD::DELETED_NODE)
18091         AddToWorklist(N);
18092       return SDValue(N, 0);
18093     }
18094   }
18095 
18096   // If this is a load followed by a store to the same location, then the store
18097   // is dead/noop.
18098   // TODO: Can relax for unordered atomics (see D66309)
18099   if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) {
18100     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
18101         ST->isUnindexed() && ST->isSimple() &&
18102         Ld->getAddressSpace() == ST->getAddressSpace() &&
18103         // There can't be any side effects between the load and store, such as
18104         // a call or store.
18105         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
18106       // The store is dead, remove it.
18107       return Chain;
18108     }
18109   }
18110 
18111   // TODO: Can relax for unordered atomics (see D66309)
18112   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
18113     if (ST->isUnindexed() && ST->isSimple() &&
18114         ST1->isUnindexed() && ST1->isSimple()) {
18115       if (ST1->getBasePtr() == Ptr && ST1->getValue() == Value &&
18116           ST->getMemoryVT() == ST1->getMemoryVT() &&
18117           ST->getAddressSpace() == ST1->getAddressSpace()) {
18118         // If this is a store followed by a store with the same value to the
18119         // same location, then the store is dead/noop.
18120         return Chain;
18121       }
18122 
18123       if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
18124           !ST1->getBasePtr().isUndef() &&
18125           // BaseIndexOffset and the code below requires knowing the size
18126           // of a vector, so bail out if MemoryVT is scalable.
18127           !ST->getMemoryVT().isScalableVector() &&
18128           !ST1->getMemoryVT().isScalableVector() &&
18129           ST->getAddressSpace() == ST1->getAddressSpace()) {
18130         const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
18131         const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
18132         unsigned STBitSize = ST->getMemoryVT().getFixedSizeInBits();
18133         unsigned ChainBitSize = ST1->getMemoryVT().getFixedSizeInBits();
18134         // If this is a store who's preceding store to a subset of the current
18135         // location and no one other node is chained to that store we can
18136         // effectively drop the store. Do not remove stores to undef as they may
18137         // be used as data sinks.
18138         if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) {
18139           CombineTo(ST1, ST1->getChain());
18140           return SDValue();
18141         }
18142       }
18143     }
18144   }
18145 
18146   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
18147   // truncating store.  We can do this even if this is already a truncstore.
18148   if ((Value.getOpcode() == ISD::FP_ROUND ||
18149        Value.getOpcode() == ISD::TRUNCATE) &&
18150       Value.getNode()->hasOneUse() && ST->isUnindexed() &&
18151       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
18152                                ST->getMemoryVT(), LegalOperations)) {
18153     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
18154                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
18155   }
18156 
18157   // Always perform this optimization before types are legal. If the target
18158   // prefers, also try this after legalization to catch stores that were created
18159   // by intrinsics or other nodes.
18160   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
18161     while (true) {
18162       // There can be multiple store sequences on the same chain.
18163       // Keep trying to merge store sequences until we are unable to do so
18164       // or until we merge the last store on the chain.
18165       bool Changed = mergeConsecutiveStores(ST);
18166       if (!Changed) break;
18167       // Return N as merge only uses CombineTo and no worklist clean
18168       // up is necessary.
18169       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
18170         return SDValue(N, 0);
18171     }
18172   }
18173 
18174   // Try transforming N to an indexed store.
18175   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
18176     return SDValue(N, 0);
18177 
18178   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
18179   //
18180   // Make sure to do this only after attempting to merge stores in order to
18181   //  avoid changing the types of some subset of stores due to visit order,
18182   //  preventing their merging.
18183   if (isa<ConstantFPSDNode>(ST->getValue())) {
18184     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
18185       return NewSt;
18186   }
18187 
18188   if (SDValue NewSt = splitMergedValStore(ST))
18189     return NewSt;
18190 
18191   return ReduceLoadOpStoreWidth(N);
18192 }
18193 
visitLIFETIME_END(SDNode * N)18194 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
18195   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
18196   if (!LifetimeEnd->hasOffset())
18197     return SDValue();
18198 
18199   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
18200                                         LifetimeEnd->getOffset(), false);
18201 
18202   // We walk up the chains to find stores.
18203   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
18204   while (!Chains.empty()) {
18205     SDValue Chain = Chains.pop_back_val();
18206     if (!Chain.hasOneUse())
18207       continue;
18208     switch (Chain.getOpcode()) {
18209     case ISD::TokenFactor:
18210       for (unsigned Nops = Chain.getNumOperands(); Nops;)
18211         Chains.push_back(Chain.getOperand(--Nops));
18212       break;
18213     case ISD::LIFETIME_START:
18214     case ISD::LIFETIME_END:
18215       // We can forward past any lifetime start/end that can be proven not to
18216       // alias the node.
18217       if (!isAlias(Chain.getNode(), N))
18218         Chains.push_back(Chain.getOperand(0));
18219       break;
18220     case ISD::STORE: {
18221       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
18222       // TODO: Can relax for unordered atomics (see D66309)
18223       if (!ST->isSimple() || ST->isIndexed())
18224         continue;
18225       const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
18226       // The bounds of a scalable store are not known until runtime, so this
18227       // store cannot be elided.
18228       if (StoreSize.isScalable())
18229         continue;
18230       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
18231       // If we store purely within object bounds just before its lifetime ends,
18232       // we can remove the store.
18233       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
18234                                    StoreSize.getFixedSize() * 8)) {
18235         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
18236                    dbgs() << "\nwithin LIFETIME_END of : ";
18237                    LifetimeEndBase.dump(); dbgs() << "\n");
18238         CombineTo(ST, ST->getChain());
18239         return SDValue(N, 0);
18240       }
18241     }
18242     }
18243   }
18244   return SDValue();
18245 }
18246 
18247 /// For the instruction sequence of store below, F and I values
18248 /// are bundled together as an i64 value before being stored into memory.
18249 /// Sometimes it is more efficent to generate separate stores for F and I,
18250 /// which can remove the bitwise instructions or sink them to colder places.
18251 ///
18252 ///   (store (or (zext (bitcast F to i32) to i64),
18253 ///              (shl (zext I to i64), 32)), addr)  -->
18254 ///   (store F, addr) and (store I, addr+4)
18255 ///
18256 /// Similarly, splitting for other merged store can also be beneficial, like:
18257 /// For pair of {i32, i32}, i64 store --> two i32 stores.
18258 /// For pair of {i32, i16}, i64 store --> two i32 stores.
18259 /// For pair of {i16, i16}, i32 store --> two i16 stores.
18260 /// For pair of {i16, i8},  i32 store --> two i16 stores.
18261 /// For pair of {i8, i8},   i16 store --> two i8 stores.
18262 ///
18263 /// We allow each target to determine specifically which kind of splitting is
18264 /// supported.
18265 ///
18266 /// The store patterns are commonly seen from the simple code snippet below
18267 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
18268 ///   void goo(const std::pair<int, float> &);
18269 ///   hoo() {
18270 ///     ...
18271 ///     goo(std::make_pair(tmp, ftmp));
18272 ///     ...
18273 ///   }
18274 ///
splitMergedValStore(StoreSDNode * ST)18275 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
18276   if (OptLevel == CodeGenOpt::None)
18277     return SDValue();
18278 
18279   // Can't change the number of memory accesses for a volatile store or break
18280   // atomicity for an atomic one.
18281   if (!ST->isSimple())
18282     return SDValue();
18283 
18284   SDValue Val = ST->getValue();
18285   SDLoc DL(ST);
18286 
18287   // Match OR operand.
18288   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
18289     return SDValue();
18290 
18291   // Match SHL operand and get Lower and Higher parts of Val.
18292   SDValue Op1 = Val.getOperand(0);
18293   SDValue Op2 = Val.getOperand(1);
18294   SDValue Lo, Hi;
18295   if (Op1.getOpcode() != ISD::SHL) {
18296     std::swap(Op1, Op2);
18297     if (Op1.getOpcode() != ISD::SHL)
18298       return SDValue();
18299   }
18300   Lo = Op2;
18301   Hi = Op1.getOperand(0);
18302   if (!Op1.hasOneUse())
18303     return SDValue();
18304 
18305   // Match shift amount to HalfValBitSize.
18306   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
18307   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
18308   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
18309     return SDValue();
18310 
18311   // Lo and Hi are zero-extended from int with size less equal than 32
18312   // to i64.
18313   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
18314       !Lo.getOperand(0).getValueType().isScalarInteger() ||
18315       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
18316       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
18317       !Hi.getOperand(0).getValueType().isScalarInteger() ||
18318       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
18319     return SDValue();
18320 
18321   // Use the EVT of low and high parts before bitcast as the input
18322   // of target query.
18323   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
18324                   ? Lo.getOperand(0).getValueType()
18325                   : Lo.getValueType();
18326   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
18327                    ? Hi.getOperand(0).getValueType()
18328                    : Hi.getValueType();
18329   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
18330     return SDValue();
18331 
18332   // Start to split store.
18333   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
18334   AAMDNodes AAInfo = ST->getAAInfo();
18335 
18336   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
18337   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
18338   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
18339   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
18340 
18341   SDValue Chain = ST->getChain();
18342   SDValue Ptr = ST->getBasePtr();
18343   // Lower value store.
18344   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
18345                              ST->getOriginalAlign(), MMOFlags, AAInfo);
18346   Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(HalfValBitSize / 8), DL);
18347   // Higher value store.
18348   SDValue St1 = DAG.getStore(
18349       St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
18350       ST->getOriginalAlign(), MMOFlags, AAInfo);
18351   return St1;
18352 }
18353 
18354 /// Convert a disguised subvector insertion into a shuffle:
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)18355 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
18356   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
18357          "Expected extract_vector_elt");
18358   SDValue InsertVal = N->getOperand(1);
18359   SDValue Vec = N->getOperand(0);
18360 
18361   // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
18362   // InsIndex)
18363   //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
18364   //   CONCAT_VECTORS.
18365   if (Vec.getOpcode() == ISD::VECTOR_SHUFFLE && Vec.hasOneUse() &&
18366       InsertVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
18367       isa<ConstantSDNode>(InsertVal.getOperand(1))) {
18368     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Vec.getNode());
18369     ArrayRef<int> Mask = SVN->getMask();
18370 
18371     SDValue X = Vec.getOperand(0);
18372     SDValue Y = Vec.getOperand(1);
18373 
18374     // Vec's operand 0 is using indices from 0 to N-1 and
18375     // operand 1 from N to 2N - 1, where N is the number of
18376     // elements in the vectors.
18377     SDValue InsertVal0 = InsertVal.getOperand(0);
18378     int ElementOffset = -1;
18379 
18380     // We explore the inputs of the shuffle in order to see if we find the
18381     // source of the extract_vector_elt. If so, we can use it to modify the
18382     // shuffle rather than perform an insert_vector_elt.
18383     SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
18384     ArgWorkList.emplace_back(Mask.size(), Y);
18385     ArgWorkList.emplace_back(0, X);
18386 
18387     while (!ArgWorkList.empty()) {
18388       int ArgOffset;
18389       SDValue ArgVal;
18390       std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
18391 
18392       if (ArgVal == InsertVal0) {
18393         ElementOffset = ArgOffset;
18394         break;
18395       }
18396 
18397       // Peek through concat_vector.
18398       if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
18399         int CurrentArgOffset =
18400             ArgOffset + ArgVal.getValueType().getVectorNumElements();
18401         int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
18402         for (SDValue Op : reverse(ArgVal->ops())) {
18403           CurrentArgOffset -= Step;
18404           ArgWorkList.emplace_back(CurrentArgOffset, Op);
18405         }
18406 
18407         // Make sure we went through all the elements and did not screw up index
18408         // computation.
18409         assert(CurrentArgOffset == ArgOffset);
18410       }
18411     }
18412 
18413     if (ElementOffset != -1) {
18414       SmallVector<int, 16> NewMask(Mask.begin(), Mask.end());
18415 
18416       auto *ExtrIndex = cast<ConstantSDNode>(InsertVal.getOperand(1));
18417       NewMask[InsIndex] = ElementOffset + ExtrIndex->getZExtValue();
18418       assert(NewMask[InsIndex] <
18419                  (int)(2 * Vec.getValueType().getVectorNumElements()) &&
18420              NewMask[InsIndex] >= 0 && "NewMask[InsIndex] is out of bound");
18421 
18422       SDValue LegalShuffle =
18423               TLI.buildLegalVectorShuffle(Vec.getValueType(), SDLoc(N), X,
18424                                           Y, NewMask, DAG);
18425       if (LegalShuffle)
18426         return LegalShuffle;
18427     }
18428   }
18429 
18430   // insert_vector_elt V, (bitcast X from vector type), IdxC -->
18431   // bitcast(shuffle (bitcast V), (extended X), Mask)
18432   // Note: We do not use an insert_subvector node because that requires a
18433   // legal subvector type.
18434   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
18435       !InsertVal.getOperand(0).getValueType().isVector())
18436     return SDValue();
18437 
18438   SDValue SubVec = InsertVal.getOperand(0);
18439   SDValue DestVec = N->getOperand(0);
18440   EVT SubVecVT = SubVec.getValueType();
18441   EVT VT = DestVec.getValueType();
18442   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
18443   // If the source only has a single vector element, the cost of creating adding
18444   // it to a vector is likely to exceed the cost of a insert_vector_elt.
18445   if (NumSrcElts == 1)
18446     return SDValue();
18447   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
18448   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
18449 
18450   // Step 1: Create a shuffle mask that implements this insert operation. The
18451   // vector that we are inserting into will be operand 0 of the shuffle, so
18452   // those elements are just 'i'. The inserted subvector is in the first
18453   // positions of operand 1 of the shuffle. Example:
18454   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
18455   SmallVector<int, 16> Mask(NumMaskVals);
18456   for (unsigned i = 0; i != NumMaskVals; ++i) {
18457     if (i / NumSrcElts == InsIndex)
18458       Mask[i] = (i % NumSrcElts) + NumMaskVals;
18459     else
18460       Mask[i] = i;
18461   }
18462 
18463   // Bail out if the target can not handle the shuffle we want to create.
18464   EVT SubVecEltVT = SubVecVT.getVectorElementType();
18465   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
18466   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
18467     return SDValue();
18468 
18469   // Step 2: Create a wide vector from the inserted source vector by appending
18470   // undefined elements. This is the same size as our destination vector.
18471   SDLoc DL(N);
18472   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
18473   ConcatOps[0] = SubVec;
18474   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
18475 
18476   // Step 3: Shuffle in the padded subvector.
18477   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
18478   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
18479   AddToWorklist(PaddedSubV.getNode());
18480   AddToWorklist(DestVecBC.getNode());
18481   AddToWorklist(Shuf.getNode());
18482   return DAG.getBitcast(VT, Shuf);
18483 }
18484 
visitINSERT_VECTOR_ELT(SDNode * N)18485 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
18486   SDValue InVec = N->getOperand(0);
18487   SDValue InVal = N->getOperand(1);
18488   SDValue EltNo = N->getOperand(2);
18489   SDLoc DL(N);
18490 
18491   EVT VT = InVec.getValueType();
18492   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
18493 
18494   // Insert into out-of-bounds element is undefined.
18495   if (IndexC && VT.isFixedLengthVector() &&
18496       IndexC->getZExtValue() >= VT.getVectorNumElements())
18497     return DAG.getUNDEF(VT);
18498 
18499   // Remove redundant insertions:
18500   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
18501   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
18502       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
18503     return InVec;
18504 
18505   if (!IndexC) {
18506     // If this is variable insert to undef vector, it might be better to splat:
18507     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
18508     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
18509       if (VT.isScalableVector())
18510         return DAG.getSplatVector(VT, DL, InVal);
18511       else {
18512         SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), InVal);
18513         return DAG.getBuildVector(VT, DL, Ops);
18514       }
18515     }
18516     return SDValue();
18517   }
18518 
18519   if (VT.isScalableVector())
18520     return SDValue();
18521 
18522   unsigned NumElts = VT.getVectorNumElements();
18523 
18524   // We must know which element is being inserted for folds below here.
18525   unsigned Elt = IndexC->getZExtValue();
18526   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
18527     return Shuf;
18528 
18529   // Canonicalize insert_vector_elt dag nodes.
18530   // Example:
18531   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
18532   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
18533   //
18534   // Do this only if the child insert_vector node has one use; also
18535   // do this only if indices are both constants and Idx1 < Idx0.
18536   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
18537       && isa<ConstantSDNode>(InVec.getOperand(2))) {
18538     unsigned OtherElt = InVec.getConstantOperandVal(2);
18539     if (Elt < OtherElt) {
18540       // Swap nodes.
18541       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
18542                                   InVec.getOperand(0), InVal, EltNo);
18543       AddToWorklist(NewOp.getNode());
18544       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
18545                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
18546     }
18547   }
18548 
18549   // If we can't generate a legal BUILD_VECTOR, exit
18550   if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
18551     return SDValue();
18552 
18553   // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially
18554   // be converted to a BUILD_VECTOR).  Fill in the Ops vector with the
18555   // vector elements.
18556   SmallVector<SDValue, 8> Ops;
18557   // Do not combine these two vectors if the output vector will not replace
18558   // the input vector.
18559   if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) {
18560     Ops.append(InVec.getNode()->op_begin(),
18561                InVec.getNode()->op_end());
18562   } else if (InVec.isUndef()) {
18563     Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType()));
18564   } else {
18565     return SDValue();
18566   }
18567   assert(Ops.size() == NumElts && "Unexpected vector size");
18568 
18569   // Insert the element
18570   if (Elt < Ops.size()) {
18571     // All the operands of BUILD_VECTOR must have the same type;
18572     // we enforce that here.
18573     EVT OpVT = Ops[0].getValueType();
18574     Ops[Elt] = OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal;
18575   }
18576 
18577   // Return the new vector
18578   return DAG.getBuildVector(VT, DL, Ops);
18579 }
18580 
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)18581 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
18582                                                   SDValue EltNo,
18583                                                   LoadSDNode *OriginalLoad) {
18584   assert(OriginalLoad->isSimple());
18585 
18586   EVT ResultVT = EVE->getValueType(0);
18587   EVT VecEltVT = InVecVT.getVectorElementType();
18588 
18589   // If the vector element type is not a multiple of a byte then we are unable
18590   // to correctly compute an address to load only the extracted element as a
18591   // scalar.
18592   if (!VecEltVT.isByteSized())
18593     return SDValue();
18594 
18595   Align Alignment = OriginalLoad->getAlign();
18596   Align NewAlign = DAG.getDataLayout().getABITypeAlign(
18597       VecEltVT.getTypeForEVT(*DAG.getContext()));
18598 
18599   if (NewAlign > Alignment ||
18600       !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT))
18601     return SDValue();
18602 
18603   ISD::LoadExtType ExtTy = ResultVT.bitsGT(VecEltVT) ?
18604     ISD::NON_EXTLOAD : ISD::EXTLOAD;
18605   if (!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
18606     return SDValue();
18607 
18608   Alignment = NewAlign;
18609 
18610   MachinePointerInfo MPI;
18611   SDLoc DL(EVE);
18612   if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
18613     int Elt = ConstEltNo->getZExtValue();
18614     unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
18615     MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
18616   } else {
18617     // Discard the pointer info except the address space because the memory
18618     // operand can't represent this new access since the offset is variable.
18619     MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
18620   }
18621   SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
18622                                                InVecVT, EltNo);
18623 
18624   // The replacement we need to do here is a little tricky: we need to
18625   // replace an extractelement of a load with a load.
18626   // Use ReplaceAllUsesOfValuesWith to do the replacement.
18627   // Note that this replacement assumes that the extractvalue is the only
18628   // use of the load; that's okay because we don't want to perform this
18629   // transformation in other cases anyway.
18630   SDValue Load;
18631   SDValue Chain;
18632   if (ResultVT.bitsGT(VecEltVT)) {
18633     // If the result type of vextract is wider than the load, then issue an
18634     // extending load instead.
18635     ISD::LoadExtType ExtType = TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT,
18636                                                   VecEltVT)
18637                                    ? ISD::ZEXTLOAD
18638                                    : ISD::EXTLOAD;
18639     Load = DAG.getExtLoad(ExtType, SDLoc(EVE), ResultVT,
18640                           OriginalLoad->getChain(), NewPtr, MPI, VecEltVT,
18641                           Alignment, OriginalLoad->getMemOperand()->getFlags(),
18642                           OriginalLoad->getAAInfo());
18643     Chain = Load.getValue(1);
18644   } else {
18645     Load = DAG.getLoad(
18646         VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr, MPI, Alignment,
18647         OriginalLoad->getMemOperand()->getFlags(), OriginalLoad->getAAInfo());
18648     Chain = Load.getValue(1);
18649     if (ResultVT.bitsLT(VecEltVT))
18650       Load = DAG.getNode(ISD::TRUNCATE, SDLoc(EVE), ResultVT, Load);
18651     else
18652       Load = DAG.getBitcast(ResultVT, Load);
18653   }
18654   WorklistRemover DeadNodes(*this);
18655   SDValue From[] = { SDValue(EVE, 0), SDValue(OriginalLoad, 1) };
18656   SDValue To[] = { Load, Chain };
18657   DAG.ReplaceAllUsesOfValuesWith(From, To, 2);
18658   // Make sure to revisit this node to clean it up; it will usually be dead.
18659   AddToWorklist(EVE);
18660   // Since we're explicitly calling ReplaceAllUses, add the new node to the
18661   // worklist explicitly as well.
18662   AddToWorklistWithUsers(Load.getNode());
18663   ++OpsNarrowed;
18664   return SDValue(EVE, 0);
18665 }
18666 
18667 /// Transform a vector binary operation into a scalar binary operation by moving
18668 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)18669 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
18670                                        bool LegalOperations) {
18671   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18672   SDValue Vec = ExtElt->getOperand(0);
18673   SDValue Index = ExtElt->getOperand(1);
18674   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
18675   if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
18676       Vec.getNode()->getNumValues() != 1)
18677     return SDValue();
18678 
18679   // Targets may want to avoid this to prevent an expensive register transfer.
18680   if (!TLI.shouldScalarizeBinop(Vec))
18681     return SDValue();
18682 
18683   // Extracting an element of a vector constant is constant-folded, so this
18684   // transform is just replacing a vector op with a scalar op while moving the
18685   // extract.
18686   SDValue Op0 = Vec.getOperand(0);
18687   SDValue Op1 = Vec.getOperand(1);
18688   if (isAnyConstantBuildVector(Op0, true) ||
18689       isAnyConstantBuildVector(Op1, true)) {
18690     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
18691     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
18692     SDLoc DL(ExtElt);
18693     EVT VT = ExtElt->getValueType(0);
18694     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
18695     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
18696     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
18697   }
18698 
18699   return SDValue();
18700 }
18701 
visitEXTRACT_VECTOR_ELT(SDNode * N)18702 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
18703   SDValue VecOp = N->getOperand(0);
18704   SDValue Index = N->getOperand(1);
18705   EVT ScalarVT = N->getValueType(0);
18706   EVT VecVT = VecOp.getValueType();
18707   if (VecOp.isUndef())
18708     return DAG.getUNDEF(ScalarVT);
18709 
18710   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
18711   //
18712   // This only really matters if the index is non-constant since other combines
18713   // on the constant elements already work.
18714   SDLoc DL(N);
18715   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
18716       Index == VecOp.getOperand(2)) {
18717     SDValue Elt = VecOp.getOperand(1);
18718     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
18719   }
18720 
18721   // (vextract (scalar_to_vector val, 0) -> val
18722   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
18723     // Only 0'th element of SCALAR_TO_VECTOR is defined.
18724     if (DAG.isKnownNeverZero(Index))
18725       return DAG.getUNDEF(ScalarVT);
18726 
18727     // Check if the result type doesn't match the inserted element type. A
18728     // SCALAR_TO_VECTOR may truncate the inserted element and the
18729     // EXTRACT_VECTOR_ELT may widen the extracted vector.
18730     SDValue InOp = VecOp.getOperand(0);
18731     if (InOp.getValueType() != ScalarVT) {
18732       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
18733       return DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
18734     }
18735     return InOp;
18736   }
18737 
18738   // extract_vector_elt of out-of-bounds element -> UNDEF
18739   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
18740   if (IndexC && VecVT.isFixedLengthVector() &&
18741       IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
18742     return DAG.getUNDEF(ScalarVT);
18743 
18744   // extract_vector_elt (build_vector x, y), 1 -> y
18745   if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
18746        VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
18747       TLI.isTypeLegal(VecVT) &&
18748       (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
18749     assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
18750             VecVT.isFixedLengthVector()) &&
18751            "BUILD_VECTOR used for scalable vectors");
18752     unsigned IndexVal =
18753         VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
18754     SDValue Elt = VecOp.getOperand(IndexVal);
18755     EVT InEltVT = Elt.getValueType();
18756 
18757     // Sometimes build_vector's scalar input types do not match result type.
18758     if (ScalarVT == InEltVT)
18759       return Elt;
18760 
18761     // TODO: It may be useful to truncate if free if the build_vector implicitly
18762     // converts.
18763   }
18764 
18765   if (VecVT.isScalableVector())
18766     return SDValue();
18767 
18768   // All the code from this point onwards assumes fixed width vectors, but it's
18769   // possible that some of the combinations could be made to work for scalable
18770   // vectors too.
18771   unsigned NumElts = VecVT.getVectorNumElements();
18772   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
18773 
18774   // TODO: These transforms should not require the 'hasOneUse' restriction, but
18775   // there are regressions on multiple targets without it. We can end up with a
18776   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
18777   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
18778       VecOp.hasOneUse()) {
18779     // The vector index of the LSBs of the source depend on the endian-ness.
18780     bool IsLE = DAG.getDataLayout().isLittleEndian();
18781     unsigned ExtractIndex = IndexC->getZExtValue();
18782     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
18783     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
18784     SDValue BCSrc = VecOp.getOperand(0);
18785     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
18786       return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc);
18787 
18788     if (LegalTypes && BCSrc.getValueType().isInteger() &&
18789         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
18790       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
18791       // trunc i64 X to i32
18792       SDValue X = BCSrc.getOperand(0);
18793       assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
18794              "Extract element and scalar to vector can't change element type "
18795              "from FP to integer.");
18796       unsigned XBitWidth = X.getValueSizeInBits();
18797       BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
18798 
18799       // An extract element return value type can be wider than its vector
18800       // operand element type. In that case, the high bits are undefined, so
18801       // it's possible that we may need to extend rather than truncate.
18802       if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
18803         assert(XBitWidth % VecEltBitWidth == 0 &&
18804                "Scalar bitwidth must be a multiple of vector element bitwidth");
18805         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
18806       }
18807     }
18808   }
18809 
18810   if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
18811     return BO;
18812 
18813   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
18814   // We only perform this optimization before the op legalization phase because
18815   // we may introduce new vector instructions which are not backed by TD
18816   // patterns. For example on AVX, extracting elements from a wide vector
18817   // without using extract_subvector. However, if we can find an underlying
18818   // scalar value, then we can always use that.
18819   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
18820     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
18821     // Find the new index to extract from.
18822     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
18823 
18824     // Extracting an undef index is undef.
18825     if (OrigElt == -1)
18826       return DAG.getUNDEF(ScalarVT);
18827 
18828     // Select the right vector half to extract from.
18829     SDValue SVInVec;
18830     if (OrigElt < (int)NumElts) {
18831       SVInVec = VecOp.getOperand(0);
18832     } else {
18833       SVInVec = VecOp.getOperand(1);
18834       OrigElt -= NumElts;
18835     }
18836 
18837     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
18838       SDValue InOp = SVInVec.getOperand(OrigElt);
18839       if (InOp.getValueType() != ScalarVT) {
18840         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
18841         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
18842       }
18843 
18844       return InOp;
18845     }
18846 
18847     // FIXME: We should handle recursing on other vector shuffles and
18848     // scalar_to_vector here as well.
18849 
18850     if (!LegalOperations ||
18851         // FIXME: Should really be just isOperationLegalOrCustom.
18852         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
18853         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
18854       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
18855                          DAG.getVectorIdxConstant(OrigElt, DL));
18856     }
18857   }
18858 
18859   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
18860   // simplify it based on the (valid) extraction indices.
18861   if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
18862         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
18863                Use->getOperand(0) == VecOp &&
18864                isa<ConstantSDNode>(Use->getOperand(1));
18865       })) {
18866     APInt DemandedElts = APInt::getNullValue(NumElts);
18867     for (SDNode *Use : VecOp->uses()) {
18868       auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
18869       if (CstElt->getAPIntValue().ult(NumElts))
18870         DemandedElts.setBit(CstElt->getZExtValue());
18871     }
18872     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
18873       // We simplified the vector operand of this extract element. If this
18874       // extract is not dead, visit it again so it is folded properly.
18875       if (N->getOpcode() != ISD::DELETED_NODE)
18876         AddToWorklist(N);
18877       return SDValue(N, 0);
18878     }
18879     APInt DemandedBits = APInt::getAllOnesValue(VecEltBitWidth);
18880     if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
18881       // We simplified the vector operand of this extract element. If this
18882       // extract is not dead, visit it again so it is folded properly.
18883       if (N->getOpcode() != ISD::DELETED_NODE)
18884         AddToWorklist(N);
18885       return SDValue(N, 0);
18886     }
18887   }
18888 
18889   // Everything under here is trying to match an extract of a loaded value.
18890   // If the result of load has to be truncated, then it's not necessarily
18891   // profitable.
18892   bool BCNumEltsChanged = false;
18893   EVT ExtVT = VecVT.getVectorElementType();
18894   EVT LVT = ExtVT;
18895   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
18896     return SDValue();
18897 
18898   if (VecOp.getOpcode() == ISD::BITCAST) {
18899     // Don't duplicate a load with other uses.
18900     if (!VecOp.hasOneUse())
18901       return SDValue();
18902 
18903     EVT BCVT = VecOp.getOperand(0).getValueType();
18904     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
18905       return SDValue();
18906     if (NumElts != BCVT.getVectorNumElements())
18907       BCNumEltsChanged = true;
18908     VecOp = VecOp.getOperand(0);
18909     ExtVT = BCVT.getVectorElementType();
18910   }
18911 
18912   // extract (vector load $addr), i --> load $addr + i * size
18913   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
18914       ISD::isNormalLoad(VecOp.getNode()) &&
18915       !Index->hasPredecessor(VecOp.getNode())) {
18916     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
18917     if (VecLoad && VecLoad->isSimple())
18918       return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
18919   }
18920 
18921   // Perform only after legalization to ensure build_vector / vector_shuffle
18922   // optimizations have already been done.
18923   if (!LegalOperations || !IndexC)
18924     return SDValue();
18925 
18926   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
18927   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
18928   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
18929   int Elt = IndexC->getZExtValue();
18930   LoadSDNode *LN0 = nullptr;
18931   if (ISD::isNormalLoad(VecOp.getNode())) {
18932     LN0 = cast<LoadSDNode>(VecOp);
18933   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
18934              VecOp.getOperand(0).getValueType() == ExtVT &&
18935              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
18936     // Don't duplicate a load with other uses.
18937     if (!VecOp.hasOneUse())
18938       return SDValue();
18939 
18940     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
18941   }
18942   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
18943     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
18944     // =>
18945     // (load $addr+1*size)
18946 
18947     // Don't duplicate a load with other uses.
18948     if (!VecOp.hasOneUse())
18949       return SDValue();
18950 
18951     // If the bit convert changed the number of elements, it is unsafe
18952     // to examine the mask.
18953     if (BCNumEltsChanged)
18954       return SDValue();
18955 
18956     // Select the input vector, guarding against out of range extract vector.
18957     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
18958     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
18959 
18960     if (VecOp.getOpcode() == ISD::BITCAST) {
18961       // Don't duplicate a load with other uses.
18962       if (!VecOp.hasOneUse())
18963         return SDValue();
18964 
18965       VecOp = VecOp.getOperand(0);
18966     }
18967     if (ISD::isNormalLoad(VecOp.getNode())) {
18968       LN0 = cast<LoadSDNode>(VecOp);
18969       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
18970       Index = DAG.getConstant(Elt, DL, Index.getValueType());
18971     }
18972   } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
18973              VecVT.getVectorElementType() == ScalarVT &&
18974              (!LegalTypes ||
18975               TLI.isTypeLegal(
18976                   VecOp.getOperand(0).getValueType().getVectorElementType()))) {
18977     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
18978     //      -> extract_vector_elt a, 0
18979     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
18980     //      -> extract_vector_elt a, 1
18981     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
18982     //      -> extract_vector_elt b, 0
18983     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
18984     //      -> extract_vector_elt b, 1
18985     SDLoc SL(N);
18986     EVT ConcatVT = VecOp.getOperand(0).getValueType();
18987     unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
18988     SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, SL,
18989                                      Index.getValueType());
18990 
18991     SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
18992     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL,
18993                               ConcatVT.getVectorElementType(),
18994                               ConcatOp, NewIdx);
18995     return DAG.getNode(ISD::BITCAST, SL, ScalarVT, Elt);
18996   }
18997 
18998   // Make sure we found a non-volatile load and the extractelement is
18999   // the only use.
19000   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
19001     return SDValue();
19002 
19003   // If Idx was -1 above, Elt is going to be -1, so just return undef.
19004   if (Elt == -1)
19005     return DAG.getUNDEF(LVT);
19006 
19007   return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
19008 }
19009 
19010 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)19011 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
19012   // We perform this optimization post type-legalization because
19013   // the type-legalizer often scalarizes integer-promoted vectors.
19014   // Performing this optimization before may create bit-casts which
19015   // will be type-legalized to complex code sequences.
19016   // We perform this optimization only before the operation legalizer because we
19017   // may introduce illegal operations.
19018   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
19019     return SDValue();
19020 
19021   unsigned NumInScalars = N->getNumOperands();
19022   SDLoc DL(N);
19023   EVT VT = N->getValueType(0);
19024 
19025   // Check to see if this is a BUILD_VECTOR of a bunch of values
19026   // which come from any_extend or zero_extend nodes. If so, we can create
19027   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
19028   // optimizations. We do not handle sign-extend because we can't fill the sign
19029   // using shuffles.
19030   EVT SourceType = MVT::Other;
19031   bool AllAnyExt = true;
19032 
19033   for (unsigned i = 0; i != NumInScalars; ++i) {
19034     SDValue In = N->getOperand(i);
19035     // Ignore undef inputs.
19036     if (In.isUndef()) continue;
19037 
19038     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
19039     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
19040 
19041     // Abort if the element is not an extension.
19042     if (!ZeroExt && !AnyExt) {
19043       SourceType = MVT::Other;
19044       break;
19045     }
19046 
19047     // The input is a ZeroExt or AnyExt. Check the original type.
19048     EVT InTy = In.getOperand(0).getValueType();
19049 
19050     // Check that all of the widened source types are the same.
19051     if (SourceType == MVT::Other)
19052       // First time.
19053       SourceType = InTy;
19054     else if (InTy != SourceType) {
19055       // Multiple income types. Abort.
19056       SourceType = MVT::Other;
19057       break;
19058     }
19059 
19060     // Check if all of the extends are ANY_EXTENDs.
19061     AllAnyExt &= AnyExt;
19062   }
19063 
19064   // In order to have valid types, all of the inputs must be extended from the
19065   // same source type and all of the inputs must be any or zero extend.
19066   // Scalar sizes must be a power of two.
19067   EVT OutScalarTy = VT.getScalarType();
19068   bool ValidTypes = SourceType != MVT::Other &&
19069                  isPowerOf2_32(OutScalarTy.getSizeInBits()) &&
19070                  isPowerOf2_32(SourceType.getSizeInBits());
19071 
19072   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
19073   // turn into a single shuffle instruction.
19074   if (!ValidTypes)
19075     return SDValue();
19076 
19077   // If we already have a splat buildvector, then don't fold it if it means
19078   // introducing zeros.
19079   if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
19080     return SDValue();
19081 
19082   bool isLE = DAG.getDataLayout().isLittleEndian();
19083   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
19084   assert(ElemRatio > 1 && "Invalid element size ratio");
19085   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
19086                                DAG.getConstant(0, DL, SourceType);
19087 
19088   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
19089   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
19090 
19091   // Populate the new build_vector
19092   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
19093     SDValue Cast = N->getOperand(i);
19094     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
19095             Cast.getOpcode() == ISD::ZERO_EXTEND ||
19096             Cast.isUndef()) && "Invalid cast opcode");
19097     SDValue In;
19098     if (Cast.isUndef())
19099       In = DAG.getUNDEF(SourceType);
19100     else
19101       In = Cast->getOperand(0);
19102     unsigned Index = isLE ? (i * ElemRatio) :
19103                             (i * ElemRatio + (ElemRatio - 1));
19104 
19105     assert(Index < Ops.size() && "Invalid index");
19106     Ops[Index] = In;
19107   }
19108 
19109   // The type of the new BUILD_VECTOR node.
19110   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
19111   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
19112          "Invalid vector size");
19113   // Check if the new vector type is legal.
19114   if (!isTypeLegal(VecVT) ||
19115       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
19116        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
19117     return SDValue();
19118 
19119   // Make the new BUILD_VECTOR.
19120   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
19121 
19122   // The new BUILD_VECTOR node has the potential to be further optimized.
19123   AddToWorklist(BV.getNode());
19124   // Bitcast to the desired type.
19125   return DAG.getBitcast(VT, BV);
19126 }
19127 
19128 // Simplify (build_vec (trunc $1)
19129 //                     (trunc (srl $1 half-width))
19130 //                     (trunc (srl $1 (2 * half-width))) …)
19131 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)19132 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
19133   assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
19134 
19135   // Only for little endian
19136   if (!DAG.getDataLayout().isLittleEndian())
19137     return SDValue();
19138 
19139   SDLoc DL(N);
19140   EVT VT = N->getValueType(0);
19141   EVT OutScalarTy = VT.getScalarType();
19142   uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
19143 
19144   // Only for power of two types to be sure that bitcast works well
19145   if (!isPowerOf2_64(ScalarTypeBitsize))
19146     return SDValue();
19147 
19148   unsigned NumInScalars = N->getNumOperands();
19149 
19150   // Look through bitcasts
19151   auto PeekThroughBitcast = [](SDValue Op) {
19152     if (Op.getOpcode() == ISD::BITCAST)
19153       return Op.getOperand(0);
19154     return Op;
19155   };
19156 
19157   // The source value where all the parts are extracted.
19158   SDValue Src;
19159   for (unsigned i = 0; i != NumInScalars; ++i) {
19160     SDValue In = PeekThroughBitcast(N->getOperand(i));
19161     // Ignore undef inputs.
19162     if (In.isUndef()) continue;
19163 
19164     if (In.getOpcode() != ISD::TRUNCATE)
19165       return SDValue();
19166 
19167     In = PeekThroughBitcast(In.getOperand(0));
19168 
19169     if (In.getOpcode() != ISD::SRL) {
19170       // For now only build_vec without shuffling, handle shifts here in the
19171       // future.
19172       if (i != 0)
19173         return SDValue();
19174 
19175       Src = In;
19176     } else {
19177       // In is SRL
19178       SDValue part = PeekThroughBitcast(In.getOperand(0));
19179 
19180       if (!Src) {
19181         Src = part;
19182       } else if (Src != part) {
19183         // Vector parts do not stem from the same variable
19184         return SDValue();
19185       }
19186 
19187       SDValue ShiftAmtVal = In.getOperand(1);
19188       if (!isa<ConstantSDNode>(ShiftAmtVal))
19189         return SDValue();
19190 
19191       uint64_t ShiftAmt = In.getNode()->getConstantOperandVal(1);
19192 
19193       // The extracted value is not extracted at the right position
19194       if (ShiftAmt != i * ScalarTypeBitsize)
19195         return SDValue();
19196     }
19197   }
19198 
19199   // Only cast if the size is the same
19200   if (Src.getValueType().getSizeInBits() != VT.getSizeInBits())
19201     return SDValue();
19202 
19203   return DAG.getBitcast(VT, Src);
19204 }
19205 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)19206 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
19207                                            ArrayRef<int> VectorMask,
19208                                            SDValue VecIn1, SDValue VecIn2,
19209                                            unsigned LeftIdx, bool DidSplitVec) {
19210   SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
19211 
19212   EVT VT = N->getValueType(0);
19213   EVT InVT1 = VecIn1.getValueType();
19214   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
19215 
19216   unsigned NumElems = VT.getVectorNumElements();
19217   unsigned ShuffleNumElems = NumElems;
19218 
19219   // If we artificially split a vector in two already, then the offsets in the
19220   // operands will all be based off of VecIn1, even those in VecIn2.
19221   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
19222 
19223   uint64_t VTSize = VT.getFixedSizeInBits();
19224   uint64_t InVT1Size = InVT1.getFixedSizeInBits();
19225   uint64_t InVT2Size = InVT2.getFixedSizeInBits();
19226 
19227   assert(InVT2Size <= InVT1Size &&
19228          "Inputs must be sorted to be in non-increasing vector size order.");
19229 
19230   // We can't generate a shuffle node with mismatched input and output types.
19231   // Try to make the types match the type of the output.
19232   if (InVT1 != VT || InVT2 != VT) {
19233     if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
19234       // If the output vector length is a multiple of both input lengths,
19235       // we can concatenate them and pad the rest with undefs.
19236       unsigned NumConcats = VTSize / InVT1Size;
19237       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
19238       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
19239       ConcatOps[0] = VecIn1;
19240       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
19241       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
19242       VecIn2 = SDValue();
19243     } else if (InVT1Size == VTSize * 2) {
19244       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
19245         return SDValue();
19246 
19247       if (!VecIn2.getNode()) {
19248         // If we only have one input vector, and it's twice the size of the
19249         // output, split it in two.
19250         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
19251                              DAG.getVectorIdxConstant(NumElems, DL));
19252         VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
19253         // Since we now have shorter input vectors, adjust the offset of the
19254         // second vector's start.
19255         Vec2Offset = NumElems;
19256       } else {
19257         assert(InVT2Size <= InVT1Size &&
19258                "Second input is not going to be larger than the first one.");
19259 
19260         // VecIn1 is wider than the output, and we have another, possibly
19261         // smaller input. Pad the smaller input with undefs, shuffle at the
19262         // input vector width, and extract the output.
19263         // The shuffle type is different than VT, so check legality again.
19264         if (LegalOperations &&
19265             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
19266           return SDValue();
19267 
19268         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
19269         // lower it back into a BUILD_VECTOR. So if the inserted type is
19270         // illegal, don't even try.
19271         if (InVT1 != InVT2) {
19272           if (!TLI.isTypeLegal(InVT2))
19273             return SDValue();
19274           VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
19275                                DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
19276         }
19277         ShuffleNumElems = NumElems * 2;
19278       }
19279     } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
19280       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
19281       ConcatOps[0] = VecIn2;
19282       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
19283     } else {
19284       // TODO: Support cases where the length mismatch isn't exactly by a
19285       // factor of 2.
19286       // TODO: Move this check upwards, so that if we have bad type
19287       // mismatches, we don't create any DAG nodes.
19288       return SDValue();
19289     }
19290   }
19291 
19292   // Initialize mask to undef.
19293   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
19294 
19295   // Only need to run up to the number of elements actually used, not the
19296   // total number of elements in the shuffle - if we are shuffling a wider
19297   // vector, the high lanes should be set to undef.
19298   for (unsigned i = 0; i != NumElems; ++i) {
19299     if (VectorMask[i] <= 0)
19300       continue;
19301 
19302     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
19303     if (VectorMask[i] == (int)LeftIdx) {
19304       Mask[i] = ExtIndex;
19305     } else if (VectorMask[i] == (int)LeftIdx + 1) {
19306       Mask[i] = Vec2Offset + ExtIndex;
19307     }
19308   }
19309 
19310   // The type the input vectors may have changed above.
19311   InVT1 = VecIn1.getValueType();
19312 
19313   // If we already have a VecIn2, it should have the same type as VecIn1.
19314   // If we don't, get an undef/zero vector of the appropriate type.
19315   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
19316   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
19317 
19318   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
19319   if (ShuffleNumElems > NumElems)
19320     Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
19321 
19322   return Shuffle;
19323 }
19324 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)19325 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
19326   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
19327 
19328   // First, determine where the build vector is not undef.
19329   // TODO: We could extend this to handle zero elements as well as undefs.
19330   int NumBVOps = BV->getNumOperands();
19331   int ZextElt = -1;
19332   for (int i = 0; i != NumBVOps; ++i) {
19333     SDValue Op = BV->getOperand(i);
19334     if (Op.isUndef())
19335       continue;
19336     if (ZextElt == -1)
19337       ZextElt = i;
19338     else
19339       return SDValue();
19340   }
19341   // Bail out if there's no non-undef element.
19342   if (ZextElt == -1)
19343     return SDValue();
19344 
19345   // The build vector contains some number of undef elements and exactly
19346   // one other element. That other element must be a zero-extended scalar
19347   // extracted from a vector at a constant index to turn this into a shuffle.
19348   // Also, require that the build vector does not implicitly truncate/extend
19349   // its elements.
19350   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
19351   EVT VT = BV->getValueType(0);
19352   SDValue Zext = BV->getOperand(ZextElt);
19353   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
19354       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
19355       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
19356       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
19357     return SDValue();
19358 
19359   // The zero-extend must be a multiple of the source size, and we must be
19360   // building a vector of the same size as the source of the extract element.
19361   SDValue Extract = Zext.getOperand(0);
19362   unsigned DestSize = Zext.getValueSizeInBits();
19363   unsigned SrcSize = Extract.getValueSizeInBits();
19364   if (DestSize % SrcSize != 0 ||
19365       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
19366     return SDValue();
19367 
19368   // Create a shuffle mask that will combine the extracted element with zeros
19369   // and undefs.
19370   int ZextRatio = DestSize / SrcSize;
19371   int NumMaskElts = NumBVOps * ZextRatio;
19372   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
19373   for (int i = 0; i != NumMaskElts; ++i) {
19374     if (i / ZextRatio == ZextElt) {
19375       // The low bits of the (potentially translated) extracted element map to
19376       // the source vector. The high bits map to zero. We will use a zero vector
19377       // as the 2nd source operand of the shuffle, so use the 1st element of
19378       // that vector (mask value is number-of-elements) for the high bits.
19379       if (i % ZextRatio == 0)
19380         ShufMask[i] = Extract.getConstantOperandVal(1);
19381       else
19382         ShufMask[i] = NumMaskElts;
19383     }
19384 
19385     // Undef elements of the build vector remain undef because we initialize
19386     // the shuffle mask with -1.
19387   }
19388 
19389   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
19390   // bitcast (shuffle V, ZeroVec, VectorMask)
19391   SDLoc DL(BV);
19392   EVT VecVT = Extract.getOperand(0).getValueType();
19393   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
19394   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19395   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
19396                                              ZeroVec, ShufMask, DAG);
19397   if (!Shuf)
19398     return SDValue();
19399   return DAG.getBitcast(VT, Shuf);
19400 }
19401 
19402 // FIXME: promote to STLExtras.
19403 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)19404 static auto getFirstIndexOf(R &&Range, const T &Val) {
19405   auto I = find(Range, Val);
19406   if (I == Range.end())
19407     return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
19408   return std::distance(Range.begin(), I);
19409 }
19410 
19411 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
19412 // operations. If the types of the vectors we're extracting from allow it,
19413 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)19414 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
19415   SDLoc DL(N);
19416   EVT VT = N->getValueType(0);
19417 
19418   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
19419   if (!isTypeLegal(VT))
19420     return SDValue();
19421 
19422   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
19423     return V;
19424 
19425   // May only combine to shuffle after legalize if shuffle is legal.
19426   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
19427     return SDValue();
19428 
19429   bool UsesZeroVector = false;
19430   unsigned NumElems = N->getNumOperands();
19431 
19432   // Record, for each element of the newly built vector, which input vector
19433   // that element comes from. -1 stands for undef, 0 for the zero vector,
19434   // and positive values for the input vectors.
19435   // VectorMask maps each element to its vector number, and VecIn maps vector
19436   // numbers to their initial SDValues.
19437 
19438   SmallVector<int, 8> VectorMask(NumElems, -1);
19439   SmallVector<SDValue, 8> VecIn;
19440   VecIn.push_back(SDValue());
19441 
19442   for (unsigned i = 0; i != NumElems; ++i) {
19443     SDValue Op = N->getOperand(i);
19444 
19445     if (Op.isUndef())
19446       continue;
19447 
19448     // See if we can use a blend with a zero vector.
19449     // TODO: Should we generalize this to a blend with an arbitrary constant
19450     // vector?
19451     if (isNullConstant(Op) || isNullFPConstant(Op)) {
19452       UsesZeroVector = true;
19453       VectorMask[i] = 0;
19454       continue;
19455     }
19456 
19457     // Not an undef or zero. If the input is something other than an
19458     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
19459     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
19460         !isa<ConstantSDNode>(Op.getOperand(1)))
19461       return SDValue();
19462     SDValue ExtractedFromVec = Op.getOperand(0);
19463 
19464     if (ExtractedFromVec.getValueType().isScalableVector())
19465       return SDValue();
19466 
19467     const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
19468     if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
19469       return SDValue();
19470 
19471     // All inputs must have the same element type as the output.
19472     if (VT.getVectorElementType() !=
19473         ExtractedFromVec.getValueType().getVectorElementType())
19474       return SDValue();
19475 
19476     // Have we seen this input vector before?
19477     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
19478     // a map back from SDValues to numbers isn't worth it.
19479     int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
19480     if (Idx == -1) { // A new source vector?
19481       Idx = VecIn.size();
19482       VecIn.push_back(ExtractedFromVec);
19483     }
19484 
19485     VectorMask[i] = Idx;
19486   }
19487 
19488   // If we didn't find at least one input vector, bail out.
19489   if (VecIn.size() < 2)
19490     return SDValue();
19491 
19492   // If all the Operands of BUILD_VECTOR extract from same
19493   // vector, then split the vector efficiently based on the maximum
19494   // vector access index and adjust the VectorMask and
19495   // VecIn accordingly.
19496   bool DidSplitVec = false;
19497   if (VecIn.size() == 2) {
19498     unsigned MaxIndex = 0;
19499     unsigned NearestPow2 = 0;
19500     SDValue Vec = VecIn.back();
19501     EVT InVT = Vec.getValueType();
19502     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
19503 
19504     for (unsigned i = 0; i < NumElems; i++) {
19505       if (VectorMask[i] <= 0)
19506         continue;
19507       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
19508       IndexVec[i] = Index;
19509       MaxIndex = std::max(MaxIndex, Index);
19510     }
19511 
19512     NearestPow2 = PowerOf2Ceil(MaxIndex);
19513     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
19514         NumElems * 2 < NearestPow2) {
19515       unsigned SplitSize = NearestPow2 / 2;
19516       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
19517                                      InVT.getVectorElementType(), SplitSize);
19518       if (TLI.isTypeLegal(SplitVT) &&
19519           SplitSize + SplitVT.getVectorNumElements() <=
19520               InVT.getVectorNumElements()) {
19521         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
19522                                      DAG.getVectorIdxConstant(SplitSize, DL));
19523         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
19524                                      DAG.getVectorIdxConstant(0, DL));
19525         VecIn.pop_back();
19526         VecIn.push_back(VecIn1);
19527         VecIn.push_back(VecIn2);
19528         DidSplitVec = true;
19529 
19530         for (unsigned i = 0; i < NumElems; i++) {
19531           if (VectorMask[i] <= 0)
19532             continue;
19533           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
19534         }
19535       }
19536     }
19537   }
19538 
19539   // Sort input vectors by decreasing vector element count,
19540   // while preserving the relative order of equally-sized vectors.
19541   // Note that we keep the first "implicit zero vector as-is.
19542   SmallVector<SDValue, 8> SortedVecIn(VecIn);
19543   llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
19544                     [](const SDValue &a, const SDValue &b) {
19545                       return a.getValueType().getVectorNumElements() >
19546                              b.getValueType().getVectorNumElements();
19547                     });
19548 
19549   // We now also need to rebuild the VectorMask, because it referenced element
19550   // order in VecIn, and we just sorted them.
19551   for (int &SourceVectorIndex : VectorMask) {
19552     if (SourceVectorIndex <= 0)
19553       continue;
19554     unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
19555     assert(Idx > 0 && Idx < SortedVecIn.size() &&
19556            VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
19557     SourceVectorIndex = Idx;
19558   }
19559 
19560   VecIn = std::move(SortedVecIn);
19561 
19562   // TODO: Should this fire if some of the input vectors has illegal type (like
19563   // it does now), or should we let legalization run its course first?
19564 
19565   // Shuffle phase:
19566   // Take pairs of vectors, and shuffle them so that the result has elements
19567   // from these vectors in the correct places.
19568   // For example, given:
19569   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
19570   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
19571   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
19572   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
19573   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
19574   // We will generate:
19575   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
19576   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
19577   SmallVector<SDValue, 4> Shuffles;
19578   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
19579     unsigned LeftIdx = 2 * In + 1;
19580     SDValue VecLeft = VecIn[LeftIdx];
19581     SDValue VecRight =
19582         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
19583 
19584     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
19585                                                 VecRight, LeftIdx, DidSplitVec))
19586       Shuffles.push_back(Shuffle);
19587     else
19588       return SDValue();
19589   }
19590 
19591   // If we need the zero vector as an "ingredient" in the blend tree, add it
19592   // to the list of shuffles.
19593   if (UsesZeroVector)
19594     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
19595                                       : DAG.getConstantFP(0.0, DL, VT));
19596 
19597   // If we only have one shuffle, we're done.
19598   if (Shuffles.size() == 1)
19599     return Shuffles[0];
19600 
19601   // Update the vector mask to point to the post-shuffle vectors.
19602   for (int &Vec : VectorMask)
19603     if (Vec == 0)
19604       Vec = Shuffles.size() - 1;
19605     else
19606       Vec = (Vec - 1) / 2;
19607 
19608   // More than one shuffle. Generate a binary tree of blends, e.g. if from
19609   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
19610   // generate:
19611   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
19612   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
19613   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
19614   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
19615   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
19616   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
19617   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
19618 
19619   // Make sure the initial size of the shuffle list is even.
19620   if (Shuffles.size() % 2)
19621     Shuffles.push_back(DAG.getUNDEF(VT));
19622 
19623   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
19624     if (CurSize % 2) {
19625       Shuffles[CurSize] = DAG.getUNDEF(VT);
19626       CurSize++;
19627     }
19628     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
19629       int Left = 2 * In;
19630       int Right = 2 * In + 1;
19631       SmallVector<int, 8> Mask(NumElems, -1);
19632       for (unsigned i = 0; i != NumElems; ++i) {
19633         if (VectorMask[i] == Left) {
19634           Mask[i] = i;
19635           VectorMask[i] = In;
19636         } else if (VectorMask[i] == Right) {
19637           Mask[i] = i + NumElems;
19638           VectorMask[i] = In;
19639         }
19640       }
19641 
19642       Shuffles[In] =
19643           DAG.getVectorShuffle(VT, DL, Shuffles[Left], Shuffles[Right], Mask);
19644     }
19645   }
19646   return Shuffles[0];
19647 }
19648 
19649 // Try to turn a build vector of zero extends of extract vector elts into a
19650 // a vector zero extend and possibly an extract subvector.
19651 // TODO: Support sign extend?
19652 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)19653 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
19654   if (LegalOperations)
19655     return SDValue();
19656 
19657   EVT VT = N->getValueType(0);
19658 
19659   bool FoundZeroExtend = false;
19660   SDValue Op0 = N->getOperand(0);
19661   auto checkElem = [&](SDValue Op) -> int64_t {
19662     unsigned Opc = Op.getOpcode();
19663     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
19664     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
19665         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19666         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
19667       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
19668         return C->getZExtValue();
19669     return -1;
19670   };
19671 
19672   // Make sure the first element matches
19673   // (zext (extract_vector_elt X, C))
19674   int64_t Offset = checkElem(Op0);
19675   if (Offset < 0)
19676     return SDValue();
19677 
19678   unsigned NumElems = N->getNumOperands();
19679   SDValue In = Op0.getOperand(0).getOperand(0);
19680   EVT InSVT = In.getValueType().getScalarType();
19681   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
19682 
19683   // Don't create an illegal input type after type legalization.
19684   if (LegalTypes && !TLI.isTypeLegal(InVT))
19685     return SDValue();
19686 
19687   // Ensure all the elements come from the same vector and are adjacent.
19688   for (unsigned i = 1; i != NumElems; ++i) {
19689     if ((Offset + i) != checkElem(N->getOperand(i)))
19690       return SDValue();
19691   }
19692 
19693   SDLoc DL(N);
19694   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
19695                    Op0.getOperand(0).getOperand(1));
19696   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
19697                      VT, In);
19698 }
19699 
visitBUILD_VECTOR(SDNode * N)19700 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
19701   EVT VT = N->getValueType(0);
19702 
19703   // A vector built entirely of undefs is undef.
19704   if (ISD::allOperandsUndef(N))
19705     return DAG.getUNDEF(VT);
19706 
19707   // If this is a splat of a bitcast from another vector, change to a
19708   // concat_vector.
19709   // For example:
19710   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
19711   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
19712   //
19713   // If X is a build_vector itself, the concat can become a larger build_vector.
19714   // TODO: Maybe this is useful for non-splat too?
19715   if (!LegalOperations) {
19716     if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
19717       Splat = peekThroughBitcasts(Splat);
19718       EVT SrcVT = Splat.getValueType();
19719       if (SrcVT.isVector()) {
19720         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
19721         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
19722                                      SrcVT.getVectorElementType(), NumElts);
19723         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
19724           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
19725           SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
19726                                        NewVT, Ops);
19727           return DAG.getBitcast(VT, Concat);
19728         }
19729       }
19730     }
19731   }
19732 
19733   // Check if we can express BUILD VECTOR via subvector extract.
19734   if (!LegalTypes && (N->getNumOperands() > 1)) {
19735     SDValue Op0 = N->getOperand(0);
19736     auto checkElem = [&](SDValue Op) -> uint64_t {
19737       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
19738           (Op0.getOperand(0) == Op.getOperand(0)))
19739         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
19740           return CNode->getZExtValue();
19741       return -1;
19742     };
19743 
19744     int Offset = checkElem(Op0);
19745     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
19746       if (Offset + i != checkElem(N->getOperand(i))) {
19747         Offset = -1;
19748         break;
19749       }
19750     }
19751 
19752     if ((Offset == 0) &&
19753         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
19754       return Op0.getOperand(0);
19755     if ((Offset != -1) &&
19756         ((Offset % N->getValueType(0).getVectorNumElements()) ==
19757          0)) // IDX must be multiple of output size.
19758       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
19759                          Op0.getOperand(0), Op0.getOperand(1));
19760   }
19761 
19762   if (SDValue V = convertBuildVecZextToZext(N))
19763     return V;
19764 
19765   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
19766     return V;
19767 
19768   if (SDValue V = reduceBuildVecTruncToBitCast(N))
19769     return V;
19770 
19771   if (SDValue V = reduceBuildVecToShuffle(N))
19772     return V;
19773 
19774   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
19775   // Do this late as some of the above may replace the splat.
19776   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
19777     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
19778       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
19779       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
19780     }
19781 
19782   return SDValue();
19783 }
19784 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)19785 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
19786   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19787   EVT OpVT = N->getOperand(0).getValueType();
19788 
19789   // If the operands are legal vectors, leave them alone.
19790   if (TLI.isTypeLegal(OpVT))
19791     return SDValue();
19792 
19793   SDLoc DL(N);
19794   EVT VT = N->getValueType(0);
19795   SmallVector<SDValue, 8> Ops;
19796 
19797   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
19798   SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
19799 
19800   // Keep track of what we encounter.
19801   bool AnyInteger = false;
19802   bool AnyFP = false;
19803   for (const SDValue &Op : N->ops()) {
19804     if (ISD::BITCAST == Op.getOpcode() &&
19805         !Op.getOperand(0).getValueType().isVector())
19806       Ops.push_back(Op.getOperand(0));
19807     else if (ISD::UNDEF == Op.getOpcode())
19808       Ops.push_back(ScalarUndef);
19809     else
19810       return SDValue();
19811 
19812     // Note whether we encounter an integer or floating point scalar.
19813     // If it's neither, bail out, it could be something weird like x86mmx.
19814     EVT LastOpVT = Ops.back().getValueType();
19815     if (LastOpVT.isFloatingPoint())
19816       AnyFP = true;
19817     else if (LastOpVT.isInteger())
19818       AnyInteger = true;
19819     else
19820       return SDValue();
19821   }
19822 
19823   // If any of the operands is a floating point scalar bitcast to a vector,
19824   // use floating point types throughout, and bitcast everything.
19825   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
19826   if (AnyFP) {
19827     SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
19828     ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
19829     if (AnyInteger) {
19830       for (SDValue &Op : Ops) {
19831         if (Op.getValueType() == SVT)
19832           continue;
19833         if (Op.isUndef())
19834           Op = ScalarUndef;
19835         else
19836           Op = DAG.getBitcast(SVT, Op);
19837       }
19838     }
19839   }
19840 
19841   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
19842                                VT.getSizeInBits() / SVT.getSizeInBits());
19843   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
19844 }
19845 
19846 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
19847 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
19848 // most two distinct vectors the same size as the result, attempt to turn this
19849 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)19850 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
19851   EVT VT = N->getValueType(0);
19852   EVT OpVT = N->getOperand(0).getValueType();
19853 
19854   // We currently can't generate an appropriate shuffle for a scalable vector.
19855   if (VT.isScalableVector())
19856     return SDValue();
19857 
19858   int NumElts = VT.getVectorNumElements();
19859   int NumOpElts = OpVT.getVectorNumElements();
19860 
19861   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
19862   SmallVector<int, 8> Mask;
19863 
19864   for (SDValue Op : N->ops()) {
19865     Op = peekThroughBitcasts(Op);
19866 
19867     // UNDEF nodes convert to UNDEF shuffle mask values.
19868     if (Op.isUndef()) {
19869       Mask.append((unsigned)NumOpElts, -1);
19870       continue;
19871     }
19872 
19873     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
19874       return SDValue();
19875 
19876     // What vector are we extracting the subvector from and at what index?
19877     SDValue ExtVec = Op.getOperand(0);
19878     int ExtIdx = Op.getConstantOperandVal(1);
19879 
19880     // We want the EVT of the original extraction to correctly scale the
19881     // extraction index.
19882     EVT ExtVT = ExtVec.getValueType();
19883     ExtVec = peekThroughBitcasts(ExtVec);
19884 
19885     // UNDEF nodes convert to UNDEF shuffle mask values.
19886     if (ExtVec.isUndef()) {
19887       Mask.append((unsigned)NumOpElts, -1);
19888       continue;
19889     }
19890 
19891     // Ensure that we are extracting a subvector from a vector the same
19892     // size as the result.
19893     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
19894       return SDValue();
19895 
19896     // Scale the subvector index to account for any bitcast.
19897     int NumExtElts = ExtVT.getVectorNumElements();
19898     if (0 == (NumExtElts % NumElts))
19899       ExtIdx /= (NumExtElts / NumElts);
19900     else if (0 == (NumElts % NumExtElts))
19901       ExtIdx *= (NumElts / NumExtElts);
19902     else
19903       return SDValue();
19904 
19905     // At most we can reference 2 inputs in the final shuffle.
19906     if (SV0.isUndef() || SV0 == ExtVec) {
19907       SV0 = ExtVec;
19908       for (int i = 0; i != NumOpElts; ++i)
19909         Mask.push_back(i + ExtIdx);
19910     } else if (SV1.isUndef() || SV1 == ExtVec) {
19911       SV1 = ExtVec;
19912       for (int i = 0; i != NumOpElts; ++i)
19913         Mask.push_back(i + ExtIdx + NumElts);
19914     } else {
19915       return SDValue();
19916     }
19917   }
19918 
19919   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19920   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
19921                                      DAG.getBitcast(VT, SV1), Mask, DAG);
19922 }
19923 
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)19924 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
19925   unsigned CastOpcode = N->getOperand(0).getOpcode();
19926   switch (CastOpcode) {
19927   case ISD::SINT_TO_FP:
19928   case ISD::UINT_TO_FP:
19929   case ISD::FP_TO_SINT:
19930   case ISD::FP_TO_UINT:
19931     // TODO: Allow more opcodes?
19932     //  case ISD::BITCAST:
19933     //  case ISD::TRUNCATE:
19934     //  case ISD::ZERO_EXTEND:
19935     //  case ISD::SIGN_EXTEND:
19936     //  case ISD::FP_EXTEND:
19937     break;
19938   default:
19939     return SDValue();
19940   }
19941 
19942   EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
19943   if (!SrcVT.isVector())
19944     return SDValue();
19945 
19946   // All operands of the concat must be the same kind of cast from the same
19947   // source type.
19948   SmallVector<SDValue, 4> SrcOps;
19949   for (SDValue Op : N->ops()) {
19950     if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
19951         Op.getOperand(0).getValueType() != SrcVT)
19952       return SDValue();
19953     SrcOps.push_back(Op.getOperand(0));
19954   }
19955 
19956   // The wider cast must be supported by the target. This is unusual because
19957   // the operation support type parameter depends on the opcode. In addition,
19958   // check the other type in the cast to make sure this is really legal.
19959   EVT VT = N->getValueType(0);
19960   EVT SrcEltVT = SrcVT.getVectorElementType();
19961   ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
19962   EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
19963   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19964   switch (CastOpcode) {
19965   case ISD::SINT_TO_FP:
19966   case ISD::UINT_TO_FP:
19967     if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
19968         !TLI.isTypeLegal(VT))
19969       return SDValue();
19970     break;
19971   case ISD::FP_TO_SINT:
19972   case ISD::FP_TO_UINT:
19973     if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
19974         !TLI.isTypeLegal(ConcatSrcVT))
19975       return SDValue();
19976     break;
19977   default:
19978     llvm_unreachable("Unexpected cast opcode");
19979   }
19980 
19981   // concat (cast X), (cast Y)... -> cast (concat X, Y...)
19982   SDLoc DL(N);
19983   SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
19984   return DAG.getNode(CastOpcode, DL, VT, NewConcat);
19985 }
19986 
visitCONCAT_VECTORS(SDNode * N)19987 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
19988   // If we only have one input vector, we don't need to do any concatenation.
19989   if (N->getNumOperands() == 1)
19990     return N->getOperand(0);
19991 
19992   // Check if all of the operands are undefs.
19993   EVT VT = N->getValueType(0);
19994   if (ISD::allOperandsUndef(N))
19995     return DAG.getUNDEF(VT);
19996 
19997   // Optimize concat_vectors where all but the first of the vectors are undef.
19998   if (all_of(drop_begin(N->ops()),
19999              [](const SDValue &Op) { return Op.isUndef(); })) {
20000     SDValue In = N->getOperand(0);
20001     assert(In.getValueType().isVector() && "Must concat vectors");
20002 
20003     // If the input is a concat_vectors, just make a larger concat by padding
20004     // with smaller undefs.
20005     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) {
20006       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
20007       SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
20008       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
20009       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
20010     }
20011 
20012     SDValue Scalar = peekThroughOneUseBitcasts(In);
20013 
20014     // concat_vectors(scalar_to_vector(scalar), undef) ->
20015     //     scalar_to_vector(scalar)
20016     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
20017          Scalar.hasOneUse()) {
20018       EVT SVT = Scalar.getValueType().getVectorElementType();
20019       if (SVT == Scalar.getOperand(0).getValueType())
20020         Scalar = Scalar.getOperand(0);
20021     }
20022 
20023     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
20024     if (!Scalar.getValueType().isVector()) {
20025       // If the bitcast type isn't legal, it might be a trunc of a legal type;
20026       // look through the trunc so we can still do the transform:
20027       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
20028       if (Scalar->getOpcode() == ISD::TRUNCATE &&
20029           !TLI.isTypeLegal(Scalar.getValueType()) &&
20030           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
20031         Scalar = Scalar->getOperand(0);
20032 
20033       EVT SclTy = Scalar.getValueType();
20034 
20035       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
20036         return SDValue();
20037 
20038       // Bail out if the vector size is not a multiple of the scalar size.
20039       if (VT.getSizeInBits() % SclTy.getSizeInBits())
20040         return SDValue();
20041 
20042       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
20043       if (VNTNumElms < 2)
20044         return SDValue();
20045 
20046       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
20047       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
20048         return SDValue();
20049 
20050       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
20051       return DAG.getBitcast(VT, Res);
20052     }
20053   }
20054 
20055   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
20056   // We have already tested above for an UNDEF only concatenation.
20057   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
20058   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
20059   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
20060     return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
20061   };
20062   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
20063     SmallVector<SDValue, 8> Opnds;
20064     EVT SVT = VT.getScalarType();
20065 
20066     EVT MinVT = SVT;
20067     if (!SVT.isFloatingPoint()) {
20068       // If BUILD_VECTOR are from built from integer, they may have different
20069       // operand types. Get the smallest type and truncate all operands to it.
20070       bool FoundMinVT = false;
20071       for (const SDValue &Op : N->ops())
20072         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
20073           EVT OpSVT = Op.getOperand(0).getValueType();
20074           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
20075           FoundMinVT = true;
20076         }
20077       assert(FoundMinVT && "Concat vector type mismatch");
20078     }
20079 
20080     for (const SDValue &Op : N->ops()) {
20081       EVT OpVT = Op.getValueType();
20082       unsigned NumElts = OpVT.getVectorNumElements();
20083 
20084       if (ISD::UNDEF == Op.getOpcode())
20085         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
20086 
20087       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
20088         if (SVT.isFloatingPoint()) {
20089           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
20090           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
20091         } else {
20092           for (unsigned i = 0; i != NumElts; ++i)
20093             Opnds.push_back(
20094                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
20095         }
20096       }
20097     }
20098 
20099     assert(VT.getVectorNumElements() == Opnds.size() &&
20100            "Concat vector type mismatch");
20101     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
20102   }
20103 
20104   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
20105   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
20106     return V;
20107 
20108   // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
20109   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT))
20110     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
20111       return V;
20112 
20113   if (SDValue V = combineConcatVectorOfCasts(N, DAG))
20114     return V;
20115 
20116   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
20117   // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
20118   // operands and look for a CONCAT operations that place the incoming vectors
20119   // at the exact same location.
20120   //
20121   // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
20122   SDValue SingleSource = SDValue();
20123   unsigned PartNumElem =
20124       N->getOperand(0).getValueType().getVectorMinNumElements();
20125 
20126   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
20127     SDValue Op = N->getOperand(i);
20128 
20129     if (Op.isUndef())
20130       continue;
20131 
20132     // Check if this is the identity extract:
20133     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20134       return SDValue();
20135 
20136     // Find the single incoming vector for the extract_subvector.
20137     if (SingleSource.getNode()) {
20138       if (Op.getOperand(0) != SingleSource)
20139         return SDValue();
20140     } else {
20141       SingleSource = Op.getOperand(0);
20142 
20143       // Check the source type is the same as the type of the result.
20144       // If not, this concat may extend the vector, so we can not
20145       // optimize it away.
20146       if (SingleSource.getValueType() != N->getValueType(0))
20147         return SDValue();
20148     }
20149 
20150     // Check that we are reading from the identity index.
20151     unsigned IdentityIndex = i * PartNumElem;
20152     if (Op.getConstantOperandAPInt(1) != IdentityIndex)
20153       return SDValue();
20154   }
20155 
20156   if (SingleSource.getNode())
20157     return SingleSource;
20158 
20159   return SDValue();
20160 }
20161 
20162 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
20163 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)20164 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
20165   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
20166       V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
20167     return V.getOperand(1);
20168   }
20169   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
20170   if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
20171       V.getOperand(0).getValueType() == SubVT &&
20172       (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
20173     uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
20174     return V.getOperand(SubIdx);
20175   }
20176   return SDValue();
20177 }
20178 
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)20179 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
20180                                               SelectionDAG &DAG,
20181                                               bool LegalOperations) {
20182   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20183   SDValue BinOp = Extract->getOperand(0);
20184   unsigned BinOpcode = BinOp.getOpcode();
20185   if (!TLI.isBinOp(BinOpcode) || BinOp.getNode()->getNumValues() != 1)
20186     return SDValue();
20187 
20188   EVT VecVT = BinOp.getValueType();
20189   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
20190   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
20191     return SDValue();
20192 
20193   SDValue Index = Extract->getOperand(1);
20194   EVT SubVT = Extract->getValueType(0);
20195   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
20196     return SDValue();
20197 
20198   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
20199   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
20200 
20201   // TODO: We could handle the case where only 1 operand is being inserted by
20202   //       creating an extract of the other operand, but that requires checking
20203   //       number of uses and/or costs.
20204   if (!Sub0 || !Sub1)
20205     return SDValue();
20206 
20207   // We are inserting both operands of the wide binop only to extract back
20208   // to the narrow vector size. Eliminate all of the insert/extract:
20209   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
20210   return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
20211                      BinOp->getFlags());
20212 }
20213 
20214 /// If we are extracting a subvector produced by a wide binary operator try
20215 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)20216 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
20217                                           bool LegalOperations) {
20218   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
20219   // some of these bailouts with other transforms.
20220 
20221   if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
20222     return V;
20223 
20224   // The extract index must be a constant, so we can map it to a concat operand.
20225   auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
20226   if (!ExtractIndexC)
20227     return SDValue();
20228 
20229   // We are looking for an optionally bitcasted wide vector binary operator
20230   // feeding an extract subvector.
20231   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20232   SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
20233   unsigned BOpcode = BinOp.getOpcode();
20234   if (!TLI.isBinOp(BOpcode) || BinOp.getNode()->getNumValues() != 1)
20235     return SDValue();
20236 
20237   // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
20238   // reduced to the unary fneg when it is visited, and we probably want to deal
20239   // with fneg in a target-specific way.
20240   if (BOpcode == ISD::FSUB) {
20241     auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
20242     if (C && C->getValueAPF().isNegZero())
20243       return SDValue();
20244   }
20245 
20246   // The binop must be a vector type, so we can extract some fraction of it.
20247   EVT WideBVT = BinOp.getValueType();
20248   // The optimisations below currently assume we are dealing with fixed length
20249   // vectors. It is possible to add support for scalable vectors, but at the
20250   // moment we've done no analysis to prove whether they are profitable or not.
20251   if (!WideBVT.isFixedLengthVector())
20252     return SDValue();
20253 
20254   EVT VT = Extract->getValueType(0);
20255   unsigned ExtractIndex = ExtractIndexC->getZExtValue();
20256   assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
20257          "Extract index is not a multiple of the vector length.");
20258 
20259   // Bail out if this is not a proper multiple width extraction.
20260   unsigned WideWidth = WideBVT.getSizeInBits();
20261   unsigned NarrowWidth = VT.getSizeInBits();
20262   if (WideWidth % NarrowWidth != 0)
20263     return SDValue();
20264 
20265   // Bail out if we are extracting a fraction of a single operation. This can
20266   // occur because we potentially looked through a bitcast of the binop.
20267   unsigned NarrowingRatio = WideWidth / NarrowWidth;
20268   unsigned WideNumElts = WideBVT.getVectorNumElements();
20269   if (WideNumElts % NarrowingRatio != 0)
20270     return SDValue();
20271 
20272   // Bail out if the target does not support a narrower version of the binop.
20273   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
20274                                    WideNumElts / NarrowingRatio);
20275   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
20276     return SDValue();
20277 
20278   // If extraction is cheap, we don't need to look at the binop operands
20279   // for concat ops. The narrow binop alone makes this transform profitable.
20280   // We can't just reuse the original extract index operand because we may have
20281   // bitcasted.
20282   unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
20283   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
20284   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
20285       BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
20286     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
20287     SDLoc DL(Extract);
20288     SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
20289     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
20290                             BinOp.getOperand(0), NewExtIndex);
20291     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
20292                             BinOp.getOperand(1), NewExtIndex);
20293     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y,
20294                                       BinOp.getNode()->getFlags());
20295     return DAG.getBitcast(VT, NarrowBinOp);
20296   }
20297 
20298   // Only handle the case where we are doubling and then halving. A larger ratio
20299   // may require more than two narrow binops to replace the wide binop.
20300   if (NarrowingRatio != 2)
20301     return SDValue();
20302 
20303   // TODO: The motivating case for this transform is an x86 AVX1 target. That
20304   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
20305   // flavors, but no other 256-bit integer support. This could be extended to
20306   // handle any binop, but that may require fixing/adding other folds to avoid
20307   // codegen regressions.
20308   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
20309     return SDValue();
20310 
20311   // We need at least one concatenation operation of a binop operand to make
20312   // this transform worthwhile. The concat must double the input vector sizes.
20313   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
20314     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
20315       return V.getOperand(ConcatOpNum);
20316     return SDValue();
20317   };
20318   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
20319   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
20320 
20321   if (SubVecL || SubVecR) {
20322     // If a binop operand was not the result of a concat, we must extract a
20323     // half-sized operand for our new narrow binop:
20324     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
20325     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
20326     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
20327     SDLoc DL(Extract);
20328     SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
20329     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
20330                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
20331                                       BinOp.getOperand(0), IndexC);
20332 
20333     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
20334                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
20335                                       BinOp.getOperand(1), IndexC);
20336 
20337     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
20338     return DAG.getBitcast(VT, NarrowBinOp);
20339   }
20340 
20341   return SDValue();
20342 }
20343 
20344 /// If we are extracting a subvector from a wide vector load, convert to a
20345 /// narrow load to eliminate the extraction:
20346 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)20347 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
20348   // TODO: Add support for big-endian. The offset calculation must be adjusted.
20349   if (DAG.getDataLayout().isBigEndian())
20350     return SDValue();
20351 
20352   auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
20353   auto *ExtIdx = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
20354   if (!Ld || Ld->getExtensionType() || !Ld->isSimple() ||
20355       !ExtIdx)
20356     return SDValue();
20357 
20358   // Allow targets to opt-out.
20359   EVT VT = Extract->getValueType(0);
20360 
20361   // We can only create byte sized loads.
20362   if (!VT.isByteSized())
20363     return SDValue();
20364 
20365   unsigned Index = ExtIdx->getZExtValue();
20366   unsigned NumElts = VT.getVectorMinNumElements();
20367 
20368   // The definition of EXTRACT_SUBVECTOR states that the index must be a
20369   // multiple of the minimum number of elements in the result type.
20370   assert(Index % NumElts == 0 && "The extract subvector index is not a "
20371                                  "multiple of the result's element count");
20372 
20373   // It's fine to use TypeSize here as we know the offset will not be negative.
20374   TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
20375 
20376   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20377   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
20378     return SDValue();
20379 
20380   // The narrow load will be offset from the base address of the old load if
20381   // we are extracting from something besides index 0 (little-endian).
20382   SDLoc DL(Extract);
20383 
20384   // TODO: Use "BaseIndexOffset" to make this more effective.
20385   SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
20386 
20387   uint64_t StoreSize = MemoryLocation::getSizeOrUnknown(VT.getStoreSize());
20388   MachineFunction &MF = DAG.getMachineFunction();
20389   MachineMemOperand *MMO;
20390   if (Offset.isScalable()) {
20391     MachinePointerInfo MPI =
20392         MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
20393     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, StoreSize);
20394   } else
20395     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedSize(),
20396                                   StoreSize);
20397 
20398   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
20399   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
20400   return NewLd;
20401 }
20402 
visitEXTRACT_SUBVECTOR(SDNode * N)20403 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
20404   EVT NVT = N->getValueType(0);
20405   SDValue V = N->getOperand(0);
20406   uint64_t ExtIdx = N->getConstantOperandVal(1);
20407 
20408   // Extract from UNDEF is UNDEF.
20409   if (V.isUndef())
20410     return DAG.getUNDEF(NVT);
20411 
20412   if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
20413     if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
20414       return NarrowLoad;
20415 
20416   // Combine an extract of an extract into a single extract_subvector.
20417   // ext (ext X, C), 0 --> ext X, C
20418   if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
20419     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
20420                                     V.getConstantOperandVal(1)) &&
20421         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
20422       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
20423                          V.getOperand(1));
20424     }
20425   }
20426 
20427   // Try to move vector bitcast after extract_subv by scaling extraction index:
20428   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
20429   if (V.getOpcode() == ISD::BITCAST &&
20430       V.getOperand(0).getValueType().isVector() &&
20431       (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
20432     SDValue SrcOp = V.getOperand(0);
20433     EVT SrcVT = SrcOp.getValueType();
20434     unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
20435     unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
20436     if ((SrcNumElts % DestNumElts) == 0) {
20437       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
20438       ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
20439       EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
20440                                       NewExtEC);
20441       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
20442         SDLoc DL(N);
20443         SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
20444         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
20445                                          V.getOperand(0), NewIndex);
20446         return DAG.getBitcast(NVT, NewExtract);
20447       }
20448     }
20449     if ((DestNumElts % SrcNumElts) == 0) {
20450       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
20451       if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
20452         ElementCount NewExtEC =
20453             NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
20454         EVT ScalarVT = SrcVT.getScalarType();
20455         if ((ExtIdx % DestSrcRatio) == 0) {
20456           SDLoc DL(N);
20457           unsigned IndexValScaled = ExtIdx / DestSrcRatio;
20458           EVT NewExtVT =
20459               EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
20460           if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
20461             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
20462             SDValue NewExtract =
20463                 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
20464                             V.getOperand(0), NewIndex);
20465             return DAG.getBitcast(NVT, NewExtract);
20466           }
20467           if (NewExtEC.isScalar() &&
20468               TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
20469             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
20470             SDValue NewExtract =
20471                 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
20472                             V.getOperand(0), NewIndex);
20473             return DAG.getBitcast(NVT, NewExtract);
20474           }
20475         }
20476       }
20477     }
20478   }
20479 
20480   if (V.getOpcode() == ISD::CONCAT_VECTORS) {
20481     unsigned ExtNumElts = NVT.getVectorMinNumElements();
20482     EVT ConcatSrcVT = V.getOperand(0).getValueType();
20483     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
20484            "Concat and extract subvector do not change element type");
20485     assert((ExtIdx % ExtNumElts) == 0 &&
20486            "Extract index is not a multiple of the input vector length.");
20487 
20488     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
20489     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
20490 
20491     // If the concatenated source types match this extract, it's a direct
20492     // simplification:
20493     // extract_subvec (concat V1, V2, ...), i --> Vi
20494     if (ConcatSrcNumElts == ExtNumElts)
20495       return V.getOperand(ConcatOpIdx);
20496 
20497     // If the concatenated source vectors are a multiple length of this extract,
20498     // then extract a fraction of one of those source vectors directly from a
20499     // concat operand. Example:
20500     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
20501     //   v2i8 extract_subvec v8i8 Y, 6
20502     if (NVT.isFixedLengthVector() && ConcatSrcNumElts % ExtNumElts == 0) {
20503       SDLoc DL(N);
20504       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
20505       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
20506              "Trying to extract from >1 concat operand?");
20507       assert(NewExtIdx % ExtNumElts == 0 &&
20508              "Extract index is not a multiple of the input vector length.");
20509       SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
20510       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
20511                          V.getOperand(ConcatOpIdx), NewIndexC);
20512     }
20513   }
20514 
20515   V = peekThroughBitcasts(V);
20516 
20517   // If the input is a build vector. Try to make a smaller build vector.
20518   if (V.getOpcode() == ISD::BUILD_VECTOR) {
20519     EVT InVT = V.getValueType();
20520     unsigned ExtractSize = NVT.getSizeInBits();
20521     unsigned EltSize = InVT.getScalarSizeInBits();
20522     // Only do this if we won't split any elements.
20523     if (ExtractSize % EltSize == 0) {
20524       unsigned NumElems = ExtractSize / EltSize;
20525       EVT EltVT = InVT.getVectorElementType();
20526       EVT ExtractVT =
20527           NumElems == 1 ? EltVT
20528                         : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
20529       if ((Level < AfterLegalizeDAG ||
20530            (NumElems == 1 ||
20531             TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
20532           (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
20533         unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
20534 
20535         if (NumElems == 1) {
20536           SDValue Src = V->getOperand(IdxVal);
20537           if (EltVT != Src.getValueType())
20538             Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
20539           return DAG.getBitcast(NVT, Src);
20540         }
20541 
20542         // Extract the pieces from the original build_vector.
20543         SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N),
20544                                               V->ops().slice(IdxVal, NumElems));
20545         return DAG.getBitcast(NVT, BuildVec);
20546       }
20547     }
20548   }
20549 
20550   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
20551     // Handle only simple case where vector being inserted and vector
20552     // being extracted are of same size.
20553     EVT SmallVT = V.getOperand(1).getValueType();
20554     if (!NVT.bitsEq(SmallVT))
20555       return SDValue();
20556 
20557     // Combine:
20558     //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
20559     // Into:
20560     //    indices are equal or bit offsets are equal => V1
20561     //    otherwise => (extract_subvec V1, ExtIdx)
20562     uint64_t InsIdx = V.getConstantOperandVal(2);
20563     if (InsIdx * SmallVT.getScalarSizeInBits() ==
20564         ExtIdx * NVT.getScalarSizeInBits()) {
20565       if (LegalOperations && !TLI.isOperationLegal(ISD::BITCAST, NVT))
20566         return SDValue();
20567 
20568       return DAG.getBitcast(NVT, V.getOperand(1));
20569     }
20570     return DAG.getNode(
20571         ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
20572         DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
20573         N->getOperand(1));
20574   }
20575 
20576   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
20577     return NarrowBOp;
20578 
20579   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
20580     return SDValue(N, 0);
20581 
20582   return SDValue();
20583 }
20584 
20585 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
20586 /// followed by concatenation. Narrow vector ops may have better performance
20587 /// than wide ops, and this can unlock further narrowing of other vector ops.
20588 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)20589 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
20590                                          SelectionDAG &DAG) {
20591   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
20592   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
20593       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
20594       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
20595     return SDValue();
20596 
20597   // Split the wide shuffle mask into halves. Any mask element that is accessing
20598   // operand 1 is offset down to account for narrowing of the vectors.
20599   ArrayRef<int> Mask = Shuf->getMask();
20600   EVT VT = Shuf->getValueType(0);
20601   unsigned NumElts = VT.getVectorNumElements();
20602   unsigned HalfNumElts = NumElts / 2;
20603   SmallVector<int, 16> Mask0(HalfNumElts, -1);
20604   SmallVector<int, 16> Mask1(HalfNumElts, -1);
20605   for (unsigned i = 0; i != NumElts; ++i) {
20606     if (Mask[i] == -1)
20607       continue;
20608     // If we reference the upper (undef) subvector then the element is undef.
20609     if ((Mask[i] % NumElts) >= HalfNumElts)
20610       continue;
20611     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
20612     if (i < HalfNumElts)
20613       Mask0[i] = M;
20614     else
20615       Mask1[i - HalfNumElts] = M;
20616   }
20617 
20618   // Ask the target if this is a valid transform.
20619   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20620   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
20621                                 HalfNumElts);
20622   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
20623       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
20624     return SDValue();
20625 
20626   // shuffle (concat X, undef), (concat Y, undef), Mask -->
20627   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
20628   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
20629   SDLoc DL(Shuf);
20630   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
20631   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
20632   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
20633 }
20634 
20635 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
20636 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)20637 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
20638   EVT VT = N->getValueType(0);
20639   unsigned NumElts = VT.getVectorNumElements();
20640 
20641   SDValue N0 = N->getOperand(0);
20642   SDValue N1 = N->getOperand(1);
20643   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
20644   ArrayRef<int> Mask = SVN->getMask();
20645 
20646   SmallVector<SDValue, 4> Ops;
20647   EVT ConcatVT = N0.getOperand(0).getValueType();
20648   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
20649   unsigned NumConcats = NumElts / NumElemsPerConcat;
20650 
20651   auto IsUndefMaskElt = [](int i) { return i == -1; };
20652 
20653   // Special case: shuffle(concat(A,B)) can be more efficiently represented
20654   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
20655   // half vector elements.
20656   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
20657       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
20658                    IsUndefMaskElt)) {
20659     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
20660                               N0.getOperand(1),
20661                               Mask.slice(0, NumElemsPerConcat));
20662     N1 = DAG.getUNDEF(ConcatVT);
20663     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
20664   }
20665 
20666   // Look at every vector that's inserted. We're looking for exact
20667   // subvector-sized copies from a concatenated vector
20668   for (unsigned I = 0; I != NumConcats; ++I) {
20669     unsigned Begin = I * NumElemsPerConcat;
20670     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
20671 
20672     // Make sure we're dealing with a copy.
20673     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
20674       Ops.push_back(DAG.getUNDEF(ConcatVT));
20675       continue;
20676     }
20677 
20678     int OpIdx = -1;
20679     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
20680       if (IsUndefMaskElt(SubMask[i]))
20681         continue;
20682       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
20683         return SDValue();
20684       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
20685       if (0 <= OpIdx && EltOpIdx != OpIdx)
20686         return SDValue();
20687       OpIdx = EltOpIdx;
20688     }
20689     assert(0 <= OpIdx && "Unknown concat_vectors op");
20690 
20691     if (OpIdx < (int)N0.getNumOperands())
20692       Ops.push_back(N0.getOperand(OpIdx));
20693     else
20694       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
20695   }
20696 
20697   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
20698 }
20699 
20700 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
20701 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
20702 //
20703 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
20704 // a simplification in some sense, but it isn't appropriate in general: some
20705 // BUILD_VECTORs are substantially cheaper than others. The general case
20706 // of a BUILD_VECTOR requires inserting each element individually (or
20707 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
20708 // all constants is a single constant pool load.  A BUILD_VECTOR where each
20709 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
20710 // are undef lowers to a small number of element insertions.
20711 //
20712 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
20713 // We don't fold shuffles where one side is a non-zero constant, and we don't
20714 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
20715 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)20716 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
20717                                        SelectionDAG &DAG,
20718                                        const TargetLowering &TLI) {
20719   EVT VT = SVN->getValueType(0);
20720   unsigned NumElts = VT.getVectorNumElements();
20721   SDValue N0 = SVN->getOperand(0);
20722   SDValue N1 = SVN->getOperand(1);
20723 
20724   if (!N0->hasOneUse())
20725     return SDValue();
20726 
20727   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
20728   // discussed above.
20729   if (!N1.isUndef()) {
20730     if (!N1->hasOneUse())
20731       return SDValue();
20732 
20733     bool N0AnyConst = isAnyConstantBuildVector(N0);
20734     bool N1AnyConst = isAnyConstantBuildVector(N1);
20735     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
20736       return SDValue();
20737     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
20738       return SDValue();
20739   }
20740 
20741   // If both inputs are splats of the same value then we can safely merge this
20742   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
20743   bool IsSplat = false;
20744   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
20745   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
20746   if (BV0 && BV1)
20747     if (SDValue Splat0 = BV0->getSplatValue())
20748       IsSplat = (Splat0 == BV1->getSplatValue());
20749 
20750   SmallVector<SDValue, 8> Ops;
20751   SmallSet<SDValue, 16> DuplicateOps;
20752   for (int M : SVN->getMask()) {
20753     SDValue Op = DAG.getUNDEF(VT.getScalarType());
20754     if (M >= 0) {
20755       int Idx = M < (int)NumElts ? M : M - NumElts;
20756       SDValue &S = (M < (int)NumElts ? N0 : N1);
20757       if (S.getOpcode() == ISD::BUILD_VECTOR) {
20758         Op = S.getOperand(Idx);
20759       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
20760         SDValue Op0 = S.getOperand(0);
20761         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
20762       } else {
20763         // Operand can't be combined - bail out.
20764         return SDValue();
20765       }
20766     }
20767 
20768     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
20769     // generating a splat; semantically, this is fine, but it's likely to
20770     // generate low-quality code if the target can't reconstruct an appropriate
20771     // shuffle.
20772     if (!Op.isUndef() && !isIntOrFPConstant(Op))
20773       if (!IsSplat && !DuplicateOps.insert(Op).second)
20774         return SDValue();
20775 
20776     Ops.push_back(Op);
20777   }
20778 
20779   // BUILD_VECTOR requires all inputs to be of the same type, find the
20780   // maximum type and extend them all.
20781   EVT SVT = VT.getScalarType();
20782   if (SVT.isInteger())
20783     for (SDValue &Op : Ops)
20784       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
20785   if (SVT != VT.getScalarType())
20786     for (SDValue &Op : Ops)
20787       Op = TLI.isZExtFree(Op.getValueType(), SVT)
20788                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
20789                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT);
20790   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
20791 }
20792 
20793 // Match shuffles that can be converted to any_vector_extend_in_reg.
20794 // This is often generated during legalization.
20795 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
20796 // TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case.
combineShuffleToVectorExtend(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)20797 static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN,
20798                                             SelectionDAG &DAG,
20799                                             const TargetLowering &TLI,
20800                                             bool LegalOperations) {
20801   EVT VT = SVN->getValueType(0);
20802   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
20803 
20804   // TODO Add support for big-endian when we have a test case.
20805   if (!VT.isInteger() || IsBigEndian)
20806     return SDValue();
20807 
20808   unsigned NumElts = VT.getVectorNumElements();
20809   unsigned EltSizeInBits = VT.getScalarSizeInBits();
20810   ArrayRef<int> Mask = SVN->getMask();
20811   SDValue N0 = SVN->getOperand(0);
20812 
20813   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
20814   auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) {
20815     for (unsigned i = 0; i != NumElts; ++i) {
20816       if (Mask[i] < 0)
20817         continue;
20818       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
20819         continue;
20820       return false;
20821     }
20822     return true;
20823   };
20824 
20825   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
20826   // power-of-2 extensions as they are the most likely.
20827   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
20828     // Check for non power of 2 vector sizes
20829     if (NumElts % Scale != 0)
20830       continue;
20831     if (!isAnyExtend(Scale))
20832       continue;
20833 
20834     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
20835     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
20836     // Never create an illegal type. Only create unsupported operations if we
20837     // are pre-legalization.
20838     if (TLI.isTypeLegal(OutVT))
20839       if (!LegalOperations ||
20840           TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT))
20841         return DAG.getBitcast(VT,
20842                               DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG,
20843                                           SDLoc(SVN), OutVT, N0));
20844   }
20845 
20846   return SDValue();
20847 }
20848 
20849 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
20850 // each source element of a large type into the lowest elements of a smaller
20851 // destination type. This is often generated during legalization.
20852 // If the source node itself was a '*_extend_vector_inreg' node then we should
20853 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)20854 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
20855                                         SelectionDAG &DAG) {
20856   EVT VT = SVN->getValueType(0);
20857   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
20858 
20859   // TODO Add support for big-endian when we have a test case.
20860   if (!VT.isInteger() || IsBigEndian)
20861     return SDValue();
20862 
20863   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
20864 
20865   unsigned Opcode = N0.getOpcode();
20866   if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
20867       Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
20868       Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
20869     return SDValue();
20870 
20871   SDValue N00 = N0.getOperand(0);
20872   ArrayRef<int> Mask = SVN->getMask();
20873   unsigned NumElts = VT.getVectorNumElements();
20874   unsigned EltSizeInBits = VT.getScalarSizeInBits();
20875   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
20876   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
20877 
20878   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
20879     return SDValue();
20880   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
20881 
20882   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
20883   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
20884   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
20885   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
20886     for (unsigned i = 0; i != NumElts; ++i) {
20887       if (Mask[i] < 0)
20888         continue;
20889       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
20890         continue;
20891       return false;
20892     }
20893     return true;
20894   };
20895 
20896   // At the moment we just handle the case where we've truncated back to the
20897   // same size as before the extension.
20898   // TODO: handle more extension/truncation cases as cases arise.
20899   if (EltSizeInBits != ExtSrcSizeInBits)
20900     return SDValue();
20901 
20902   // We can remove *extend_vector_inreg only if the truncation happens at
20903   // the same scale as the extension.
20904   if (isTruncate(ExtScale))
20905     return DAG.getBitcast(VT, N00);
20906 
20907   return SDValue();
20908 }
20909 
20910 // Combine shuffles of splat-shuffles of the form:
20911 // shuffle (shuffle V, undef, splat-mask), undef, M
20912 // If splat-mask contains undef elements, we need to be careful about
20913 // introducing undef's in the folded mask which are not the result of composing
20914 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)20915 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
20916                                         SelectionDAG &DAG) {
20917   if (!Shuf->getOperand(1).isUndef())
20918     return SDValue();
20919   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
20920   if (!Splat || !Splat->isSplat())
20921     return SDValue();
20922 
20923   ArrayRef<int> ShufMask = Shuf->getMask();
20924   ArrayRef<int> SplatMask = Splat->getMask();
20925   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
20926 
20927   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
20928   // every undef mask element in the splat-shuffle has a corresponding undef
20929   // element in the user-shuffle's mask or if the composition of mask elements
20930   // would result in undef.
20931   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
20932   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
20933   //   In this case it is not legal to simplify to the splat-shuffle because we
20934   //   may be exposing the users of the shuffle an undef element at index 1
20935   //   which was not there before the combine.
20936   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
20937   //   In this case the composition of masks yields SplatMask, so it's ok to
20938   //   simplify to the splat-shuffle.
20939   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
20940   //   In this case the composed mask includes all undef elements of SplatMask
20941   //   and in addition sets element zero to undef. It is safe to simplify to
20942   //   the splat-shuffle.
20943   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
20944                                        ArrayRef<int> SplatMask) {
20945     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
20946       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
20947           SplatMask[UserMask[i]] != -1)
20948         return false;
20949     return true;
20950   };
20951   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
20952     return Shuf->getOperand(0);
20953 
20954   // Create a new shuffle with a mask that is composed of the two shuffles'
20955   // masks.
20956   SmallVector<int, 32> NewMask;
20957   for (int Idx : ShufMask)
20958     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
20959 
20960   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
20961                               Splat->getOperand(0), Splat->getOperand(1),
20962                               NewMask);
20963 }
20964 
20965 /// Combine shuffle of shuffle of the form:
20966 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)20967 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
20968                                      SelectionDAG &DAG) {
20969   if (!OuterShuf->getOperand(1).isUndef())
20970     return SDValue();
20971   auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
20972   if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
20973     return SDValue();
20974 
20975   ArrayRef<int> OuterMask = OuterShuf->getMask();
20976   ArrayRef<int> InnerMask = InnerShuf->getMask();
20977   unsigned NumElts = OuterMask.size();
20978   assert(NumElts == InnerMask.size() && "Mask length mismatch");
20979   SmallVector<int, 32> CombinedMask(NumElts, -1);
20980   int SplatIndex = -1;
20981   for (unsigned i = 0; i != NumElts; ++i) {
20982     // Undef lanes remain undef.
20983     int OuterMaskElt = OuterMask[i];
20984     if (OuterMaskElt == -1)
20985       continue;
20986 
20987     // Peek through the shuffle masks to get the underlying source element.
20988     int InnerMaskElt = InnerMask[OuterMaskElt];
20989     if (InnerMaskElt == -1)
20990       continue;
20991 
20992     // Initialize the splatted element.
20993     if (SplatIndex == -1)
20994       SplatIndex = InnerMaskElt;
20995 
20996     // Non-matching index - this is not a splat.
20997     if (SplatIndex != InnerMaskElt)
20998       return SDValue();
20999 
21000     CombinedMask[i] = InnerMaskElt;
21001   }
21002   assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
21003           getSplatIndex(CombinedMask) != -1) &&
21004          "Expected a splat mask");
21005 
21006   // TODO: The transform may be a win even if the mask is not legal.
21007   EVT VT = OuterShuf->getValueType(0);
21008   assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
21009   if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
21010     return SDValue();
21011 
21012   return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
21013                               InnerShuf->getOperand(1), CombinedMask);
21014 }
21015 
21016 /// If the shuffle mask is taking exactly one element from the first vector
21017 /// operand and passing through all other elements from the second vector
21018 /// operand, return the index of the mask element that is choosing an element
21019 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)21020 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
21021   int MaskSize = Mask.size();
21022   int EltFromOp0 = -1;
21023   // TODO: This does not match if there are undef elements in the shuffle mask.
21024   // Should we ignore undefs in the shuffle mask instead? The trade-off is
21025   // removing an instruction (a shuffle), but losing the knowledge that some
21026   // vector lanes are not needed.
21027   for (int i = 0; i != MaskSize; ++i) {
21028     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
21029       // We're looking for a shuffle of exactly one element from operand 0.
21030       if (EltFromOp0 != -1)
21031         return -1;
21032       EltFromOp0 = i;
21033     } else if (Mask[i] != i + MaskSize) {
21034       // Nothing from operand 1 can change lanes.
21035       return -1;
21036     }
21037   }
21038   return EltFromOp0;
21039 }
21040 
21041 /// If a shuffle inserts exactly one element from a source vector operand into
21042 /// another vector operand and we can access the specified element as a scalar,
21043 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)21044 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
21045                                       SelectionDAG &DAG) {
21046   // First, check if we are taking one element of a vector and shuffling that
21047   // element into another vector.
21048   ArrayRef<int> Mask = Shuf->getMask();
21049   SmallVector<int, 16> CommutedMask(Mask.begin(), Mask.end());
21050   SDValue Op0 = Shuf->getOperand(0);
21051   SDValue Op1 = Shuf->getOperand(1);
21052   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
21053   if (ShufOp0Index == -1) {
21054     // Commute mask and check again.
21055     ShuffleVectorSDNode::commuteMask(CommutedMask);
21056     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
21057     if (ShufOp0Index == -1)
21058       return SDValue();
21059     // Commute operands to match the commuted shuffle mask.
21060     std::swap(Op0, Op1);
21061     Mask = CommutedMask;
21062   }
21063 
21064   // The shuffle inserts exactly one element from operand 0 into operand 1.
21065   // Now see if we can access that element as a scalar via a real insert element
21066   // instruction.
21067   // TODO: We can try harder to locate the element as a scalar. Examples: it
21068   // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
21069   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
21070          "Shuffle mask value must be from operand 0");
21071   if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
21072     return SDValue();
21073 
21074   auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
21075   if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
21076     return SDValue();
21077 
21078   // There's an existing insertelement with constant insertion index, so we
21079   // don't need to check the legality/profitability of a replacement operation
21080   // that differs at most in the constant value. The target should be able to
21081   // lower any of those in a similar way. If not, legalization will expand this
21082   // to a scalar-to-vector plus shuffle.
21083   //
21084   // Note that the shuffle may move the scalar from the position that the insert
21085   // element used. Therefore, our new insert element occurs at the shuffle's
21086   // mask index value, not the insert's index value.
21087   // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
21088   SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
21089   return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
21090                      Op1, Op0.getOperand(1), NewInsIndex);
21091 }
21092 
21093 /// If we have a unary shuffle of a shuffle, see if it can be folded away
21094 /// completely. This has the potential to lose undef knowledge because the first
21095 /// shuffle may not have an undef mask element where the second one does. So
21096 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)21097 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
21098   // shuf (shuf0 X, Y, Mask0), undef, Mask
21099   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
21100   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
21101     return SDValue();
21102 
21103   ArrayRef<int> Mask = Shuf->getMask();
21104   ArrayRef<int> Mask0 = Shuf0->getMask();
21105   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
21106     // Ignore undef elements.
21107     if (Mask[i] == -1)
21108       continue;
21109     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
21110 
21111     // Is the element of the shuffle operand chosen by this shuffle the same as
21112     // the element chosen by the shuffle operand itself?
21113     if (Mask0[Mask[i]] != Mask0[i])
21114       return SDValue();
21115   }
21116   // Every element of this shuffle is identical to the result of the previous
21117   // shuffle, so we can replace this value.
21118   return Shuf->getOperand(0);
21119 }
21120 
visitVECTOR_SHUFFLE(SDNode * N)21121 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
21122   EVT VT = N->getValueType(0);
21123   unsigned NumElts = VT.getVectorNumElements();
21124 
21125   SDValue N0 = N->getOperand(0);
21126   SDValue N1 = N->getOperand(1);
21127 
21128   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
21129 
21130   // Canonicalize shuffle undef, undef -> undef
21131   if (N0.isUndef() && N1.isUndef())
21132     return DAG.getUNDEF(VT);
21133 
21134   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
21135 
21136   // Canonicalize shuffle v, v -> v, undef
21137   if (N0 == N1) {
21138     SmallVector<int, 8> NewMask;
21139     for (unsigned i = 0; i != NumElts; ++i) {
21140       int Idx = SVN->getMaskElt(i);
21141       if (Idx >= (int)NumElts) Idx -= NumElts;
21142       NewMask.push_back(Idx);
21143     }
21144     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT), NewMask);
21145   }
21146 
21147   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
21148   if (N0.isUndef())
21149     return DAG.getCommutedVectorShuffle(*SVN);
21150 
21151   // Remove references to rhs if it is undef
21152   if (N1.isUndef()) {
21153     bool Changed = false;
21154     SmallVector<int, 8> NewMask;
21155     for (unsigned i = 0; i != NumElts; ++i) {
21156       int Idx = SVN->getMaskElt(i);
21157       if (Idx >= (int)NumElts) {
21158         Idx = -1;
21159         Changed = true;
21160       }
21161       NewMask.push_back(Idx);
21162     }
21163     if (Changed)
21164       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
21165   }
21166 
21167   if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
21168     return InsElt;
21169 
21170   // A shuffle of a single vector that is a splatted value can always be folded.
21171   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
21172     return V;
21173 
21174   if (SDValue V = formSplatFromShuffles(SVN, DAG))
21175     return V;
21176 
21177   // If it is a splat, check if the argument vector is another splat or a
21178   // build_vector.
21179   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
21180     int SplatIndex = SVN->getSplatIndex();
21181     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
21182         TLI.isBinOp(N0.getOpcode()) && N0.getNode()->getNumValues() == 1) {
21183       // splat (vector_bo L, R), Index -->
21184       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
21185       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
21186       SDLoc DL(N);
21187       EVT EltVT = VT.getScalarType();
21188       SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
21189       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
21190       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
21191       SDValue NewBO = DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR,
21192                                   N0.getNode()->getFlags());
21193       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
21194       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
21195       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
21196     }
21197 
21198     // If this is a bit convert that changes the element type of the vector but
21199     // not the number of vector elements, look through it.  Be careful not to
21200     // look though conversions that change things like v4f32 to v2f64.
21201     SDNode *V = N0.getNode();
21202     if (V->getOpcode() == ISD::BITCAST) {
21203       SDValue ConvInput = V->getOperand(0);
21204       if (ConvInput.getValueType().isVector() &&
21205           ConvInput.getValueType().getVectorNumElements() == NumElts)
21206         V = ConvInput.getNode();
21207     }
21208 
21209     if (V->getOpcode() == ISD::BUILD_VECTOR) {
21210       assert(V->getNumOperands() == NumElts &&
21211              "BUILD_VECTOR has wrong number of operands");
21212       SDValue Base;
21213       bool AllSame = true;
21214       for (unsigned i = 0; i != NumElts; ++i) {
21215         if (!V->getOperand(i).isUndef()) {
21216           Base = V->getOperand(i);
21217           break;
21218         }
21219       }
21220       // Splat of <u, u, u, u>, return <u, u, u, u>
21221       if (!Base.getNode())
21222         return N0;
21223       for (unsigned i = 0; i != NumElts; ++i) {
21224         if (V->getOperand(i) != Base) {
21225           AllSame = false;
21226           break;
21227         }
21228       }
21229       // Splat of <x, x, x, x>, return <x, x, x, x>
21230       if (AllSame)
21231         return N0;
21232 
21233       // Canonicalize any other splat as a build_vector.
21234       SDValue Splatted = V->getOperand(SplatIndex);
21235       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
21236       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
21237 
21238       // We may have jumped through bitcasts, so the type of the
21239       // BUILD_VECTOR may not match the type of the shuffle.
21240       if (V->getValueType(0) != VT)
21241         NewBV = DAG.getBitcast(VT, NewBV);
21242       return NewBV;
21243     }
21244   }
21245 
21246   // Simplify source operands based on shuffle mask.
21247   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
21248     return SDValue(N, 0);
21249 
21250   // This is intentionally placed after demanded elements simplification because
21251   // it could eliminate knowledge of undef elements created by this shuffle.
21252   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
21253     return ShufOp;
21254 
21255   // Match shuffles that can be converted to any_vector_extend_in_reg.
21256   if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations))
21257     return V;
21258 
21259   // Combine "truncate_vector_in_reg" style shuffles.
21260   if (SDValue V = combineTruncationShuffle(SVN, DAG))
21261     return V;
21262 
21263   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
21264       Level < AfterLegalizeVectorOps &&
21265       (N1.isUndef() ||
21266       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
21267        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
21268     if (SDValue V = partitionShuffleOfConcats(N, DAG))
21269       return V;
21270   }
21271 
21272   // A shuffle of a concat of the same narrow vector can be reduced to use
21273   // only low-half elements of a concat with undef:
21274   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
21275   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
21276       N0.getNumOperands() == 2 &&
21277       N0.getOperand(0) == N0.getOperand(1)) {
21278     int HalfNumElts = (int)NumElts / 2;
21279     SmallVector<int, 8> NewMask;
21280     for (unsigned i = 0; i != NumElts; ++i) {
21281       int Idx = SVN->getMaskElt(i);
21282       if (Idx >= HalfNumElts) {
21283         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
21284         Idx -= HalfNumElts;
21285       }
21286       NewMask.push_back(Idx);
21287     }
21288     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
21289       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
21290       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
21291                                    N0.getOperand(0), UndefVec);
21292       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
21293     }
21294   }
21295 
21296   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
21297   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
21298   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
21299     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
21300       return Res;
21301 
21302   // If this shuffle only has a single input that is a bitcasted shuffle,
21303   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
21304   // back to their original types.
21305   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
21306       N1.isUndef() && Level < AfterLegalizeVectorOps &&
21307       TLI.isTypeLegal(VT)) {
21308 
21309     SDValue BC0 = peekThroughOneUseBitcasts(N0);
21310     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
21311       EVT SVT = VT.getScalarType();
21312       EVT InnerVT = BC0->getValueType(0);
21313       EVT InnerSVT = InnerVT.getScalarType();
21314 
21315       // Determine which shuffle works with the smaller scalar type.
21316       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
21317       EVT ScaleSVT = ScaleVT.getScalarType();
21318 
21319       if (TLI.isTypeLegal(ScaleVT) &&
21320           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
21321           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
21322         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
21323         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
21324 
21325         // Scale the shuffle masks to the smaller scalar type.
21326         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
21327         SmallVector<int, 8> InnerMask;
21328         SmallVector<int, 8> OuterMask;
21329         narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
21330         narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
21331 
21332         // Merge the shuffle masks.
21333         SmallVector<int, 8> NewMask;
21334         for (int M : OuterMask)
21335           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
21336 
21337         // Test for shuffle mask legality over both commutations.
21338         SDValue SV0 = BC0->getOperand(0);
21339         SDValue SV1 = BC0->getOperand(1);
21340         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
21341         if (!LegalMask) {
21342           std::swap(SV0, SV1);
21343           ShuffleVectorSDNode::commuteMask(NewMask);
21344           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
21345         }
21346 
21347         if (LegalMask) {
21348           SV0 = DAG.getBitcast(ScaleVT, SV0);
21349           SV1 = DAG.getBitcast(ScaleVT, SV1);
21350           return DAG.getBitcast(
21351               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
21352         }
21353       }
21354     }
21355   }
21356 
21357   // Compute the combined shuffle mask for a shuffle with SV0 as the first
21358   // operand, and SV1 as the second operand.
21359   // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
21360   //      Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
21361   auto MergeInnerShuffle =
21362       [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
21363                      ShuffleVectorSDNode *OtherSVN, SDValue N1,
21364                      const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
21365                      SmallVectorImpl<int> &Mask) -> bool {
21366     // Don't try to fold splats; they're likely to simplify somehow, or they
21367     // might be free.
21368     if (OtherSVN->isSplat())
21369       return false;
21370 
21371     SV0 = SV1 = SDValue();
21372     Mask.clear();
21373 
21374     for (unsigned i = 0; i != NumElts; ++i) {
21375       int Idx = SVN->getMaskElt(i);
21376       if (Idx < 0) {
21377         // Propagate Undef.
21378         Mask.push_back(Idx);
21379         continue;
21380       }
21381 
21382       if (Commute)
21383         Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
21384 
21385       SDValue CurrentVec;
21386       if (Idx < (int)NumElts) {
21387         // This shuffle index refers to the inner shuffle N0. Lookup the inner
21388         // shuffle mask to identify which vector is actually referenced.
21389         Idx = OtherSVN->getMaskElt(Idx);
21390         if (Idx < 0) {
21391           // Propagate Undef.
21392           Mask.push_back(Idx);
21393           continue;
21394         }
21395         CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
21396                                           : OtherSVN->getOperand(1);
21397       } else {
21398         // This shuffle index references an element within N1.
21399         CurrentVec = N1;
21400       }
21401 
21402       // Simple case where 'CurrentVec' is UNDEF.
21403       if (CurrentVec.isUndef()) {
21404         Mask.push_back(-1);
21405         continue;
21406       }
21407 
21408       // Canonicalize the shuffle index. We don't know yet if CurrentVec
21409       // will be the first or second operand of the combined shuffle.
21410       Idx = Idx % NumElts;
21411       if (!SV0.getNode() || SV0 == CurrentVec) {
21412         // Ok. CurrentVec is the left hand side.
21413         // Update the mask accordingly.
21414         SV0 = CurrentVec;
21415         Mask.push_back(Idx);
21416         continue;
21417       }
21418       if (!SV1.getNode() || SV1 == CurrentVec) {
21419         // Ok. CurrentVec is the right hand side.
21420         // Update the mask accordingly.
21421         SV1 = CurrentVec;
21422         Mask.push_back(Idx + NumElts);
21423         continue;
21424       }
21425 
21426       // Last chance - see if the vector is another shuffle and if it
21427       // uses one of the existing candidate shuffle ops.
21428       if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
21429         int InnerIdx = CurrentSVN->getMaskElt(Idx);
21430         if (InnerIdx < 0) {
21431           Mask.push_back(-1);
21432           continue;
21433         }
21434         SDValue InnerVec = (InnerIdx < (int)NumElts)
21435                                ? CurrentSVN->getOperand(0)
21436                                : CurrentSVN->getOperand(1);
21437         if (InnerVec.isUndef()) {
21438           Mask.push_back(-1);
21439           continue;
21440         }
21441         InnerIdx %= NumElts;
21442         if (InnerVec == SV0) {
21443           Mask.push_back(InnerIdx);
21444           continue;
21445         }
21446         if (InnerVec == SV1) {
21447           Mask.push_back(InnerIdx + NumElts);
21448           continue;
21449         }
21450       }
21451 
21452       // Bail out if we cannot convert the shuffle pair into a single shuffle.
21453       return false;
21454     }
21455 
21456     if (llvm::all_of(Mask, [](int M) { return M < 0; }))
21457       return true;
21458 
21459     // Avoid introducing shuffles with illegal mask.
21460     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
21461     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
21462     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
21463     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
21464     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
21465     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
21466     if (TLI.isShuffleMaskLegal(Mask, VT))
21467       return true;
21468 
21469     std::swap(SV0, SV1);
21470     ShuffleVectorSDNode::commuteMask(Mask);
21471     return TLI.isShuffleMaskLegal(Mask, VT);
21472   };
21473 
21474   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
21475     // Canonicalize shuffles according to rules:
21476     //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
21477     //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
21478     //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
21479     if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
21480         N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
21481       // The incoming shuffle must be of the same type as the result of the
21482       // current shuffle.
21483       assert(N1->getOperand(0).getValueType() == VT &&
21484              "Shuffle types don't match");
21485 
21486       SDValue SV0 = N1->getOperand(0);
21487       SDValue SV1 = N1->getOperand(1);
21488       bool HasSameOp0 = N0 == SV0;
21489       bool IsSV1Undef = SV1.isUndef();
21490       if (HasSameOp0 || IsSV1Undef || N0 == SV1)
21491         // Commute the operands of this shuffle so merging below will trigger.
21492         return DAG.getCommutedVectorShuffle(*SVN);
21493     }
21494 
21495     // Canonicalize splat shuffles to the RHS to improve merging below.
21496     //  shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
21497     if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
21498         N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
21499         cast<ShuffleVectorSDNode>(N0)->isSplat() &&
21500         !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
21501       return DAG.getCommutedVectorShuffle(*SVN);
21502     }
21503 
21504     // Try to fold according to rules:
21505     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
21506     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
21507     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
21508     // Don't try to fold shuffles with illegal type.
21509     // Only fold if this shuffle is the only user of the other shuffle.
21510     // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
21511     for (int i = 0; i != 2; ++i) {
21512       if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
21513           N->isOnlyUserOf(N->getOperand(i).getNode())) {
21514         // The incoming shuffle must be of the same type as the result of the
21515         // current shuffle.
21516         auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
21517         assert(OtherSV->getOperand(0).getValueType() == VT &&
21518                "Shuffle types don't match");
21519 
21520         SDValue SV0, SV1;
21521         SmallVector<int, 4> Mask;
21522         if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
21523                               SV0, SV1, Mask)) {
21524           // Check if all indices in Mask are Undef. In case, propagate Undef.
21525           if (llvm::all_of(Mask, [](int M) { return M < 0; }))
21526             return DAG.getUNDEF(VT);
21527 
21528           return DAG.getVectorShuffle(VT, SDLoc(N),
21529                                       SV0 ? SV0 : DAG.getUNDEF(VT),
21530                                       SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
21531         }
21532       }
21533     }
21534 
21535     // Merge shuffles through binops if we are able to merge it with at least
21536     // one other shuffles.
21537     // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
21538     // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
21539     unsigned SrcOpcode = N0.getOpcode();
21540     if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
21541         (N1.isUndef() ||
21542          (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
21543       // Get binop source ops, or just pass on the undef.
21544       SDValue Op00 = N0.getOperand(0);
21545       SDValue Op01 = N0.getOperand(1);
21546       SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
21547       SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
21548       // TODO: We might be able to relax the VT check but we don't currently
21549       // have any isBinOp() that has different result/ops VTs so play safe until
21550       // we have test coverage.
21551       if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
21552           Op01.getValueType() == VT && Op11.getValueType() == VT &&
21553           (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
21554            Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
21555            Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
21556            Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
21557         auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
21558                                         SmallVectorImpl<int> &Mask, bool LeftOp,
21559                                         bool Commute) {
21560           SDValue InnerN = Commute ? N1 : N0;
21561           SDValue Op0 = LeftOp ? Op00 : Op01;
21562           SDValue Op1 = LeftOp ? Op10 : Op11;
21563           if (Commute)
21564             std::swap(Op0, Op1);
21565           // Only accept the merged shuffle if we don't introduce undef elements,
21566           // or the inner shuffle already contained undef elements.
21567           auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
21568           return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
21569                  MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
21570                                    Mask) &&
21571                  (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
21572                   llvm::none_of(Mask, [](int M) { return M < 0; }));
21573         };
21574 
21575         // Ensure we don't increase the number of shuffles - we must merge a
21576         // shuffle from at least one of the LHS and RHS ops.
21577         bool MergedLeft = false;
21578         SDValue LeftSV0, LeftSV1;
21579         SmallVector<int, 4> LeftMask;
21580         if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
21581             CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
21582           MergedLeft = true;
21583         } else {
21584           LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
21585           LeftSV0 = Op00, LeftSV1 = Op10;
21586         }
21587 
21588         bool MergedRight = false;
21589         SDValue RightSV0, RightSV1;
21590         SmallVector<int, 4> RightMask;
21591         if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
21592             CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
21593           MergedRight = true;
21594         } else {
21595           RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
21596           RightSV0 = Op01, RightSV1 = Op11;
21597         }
21598 
21599         if (MergedLeft || MergedRight) {
21600           SDLoc DL(N);
21601           SDValue LHS = DAG.getVectorShuffle(
21602               VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
21603               LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
21604           SDValue RHS = DAG.getVectorShuffle(
21605               VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
21606               RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
21607           return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
21608         }
21609       }
21610     }
21611   }
21612 
21613   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
21614     return V;
21615 
21616   return SDValue();
21617 }
21618 
visitSCALAR_TO_VECTOR(SDNode * N)21619 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
21620   SDValue InVal = N->getOperand(0);
21621   EVT VT = N->getValueType(0);
21622 
21623   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
21624   // with a VECTOR_SHUFFLE and possible truncate.
21625   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21626       VT.isFixedLengthVector() &&
21627       InVal->getOperand(0).getValueType().isFixedLengthVector()) {
21628     SDValue InVec = InVal->getOperand(0);
21629     SDValue EltNo = InVal->getOperand(1);
21630     auto InVecT = InVec.getValueType();
21631     if (ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo)) {
21632       SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1);
21633       int Elt = C0->getZExtValue();
21634       NewMask[0] = Elt;
21635       // If we have an implict truncate do truncate here as long as it's legal.
21636       // if it's not legal, this should
21637       if (VT.getScalarType() != InVal.getValueType() &&
21638           InVal.getValueType().isScalarInteger() &&
21639           isTypeLegal(VT.getScalarType())) {
21640         SDValue Val =
21641             DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal);
21642         return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
21643       }
21644       if (VT.getScalarType() == InVecT.getScalarType() &&
21645           VT.getVectorNumElements() <= InVecT.getVectorNumElements()) {
21646         SDValue LegalShuffle =
21647           TLI.buildLegalVectorShuffle(InVecT, SDLoc(N), InVec,
21648                                       DAG.getUNDEF(InVecT), NewMask, DAG);
21649         if (LegalShuffle) {
21650           // If the initial vector is the correct size this shuffle is a
21651           // valid result.
21652           if (VT == InVecT)
21653             return LegalShuffle;
21654           // If not we must truncate the vector.
21655           if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
21656             SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
21657             EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
21658                                          InVecT.getVectorElementType(),
21659                                          VT.getVectorNumElements());
21660             return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT,
21661                                LegalShuffle, ZeroIdx);
21662           }
21663         }
21664       }
21665     }
21666   }
21667 
21668   return SDValue();
21669 }
21670 
visitINSERT_SUBVECTOR(SDNode * N)21671 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
21672   EVT VT = N->getValueType(0);
21673   SDValue N0 = N->getOperand(0);
21674   SDValue N1 = N->getOperand(1);
21675   SDValue N2 = N->getOperand(2);
21676   uint64_t InsIdx = N->getConstantOperandVal(2);
21677 
21678   // If inserting an UNDEF, just return the original vector.
21679   if (N1.isUndef())
21680     return N0;
21681 
21682   // If this is an insert of an extracted vector into an undef vector, we can
21683   // just use the input to the extract.
21684   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
21685       N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
21686     return N1.getOperand(0);
21687 
21688   // If we are inserting a bitcast value into an undef, with the same
21689   // number of elements, just use the bitcast input of the extract.
21690   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
21691   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
21692   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
21693       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
21694       N1.getOperand(0).getOperand(1) == N2 &&
21695       N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
21696           VT.getVectorElementCount() &&
21697       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
21698           VT.getSizeInBits()) {
21699     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
21700   }
21701 
21702   // If both N1 and N2 are bitcast values on which insert_subvector
21703   // would makes sense, pull the bitcast through.
21704   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
21705   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
21706   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
21707     SDValue CN0 = N0.getOperand(0);
21708     SDValue CN1 = N1.getOperand(0);
21709     EVT CN0VT = CN0.getValueType();
21710     EVT CN1VT = CN1.getValueType();
21711     if (CN0VT.isVector() && CN1VT.isVector() &&
21712         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
21713         CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
21714       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
21715                                       CN0.getValueType(), CN0, CN1, N2);
21716       return DAG.getBitcast(VT, NewINSERT);
21717     }
21718   }
21719 
21720   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
21721   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
21722   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
21723   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
21724       N0.getOperand(1).getValueType() == N1.getValueType() &&
21725       N0.getOperand(2) == N2)
21726     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
21727                        N1, N2);
21728 
21729   // Eliminate an intermediate insert into an undef vector:
21730   // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
21731   // insert_subvector undef, X, N2
21732   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
21733       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
21734     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
21735                        N1.getOperand(1), N2);
21736 
21737   // Push subvector bitcasts to the output, adjusting the index as we go.
21738   // insert_subvector(bitcast(v), bitcast(s), c1)
21739   // -> bitcast(insert_subvector(v, s, c2))
21740   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
21741       N1.getOpcode() == ISD::BITCAST) {
21742     SDValue N0Src = peekThroughBitcasts(N0);
21743     SDValue N1Src = peekThroughBitcasts(N1);
21744     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
21745     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
21746     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
21747         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
21748       EVT NewVT;
21749       SDLoc DL(N);
21750       SDValue NewIdx;
21751       LLVMContext &Ctx = *DAG.getContext();
21752       ElementCount NumElts = VT.getVectorElementCount();
21753       unsigned EltSizeInBits = VT.getScalarSizeInBits();
21754       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
21755         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
21756         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
21757         NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
21758       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
21759         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
21760         if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
21761           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
21762                                    NumElts.divideCoefficientBy(Scale));
21763           NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
21764         }
21765       }
21766       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
21767         SDValue Res = DAG.getBitcast(NewVT, N0Src);
21768         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
21769         return DAG.getBitcast(VT, Res);
21770       }
21771     }
21772   }
21773 
21774   // Canonicalize insert_subvector dag nodes.
21775   // Example:
21776   // (insert_subvector (insert_subvector A, Idx0), Idx1)
21777   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
21778   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
21779       N1.getValueType() == N0.getOperand(1).getValueType()) {
21780     unsigned OtherIdx = N0.getConstantOperandVal(2);
21781     if (InsIdx < OtherIdx) {
21782       // Swap nodes.
21783       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
21784                                   N0.getOperand(0), N1, N2);
21785       AddToWorklist(NewOp.getNode());
21786       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
21787                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
21788     }
21789   }
21790 
21791   // If the input vector is a concatenation, and the insert replaces
21792   // one of the pieces, we can optimize into a single concat_vectors.
21793   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
21794       N0.getOperand(0).getValueType() == N1.getValueType() &&
21795       N0.getOperand(0).getValueType().isScalableVector() ==
21796           N1.getValueType().isScalableVector()) {
21797     unsigned Factor = N1.getValueType().getVectorMinNumElements();
21798     SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
21799     Ops[InsIdx / Factor] = N1;
21800     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
21801   }
21802 
21803   // Simplify source operands based on insertion.
21804   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
21805     return SDValue(N, 0);
21806 
21807   return SDValue();
21808 }
21809 
visitFP_TO_FP16(SDNode * N)21810 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
21811   SDValue N0 = N->getOperand(0);
21812 
21813   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
21814   if (N0->getOpcode() == ISD::FP16_TO_FP)
21815     return N0->getOperand(0);
21816 
21817   return SDValue();
21818 }
21819 
visitFP16_TO_FP(SDNode * N)21820 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
21821   SDValue N0 = N->getOperand(0);
21822 
21823   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
21824   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
21825     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
21826     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
21827       return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
21828                          N0.getOperand(0));
21829     }
21830   }
21831 
21832   return SDValue();
21833 }
21834 
visitVECREDUCE(SDNode * N)21835 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
21836   SDValue N0 = N->getOperand(0);
21837   EVT VT = N0.getValueType();
21838   unsigned Opcode = N->getOpcode();
21839 
21840   // VECREDUCE over 1-element vector is just an extract.
21841   if (VT.getVectorElementCount().isScalar()) {
21842     SDLoc dl(N);
21843     SDValue Res =
21844         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
21845                     DAG.getVectorIdxConstant(0, dl));
21846     if (Res.getValueType() != N->getValueType(0))
21847       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
21848     return Res;
21849   }
21850 
21851   // On an boolean vector an and/or reduction is the same as a umin/umax
21852   // reduction. Convert them if the latter is legal while the former isn't.
21853   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
21854     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
21855         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
21856     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
21857         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
21858         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
21859       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
21860   }
21861 
21862   return SDValue();
21863 }
21864 
21865 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
21866 /// with the destination vector and a zero vector.
21867 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
21868 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)21869 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
21870   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
21871 
21872   EVT VT = N->getValueType(0);
21873   SDValue LHS = N->getOperand(0);
21874   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
21875   SDLoc DL(N);
21876 
21877   // Make sure we're not running after operation legalization where it
21878   // may have custom lowered the vector shuffles.
21879   if (LegalOperations)
21880     return SDValue();
21881 
21882   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
21883     return SDValue();
21884 
21885   EVT RVT = RHS.getValueType();
21886   unsigned NumElts = RHS.getNumOperands();
21887 
21888   // Attempt to create a valid clear mask, splitting the mask into
21889   // sub elements and checking to see if each is
21890   // all zeros or all ones - suitable for shuffle masking.
21891   auto BuildClearMask = [&](int Split) {
21892     int NumSubElts = NumElts * Split;
21893     int NumSubBits = RVT.getScalarSizeInBits() / Split;
21894 
21895     SmallVector<int, 8> Indices;
21896     for (int i = 0; i != NumSubElts; ++i) {
21897       int EltIdx = i / Split;
21898       int SubIdx = i % Split;
21899       SDValue Elt = RHS.getOperand(EltIdx);
21900       // X & undef --> 0 (not undef). So this lane must be converted to choose
21901       // from the zero constant vector (same as if the element had all 0-bits).
21902       if (Elt.isUndef()) {
21903         Indices.push_back(i + NumSubElts);
21904         continue;
21905       }
21906 
21907       APInt Bits;
21908       if (isa<ConstantSDNode>(Elt))
21909         Bits = cast<ConstantSDNode>(Elt)->getAPIntValue();
21910       else if (isa<ConstantFPSDNode>(Elt))
21911         Bits = cast<ConstantFPSDNode>(Elt)->getValueAPF().bitcastToAPInt();
21912       else
21913         return SDValue();
21914 
21915       // Extract the sub element from the constant bit mask.
21916       if (DAG.getDataLayout().isBigEndian())
21917         Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
21918       else
21919         Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
21920 
21921       if (Bits.isAllOnesValue())
21922         Indices.push_back(i);
21923       else if (Bits == 0)
21924         Indices.push_back(i + NumSubElts);
21925       else
21926         return SDValue();
21927     }
21928 
21929     // Let's see if the target supports this vector_shuffle.
21930     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
21931     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
21932     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
21933       return SDValue();
21934 
21935     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
21936     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
21937                                                    DAG.getBitcast(ClearVT, LHS),
21938                                                    Zero, Indices));
21939   };
21940 
21941   // Determine maximum split level (byte level masking).
21942   int MaxSplit = 1;
21943   if (RVT.getScalarSizeInBits() % 8 == 0)
21944     MaxSplit = RVT.getScalarSizeInBits() / 8;
21945 
21946   for (int Split = 1; Split <= MaxSplit; ++Split)
21947     if (RVT.getScalarSizeInBits() % Split == 0)
21948       if (SDValue S = BuildClearMask(Split))
21949         return S;
21950 
21951   return SDValue();
21952 }
21953 
21954 /// If a vector binop is performed on splat values, it may be profitable to
21955 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG)21956 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG) {
21957   SDValue N0 = N->getOperand(0);
21958   SDValue N1 = N->getOperand(1);
21959   unsigned Opcode = N->getOpcode();
21960   EVT VT = N->getValueType(0);
21961   EVT EltVT = VT.getVectorElementType();
21962   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21963 
21964   // TODO: Remove/replace the extract cost check? If the elements are available
21965   //       as scalars, then there may be no extract cost. Should we ask if
21966   //       inserting a scalar back into a vector is cheap instead?
21967   int Index0, Index1;
21968   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
21969   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
21970   if (!Src0 || !Src1 || Index0 != Index1 ||
21971       Src0.getValueType().getVectorElementType() != EltVT ||
21972       Src1.getValueType().getVectorElementType() != EltVT ||
21973       !TLI.isExtractVecEltCheap(VT, Index0) ||
21974       !TLI.isOperationLegalOrCustom(Opcode, EltVT))
21975     return SDValue();
21976 
21977   SDLoc DL(N);
21978   SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
21979   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
21980   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
21981   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
21982 
21983   // If all lanes but 1 are undefined, no need to splat the scalar result.
21984   // TODO: Keep track of undefs and use that info in the general case.
21985   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
21986       count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
21987       count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
21988     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
21989     // build_vec ..undef, (bo X, Y), undef...
21990     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
21991     Ops[Index0] = ScalarBO;
21992     return DAG.getBuildVector(VT, DL, Ops);
21993   }
21994 
21995   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
21996   SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
21997   return DAG.getBuildVector(VT, DL, Ops);
21998 }
21999 
22000 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N)22001 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
22002   assert(N->getValueType(0).isVector() &&
22003          "SimplifyVBinOp only works on vectors!");
22004 
22005   SDValue LHS = N->getOperand(0);
22006   SDValue RHS = N->getOperand(1);
22007   SDValue Ops[] = {LHS, RHS};
22008   EVT VT = N->getValueType(0);
22009   unsigned Opcode = N->getOpcode();
22010   SDNodeFlags Flags = N->getFlags();
22011 
22012   // See if we can constant fold the vector operation.
22013   if (SDValue Fold = DAG.FoldConstantVectorArithmetic(
22014           Opcode, SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags()))
22015     return Fold;
22016 
22017   // Move unary shuffles with identical masks after a vector binop:
22018   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
22019   //   --> shuffle (VBinOp A, B), Undef, Mask
22020   // This does not require type legality checks because we are creating the
22021   // same types of operations that are in the original sequence. We do have to
22022   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
22023   // though. This code is adapted from the identical transform in instcombine.
22024   if (Opcode != ISD::UDIV && Opcode != ISD::SDIV &&
22025       Opcode != ISD::UREM && Opcode != ISD::SREM &&
22026       Opcode != ISD::UDIVREM && Opcode != ISD::SDIVREM) {
22027     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
22028     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
22029     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
22030         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
22031         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
22032       SDLoc DL(N);
22033       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
22034                                      RHS.getOperand(0), Flags);
22035       SDValue UndefV = LHS.getOperand(1);
22036       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
22037     }
22038 
22039     // Try to sink a splat shuffle after a binop with a uniform constant.
22040     // This is limited to cases where neither the shuffle nor the constant have
22041     // undefined elements because that could be poison-unsafe or inhibit
22042     // demanded elements analysis. It is further limited to not change a splat
22043     // of an inserted scalar because that may be optimized better by
22044     // load-folding or other target-specific behaviors.
22045     if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) &&
22046         Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
22047         Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
22048       // binop (splat X), (splat C) --> splat (binop X, C)
22049       SDLoc DL(N);
22050       SDValue X = Shuf0->getOperand(0);
22051       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
22052       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
22053                                   Shuf0->getMask());
22054     }
22055     if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) &&
22056         Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
22057         Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
22058       // binop (splat C), (splat X) --> splat (binop C, X)
22059       SDLoc DL(N);
22060       SDValue X = Shuf1->getOperand(0);
22061       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
22062       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
22063                                   Shuf1->getMask());
22064     }
22065   }
22066 
22067   // The following pattern is likely to emerge with vector reduction ops. Moving
22068   // the binary operation ahead of insertion may allow using a narrower vector
22069   // instruction that has better performance than the wide version of the op:
22070   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
22071   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
22072       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
22073       LHS.getOperand(2) == RHS.getOperand(2) &&
22074       (LHS.hasOneUse() || RHS.hasOneUse())) {
22075     SDValue X = LHS.getOperand(1);
22076     SDValue Y = RHS.getOperand(1);
22077     SDValue Z = LHS.getOperand(2);
22078     EVT NarrowVT = X.getValueType();
22079     if (NarrowVT == Y.getValueType() &&
22080         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
22081                                               LegalOperations)) {
22082       // (binop undef, undef) may not return undef, so compute that result.
22083       SDLoc DL(N);
22084       SDValue VecC =
22085           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
22086       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
22087       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
22088     }
22089   }
22090 
22091   // Make sure all but the first op are undef or constant.
22092   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
22093     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
22094            all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
22095              return Op.isUndef() ||
22096                     ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
22097            });
22098   };
22099 
22100   // The following pattern is likely to emerge with vector reduction ops. Moving
22101   // the binary operation ahead of the concat may allow using a narrower vector
22102   // instruction that has better performance than the wide version of the op:
22103   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
22104   //   concat (VBinOp X, Y), VecC
22105   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
22106       (LHS.hasOneUse() || RHS.hasOneUse())) {
22107     EVT NarrowVT = LHS.getOperand(0).getValueType();
22108     if (NarrowVT == RHS.getOperand(0).getValueType() &&
22109         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
22110       SDLoc DL(N);
22111       unsigned NumOperands = LHS.getNumOperands();
22112       SmallVector<SDValue, 4> ConcatOps;
22113       for (unsigned i = 0; i != NumOperands; ++i) {
22114         // This constant fold for operands 1 and up.
22115         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
22116                                         RHS.getOperand(i)));
22117       }
22118 
22119       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
22120     }
22121   }
22122 
22123   if (SDValue V = scalarizeBinOpOfSplats(N, DAG))
22124     return V;
22125 
22126   return SDValue();
22127 }
22128 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)22129 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
22130                                     SDValue N2) {
22131   assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
22132 
22133   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
22134                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
22135 
22136   // If we got a simplified select_cc node back from SimplifySelectCC, then
22137   // break it down into a new SETCC node, and a new SELECT node, and then return
22138   // the SELECT node, since we were called with a SELECT node.
22139   if (SCC.getNode()) {
22140     // Check to see if we got a select_cc back (to turn into setcc/select).
22141     // Otherwise, just return whatever node we got back, like fabs.
22142     if (SCC.getOpcode() == ISD::SELECT_CC) {
22143       const SDNodeFlags Flags = N0.getNode()->getFlags();
22144       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
22145                                   N0.getValueType(),
22146                                   SCC.getOperand(0), SCC.getOperand(1),
22147                                   SCC.getOperand(4), Flags);
22148       AddToWorklist(SETCC.getNode());
22149       SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
22150                                          SCC.getOperand(2), SCC.getOperand(3));
22151       SelectNode->setFlags(Flags);
22152       return SelectNode;
22153     }
22154 
22155     return SCC;
22156   }
22157   return SDValue();
22158 }
22159 
22160 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
22161 /// being selected between, see if we can simplify the select.  Callers of this
22162 /// should assume that TheSelect is deleted if this returns true.  As such, they
22163 /// should return the appropriate thing (e.g. the node) back to the top-level of
22164 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)22165 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
22166                                     SDValue RHS) {
22167   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
22168   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
22169   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
22170     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
22171       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
22172       SDValue Sqrt = RHS;
22173       ISD::CondCode CC;
22174       SDValue CmpLHS;
22175       const ConstantFPSDNode *Zero = nullptr;
22176 
22177       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
22178         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
22179         CmpLHS = TheSelect->getOperand(0);
22180         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
22181       } else {
22182         // SELECT or VSELECT
22183         SDValue Cmp = TheSelect->getOperand(0);
22184         if (Cmp.getOpcode() == ISD::SETCC) {
22185           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
22186           CmpLHS = Cmp.getOperand(0);
22187           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
22188         }
22189       }
22190       if (Zero && Zero->isZero() &&
22191           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
22192           CC == ISD::SETULT || CC == ISD::SETLT)) {
22193         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
22194         CombineTo(TheSelect, Sqrt);
22195         return true;
22196       }
22197     }
22198   }
22199   // Cannot simplify select with vector condition
22200   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
22201 
22202   // If this is a select from two identical things, try to pull the operation
22203   // through the select.
22204   if (LHS.getOpcode() != RHS.getOpcode() ||
22205       !LHS.hasOneUse() || !RHS.hasOneUse())
22206     return false;
22207 
22208   // If this is a load and the token chain is identical, replace the select
22209   // of two loads with a load through a select of the address to load from.
22210   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
22211   // constants have been dropped into the constant pool.
22212   if (LHS.getOpcode() == ISD::LOAD) {
22213     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
22214     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
22215 
22216     // Token chains must be identical.
22217     if (LHS.getOperand(0) != RHS.getOperand(0) ||
22218         // Do not let this transformation reduce the number of volatile loads.
22219         // Be conservative for atomics for the moment
22220         // TODO: This does appear to be legal for unordered atomics (see D66309)
22221         !LLD->isSimple() || !RLD->isSimple() ||
22222         // FIXME: If either is a pre/post inc/dec load,
22223         // we'd need to split out the address adjustment.
22224         LLD->isIndexed() || RLD->isIndexed() ||
22225         // If this is an EXTLOAD, the VT's must match.
22226         LLD->getMemoryVT() != RLD->getMemoryVT() ||
22227         // If this is an EXTLOAD, the kind of extension must match.
22228         (LLD->getExtensionType() != RLD->getExtensionType() &&
22229          // The only exception is if one of the extensions is anyext.
22230          LLD->getExtensionType() != ISD::EXTLOAD &&
22231          RLD->getExtensionType() != ISD::EXTLOAD) ||
22232         // FIXME: this discards src value information.  This is
22233         // over-conservative. It would be beneficial to be able to remember
22234         // both potential memory locations.  Since we are discarding
22235         // src value info, don't do the transformation if the memory
22236         // locations are not in the default address space.
22237         LLD->getPointerInfo().getAddrSpace() != 0 ||
22238         RLD->getPointerInfo().getAddrSpace() != 0 ||
22239         // We can't produce a CMOV of a TargetFrameIndex since we won't
22240         // generate the address generation required.
22241         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
22242         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
22243         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
22244                                       LLD->getBasePtr().getValueType()))
22245       return false;
22246 
22247     // The loads must not depend on one another.
22248     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
22249       return false;
22250 
22251     // Check that the select condition doesn't reach either load.  If so,
22252     // folding this will induce a cycle into the DAG.  If not, this is safe to
22253     // xform, so create a select of the addresses.
22254 
22255     SmallPtrSet<const SDNode *, 32> Visited;
22256     SmallVector<const SDNode *, 16> Worklist;
22257 
22258     // Always fail if LLD and RLD are not independent. TheSelect is a
22259     // predecessor to all Nodes in question so we need not search past it.
22260 
22261     Visited.insert(TheSelect);
22262     Worklist.push_back(LLD);
22263     Worklist.push_back(RLD);
22264 
22265     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
22266         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
22267       return false;
22268 
22269     SDValue Addr;
22270     if (TheSelect->getOpcode() == ISD::SELECT) {
22271       // We cannot do this optimization if any pair of {RLD, LLD} is a
22272       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
22273       // Loads, we only need to check if CondNode is a successor to one of the
22274       // loads. We can further avoid this if there's no use of their chain
22275       // value.
22276       SDNode *CondNode = TheSelect->getOperand(0).getNode();
22277       Worklist.push_back(CondNode);
22278 
22279       if ((LLD->hasAnyUseOfValue(1) &&
22280            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
22281           (RLD->hasAnyUseOfValue(1) &&
22282            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
22283         return false;
22284 
22285       Addr = DAG.getSelect(SDLoc(TheSelect),
22286                            LLD->getBasePtr().getValueType(),
22287                            TheSelect->getOperand(0), LLD->getBasePtr(),
22288                            RLD->getBasePtr());
22289     } else {  // Otherwise SELECT_CC
22290       // We cannot do this optimization if any pair of {RLD, LLD} is a
22291       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
22292       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
22293       // one of the loads. We can further avoid this if there's no use of their
22294       // chain value.
22295 
22296       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
22297       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
22298       Worklist.push_back(CondLHS);
22299       Worklist.push_back(CondRHS);
22300 
22301       if ((LLD->hasAnyUseOfValue(1) &&
22302            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
22303           (RLD->hasAnyUseOfValue(1) &&
22304            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
22305         return false;
22306 
22307       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
22308                          LLD->getBasePtr().getValueType(),
22309                          TheSelect->getOperand(0),
22310                          TheSelect->getOperand(1),
22311                          LLD->getBasePtr(), RLD->getBasePtr(),
22312                          TheSelect->getOperand(4));
22313     }
22314 
22315     SDValue Load;
22316     // It is safe to replace the two loads if they have different alignments,
22317     // but the new load must be the minimum (most restrictive) alignment of the
22318     // inputs.
22319     Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
22320     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
22321     if (!RLD->isInvariant())
22322       MMOFlags &= ~MachineMemOperand::MOInvariant;
22323     if (!RLD->isDereferenceable())
22324       MMOFlags &= ~MachineMemOperand::MODereferenceable;
22325     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
22326       // FIXME: Discards pointer and AA info.
22327       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
22328                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
22329                          MMOFlags);
22330     } else {
22331       // FIXME: Discards pointer and AA info.
22332       Load = DAG.getExtLoad(
22333           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
22334                                                   : LLD->getExtensionType(),
22335           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
22336           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
22337     }
22338 
22339     // Users of the select now use the result of the load.
22340     CombineTo(TheSelect, Load);
22341 
22342     // Users of the old loads now use the new load's chain.  We know the
22343     // old-load value is dead now.
22344     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
22345     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
22346     return true;
22347   }
22348 
22349   return false;
22350 }
22351 
22352 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
22353 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)22354 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
22355                                             SDValue N1, SDValue N2, SDValue N3,
22356                                             ISD::CondCode CC) {
22357   // If this is a select where the false operand is zero and the compare is a
22358   // check of the sign bit, see if we can perform the "gzip trick":
22359   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
22360   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
22361   EVT XType = N0.getValueType();
22362   EVT AType = N2.getValueType();
22363   if (!isNullConstant(N3) || !XType.bitsGE(AType))
22364     return SDValue();
22365 
22366   // If the comparison is testing for a positive value, we have to invert
22367   // the sign bit mask, so only do that transform if the target has a bitwise
22368   // 'and not' instruction (the invert is free).
22369   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
22370     // (X > -1) ? A : 0
22371     // (X >  0) ? X : 0 <-- This is canonical signed max.
22372     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
22373       return SDValue();
22374   } else if (CC == ISD::SETLT) {
22375     // (X <  0) ? A : 0
22376     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
22377     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
22378       return SDValue();
22379   } else {
22380     return SDValue();
22381   }
22382 
22383   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
22384   // constant.
22385   EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
22386   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
22387   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
22388     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
22389     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
22390       SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
22391       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
22392       AddToWorklist(Shift.getNode());
22393 
22394       if (XType.bitsGT(AType)) {
22395         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
22396         AddToWorklist(Shift.getNode());
22397       }
22398 
22399       if (CC == ISD::SETGT)
22400         Shift = DAG.getNOT(DL, Shift, AType);
22401 
22402       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
22403     }
22404   }
22405 
22406   unsigned ShCt = XType.getSizeInBits() - 1;
22407   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
22408     return SDValue();
22409 
22410   SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
22411   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
22412   AddToWorklist(Shift.getNode());
22413 
22414   if (XType.bitsGT(AType)) {
22415     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
22416     AddToWorklist(Shift.getNode());
22417   }
22418 
22419   if (CC == ISD::SETGT)
22420     Shift = DAG.getNOT(DL, Shift, AType);
22421 
22422   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
22423 }
22424 
22425 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)22426 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
22427   SDValue N0 = N->getOperand(0);
22428   SDValue N1 = N->getOperand(1);
22429   SDValue N2 = N->getOperand(2);
22430   EVT VT = N->getValueType(0);
22431   SDLoc DL(N);
22432 
22433   unsigned BinOpc = N1.getOpcode();
22434   if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc))
22435     return SDValue();
22436 
22437   if (!N->isOnlyUserOf(N0.getNode()) || !N->isOnlyUserOf(N1.getNode()))
22438     return SDValue();
22439 
22440   // Fold select(cond, binop(x, y), binop(z, y))
22441   //  --> binop(select(cond, x, z), y)
22442   if (N1.getOperand(1) == N2.getOperand(1)) {
22443     SDValue NewSel =
22444         DAG.getSelect(DL, VT, N0, N1.getOperand(0), N2.getOperand(0));
22445     SDValue NewBinOp = DAG.getNode(BinOpc, DL, VT, NewSel, N1.getOperand(1));
22446     NewBinOp->setFlags(N1->getFlags());
22447     NewBinOp->intersectFlagsWith(N2->getFlags());
22448     return NewBinOp;
22449   }
22450 
22451   // Fold select(cond, binop(x, y), binop(x, z))
22452   //  --> binop(x, select(cond, y, z))
22453   // Second op VT might be different (e.g. shift amount type)
22454   if (N1.getOperand(0) == N2.getOperand(0) &&
22455       VT == N1.getOperand(1).getValueType() &&
22456       VT == N2.getOperand(1).getValueType()) {
22457     SDValue NewSel =
22458         DAG.getSelect(DL, VT, N0, N1.getOperand(1), N2.getOperand(1));
22459     SDValue NewBinOp = DAG.getNode(BinOpc, DL, VT, N1.getOperand(0), NewSel);
22460     NewBinOp->setFlags(N1->getFlags());
22461     NewBinOp->intersectFlagsWith(N2->getFlags());
22462     return NewBinOp;
22463   }
22464 
22465   // TODO: Handle isCommutativeBinOp patterns as well?
22466   return SDValue();
22467 }
22468 
22469 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)22470 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
22471   SDValue N0 = N->getOperand(0);
22472   EVT VT = N->getValueType(0);
22473   bool IsFabs = N->getOpcode() == ISD::FABS;
22474   bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
22475 
22476   if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
22477     return SDValue();
22478 
22479   SDValue Int = N0.getOperand(0);
22480   EVT IntVT = Int.getValueType();
22481 
22482   // The operand to cast should be integer.
22483   if (!IntVT.isInteger() || IntVT.isVector())
22484     return SDValue();
22485 
22486   // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
22487   // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
22488   APInt SignMask;
22489   if (N0.getValueType().isVector()) {
22490     // For vector, create a sign mask (0x80...) or its inverse (for fabs,
22491     // 0x7f...) per element and splat it.
22492     SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
22493     if (IsFabs)
22494       SignMask = ~SignMask;
22495     SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
22496   } else {
22497     // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
22498     SignMask = APInt::getSignMask(IntVT.getSizeInBits());
22499     if (IsFabs)
22500       SignMask = ~SignMask;
22501   }
22502   SDLoc DL(N0);
22503   Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
22504                     DAG.getConstant(SignMask, DL, IntVT));
22505   AddToWorklist(Int.getNode());
22506   return DAG.getBitcast(VT, Int);
22507 }
22508 
22509 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
22510 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
22511 /// in it. This may be a win when the constant is not otherwise available
22512 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)22513 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
22514     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
22515     ISD::CondCode CC) {
22516   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
22517     return SDValue();
22518 
22519   // If we are before legalize types, we want the other legalization to happen
22520   // first (for example, to avoid messing with soft float).
22521   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
22522   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
22523   EVT VT = N2.getValueType();
22524   if (!TV || !FV || !TLI.isTypeLegal(VT))
22525     return SDValue();
22526 
22527   // If a constant can be materialized without loads, this does not make sense.
22528   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
22529       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
22530       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
22531     return SDValue();
22532 
22533   // If both constants have multiple uses, then we won't need to do an extra
22534   // load. The values are likely around in registers for other users.
22535   if (!TV->hasOneUse() && !FV->hasOneUse())
22536     return SDValue();
22537 
22538   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
22539                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
22540   Type *FPTy = Elts[0]->getType();
22541   const DataLayout &TD = DAG.getDataLayout();
22542 
22543   // Create a ConstantArray of the two constants.
22544   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
22545   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
22546                                       TD.getPrefTypeAlign(FPTy));
22547   Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
22548 
22549   // Get offsets to the 0 and 1 elements of the array, so we can select between
22550   // them.
22551   SDValue Zero = DAG.getIntPtrConstant(0, DL);
22552   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
22553   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
22554   SDValue Cond =
22555       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
22556   AddToWorklist(Cond.getNode());
22557   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
22558   AddToWorklist(CstOffset.getNode());
22559   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
22560   AddToWorklist(CPIdx.getNode());
22561   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
22562                      MachinePointerInfo::getConstantPool(
22563                          DAG.getMachineFunction()), Alignment);
22564 }
22565 
22566 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
22567 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)22568 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
22569                                       SDValue N2, SDValue N3, ISD::CondCode CC,
22570                                       bool NotExtCompare) {
22571   // (x ? y : y) -> y.
22572   if (N2 == N3) return N2;
22573 
22574   EVT CmpOpVT = N0.getValueType();
22575   EVT CmpResVT = getSetCCResultType(CmpOpVT);
22576   EVT VT = N2.getValueType();
22577   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
22578   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
22579   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
22580 
22581   // Determine if the condition we're dealing with is constant.
22582   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
22583     AddToWorklist(SCC.getNode());
22584     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
22585       // fold select_cc true, x, y -> x
22586       // fold select_cc false, x, y -> y
22587       return !(SCCC->isNullValue()) ? N2 : N3;
22588     }
22589   }
22590 
22591   if (SDValue V =
22592           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
22593     return V;
22594 
22595   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
22596     return V;
22597 
22598   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (shr (shl x)) A)
22599   // where y is has a single bit set.
22600   // A plaintext description would be, we can turn the SELECT_CC into an AND
22601   // when the condition can be materialized as an all-ones register.  Any
22602   // single bit-test can be materialized as an all-ones register with
22603   // shift-left and shift-right-arith.
22604   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
22605       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
22606     SDValue AndLHS = N0->getOperand(0);
22607     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
22608     if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
22609       // Shift the tested bit over the sign bit.
22610       const APInt &AndMask = ConstAndRHS->getAPIntValue();
22611       unsigned ShCt = AndMask.getBitWidth() - 1;
22612       if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
22613         SDValue ShlAmt =
22614           DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
22615                           getShiftAmountTy(AndLHS.getValueType()));
22616         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
22617 
22618         // Now arithmetic right shift it all the way over, so the result is
22619         // either all-ones, or zero.
22620         SDValue ShrAmt =
22621           DAG.getConstant(ShCt, SDLoc(Shl),
22622                           getShiftAmountTy(Shl.getValueType()));
22623         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
22624 
22625         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
22626       }
22627     }
22628   }
22629 
22630   // fold select C, 16, 0 -> shl C, 4
22631   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
22632   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
22633 
22634   if ((Fold || Swap) &&
22635       TLI.getBooleanContents(CmpOpVT) ==
22636           TargetLowering::ZeroOrOneBooleanContent &&
22637       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
22638 
22639     if (Swap) {
22640       CC = ISD::getSetCCInverse(CC, CmpOpVT);
22641       std::swap(N2C, N3C);
22642     }
22643 
22644     // If the caller doesn't want us to simplify this into a zext of a compare,
22645     // don't do it.
22646     if (NotExtCompare && N2C->isOne())
22647       return SDValue();
22648 
22649     SDValue Temp, SCC;
22650     // zext (setcc n0, n1)
22651     if (LegalTypes) {
22652       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
22653       if (VT.bitsLT(SCC.getValueType()))
22654         Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
22655       else
22656         Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
22657     } else {
22658       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
22659       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
22660     }
22661 
22662     AddToWorklist(SCC.getNode());
22663     AddToWorklist(Temp.getNode());
22664 
22665     if (N2C->isOne())
22666       return Temp;
22667 
22668     unsigned ShCt = N2C->getAPIntValue().logBase2();
22669     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
22670       return SDValue();
22671 
22672     // shl setcc result by log2 n2c
22673     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
22674                        DAG.getConstant(ShCt, SDLoc(Temp),
22675                                        getShiftAmountTy(Temp.getValueType())));
22676   }
22677 
22678   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
22679   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
22680   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
22681   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
22682   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
22683   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
22684   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
22685   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
22686   if (N1C && N1C->isNullValue() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
22687     SDValue ValueOnZero = N2;
22688     SDValue Count = N3;
22689     // If the condition is NE instead of E, swap the operands.
22690     if (CC == ISD::SETNE)
22691       std::swap(ValueOnZero, Count);
22692     // Check if the value on zero is a constant equal to the bits in the type.
22693     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
22694       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
22695         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
22696         // legal, combine to just cttz.
22697         if ((Count.getOpcode() == ISD::CTTZ ||
22698              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
22699             N0 == Count.getOperand(0) &&
22700             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
22701           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
22702         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
22703         // legal, combine to just ctlz.
22704         if ((Count.getOpcode() == ISD::CTLZ ||
22705              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
22706             N0 == Count.getOperand(0) &&
22707             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
22708           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
22709       }
22710     }
22711   }
22712 
22713   return SDValue();
22714 }
22715 
22716 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)22717 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
22718                                    ISD::CondCode Cond, const SDLoc &DL,
22719                                    bool foldBooleans) {
22720   TargetLowering::DAGCombinerInfo
22721     DagCombineInfo(DAG, Level, false, this);
22722   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
22723 }
22724 
22725 /// Given an ISD::SDIV node expressing a divide by constant, return
22726 /// a DAG expression to select that will generate the same value by multiplying
22727 /// by a magic number.
22728 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)22729 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
22730   // when optimising for minimum size, we don't want to expand a div to a mul
22731   // and a shift.
22732   if (DAG.getMachineFunction().getFunction().hasMinSize())
22733     return SDValue();
22734 
22735   SmallVector<SDNode *, 8> Built;
22736   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
22737     for (SDNode *N : Built)
22738       AddToWorklist(N);
22739     return S;
22740   }
22741 
22742   return SDValue();
22743 }
22744 
22745 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
22746 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)22747 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
22748   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
22749   if (!C)
22750     return SDValue();
22751 
22752   // Avoid division by zero.
22753   if (C->isNullValue())
22754     return SDValue();
22755 
22756   SmallVector<SDNode *, 8> Built;
22757   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
22758     for (SDNode *N : Built)
22759       AddToWorklist(N);
22760     return S;
22761   }
22762 
22763   return SDValue();
22764 }
22765 
22766 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
22767 /// expression that will generate the same value by multiplying by a magic
22768 /// number.
22769 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)22770 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
22771   // when optimising for minimum size, we don't want to expand a div to a mul
22772   // and a shift.
22773   if (DAG.getMachineFunction().getFunction().hasMinSize())
22774     return SDValue();
22775 
22776   SmallVector<SDNode *, 8> Built;
22777   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
22778     for (SDNode *N : Built)
22779       AddToWorklist(N);
22780     return S;
22781   }
22782 
22783   return SDValue();
22784 }
22785 
22786 /// Determines the LogBase2 value for a non-null input value using the
22787 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL)22788 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
22789   EVT VT = V.getValueType();
22790   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
22791   SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
22792   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
22793   return LogBase2;
22794 }
22795 
22796 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
22797 /// For the reciprocal, we need to find the zero of the function:
22798 ///   F(X) = A X - 1 [which has a zero at X = 1/A]
22799 ///     =>
22800 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
22801 ///     does not require additional intermediate precision]
22802 /// For the last iteration, put numerator N into it to gain more precision:
22803 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)22804 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
22805                                       SDNodeFlags Flags) {
22806   if (LegalDAG)
22807     return SDValue();
22808 
22809   // TODO: Handle half and/or extended types?
22810   EVT VT = Op.getValueType();
22811   if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
22812     return SDValue();
22813 
22814   // If estimates are explicitly disabled for this function, we're done.
22815   MachineFunction &MF = DAG.getMachineFunction();
22816   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
22817   if (Enabled == TLI.ReciprocalEstimate::Disabled)
22818     return SDValue();
22819 
22820   // Estimates may be explicitly enabled for this type with a custom number of
22821   // refinement steps.
22822   int Iterations = TLI.getDivRefinementSteps(VT, MF);
22823   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
22824     AddToWorklist(Est.getNode());
22825 
22826     SDLoc DL(Op);
22827     if (Iterations) {
22828       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
22829 
22830       // Newton iterations: Est = Est + Est (N - Arg * Est)
22831       // If this is the last iteration, also multiply by the numerator.
22832       for (int i = 0; i < Iterations; ++i) {
22833         SDValue MulEst = Est;
22834 
22835         if (i == Iterations - 1) {
22836           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
22837           AddToWorklist(MulEst.getNode());
22838         }
22839 
22840         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
22841         AddToWorklist(NewEst.getNode());
22842 
22843         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
22844                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
22845         AddToWorklist(NewEst.getNode());
22846 
22847         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
22848         AddToWorklist(NewEst.getNode());
22849 
22850         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
22851         AddToWorklist(Est.getNode());
22852       }
22853     } else {
22854       // If no iterations are available, multiply with N.
22855       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
22856       AddToWorklist(Est.getNode());
22857     }
22858 
22859     return Est;
22860   }
22861 
22862   return SDValue();
22863 }
22864 
22865 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
22866 /// For the reciprocal sqrt, we need to find the zero of the function:
22867 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
22868 ///     =>
22869 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
22870 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)22871 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
22872                                          unsigned Iterations,
22873                                          SDNodeFlags Flags, bool Reciprocal) {
22874   EVT VT = Arg.getValueType();
22875   SDLoc DL(Arg);
22876   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
22877 
22878   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
22879   // this entire sequence requires only one FP constant.
22880   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
22881   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
22882 
22883   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
22884   for (unsigned i = 0; i < Iterations; ++i) {
22885     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
22886     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
22887     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
22888     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
22889   }
22890 
22891   // If non-reciprocal square root is requested, multiply the result by Arg.
22892   if (!Reciprocal)
22893     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
22894 
22895   return Est;
22896 }
22897 
22898 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
22899 /// For the reciprocal sqrt, we need to find the zero of the function:
22900 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
22901 ///     =>
22902 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)22903 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
22904                                          unsigned Iterations,
22905                                          SDNodeFlags Flags, bool Reciprocal) {
22906   EVT VT = Arg.getValueType();
22907   SDLoc DL(Arg);
22908   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
22909   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
22910 
22911   // This routine must enter the loop below to work correctly
22912   // when (Reciprocal == false).
22913   assert(Iterations > 0);
22914 
22915   // Newton iterations for reciprocal square root:
22916   // E = (E * -0.5) * ((A * E) * E + -3.0)
22917   for (unsigned i = 0; i < Iterations; ++i) {
22918     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
22919     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
22920     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
22921 
22922     // When calculating a square root at the last iteration build:
22923     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
22924     // (notice a common subexpression)
22925     SDValue LHS;
22926     if (Reciprocal || (i + 1) < Iterations) {
22927       // RSQRT: LHS = (E * -0.5)
22928       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
22929     } else {
22930       // SQRT: LHS = (A * E) * -0.5
22931       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
22932     }
22933 
22934     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
22935   }
22936 
22937   return Est;
22938 }
22939 
22940 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
22941 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
22942 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)22943 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
22944                                            bool Reciprocal) {
22945   if (LegalDAG)
22946     return SDValue();
22947 
22948   // TODO: Handle half and/or extended types?
22949   EVT VT = Op.getValueType();
22950   if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
22951     return SDValue();
22952 
22953   // If estimates are explicitly disabled for this function, we're done.
22954   MachineFunction &MF = DAG.getMachineFunction();
22955   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
22956   if (Enabled == TLI.ReciprocalEstimate::Disabled)
22957     return SDValue();
22958 
22959   // Estimates may be explicitly enabled for this type with a custom number of
22960   // refinement steps.
22961   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
22962 
22963   bool UseOneConstNR = false;
22964   if (SDValue Est =
22965       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
22966                           Reciprocal)) {
22967     AddToWorklist(Est.getNode());
22968 
22969     if (Iterations)
22970       Est = UseOneConstNR
22971             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
22972             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
22973     if (!Reciprocal) {
22974       SDLoc DL(Op);
22975       // Try the target specific test first.
22976       SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
22977 
22978       // The estimate is now completely wrong if the input was exactly 0.0 or
22979       // possibly a denormal. Force the answer to 0.0 or value provided by
22980       // target for those cases.
22981       Est = DAG.getNode(
22982           Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
22983           Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
22984     }
22985     return Est;
22986   }
22987 
22988   return SDValue();
22989 }
22990 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)22991 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
22992   return buildSqrtEstimateImpl(Op, Flags, true);
22993 }
22994 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)22995 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
22996   return buildSqrtEstimateImpl(Op, Flags, false);
22997 }
22998 
22999 /// Return true if there is any possibility that the two addresses overlap.
isAlias(SDNode * Op0,SDNode * Op1) const23000 bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const {
23001 
23002   struct MemUseCharacteristics {
23003     bool IsVolatile;
23004     bool IsAtomic;
23005     SDValue BasePtr;
23006     int64_t Offset;
23007     Optional<int64_t> NumBytes;
23008     MachineMemOperand *MMO;
23009   };
23010 
23011   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
23012     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
23013       int64_t Offset = 0;
23014       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
23015         Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
23016                      ? C->getSExtValue()
23017                      : (LSN->getAddressingMode() == ISD::PRE_DEC)
23018                            ? -1 * C->getSExtValue()
23019                            : 0;
23020       uint64_t Size =
23021           MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize());
23022       return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(),
23023               Offset /*base offset*/,
23024               Optional<int64_t>(Size),
23025               LSN->getMemOperand()};
23026     }
23027     if (const auto *LN = cast<LifetimeSDNode>(N))
23028       return {false /*isVolatile*/, /*isAtomic*/ false, LN->getOperand(1),
23029               (LN->hasOffset()) ? LN->getOffset() : 0,
23030               (LN->hasOffset()) ? Optional<int64_t>(LN->getSize())
23031                                 : Optional<int64_t>(),
23032               (MachineMemOperand *)nullptr};
23033     // Default.
23034     return {false /*isvolatile*/, /*isAtomic*/ false, SDValue(),
23035             (int64_t)0 /*offset*/,
23036             Optional<int64_t>() /*size*/, (MachineMemOperand *)nullptr};
23037   };
23038 
23039   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
23040                         MUC1 = getCharacteristics(Op1);
23041 
23042   // If they are to the same address, then they must be aliases.
23043   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
23044       MUC0.Offset == MUC1.Offset)
23045     return true;
23046 
23047   // If they are both volatile then they cannot be reordered.
23048   if (MUC0.IsVolatile && MUC1.IsVolatile)
23049     return true;
23050 
23051   // Be conservative about atomics for the moment
23052   // TODO: This is way overconservative for unordered atomics (see D66309)
23053   if (MUC0.IsAtomic && MUC1.IsAtomic)
23054     return true;
23055 
23056   if (MUC0.MMO && MUC1.MMO) {
23057     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
23058         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
23059       return false;
23060   }
23061 
23062   // Try to prove that there is aliasing, or that there is no aliasing. Either
23063   // way, we can return now. If nothing can be proved, proceed with more tests.
23064   bool IsAlias;
23065   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
23066                                        DAG, IsAlias))
23067     return IsAlias;
23068 
23069   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
23070   // either are not known.
23071   if (!MUC0.MMO || !MUC1.MMO)
23072     return true;
23073 
23074   // If one operation reads from invariant memory, and the other may store, they
23075   // cannot alias. These should really be checking the equivalent of mayWrite,
23076   // but it only matters for memory nodes other than load /store.
23077   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
23078       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
23079     return false;
23080 
23081   // If we know required SrcValue1 and SrcValue2 have relatively large
23082   // alignment compared to the size and offset of the access, we may be able
23083   // to prove they do not alias. This check is conservative for now to catch
23084   // cases created by splitting vector types, it only works when the offsets are
23085   // multiples of the size of the data.
23086   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
23087   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
23088   Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
23089   Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
23090   auto &Size0 = MUC0.NumBytes;
23091   auto &Size1 = MUC1.NumBytes;
23092   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
23093       Size0.hasValue() && Size1.hasValue() && *Size0 == *Size1 &&
23094       OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 &&
23095       SrcValOffset1 % *Size1 == 0) {
23096     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
23097     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
23098 
23099     // There is no overlap between these relatively aligned accesses of
23100     // similar size. Return no alias.
23101     if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0)
23102       return false;
23103   }
23104 
23105   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
23106                    ? CombinerGlobalAA
23107                    : DAG.getSubtarget().useAA();
23108 #ifndef NDEBUG
23109   if (CombinerAAOnlyFunc.getNumOccurrences() &&
23110       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
23111     UseAA = false;
23112 #endif
23113 
23114   if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
23115       Size0.hasValue() && Size1.hasValue()) {
23116     // Use alias analysis information.
23117     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
23118     int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset;
23119     int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset;
23120     if (AA->isNoAlias(
23121             MemoryLocation(MUC0.MMO->getValue(), Overlap0,
23122                            UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
23123             MemoryLocation(MUC1.MMO->getValue(), Overlap1,
23124                            UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
23125       return false;
23126   }
23127 
23128   // Otherwise we have to assume they alias.
23129   return true;
23130 }
23131 
23132 /// Walk up chain skipping non-aliasing memory nodes,
23133 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)23134 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
23135                                    SmallVectorImpl<SDValue> &Aliases) {
23136   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
23137   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
23138 
23139   // Get alias information for node.
23140   // TODO: relax aliasing for unordered atomics (see D66309)
23141   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
23142 
23143   // Starting off.
23144   Chains.push_back(OriginalChain);
23145   unsigned Depth = 0;
23146 
23147   // Attempt to improve chain by a single step
23148   std::function<bool(SDValue &)> ImproveChain = [&](SDValue &C) -> bool {
23149     switch (C.getOpcode()) {
23150     case ISD::EntryToken:
23151       // No need to mark EntryToken.
23152       C = SDValue();
23153       return true;
23154     case ISD::LOAD:
23155     case ISD::STORE: {
23156       // Get alias information for C.
23157       // TODO: Relax aliasing for unordered atomics (see D66309)
23158       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
23159                       cast<LSBaseSDNode>(C.getNode())->isSimple();
23160       if ((IsLoad && IsOpLoad) || !isAlias(N, C.getNode())) {
23161         // Look further up the chain.
23162         C = C.getOperand(0);
23163         return true;
23164       }
23165       // Alias, so stop here.
23166       return false;
23167     }
23168 
23169     case ISD::CopyFromReg:
23170       // Always forward past past CopyFromReg.
23171       C = C.getOperand(0);
23172       return true;
23173 
23174     case ISD::LIFETIME_START:
23175     case ISD::LIFETIME_END: {
23176       // We can forward past any lifetime start/end that can be proven not to
23177       // alias the memory access.
23178       if (!isAlias(N, C.getNode())) {
23179         // Look further up the chain.
23180         C = C.getOperand(0);
23181         return true;
23182       }
23183       return false;
23184     }
23185     default:
23186       return false;
23187     }
23188   };
23189 
23190   // Look at each chain and determine if it is an alias.  If so, add it to the
23191   // aliases list.  If not, then continue up the chain looking for the next
23192   // candidate.
23193   while (!Chains.empty()) {
23194     SDValue Chain = Chains.pop_back_val();
23195 
23196     // Don't bother if we've seen Chain before.
23197     if (!Visited.insert(Chain.getNode()).second)
23198       continue;
23199 
23200     // For TokenFactor nodes, look at each operand and only continue up the
23201     // chain until we reach the depth limit.
23202     //
23203     // FIXME: The depth check could be made to return the last non-aliasing
23204     // chain we found before we hit a tokenfactor rather than the original
23205     // chain.
23206     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
23207       Aliases.clear();
23208       Aliases.push_back(OriginalChain);
23209       return;
23210     }
23211 
23212     if (Chain.getOpcode() == ISD::TokenFactor) {
23213       // We have to check each of the operands of the token factor for "small"
23214       // token factors, so we queue them up.  Adding the operands to the queue
23215       // (stack) in reverse order maintains the original order and increases the
23216       // likelihood that getNode will find a matching token factor (CSE.)
23217       if (Chain.getNumOperands() > 16) {
23218         Aliases.push_back(Chain);
23219         continue;
23220       }
23221       for (unsigned n = Chain.getNumOperands(); n;)
23222         Chains.push_back(Chain.getOperand(--n));
23223       ++Depth;
23224       continue;
23225     }
23226     // Everything else
23227     if (ImproveChain(Chain)) {
23228       // Updated Chain Found, Consider new chain if one exists.
23229       if (Chain.getNode())
23230         Chains.push_back(Chain);
23231       ++Depth;
23232       continue;
23233     }
23234     // No Improved Chain Possible, treat as Alias.
23235     Aliases.push_back(Chain);
23236   }
23237 }
23238 
23239 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
23240 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)23241 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
23242   if (OptLevel == CodeGenOpt::None)
23243     return OldChain;
23244 
23245   // Ops for replacing token factor.
23246   SmallVector<SDValue, 8> Aliases;
23247 
23248   // Accumulate all the aliases to this node.
23249   GatherAllAliases(N, OldChain, Aliases);
23250 
23251   // If no operands then chain to entry token.
23252   if (Aliases.size() == 0)
23253     return DAG.getEntryNode();
23254 
23255   // If a single operand then chain to it.  We don't need to revisit it.
23256   if (Aliases.size() == 1)
23257     return Aliases[0];
23258 
23259   // Construct a custom tailored token factor.
23260   return DAG.getTokenFactor(SDLoc(N), Aliases);
23261 }
23262 
23263 namespace {
23264 // TODO: Replace with with std::monostate when we move to C++17.
23265 struct UnitT { } Unit;
operator ==(const UnitT &,const UnitT &)23266 bool operator==(const UnitT &, const UnitT &) { return true; }
operator !=(const UnitT &,const UnitT &)23267 bool operator!=(const UnitT &, const UnitT &) { return false; }
23268 } // namespace
23269 
23270 // This function tries to collect a bunch of potentially interesting
23271 // nodes to improve the chains of, all at once. This might seem
23272 // redundant, as this function gets called when visiting every store
23273 // node, so why not let the work be done on each store as it's visited?
23274 //
23275 // I believe this is mainly important because mergeConsecutiveStores
23276 // is unable to deal with merging stores of different sizes, so unless
23277 // we improve the chains of all the potential candidates up-front
23278 // before running mergeConsecutiveStores, it might only see some of
23279 // the nodes that will eventually be candidates, and then not be able
23280 // to go from a partially-merged state to the desired final
23281 // fully-merged state.
23282 
parallelizeChainedStores(StoreSDNode * St)23283 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
23284   SmallVector<StoreSDNode *, 8> ChainedStores;
23285   StoreSDNode *STChain = St;
23286   // Intervals records which offsets from BaseIndex have been covered. In
23287   // the common case, every store writes to the immediately previous address
23288   // space and thus merged with the previous interval at insertion time.
23289 
23290   using IMap =
23291       llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>;
23292   IMap::Allocator A;
23293   IMap Intervals(A);
23294 
23295   // This holds the base pointer, index, and the offset in bytes from the base
23296   // pointer.
23297   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
23298 
23299   // We must have a base and an offset.
23300   if (!BasePtr.getBase().getNode())
23301     return false;
23302 
23303   // Do not handle stores to undef base pointers.
23304   if (BasePtr.getBase().isUndef())
23305     return false;
23306 
23307   // Do not handle stores to opaque types
23308   if (St->getMemoryVT().isZeroSized())
23309     return false;
23310 
23311   // BaseIndexOffset assumes that offsets are fixed-size, which
23312   // is not valid for scalable vectors where the offsets are
23313   // scaled by `vscale`, so bail out early.
23314   if (St->getMemoryVT().isScalableVector())
23315     return false;
23316 
23317   // Add ST's interval.
23318   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit);
23319 
23320   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
23321     if (Chain->getMemoryVT().isScalableVector())
23322       return false;
23323 
23324     // If the chain has more than one use, then we can't reorder the mem ops.
23325     if (!SDValue(Chain, 0)->hasOneUse())
23326       break;
23327     // TODO: Relax for unordered atomics (see D66309)
23328     if (!Chain->isSimple() || Chain->isIndexed())
23329       break;
23330 
23331     // Find the base pointer and offset for this memory node.
23332     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
23333     // Check that the base pointer is the same as the original one.
23334     int64_t Offset;
23335     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
23336       break;
23337     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
23338     // Make sure we don't overlap with other intervals by checking the ones to
23339     // the left or right before inserting.
23340     auto I = Intervals.find(Offset);
23341     // If there's a next interval, we should end before it.
23342     if (I != Intervals.end() && I.start() < (Offset + Length))
23343       break;
23344     // If there's a previous interval, we should start after it.
23345     if (I != Intervals.begin() && (--I).stop() <= Offset)
23346       break;
23347     Intervals.insert(Offset, Offset + Length, Unit);
23348 
23349     ChainedStores.push_back(Chain);
23350     STChain = Chain;
23351   }
23352 
23353   // If we didn't find a chained store, exit.
23354   if (ChainedStores.size() == 0)
23355     return false;
23356 
23357   // Improve all chained stores (St and ChainedStores members) starting from
23358   // where the store chain ended and return single TokenFactor.
23359   SDValue NewChain = STChain->getChain();
23360   SmallVector<SDValue, 8> TFOps;
23361   for (unsigned I = ChainedStores.size(); I;) {
23362     StoreSDNode *S = ChainedStores[--I];
23363     SDValue BetterChain = FindBetterChain(S, NewChain);
23364     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
23365         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
23366     TFOps.push_back(SDValue(S, 0));
23367     ChainedStores[I] = S;
23368   }
23369 
23370   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
23371   SDValue BetterChain = FindBetterChain(St, NewChain);
23372   SDValue NewST;
23373   if (St->isTruncatingStore())
23374     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
23375                               St->getBasePtr(), St->getMemoryVT(),
23376                               St->getMemOperand());
23377   else
23378     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
23379                          St->getBasePtr(), St->getMemOperand());
23380 
23381   TFOps.push_back(NewST);
23382 
23383   // If we improved every element of TFOps, then we've lost the dependence on
23384   // NewChain to successors of St and we need to add it back to TFOps. Do so at
23385   // the beginning to keep relative order consistent with FindBetterChains.
23386   auto hasImprovedChain = [&](SDValue ST) -> bool {
23387     return ST->getOperand(0) != NewChain;
23388   };
23389   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
23390   if (AddNewChain)
23391     TFOps.insert(TFOps.begin(), NewChain);
23392 
23393   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
23394   CombineTo(St, TF);
23395 
23396   // Add TF and its operands to the worklist.
23397   AddToWorklist(TF.getNode());
23398   for (const SDValue &Op : TF->ops())
23399     AddToWorklist(Op.getNode());
23400   AddToWorklist(STChain);
23401   return true;
23402 }
23403 
findBetterNeighborChains(StoreSDNode * St)23404 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
23405   if (OptLevel == CodeGenOpt::None)
23406     return false;
23407 
23408   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
23409 
23410   // We must have a base and an offset.
23411   if (!BasePtr.getBase().getNode())
23412     return false;
23413 
23414   // Do not handle stores to undef base pointers.
23415   if (BasePtr.getBase().isUndef())
23416     return false;
23417 
23418   // Directly improve a chain of disjoint stores starting at St.
23419   if (parallelizeChainedStores(St))
23420     return true;
23421 
23422   // Improve St's Chain..
23423   SDValue BetterChain = FindBetterChain(St, St->getChain());
23424   if (St->getChain() != BetterChain) {
23425     replaceStoreChain(St, BetterChain);
23426     return true;
23427   }
23428   return false;
23429 }
23430 
23431 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOpt::Level OptLevel)23432 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
23433                            CodeGenOpt::Level OptLevel) {
23434   /// This is the main entry point to this class.
23435   DAGCombiner(*this, AA, OptLevel).Run(Level);
23436 }
23437