1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/Statistic.h"
31 #include "llvm/Analysis/AliasAnalysis.h"
32 #include "llvm/Analysis/MemoryLocation.h"
33 #include "llvm/CodeGen/DAGCombine.h"
34 #include "llvm/CodeGen/ISDOpcodes.h"
35 #include "llvm/CodeGen/MachineFrameInfo.h"
36 #include "llvm/CodeGen/MachineFunction.h"
37 #include "llvm/CodeGen/MachineMemOperand.h"
38 #include "llvm/CodeGen/RuntimeLibcalls.h"
39 #include "llvm/CodeGen/SelectionDAG.h"
40 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
41 #include "llvm/CodeGen/SelectionDAGNodes.h"
42 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
43 #include "llvm/CodeGen/TargetLowering.h"
44 #include "llvm/CodeGen/TargetRegisterInfo.h"
45 #include "llvm/CodeGen/TargetSubtargetInfo.h"
46 #include "llvm/CodeGen/ValueTypes.h"
47 #include "llvm/IR/Attributes.h"
48 #include "llvm/IR/Constant.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/DerivedTypes.h"
51 #include "llvm/IR/Function.h"
52 #include "llvm/IR/LLVMContext.h"
53 #include "llvm/IR/Metadata.h"
54 #include "llvm/Support/Casting.h"
55 #include "llvm/Support/CodeGen.h"
56 #include "llvm/Support/CommandLine.h"
57 #include "llvm/Support/Compiler.h"
58 #include "llvm/Support/Debug.h"
59 #include "llvm/Support/ErrorHandling.h"
60 #include "llvm/Support/KnownBits.h"
61 #include "llvm/Support/MachineValueType.h"
62 #include "llvm/Support/MathExtras.h"
63 #include "llvm/Support/raw_ostream.h"
64 #include "llvm/Target/TargetMachine.h"
65 #include "llvm/Target/TargetOptions.h"
66 #include <algorithm>
67 #include <cassert>
68 #include <cstdint>
69 #include <functional>
70 #include <iterator>
71 #include <string>
72 #include <tuple>
73 #include <utility>
74 
75 using namespace llvm;
76 
77 #define DEBUG_TYPE "dagcombine"
78 
79 STATISTIC(NodesCombined   , "Number of dag nodes combined");
80 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
81 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
82 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
83 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
84 STATISTIC(SlicedLoads, "Number of load sliced");
85 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
86 
87 static cl::opt<bool>
88 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
89                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
90 
91 static cl::opt<bool>
92 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
93         cl::desc("Enable DAG combiner's use of TBAA"));
94 
95 #ifndef NDEBUG
96 static cl::opt<std::string>
97 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
98                    cl::desc("Only use DAG-combiner alias analysis in this"
99                             " function"));
100 #endif
101 
102 /// Hidden option to stress test load slicing, i.e., when this option
103 /// is enabled, load slicing bypasses most of its profitability guards.
104 static cl::opt<bool>
105 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
106                   cl::desc("Bypass the profitability model of load slicing"),
107                   cl::init(false));
108 
109 static cl::opt<bool>
110   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
111                     cl::desc("DAG combiner may split indexing from loads"));
112 
113 static cl::opt<bool>
114     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
115                        cl::desc("DAG combiner enable merging multiple stores "
116                                 "into a wider store"));
117 
118 static cl::opt<unsigned> TokenFactorInlineLimit(
119     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
120     cl::desc("Limit the number of operands to inline for Token Factors"));
121 
122 static cl::opt<unsigned> StoreMergeDependenceLimit(
123     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
124     cl::desc("Limit the number of times for the same StoreNode and RootNode "
125              "to bail out in store merging dependence check"));
126 
127 namespace {
128 
129   class DAGCombiner {
130     SelectionDAG &DAG;
131     const TargetLowering &TLI;
132     CombineLevel Level;
133     CodeGenOpt::Level OptLevel;
134     bool LegalDAG = false;
135     bool LegalOperations = false;
136     bool LegalTypes = false;
137     bool ForCodeSize;
138 
139     /// Worklist of all of the nodes that need to be simplified.
140     ///
141     /// This must behave as a stack -- new nodes to process are pushed onto the
142     /// back and when processing we pop off of the back.
143     ///
144     /// The worklist will not contain duplicates but may contain null entries
145     /// due to nodes being deleted from the underlying DAG.
146     SmallVector<SDNode *, 64> Worklist;
147 
148     /// Mapping from an SDNode to its position on the worklist.
149     ///
150     /// This is used to find and remove nodes from the worklist (by nulling
151     /// them) when they are deleted from the underlying DAG. It relies on
152     /// stable indices of nodes within the worklist.
153     DenseMap<SDNode *, unsigned> WorklistMap;
154     /// This records all nodes attempted to add to the worklist since we
155     /// considered a new worklist entry. As we keep do not add duplicate nodes
156     /// in the worklist, this is different from the tail of the worklist.
157     SmallSetVector<SDNode *, 32> PruningList;
158 
159     /// Set of nodes which have been combined (at least once).
160     ///
161     /// This is used to allow us to reliably add any operands of a DAG node
162     /// which have not yet been combined to the worklist.
163     SmallPtrSet<SDNode *, 32> CombinedNodes;
164 
165     /// Map from candidate StoreNode to the pair of RootNode and count.
166     /// The count is used to track how many times we have seen the StoreNode
167     /// with the same RootNode bail out in dependence check. If we have seen
168     /// the bail out for the same pair many times over a limit, we won't
169     /// consider the StoreNode with the same RootNode as store merging
170     /// candidate again.
171     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
172 
173     // AA - Used for DAG load/store alias analysis.
174     AliasAnalysis *AA;
175 
176     /// When an instruction is simplified, add all users of the instruction to
177     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)178     void AddUsersToWorklist(SDNode *N) {
179       for (SDNode *Node : N->uses())
180         AddToWorklist(Node);
181     }
182 
183     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)184     void AddToWorklistWithUsers(SDNode *N) {
185       AddUsersToWorklist(N);
186       AddToWorklist(N);
187     }
188 
189     // Prune potentially dangling nodes. This is called after
190     // any visit to a node, but should also be called during a visit after any
191     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()192     void clearAddedDanglingWorklistEntries() {
193       // Check any nodes added to the worklist to see if they are prunable.
194       while (!PruningList.empty()) {
195         auto *N = PruningList.pop_back_val();
196         if (N->use_empty())
197           recursivelyDeleteUnusedNodes(N);
198       }
199     }
200 
getNextWorklistEntry()201     SDNode *getNextWorklistEntry() {
202       // Before we do any work, remove nodes that are not in use.
203       clearAddedDanglingWorklistEntries();
204       SDNode *N = nullptr;
205       // The Worklist holds the SDNodes in order, but it may contain null
206       // entries.
207       while (!N && !Worklist.empty()) {
208         N = Worklist.pop_back_val();
209       }
210 
211       if (N) {
212         bool GoodWorklistEntry = WorklistMap.erase(N);
213         (void)GoodWorklistEntry;
214         assert(GoodWorklistEntry &&
215                "Found a worklist entry without a corresponding map entry!");
216       }
217       return N;
218     }
219 
220     /// Call the node-specific routine that folds each particular type of node.
221     SDValue visit(SDNode *N);
222 
223   public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOpt::Level OL)224     DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
225         : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes),
226           OptLevel(OL), AA(AA) {
227       ForCodeSize = DAG.shouldOptForSize();
228 
229       MaximumLegalStoreInBits = 0;
230       // We use the minimum store size here, since that's all we can guarantee
231       // for the scalable vector types.
232       for (MVT VT : MVT::all_valuetypes())
233         if (EVT(VT).isSimple() && VT != MVT::Other &&
234             TLI.isTypeLegal(EVT(VT)) &&
235             VT.getSizeInBits().getKnownMinSize() >= MaximumLegalStoreInBits)
236           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinSize();
237     }
238 
ConsiderForPruning(SDNode * N)239     void ConsiderForPruning(SDNode *N) {
240       // Mark this for potential pruning.
241       PruningList.insert(N);
242     }
243 
244     /// Add to the worklist making sure its instance is at the back (next to be
245     /// processed.)
AddToWorklist(SDNode * N)246     void AddToWorklist(SDNode *N) {
247       assert(N->getOpcode() != ISD::DELETED_NODE &&
248              "Deleted Node added to Worklist");
249 
250       // Skip handle nodes as they can't usefully be combined and confuse the
251       // zero-use deletion strategy.
252       if (N->getOpcode() == ISD::HANDLENODE)
253         return;
254 
255       ConsiderForPruning(N);
256 
257       if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
258         Worklist.push_back(N);
259     }
260 
261     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)262     void removeFromWorklist(SDNode *N) {
263       CombinedNodes.erase(N);
264       PruningList.remove(N);
265       StoreRootCountMap.erase(N);
266 
267       auto It = WorklistMap.find(N);
268       if (It == WorklistMap.end())
269         return; // Not in the worklist.
270 
271       // Null out the entry rather than erasing it to avoid a linear operation.
272       Worklist[It->second] = nullptr;
273       WorklistMap.erase(It);
274     }
275 
276     void deleteAndRecombine(SDNode *N);
277     bool recursivelyDeleteUnusedNodes(SDNode *N);
278 
279     /// Replaces all uses of the results of one DAG node with new values.
280     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
281                       bool AddTo = true);
282 
283     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)284     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
285       return CombineTo(N, &Res, 1, AddTo);
286     }
287 
288     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)289     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
290                       bool AddTo = true) {
291       SDValue To[] = { Res0, Res1 };
292       return CombineTo(N, To, 2, AddTo);
293     }
294 
295     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
296 
297   private:
298     unsigned MaximumLegalStoreInBits;
299 
300     /// Check the specified integer node value to see if it can be simplified or
301     /// if things it uses can be simplified by bit propagation.
302     /// If so, return true.
SimplifyDemandedBits(SDValue Op)303     bool SimplifyDemandedBits(SDValue Op) {
304       unsigned BitWidth = Op.getScalarValueSizeInBits();
305       APInt DemandedBits = APInt::getAllOnesValue(BitWidth);
306       return SimplifyDemandedBits(Op, DemandedBits);
307     }
308 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)309     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
310       EVT VT = Op.getValueType();
311       unsigned NumElts = VT.isVector() ? VT.getVectorNumElements() : 1;
312       APInt DemandedElts = APInt::getAllOnesValue(NumElts);
313       return SimplifyDemandedBits(Op, DemandedBits, DemandedElts);
314     }
315 
316     /// Check the specified vector node value to see if it can be simplified or
317     /// if things it uses can be simplified as it only uses some of the
318     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)319     bool SimplifyDemandedVectorElts(SDValue Op) {
320       unsigned NumElts = Op.getValueType().getVectorNumElements();
321       APInt DemandedElts = APInt::getAllOnesValue(NumElts);
322       return SimplifyDemandedVectorElts(Op, DemandedElts);
323     }
324 
325     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
326                               const APInt &DemandedElts);
327     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
328                                     bool AssumeSingleUse = false);
329 
330     bool CombineToPreIndexedLoadStore(SDNode *N);
331     bool CombineToPostIndexedLoadStore(SDNode *N);
332     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
333     bool SliceUpLoad(SDNode *N);
334 
335     // Scalars have size 0 to distinguish from singleton vectors.
336     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
337     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
338     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
339 
340     /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
341     ///   load.
342     ///
343     /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
344     /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
345     /// \param EltNo index of the vector element to load.
346     /// \param OriginalLoad load that EVE came from to be replaced.
347     /// \returns EVE on success SDValue() on failure.
348     SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
349                                          SDValue EltNo,
350                                          LoadSDNode *OriginalLoad);
351     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
352     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
353     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
354     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
355     SDValue PromoteIntBinOp(SDValue Op);
356     SDValue PromoteIntShiftOp(SDValue Op);
357     SDValue PromoteExtend(SDValue Op);
358     bool PromoteLoad(SDValue Op);
359 
360     /// Call the node-specific routine that knows how to fold each
361     /// particular type of node. If that doesn't do anything, try the
362     /// target-specific DAG combines.
363     SDValue combine(SDNode *N);
364 
365     // Visitation implementation - Implement dag node combining for different
366     // node types.  The semantics are as follows:
367     // Return Value:
368     //   SDValue.getNode() == 0 - No change was made
369     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
370     //   otherwise              - N should be replaced by the returned Operand.
371     //
372     SDValue visitTokenFactor(SDNode *N);
373     SDValue visitMERGE_VALUES(SDNode *N);
374     SDValue visitADD(SDNode *N);
375     SDValue visitADDLike(SDNode *N);
376     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
377     SDValue visitSUB(SDNode *N);
378     SDValue visitADDSAT(SDNode *N);
379     SDValue visitSUBSAT(SDNode *N);
380     SDValue visitADDC(SDNode *N);
381     SDValue visitADDO(SDNode *N);
382     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
383     SDValue visitSUBC(SDNode *N);
384     SDValue visitSUBO(SDNode *N);
385     SDValue visitADDE(SDNode *N);
386     SDValue visitADDCARRY(SDNode *N);
387     SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
388     SDValue visitSUBE(SDNode *N);
389     SDValue visitSUBCARRY(SDNode *N);
390     SDValue visitMUL(SDNode *N);
391     SDValue visitMULFIX(SDNode *N);
392     SDValue useDivRem(SDNode *N);
393     SDValue visitSDIV(SDNode *N);
394     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
395     SDValue visitUDIV(SDNode *N);
396     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
397     SDValue visitREM(SDNode *N);
398     SDValue visitMULHU(SDNode *N);
399     SDValue visitMULHS(SDNode *N);
400     SDValue visitSMUL_LOHI(SDNode *N);
401     SDValue visitUMUL_LOHI(SDNode *N);
402     SDValue visitMULO(SDNode *N);
403     SDValue visitIMINMAX(SDNode *N);
404     SDValue visitAND(SDNode *N);
405     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
406     SDValue visitOR(SDNode *N);
407     SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
408     SDValue visitXOR(SDNode *N);
409     SDValue SimplifyVBinOp(SDNode *N);
410     SDValue visitSHL(SDNode *N);
411     SDValue visitSRA(SDNode *N);
412     SDValue visitSRL(SDNode *N);
413     SDValue visitFunnelShift(SDNode *N);
414     SDValue visitRotate(SDNode *N);
415     SDValue visitABS(SDNode *N);
416     SDValue visitBSWAP(SDNode *N);
417     SDValue visitBITREVERSE(SDNode *N);
418     SDValue visitCTLZ(SDNode *N);
419     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
420     SDValue visitCTTZ(SDNode *N);
421     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
422     SDValue visitCTPOP(SDNode *N);
423     SDValue visitSELECT(SDNode *N);
424     SDValue visitVSELECT(SDNode *N);
425     SDValue visitSELECT_CC(SDNode *N);
426     SDValue visitSETCC(SDNode *N);
427     SDValue visitSETCCCARRY(SDNode *N);
428     SDValue visitSIGN_EXTEND(SDNode *N);
429     SDValue visitZERO_EXTEND(SDNode *N);
430     SDValue visitANY_EXTEND(SDNode *N);
431     SDValue visitAssertExt(SDNode *N);
432     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
433     SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N);
434     SDValue visitZERO_EXTEND_VECTOR_INREG(SDNode *N);
435     SDValue visitTRUNCATE(SDNode *N);
436     SDValue visitBITCAST(SDNode *N);
437     SDValue visitBUILD_PAIR(SDNode *N);
438     SDValue visitFADD(SDNode *N);
439     SDValue visitFSUB(SDNode *N);
440     SDValue visitFMUL(SDNode *N);
441     SDValue visitFMA(SDNode *N);
442     SDValue visitFDIV(SDNode *N);
443     SDValue visitFREM(SDNode *N);
444     SDValue visitFSQRT(SDNode *N);
445     SDValue visitFCOPYSIGN(SDNode *N);
446     SDValue visitFPOW(SDNode *N);
447     SDValue visitSINT_TO_FP(SDNode *N);
448     SDValue visitUINT_TO_FP(SDNode *N);
449     SDValue visitFP_TO_SINT(SDNode *N);
450     SDValue visitFP_TO_UINT(SDNode *N);
451     SDValue visitFP_ROUND(SDNode *N);
452     SDValue visitFP_EXTEND(SDNode *N);
453     SDValue visitFNEG(SDNode *N);
454     SDValue visitFABS(SDNode *N);
455     SDValue visitFCEIL(SDNode *N);
456     SDValue visitFTRUNC(SDNode *N);
457     SDValue visitFFLOOR(SDNode *N);
458     SDValue visitFMINNUM(SDNode *N);
459     SDValue visitFMAXNUM(SDNode *N);
460     SDValue visitFMINIMUM(SDNode *N);
461     SDValue visitFMAXIMUM(SDNode *N);
462     SDValue visitBRCOND(SDNode *N);
463     SDValue visitBR_CC(SDNode *N);
464     SDValue visitLOAD(SDNode *N);
465 
466     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
467     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
468 
469     SDValue visitSTORE(SDNode *N);
470     SDValue visitLIFETIME_END(SDNode *N);
471     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
472     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
473     SDValue visitBUILD_VECTOR(SDNode *N);
474     SDValue visitCONCAT_VECTORS(SDNode *N);
475     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
476     SDValue visitVECTOR_SHUFFLE(SDNode *N);
477     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
478     SDValue visitINSERT_SUBVECTOR(SDNode *N);
479     SDValue visitMLOAD(SDNode *N);
480     SDValue visitMSTORE(SDNode *N);
481     SDValue visitMGATHER(SDNode *N);
482     SDValue visitMSCATTER(SDNode *N);
483     SDValue visitFP_TO_FP16(SDNode *N);
484     SDValue visitFP16_TO_FP(SDNode *N);
485     SDValue visitVECREDUCE(SDNode *N);
486 
487     SDValue visitFADDForFMACombine(SDNode *N);
488     SDValue visitFSUBForFMACombine(SDNode *N);
489     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
490 
491     SDValue XformToShuffleWithZero(SDNode *N);
492     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
493                                                     const SDLoc &DL, SDValue N0,
494                                                     SDValue N1);
495     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
496                                       SDValue N1);
497     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
498                            SDValue N1, SDNodeFlags Flags);
499 
500     SDValue visitShiftByConstant(SDNode *N);
501 
502     SDValue foldSelectOfConstants(SDNode *N);
503     SDValue foldVSelectOfConstants(SDNode *N);
504     SDValue foldBinOpIntoSelect(SDNode *BO);
505     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
506     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
507     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
508     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
509                              SDValue N2, SDValue N3, ISD::CondCode CC,
510                              bool NotExtCompare = false);
511     SDValue convertSelectOfFPConstantsToLoadOffset(
512         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
513         ISD::CondCode CC);
514     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
515                                    SDValue N2, SDValue N3, ISD::CondCode CC);
516     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
517                               const SDLoc &DL);
518     SDValue unfoldMaskedMerge(SDNode *N);
519     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
520     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
521                           const SDLoc &DL, bool foldBooleans);
522     SDValue rebuildSetCC(SDValue N);
523 
524     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
525                            SDValue &CC) const;
526     bool isOneUseSetCC(SDValue N) const;
527     bool isCheaperToUseNegatedFPOps(SDValue X, SDValue Y);
528 
529     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
530                                          unsigned HiOp);
531     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
532     SDValue CombineExtLoad(SDNode *N);
533     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
534     SDValue combineRepeatedFPDivisors(SDNode *N);
535     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
536     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
537     SDValue BuildSDIV(SDNode *N);
538     SDValue BuildSDIVPow2(SDNode *N);
539     SDValue BuildUDIV(SDNode *N);
540     SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
541     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
542     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
543     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
544     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
545     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
546                                 SDNodeFlags Flags, bool Reciprocal);
547     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
548                                 SDNodeFlags Flags, bool Reciprocal);
549     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
550                                bool DemandHighBits = true);
551     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
552     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
553                               SDValue InnerPos, SDValue InnerNeg,
554                               unsigned PosOpcode, unsigned NegOpcode,
555                               const SDLoc &DL);
556     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
557     SDValue MatchLoadCombine(SDNode *N);
558     SDValue MatchStoreCombine(StoreSDNode *N);
559     SDValue ReduceLoadWidth(SDNode *N);
560     SDValue ReduceLoadOpStoreWidth(SDNode *N);
561     SDValue splitMergedValStore(StoreSDNode *ST);
562     SDValue TransformFPLoadStorePair(SDNode *N);
563     SDValue convertBuildVecZextToZext(SDNode *N);
564     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
565     SDValue reduceBuildVecToShuffle(SDNode *N);
566     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
567                                   ArrayRef<int> VectorMask, SDValue VecIn1,
568                                   SDValue VecIn2, unsigned LeftIdx,
569                                   bool DidSplitVec);
570     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
571 
572     /// Walk up chain skipping non-aliasing memory nodes,
573     /// looking for aliasing nodes and adding them to the Aliases vector.
574     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
575                           SmallVectorImpl<SDValue> &Aliases);
576 
577     /// Return true if there is any possibility that the two addresses overlap.
578     bool isAlias(SDNode *Op0, SDNode *Op1) const;
579 
580     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
581     /// chain (aliasing node.)
582     SDValue FindBetterChain(SDNode *N, SDValue Chain);
583 
584     /// Try to replace a store and any possibly adjacent stores on
585     /// consecutive chains with better chains. Return true only if St is
586     /// replaced.
587     ///
588     /// Notice that other chains may still be replaced even if the function
589     /// returns false.
590     bool findBetterNeighborChains(StoreSDNode *St);
591 
592     // Helper for findBetterNeighborChains. Walk up store chain add additional
593     // chained stores that do not overlap and can be parallelized.
594     bool parallelizeChainedStores(StoreSDNode *St);
595 
596     /// Holds a pointer to an LSBaseSDNode as well as information on where it
597     /// is located in a sequence of memory operations connected by a chain.
598     struct MemOpLink {
599       // Ptr to the mem node.
600       LSBaseSDNode *MemNode;
601 
602       // Offset from the base ptr.
603       int64_t OffsetFromBase;
604 
MemOpLink__anon4d358cf60111::DAGCombiner::MemOpLink605       MemOpLink(LSBaseSDNode *N, int64_t Offset)
606           : MemNode(N), OffsetFromBase(Offset) {}
607     };
608 
609     /// This is a helper function for visitMUL to check the profitability
610     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
611     /// MulNode is the original multiply, AddNode is (add x, c1),
612     /// and ConstNode is c2.
613     bool isMulAddWithConstProfitable(SDNode *MulNode,
614                                      SDValue &AddNode,
615                                      SDValue &ConstNode);
616 
617     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
618     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
619     /// the type of the loaded value to be extended.
620     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
621                           EVT LoadResultTy, EVT &ExtVT);
622 
623     /// Helper function to calculate whether the given Load/Store can have its
624     /// width reduced to ExtVT.
625     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
626                            EVT &MemVT, unsigned ShAmt = 0);
627 
628     /// Used by BackwardsPropagateMask to find suitable loads.
629     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
630                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
631                            ConstantSDNode *Mask, SDNode *&NodeToMask);
632     /// Attempt to propagate a given AND node back to load leaves so that they
633     /// can be combined into narrow loads.
634     bool BackwardsPropagateMask(SDNode *N);
635 
636     /// Helper function for MergeConsecutiveStores which merges the
637     /// component store chains.
638     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
639                                 unsigned NumStores);
640 
641     /// This is a helper function for MergeConsecutiveStores. When the
642     /// source elements of the consecutive stores are all constants or
643     /// all extracted vector elements, try to merge them into one
644     /// larger store introducing bitcasts if necessary.  \return True
645     /// if a merged store was created.
646     bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
647                                          EVT MemVT, unsigned NumStores,
648                                          bool IsConstantSrc, bool UseVector,
649                                          bool UseTrunc);
650 
651     /// This is a helper function for MergeConsecutiveStores. Stores
652     /// that potentially may be merged with St are placed in
653     /// StoreNodes. RootNode is a chain predecessor to all store
654     /// candidates.
655     void getStoreMergeCandidates(StoreSDNode *St,
656                                  SmallVectorImpl<MemOpLink> &StoreNodes,
657                                  SDNode *&Root);
658 
659     /// Helper function for MergeConsecutiveStores. Checks if
660     /// candidate stores have indirect dependency through their
661     /// operands. RootNode is the predecessor to all stores calculated
662     /// by getStoreMergeCandidates and is used to prune the dependency check.
663     /// \return True if safe to merge.
664     bool checkMergeStoreCandidatesForDependencies(
665         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
666         SDNode *RootNode);
667 
668     /// Merge consecutive store operations into a wide store.
669     /// This optimization uses wide integers or vectors when possible.
670     /// \return number of stores that were merged into a merged store (the
671     /// affected nodes are stored as a prefix in \p StoreNodes).
672     bool MergeConsecutiveStores(StoreSDNode *St);
673 
674     /// Try to transform a truncation where C is a constant:
675     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
676     ///
677     /// \p N needs to be a truncation and its first operand an AND. Other
678     /// requirements are checked by the function (e.g. that trunc is
679     /// single-use) and if missed an empty SDValue is returned.
680     SDValue distributeTruncateThroughAnd(SDNode *N);
681 
682     /// Helper function to determine whether the target supports operation
683     /// given by \p Opcode for type \p VT, that is, whether the operation
684     /// is legal or custom before legalizing operations, and whether is
685     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)686     bool hasOperation(unsigned Opcode, EVT VT) {
687       if (LegalOperations)
688         return TLI.isOperationLegal(Opcode, VT);
689       return TLI.isOperationLegalOrCustom(Opcode, VT);
690     }
691 
692   public:
693     /// Runs the dag combiner on all nodes in the work list
694     void Run(CombineLevel AtLevel);
695 
getDAG() const696     SelectionDAG &getDAG() const { return DAG; }
697 
698     /// Returns a type large enough to hold any valid shift amount - before type
699     /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)700     EVT getShiftAmountTy(EVT LHSTy) {
701       assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
702       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
703     }
704 
705     /// This method returns true if we are running before type legalization or
706     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)707     bool isTypeLegal(const EVT &VT) {
708       if (!LegalTypes) return true;
709       return TLI.isTypeLegal(VT);
710     }
711 
712     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const713     EVT getSetCCResultType(EVT VT) const {
714       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
715     }
716 
717     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
718                          SDValue OrigLoad, SDValue ExtLoad,
719                          ISD::NodeType ExtType);
720   };
721 
722 /// This class is a DAGUpdateListener that removes any deleted
723 /// nodes from the worklist.
724 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
725   DAGCombiner &DC;
726 
727 public:
WorklistRemover(DAGCombiner & dc)728   explicit WorklistRemover(DAGCombiner &dc)
729     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
730 
NodeDeleted(SDNode * N,SDNode * E)731   void NodeDeleted(SDNode *N, SDNode *E) override {
732     DC.removeFromWorklist(N);
733   }
734 };
735 
736 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
737   DAGCombiner &DC;
738 
739 public:
WorklistInserter(DAGCombiner & dc)740   explicit WorklistInserter(DAGCombiner &dc)
741       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
742 
743   // FIXME: Ideally we could add N to the worklist, but this causes exponential
744   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)745   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
746 };
747 
748 } // end anonymous namespace
749 
750 //===----------------------------------------------------------------------===//
751 //  TargetLowering::DAGCombinerInfo implementation
752 //===----------------------------------------------------------------------===//
753 
AddToWorklist(SDNode * N)754 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
755   ((DAGCombiner*)DC)->AddToWorklist(N);
756 }
757 
758 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)759 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
760   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
761 }
762 
763 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)764 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
765   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
766 }
767 
768 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)769 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
770   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
771 }
772 
773 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)774 recursivelyDeleteUnusedNodes(SDNode *N) {
775   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
776 }
777 
778 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)779 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
780   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Helper Functions
785 //===----------------------------------------------------------------------===//
786 
deleteAndRecombine(SDNode * N)787 void DAGCombiner::deleteAndRecombine(SDNode *N) {
788   removeFromWorklist(N);
789 
790   // If the operands of this node are only used by the node, they will now be
791   // dead. Make sure to re-visit them and recursively delete dead nodes.
792   for (const SDValue &Op : N->ops())
793     // For an operand generating multiple values, one of the values may
794     // become dead allowing further simplification (e.g. split index
795     // arithmetic from an indexed load).
796     if (Op->hasOneUse() || Op->getNumValues() > 1)
797       AddToWorklist(Op.getNode());
798 
799   DAG.DeleteNode(N);
800 }
801 
802 // APInts must be the same size for most operations, this helper
803 // function zero extends the shorter of the pair so that they match.
804 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)805 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
806   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
807   LHS = LHS.zextOrSelf(Bits);
808   RHS = RHS.zextOrSelf(Bits);
809 }
810 
811 // Return true if this node is a setcc, or is a select_cc
812 // that selects between the target values used for true and false, making it
813 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
814 // the appropriate nodes based on the type of node we are checking. This
815 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC) const816 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
817                                     SDValue &CC) const {
818   if (N.getOpcode() == ISD::SETCC) {
819     LHS = N.getOperand(0);
820     RHS = N.getOperand(1);
821     CC  = N.getOperand(2);
822     return true;
823   }
824 
825   if (N.getOpcode() != ISD::SELECT_CC ||
826       !TLI.isConstTrueVal(N.getOperand(2).getNode()) ||
827       !TLI.isConstFalseVal(N.getOperand(3).getNode()))
828     return false;
829 
830   if (TLI.getBooleanContents(N.getValueType()) ==
831       TargetLowering::UndefinedBooleanContent)
832     return false;
833 
834   LHS = N.getOperand(0);
835   RHS = N.getOperand(1);
836   CC  = N.getOperand(4);
837   return true;
838 }
839 
840 /// Return true if this is a SetCC-equivalent operation with only one use.
841 /// If this is true, it allows the users to invert the operation for free when
842 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const843 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
844   SDValue N0, N1, N2;
845   if (isSetCCEquivalent(N, N0, N1, N2) && N.getNode()->hasOneUse())
846     return true;
847   return false;
848 }
849 
850 // Returns the SDNode if it is a constant float BuildVector
851 // or constant float.
isConstantFPBuildVectorOrConstantFP(SDValue N)852 static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) {
853   if (isa<ConstantFPSDNode>(N))
854     return N.getNode();
855   if (ISD::isBuildVectorOfConstantFPSDNodes(N.getNode()))
856     return N.getNode();
857   return nullptr;
858 }
859 
860 // Determines if it is a constant integer or a build vector of constant
861 // integers (and undefs).
862 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)863 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
864   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
865     return !(Const->isOpaque() && NoOpaques);
866   if (N.getOpcode() != ISD::BUILD_VECTOR)
867     return false;
868   unsigned BitWidth = N.getScalarValueSizeInBits();
869   for (const SDValue &Op : N->op_values()) {
870     if (Op.isUndef())
871       continue;
872     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
873     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
874         (Const->isOpaque() && NoOpaques))
875       return false;
876   }
877   return true;
878 }
879 
880 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
881 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)882 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
883   if (V.getOpcode() != ISD::BUILD_VECTOR)
884     return false;
885   return isConstantOrConstantVector(V, NoOpaques) ||
886          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
887 }
888 
889 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)890 static bool canSplitIdx(LoadSDNode *LD) {
891   return MaySplitLoadIndex &&
892          (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
893           !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
894 }
895 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)896 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
897                                                              const SDLoc &DL,
898                                                              SDValue N0,
899                                                              SDValue N1) {
900   // Currently this only tries to ensure we don't undo the GEP splits done by
901   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
902   // we check if the following transformation would be problematic:
903   // (load/store (add, (add, x, offset1), offset2)) ->
904   // (load/store (add, x, offset1+offset2)).
905 
906   if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
907     return false;
908 
909   if (N0.hasOneUse())
910     return false;
911 
912   auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
913   auto *C2 = dyn_cast<ConstantSDNode>(N1);
914   if (!C1 || !C2)
915     return false;
916 
917   const APInt &C1APIntVal = C1->getAPIntValue();
918   const APInt &C2APIntVal = C2->getAPIntValue();
919   if (C1APIntVal.getBitWidth() > 64 || C2APIntVal.getBitWidth() > 64)
920     return false;
921 
922   const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
923   if (CombinedValueIntVal.getBitWidth() > 64)
924     return false;
925   const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
926 
927   for (SDNode *Node : N0->uses()) {
928     auto LoadStore = dyn_cast<MemSDNode>(Node);
929     if (LoadStore) {
930       // Is x[offset2] already not a legal addressing mode? If so then
931       // reassociating the constants breaks nothing (we test offset2 because
932       // that's the one we hope to fold into the load or store).
933       TargetLoweringBase::AddrMode AM;
934       AM.HasBaseReg = true;
935       AM.BaseOffs = C2APIntVal.getSExtValue();
936       EVT VT = LoadStore->getMemoryVT();
937       unsigned AS = LoadStore->getAddressSpace();
938       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
939       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
940         continue;
941 
942       // Would x[offset1+offset2] still be a legal addressing mode?
943       AM.BaseOffs = CombinedValue;
944       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
945         return true;
946     }
947   }
948 
949   return false;
950 }
951 
952 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
953 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)954 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
955                                                SDValue N0, SDValue N1) {
956   EVT VT = N0.getValueType();
957 
958   if (N0.getOpcode() != Opc)
959     return SDValue();
960 
961   // Don't reassociate reductions.
962   if (N0->getFlags().hasVectorReduction())
963     return SDValue();
964 
965   if (SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) {
966     if (SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
967       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
968       if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, C1, C2))
969         return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode);
970       return SDValue();
971     }
972     if (N0.hasOneUse()) {
973       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
974       //              iff (op x, c1) has one use
975       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N0.getOperand(0), N1);
976       if (!OpNode.getNode())
977         return SDValue();
978       return DAG.getNode(Opc, DL, VT, OpNode, N0.getOperand(1));
979     }
980   }
981   return SDValue();
982 }
983 
984 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)985 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
986                                     SDValue N1, SDNodeFlags Flags) {
987   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
988   // Don't reassociate reductions.
989   if (Flags.hasVectorReduction())
990     return SDValue();
991 
992   // Floating-point reassociation is not allowed without loose FP math.
993   if (N0.getValueType().isFloatingPoint() ||
994       N1.getValueType().isFloatingPoint())
995     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
996       return SDValue();
997 
998   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1))
999     return Combined;
1000   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0))
1001     return Combined;
1002   return SDValue();
1003 }
1004 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1005 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1006                                bool AddTo) {
1007   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1008   ++NodesCombined;
1009   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1010              To[0].getNode()->dump(&DAG);
1011              dbgs() << " and " << NumTo - 1 << " other values\n");
1012   for (unsigned i = 0, e = NumTo; i != e; ++i)
1013     assert((!To[i].getNode() ||
1014             N->getValueType(i) == To[i].getValueType()) &&
1015            "Cannot combine value to value of different type!");
1016 
1017   WorklistRemover DeadNodes(*this);
1018   DAG.ReplaceAllUsesWith(N, To);
1019   if (AddTo) {
1020     // Push the new nodes and any users onto the worklist
1021     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1022       if (To[i].getNode()) {
1023         AddToWorklist(To[i].getNode());
1024         AddUsersToWorklist(To[i].getNode());
1025       }
1026     }
1027   }
1028 
1029   // Finally, if the node is now dead, remove it from the graph.  The node
1030   // may not be dead if the replacement process recursively simplified to
1031   // something else needing this node.
1032   if (N->use_empty())
1033     deleteAndRecombine(N);
1034   return SDValue(N, 0);
1035 }
1036 
1037 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1038 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1039   // Replace all uses.  If any nodes become isomorphic to other nodes and
1040   // are deleted, make sure to remove them from our worklist.
1041   WorklistRemover DeadNodes(*this);
1042   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1043 
1044   // Push the new node and any (possibly new) users onto the worklist.
1045   AddToWorklistWithUsers(TLO.New.getNode());
1046 
1047   // Finally, if the node is now dead, remove it from the graph.  The node
1048   // may not be dead if the replacement process recursively simplified to
1049   // something else needing this node.
1050   if (TLO.Old.getNode()->use_empty())
1051     deleteAndRecombine(TLO.Old.getNode());
1052 }
1053 
1054 /// Check the specified integer node value to see if it can be simplified or if
1055 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts)1056 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1057                                        const APInt &DemandedElts) {
1058   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1059   KnownBits Known;
1060   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO))
1061     return false;
1062 
1063   // Revisit the node.
1064   AddToWorklist(Op.getNode());
1065 
1066   // Replace the old value with the new one.
1067   ++NodesCombined;
1068   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG);
1069              dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG);
1070              dbgs() << '\n');
1071 
1072   CommitTargetLoweringOpt(TLO);
1073   return true;
1074 }
1075 
1076 /// Check the specified vector node value to see if it can be simplified or
1077 /// if things it uses can be simplified as it only uses some of the elements.
1078 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1079 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1080                                              const APInt &DemandedElts,
1081                                              bool AssumeSingleUse) {
1082   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1083   APInt KnownUndef, KnownZero;
1084   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1085                                       TLO, 0, AssumeSingleUse))
1086     return false;
1087 
1088   // Revisit the node.
1089   AddToWorklist(Op.getNode());
1090 
1091   // Replace the old value with the new one.
1092   ++NodesCombined;
1093   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG);
1094              dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG);
1095              dbgs() << '\n');
1096 
1097   CommitTargetLoweringOpt(TLO);
1098   return true;
1099 }
1100 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1101 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1102   SDLoc DL(Load);
1103   EVT VT = Load->getValueType(0);
1104   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1105 
1106   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1107              Trunc.getNode()->dump(&DAG); dbgs() << '\n');
1108   WorklistRemover DeadNodes(*this);
1109   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1110   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1111   deleteAndRecombine(Load);
1112   AddToWorklist(Trunc.getNode());
1113 }
1114 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1115 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1116   Replace = false;
1117   SDLoc DL(Op);
1118   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1119     LoadSDNode *LD = cast<LoadSDNode>(Op);
1120     EVT MemVT = LD->getMemoryVT();
1121     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1122                                                       : LD->getExtensionType();
1123     Replace = true;
1124     return DAG.getExtLoad(ExtType, DL, PVT,
1125                           LD->getChain(), LD->getBasePtr(),
1126                           MemVT, LD->getMemOperand());
1127   }
1128 
1129   unsigned Opc = Op.getOpcode();
1130   switch (Opc) {
1131   default: break;
1132   case ISD::AssertSext:
1133     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1134       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1135     break;
1136   case ISD::AssertZext:
1137     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1138       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1139     break;
1140   case ISD::Constant: {
1141     unsigned ExtOpc =
1142       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1143     return DAG.getNode(ExtOpc, DL, PVT, Op);
1144   }
1145   }
1146 
1147   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1148     return SDValue();
1149   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1150 }
1151 
SExtPromoteOperand(SDValue Op,EVT PVT)1152 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1153   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1154     return SDValue();
1155   EVT OldVT = Op.getValueType();
1156   SDLoc DL(Op);
1157   bool Replace = false;
1158   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1159   if (!NewOp.getNode())
1160     return SDValue();
1161   AddToWorklist(NewOp.getNode());
1162 
1163   if (Replace)
1164     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1165   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1166                      DAG.getValueType(OldVT));
1167 }
1168 
ZExtPromoteOperand(SDValue Op,EVT PVT)1169 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1170   EVT OldVT = Op.getValueType();
1171   SDLoc DL(Op);
1172   bool Replace = false;
1173   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1174   if (!NewOp.getNode())
1175     return SDValue();
1176   AddToWorklist(NewOp.getNode());
1177 
1178   if (Replace)
1179     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1180   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1181 }
1182 
1183 /// Promote the specified integer binary operation if the target indicates it is
1184 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1185 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1186 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1187   if (!LegalOperations)
1188     return SDValue();
1189 
1190   EVT VT = Op.getValueType();
1191   if (VT.isVector() || !VT.isInteger())
1192     return SDValue();
1193 
1194   // If operation type is 'undesirable', e.g. i16 on x86, consider
1195   // promoting it.
1196   unsigned Opc = Op.getOpcode();
1197   if (TLI.isTypeDesirableForOp(Opc, VT))
1198     return SDValue();
1199 
1200   EVT PVT = VT;
1201   // Consult target whether it is a good idea to promote this operation and
1202   // what's the right type to promote it to.
1203   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1204     assert(PVT != VT && "Don't know what type to promote to!");
1205 
1206     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1207 
1208     bool Replace0 = false;
1209     SDValue N0 = Op.getOperand(0);
1210     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1211 
1212     bool Replace1 = false;
1213     SDValue N1 = Op.getOperand(1);
1214     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1215     SDLoc DL(Op);
1216 
1217     SDValue RV =
1218         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1219 
1220     // We are always replacing N0/N1's use in N and only need
1221     // additional replacements if there are additional uses.
1222     Replace0 &= !N0->hasOneUse();
1223     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1224 
1225     // Combine Op here so it is preserved past replacements.
1226     CombineTo(Op.getNode(), RV);
1227 
1228     // If operands have a use ordering, make sure we deal with
1229     // predecessor first.
1230     if (Replace0 && Replace1 && N0.getNode()->isPredecessorOf(N1.getNode())) {
1231       std::swap(N0, N1);
1232       std::swap(NN0, NN1);
1233     }
1234 
1235     if (Replace0) {
1236       AddToWorklist(NN0.getNode());
1237       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1238     }
1239     if (Replace1) {
1240       AddToWorklist(NN1.getNode());
1241       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1242     }
1243     return Op;
1244   }
1245   return SDValue();
1246 }
1247 
1248 /// Promote the specified integer shift operation if the target indicates it is
1249 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1250 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1251 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1252   if (!LegalOperations)
1253     return SDValue();
1254 
1255   EVT VT = Op.getValueType();
1256   if (VT.isVector() || !VT.isInteger())
1257     return SDValue();
1258 
1259   // If operation type is 'undesirable', e.g. i16 on x86, consider
1260   // promoting it.
1261   unsigned Opc = Op.getOpcode();
1262   if (TLI.isTypeDesirableForOp(Opc, VT))
1263     return SDValue();
1264 
1265   EVT PVT = VT;
1266   // Consult target whether it is a good idea to promote this operation and
1267   // what's the right type to promote it to.
1268   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1269     assert(PVT != VT && "Don't know what type to promote to!");
1270 
1271     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1272 
1273     bool Replace = false;
1274     SDValue N0 = Op.getOperand(0);
1275     SDValue N1 = Op.getOperand(1);
1276     if (Opc == ISD::SRA)
1277       N0 = SExtPromoteOperand(N0, PVT);
1278     else if (Opc == ISD::SRL)
1279       N0 = ZExtPromoteOperand(N0, PVT);
1280     else
1281       N0 = PromoteOperand(N0, PVT, Replace);
1282 
1283     if (!N0.getNode())
1284       return SDValue();
1285 
1286     SDLoc DL(Op);
1287     SDValue RV =
1288         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1289 
1290     if (Replace)
1291       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1292 
1293     // Deal with Op being deleted.
1294     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1295       return RV;
1296   }
1297   return SDValue();
1298 }
1299 
PromoteExtend(SDValue Op)1300 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1301   if (!LegalOperations)
1302     return SDValue();
1303 
1304   EVT VT = Op.getValueType();
1305   if (VT.isVector() || !VT.isInteger())
1306     return SDValue();
1307 
1308   // If operation type is 'undesirable', e.g. i16 on x86, consider
1309   // promoting it.
1310   unsigned Opc = Op.getOpcode();
1311   if (TLI.isTypeDesirableForOp(Opc, VT))
1312     return SDValue();
1313 
1314   EVT PVT = VT;
1315   // Consult target whether it is a good idea to promote this operation and
1316   // what's the right type to promote it to.
1317   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1318     assert(PVT != VT && "Don't know what type to promote to!");
1319     // fold (aext (aext x)) -> (aext x)
1320     // fold (aext (zext x)) -> (zext x)
1321     // fold (aext (sext x)) -> (sext x)
1322     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1323     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1324   }
1325   return SDValue();
1326 }
1327 
PromoteLoad(SDValue Op)1328 bool DAGCombiner::PromoteLoad(SDValue Op) {
1329   if (!LegalOperations)
1330     return false;
1331 
1332   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1333     return false;
1334 
1335   EVT VT = Op.getValueType();
1336   if (VT.isVector() || !VT.isInteger())
1337     return false;
1338 
1339   // If operation type is 'undesirable', e.g. i16 on x86, consider
1340   // promoting it.
1341   unsigned Opc = Op.getOpcode();
1342   if (TLI.isTypeDesirableForOp(Opc, VT))
1343     return false;
1344 
1345   EVT PVT = VT;
1346   // Consult target whether it is a good idea to promote this operation and
1347   // what's the right type to promote it to.
1348   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1349     assert(PVT != VT && "Don't know what type to promote to!");
1350 
1351     SDLoc DL(Op);
1352     SDNode *N = Op.getNode();
1353     LoadSDNode *LD = cast<LoadSDNode>(N);
1354     EVT MemVT = LD->getMemoryVT();
1355     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1356                                                       : LD->getExtensionType();
1357     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1358                                    LD->getChain(), LD->getBasePtr(),
1359                                    MemVT, LD->getMemOperand());
1360     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1361 
1362     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1363                Result.getNode()->dump(&DAG); dbgs() << '\n');
1364     WorklistRemover DeadNodes(*this);
1365     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1366     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1367     deleteAndRecombine(N);
1368     AddToWorklist(Result.getNode());
1369     return true;
1370   }
1371   return false;
1372 }
1373 
1374 /// Recursively delete a node which has no uses and any operands for
1375 /// which it is the only use.
1376 ///
1377 /// Note that this both deletes the nodes and removes them from the worklist.
1378 /// It also adds any nodes who have had a user deleted to the worklist as they
1379 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1380 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1381   if (!N->use_empty())
1382     return false;
1383 
1384   SmallSetVector<SDNode *, 16> Nodes;
1385   Nodes.insert(N);
1386   do {
1387     N = Nodes.pop_back_val();
1388     if (!N)
1389       continue;
1390 
1391     if (N->use_empty()) {
1392       for (const SDValue &ChildN : N->op_values())
1393         Nodes.insert(ChildN.getNode());
1394 
1395       removeFromWorklist(N);
1396       DAG.DeleteNode(N);
1397     } else {
1398       AddToWorklist(N);
1399     }
1400   } while (!Nodes.empty());
1401   return true;
1402 }
1403 
1404 //===----------------------------------------------------------------------===//
1405 //  Main DAG Combiner implementation
1406 //===----------------------------------------------------------------------===//
1407 
Run(CombineLevel AtLevel)1408 void DAGCombiner::Run(CombineLevel AtLevel) {
1409   // set the instance variables, so that the various visit routines may use it.
1410   Level = AtLevel;
1411   LegalDAG = Level >= AfterLegalizeDAG;
1412   LegalOperations = Level >= AfterLegalizeVectorOps;
1413   LegalTypes = Level >= AfterLegalizeTypes;
1414 
1415   WorklistInserter AddNodes(*this);
1416 
1417   // Add all the dag nodes to the worklist.
1418   for (SDNode &Node : DAG.allnodes())
1419     AddToWorklist(&Node);
1420 
1421   // Create a dummy node (which is not added to allnodes), that adds a reference
1422   // to the root node, preventing it from being deleted, and tracking any
1423   // changes of the root.
1424   HandleSDNode Dummy(DAG.getRoot());
1425 
1426   // While we have a valid worklist entry node, try to combine it.
1427   while (SDNode *N = getNextWorklistEntry()) {
1428     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1429     // N is deleted from the DAG, since they too may now be dead or may have a
1430     // reduced number of uses, allowing other xforms.
1431     if (recursivelyDeleteUnusedNodes(N))
1432       continue;
1433 
1434     WorklistRemover DeadNodes(*this);
1435 
1436     // If this combine is running after legalizing the DAG, re-legalize any
1437     // nodes pulled off the worklist.
1438     if (LegalDAG) {
1439       SmallSetVector<SDNode *, 16> UpdatedNodes;
1440       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1441 
1442       for (SDNode *LN : UpdatedNodes)
1443         AddToWorklistWithUsers(LN);
1444 
1445       if (!NIsValid)
1446         continue;
1447     }
1448 
1449     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1450 
1451     // Add any operands of the new node which have not yet been combined to the
1452     // worklist as well. Because the worklist uniques things already, this
1453     // won't repeatedly process the same operand.
1454     CombinedNodes.insert(N);
1455     for (const SDValue &ChildN : N->op_values())
1456       if (!CombinedNodes.count(ChildN.getNode()))
1457         AddToWorklist(ChildN.getNode());
1458 
1459     SDValue RV = combine(N);
1460 
1461     if (!RV.getNode())
1462       continue;
1463 
1464     ++NodesCombined;
1465 
1466     // If we get back the same node we passed in, rather than a new node or
1467     // zero, we know that the node must have defined multiple values and
1468     // CombineTo was used.  Since CombineTo takes care of the worklist
1469     // mechanics for us, we have no work to do in this case.
1470     if (RV.getNode() == N)
1471       continue;
1472 
1473     assert(N->getOpcode() != ISD::DELETED_NODE &&
1474            RV.getOpcode() != ISD::DELETED_NODE &&
1475            "Node was deleted but visit returned new node!");
1476 
1477     LLVM_DEBUG(dbgs() << " ... into: "; RV.getNode()->dump(&DAG));
1478 
1479     if (N->getNumValues() == RV.getNode()->getNumValues())
1480       DAG.ReplaceAllUsesWith(N, RV.getNode());
1481     else {
1482       assert(N->getValueType(0) == RV.getValueType() &&
1483              N->getNumValues() == 1 && "Type mismatch");
1484       DAG.ReplaceAllUsesWith(N, &RV);
1485     }
1486 
1487     // Push the new node and any users onto the worklist
1488     AddToWorklist(RV.getNode());
1489     AddUsersToWorklist(RV.getNode());
1490 
1491     // Finally, if the node is now dead, remove it from the graph.  The node
1492     // may not be dead if the replacement process recursively simplified to
1493     // something else needing this node. This will also take care of adding any
1494     // operands which have lost a user to the worklist.
1495     recursivelyDeleteUnusedNodes(N);
1496   }
1497 
1498   // If the root changed (e.g. it was a dead load, update the root).
1499   DAG.setRoot(Dummy.getValue());
1500   DAG.RemoveDeadNodes();
1501 }
1502 
visit(SDNode * N)1503 SDValue DAGCombiner::visit(SDNode *N) {
1504   switch (N->getOpcode()) {
1505   default: break;
1506   case ISD::TokenFactor:        return visitTokenFactor(N);
1507   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1508   case ISD::ADD:                return visitADD(N);
1509   case ISD::SUB:                return visitSUB(N);
1510   case ISD::SADDSAT:
1511   case ISD::UADDSAT:            return visitADDSAT(N);
1512   case ISD::SSUBSAT:
1513   case ISD::USUBSAT:            return visitSUBSAT(N);
1514   case ISD::ADDC:               return visitADDC(N);
1515   case ISD::SADDO:
1516   case ISD::UADDO:              return visitADDO(N);
1517   case ISD::SUBC:               return visitSUBC(N);
1518   case ISD::SSUBO:
1519   case ISD::USUBO:              return visitSUBO(N);
1520   case ISD::ADDE:               return visitADDE(N);
1521   case ISD::ADDCARRY:           return visitADDCARRY(N);
1522   case ISD::SUBE:               return visitSUBE(N);
1523   case ISD::SUBCARRY:           return visitSUBCARRY(N);
1524   case ISD::SMULFIX:
1525   case ISD::SMULFIXSAT:
1526   case ISD::UMULFIX:
1527   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1528   case ISD::MUL:                return visitMUL(N);
1529   case ISD::SDIV:               return visitSDIV(N);
1530   case ISD::UDIV:               return visitUDIV(N);
1531   case ISD::SREM:
1532   case ISD::UREM:               return visitREM(N);
1533   case ISD::MULHU:              return visitMULHU(N);
1534   case ISD::MULHS:              return visitMULHS(N);
1535   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1536   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1537   case ISD::SMULO:
1538   case ISD::UMULO:              return visitMULO(N);
1539   case ISD::SMIN:
1540   case ISD::SMAX:
1541   case ISD::UMIN:
1542   case ISD::UMAX:               return visitIMINMAX(N);
1543   case ISD::AND:                return visitAND(N);
1544   case ISD::OR:                 return visitOR(N);
1545   case ISD::XOR:                return visitXOR(N);
1546   case ISD::SHL:                return visitSHL(N);
1547   case ISD::SRA:                return visitSRA(N);
1548   case ISD::SRL:                return visitSRL(N);
1549   case ISD::ROTR:
1550   case ISD::ROTL:               return visitRotate(N);
1551   case ISD::FSHL:
1552   case ISD::FSHR:               return visitFunnelShift(N);
1553   case ISD::ABS:                return visitABS(N);
1554   case ISD::BSWAP:              return visitBSWAP(N);
1555   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1556   case ISD::CTLZ:               return visitCTLZ(N);
1557   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1558   case ISD::CTTZ:               return visitCTTZ(N);
1559   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1560   case ISD::CTPOP:              return visitCTPOP(N);
1561   case ISD::SELECT:             return visitSELECT(N);
1562   case ISD::VSELECT:            return visitVSELECT(N);
1563   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1564   case ISD::SETCC:              return visitSETCC(N);
1565   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1566   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1567   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1568   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1569   case ISD::AssertSext:
1570   case ISD::AssertZext:         return visitAssertExt(N);
1571   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
1572   case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N);
1573   case ISD::ZERO_EXTEND_VECTOR_INREG: return visitZERO_EXTEND_VECTOR_INREG(N);
1574   case ISD::TRUNCATE:           return visitTRUNCATE(N);
1575   case ISD::BITCAST:            return visitBITCAST(N);
1576   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
1577   case ISD::FADD:               return visitFADD(N);
1578   case ISD::FSUB:               return visitFSUB(N);
1579   case ISD::FMUL:               return visitFMUL(N);
1580   case ISD::FMA:                return visitFMA(N);
1581   case ISD::FDIV:               return visitFDIV(N);
1582   case ISD::FREM:               return visitFREM(N);
1583   case ISD::FSQRT:              return visitFSQRT(N);
1584   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
1585   case ISD::FPOW:               return visitFPOW(N);
1586   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
1587   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
1588   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
1589   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
1590   case ISD::FP_ROUND:           return visitFP_ROUND(N);
1591   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
1592   case ISD::FNEG:               return visitFNEG(N);
1593   case ISD::FABS:               return visitFABS(N);
1594   case ISD::FFLOOR:             return visitFFLOOR(N);
1595   case ISD::FMINNUM:            return visitFMINNUM(N);
1596   case ISD::FMAXNUM:            return visitFMAXNUM(N);
1597   case ISD::FMINIMUM:           return visitFMINIMUM(N);
1598   case ISD::FMAXIMUM:           return visitFMAXIMUM(N);
1599   case ISD::FCEIL:              return visitFCEIL(N);
1600   case ISD::FTRUNC:             return visitFTRUNC(N);
1601   case ISD::BRCOND:             return visitBRCOND(N);
1602   case ISD::BR_CC:              return visitBR_CC(N);
1603   case ISD::LOAD:               return visitLOAD(N);
1604   case ISD::STORE:              return visitSTORE(N);
1605   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
1606   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1607   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
1608   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
1609   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
1610   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
1611   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
1612   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
1613   case ISD::MGATHER:            return visitMGATHER(N);
1614   case ISD::MLOAD:              return visitMLOAD(N);
1615   case ISD::MSCATTER:           return visitMSCATTER(N);
1616   case ISD::MSTORE:             return visitMSTORE(N);
1617   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
1618   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
1619   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
1620   case ISD::VECREDUCE_FADD:
1621   case ISD::VECREDUCE_FMUL:
1622   case ISD::VECREDUCE_ADD:
1623   case ISD::VECREDUCE_MUL:
1624   case ISD::VECREDUCE_AND:
1625   case ISD::VECREDUCE_OR:
1626   case ISD::VECREDUCE_XOR:
1627   case ISD::VECREDUCE_SMAX:
1628   case ISD::VECREDUCE_SMIN:
1629   case ISD::VECREDUCE_UMAX:
1630   case ISD::VECREDUCE_UMIN:
1631   case ISD::VECREDUCE_FMAX:
1632   case ISD::VECREDUCE_FMIN:     return visitVECREDUCE(N);
1633   }
1634   return SDValue();
1635 }
1636 
combine(SDNode * N)1637 SDValue DAGCombiner::combine(SDNode *N) {
1638   SDValue RV = visit(N);
1639 
1640   // If nothing happened, try a target-specific DAG combine.
1641   if (!RV.getNode()) {
1642     assert(N->getOpcode() != ISD::DELETED_NODE &&
1643            "Node was deleted but visit returned NULL!");
1644 
1645     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1646         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
1647 
1648       // Expose the DAG combiner to the target combiner impls.
1649       TargetLowering::DAGCombinerInfo
1650         DagCombineInfo(DAG, Level, false, this);
1651 
1652       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
1653     }
1654   }
1655 
1656   // If nothing happened still, try promoting the operation.
1657   if (!RV.getNode()) {
1658     switch (N->getOpcode()) {
1659     default: break;
1660     case ISD::ADD:
1661     case ISD::SUB:
1662     case ISD::MUL:
1663     case ISD::AND:
1664     case ISD::OR:
1665     case ISD::XOR:
1666       RV = PromoteIntBinOp(SDValue(N, 0));
1667       break;
1668     case ISD::SHL:
1669     case ISD::SRA:
1670     case ISD::SRL:
1671       RV = PromoteIntShiftOp(SDValue(N, 0));
1672       break;
1673     case ISD::SIGN_EXTEND:
1674     case ISD::ZERO_EXTEND:
1675     case ISD::ANY_EXTEND:
1676       RV = PromoteExtend(SDValue(N, 0));
1677       break;
1678     case ISD::LOAD:
1679       if (PromoteLoad(SDValue(N, 0)))
1680         RV = SDValue(N, 0);
1681       break;
1682     }
1683   }
1684 
1685   // If N is a commutative binary node, try to eliminate it if the commuted
1686   // version is already present in the DAG.
1687   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) &&
1688       N->getNumValues() == 1) {
1689     SDValue N0 = N->getOperand(0);
1690     SDValue N1 = N->getOperand(1);
1691 
1692     // Constant operands are canonicalized to RHS.
1693     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
1694       SDValue Ops[] = {N1, N0};
1695       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
1696                                             N->getFlags());
1697       if (CSENode)
1698         return SDValue(CSENode, 0);
1699     }
1700   }
1701 
1702   return RV;
1703 }
1704 
1705 /// Given a node, return its input chain if it has one, otherwise return a null
1706 /// sd operand.
getInputChainForNode(SDNode * N)1707 static SDValue getInputChainForNode(SDNode *N) {
1708   if (unsigned NumOps = N->getNumOperands()) {
1709     if (N->getOperand(0).getValueType() == MVT::Other)
1710       return N->getOperand(0);
1711     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
1712       return N->getOperand(NumOps-1);
1713     for (unsigned i = 1; i < NumOps-1; ++i)
1714       if (N->getOperand(i).getValueType() == MVT::Other)
1715         return N->getOperand(i);
1716   }
1717   return SDValue();
1718 }
1719 
visitTokenFactor(SDNode * N)1720 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
1721   // If N has two operands, where one has an input chain equal to the other,
1722   // the 'other' chain is redundant.
1723   if (N->getNumOperands() == 2) {
1724     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
1725       return N->getOperand(0);
1726     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
1727       return N->getOperand(1);
1728   }
1729 
1730   // Don't simplify token factors if optnone.
1731   if (OptLevel == CodeGenOpt::None)
1732     return SDValue();
1733 
1734   // If the sole user is a token factor, we should make sure we have a
1735   // chance to merge them together. This prevents TF chains from inhibiting
1736   // optimizations.
1737   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
1738     AddToWorklist(*(N->use_begin()));
1739 
1740   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
1741   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
1742   SmallPtrSet<SDNode*, 16> SeenOps;
1743   bool Changed = false;             // If we should replace this token factor.
1744 
1745   // Start out with this token factor.
1746   TFs.push_back(N);
1747 
1748   // Iterate through token factors.  The TFs grows when new token factors are
1749   // encountered.
1750   for (unsigned i = 0; i < TFs.size(); ++i) {
1751     // Limit number of nodes to inline, to avoid quadratic compile times.
1752     // We have to add the outstanding Token Factors to Ops, otherwise we might
1753     // drop Ops from the resulting Token Factors.
1754     if (Ops.size() > TokenFactorInlineLimit) {
1755       for (unsigned j = i; j < TFs.size(); j++)
1756         Ops.emplace_back(TFs[j], 0);
1757       // Drop unprocessed Token Factors from TFs, so we do not add them to the
1758       // combiner worklist later.
1759       TFs.resize(i);
1760       break;
1761     }
1762 
1763     SDNode *TF = TFs[i];
1764     // Check each of the operands.
1765     for (const SDValue &Op : TF->op_values()) {
1766       switch (Op.getOpcode()) {
1767       case ISD::EntryToken:
1768         // Entry tokens don't need to be added to the list. They are
1769         // redundant.
1770         Changed = true;
1771         break;
1772 
1773       case ISD::TokenFactor:
1774         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
1775           // Queue up for processing.
1776           TFs.push_back(Op.getNode());
1777           Changed = true;
1778           break;
1779         }
1780         LLVM_FALLTHROUGH;
1781 
1782       default:
1783         // Only add if it isn't already in the list.
1784         if (SeenOps.insert(Op.getNode()).second)
1785           Ops.push_back(Op);
1786         else
1787           Changed = true;
1788         break;
1789       }
1790     }
1791   }
1792 
1793   // Re-visit inlined Token Factors, to clean them up in case they have been
1794   // removed. Skip the first Token Factor, as this is the current node.
1795   for (unsigned i = 1, e = TFs.size(); i < e; i++)
1796     AddToWorklist(TFs[i]);
1797 
1798   // Remove Nodes that are chained to another node in the list. Do so
1799   // by walking up chains breath-first stopping when we've seen
1800   // another operand. In general we must climb to the EntryNode, but we can exit
1801   // early if we find all remaining work is associated with just one operand as
1802   // no further pruning is possible.
1803 
1804   // List of nodes to search through and original Ops from which they originate.
1805   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
1806   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
1807   SmallPtrSet<SDNode *, 16> SeenChains;
1808   bool DidPruneOps = false;
1809 
1810   unsigned NumLeftToConsider = 0;
1811   for (const SDValue &Op : Ops) {
1812     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
1813     OpWorkCount.push_back(1);
1814   }
1815 
1816   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
1817     // If this is an Op, we can remove the op from the list. Remark any
1818     // search associated with it as from the current OpNumber.
1819     if (SeenOps.count(Op) != 0) {
1820       Changed = true;
1821       DidPruneOps = true;
1822       unsigned OrigOpNumber = 0;
1823       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
1824         OrigOpNumber++;
1825       assert((OrigOpNumber != Ops.size()) &&
1826              "expected to find TokenFactor Operand");
1827       // Re-mark worklist from OrigOpNumber to OpNumber
1828       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
1829         if (Worklist[i].second == OrigOpNumber) {
1830           Worklist[i].second = OpNumber;
1831         }
1832       }
1833       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
1834       OpWorkCount[OrigOpNumber] = 0;
1835       NumLeftToConsider--;
1836     }
1837     // Add if it's a new chain
1838     if (SeenChains.insert(Op).second) {
1839       OpWorkCount[OpNumber]++;
1840       Worklist.push_back(std::make_pair(Op, OpNumber));
1841     }
1842   };
1843 
1844   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
1845     // We need at least be consider at least 2 Ops to prune.
1846     if (NumLeftToConsider <= 1)
1847       break;
1848     auto CurNode = Worklist[i].first;
1849     auto CurOpNumber = Worklist[i].second;
1850     assert((OpWorkCount[CurOpNumber] > 0) &&
1851            "Node should not appear in worklist");
1852     switch (CurNode->getOpcode()) {
1853     case ISD::EntryToken:
1854       // Hitting EntryToken is the only way for the search to terminate without
1855       // hitting
1856       // another operand's search. Prevent us from marking this operand
1857       // considered.
1858       NumLeftToConsider++;
1859       break;
1860     case ISD::TokenFactor:
1861       for (const SDValue &Op : CurNode->op_values())
1862         AddToWorklist(i, Op.getNode(), CurOpNumber);
1863       break;
1864     case ISD::LIFETIME_START:
1865     case ISD::LIFETIME_END:
1866     case ISD::CopyFromReg:
1867     case ISD::CopyToReg:
1868       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
1869       break;
1870     default:
1871       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
1872         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
1873       break;
1874     }
1875     OpWorkCount[CurOpNumber]--;
1876     if (OpWorkCount[CurOpNumber] == 0)
1877       NumLeftToConsider--;
1878   }
1879 
1880   // If we've changed things around then replace token factor.
1881   if (Changed) {
1882     SDValue Result;
1883     if (Ops.empty()) {
1884       // The entry token is the only possible outcome.
1885       Result = DAG.getEntryNode();
1886     } else {
1887       if (DidPruneOps) {
1888         SmallVector<SDValue, 8> PrunedOps;
1889         //
1890         for (const SDValue &Op : Ops) {
1891           if (SeenChains.count(Op.getNode()) == 0)
1892             PrunedOps.push_back(Op);
1893         }
1894         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
1895       } else {
1896         Result = DAG.getTokenFactor(SDLoc(N), Ops);
1897       }
1898     }
1899     return Result;
1900   }
1901   return SDValue();
1902 }
1903 
1904 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)1905 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
1906   WorklistRemover DeadNodes(*this);
1907   // Replacing results may cause a different MERGE_VALUES to suddenly
1908   // be CSE'd with N, and carry its uses with it. Iterate until no
1909   // uses remain, to ensure that the node can be safely deleted.
1910   // First add the users of this node to the work list so that they
1911   // can be tried again once they have new operands.
1912   AddUsersToWorklist(N);
1913   do {
1914     // Do as a single replacement to avoid rewalking use lists.
1915     SmallVector<SDValue, 8> Ops;
1916     for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1917       Ops.push_back(N->getOperand(i));
1918     DAG.ReplaceAllUsesWith(N, Ops.data());
1919   } while (!N->use_empty());
1920   deleteAndRecombine(N);
1921   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
1922 }
1923 
1924 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
1925 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)1926 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
1927   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
1928   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
1929 }
1930 
foldBinOpIntoSelect(SDNode * BO)1931 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
1932   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
1933          "Unexpected binary operator");
1934 
1935   // Don't do this unless the old select is going away. We want to eliminate the
1936   // binary operator, not replace a binop with a select.
1937   // TODO: Handle ISD::SELECT_CC.
1938   unsigned SelOpNo = 0;
1939   SDValue Sel = BO->getOperand(0);
1940   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
1941     SelOpNo = 1;
1942     Sel = BO->getOperand(1);
1943   }
1944 
1945   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
1946     return SDValue();
1947 
1948   SDValue CT = Sel.getOperand(1);
1949   if (!isConstantOrConstantVector(CT, true) &&
1950       !isConstantFPBuildVectorOrConstantFP(CT))
1951     return SDValue();
1952 
1953   SDValue CF = Sel.getOperand(2);
1954   if (!isConstantOrConstantVector(CF, true) &&
1955       !isConstantFPBuildVectorOrConstantFP(CF))
1956     return SDValue();
1957 
1958   // Bail out if any constants are opaque because we can't constant fold those.
1959   // The exception is "and" and "or" with either 0 or -1 in which case we can
1960   // propagate non constant operands into select. I.e.:
1961   // and (select Cond, 0, -1), X --> select Cond, 0, X
1962   // or X, (select Cond, -1, 0) --> select Cond, -1, X
1963   auto BinOpcode = BO->getOpcode();
1964   bool CanFoldNonConst =
1965       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
1966       (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) &&
1967       (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF));
1968 
1969   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
1970   if (!CanFoldNonConst &&
1971       !isConstantOrConstantVector(CBO, true) &&
1972       !isConstantFPBuildVectorOrConstantFP(CBO))
1973     return SDValue();
1974 
1975   EVT VT = Sel.getValueType();
1976 
1977   // In case of shift value and shift amount may have different VT. For instance
1978   // on x86 shift amount is i8 regardles of LHS type. Bail out if we have
1979   // swapped operands and value types do not match. NB: x86 is fine if operands
1980   // are not swapped with shift amount VT being not bigger than shifted value.
1981   // TODO: that is possible to check for a shift operation, correct VTs and
1982   // still perform optimization on x86 if needed.
1983   if (SelOpNo && VT != CBO.getValueType())
1984     return SDValue();
1985 
1986   // We have a select-of-constants followed by a binary operator with a
1987   // constant. Eliminate the binop by pulling the constant math into the select.
1988   // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
1989   SDLoc DL(Sel);
1990   SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
1991                           : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
1992   if (!CanFoldNonConst && !NewCT.isUndef() &&
1993       !isConstantOrConstantVector(NewCT, true) &&
1994       !isConstantFPBuildVectorOrConstantFP(NewCT))
1995     return SDValue();
1996 
1997   SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
1998                           : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
1999   if (!CanFoldNonConst && !NewCF.isUndef() &&
2000       !isConstantOrConstantVector(NewCF, true) &&
2001       !isConstantFPBuildVectorOrConstantFP(NewCF))
2002     return SDValue();
2003 
2004   SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2005   SelectOp->setFlags(BO->getFlags());
2006   return SelectOp;
2007 }
2008 
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2009 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2010   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2011          "Expecting add or sub");
2012 
2013   // Match a constant operand and a zext operand for the math instruction:
2014   // add Z, C
2015   // sub C, Z
2016   bool IsAdd = N->getOpcode() == ISD::ADD;
2017   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2018   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2019   auto *CN = dyn_cast<ConstantSDNode>(C);
2020   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2021     return SDValue();
2022 
2023   // Match the zext operand as a setcc of a boolean.
2024   if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2025       Z.getOperand(0).getValueType() != MVT::i1)
2026     return SDValue();
2027 
2028   // Match the compare as: setcc (X & 1), 0, eq.
2029   SDValue SetCC = Z.getOperand(0);
2030   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2031   if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2032       SetCC.getOperand(0).getOpcode() != ISD::AND ||
2033       !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2034     return SDValue();
2035 
2036   // We are adding/subtracting a constant and an inverted low bit. Turn that
2037   // into a subtract/add of the low bit with incremented/decremented constant:
2038   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2039   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2040   EVT VT = C.getValueType();
2041   SDLoc DL(N);
2042   SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2043   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2044                        DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2045   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2046 }
2047 
2048 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2049 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2050 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2051   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2052          "Expecting add or sub");
2053 
2054   // We need a constant operand for the add/sub, and the other operand is a
2055   // logical shift right: add (srl), C or sub C, (srl).
2056   // TODO - support non-uniform vector amounts.
2057   bool IsAdd = N->getOpcode() == ISD::ADD;
2058   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2059   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2060   ConstantSDNode *C = isConstOrConstSplat(ConstantOp);
2061   if (!C || ShiftOp.getOpcode() != ISD::SRL)
2062     return SDValue();
2063 
2064   // The shift must be of a 'not' value.
2065   SDValue Not = ShiftOp.getOperand(0);
2066   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2067     return SDValue();
2068 
2069   // The shift must be moving the sign bit to the least-significant-bit.
2070   EVT VT = ShiftOp.getValueType();
2071   SDValue ShAmt = ShiftOp.getOperand(1);
2072   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2073   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2074     return SDValue();
2075 
2076   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2077   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2078   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2079   SDLoc DL(N);
2080   auto ShOpcode = IsAdd ? ISD::SRA : ISD::SRL;
2081   SDValue NewShift = DAG.getNode(ShOpcode, DL, VT, Not.getOperand(0), ShAmt);
2082   APInt NewC = IsAdd ? C->getAPIntValue() + 1 : C->getAPIntValue() - 1;
2083   return DAG.getNode(ISD::ADD, DL, VT, NewShift, DAG.getConstant(NewC, DL, VT));
2084 }
2085 
2086 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2087 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2088 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2089 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2090   SDValue N0 = N->getOperand(0);
2091   SDValue N1 = N->getOperand(1);
2092   EVT VT = N0.getValueType();
2093   SDLoc DL(N);
2094 
2095   // fold vector ops
2096   if (VT.isVector()) {
2097     if (SDValue FoldedVOp = SimplifyVBinOp(N))
2098       return FoldedVOp;
2099 
2100     // fold (add x, 0) -> x, vector edition
2101     if (ISD::isBuildVectorAllZeros(N1.getNode()))
2102       return N0;
2103     if (ISD::isBuildVectorAllZeros(N0.getNode()))
2104       return N1;
2105   }
2106 
2107   // fold (add x, undef) -> undef
2108   if (N0.isUndef())
2109     return N0;
2110 
2111   if (N1.isUndef())
2112     return N1;
2113 
2114   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
2115     // canonicalize constant to RHS
2116     if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
2117       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2118     // fold (add c1, c2) -> c1+c2
2119     return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, N0.getNode(),
2120                                       N1.getNode());
2121   }
2122 
2123   // fold (add x, 0) -> x
2124   if (isNullConstant(N1))
2125     return N0;
2126 
2127   if (isConstantOrConstantVector(N1, /* NoOpaque */ true)) {
2128     // fold ((A-c1)+c2) -> (A+(c2-c1))
2129     if (N0.getOpcode() == ISD::SUB &&
2130         isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2131       SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N1.getNode(),
2132                                                N0.getOperand(1).getNode());
2133       assert(Sub && "Constant folding failed");
2134       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2135     }
2136 
2137     // fold ((c1-A)+c2) -> (c1+c2)-A
2138     if (N0.getOpcode() == ISD::SUB &&
2139         isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) {
2140       SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, N1.getNode(),
2141                                                N0.getOperand(0).getNode());
2142       assert(Add && "Constant folding failed");
2143       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2144     }
2145 
2146     // add (sext i1 X), 1 -> zext (not i1 X)
2147     // We don't transform this pattern:
2148     //   add (zext i1 X), -1 -> sext (not i1 X)
2149     // because most (?) targets generate better code for the zext form.
2150     if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2151         isOneOrOneSplat(N1)) {
2152       SDValue X = N0.getOperand(0);
2153       if ((!LegalOperations ||
2154            (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2155             TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2156           X.getScalarValueSizeInBits() == 1) {
2157         SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2158         return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2159       }
2160     }
2161 
2162     // Undo the add -> or combine to merge constant offsets from a frame index.
2163     if (N0.getOpcode() == ISD::OR &&
2164         isa<FrameIndexSDNode>(N0.getOperand(0)) &&
2165         isa<ConstantSDNode>(N0.getOperand(1)) &&
2166         DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) {
2167       SDValue Add0 = DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(1));
2168       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0);
2169     }
2170   }
2171 
2172   if (SDValue NewSel = foldBinOpIntoSelect(N))
2173     return NewSel;
2174 
2175   // reassociate add
2176   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N0, N1)) {
2177     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2178       return RADD;
2179   }
2180   // fold ((0-A) + B) -> B-A
2181   if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2182     return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2183 
2184   // fold (A + (0-B)) -> A-B
2185   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2186     return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2187 
2188   // fold (A+(B-A)) -> B
2189   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2190     return N1.getOperand(0);
2191 
2192   // fold ((B-A)+A) -> B
2193   if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2194     return N0.getOperand(0);
2195 
2196   // fold ((A-B)+(C-A)) -> (C-B)
2197   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2198       N0.getOperand(0) == N1.getOperand(1))
2199     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2200                        N0.getOperand(1));
2201 
2202   // fold ((A-B)+(B-C)) -> (A-C)
2203   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2204       N0.getOperand(1) == N1.getOperand(0))
2205     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2206                        N1.getOperand(1));
2207 
2208   // fold (A+(B-(A+C))) to (B-C)
2209   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2210       N0 == N1.getOperand(1).getOperand(0))
2211     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2212                        N1.getOperand(1).getOperand(1));
2213 
2214   // fold (A+(B-(C+A))) to (B-C)
2215   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2216       N0 == N1.getOperand(1).getOperand(1))
2217     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2218                        N1.getOperand(1).getOperand(0));
2219 
2220   // fold (A+((B-A)+or-C)) to (B+or-C)
2221   if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2222       N1.getOperand(0).getOpcode() == ISD::SUB &&
2223       N0 == N1.getOperand(0).getOperand(1))
2224     return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2225                        N1.getOperand(1));
2226 
2227   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2228   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB) {
2229     SDValue N00 = N0.getOperand(0);
2230     SDValue N01 = N0.getOperand(1);
2231     SDValue N10 = N1.getOperand(0);
2232     SDValue N11 = N1.getOperand(1);
2233 
2234     if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2235       return DAG.getNode(ISD::SUB, DL, VT,
2236                          DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2237                          DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2238   }
2239 
2240   // fold (add (umax X, C), -C) --> (usubsat X, C)
2241   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2242     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2243       return (!Max && !Op) ||
2244              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2245     };
2246     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2247                                   /*AllowUndefs*/ true))
2248       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2249                          N0.getOperand(1));
2250   }
2251 
2252   if (SimplifyDemandedBits(SDValue(N, 0)))
2253     return SDValue(N, 0);
2254 
2255   if (isOneOrOneSplat(N1)) {
2256     // fold (add (xor a, -1), 1) -> (sub 0, a)
2257     if (isBitwiseNot(N0))
2258       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2259                          N0.getOperand(0));
2260 
2261     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2262     if (N0.getOpcode() == ISD::ADD ||
2263         N0.getOpcode() == ISD::UADDO ||
2264         N0.getOpcode() == ISD::SADDO) {
2265       SDValue A, Xor;
2266 
2267       if (isBitwiseNot(N0.getOperand(0))) {
2268         A = N0.getOperand(1);
2269         Xor = N0.getOperand(0);
2270       } else if (isBitwiseNot(N0.getOperand(1))) {
2271         A = N0.getOperand(0);
2272         Xor = N0.getOperand(1);
2273       }
2274 
2275       if (Xor)
2276         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2277     }
2278 
2279     // Look for:
2280     //   add (add x, y), 1
2281     // And if the target does not like this form then turn into:
2282     //   sub y, (xor x, -1)
2283     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() &&
2284         N0.getOpcode() == ISD::ADD) {
2285       SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2286                                 DAG.getAllOnesConstant(DL, VT));
2287       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2288     }
2289   }
2290 
2291   // (x - y) + -1  ->  add (xor y, -1), x
2292   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2293       isAllOnesOrAllOnesSplat(N1)) {
2294     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2295     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2296   }
2297 
2298   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2299     return Combined;
2300 
2301   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2302     return Combined;
2303 
2304   return SDValue();
2305 }
2306 
visitADD(SDNode * N)2307 SDValue DAGCombiner::visitADD(SDNode *N) {
2308   SDValue N0 = N->getOperand(0);
2309   SDValue N1 = N->getOperand(1);
2310   EVT VT = N0.getValueType();
2311   SDLoc DL(N);
2312 
2313   if (SDValue Combined = visitADDLike(N))
2314     return Combined;
2315 
2316   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2317     return V;
2318 
2319   if (SDValue V = foldAddSubOfSignBit(N, DAG))
2320     return V;
2321 
2322   // fold (a+b) -> (a|b) iff a and b share no bits.
2323   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2324       DAG.haveNoCommonBitsSet(N0, N1))
2325     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2326 
2327   return SDValue();
2328 }
2329 
visitADDSAT(SDNode * N)2330 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2331   unsigned Opcode = N->getOpcode();
2332   SDValue N0 = N->getOperand(0);
2333   SDValue N1 = N->getOperand(1);
2334   EVT VT = N0.getValueType();
2335   SDLoc DL(N);
2336 
2337   // fold vector ops
2338   if (VT.isVector()) {
2339     // TODO SimplifyVBinOp
2340 
2341     // fold (add_sat x, 0) -> x, vector edition
2342     if (ISD::isBuildVectorAllZeros(N1.getNode()))
2343       return N0;
2344     if (ISD::isBuildVectorAllZeros(N0.getNode()))
2345       return N1;
2346   }
2347 
2348   // fold (add_sat x, undef) -> -1
2349   if (N0.isUndef() || N1.isUndef())
2350     return DAG.getAllOnesConstant(DL, VT);
2351 
2352   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
2353     // canonicalize constant to RHS
2354     if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
2355       return DAG.getNode(Opcode, DL, VT, N1, N0);
2356     // fold (add_sat c1, c2) -> c3
2357     return DAG.FoldConstantArithmetic(Opcode, DL, VT, N0.getNode(),
2358                                       N1.getNode());
2359   }
2360 
2361   // fold (add_sat x, 0) -> x
2362   if (isNullConstant(N1))
2363     return N0;
2364 
2365   // If it cannot overflow, transform into an add.
2366   if (Opcode == ISD::UADDSAT)
2367     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2368       return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
2369 
2370   return SDValue();
2371 }
2372 
getAsCarry(const TargetLowering & TLI,SDValue V)2373 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
2374   bool Masked = false;
2375 
2376   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2377   while (true) {
2378     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2379       V = V.getOperand(0);
2380       continue;
2381     }
2382 
2383     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
2384       Masked = true;
2385       V = V.getOperand(0);
2386       continue;
2387     }
2388 
2389     break;
2390   }
2391 
2392   // If this is not a carry, return.
2393   if (V.getResNo() != 1)
2394     return SDValue();
2395 
2396   if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
2397       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
2398     return SDValue();
2399 
2400   EVT VT = V.getNode()->getValueType(0);
2401   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
2402     return SDValue();
2403 
2404   // If the result is masked, then no matter what kind of bool it is we can
2405   // return. If it isn't, then we need to make sure the bool type is either 0 or
2406   // 1 and not other values.
2407   if (Masked ||
2408       TLI.getBooleanContents(V.getValueType()) ==
2409           TargetLoweringBase::ZeroOrOneBooleanContent)
2410     return V;
2411 
2412   return SDValue();
2413 }
2414 
2415 /// Given the operands of an add/sub operation, see if the 2nd operand is a
2416 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
2417 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)2418 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
2419                                  SelectionDAG &DAG, const SDLoc &DL) {
2420   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
2421     return SDValue();
2422 
2423   EVT VT = N0.getValueType();
2424   if (DAG.ComputeNumSignBits(N1.getOperand(0)) != VT.getScalarSizeInBits())
2425     return SDValue();
2426 
2427   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
2428   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
2429   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N1.getOperand(0));
2430 }
2431 
2432 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)2433 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
2434                                           SDNode *LocReference) {
2435   EVT VT = N0.getValueType();
2436   SDLoc DL(LocReference);
2437 
2438   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
2439   if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
2440       isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
2441     return DAG.getNode(ISD::SUB, DL, VT, N0,
2442                        DAG.getNode(ISD::SHL, DL, VT,
2443                                    N1.getOperand(0).getOperand(1),
2444                                    N1.getOperand(1)));
2445 
2446   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
2447     return V;
2448 
2449   // Look for:
2450   //   add (add x, 1), y
2451   // And if the target does not like this form then turn into:
2452   //   sub y, (xor x, -1)
2453   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() &&
2454       N0.getOpcode() == ISD::ADD && isOneOrOneSplat(N0.getOperand(1))) {
2455     SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2456                               DAG.getAllOnesConstant(DL, VT));
2457     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
2458   }
2459 
2460   // Hoist one-use subtraction by non-opaque constant:
2461   //   (x - C) + y  ->  (x + y) - C
2462   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
2463   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2464       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
2465     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
2466     return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2467   }
2468   // Hoist one-use subtraction from non-opaque constant:
2469   //   (C - x) + y  ->  (y - x) + C
2470   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2471       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
2472     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2473     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
2474   }
2475 
2476   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
2477   // rather than 'add 0/-1' (the zext should get folded).
2478   // add (sext i1 Y), X --> sub X, (zext i1 Y)
2479   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
2480       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
2481       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
2482     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
2483     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
2484   }
2485 
2486   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
2487   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2488     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
2489     if (TN->getVT() == MVT::i1) {
2490       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
2491                                  DAG.getConstant(1, DL, VT));
2492       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
2493     }
2494   }
2495 
2496   // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2497   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
2498       N1.getResNo() == 0)
2499     return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
2500                        N0, N1.getOperand(0), N1.getOperand(2));
2501 
2502   // (add X, Carry) -> (addcarry X, 0, Carry)
2503   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2504     if (SDValue Carry = getAsCarry(TLI, N1))
2505       return DAG.getNode(ISD::ADDCARRY, DL,
2506                          DAG.getVTList(VT, Carry.getValueType()), N0,
2507                          DAG.getConstant(0, DL, VT), Carry);
2508 
2509   return SDValue();
2510 }
2511 
visitADDC(SDNode * N)2512 SDValue DAGCombiner::visitADDC(SDNode *N) {
2513   SDValue N0 = N->getOperand(0);
2514   SDValue N1 = N->getOperand(1);
2515   EVT VT = N0.getValueType();
2516   SDLoc DL(N);
2517 
2518   // If the flag result is dead, turn this into an ADD.
2519   if (!N->hasAnyUseOfValue(1))
2520     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2521                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2522 
2523   // canonicalize constant to RHS.
2524   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2525   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2526   if (N0C && !N1C)
2527     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
2528 
2529   // fold (addc x, 0) -> x + no carry out
2530   if (isNullConstant(N1))
2531     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
2532                                         DL, MVT::Glue));
2533 
2534   // If it cannot overflow, transform into an add.
2535   if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2536     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2537                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2538 
2539   return SDValue();
2540 }
2541 
flipBoolean(SDValue V,const SDLoc & DL,SelectionDAG & DAG,const TargetLowering & TLI)2542 static SDValue flipBoolean(SDValue V, const SDLoc &DL,
2543                            SelectionDAG &DAG, const TargetLowering &TLI) {
2544   EVT VT = V.getValueType();
2545 
2546   SDValue Cst;
2547   switch (TLI.getBooleanContents(VT)) {
2548   case TargetLowering::ZeroOrOneBooleanContent:
2549   case TargetLowering::UndefinedBooleanContent:
2550     Cst = DAG.getConstant(1, DL, VT);
2551     break;
2552   case TargetLowering::ZeroOrNegativeOneBooleanContent:
2553     Cst = DAG.getAllOnesConstant(DL, VT);
2554     break;
2555   }
2556 
2557   return DAG.getNode(ISD::XOR, DL, VT, V, Cst);
2558 }
2559 
2560 /**
2561  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
2562  * then the flip also occurs if computing the inverse is the same cost.
2563  * This function returns an empty SDValue in case it cannot flip the boolean
2564  * without increasing the cost of the computation. If you want to flip a boolean
2565  * no matter what, use flipBoolean.
2566  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)2567 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
2568                                   const TargetLowering &TLI,
2569                                   bool Force) {
2570   if (Force && isa<ConstantSDNode>(V))
2571     return flipBoolean(V, SDLoc(V), DAG, TLI);
2572 
2573   if (V.getOpcode() != ISD::XOR)
2574     return SDValue();
2575 
2576   ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
2577   if (!Const)
2578     return SDValue();
2579 
2580   EVT VT = V.getValueType();
2581 
2582   bool IsFlip = false;
2583   switch(TLI.getBooleanContents(VT)) {
2584     case TargetLowering::ZeroOrOneBooleanContent:
2585       IsFlip = Const->isOne();
2586       break;
2587     case TargetLowering::ZeroOrNegativeOneBooleanContent:
2588       IsFlip = Const->isAllOnesValue();
2589       break;
2590     case TargetLowering::UndefinedBooleanContent:
2591       IsFlip = (Const->getAPIntValue() & 0x01) == 1;
2592       break;
2593   }
2594 
2595   if (IsFlip)
2596     return V.getOperand(0);
2597   if (Force)
2598     return flipBoolean(V, SDLoc(V), DAG, TLI);
2599   return SDValue();
2600 }
2601 
visitADDO(SDNode * N)2602 SDValue DAGCombiner::visitADDO(SDNode *N) {
2603   SDValue N0 = N->getOperand(0);
2604   SDValue N1 = N->getOperand(1);
2605   EVT VT = N0.getValueType();
2606   bool IsSigned = (ISD::SADDO == N->getOpcode());
2607 
2608   EVT CarryVT = N->getValueType(1);
2609   SDLoc DL(N);
2610 
2611   // If the flag result is dead, turn this into an ADD.
2612   if (!N->hasAnyUseOfValue(1))
2613     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2614                      DAG.getUNDEF(CarryVT));
2615 
2616   // canonicalize constant to RHS.
2617   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2618       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2619     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
2620 
2621   // fold (addo x, 0) -> x + no carry out
2622   if (isNullOrNullSplat(N1))
2623     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
2624 
2625   if (!IsSigned) {
2626     // If it cannot overflow, transform into an add.
2627     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2628       return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2629                        DAG.getConstant(0, DL, CarryVT));
2630 
2631     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
2632     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
2633       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
2634                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
2635       return CombineTo(N, Sub,
2636                        flipBoolean(Sub.getValue(1), DL, DAG, TLI));
2637     }
2638 
2639     if (SDValue Combined = visitUADDOLike(N0, N1, N))
2640       return Combined;
2641 
2642     if (SDValue Combined = visitUADDOLike(N1, N0, N))
2643       return Combined;
2644   }
2645 
2646   return SDValue();
2647 }
2648 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)2649 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
2650   EVT VT = N0.getValueType();
2651   if (VT.isVector())
2652     return SDValue();
2653 
2654   // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2655   // If Y + 1 cannot overflow.
2656   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
2657     SDValue Y = N1.getOperand(0);
2658     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
2659     if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
2660       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
2661                          N1.getOperand(2));
2662   }
2663 
2664   // (uaddo X, Carry) -> (addcarry X, 0, Carry)
2665   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2666     if (SDValue Carry = getAsCarry(TLI, N1))
2667       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
2668                          DAG.getConstant(0, SDLoc(N), VT), Carry);
2669 
2670   return SDValue();
2671 }
2672 
visitADDE(SDNode * N)2673 SDValue DAGCombiner::visitADDE(SDNode *N) {
2674   SDValue N0 = N->getOperand(0);
2675   SDValue N1 = N->getOperand(1);
2676   SDValue CarryIn = N->getOperand(2);
2677 
2678   // canonicalize constant to RHS
2679   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2680   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2681   if (N0C && !N1C)
2682     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
2683                        N1, N0, CarryIn);
2684 
2685   // fold (adde x, y, false) -> (addc x, y)
2686   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
2687     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
2688 
2689   return SDValue();
2690 }
2691 
visitADDCARRY(SDNode * N)2692 SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
2693   SDValue N0 = N->getOperand(0);
2694   SDValue N1 = N->getOperand(1);
2695   SDValue CarryIn = N->getOperand(2);
2696   SDLoc DL(N);
2697 
2698   // canonicalize constant to RHS
2699   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2700   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2701   if (N0C && !N1C)
2702     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
2703 
2704   // fold (addcarry x, y, false) -> (uaddo x, y)
2705   if (isNullConstant(CarryIn)) {
2706     if (!LegalOperations ||
2707         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
2708       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
2709   }
2710 
2711   // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
2712   if (isNullConstant(N0) && isNullConstant(N1)) {
2713     EVT VT = N0.getValueType();
2714     EVT CarryVT = CarryIn.getValueType();
2715     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
2716     AddToWorklist(CarryExt.getNode());
2717     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
2718                                     DAG.getConstant(1, DL, VT)),
2719                      DAG.getConstant(0, DL, CarryVT));
2720   }
2721 
2722   if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
2723     return Combined;
2724 
2725   if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
2726     return Combined;
2727 
2728   return SDValue();
2729 }
2730 
2731 /**
2732  * If we are facing some sort of diamond carry propapagtion pattern try to
2733  * break it up to generate something like:
2734  *   (addcarry X, 0, (addcarry A, B, Z):Carry)
2735  *
2736  * The end result is usually an increase in operation required, but because the
2737  * carry is now linearized, other tranforms can kick in and optimize the DAG.
2738  *
2739  * Patterns typically look something like
2740  *            (uaddo A, B)
2741  *             /       \
2742  *          Carry      Sum
2743  *            |          \
2744  *            | (addcarry *, 0, Z)
2745  *            |       /
2746  *             \   Carry
2747  *              |   /
2748  * (addcarry X, *, *)
2749  *
2750  * But numerous variation exist. Our goal is to identify A, B, X and Z and
2751  * produce a combine with a single path for carry propagation.
2752  */
combineADDCARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)2753 static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
2754                                       SDValue X, SDValue Carry0, SDValue Carry1,
2755                                       SDNode *N) {
2756   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
2757     return SDValue();
2758   if (Carry1.getOpcode() != ISD::UADDO)
2759     return SDValue();
2760 
2761   SDValue Z;
2762 
2763   /**
2764    * First look for a suitable Z. It will present itself in the form of
2765    * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
2766    */
2767   if (Carry0.getOpcode() == ISD::ADDCARRY &&
2768       isNullConstant(Carry0.getOperand(1))) {
2769     Z = Carry0.getOperand(2);
2770   } else if (Carry0.getOpcode() == ISD::UADDO &&
2771              isOneConstant(Carry0.getOperand(1))) {
2772     EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
2773     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
2774   } else {
2775     // We couldn't find a suitable Z.
2776     return SDValue();
2777   }
2778 
2779 
2780   auto cancelDiamond = [&](SDValue A,SDValue B) {
2781     SDLoc DL(N);
2782     SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z);
2783     Combiner.AddToWorklist(NewY.getNode());
2784     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X,
2785                        DAG.getConstant(0, DL, X.getValueType()),
2786                        NewY.getValue(1));
2787   };
2788 
2789   /**
2790    *      (uaddo A, B)
2791    *           |
2792    *          Sum
2793    *           |
2794    * (addcarry *, 0, Z)
2795    */
2796   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
2797     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
2798   }
2799 
2800   /**
2801    * (addcarry A, 0, Z)
2802    *         |
2803    *        Sum
2804    *         |
2805    *  (uaddo *, B)
2806    */
2807   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
2808     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
2809   }
2810 
2811   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
2812     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
2813   }
2814 
2815   return SDValue();
2816 }
2817 
2818 // If we are facing some sort of diamond carry/borrow in/out pattern try to
2819 // match patterns like:
2820 //
2821 //          (uaddo A, B)            CarryIn
2822 //            |  \                     |
2823 //            |   \                    |
2824 //    PartialSum   PartialCarryOutX   /
2825 //            |        |             /
2826 //            |    ____|____________/
2827 //            |   /    |
2828 //     (uaddo *, *)    \________
2829 //       |  \                   \
2830 //       |   \                   |
2831 //       |    PartialCarryOutY   |
2832 //       |        \              |
2833 //       |         \            /
2834 //   AddCarrySum    |    ______/
2835 //                  |   /
2836 //   CarryOut = (or *, *)
2837 //
2838 // And generate ADDCARRY (or SUBCARRY) with two result values:
2839 //
2840 //    {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn)
2841 //
2842 // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with
2843 // a single path for carry/borrow out propagation:
combineCarryDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,const TargetLowering & TLI,SDValue Carry0,SDValue Carry1,SDNode * N)2844 static SDValue combineCarryDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
2845                                    const TargetLowering &TLI, SDValue Carry0,
2846                                    SDValue Carry1, SDNode *N) {
2847   if (Carry0.getResNo() != 1 || Carry1.getResNo() != 1)
2848     return SDValue();
2849   unsigned Opcode = Carry0.getOpcode();
2850   if (Opcode != Carry1.getOpcode())
2851     return SDValue();
2852   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
2853     return SDValue();
2854 
2855   // Canonicalize the add/sub of A and B as Carry0 and the add/sub of the
2856   // carry/borrow in as Carry1. (The top and middle uaddo nodes respectively in
2857   // the above ASCII art.)
2858   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
2859       Carry1.getOperand(1) != Carry0.getValue(0))
2860     std::swap(Carry0, Carry1);
2861   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
2862       Carry1.getOperand(1) != Carry0.getValue(0))
2863     return SDValue();
2864 
2865   // The carry in value must be on the righthand side for subtraction.
2866   unsigned CarryInOperandNum =
2867       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
2868   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
2869     return SDValue();
2870   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
2871 
2872   unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY;
2873   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
2874     return SDValue();
2875 
2876   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
2877   // TODO: make getAsCarry() aware of how partial carries are merged.
2878   if (CarryIn.getOpcode() != ISD::ZERO_EXTEND)
2879     return SDValue();
2880   CarryIn = CarryIn.getOperand(0);
2881   if (CarryIn.getValueType() != MVT::i1)
2882     return SDValue();
2883 
2884   SDLoc DL(N);
2885   SDValue Merged =
2886       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
2887                   Carry0.getOperand(1), CarryIn);
2888 
2889   // Please note that because we have proven that the result of the UADDO/USUBO
2890   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
2891   // therefore prove that if the first UADDO/USUBO overflows, the second
2892   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
2893   // maximum value.
2894   //
2895   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
2896   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
2897   //
2898   // This is important because it means that OR and XOR can be used to merge
2899   // carry flags; and that AND can return a constant zero.
2900   //
2901   // TODO: match other operations that can merge flags (ADD, etc)
2902   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
2903   if (N->getOpcode() == ISD::AND)
2904     return DAG.getConstant(0, DL, MVT::i1);
2905   return Merged.getValue(1);
2906 }
2907 
visitADDCARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)2908 SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
2909                                        SDNode *N) {
2910   // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry.
2911   if (isBitwiseNot(N0))
2912     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
2913       SDLoc DL(N);
2914       SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1,
2915                                 N0.getOperand(0), NotC);
2916       return CombineTo(N, Sub,
2917                        flipBoolean(Sub.getValue(1), DL, DAG, TLI));
2918     }
2919 
2920   // Iff the flag result is dead:
2921   // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
2922   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
2923   // or the dependency between the instructions.
2924   if ((N0.getOpcode() == ISD::ADD ||
2925        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
2926         N0.getValue(1) != CarryIn)) &&
2927       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
2928     return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
2929                        N0.getOperand(0), N0.getOperand(1), CarryIn);
2930 
2931   /**
2932    * When one of the addcarry argument is itself a carry, we may be facing
2933    * a diamond carry propagation. In which case we try to transform the DAG
2934    * to ensure linear carry propagation if that is possible.
2935    */
2936   if (auto Y = getAsCarry(TLI, N1)) {
2937     // Because both are carries, Y and Z can be swapped.
2938     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
2939       return R;
2940     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
2941       return R;
2942   }
2943 
2944   return SDValue();
2945 }
2946 
2947 // Since it may not be valid to emit a fold to zero for vector initializers
2948 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)2949 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
2950                              SelectionDAG &DAG, bool LegalOperations) {
2951   if (!VT.isVector())
2952     return DAG.getConstant(0, DL, VT);
2953   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
2954     return DAG.getConstant(0, DL, VT);
2955   return SDValue();
2956 }
2957 
visitSUB(SDNode * N)2958 SDValue DAGCombiner::visitSUB(SDNode *N) {
2959   SDValue N0 = N->getOperand(0);
2960   SDValue N1 = N->getOperand(1);
2961   EVT VT = N0.getValueType();
2962   SDLoc DL(N);
2963 
2964   // fold vector ops
2965   if (VT.isVector()) {
2966     if (SDValue FoldedVOp = SimplifyVBinOp(N))
2967       return FoldedVOp;
2968 
2969     // fold (sub x, 0) -> x, vector edition
2970     if (ISD::isBuildVectorAllZeros(N1.getNode()))
2971       return N0;
2972   }
2973 
2974   // fold (sub x, x) -> 0
2975   // FIXME: Refactor this and xor and other similar operations together.
2976   if (N0 == N1)
2977     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
2978   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2979       DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
2980     // fold (sub c1, c2) -> c1-c2
2981     return DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N0.getNode(),
2982                                       N1.getNode());
2983   }
2984 
2985   if (SDValue NewSel = foldBinOpIntoSelect(N))
2986     return NewSel;
2987 
2988   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
2989 
2990   // fold (sub x, c) -> (add x, -c)
2991   if (N1C) {
2992     return DAG.getNode(ISD::ADD, DL, VT, N0,
2993                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
2994   }
2995 
2996   if (isNullOrNullSplat(N0)) {
2997     unsigned BitWidth = VT.getScalarSizeInBits();
2998     // Right-shifting everything out but the sign bit followed by negation is
2999     // the same as flipping arithmetic/logical shift type without the negation:
3000     // -(X >>u 31) -> (X >>s 31)
3001     // -(X >>s 31) -> (X >>u 31)
3002     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3003       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3004       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3005         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3006         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3007           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3008       }
3009     }
3010 
3011     // 0 - X --> 0 if the sub is NUW.
3012     if (N->getFlags().hasNoUnsignedWrap())
3013       return N0;
3014 
3015     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3016       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3017       // N1 must be 0 because negating the minimum signed value is undefined.
3018       if (N->getFlags().hasNoSignedWrap())
3019         return N0;
3020 
3021       // 0 - X --> X if X is 0 or the minimum signed value.
3022       return N1;
3023     }
3024   }
3025 
3026   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3027   if (isAllOnesOrAllOnesSplat(N0))
3028     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3029 
3030   // fold (A - (0-B)) -> A+B
3031   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3032     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3033 
3034   // fold A-(A-B) -> B
3035   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3036     return N1.getOperand(1);
3037 
3038   // fold (A+B)-A -> B
3039   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3040     return N0.getOperand(1);
3041 
3042   // fold (A+B)-B -> A
3043   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3044     return N0.getOperand(0);
3045 
3046   // fold (A+C1)-C2 -> A+(C1-C2)
3047   if (N0.getOpcode() == ISD::ADD &&
3048       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3049       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3050     SDValue NewC = DAG.FoldConstantArithmetic(
3051         ISD::SUB, DL, VT, N0.getOperand(1).getNode(), N1.getNode());
3052     assert(NewC && "Constant folding failed");
3053     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3054   }
3055 
3056   // fold C2-(A+C1) -> (C2-C1)-A
3057   if (N1.getOpcode() == ISD::ADD) {
3058     SDValue N11 = N1.getOperand(1);
3059     if (isConstantOrConstantVector(N0, /* NoOpaques */ true) &&
3060         isConstantOrConstantVector(N11, /* NoOpaques */ true)) {
3061       SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N0.getNode(),
3062                                                 N11.getNode());
3063       assert(NewC && "Constant folding failed");
3064       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3065     }
3066   }
3067 
3068   // fold (A-C1)-C2 -> A-(C1+C2)
3069   if (N0.getOpcode() == ISD::SUB &&
3070       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3071       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3072     SDValue NewC = DAG.FoldConstantArithmetic(
3073         ISD::ADD, DL, VT, N0.getOperand(1).getNode(), N1.getNode());
3074     assert(NewC && "Constant folding failed");
3075     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3076   }
3077 
3078   // fold (c1-A)-c2 -> (c1-c2)-A
3079   if (N0.getOpcode() == ISD::SUB &&
3080       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3081       isConstantOrConstantVector(N0.getOperand(0), /* NoOpaques */ true)) {
3082     SDValue NewC = DAG.FoldConstantArithmetic(
3083         ISD::SUB, DL, VT, N0.getOperand(0).getNode(), N1.getNode());
3084     assert(NewC && "Constant folding failed");
3085     return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3086   }
3087 
3088   // fold ((A+(B+or-C))-B) -> A+or-C
3089   if (N0.getOpcode() == ISD::ADD &&
3090       (N0.getOperand(1).getOpcode() == ISD::SUB ||
3091        N0.getOperand(1).getOpcode() == ISD::ADD) &&
3092       N0.getOperand(1).getOperand(0) == N1)
3093     return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3094                        N0.getOperand(1).getOperand(1));
3095 
3096   // fold ((A+(C+B))-B) -> A+C
3097   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3098       N0.getOperand(1).getOperand(1) == N1)
3099     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3100                        N0.getOperand(1).getOperand(0));
3101 
3102   // fold ((A-(B-C))-C) -> A-B
3103   if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3104       N0.getOperand(1).getOperand(1) == N1)
3105     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3106                        N0.getOperand(1).getOperand(0));
3107 
3108   // fold (A-(B-C)) -> A+(C-B)
3109   if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3110     return DAG.getNode(ISD::ADD, DL, VT, N0,
3111                        DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3112                                    N1.getOperand(0)));
3113 
3114   // A - (A & B)  ->  A & (~B)
3115   if (N1.getOpcode() == ISD::AND) {
3116     SDValue A = N1.getOperand(0);
3117     SDValue B = N1.getOperand(1);
3118     if (A != N0)
3119       std::swap(A, B);
3120     if (A == N0 &&
3121         (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3122       SDValue InvB =
3123           DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3124       return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3125     }
3126   }
3127 
3128   // fold (X - (-Y * Z)) -> (X + (Y * Z))
3129   if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3130     if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3131         isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3132       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3133                                 N1.getOperand(0).getOperand(1),
3134                                 N1.getOperand(1));
3135       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3136     }
3137     if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3138         isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3139       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3140                                 N1.getOperand(0),
3141                                 N1.getOperand(1).getOperand(1));
3142       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3143     }
3144   }
3145 
3146   // If either operand of a sub is undef, the result is undef
3147   if (N0.isUndef())
3148     return N0;
3149   if (N1.isUndef())
3150     return N1;
3151 
3152   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3153     return V;
3154 
3155   if (SDValue V = foldAddSubOfSignBit(N, DAG))
3156     return V;
3157 
3158   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3159     return V;
3160 
3161   // (x - y) - 1  ->  add (xor y, -1), x
3162   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && isOneOrOneSplat(N1)) {
3163     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
3164                               DAG.getAllOnesConstant(DL, VT));
3165     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
3166   }
3167 
3168   // Look for:
3169   //   sub y, (xor x, -1)
3170   // And if the target does not like this form then turn into:
3171   //   add (add x, y), 1
3172   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3173     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3174     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3175   }
3176 
3177   // Hoist one-use addition by non-opaque constant:
3178   //   (x + C) - y  ->  (x - y) + C
3179   if (N0.hasOneUse() && N0.getOpcode() == ISD::ADD &&
3180       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3181     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3182     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3183   }
3184   // y - (x + C)  ->  (y - x) - C
3185   if (N1.hasOneUse() && N1.getOpcode() == ISD::ADD &&
3186       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
3187     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
3188     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
3189   }
3190   // (x - C) - y  ->  (x - y) - C
3191   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3192   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3193       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3194     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3195     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
3196   }
3197   // (C - x) - y  ->  C - (x + y)
3198   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3199       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3200     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
3201     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
3202   }
3203 
3204   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3205   // rather than 'sub 0/1' (the sext should get folded).
3206   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3207   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3208       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
3209       TLI.getBooleanContents(VT) ==
3210           TargetLowering::ZeroOrNegativeOneBooleanContent) {
3211     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
3212     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
3213   }
3214 
3215   // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3216   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
3217     if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3218       SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
3219       SDValue S0 = N1.getOperand(0);
3220       if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0)) {
3221         unsigned OpSizeInBits = VT.getScalarSizeInBits();
3222         if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
3223           if (C->getAPIntValue() == (OpSizeInBits - 1))
3224             return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
3225       }
3226     }
3227   }
3228 
3229   // If the relocation model supports it, consider symbol offsets.
3230   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
3231     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3232       // fold (sub Sym, c) -> Sym-c
3233       if (N1C && GA->getOpcode() == ISD::GlobalAddress)
3234         return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
3235                                     GA->getOffset() -
3236                                         (uint64_t)N1C->getSExtValue());
3237       // fold (sub Sym+c1, Sym+c2) -> c1-c2
3238       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
3239         if (GA->getGlobal() == GB->getGlobal())
3240           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
3241                                  DL, VT);
3242     }
3243 
3244   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3245   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3246     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3247     if (TN->getVT() == MVT::i1) {
3248       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3249                                  DAG.getConstant(1, DL, VT));
3250       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
3251     }
3252   }
3253 
3254   // Prefer an add for more folding potential and possibly better codegen:
3255   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3256   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3257     SDValue ShAmt = N1.getOperand(1);
3258     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
3259     if (ShAmtC &&
3260         ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
3261       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
3262       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
3263     }
3264   }
3265 
3266   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) {
3267     // (sub Carry, X)  ->  (addcarry (sub 0, X), 0, Carry)
3268     if (SDValue Carry = getAsCarry(TLI, N0)) {
3269       SDValue X = N1;
3270       SDValue Zero = DAG.getConstant(0, DL, VT);
3271       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
3272       return DAG.getNode(ISD::ADDCARRY, DL,
3273                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
3274                          Carry);
3275     }
3276   }
3277 
3278   return SDValue();
3279 }
3280 
visitSUBSAT(SDNode * N)3281 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
3282   SDValue N0 = N->getOperand(0);
3283   SDValue N1 = N->getOperand(1);
3284   EVT VT = N0.getValueType();
3285   SDLoc DL(N);
3286 
3287   // fold vector ops
3288   if (VT.isVector()) {
3289     // TODO SimplifyVBinOp
3290 
3291     // fold (sub_sat x, 0) -> x, vector edition
3292     if (ISD::isBuildVectorAllZeros(N1.getNode()))
3293       return N0;
3294   }
3295 
3296   // fold (sub_sat x, undef) -> 0
3297   if (N0.isUndef() || N1.isUndef())
3298     return DAG.getConstant(0, DL, VT);
3299 
3300   // fold (sub_sat x, x) -> 0
3301   if (N0 == N1)
3302     return DAG.getConstant(0, DL, VT);
3303 
3304   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3305       DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
3306     // fold (sub_sat c1, c2) -> c3
3307     return DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, N0.getNode(),
3308                                       N1.getNode());
3309   }
3310 
3311   // fold (sub_sat x, 0) -> x
3312   if (isNullConstant(N1))
3313     return N0;
3314 
3315   return SDValue();
3316 }
3317 
visitSUBC(SDNode * N)3318 SDValue DAGCombiner::visitSUBC(SDNode *N) {
3319   SDValue N0 = N->getOperand(0);
3320   SDValue N1 = N->getOperand(1);
3321   EVT VT = N0.getValueType();
3322   SDLoc DL(N);
3323 
3324   // If the flag result is dead, turn this into an SUB.
3325   if (!N->hasAnyUseOfValue(1))
3326     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3327                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3328 
3329   // fold (subc x, x) -> 0 + no borrow
3330   if (N0 == N1)
3331     return CombineTo(N, DAG.getConstant(0, DL, VT),
3332                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3333 
3334   // fold (subc x, 0) -> x + no borrow
3335   if (isNullConstant(N1))
3336     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3337 
3338   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3339   if (isAllOnesConstant(N0))
3340     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3341                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3342 
3343   return SDValue();
3344 }
3345 
visitSUBO(SDNode * N)3346 SDValue DAGCombiner::visitSUBO(SDNode *N) {
3347   SDValue N0 = N->getOperand(0);
3348   SDValue N1 = N->getOperand(1);
3349   EVT VT = N0.getValueType();
3350   bool IsSigned = (ISD::SSUBO == N->getOpcode());
3351 
3352   EVT CarryVT = N->getValueType(1);
3353   SDLoc DL(N);
3354 
3355   // If the flag result is dead, turn this into an SUB.
3356   if (!N->hasAnyUseOfValue(1))
3357     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3358                      DAG.getUNDEF(CarryVT));
3359 
3360   // fold (subo x, x) -> 0 + no borrow
3361   if (N0 == N1)
3362     return CombineTo(N, DAG.getConstant(0, DL, VT),
3363                      DAG.getConstant(0, DL, CarryVT));
3364 
3365   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3366 
3367   // fold (subox, c) -> (addo x, -c)
3368   if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) {
3369     return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
3370                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3371   }
3372 
3373   // fold (subo x, 0) -> x + no borrow
3374   if (isNullOrNullSplat(N1))
3375     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3376 
3377   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3378   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
3379     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3380                      DAG.getConstant(0, DL, CarryVT));
3381 
3382   return SDValue();
3383 }
3384 
visitSUBE(SDNode * N)3385 SDValue DAGCombiner::visitSUBE(SDNode *N) {
3386   SDValue N0 = N->getOperand(0);
3387   SDValue N1 = N->getOperand(1);
3388   SDValue CarryIn = N->getOperand(2);
3389 
3390   // fold (sube x, y, false) -> (subc x, y)
3391   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3392     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
3393 
3394   return SDValue();
3395 }
3396 
visitSUBCARRY(SDNode * N)3397 SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
3398   SDValue N0 = N->getOperand(0);
3399   SDValue N1 = N->getOperand(1);
3400   SDValue CarryIn = N->getOperand(2);
3401 
3402   // fold (subcarry x, y, false) -> (usubo x, y)
3403   if (isNullConstant(CarryIn)) {
3404     if (!LegalOperations ||
3405         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
3406       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
3407   }
3408 
3409   return SDValue();
3410 }
3411 
3412 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
3413 // UMULFIXSAT here.
visitMULFIX(SDNode * N)3414 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
3415   SDValue N0 = N->getOperand(0);
3416   SDValue N1 = N->getOperand(1);
3417   SDValue Scale = N->getOperand(2);
3418   EVT VT = N0.getValueType();
3419 
3420   // fold (mulfix x, undef, scale) -> 0
3421   if (N0.isUndef() || N1.isUndef())
3422     return DAG.getConstant(0, SDLoc(N), VT);
3423 
3424   // Canonicalize constant to RHS (vector doesn't have to splat)
3425   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3426      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3427     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
3428 
3429   // fold (mulfix x, 0, scale) -> 0
3430   if (isNullConstant(N1))
3431     return DAG.getConstant(0, SDLoc(N), VT);
3432 
3433   return SDValue();
3434 }
3435 
visitMUL(SDNode * N)3436 SDValue DAGCombiner::visitMUL(SDNode *N) {
3437   SDValue N0 = N->getOperand(0);
3438   SDValue N1 = N->getOperand(1);
3439   EVT VT = N0.getValueType();
3440 
3441   // fold (mul x, undef) -> 0
3442   if (N0.isUndef() || N1.isUndef())
3443     return DAG.getConstant(0, SDLoc(N), VT);
3444 
3445   bool N0IsConst = false;
3446   bool N1IsConst = false;
3447   bool N1IsOpaqueConst = false;
3448   bool N0IsOpaqueConst = false;
3449   APInt ConstValue0, ConstValue1;
3450   // fold vector ops
3451   if (VT.isVector()) {
3452     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3453       return FoldedVOp;
3454 
3455     N0IsConst = ISD::isConstantSplatVector(N0.getNode(), ConstValue0);
3456     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
3457     assert((!N0IsConst ||
3458             ConstValue0.getBitWidth() == VT.getScalarSizeInBits()) &&
3459            "Splat APInt should be element width");
3460     assert((!N1IsConst ||
3461             ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
3462            "Splat APInt should be element width");
3463   } else {
3464     N0IsConst = isa<ConstantSDNode>(N0);
3465     if (N0IsConst) {
3466       ConstValue0 = cast<ConstantSDNode>(N0)->getAPIntValue();
3467       N0IsOpaqueConst = cast<ConstantSDNode>(N0)->isOpaque();
3468     }
3469     N1IsConst = isa<ConstantSDNode>(N1);
3470     if (N1IsConst) {
3471       ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue();
3472       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
3473     }
3474   }
3475 
3476   // fold (mul c1, c2) -> c1*c2
3477   if (N0IsConst && N1IsConst && !N0IsOpaqueConst && !N1IsOpaqueConst)
3478     return DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT,
3479                                       N0.getNode(), N1.getNode());
3480 
3481   // canonicalize constant to RHS (vector doesn't have to splat)
3482   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3483      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3484     return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
3485   // fold (mul x, 0) -> 0
3486   if (N1IsConst && ConstValue1.isNullValue())
3487     return N1;
3488   // fold (mul x, 1) -> x
3489   if (N1IsConst && ConstValue1.isOneValue())
3490     return N0;
3491 
3492   if (SDValue NewSel = foldBinOpIntoSelect(N))
3493     return NewSel;
3494 
3495   // fold (mul x, -1) -> 0-x
3496   if (N1IsConst && ConstValue1.isAllOnesValue()) {
3497     SDLoc DL(N);
3498     return DAG.getNode(ISD::SUB, DL, VT,
3499                        DAG.getConstant(0, DL, VT), N0);
3500   }
3501   // fold (mul x, (1 << c)) -> x << c
3502   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
3503       DAG.isKnownToBeAPowerOfTwo(N1) &&
3504       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
3505     SDLoc DL(N);
3506     SDValue LogBase2 = BuildLogBase2(N1, DL);
3507     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
3508     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
3509     return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
3510   }
3511   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
3512   if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) {
3513     unsigned Log2Val = (-ConstValue1).logBase2();
3514     SDLoc DL(N);
3515     // FIXME: If the input is something that is easily negated (e.g. a
3516     // single-use add), we should put the negate there.
3517     return DAG.getNode(ISD::SUB, DL, VT,
3518                        DAG.getConstant(0, DL, VT),
3519                        DAG.getNode(ISD::SHL, DL, VT, N0,
3520                             DAG.getConstant(Log2Val, DL,
3521                                       getShiftAmountTy(N0.getValueType()))));
3522   }
3523 
3524   // Try to transform multiply-by-(power-of-2 +/- 1) into shift and add/sub.
3525   // mul x, (2^N + 1) --> add (shl x, N), x
3526   // mul x, (2^N - 1) --> sub (shl x, N), x
3527   // Examples: x * 33 --> (x << 5) + x
3528   //           x * 15 --> (x << 4) - x
3529   //           x * -33 --> -((x << 5) + x)
3530   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
3531   if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
3532     // TODO: We could handle more general decomposition of any constant by
3533     //       having the target set a limit on number of ops and making a
3534     //       callback to determine that sequence (similar to sqrt expansion).
3535     unsigned MathOp = ISD::DELETED_NODE;
3536     APInt MulC = ConstValue1.abs();
3537     if ((MulC - 1).isPowerOf2())
3538       MathOp = ISD::ADD;
3539     else if ((MulC + 1).isPowerOf2())
3540       MathOp = ISD::SUB;
3541 
3542     if (MathOp != ISD::DELETED_NODE) {
3543       unsigned ShAmt =
3544           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
3545       assert(ShAmt < VT.getScalarSizeInBits() &&
3546              "multiply-by-constant generated out of bounds shift");
3547       SDLoc DL(N);
3548       SDValue Shl =
3549           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
3550       SDValue R = DAG.getNode(MathOp, DL, VT, Shl, N0);
3551       if (ConstValue1.isNegative())
3552         R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R);
3553       return R;
3554     }
3555   }
3556 
3557   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
3558   if (N0.getOpcode() == ISD::SHL &&
3559       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3560       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3561     SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1));
3562     if (isConstantOrConstantVector(C3))
3563       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3);
3564   }
3565 
3566   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
3567   // use.
3568   {
3569     SDValue Sh(nullptr, 0), Y(nullptr, 0);
3570 
3571     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
3572     if (N0.getOpcode() == ISD::SHL &&
3573         isConstantOrConstantVector(N0.getOperand(1)) &&
3574         N0.getNode()->hasOneUse()) {
3575       Sh = N0; Y = N1;
3576     } else if (N1.getOpcode() == ISD::SHL &&
3577                isConstantOrConstantVector(N1.getOperand(1)) &&
3578                N1.getNode()->hasOneUse()) {
3579       Sh = N1; Y = N0;
3580     }
3581 
3582     if (Sh.getNode()) {
3583       SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, Sh.getOperand(0), Y);
3584       return DAG.getNode(ISD::SHL, SDLoc(N), VT, Mul, Sh.getOperand(1));
3585     }
3586   }
3587 
3588   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
3589   if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
3590       N0.getOpcode() == ISD::ADD &&
3591       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
3592       isMulAddWithConstProfitable(N, N0, N1))
3593       return DAG.getNode(ISD::ADD, SDLoc(N), VT,
3594                          DAG.getNode(ISD::MUL, SDLoc(N0), VT,
3595                                      N0.getOperand(0), N1),
3596                          DAG.getNode(ISD::MUL, SDLoc(N1), VT,
3597                                      N0.getOperand(1), N1));
3598 
3599   // reassociate mul
3600   if (SDValue RMUL = reassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags()))
3601     return RMUL;
3602 
3603   return SDValue();
3604 }
3605 
3606 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)3607 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
3608                                      const TargetLowering &TLI) {
3609   RTLIB::Libcall LC;
3610   EVT NodeType = Node->getValueType(0);
3611   if (!NodeType.isSimple())
3612     return false;
3613   switch (NodeType.getSimpleVT().SimpleTy) {
3614   default: return false; // No libcall for vector types.
3615   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
3616   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
3617   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
3618   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
3619   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
3620   }
3621 
3622   return TLI.getLibcallName(LC) != nullptr;
3623 }
3624 
3625 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)3626 SDValue DAGCombiner::useDivRem(SDNode *Node) {
3627   if (Node->use_empty())
3628     return SDValue(); // This is a dead node, leave it alone.
3629 
3630   unsigned Opcode = Node->getOpcode();
3631   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
3632   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
3633 
3634   // DivMod lib calls can still work on non-legal types if using lib-calls.
3635   EVT VT = Node->getValueType(0);
3636   if (VT.isVector() || !VT.isInteger())
3637     return SDValue();
3638 
3639   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
3640     return SDValue();
3641 
3642   // If DIVREM is going to get expanded into a libcall,
3643   // but there is no libcall available, then don't combine.
3644   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
3645       !isDivRemLibcallAvailable(Node, isSigned, TLI))
3646     return SDValue();
3647 
3648   // If div is legal, it's better to do the normal expansion
3649   unsigned OtherOpcode = 0;
3650   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
3651     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
3652     if (TLI.isOperationLegalOrCustom(Opcode, VT))
3653       return SDValue();
3654   } else {
3655     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
3656     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
3657       return SDValue();
3658   }
3659 
3660   SDValue Op0 = Node->getOperand(0);
3661   SDValue Op1 = Node->getOperand(1);
3662   SDValue combined;
3663   for (SDNode::use_iterator UI = Op0.getNode()->use_begin(),
3664          UE = Op0.getNode()->use_end(); UI != UE; ++UI) {
3665     SDNode *User = *UI;
3666     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
3667         User->use_empty())
3668       continue;
3669     // Convert the other matching node(s), too;
3670     // otherwise, the DIVREM may get target-legalized into something
3671     // target-specific that we won't be able to recognize.
3672     unsigned UserOpc = User->getOpcode();
3673     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
3674         User->getOperand(0) == Op0 &&
3675         User->getOperand(1) == Op1) {
3676       if (!combined) {
3677         if (UserOpc == OtherOpcode) {
3678           SDVTList VTs = DAG.getVTList(VT, VT);
3679           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
3680         } else if (UserOpc == DivRemOpc) {
3681           combined = SDValue(User, 0);
3682         } else {
3683           assert(UserOpc == Opcode);
3684           continue;
3685         }
3686       }
3687       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
3688         CombineTo(User, combined);
3689       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
3690         CombineTo(User, combined.getValue(1));
3691     }
3692   }
3693   return combined;
3694 }
3695 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)3696 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
3697   SDValue N0 = N->getOperand(0);
3698   SDValue N1 = N->getOperand(1);
3699   EVT VT = N->getValueType(0);
3700   SDLoc DL(N);
3701 
3702   unsigned Opc = N->getOpcode();
3703   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
3704   ConstantSDNode *N1C = isConstOrConstSplat(N1);
3705 
3706   // X / undef -> undef
3707   // X % undef -> undef
3708   // X / 0 -> undef
3709   // X % 0 -> undef
3710   // NOTE: This includes vectors where any divisor element is zero/undef.
3711   if (DAG.isUndef(Opc, {N0, N1}))
3712     return DAG.getUNDEF(VT);
3713 
3714   // undef / X -> 0
3715   // undef % X -> 0
3716   if (N0.isUndef())
3717     return DAG.getConstant(0, DL, VT);
3718 
3719   // 0 / X -> 0
3720   // 0 % X -> 0
3721   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3722   if (N0C && N0C->isNullValue())
3723     return N0;
3724 
3725   // X / X -> 1
3726   // X % X -> 0
3727   if (N0 == N1)
3728     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
3729 
3730   // X / 1 -> X
3731   // X % 1 -> 0
3732   // If this is a boolean op (single-bit element type), we can't have
3733   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
3734   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
3735   // it's a 1.
3736   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
3737     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
3738 
3739   return SDValue();
3740 }
3741 
visitSDIV(SDNode * N)3742 SDValue DAGCombiner::visitSDIV(SDNode *N) {
3743   SDValue N0 = N->getOperand(0);
3744   SDValue N1 = N->getOperand(1);
3745   EVT VT = N->getValueType(0);
3746   EVT CCVT = getSetCCResultType(VT);
3747 
3748   // fold vector ops
3749   if (VT.isVector())
3750     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3751       return FoldedVOp;
3752 
3753   SDLoc DL(N);
3754 
3755   // fold (sdiv c1, c2) -> c1/c2
3756   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3757   ConstantSDNode *N1C = isConstOrConstSplat(N1);
3758   if (N0C && N1C && !N0C->isOpaque() && !N1C->isOpaque())
3759     return DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, N0C, N1C);
3760   // fold (sdiv X, -1) -> 0-X
3761   if (N1C && N1C->isAllOnesValue())
3762     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
3763   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
3764   if (N1C && N1C->getAPIntValue().isMinSignedValue())
3765     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
3766                          DAG.getConstant(1, DL, VT),
3767                          DAG.getConstant(0, DL, VT));
3768 
3769   if (SDValue V = simplifyDivRem(N, DAG))
3770     return V;
3771 
3772   if (SDValue NewSel = foldBinOpIntoSelect(N))
3773     return NewSel;
3774 
3775   // If we know the sign bits of both operands are zero, strength reduce to a
3776   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
3777   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
3778     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
3779 
3780   if (SDValue V = visitSDIVLike(N0, N1, N)) {
3781     // If the corresponding remainder node exists, update its users with
3782     // (Dividend - (Quotient * Divisor).
3783     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
3784                                               { N0, N1 })) {
3785       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
3786       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
3787       AddToWorklist(Mul.getNode());
3788       AddToWorklist(Sub.getNode());
3789       CombineTo(RemNode, Sub);
3790     }
3791     return V;
3792   }
3793 
3794   // sdiv, srem -> sdivrem
3795   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
3796   // true.  Otherwise, we break the simplification logic in visitREM().
3797   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
3798   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
3799     if (SDValue DivRem = useDivRem(N))
3800         return DivRem;
3801 
3802   return SDValue();
3803 }
3804 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)3805 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
3806   SDLoc DL(N);
3807   EVT VT = N->getValueType(0);
3808   EVT CCVT = getSetCCResultType(VT);
3809   unsigned BitWidth = VT.getScalarSizeInBits();
3810 
3811   // Helper for determining whether a value is a power-2 constant scalar or a
3812   // vector of such elements.
3813   auto IsPowerOfTwo = [](ConstantSDNode *C) {
3814     if (C->isNullValue() || C->isOpaque())
3815       return false;
3816     if (C->getAPIntValue().isPowerOf2())
3817       return true;
3818     if ((-C->getAPIntValue()).isPowerOf2())
3819       return true;
3820     return false;
3821   };
3822 
3823   // fold (sdiv X, pow2) -> simple ops after legalize
3824   // FIXME: We check for the exact bit here because the generic lowering gives
3825   // better results in that case. The target-specific lowering should learn how
3826   // to handle exact sdivs efficiently.
3827   if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) {
3828     // Target-specific implementation of sdiv x, pow2.
3829     if (SDValue Res = BuildSDIVPow2(N))
3830       return Res;
3831 
3832     // Create constants that are functions of the shift amount value.
3833     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
3834     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
3835     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
3836     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
3837     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
3838     if (!isConstantOrConstantVector(Inexact))
3839       return SDValue();
3840 
3841     // Splat the sign bit into the register
3842     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
3843                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
3844     AddToWorklist(Sign.getNode());
3845 
3846     // Add (N0 < 0) ? abs2 - 1 : 0;
3847     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
3848     AddToWorklist(Srl.getNode());
3849     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
3850     AddToWorklist(Add.getNode());
3851     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
3852     AddToWorklist(Sra.getNode());
3853 
3854     // Special case: (sdiv X, 1) -> X
3855     // Special Case: (sdiv X, -1) -> 0-X
3856     SDValue One = DAG.getConstant(1, DL, VT);
3857     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
3858     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
3859     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
3860     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
3861     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
3862 
3863     // If dividing by a positive value, we're done. Otherwise, the result must
3864     // be negated.
3865     SDValue Zero = DAG.getConstant(0, DL, VT);
3866     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
3867 
3868     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
3869     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
3870     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
3871     return Res;
3872   }
3873 
3874   // If integer divide is expensive and we satisfy the requirements, emit an
3875   // alternate sequence.  Targets may check function attributes for size/speed
3876   // trade-offs.
3877   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
3878   if (isConstantOrConstantVector(N1) &&
3879       !TLI.isIntDivCheap(N->getValueType(0), Attr))
3880     if (SDValue Op = BuildSDIV(N))
3881       return Op;
3882 
3883   return SDValue();
3884 }
3885 
visitUDIV(SDNode * N)3886 SDValue DAGCombiner::visitUDIV(SDNode *N) {
3887   SDValue N0 = N->getOperand(0);
3888   SDValue N1 = N->getOperand(1);
3889   EVT VT = N->getValueType(0);
3890   EVT CCVT = getSetCCResultType(VT);
3891 
3892   // fold vector ops
3893   if (VT.isVector())
3894     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3895       return FoldedVOp;
3896 
3897   SDLoc DL(N);
3898 
3899   // fold (udiv c1, c2) -> c1/c2
3900   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3901   ConstantSDNode *N1C = isConstOrConstSplat(N1);
3902   if (N0C && N1C)
3903     if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT,
3904                                                     N0C, N1C))
3905       return Folded;
3906   // fold (udiv X, -1) -> select(X == -1, 1, 0)
3907   if (N1C && N1C->getAPIntValue().isAllOnesValue())
3908     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
3909                          DAG.getConstant(1, DL, VT),
3910                          DAG.getConstant(0, DL, VT));
3911 
3912   if (SDValue V = simplifyDivRem(N, DAG))
3913     return V;
3914 
3915   if (SDValue NewSel = foldBinOpIntoSelect(N))
3916     return NewSel;
3917 
3918   if (SDValue V = visitUDIVLike(N0, N1, N)) {
3919     // If the corresponding remainder node exists, update its users with
3920     // (Dividend - (Quotient * Divisor).
3921     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
3922                                               { N0, N1 })) {
3923       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
3924       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
3925       AddToWorklist(Mul.getNode());
3926       AddToWorklist(Sub.getNode());
3927       CombineTo(RemNode, Sub);
3928     }
3929     return V;
3930   }
3931 
3932   // sdiv, srem -> sdivrem
3933   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
3934   // true.  Otherwise, we break the simplification logic in visitREM().
3935   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
3936   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
3937     if (SDValue DivRem = useDivRem(N))
3938         return DivRem;
3939 
3940   return SDValue();
3941 }
3942 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)3943 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
3944   SDLoc DL(N);
3945   EVT VT = N->getValueType(0);
3946 
3947   // fold (udiv x, (1 << c)) -> x >>u c
3948   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
3949       DAG.isKnownToBeAPowerOfTwo(N1)) {
3950     SDValue LogBase2 = BuildLogBase2(N1, DL);
3951     AddToWorklist(LogBase2.getNode());
3952 
3953     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
3954     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
3955     AddToWorklist(Trunc.getNode());
3956     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
3957   }
3958 
3959   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
3960   if (N1.getOpcode() == ISD::SHL) {
3961     SDValue N10 = N1.getOperand(0);
3962     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
3963         DAG.isKnownToBeAPowerOfTwo(N10)) {
3964       SDValue LogBase2 = BuildLogBase2(N10, DL);
3965       AddToWorklist(LogBase2.getNode());
3966 
3967       EVT ADDVT = N1.getOperand(1).getValueType();
3968       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
3969       AddToWorklist(Trunc.getNode());
3970       SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
3971       AddToWorklist(Add.getNode());
3972       return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
3973     }
3974   }
3975 
3976   // fold (udiv x, c) -> alternate
3977   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
3978   if (isConstantOrConstantVector(N1) &&
3979       !TLI.isIntDivCheap(N->getValueType(0), Attr))
3980     if (SDValue Op = BuildUDIV(N))
3981       return Op;
3982 
3983   return SDValue();
3984 }
3985 
3986 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)3987 SDValue DAGCombiner::visitREM(SDNode *N) {
3988   unsigned Opcode = N->getOpcode();
3989   SDValue N0 = N->getOperand(0);
3990   SDValue N1 = N->getOperand(1);
3991   EVT VT = N->getValueType(0);
3992   EVT CCVT = getSetCCResultType(VT);
3993 
3994   bool isSigned = (Opcode == ISD::SREM);
3995   SDLoc DL(N);
3996 
3997   // fold (rem c1, c2) -> c1%c2
3998   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3999   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4000   if (N0C && N1C)
4001     if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C))
4002       return Folded;
4003   // fold (urem X, -1) -> select(X == -1, 0, x)
4004   if (!isSigned && N1C && N1C->getAPIntValue().isAllOnesValue())
4005     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4006                          DAG.getConstant(0, DL, VT), N0);
4007 
4008   if (SDValue V = simplifyDivRem(N, DAG))
4009     return V;
4010 
4011   if (SDValue NewSel = foldBinOpIntoSelect(N))
4012     return NewSel;
4013 
4014   if (isSigned) {
4015     // If we know the sign bits of both operands are zero, strength reduce to a
4016     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4017     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4018       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4019   } else {
4020     SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4021     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4022       // fold (urem x, pow2) -> (and x, pow2-1)
4023       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4024       AddToWorklist(Add.getNode());
4025       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4026     }
4027     if (N1.getOpcode() == ISD::SHL &&
4028         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
4029       // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4030       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4031       AddToWorklist(Add.getNode());
4032       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4033     }
4034   }
4035 
4036   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4037 
4038   // If X/C can be simplified by the division-by-constant logic, lower
4039   // X%C to the equivalent of X-X/C*C.
4040   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4041   // speculative DIV must not cause a DIVREM conversion.  We guard against this
4042   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
4043   // combine will not return a DIVREM.  Regardless, checking cheapness here
4044   // makes sense since the simplification results in fatter code.
4045   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
4046     SDValue OptimizedDiv =
4047         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4048     if (OptimizedDiv.getNode()) {
4049       // If the equivalent Div node also exists, update its users.
4050       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4051       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
4052                                                 { N0, N1 }))
4053         CombineTo(DivNode, OptimizedDiv);
4054       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
4055       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4056       AddToWorklist(OptimizedDiv.getNode());
4057       AddToWorklist(Mul.getNode());
4058       return Sub;
4059     }
4060   }
4061 
4062   // sdiv, srem -> sdivrem
4063   if (SDValue DivRem = useDivRem(N))
4064     return DivRem.getValue(1);
4065 
4066   return SDValue();
4067 }
4068 
visitMULHS(SDNode * N)4069 SDValue DAGCombiner::visitMULHS(SDNode *N) {
4070   SDValue N0 = N->getOperand(0);
4071   SDValue N1 = N->getOperand(1);
4072   EVT VT = N->getValueType(0);
4073   SDLoc DL(N);
4074 
4075   if (VT.isVector()) {
4076     // fold (mulhs x, 0) -> 0
4077     // do not return N0/N1, because undef node may exist.
4078     if (ISD::isBuildVectorAllZeros(N0.getNode()) ||
4079         ISD::isBuildVectorAllZeros(N1.getNode()))
4080       return DAG.getConstant(0, DL, VT);
4081   }
4082 
4083   // fold (mulhs x, 0) -> 0
4084   if (isNullConstant(N1))
4085     return N1;
4086   // fold (mulhs x, 1) -> (sra x, size(x)-1)
4087   if (isOneConstant(N1))
4088     return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
4089                        DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
4090                                        getShiftAmountTy(N0.getValueType())));
4091 
4092   // fold (mulhs x, undef) -> 0
4093   if (N0.isUndef() || N1.isUndef())
4094     return DAG.getConstant(0, DL, VT);
4095 
4096   // If the type twice as wide is legal, transform the mulhs to a wider multiply
4097   // plus a shift.
4098   if (VT.isSimple() && !VT.isVector()) {
4099     MVT Simple = VT.getSimpleVT();
4100     unsigned SimpleSize = Simple.getSizeInBits();
4101     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4102     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4103       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4104       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4105       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4106       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4107             DAG.getConstant(SimpleSize, DL,
4108                             getShiftAmountTy(N1.getValueType())));
4109       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4110     }
4111   }
4112 
4113   return SDValue();
4114 }
4115 
visitMULHU(SDNode * N)4116 SDValue DAGCombiner::visitMULHU(SDNode *N) {
4117   SDValue N0 = N->getOperand(0);
4118   SDValue N1 = N->getOperand(1);
4119   EVT VT = N->getValueType(0);
4120   SDLoc DL(N);
4121 
4122   if (VT.isVector()) {
4123     // fold (mulhu x, 0) -> 0
4124     // do not return N0/N1, because undef node may exist.
4125     if (ISD::isBuildVectorAllZeros(N0.getNode()) ||
4126         ISD::isBuildVectorAllZeros(N1.getNode()))
4127       return DAG.getConstant(0, DL, VT);
4128   }
4129 
4130   // fold (mulhu x, 0) -> 0
4131   if (isNullConstant(N1))
4132     return N1;
4133   // fold (mulhu x, 1) -> 0
4134   if (isOneConstant(N1))
4135     return DAG.getConstant(0, DL, N0.getValueType());
4136   // fold (mulhu x, undef) -> 0
4137   if (N0.isUndef() || N1.isUndef())
4138     return DAG.getConstant(0, DL, VT);
4139 
4140   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
4141   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4142       DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
4143     unsigned NumEltBits = VT.getScalarSizeInBits();
4144     SDValue LogBase2 = BuildLogBase2(N1, DL);
4145     SDValue SRLAmt = DAG.getNode(
4146         ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
4147     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4148     SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
4149     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4150   }
4151 
4152   // If the type twice as wide is legal, transform the mulhu to a wider multiply
4153   // plus a shift.
4154   if (VT.isSimple() && !VT.isVector()) {
4155     MVT Simple = VT.getSimpleVT();
4156     unsigned SimpleSize = Simple.getSizeInBits();
4157     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4158     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4159       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4160       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4161       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4162       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4163             DAG.getConstant(SimpleSize, DL,
4164                             getShiftAmountTy(N1.getValueType())));
4165       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4166     }
4167   }
4168 
4169   return SDValue();
4170 }
4171 
4172 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
4173 /// give the opcodes for the two computations that are being performed. Return
4174 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)4175 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
4176                                                 unsigned HiOp) {
4177   // If the high half is not needed, just compute the low half.
4178   bool HiExists = N->hasAnyUseOfValue(1);
4179   if (!HiExists && (!LegalOperations ||
4180                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
4181     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4182     return CombineTo(N, Res, Res);
4183   }
4184 
4185   // If the low half is not needed, just compute the high half.
4186   bool LoExists = N->hasAnyUseOfValue(0);
4187   if (!LoExists && (!LegalOperations ||
4188                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
4189     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4190     return CombineTo(N, Res, Res);
4191   }
4192 
4193   // If both halves are used, return as it is.
4194   if (LoExists && HiExists)
4195     return SDValue();
4196 
4197   // If the two computed results can be simplified separately, separate them.
4198   if (LoExists) {
4199     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4200     AddToWorklist(Lo.getNode());
4201     SDValue LoOpt = combine(Lo.getNode());
4202     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
4203         (!LegalOperations ||
4204          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
4205       return CombineTo(N, LoOpt, LoOpt);
4206   }
4207 
4208   if (HiExists) {
4209     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4210     AddToWorklist(Hi.getNode());
4211     SDValue HiOpt = combine(Hi.getNode());
4212     if (HiOpt.getNode() && HiOpt != Hi &&
4213         (!LegalOperations ||
4214          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
4215       return CombineTo(N, HiOpt, HiOpt);
4216   }
4217 
4218   return SDValue();
4219 }
4220 
visitSMUL_LOHI(SDNode * N)4221 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
4222   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
4223     return Res;
4224 
4225   EVT VT = N->getValueType(0);
4226   SDLoc DL(N);
4227 
4228   // If the type is twice as wide is legal, transform the mulhu to a wider
4229   // multiply plus a shift.
4230   if (VT.isSimple() && !VT.isVector()) {
4231     MVT Simple = VT.getSimpleVT();
4232     unsigned SimpleSize = Simple.getSizeInBits();
4233     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4234     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4235       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N->getOperand(0));
4236       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N->getOperand(1));
4237       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4238       // Compute the high part as N1.
4239       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4240             DAG.getConstant(SimpleSize, DL,
4241                             getShiftAmountTy(Lo.getValueType())));
4242       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4243       // Compute the low part as N0.
4244       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4245       return CombineTo(N, Lo, Hi);
4246     }
4247   }
4248 
4249   return SDValue();
4250 }
4251 
visitUMUL_LOHI(SDNode * N)4252 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
4253   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
4254     return Res;
4255 
4256   EVT VT = N->getValueType(0);
4257   SDLoc DL(N);
4258 
4259   // (umul_lohi N0, 0) -> (0, 0)
4260   if (isNullConstant(N->getOperand(1))) {
4261     SDValue Zero = DAG.getConstant(0, DL, VT);
4262     return CombineTo(N, Zero, Zero);
4263   }
4264 
4265   // (umul_lohi N0, 1) -> (N0, 0)
4266   if (isOneConstant(N->getOperand(1))) {
4267     SDValue Zero = DAG.getConstant(0, DL, VT);
4268     return CombineTo(N, N->getOperand(0), Zero);
4269   }
4270 
4271   // If the type is twice as wide is legal, transform the mulhu to a wider
4272   // multiply plus a shift.
4273   if (VT.isSimple() && !VT.isVector()) {
4274     MVT Simple = VT.getSimpleVT();
4275     unsigned SimpleSize = Simple.getSizeInBits();
4276     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4277     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4278       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N->getOperand(0));
4279       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N->getOperand(1));
4280       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4281       // Compute the high part as N1.
4282       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4283             DAG.getConstant(SimpleSize, DL,
4284                             getShiftAmountTy(Lo.getValueType())));
4285       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4286       // Compute the low part as N0.
4287       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4288       return CombineTo(N, Lo, Hi);
4289     }
4290   }
4291 
4292   return SDValue();
4293 }
4294 
visitMULO(SDNode * N)4295 SDValue DAGCombiner::visitMULO(SDNode *N) {
4296   SDValue N0 = N->getOperand(0);
4297   SDValue N1 = N->getOperand(1);
4298   EVT VT = N0.getValueType();
4299   bool IsSigned = (ISD::SMULO == N->getOpcode());
4300 
4301   EVT CarryVT = N->getValueType(1);
4302   SDLoc DL(N);
4303 
4304   // canonicalize constant to RHS.
4305   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4306       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4307     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
4308 
4309   // fold (mulo x, 0) -> 0 + no carry out
4310   if (isNullOrNullSplat(N1))
4311     return CombineTo(N, DAG.getConstant(0, DL, VT),
4312                      DAG.getConstant(0, DL, CarryVT));
4313 
4314   // (mulo x, 2) -> (addo x, x)
4315   if (ConstantSDNode *C2 = isConstOrConstSplat(N1))
4316     if (C2->getAPIntValue() == 2)
4317       return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
4318                          N->getVTList(), N0, N0);
4319 
4320   return SDValue();
4321 }
4322 
visitIMINMAX(SDNode * N)4323 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
4324   SDValue N0 = N->getOperand(0);
4325   SDValue N1 = N->getOperand(1);
4326   EVT VT = N0.getValueType();
4327 
4328   // fold vector ops
4329   if (VT.isVector())
4330     if (SDValue FoldedVOp = SimplifyVBinOp(N))
4331       return FoldedVOp;
4332 
4333   // fold operation with constant operands.
4334   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
4335   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
4336   if (N0C && N1C)
4337     return DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, N0C, N1C);
4338 
4339   // canonicalize constant to RHS
4340   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4341      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4342     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
4343 
4344   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
4345   // Only do this if the current op isn't legal and the flipped is.
4346   unsigned Opcode = N->getOpcode();
4347   if (!TLI.isOperationLegal(Opcode, VT) &&
4348       (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
4349       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
4350     unsigned AltOpcode;
4351     switch (Opcode) {
4352     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
4353     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
4354     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
4355     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
4356     default: llvm_unreachable("Unknown MINMAX opcode");
4357     }
4358     if (TLI.isOperationLegal(AltOpcode, VT))
4359       return DAG.getNode(AltOpcode, SDLoc(N), VT, N0, N1);
4360   }
4361 
4362   return SDValue();
4363 }
4364 
4365 /// If this is a bitwise logic instruction and both operands have the same
4366 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)4367 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
4368   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
4369   EVT VT = N0.getValueType();
4370   unsigned LogicOpcode = N->getOpcode();
4371   unsigned HandOpcode = N0.getOpcode();
4372   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
4373           LogicOpcode == ISD::XOR) && "Expected logic opcode");
4374   assert(HandOpcode == N1.getOpcode() && "Bad input!");
4375 
4376   // Bail early if none of these transforms apply.
4377   if (N0.getNumOperands() == 0)
4378     return SDValue();
4379 
4380   // FIXME: We should check number of uses of the operands to not increase
4381   //        the instruction count for all transforms.
4382 
4383   // Handle size-changing casts.
4384   SDValue X = N0.getOperand(0);
4385   SDValue Y = N1.getOperand(0);
4386   EVT XVT = X.getValueType();
4387   SDLoc DL(N);
4388   if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
4389       HandOpcode == ISD::SIGN_EXTEND) {
4390     // If both operands have other uses, this transform would create extra
4391     // instructions without eliminating anything.
4392     if (!N0.hasOneUse() && !N1.hasOneUse())
4393       return SDValue();
4394     // We need matching integer source types.
4395     if (XVT != Y.getValueType())
4396       return SDValue();
4397     // Don't create an illegal op during or after legalization. Don't ever
4398     // create an unsupported vector op.
4399     if ((VT.isVector() || LegalOperations) &&
4400         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
4401       return SDValue();
4402     // Avoid infinite looping with PromoteIntBinOp.
4403     // TODO: Should we apply desirable/legal constraints to all opcodes?
4404     if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
4405         !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
4406       return SDValue();
4407     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
4408     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4409     return DAG.getNode(HandOpcode, DL, VT, Logic);
4410   }
4411 
4412   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
4413   if (HandOpcode == ISD::TRUNCATE) {
4414     // If both operands have other uses, this transform would create extra
4415     // instructions without eliminating anything.
4416     if (!N0.hasOneUse() && !N1.hasOneUse())
4417       return SDValue();
4418     // We need matching source types.
4419     if (XVT != Y.getValueType())
4420       return SDValue();
4421     // Don't create an illegal op during or after legalization.
4422     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
4423       return SDValue();
4424     // Be extra careful sinking truncate. If it's free, there's no benefit in
4425     // widening a binop. Also, don't create a logic op on an illegal type.
4426     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
4427       return SDValue();
4428     if (!TLI.isTypeLegal(XVT))
4429       return SDValue();
4430     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4431     return DAG.getNode(HandOpcode, DL, VT, Logic);
4432   }
4433 
4434   // For binops SHL/SRL/SRA/AND:
4435   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
4436   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
4437        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
4438       N0.getOperand(1) == N1.getOperand(1)) {
4439     // If either operand has other uses, this transform is not an improvement.
4440     if (!N0.hasOneUse() || !N1.hasOneUse())
4441       return SDValue();
4442     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4443     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
4444   }
4445 
4446   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
4447   if (HandOpcode == ISD::BSWAP) {
4448     // If either operand has other uses, this transform is not an improvement.
4449     if (!N0.hasOneUse() || !N1.hasOneUse())
4450       return SDValue();
4451     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4452     return DAG.getNode(HandOpcode, DL, VT, Logic);
4453   }
4454 
4455   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
4456   // Only perform this optimization up until type legalization, before
4457   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
4458   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
4459   // we don't want to undo this promotion.
4460   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
4461   // on scalars.
4462   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
4463        Level <= AfterLegalizeTypes) {
4464     // Input types must be integer and the same.
4465     if (XVT.isInteger() && XVT == Y.getValueType() &&
4466         !(VT.isVector() && TLI.isTypeLegal(VT) &&
4467           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
4468       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4469       return DAG.getNode(HandOpcode, DL, VT, Logic);
4470     }
4471   }
4472 
4473   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
4474   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
4475   // If both shuffles use the same mask, and both shuffle within a single
4476   // vector, then it is worthwhile to move the swizzle after the operation.
4477   // The type-legalizer generates this pattern when loading illegal
4478   // vector types from memory. In many cases this allows additional shuffle
4479   // optimizations.
4480   // There are other cases where moving the shuffle after the xor/and/or
4481   // is profitable even if shuffles don't perform a swizzle.
4482   // If both shuffles use the same mask, and both shuffles have the same first
4483   // or second operand, then it might still be profitable to move the shuffle
4484   // after the xor/and/or operation.
4485   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
4486     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
4487     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
4488     assert(X.getValueType() == Y.getValueType() &&
4489            "Inputs to shuffles are not the same type");
4490 
4491     // Check that both shuffles use the same mask. The masks are known to be of
4492     // the same length because the result vector type is the same.
4493     // Check also that shuffles have only one use to avoid introducing extra
4494     // instructions.
4495     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
4496         !SVN0->getMask().equals(SVN1->getMask()))
4497       return SDValue();
4498 
4499     // Don't try to fold this node if it requires introducing a
4500     // build vector of all zeros that might be illegal at this stage.
4501     SDValue ShOp = N0.getOperand(1);
4502     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
4503       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4504 
4505     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
4506     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
4507       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
4508                                   N0.getOperand(0), N1.getOperand(0));
4509       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
4510     }
4511 
4512     // Don't try to fold this node if it requires introducing a
4513     // build vector of all zeros that might be illegal at this stage.
4514     ShOp = N0.getOperand(0);
4515     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
4516       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4517 
4518     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
4519     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
4520       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
4521                                   N1.getOperand(1));
4522       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
4523     }
4524   }
4525 
4526   return SDValue();
4527 }
4528 
4529 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)4530 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
4531                                        const SDLoc &DL) {
4532   SDValue LL, LR, RL, RR, N0CC, N1CC;
4533   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
4534       !isSetCCEquivalent(N1, RL, RR, N1CC))
4535     return SDValue();
4536 
4537   assert(N0.getValueType() == N1.getValueType() &&
4538          "Unexpected operand types for bitwise logic op");
4539   assert(LL.getValueType() == LR.getValueType() &&
4540          RL.getValueType() == RR.getValueType() &&
4541          "Unexpected operand types for setcc");
4542 
4543   // If we're here post-legalization or the logic op type is not i1, the logic
4544   // op type must match a setcc result type. Also, all folds require new
4545   // operations on the left and right operands, so those types must match.
4546   EVT VT = N0.getValueType();
4547   EVT OpVT = LL.getValueType();
4548   if (LegalOperations || VT.getScalarType() != MVT::i1)
4549     if (VT != getSetCCResultType(OpVT))
4550       return SDValue();
4551   if (OpVT != RL.getValueType())
4552     return SDValue();
4553 
4554   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
4555   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
4556   bool IsInteger = OpVT.isInteger();
4557   if (LR == RR && CC0 == CC1 && IsInteger) {
4558     bool IsZero = isNullOrNullSplat(LR);
4559     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
4560 
4561     // All bits clear?
4562     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
4563     // All sign bits clear?
4564     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
4565     // Any bits set?
4566     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
4567     // Any sign bits set?
4568     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
4569 
4570     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
4571     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
4572     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
4573     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
4574     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
4575       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
4576       AddToWorklist(Or.getNode());
4577       return DAG.getSetCC(DL, VT, Or, LR, CC1);
4578     }
4579 
4580     // All bits set?
4581     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
4582     // All sign bits set?
4583     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
4584     // Any bits clear?
4585     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
4586     // Any sign bits clear?
4587     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
4588 
4589     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
4590     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
4591     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
4592     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
4593     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
4594       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
4595       AddToWorklist(And.getNode());
4596       return DAG.getSetCC(DL, VT, And, LR, CC1);
4597     }
4598   }
4599 
4600   // TODO: What is the 'or' equivalent of this fold?
4601   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
4602   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
4603       IsInteger && CC0 == ISD::SETNE &&
4604       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
4605        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
4606     SDValue One = DAG.getConstant(1, DL, OpVT);
4607     SDValue Two = DAG.getConstant(2, DL, OpVT);
4608     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
4609     AddToWorklist(Add.getNode());
4610     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
4611   }
4612 
4613   // Try more general transforms if the predicates match and the only user of
4614   // the compares is the 'and' or 'or'.
4615   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
4616       N0.hasOneUse() && N1.hasOneUse()) {
4617     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
4618     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
4619     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
4620       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
4621       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
4622       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
4623       SDValue Zero = DAG.getConstant(0, DL, OpVT);
4624       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
4625     }
4626 
4627     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
4628     // TODO - support non-uniform vector amounts.
4629     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
4630       // Match a shared variable operand and 2 non-opaque constant operands.
4631       ConstantSDNode *C0 = isConstOrConstSplat(LR);
4632       ConstantSDNode *C1 = isConstOrConstSplat(RR);
4633       if (LL == RL && C0 && C1 && !C0->isOpaque() && !C1->isOpaque()) {
4634         // Canonicalize larger constant as C0.
4635         if (C1->getAPIntValue().ugt(C0->getAPIntValue()))
4636           std::swap(C0, C1);
4637 
4638         // The difference of the constants must be a single bit.
4639         const APInt &C0Val = C0->getAPIntValue();
4640         const APInt &C1Val = C1->getAPIntValue();
4641         if ((C0Val - C1Val).isPowerOf2()) {
4642           // and/or (setcc X, C0, ne), (setcc X, C1, ne/eq) -->
4643           // setcc ((add X, -C1), ~(C0 - C1)), 0, ne/eq
4644           SDValue OffsetC = DAG.getConstant(-C1Val, DL, OpVT);
4645           SDValue Add = DAG.getNode(ISD::ADD, DL, OpVT, LL, OffsetC);
4646           SDValue MaskC = DAG.getConstant(~(C0Val - C1Val), DL, OpVT);
4647           SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Add, MaskC);
4648           SDValue Zero = DAG.getConstant(0, DL, OpVT);
4649           return DAG.getSetCC(DL, VT, And, Zero, CC0);
4650         }
4651       }
4652     }
4653   }
4654 
4655   // Canonicalize equivalent operands to LL == RL.
4656   if (LL == RR && LR == RL) {
4657     CC1 = ISD::getSetCCSwappedOperands(CC1);
4658     std::swap(RL, RR);
4659   }
4660 
4661   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
4662   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
4663   if (LL == RL && LR == RR) {
4664     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
4665                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
4666     if (NewCC != ISD::SETCC_INVALID &&
4667         (!LegalOperations ||
4668          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
4669           TLI.isOperationLegal(ISD::SETCC, OpVT))))
4670       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
4671   }
4672 
4673   return SDValue();
4674 }
4675 
4676 /// This contains all DAGCombine rules which reduce two values combined by
4677 /// an And operation to a single value. This makes them reusable in the context
4678 /// of visitSELECT(). Rules involving constants are not included as
4679 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)4680 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
4681   EVT VT = N1.getValueType();
4682   SDLoc DL(N);
4683 
4684   // fold (and x, undef) -> 0
4685   if (N0.isUndef() || N1.isUndef())
4686     return DAG.getConstant(0, DL, VT);
4687 
4688   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
4689     return V;
4690 
4691   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
4692       VT.getSizeInBits() <= 64) {
4693     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
4694       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
4695         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
4696         // immediate for an add, but it is legal if its top c2 bits are set,
4697         // transform the ADD so the immediate doesn't need to be materialized
4698         // in a register.
4699         APInt ADDC = ADDI->getAPIntValue();
4700         APInt SRLC = SRLI->getAPIntValue();
4701         if (ADDC.getMinSignedBits() <= 64 &&
4702             SRLC.ult(VT.getSizeInBits()) &&
4703             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
4704           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
4705                                              SRLC.getZExtValue());
4706           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
4707             ADDC |= Mask;
4708             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
4709               SDLoc DL0(N0);
4710               SDValue NewAdd =
4711                 DAG.getNode(ISD::ADD, DL0, VT,
4712                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
4713               CombineTo(N0.getNode(), NewAdd);
4714               // Return N so it doesn't get rechecked!
4715               return SDValue(N, 0);
4716             }
4717           }
4718         }
4719       }
4720     }
4721   }
4722 
4723   // Reduce bit extract of low half of an integer to the narrower type.
4724   // (and (srl i64:x, K), KMask) ->
4725   //   (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
4726   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
4727     if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
4728       if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
4729         unsigned Size = VT.getSizeInBits();
4730         const APInt &AndMask = CAnd->getAPIntValue();
4731         unsigned ShiftBits = CShift->getZExtValue();
4732 
4733         // Bail out, this node will probably disappear anyway.
4734         if (ShiftBits == 0)
4735           return SDValue();
4736 
4737         unsigned MaskBits = AndMask.countTrailingOnes();
4738         EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
4739 
4740         if (AndMask.isMask() &&
4741             // Required bits must not span the two halves of the integer and
4742             // must fit in the half size type.
4743             (ShiftBits + MaskBits <= Size / 2) &&
4744             TLI.isNarrowingProfitable(VT, HalfVT) &&
4745             TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
4746             TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
4747             TLI.isTruncateFree(VT, HalfVT) &&
4748             TLI.isZExtFree(HalfVT, VT)) {
4749           // The isNarrowingProfitable is to avoid regressions on PPC and
4750           // AArch64 which match a few 64-bit bit insert / bit extract patterns
4751           // on downstream users of this. Those patterns could probably be
4752           // extended to handle extensions mixed in.
4753 
4754           SDValue SL(N0);
4755           assert(MaskBits <= Size);
4756 
4757           // Extracting the highest bit of the low half.
4758           EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
4759           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
4760                                       N0.getOperand(0));
4761 
4762           SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
4763           SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
4764           SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
4765           SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
4766           return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
4767         }
4768       }
4769     }
4770   }
4771 
4772   return SDValue();
4773 }
4774 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)4775 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
4776                                    EVT LoadResultTy, EVT &ExtVT) {
4777   if (!AndC->getAPIntValue().isMask())
4778     return false;
4779 
4780   unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
4781 
4782   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
4783   EVT LoadedVT = LoadN->getMemoryVT();
4784 
4785   if (ExtVT == LoadedVT &&
4786       (!LegalOperations ||
4787        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
4788     // ZEXTLOAD will match without needing to change the size of the value being
4789     // loaded.
4790     return true;
4791   }
4792 
4793   // Do not change the width of a volatile or atomic loads.
4794   if (!LoadN->isSimple())
4795     return false;
4796 
4797   // Do not generate loads of non-round integer types since these can
4798   // be expensive (and would be wrong if the type is not byte sized).
4799   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
4800     return false;
4801 
4802   if (LegalOperations &&
4803       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
4804     return false;
4805 
4806   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
4807     return false;
4808 
4809   return true;
4810 }
4811 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)4812 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
4813                                     ISD::LoadExtType ExtType, EVT &MemVT,
4814                                     unsigned ShAmt) {
4815   if (!LDST)
4816     return false;
4817   // Only allow byte offsets.
4818   if (ShAmt % 8)
4819     return false;
4820 
4821   // Do not generate loads of non-round integer types since these can
4822   // be expensive (and would be wrong if the type is not byte sized).
4823   if (!MemVT.isRound())
4824     return false;
4825 
4826   // Don't change the width of a volatile or atomic loads.
4827   if (!LDST->isSimple())
4828     return false;
4829 
4830   // Verify that we are actually reducing a load width here.
4831   if (LDST->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits())
4832     return false;
4833 
4834   // Ensure that this isn't going to produce an unsupported memory access.
4835   if (ShAmt &&
4836       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
4837                               LDST->getAddressSpace(), ShAmt / 8,
4838                               LDST->getMemOperand()->getFlags()))
4839     return false;
4840 
4841   // It's not possible to generate a constant of extended or untyped type.
4842   EVT PtrType = LDST->getBasePtr().getValueType();
4843   if (PtrType == MVT::Untyped || PtrType.isExtended())
4844     return false;
4845 
4846   if (isa<LoadSDNode>(LDST)) {
4847     LoadSDNode *Load = cast<LoadSDNode>(LDST);
4848     // Don't transform one with multiple uses, this would require adding a new
4849     // load.
4850     if (!SDValue(Load, 0).hasOneUse())
4851       return false;
4852 
4853     if (LegalOperations &&
4854         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
4855       return false;
4856 
4857     // For the transform to be legal, the load must produce only two values
4858     // (the value loaded and the chain).  Don't transform a pre-increment
4859     // load, for example, which produces an extra value.  Otherwise the
4860     // transformation is not equivalent, and the downstream logic to replace
4861     // uses gets things wrong.
4862     if (Load->getNumValues() > 2)
4863       return false;
4864 
4865     // If the load that we're shrinking is an extload and we're not just
4866     // discarding the extension we can't simply shrink the load. Bail.
4867     // TODO: It would be possible to merge the extensions in some cases.
4868     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
4869         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
4870       return false;
4871 
4872     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
4873       return false;
4874   } else {
4875     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
4876     StoreSDNode *Store = cast<StoreSDNode>(LDST);
4877     // Can't write outside the original store
4878     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
4879       return false;
4880 
4881     if (LegalOperations &&
4882         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
4883       return false;
4884   }
4885   return true;
4886 }
4887 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)4888 bool DAGCombiner::SearchForAndLoads(SDNode *N,
4889                                     SmallVectorImpl<LoadSDNode*> &Loads,
4890                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
4891                                     ConstantSDNode *Mask,
4892                                     SDNode *&NodeToMask) {
4893   // Recursively search for the operands, looking for loads which can be
4894   // narrowed.
4895   for (SDValue Op : N->op_values()) {
4896     if (Op.getValueType().isVector())
4897       return false;
4898 
4899     // Some constants may need fixing up later if they are too large.
4900     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
4901       if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
4902           (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
4903         NodesWithConsts.insert(N);
4904       continue;
4905     }
4906 
4907     if (!Op.hasOneUse())
4908       return false;
4909 
4910     switch(Op.getOpcode()) {
4911     case ISD::LOAD: {
4912       auto *Load = cast<LoadSDNode>(Op);
4913       EVT ExtVT;
4914       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
4915           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
4916 
4917         // ZEXTLOAD is already small enough.
4918         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
4919             ExtVT.bitsGE(Load->getMemoryVT()))
4920           continue;
4921 
4922         // Use LE to convert equal sized loads to zext.
4923         if (ExtVT.bitsLE(Load->getMemoryVT()))
4924           Loads.push_back(Load);
4925 
4926         continue;
4927       }
4928       return false;
4929     }
4930     case ISD::ZERO_EXTEND:
4931     case ISD::AssertZext: {
4932       unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
4933       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
4934       EVT VT = Op.getOpcode() == ISD::AssertZext ?
4935         cast<VTSDNode>(Op.getOperand(1))->getVT() :
4936         Op.getOperand(0).getValueType();
4937 
4938       // We can accept extending nodes if the mask is wider or an equal
4939       // width to the original type.
4940       if (ExtVT.bitsGE(VT))
4941         continue;
4942       break;
4943     }
4944     case ISD::OR:
4945     case ISD::XOR:
4946     case ISD::AND:
4947       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
4948                              NodeToMask))
4949         return false;
4950       continue;
4951     }
4952 
4953     // Allow one node which will masked along with any loads found.
4954     if (NodeToMask)
4955       return false;
4956 
4957     // Also ensure that the node to be masked only produces one data result.
4958     NodeToMask = Op.getNode();
4959     if (NodeToMask->getNumValues() > 1) {
4960       bool HasValue = false;
4961       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
4962         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
4963         if (VT != MVT::Glue && VT != MVT::Other) {
4964           if (HasValue) {
4965             NodeToMask = nullptr;
4966             return false;
4967           }
4968           HasValue = true;
4969         }
4970       }
4971       assert(HasValue && "Node to be masked has no data result?");
4972     }
4973   }
4974   return true;
4975 }
4976 
BackwardsPropagateMask(SDNode * N)4977 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
4978   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
4979   if (!Mask)
4980     return false;
4981 
4982   if (!Mask->getAPIntValue().isMask())
4983     return false;
4984 
4985   // No need to do anything if the and directly uses a load.
4986   if (isa<LoadSDNode>(N->getOperand(0)))
4987     return false;
4988 
4989   SmallVector<LoadSDNode*, 8> Loads;
4990   SmallPtrSet<SDNode*, 2> NodesWithConsts;
4991   SDNode *FixupNode = nullptr;
4992   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
4993     if (Loads.size() == 0)
4994       return false;
4995 
4996     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
4997     SDValue MaskOp = N->getOperand(1);
4998 
4999     // If it exists, fixup the single node we allow in the tree that needs
5000     // masking.
5001     if (FixupNode) {
5002       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
5003       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
5004                                 FixupNode->getValueType(0),
5005                                 SDValue(FixupNode, 0), MaskOp);
5006       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
5007       if (And.getOpcode() == ISD ::AND)
5008         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
5009     }
5010 
5011     // Narrow any constants that need it.
5012     for (auto *LogicN : NodesWithConsts) {
5013       SDValue Op0 = LogicN->getOperand(0);
5014       SDValue Op1 = LogicN->getOperand(1);
5015 
5016       if (isa<ConstantSDNode>(Op0))
5017           std::swap(Op0, Op1);
5018 
5019       SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
5020                                 Op1, MaskOp);
5021 
5022       DAG.UpdateNodeOperands(LogicN, Op0, And);
5023     }
5024 
5025     // Create narrow loads.
5026     for (auto *Load : Loads) {
5027       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
5028       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
5029                                 SDValue(Load, 0), MaskOp);
5030       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
5031       if (And.getOpcode() == ISD ::AND)
5032         And = SDValue(
5033             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
5034       SDValue NewLoad = ReduceLoadWidth(And.getNode());
5035       assert(NewLoad &&
5036              "Shouldn't be masking the load if it can't be narrowed");
5037       CombineTo(Load, NewLoad, NewLoad.getValue(1));
5038     }
5039     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
5040     return true;
5041   }
5042   return false;
5043 }
5044 
5045 // Unfold
5046 //    x &  (-1 'logical shift' y)
5047 // To
5048 //    (x 'opposite logical shift' y) 'logical shift' y
5049 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)5050 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
5051   assert(N->getOpcode() == ISD::AND);
5052 
5053   SDValue N0 = N->getOperand(0);
5054   SDValue N1 = N->getOperand(1);
5055 
5056   // Do we actually prefer shifts over mask?
5057   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
5058     return SDValue();
5059 
5060   // Try to match  (-1 '[outer] logical shift' y)
5061   unsigned OuterShift;
5062   unsigned InnerShift; // The opposite direction to the OuterShift.
5063   SDValue Y;           // Shift amount.
5064   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
5065     if (!M.hasOneUse())
5066       return false;
5067     OuterShift = M->getOpcode();
5068     if (OuterShift == ISD::SHL)
5069       InnerShift = ISD::SRL;
5070     else if (OuterShift == ISD::SRL)
5071       InnerShift = ISD::SHL;
5072     else
5073       return false;
5074     if (!isAllOnesConstant(M->getOperand(0)))
5075       return false;
5076     Y = M->getOperand(1);
5077     return true;
5078   };
5079 
5080   SDValue X;
5081   if (matchMask(N1))
5082     X = N0;
5083   else if (matchMask(N0))
5084     X = N1;
5085   else
5086     return SDValue();
5087 
5088   SDLoc DL(N);
5089   EVT VT = N->getValueType(0);
5090 
5091   //     tmp = x   'opposite logical shift' y
5092   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
5093   //     ret = tmp 'logical shift' y
5094   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
5095 
5096   return T1;
5097 }
5098 
5099 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
5100 /// For a target with a bit test, this is expected to become test + set and save
5101 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)5102 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
5103   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
5104 
5105   // This is probably not worthwhile without a supported type.
5106   EVT VT = And->getValueType(0);
5107   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5108   if (!TLI.isTypeLegal(VT))
5109     return SDValue();
5110 
5111   // Look through an optional extension and find a 'not'.
5112   // TODO: Should we favor test+set even without the 'not' op?
5113   SDValue Not = And->getOperand(0), And1 = And->getOperand(1);
5114   if (Not.getOpcode() == ISD::ANY_EXTEND)
5115     Not = Not.getOperand(0);
5116   if (!isBitwiseNot(Not) || !Not.hasOneUse() || !isOneConstant(And1))
5117     return SDValue();
5118 
5119   // Look though an optional truncation. The source operand may not be the same
5120   // type as the original 'and', but that is ok because we are masking off
5121   // everything but the low bit.
5122   SDValue Srl = Not.getOperand(0);
5123   if (Srl.getOpcode() == ISD::TRUNCATE)
5124     Srl = Srl.getOperand(0);
5125 
5126   // Match a shift-right by constant.
5127   if (Srl.getOpcode() != ISD::SRL || !Srl.hasOneUse() ||
5128       !isa<ConstantSDNode>(Srl.getOperand(1)))
5129     return SDValue();
5130 
5131   // We might have looked through casts that make this transform invalid.
5132   // TODO: If the source type is wider than the result type, do the mask and
5133   //       compare in the source type.
5134   const APInt &ShiftAmt = Srl.getConstantOperandAPInt(1);
5135   unsigned VTBitWidth = VT.getSizeInBits();
5136   if (ShiftAmt.uge(VTBitWidth))
5137     return SDValue();
5138 
5139   // Turn this into a bit-test pattern using mask op + setcc:
5140   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
5141   SDLoc DL(And);
5142   SDValue X = DAG.getZExtOrTrunc(Srl.getOperand(0), DL, VT);
5143   EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
5144   SDValue Mask = DAG.getConstant(
5145       APInt::getOneBitSet(VTBitWidth, ShiftAmt.getZExtValue()), DL, VT);
5146   SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
5147   SDValue Zero = DAG.getConstant(0, DL, VT);
5148   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
5149   return DAG.getZExtOrTrunc(Setcc, DL, VT);
5150 }
5151 
visitAND(SDNode * N)5152 SDValue DAGCombiner::visitAND(SDNode *N) {
5153   SDValue N0 = N->getOperand(0);
5154   SDValue N1 = N->getOperand(1);
5155   EVT VT = N1.getValueType();
5156 
5157   // x & x --> x
5158   if (N0 == N1)
5159     return N0;
5160 
5161   // fold vector ops
5162   if (VT.isVector()) {
5163     if (SDValue FoldedVOp = SimplifyVBinOp(N))
5164       return FoldedVOp;
5165 
5166     // fold (and x, 0) -> 0, vector edition
5167     if (ISD::isBuildVectorAllZeros(N0.getNode()))
5168       // do not return N0, because undef node may exist in N0
5169       return DAG.getConstant(APInt::getNullValue(N0.getScalarValueSizeInBits()),
5170                              SDLoc(N), N0.getValueType());
5171     if (ISD::isBuildVectorAllZeros(N1.getNode()))
5172       // do not return N1, because undef node may exist in N1
5173       return DAG.getConstant(APInt::getNullValue(N1.getScalarValueSizeInBits()),
5174                              SDLoc(N), N1.getValueType());
5175 
5176     // fold (and x, -1) -> x, vector edition
5177     if (ISD::isBuildVectorAllOnes(N0.getNode()))
5178       return N1;
5179     if (ISD::isBuildVectorAllOnes(N1.getNode()))
5180       return N0;
5181   }
5182 
5183   // fold (and c1, c2) -> c1&c2
5184   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
5185   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5186   if (N0C && N1C && !N1C->isOpaque())
5187     return DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, N0C, N1C);
5188   // canonicalize constant to RHS
5189   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5190       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5191     return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
5192   // fold (and x, -1) -> x
5193   if (isAllOnesConstant(N1))
5194     return N0;
5195   // if (and x, c) is known to be zero, return 0
5196   unsigned BitWidth = VT.getScalarSizeInBits();
5197   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
5198                                    APInt::getAllOnesValue(BitWidth)))
5199     return DAG.getConstant(0, SDLoc(N), VT);
5200 
5201   if (SDValue NewSel = foldBinOpIntoSelect(N))
5202     return NewSel;
5203 
5204   // reassociate and
5205   if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
5206     return RAND;
5207 
5208   // Try to convert a constant mask AND into a shuffle clear mask.
5209   if (VT.isVector())
5210     if (SDValue Shuffle = XformToShuffleWithZero(N))
5211       return Shuffle;
5212 
5213   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
5214     return Combined;
5215 
5216   // fold (and (or x, C), D) -> D if (C & D) == D
5217   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
5218     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
5219   };
5220   if (N0.getOpcode() == ISD::OR &&
5221       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
5222     return N1;
5223   // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
5224   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
5225     SDValue N0Op0 = N0.getOperand(0);
5226     APInt Mask = ~N1C->getAPIntValue();
5227     Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
5228     if (DAG.MaskedValueIsZero(N0Op0, Mask)) {
5229       SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
5230                                  N0.getValueType(), N0Op0);
5231 
5232       // Replace uses of the AND with uses of the Zero extend node.
5233       CombineTo(N, Zext);
5234 
5235       // We actually want to replace all uses of the any_extend with the
5236       // zero_extend, to avoid duplicating things.  This will later cause this
5237       // AND to be folded.
5238       CombineTo(N0.getNode(), Zext);
5239       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
5240     }
5241   }
5242 
5243   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
5244   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
5245   // already be zero by virtue of the width of the base type of the load.
5246   //
5247   // the 'X' node here can either be nothing or an extract_vector_elt to catch
5248   // more cases.
5249   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
5250        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
5251        N0.getOperand(0).getOpcode() == ISD::LOAD &&
5252        N0.getOperand(0).getResNo() == 0) ||
5253       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
5254     LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
5255                                          N0 : N0.getOperand(0) );
5256 
5257     // Get the constant (if applicable) the zero'th operand is being ANDed with.
5258     // This can be a pure constant or a vector splat, in which case we treat the
5259     // vector as a scalar and use the splat value.
5260     APInt Constant = APInt::getNullValue(1);
5261     if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
5262       Constant = C->getAPIntValue();
5263     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
5264       APInt SplatValue, SplatUndef;
5265       unsigned SplatBitSize;
5266       bool HasAnyUndefs;
5267       bool IsSplat = Vector->isConstantSplat(SplatValue, SplatUndef,
5268                                              SplatBitSize, HasAnyUndefs);
5269       if (IsSplat) {
5270         // Undef bits can contribute to a possible optimisation if set, so
5271         // set them.
5272         SplatValue |= SplatUndef;
5273 
5274         // The splat value may be something like "0x00FFFFFF", which means 0 for
5275         // the first vector value and FF for the rest, repeating. We need a mask
5276         // that will apply equally to all members of the vector, so AND all the
5277         // lanes of the constant together.
5278         unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
5279 
5280         // If the splat value has been compressed to a bitlength lower
5281         // than the size of the vector lane, we need to re-expand it to
5282         // the lane size.
5283         if (EltBitWidth > SplatBitSize)
5284           for (SplatValue = SplatValue.zextOrTrunc(EltBitWidth);
5285                SplatBitSize < EltBitWidth; SplatBitSize = SplatBitSize * 2)
5286             SplatValue |= SplatValue.shl(SplatBitSize);
5287 
5288         // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
5289         // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
5290         if ((SplatBitSize % EltBitWidth) == 0) {
5291           Constant = APInt::getAllOnesValue(EltBitWidth);
5292           for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
5293             Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
5294         }
5295       }
5296     }
5297 
5298     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
5299     // actually legal and isn't going to get expanded, else this is a false
5300     // optimisation.
5301     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
5302                                                     Load->getValueType(0),
5303                                                     Load->getMemoryVT());
5304 
5305     // Resize the constant to the same size as the original memory access before
5306     // extension. If it is still the AllOnesValue then this AND is completely
5307     // unneeded.
5308     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
5309 
5310     bool B;
5311     switch (Load->getExtensionType()) {
5312     default: B = false; break;
5313     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
5314     case ISD::ZEXTLOAD:
5315     case ISD::NON_EXTLOAD: B = true; break;
5316     }
5317 
5318     if (B && Constant.isAllOnesValue()) {
5319       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
5320       // preserve semantics once we get rid of the AND.
5321       SDValue NewLoad(Load, 0);
5322 
5323       // Fold the AND away. NewLoad may get replaced immediately.
5324       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
5325 
5326       if (Load->getExtensionType() == ISD::EXTLOAD) {
5327         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
5328                               Load->getValueType(0), SDLoc(Load),
5329                               Load->getChain(), Load->getBasePtr(),
5330                               Load->getOffset(), Load->getMemoryVT(),
5331                               Load->getMemOperand());
5332         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
5333         if (Load->getNumValues() == 3) {
5334           // PRE/POST_INC loads have 3 values.
5335           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
5336                            NewLoad.getValue(2) };
5337           CombineTo(Load, To, 3, true);
5338         } else {
5339           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
5340         }
5341       }
5342 
5343       return SDValue(N, 0); // Return N so it doesn't get rechecked!
5344     }
5345   }
5346 
5347   // fold (and (load x), 255) -> (zextload x, i8)
5348   // fold (and (extload x, i16), 255) -> (zextload x, i8)
5349   // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8)
5350   if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
5351                                 (N0.getOpcode() == ISD::ANY_EXTEND &&
5352                                  N0.getOperand(0).getOpcode() == ISD::LOAD))) {
5353     if (SDValue Res = ReduceLoadWidth(N)) {
5354       LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
5355         ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
5356       AddToWorklist(N);
5357       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res);
5358       return SDValue(N, 0);
5359     }
5360   }
5361 
5362   if (LegalTypes) {
5363     // Attempt to propagate the AND back up to the leaves which, if they're
5364     // loads, can be combined to narrow loads and the AND node can be removed.
5365     // Perform after legalization so that extend nodes will already be
5366     // combined into the loads.
5367     if (BackwardsPropagateMask(N))
5368       return SDValue(N, 0);
5369   }
5370 
5371   if (SDValue Combined = visitANDLike(N0, N1, N))
5372     return Combined;
5373 
5374   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
5375   if (N0.getOpcode() == N1.getOpcode())
5376     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
5377       return V;
5378 
5379   // Masking the negated extension of a boolean is just the zero-extended
5380   // boolean:
5381   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
5382   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
5383   //
5384   // Note: the SimplifyDemandedBits fold below can make an information-losing
5385   // transform, and then we have no way to find this better fold.
5386   if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
5387     if (isNullOrNullSplat(N0.getOperand(0))) {
5388       SDValue SubRHS = N0.getOperand(1);
5389       if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
5390           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
5391         return SubRHS;
5392       if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
5393           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
5394         return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
5395     }
5396   }
5397 
5398   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
5399   // fold (and (sra)) -> (and (srl)) when possible.
5400   if (SimplifyDemandedBits(SDValue(N, 0)))
5401     return SDValue(N, 0);
5402 
5403   // fold (zext_inreg (extload x)) -> (zextload x)
5404   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
5405   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
5406       (ISD::isEXTLoad(N0.getNode()) ||
5407        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
5408     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
5409     EVT MemVT = LN0->getMemoryVT();
5410     // If we zero all the possible extended bits, then we can turn this into
5411     // a zextload if we are running before legalize or the operation is legal.
5412     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
5413     unsigned MemBitSize = MemVT.getScalarSizeInBits();
5414     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
5415     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
5416         ((!LegalOperations && LN0->isSimple()) ||
5417          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
5418       SDValue ExtLoad =
5419           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
5420                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
5421       AddToWorklist(N);
5422       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
5423       return SDValue(N, 0); // Return N so it doesn't get rechecked!
5424     }
5425   }
5426 
5427   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
5428   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
5429     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
5430                                            N0.getOperand(1), false))
5431       return BSwap;
5432   }
5433 
5434   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
5435     return Shifts;
5436 
5437   if (TLI.hasBitTest(N0, N1))
5438     if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
5439       return V;
5440 
5441   return SDValue();
5442 }
5443 
5444 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)5445 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
5446                                         bool DemandHighBits) {
5447   if (!LegalOperations)
5448     return SDValue();
5449 
5450   EVT VT = N->getValueType(0);
5451   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
5452     return SDValue();
5453   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
5454     return SDValue();
5455 
5456   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
5457   bool LookPassAnd0 = false;
5458   bool LookPassAnd1 = false;
5459   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
5460       std::swap(N0, N1);
5461   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
5462       std::swap(N0, N1);
5463   if (N0.getOpcode() == ISD::AND) {
5464     if (!N0.getNode()->hasOneUse())
5465       return SDValue();
5466     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5467     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
5468     // This is needed for X86.
5469     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
5470                   N01C->getZExtValue() != 0xFFFF))
5471       return SDValue();
5472     N0 = N0.getOperand(0);
5473     LookPassAnd0 = true;
5474   }
5475 
5476   if (N1.getOpcode() == ISD::AND) {
5477     if (!N1.getNode()->hasOneUse())
5478       return SDValue();
5479     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
5480     if (!N11C || N11C->getZExtValue() != 0xFF)
5481       return SDValue();
5482     N1 = N1.getOperand(0);
5483     LookPassAnd1 = true;
5484   }
5485 
5486   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
5487     std::swap(N0, N1);
5488   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
5489     return SDValue();
5490   if (!N0.getNode()->hasOneUse() || !N1.getNode()->hasOneUse())
5491     return SDValue();
5492 
5493   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5494   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
5495   if (!N01C || !N11C)
5496     return SDValue();
5497   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
5498     return SDValue();
5499 
5500   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
5501   SDValue N00 = N0->getOperand(0);
5502   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
5503     if (!N00.getNode()->hasOneUse())
5504       return SDValue();
5505     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
5506     if (!N001C || N001C->getZExtValue() != 0xFF)
5507       return SDValue();
5508     N00 = N00.getOperand(0);
5509     LookPassAnd0 = true;
5510   }
5511 
5512   SDValue N10 = N1->getOperand(0);
5513   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
5514     if (!N10.getNode()->hasOneUse())
5515       return SDValue();
5516     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
5517     // Also allow 0xFFFF since the bits will be shifted out. This is needed
5518     // for X86.
5519     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
5520                    N101C->getZExtValue() != 0xFFFF))
5521       return SDValue();
5522     N10 = N10.getOperand(0);
5523     LookPassAnd1 = true;
5524   }
5525 
5526   if (N00 != N10)
5527     return SDValue();
5528 
5529   // Make sure everything beyond the low halfword gets set to zero since the SRL
5530   // 16 will clear the top bits.
5531   unsigned OpSizeInBits = VT.getSizeInBits();
5532   if (DemandHighBits && OpSizeInBits > 16) {
5533     // If the left-shift isn't masked out then the only way this is a bswap is
5534     // if all bits beyond the low 8 are 0. In that case the entire pattern
5535     // reduces to a left shift anyway: leave it for other parts of the combiner.
5536     if (!LookPassAnd0)
5537       return SDValue();
5538 
5539     // However, if the right shift isn't masked out then it might be because
5540     // it's not needed. See if we can spot that too.
5541     if (!LookPassAnd1 &&
5542         !DAG.MaskedValueIsZero(
5543             N10, APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - 16)))
5544       return SDValue();
5545   }
5546 
5547   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
5548   if (OpSizeInBits > 16) {
5549     SDLoc DL(N);
5550     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
5551                       DAG.getConstant(OpSizeInBits - 16, DL,
5552                                       getShiftAmountTy(VT)));
5553   }
5554   return Res;
5555 }
5556 
5557 /// Return true if the specified node is an element that makes up a 32-bit
5558 /// packed halfword byteswap.
5559 /// ((x & 0x000000ff) << 8) |
5560 /// ((x & 0x0000ff00) >> 8) |
5561 /// ((x & 0x00ff0000) << 8) |
5562 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)5563 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
5564   if (!N.getNode()->hasOneUse())
5565     return false;
5566 
5567   unsigned Opc = N.getOpcode();
5568   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
5569     return false;
5570 
5571   SDValue N0 = N.getOperand(0);
5572   unsigned Opc0 = N0.getOpcode();
5573   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
5574     return false;
5575 
5576   ConstantSDNode *N1C = nullptr;
5577   // SHL or SRL: look upstream for AND mask operand
5578   if (Opc == ISD::AND)
5579     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5580   else if (Opc0 == ISD::AND)
5581     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5582   if (!N1C)
5583     return false;
5584 
5585   unsigned MaskByteOffset;
5586   switch (N1C->getZExtValue()) {
5587   default:
5588     return false;
5589   case 0xFF:       MaskByteOffset = 0; break;
5590   case 0xFF00:     MaskByteOffset = 1; break;
5591   case 0xFFFF:
5592     // In case demanded bits didn't clear the bits that will be shifted out.
5593     // This is needed for X86.
5594     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
5595       MaskByteOffset = 1;
5596       break;
5597     }
5598     return false;
5599   case 0xFF0000:   MaskByteOffset = 2; break;
5600   case 0xFF000000: MaskByteOffset = 3; break;
5601   }
5602 
5603   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
5604   if (Opc == ISD::AND) {
5605     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
5606       // (x >> 8) & 0xff
5607       // (x >> 8) & 0xff0000
5608       if (Opc0 != ISD::SRL)
5609         return false;
5610       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5611       if (!C || C->getZExtValue() != 8)
5612         return false;
5613     } else {
5614       // (x << 8) & 0xff00
5615       // (x << 8) & 0xff000000
5616       if (Opc0 != ISD::SHL)
5617         return false;
5618       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5619       if (!C || C->getZExtValue() != 8)
5620         return false;
5621     }
5622   } else if (Opc == ISD::SHL) {
5623     // (x & 0xff) << 8
5624     // (x & 0xff0000) << 8
5625     if (MaskByteOffset != 0 && MaskByteOffset != 2)
5626       return false;
5627     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5628     if (!C || C->getZExtValue() != 8)
5629       return false;
5630   } else { // Opc == ISD::SRL
5631     // (x & 0xff00) >> 8
5632     // (x & 0xff000000) >> 8
5633     if (MaskByteOffset != 1 && MaskByteOffset != 3)
5634       return false;
5635     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5636     if (!C || C->getZExtValue() != 8)
5637       return false;
5638   }
5639 
5640   if (Parts[MaskByteOffset])
5641     return false;
5642 
5643   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
5644   return true;
5645 }
5646 
5647 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)5648 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
5649   if (N.getOpcode() == ISD::OR)
5650     return isBSwapHWordElement(N.getOperand(0), Parts) &&
5651            isBSwapHWordElement(N.getOperand(1), Parts);
5652 
5653   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
5654     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
5655     if (!C || C->getAPIntValue() != 16)
5656       return false;
5657     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
5658     return true;
5659   }
5660 
5661   return false;
5662 }
5663 
5664 /// Match a 32-bit packed halfword bswap. That is
5665 /// ((x & 0x000000ff) << 8) |
5666 /// ((x & 0x0000ff00) >> 8) |
5667 /// ((x & 0x00ff0000) << 8) |
5668 /// ((x & 0xff000000) >> 8)
5669 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)5670 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
5671   if (!LegalOperations)
5672     return SDValue();
5673 
5674   EVT VT = N->getValueType(0);
5675   if (VT != MVT::i32)
5676     return SDValue();
5677   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
5678     return SDValue();
5679 
5680   // Look for either
5681   // (or (bswaphpair), (bswaphpair))
5682   // (or (or (bswaphpair), (and)), (and))
5683   // (or (or (and), (bswaphpair)), (and))
5684   SDNode *Parts[4] = {};
5685 
5686   if (isBSwapHWordPair(N0, Parts)) {
5687     // (or (or (and), (and)), (or (and), (and)))
5688     if (!isBSwapHWordPair(N1, Parts))
5689       return SDValue();
5690   } else if (N0.getOpcode() == ISD::OR) {
5691     // (or (or (or (and), (and)), (and)), (and))
5692     if (!isBSwapHWordElement(N1, Parts))
5693       return SDValue();
5694     SDValue N00 = N0.getOperand(0);
5695     SDValue N01 = N0.getOperand(1);
5696     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
5697         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
5698       return SDValue();
5699   } else
5700     return SDValue();
5701 
5702   // Make sure the parts are all coming from the same node.
5703   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
5704     return SDValue();
5705 
5706   SDLoc DL(N);
5707   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
5708                               SDValue(Parts[0], 0));
5709 
5710   // Result of the bswap should be rotated by 16. If it's not legal, then
5711   // do  (x << 16) | (x >> 16).
5712   SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
5713   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
5714     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
5715   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
5716     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
5717   return DAG.getNode(ISD::OR, DL, VT,
5718                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
5719                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
5720 }
5721 
5722 /// This contains all DAGCombine rules which reduce two values combined by
5723 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)5724 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
5725   EVT VT = N1.getValueType();
5726   SDLoc DL(N);
5727 
5728   // fold (or x, undef) -> -1
5729   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
5730     return DAG.getAllOnesConstant(DL, VT);
5731 
5732   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
5733     return V;
5734 
5735   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
5736   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
5737       // Don't increase # computations.
5738       (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
5739     // We can only do this xform if we know that bits from X that are set in C2
5740     // but not in C1 are already zero.  Likewise for Y.
5741     if (const ConstantSDNode *N0O1C =
5742         getAsNonOpaqueConstant(N0.getOperand(1))) {
5743       if (const ConstantSDNode *N1O1C =
5744           getAsNonOpaqueConstant(N1.getOperand(1))) {
5745         // We can only do this xform if we know that bits from X that are set in
5746         // C2 but not in C1 are already zero.  Likewise for Y.
5747         const APInt &LHSMask = N0O1C->getAPIntValue();
5748         const APInt &RHSMask = N1O1C->getAPIntValue();
5749 
5750         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
5751             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
5752           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
5753                                   N0.getOperand(0), N1.getOperand(0));
5754           return DAG.getNode(ISD::AND, DL, VT, X,
5755                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
5756         }
5757       }
5758     }
5759   }
5760 
5761   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
5762   if (N0.getOpcode() == ISD::AND &&
5763       N1.getOpcode() == ISD::AND &&
5764       N0.getOperand(0) == N1.getOperand(0) &&
5765       // Don't increase # computations.
5766       (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
5767     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
5768                             N0.getOperand(1), N1.getOperand(1));
5769     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
5770   }
5771 
5772   return SDValue();
5773 }
5774 
5775 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)5776 static SDValue visitORCommutative(
5777     SelectionDAG &DAG, SDValue N0, SDValue N1, SDNode *N) {
5778   EVT VT = N0.getValueType();
5779   if (N0.getOpcode() == ISD::AND) {
5780     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
5781     if (isBitwiseNot(N0.getOperand(1)) && N0.getOperand(1).getOperand(0) == N1)
5782       return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(0), N1);
5783 
5784     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
5785     if (isBitwiseNot(N0.getOperand(0)) && N0.getOperand(0).getOperand(0) == N1)
5786       return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(1), N1);
5787   }
5788 
5789   return SDValue();
5790 }
5791 
visitOR(SDNode * N)5792 SDValue DAGCombiner::visitOR(SDNode *N) {
5793   SDValue N0 = N->getOperand(0);
5794   SDValue N1 = N->getOperand(1);
5795   EVT VT = N1.getValueType();
5796 
5797   // x | x --> x
5798   if (N0 == N1)
5799     return N0;
5800 
5801   // fold vector ops
5802   if (VT.isVector()) {
5803     if (SDValue FoldedVOp = SimplifyVBinOp(N))
5804       return FoldedVOp;
5805 
5806     // fold (or x, 0) -> x, vector edition
5807     if (ISD::isBuildVectorAllZeros(N0.getNode()))
5808       return N1;
5809     if (ISD::isBuildVectorAllZeros(N1.getNode()))
5810       return N0;
5811 
5812     // fold (or x, -1) -> -1, vector edition
5813     if (ISD::isBuildVectorAllOnes(N0.getNode()))
5814       // do not return N0, because undef node may exist in N0
5815       return DAG.getAllOnesConstant(SDLoc(N), N0.getValueType());
5816     if (ISD::isBuildVectorAllOnes(N1.getNode()))
5817       // do not return N1, because undef node may exist in N1
5818       return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
5819 
5820     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
5821     // Do this only if the resulting shuffle is legal.
5822     if (isa<ShuffleVectorSDNode>(N0) &&
5823         isa<ShuffleVectorSDNode>(N1) &&
5824         // Avoid folding a node with illegal type.
5825         TLI.isTypeLegal(VT)) {
5826       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
5827       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
5828       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
5829       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
5830       // Ensure both shuffles have a zero input.
5831       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
5832         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
5833         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
5834         const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0);
5835         const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1);
5836         bool CanFold = true;
5837         int NumElts = VT.getVectorNumElements();
5838         SmallVector<int, 4> Mask(NumElts);
5839 
5840         for (int i = 0; i != NumElts; ++i) {
5841           int M0 = SV0->getMaskElt(i);
5842           int M1 = SV1->getMaskElt(i);
5843 
5844           // Determine if either index is pointing to a zero vector.
5845           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
5846           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
5847 
5848           // If one element is zero and the otherside is undef, keep undef.
5849           // This also handles the case that both are undef.
5850           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0)) {
5851             Mask[i] = -1;
5852             continue;
5853           }
5854 
5855           // Make sure only one of the elements is zero.
5856           if (M0Zero == M1Zero) {
5857             CanFold = false;
5858             break;
5859           }
5860 
5861           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
5862 
5863           // We have a zero and non-zero element. If the non-zero came from
5864           // SV0 make the index a LHS index. If it came from SV1, make it
5865           // a RHS index. We need to mod by NumElts because we don't care
5866           // which operand it came from in the original shuffles.
5867           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
5868         }
5869 
5870         if (CanFold) {
5871           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
5872           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
5873 
5874           SDValue LegalShuffle =
5875               TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
5876                                           Mask, DAG);
5877           if (LegalShuffle)
5878             return LegalShuffle;
5879         }
5880       }
5881     }
5882   }
5883 
5884   // fold (or c1, c2) -> c1|c2
5885   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
5886   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
5887   if (N0C && N1C && !N1C->isOpaque())
5888     return DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, N0C, N1C);
5889   // canonicalize constant to RHS
5890   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5891      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5892     return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
5893   // fold (or x, 0) -> x
5894   if (isNullConstant(N1))
5895     return N0;
5896   // fold (or x, -1) -> -1
5897   if (isAllOnesConstant(N1))
5898     return N1;
5899 
5900   if (SDValue NewSel = foldBinOpIntoSelect(N))
5901     return NewSel;
5902 
5903   // fold (or x, c) -> c iff (x & ~c) == 0
5904   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
5905     return N1;
5906 
5907   if (SDValue Combined = visitORLike(N0, N1, N))
5908     return Combined;
5909 
5910   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
5911     return Combined;
5912 
5913   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
5914   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
5915     return BSwap;
5916   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
5917     return BSwap;
5918 
5919   // reassociate or
5920   if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
5921     return ROR;
5922 
5923   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
5924   // iff (c1 & c2) != 0 or c1/c2 are undef.
5925   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
5926     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
5927   };
5928   if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
5929       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
5930     if (SDValue COR = DAG.FoldConstantArithmetic(
5931             ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) {
5932       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
5933       AddToWorklist(IOR.getNode());
5934       return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
5935     }
5936   }
5937 
5938   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
5939     return Combined;
5940   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
5941     return Combined;
5942 
5943   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
5944   if (N0.getOpcode() == N1.getOpcode())
5945     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
5946       return V;
5947 
5948   // See if this is some rotate idiom.
5949   if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
5950     return Rot;
5951 
5952   if (SDValue Load = MatchLoadCombine(N))
5953     return Load;
5954 
5955   // Simplify the operands using demanded-bits information.
5956   if (SimplifyDemandedBits(SDValue(N, 0)))
5957     return SDValue(N, 0);
5958 
5959   // If OR can be rewritten into ADD, try combines based on ADD.
5960   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
5961       DAG.haveNoCommonBitsSet(N0, N1))
5962     if (SDValue Combined = visitADDLike(N))
5963       return Combined;
5964 
5965   return SDValue();
5966 }
5967 
stripConstantMask(SelectionDAG & DAG,SDValue Op,SDValue & Mask)5968 static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) {
5969   if (Op.getOpcode() == ISD::AND &&
5970       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
5971     Mask = Op.getOperand(1);
5972     return Op.getOperand(0);
5973   }
5974   return Op;
5975 }
5976 
5977 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)5978 static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift,
5979                             SDValue &Mask) {
5980   Op = stripConstantMask(DAG, Op, Mask);
5981   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
5982     Shift = Op;
5983     return true;
5984   }
5985   return false;
5986 }
5987 
5988 /// Helper function for visitOR to extract the needed side of a rotate idiom
5989 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
5990 /// InstCombine merged some outside op with one of the shifts from
5991 /// the rotate pattern.
5992 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
5993 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
5994 /// patterns:
5995 ///
5996 ///   (or (add v v) (shrl v bitwidth-1)):
5997 ///     expands (add v v) -> (shl v 1)
5998 ///
5999 ///   (or (mul v c0) (shrl (mul v c1) c2)):
6000 ///     expands (mul v c0) -> (shl (mul v c1) c3)
6001 ///
6002 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
6003 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
6004 ///
6005 ///   (or (shl v c0) (shrl (shl v c1) c2)):
6006 ///     expands (shl v c0) -> (shl (shl v c1) c3)
6007 ///
6008 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
6009 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
6010 ///
6011 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)6012 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
6013                                      SDValue ExtractFrom, SDValue &Mask,
6014                                      const SDLoc &DL) {
6015   assert(OppShift && ExtractFrom && "Empty SDValue");
6016   assert(
6017       (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) &&
6018       "Existing shift must be valid as a rotate half");
6019 
6020   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
6021 
6022   // Value and Type of the shift.
6023   SDValue OppShiftLHS = OppShift.getOperand(0);
6024   EVT ShiftedVT = OppShiftLHS.getValueType();
6025 
6026   // Amount of the existing shift.
6027   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
6028 
6029   // (add v v) -> (shl v 1)
6030   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
6031       ExtractFrom.getOpcode() == ISD::ADD &&
6032       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
6033       ExtractFrom.getOperand(0) == OppShiftLHS &&
6034       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
6035     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
6036                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
6037 
6038   // Preconditions:
6039   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
6040   //
6041   // Find opcode of the needed shift to be extracted from (op0 v c0).
6042   unsigned Opcode = ISD::DELETED_NODE;
6043   bool IsMulOrDiv = false;
6044   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
6045   // opcode or its arithmetic (mul or udiv) variant.
6046   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
6047     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
6048     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
6049       return false;
6050     Opcode = NeededShift;
6051     return true;
6052   };
6053   // op0 must be either the needed shift opcode or the mul/udiv equivalent
6054   // that the needed shift can be extracted from.
6055   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
6056       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
6057     return SDValue();
6058 
6059   // op0 must be the same opcode on both sides, have the same LHS argument,
6060   // and produce the same value type.
6061   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
6062       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
6063       ShiftedVT != ExtractFrom.getValueType())
6064     return SDValue();
6065 
6066   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
6067   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
6068   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
6069   ConstantSDNode *ExtractFromCst =
6070       isConstOrConstSplat(ExtractFrom.getOperand(1));
6071   // TODO: We should be able to handle non-uniform constant vectors for these values
6072   // Check that we have constant values.
6073   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
6074       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
6075       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
6076     return SDValue();
6077 
6078   // Compute the shift amount we need to extract to complete the rotate.
6079   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
6080   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
6081     return SDValue();
6082   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
6083   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
6084   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
6085   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
6086   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
6087 
6088   // Now try extract the needed shift from the ExtractFrom op and see if the
6089   // result matches up with the existing shift's LHS op.
6090   if (IsMulOrDiv) {
6091     // Op to extract from is a mul or udiv by a constant.
6092     // Check:
6093     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
6094     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
6095     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
6096                                                  NeededShiftAmt.getZExtValue());
6097     APInt ResultAmt;
6098     APInt Rem;
6099     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
6100     if (Rem != 0 || ResultAmt != OppLHSAmt)
6101       return SDValue();
6102   } else {
6103     // Op to extract from is a shift by a constant.
6104     // Check:
6105     //      c2 - (bitwidth(op0 v c0) - c1) == c0
6106     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
6107                                           ExtractFromAmt.getBitWidth()))
6108       return SDValue();
6109   }
6110 
6111   // Return the expanded shift op that should allow a rotate to be formed.
6112   EVT ShiftVT = OppShift.getOperand(1).getValueType();
6113   EVT ResVT = ExtractFrom.getValueType();
6114   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
6115   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
6116 }
6117 
6118 // Return true if we can prove that, whenever Neg and Pos are both in the
6119 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
6120 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
6121 //
6122 //     (or (shift1 X, Neg), (shift2 X, Pos))
6123 //
6124 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
6125 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
6126 // to consider shift amounts with defined behavior.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG)6127 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
6128                            SelectionDAG &DAG) {
6129   // If EltSize is a power of 2 then:
6130   //
6131   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
6132   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
6133   //
6134   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
6135   // for the stronger condition:
6136   //
6137   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
6138   //
6139   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
6140   // we can just replace Neg with Neg' for the rest of the function.
6141   //
6142   // In other cases we check for the even stronger condition:
6143   //
6144   //     Neg == EltSize - Pos                                    [B]
6145   //
6146   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
6147   // behavior if Pos == 0 (and consequently Neg == EltSize).
6148   //
6149   // We could actually use [A] whenever EltSize is a power of 2, but the
6150   // only extra cases that it would match are those uninteresting ones
6151   // where Neg and Pos are never in range at the same time.  E.g. for
6152   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
6153   // as well as (sub 32, Pos), but:
6154   //
6155   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
6156   //
6157   // always invokes undefined behavior for 32-bit X.
6158   //
6159   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
6160   unsigned MaskLoBits = 0;
6161   if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) {
6162     if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) {
6163       KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0));
6164       unsigned Bits = Log2_64(EltSize);
6165       if (NegC->getAPIntValue().getActiveBits() <= Bits &&
6166           ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) {
6167         Neg = Neg.getOperand(0);
6168         MaskLoBits = Bits;
6169       }
6170     }
6171   }
6172 
6173   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
6174   if (Neg.getOpcode() != ISD::SUB)
6175     return false;
6176   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
6177   if (!NegC)
6178     return false;
6179   SDValue NegOp1 = Neg.getOperand(1);
6180 
6181   // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with
6182   // Pos'.  The truncation is redundant for the purpose of the equality.
6183   if (MaskLoBits && Pos.getOpcode() == ISD::AND) {
6184     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) {
6185       KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0));
6186       if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits &&
6187           ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >=
6188            MaskLoBits))
6189         Pos = Pos.getOperand(0);
6190     }
6191   }
6192 
6193   // The condition we need is now:
6194   //
6195   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
6196   //
6197   // If NegOp1 == Pos then we need:
6198   //
6199   //              EltSize & Mask == NegC & Mask
6200   //
6201   // (because "x & Mask" is a truncation and distributes through subtraction).
6202   APInt Width;
6203   if (Pos == NegOp1)
6204     Width = NegC->getAPIntValue();
6205 
6206   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
6207   // Then the condition we want to prove becomes:
6208   //
6209   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
6210   //
6211   // which, again because "x & Mask" is a truncation, becomes:
6212   //
6213   //                NegC & Mask == (EltSize - PosC) & Mask
6214   //             EltSize & Mask == (NegC + PosC) & Mask
6215   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
6216     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
6217       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
6218     else
6219       return false;
6220   } else
6221     return false;
6222 
6223   // Now we just need to check that EltSize & Mask == Width & Mask.
6224   if (MaskLoBits)
6225     // EltSize & Mask is 0 since Mask is EltSize - 1.
6226     return Width.getLoBits(MaskLoBits) == 0;
6227   return Width == EltSize;
6228 }
6229 
6230 // A subroutine of MatchRotate used once we have found an OR of two opposite
6231 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
6232 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
6233 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
6234 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)6235 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
6236                                        SDValue Neg, SDValue InnerPos,
6237                                        SDValue InnerNeg, unsigned PosOpcode,
6238                                        unsigned NegOpcode, const SDLoc &DL) {
6239   // fold (or (shl x, (*ext y)),
6240   //          (srl x, (*ext (sub 32, y)))) ->
6241   //   (rotl x, y) or (rotr x, (sub 32, y))
6242   //
6243   // fold (or (shl x, (*ext (sub 32, y))),
6244   //          (srl x, (*ext y))) ->
6245   //   (rotr x, y) or (rotl x, (sub 32, y))
6246   EVT VT = Shifted.getValueType();
6247   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG)) {
6248     bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
6249     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
6250                        HasPos ? Pos : Neg);
6251   }
6252 
6253   return SDValue();
6254 }
6255 
6256 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
6257 // idioms for rotate, and if the target supports rotation instructions, generate
6258 // a rot[lr].
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)6259 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
6260   // Must be a legal type.  Expanded 'n promoted things won't work with rotates.
6261   EVT VT = LHS.getValueType();
6262   if (!TLI.isTypeLegal(VT))
6263     return SDValue();
6264 
6265   // The target must have at least one rotate flavor.
6266   bool HasROTL = hasOperation(ISD::ROTL, VT);
6267   bool HasROTR = hasOperation(ISD::ROTR, VT);
6268   if (!HasROTL && !HasROTR)
6269     return SDValue();
6270 
6271   // Check for truncated rotate.
6272   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
6273       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
6274     assert(LHS.getValueType() == RHS.getValueType());
6275     if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
6276       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
6277     }
6278   }
6279 
6280   // Match "(X shl/srl V1) & V2" where V2 may not be present.
6281   SDValue LHSShift;   // The shift.
6282   SDValue LHSMask;    // AND value if any.
6283   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
6284 
6285   SDValue RHSShift;   // The shift.
6286   SDValue RHSMask;    // AND value if any.
6287   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
6288 
6289   // If neither side matched a rotate half, bail
6290   if (!LHSShift && !RHSShift)
6291     return SDValue();
6292 
6293   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
6294   // side of the rotate, so try to handle that here. In all cases we need to
6295   // pass the matched shift from the opposite side to compute the opcode and
6296   // needed shift amount to extract.  We still want to do this if both sides
6297   // matched a rotate half because one half may be a potential overshift that
6298   // can be broken down (ie if InstCombine merged two shl or srl ops into a
6299   // single one).
6300 
6301   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
6302   if (LHSShift)
6303     if (SDValue NewRHSShift =
6304             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
6305       RHSShift = NewRHSShift;
6306   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
6307   if (RHSShift)
6308     if (SDValue NewLHSShift =
6309             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
6310       LHSShift = NewLHSShift;
6311 
6312   // If a side is still missing, nothing else we can do.
6313   if (!RHSShift || !LHSShift)
6314     return SDValue();
6315 
6316   // At this point we've matched or extracted a shift op on each side.
6317 
6318   if (LHSShift.getOperand(0) != RHSShift.getOperand(0))
6319     return SDValue(); // Not shifting the same value.
6320 
6321   if (LHSShift.getOpcode() == RHSShift.getOpcode())
6322     return SDValue(); // Shifts must disagree.
6323 
6324   // Canonicalize shl to left side in a shl/srl pair.
6325   if (RHSShift.getOpcode() == ISD::SHL) {
6326     std::swap(LHS, RHS);
6327     std::swap(LHSShift, RHSShift);
6328     std::swap(LHSMask, RHSMask);
6329   }
6330 
6331   unsigned EltSizeInBits = VT.getScalarSizeInBits();
6332   SDValue LHSShiftArg = LHSShift.getOperand(0);
6333   SDValue LHSShiftAmt = LHSShift.getOperand(1);
6334   SDValue RHSShiftArg = RHSShift.getOperand(0);
6335   SDValue RHSShiftAmt = RHSShift.getOperand(1);
6336 
6337   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
6338   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
6339   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
6340                                         ConstantSDNode *RHS) {
6341     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
6342   };
6343   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
6344     SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
6345                               LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt);
6346 
6347     // If there is an AND of either shifted operand, apply it to the result.
6348     if (LHSMask.getNode() || RHSMask.getNode()) {
6349       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
6350       SDValue Mask = AllOnes;
6351 
6352       if (LHSMask.getNode()) {
6353         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
6354         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
6355                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
6356       }
6357       if (RHSMask.getNode()) {
6358         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
6359         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
6360                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
6361       }
6362 
6363       Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask);
6364     }
6365 
6366     return Rot;
6367   }
6368 
6369   // If there is a mask here, and we have a variable shift, we can't be sure
6370   // that we're masking out the right stuff.
6371   if (LHSMask.getNode() || RHSMask.getNode())
6372     return SDValue();
6373 
6374   // If the shift amount is sign/zext/any-extended just peel it off.
6375   SDValue LExtOp0 = LHSShiftAmt;
6376   SDValue RExtOp0 = RHSShiftAmt;
6377   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
6378        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
6379        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
6380        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
6381       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
6382        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
6383        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
6384        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
6385     LExtOp0 = LHSShiftAmt.getOperand(0);
6386     RExtOp0 = RHSShiftAmt.getOperand(0);
6387   }
6388 
6389   SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
6390                                    LExtOp0, RExtOp0, ISD::ROTL, ISD::ROTR, DL);
6391   if (TryL)
6392     return TryL;
6393 
6394   SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
6395                                    RExtOp0, LExtOp0, ISD::ROTR, ISD::ROTL, DL);
6396   if (TryR)
6397     return TryR;
6398 
6399   return SDValue();
6400 }
6401 
6402 namespace {
6403 
6404 /// Represents known origin of an individual byte in load combine pattern. The
6405 /// value of the byte is either constant zero or comes from memory.
6406 struct ByteProvider {
6407   // For constant zero providers Load is set to nullptr. For memory providers
6408   // Load represents the node which loads the byte from memory.
6409   // ByteOffset is the offset of the byte in the value produced by the load.
6410   LoadSDNode *Load = nullptr;
6411   unsigned ByteOffset = 0;
6412 
6413   ByteProvider() = default;
6414 
getMemory__anon4d358cf60b11::ByteProvider6415   static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
6416     return ByteProvider(Load, ByteOffset);
6417   }
6418 
getConstantZero__anon4d358cf60b11::ByteProvider6419   static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
6420 
isConstantZero__anon4d358cf60b11::ByteProvider6421   bool isConstantZero() const { return !Load; }
isMemory__anon4d358cf60b11::ByteProvider6422   bool isMemory() const { return Load; }
6423 
operator ==__anon4d358cf60b11::ByteProvider6424   bool operator==(const ByteProvider &Other) const {
6425     return Other.Load == Load && Other.ByteOffset == ByteOffset;
6426   }
6427 
6428 private:
ByteProvider__anon4d358cf60b11::ByteProvider6429   ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
6430       : Load(Load), ByteOffset(ByteOffset) {}
6431 };
6432 
6433 } // end anonymous namespace
6434 
6435 /// Recursively traverses the expression calculating the origin of the requested
6436 /// byte of the given value. Returns None if the provider can't be calculated.
6437 ///
6438 /// For all the values except the root of the expression verifies that the value
6439 /// has exactly one use and if it's not true return None. This way if the origin
6440 /// of the byte is returned it's guaranteed that the values which contribute to
6441 /// the byte are not used outside of this expression.
6442 ///
6443 /// Because the parts of the expression are not allowed to have more than one
6444 /// use this function iterates over trees, not DAGs. So it never visits the same
6445 /// node more than once.
6446 static const Optional<ByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,bool Root=false)6447 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
6448                       bool Root = false) {
6449   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
6450   if (Depth == 10)
6451     return None;
6452 
6453   if (!Root && !Op.hasOneUse())
6454     return None;
6455 
6456   assert(Op.getValueType().isScalarInteger() && "can't handle other types");
6457   unsigned BitWidth = Op.getValueSizeInBits();
6458   if (BitWidth % 8 != 0)
6459     return None;
6460   unsigned ByteWidth = BitWidth / 8;
6461   assert(Index < ByteWidth && "invalid index requested");
6462   (void) ByteWidth;
6463 
6464   switch (Op.getOpcode()) {
6465   case ISD::OR: {
6466     auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
6467     if (!LHS)
6468       return None;
6469     auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
6470     if (!RHS)
6471       return None;
6472 
6473     if (LHS->isConstantZero())
6474       return RHS;
6475     if (RHS->isConstantZero())
6476       return LHS;
6477     return None;
6478   }
6479   case ISD::SHL: {
6480     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
6481     if (!ShiftOp)
6482       return None;
6483 
6484     uint64_t BitShift = ShiftOp->getZExtValue();
6485     if (BitShift % 8 != 0)
6486       return None;
6487     uint64_t ByteShift = BitShift / 8;
6488 
6489     return Index < ByteShift
6490                ? ByteProvider::getConstantZero()
6491                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
6492                                        Depth + 1);
6493   }
6494   case ISD::ANY_EXTEND:
6495   case ISD::SIGN_EXTEND:
6496   case ISD::ZERO_EXTEND: {
6497     SDValue NarrowOp = Op->getOperand(0);
6498     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
6499     if (NarrowBitWidth % 8 != 0)
6500       return None;
6501     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
6502 
6503     if (Index >= NarrowByteWidth)
6504       return Op.getOpcode() == ISD::ZERO_EXTEND
6505                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
6506                  : None;
6507     return calculateByteProvider(NarrowOp, Index, Depth + 1);
6508   }
6509   case ISD::BSWAP:
6510     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
6511                                  Depth + 1);
6512   case ISD::LOAD: {
6513     auto L = cast<LoadSDNode>(Op.getNode());
6514     if (!L->isSimple() || L->isIndexed())
6515       return None;
6516 
6517     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
6518     if (NarrowBitWidth % 8 != 0)
6519       return None;
6520     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
6521 
6522     if (Index >= NarrowByteWidth)
6523       return L->getExtensionType() == ISD::ZEXTLOAD
6524                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
6525                  : None;
6526     return ByteProvider::getMemory(L, Index);
6527   }
6528   }
6529 
6530   return None;
6531 }
6532 
LittleEndianByteAt(unsigned BW,unsigned i)6533 static unsigned LittleEndianByteAt(unsigned BW, unsigned i) {
6534   return i;
6535 }
6536 
BigEndianByteAt(unsigned BW,unsigned i)6537 static unsigned BigEndianByteAt(unsigned BW, unsigned i) {
6538   return BW - i - 1;
6539 }
6540 
6541 // Check if the bytes offsets we are looking at match with either big or
6542 // little endian value loaded. Return true for big endian, false for little
6543 // endian, and None if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)6544 static Optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
6545                                   int64_t FirstOffset) {
6546   // The endian can be decided only when it is 2 bytes at least.
6547   unsigned Width = ByteOffsets.size();
6548   if (Width < 2)
6549     return None;
6550 
6551   bool BigEndian = true, LittleEndian = true;
6552   for (unsigned i = 0; i < Width; i++) {
6553     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
6554     LittleEndian &= CurrentByteOffset == LittleEndianByteAt(Width, i);
6555     BigEndian &= CurrentByteOffset == BigEndianByteAt(Width, i);
6556     if (!BigEndian && !LittleEndian)
6557       return None;
6558   }
6559 
6560   assert((BigEndian != LittleEndian) && "It should be either big endian or"
6561                                         "little endian");
6562   return BigEndian;
6563 }
6564 
stripTruncAndExt(SDValue Value)6565 static SDValue stripTruncAndExt(SDValue Value) {
6566   switch (Value.getOpcode()) {
6567   case ISD::TRUNCATE:
6568   case ISD::ZERO_EXTEND:
6569   case ISD::SIGN_EXTEND:
6570   case ISD::ANY_EXTEND:
6571     return stripTruncAndExt(Value.getOperand(0));
6572   }
6573   return Value;
6574 }
6575 
6576 /// Match a pattern where a wide type scalar value is stored by several narrow
6577 /// stores. Fold it into a single store or a BSWAP and a store if the targets
6578 /// supports it.
6579 ///
6580 /// Assuming little endian target:
6581 ///  i8 *p = ...
6582 ///  i32 val = ...
6583 ///  p[0] = (val >> 0) & 0xFF;
6584 ///  p[1] = (val >> 8) & 0xFF;
6585 ///  p[2] = (val >> 16) & 0xFF;
6586 ///  p[3] = (val >> 24) & 0xFF;
6587 /// =>
6588 ///  *((i32)p) = val;
6589 ///
6590 ///  i8 *p = ...
6591 ///  i32 val = ...
6592 ///  p[0] = (val >> 24) & 0xFF;
6593 ///  p[1] = (val >> 16) & 0xFF;
6594 ///  p[2] = (val >> 8) & 0xFF;
6595 ///  p[3] = (val >> 0) & 0xFF;
6596 /// =>
6597 ///  *((i32)p) = BSWAP(val);
MatchStoreCombine(StoreSDNode * N)6598 SDValue DAGCombiner::MatchStoreCombine(StoreSDNode *N) {
6599   // Collect all the stores in the chain.
6600   SDValue Chain;
6601   SmallVector<StoreSDNode *, 8> Stores;
6602   for (StoreSDNode *Store = N; Store; Store = dyn_cast<StoreSDNode>(Chain)) {
6603     // TODO: Allow unordered atomics when wider type is legal (see D66309)
6604     if (Store->getMemoryVT() != MVT::i8 ||
6605         !Store->isSimple() || Store->isIndexed())
6606       return SDValue();
6607     Stores.push_back(Store);
6608     Chain = Store->getChain();
6609   }
6610   // Handle the simple type only.
6611   unsigned Width = Stores.size();
6612   EVT VT = EVT::getIntegerVT(
6613     *DAG.getContext(), Width * N->getMemoryVT().getSizeInBits());
6614   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6615     return SDValue();
6616 
6617   if (LegalOperations && !TLI.isOperationLegal(ISD::STORE, VT))
6618     return SDValue();
6619 
6620   // Check if all the bytes of the combined value we are looking at are stored
6621   // to the same base address. Collect bytes offsets from Base address into
6622   // ByteOffsets.
6623   SDValue CombinedValue;
6624   SmallVector<int64_t, 8> ByteOffsets(Width, INT64_MAX);
6625   int64_t FirstOffset = INT64_MAX;
6626   StoreSDNode *FirstStore = nullptr;
6627   Optional<BaseIndexOffset> Base;
6628   for (auto Store : Stores) {
6629     // All the stores store different byte of the CombinedValue. A truncate is
6630     // required to get that byte value.
6631     SDValue Trunc = Store->getValue();
6632     if (Trunc.getOpcode() != ISD::TRUNCATE)
6633       return SDValue();
6634     // A shift operation is required to get the right byte offset, except the
6635     // first byte.
6636     int64_t Offset = 0;
6637     SDValue Value = Trunc.getOperand(0);
6638     if (Value.getOpcode() == ISD::SRL ||
6639         Value.getOpcode() == ISD::SRA) {
6640       ConstantSDNode *ShiftOffset =
6641         dyn_cast<ConstantSDNode>(Value.getOperand(1));
6642       // Trying to match the following pattern. The shift offset must be
6643       // a constant and a multiple of 8. It is the byte offset in "y".
6644       //
6645       // x = srl y, offset
6646       // i8 z = trunc x
6647       // store z, ...
6648       if (!ShiftOffset || (ShiftOffset->getSExtValue() % 8))
6649         return SDValue();
6650 
6651      Offset = ShiftOffset->getSExtValue()/8;
6652      Value = Value.getOperand(0);
6653     }
6654 
6655     // Stores must share the same combined value with different offsets.
6656     if (!CombinedValue)
6657       CombinedValue = Value;
6658     else if (stripTruncAndExt(CombinedValue) != stripTruncAndExt(Value))
6659       return SDValue();
6660 
6661     // The trunc and all the extend operation should be stripped to get the
6662     // real value we are stored.
6663     else if (CombinedValue.getValueType() != VT) {
6664       if (Value.getValueType() == VT ||
6665           Value.getValueSizeInBits() > CombinedValue.getValueSizeInBits())
6666         CombinedValue = Value;
6667       // Give up if the combined value type is smaller than the store size.
6668       if (CombinedValue.getValueSizeInBits() < VT.getSizeInBits())
6669         return SDValue();
6670     }
6671 
6672     // Stores must share the same base address
6673     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
6674     int64_t ByteOffsetFromBase = 0;
6675     if (!Base)
6676       Base = Ptr;
6677     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
6678       return SDValue();
6679 
6680     // Remember the first byte store
6681     if (ByteOffsetFromBase < FirstOffset) {
6682       FirstStore = Store;
6683       FirstOffset = ByteOffsetFromBase;
6684     }
6685     // Map the offset in the store and the offset in the combined value, and
6686     // early return if it has been set before.
6687     if (Offset < 0 || Offset >= Width || ByteOffsets[Offset] != INT64_MAX)
6688       return SDValue();
6689     ByteOffsets[Offset] = ByteOffsetFromBase;
6690   }
6691 
6692   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
6693   assert(FirstStore && "First store must be set");
6694 
6695   // Check if the bytes of the combined value we are looking at match with
6696   // either big or little endian value store.
6697   Optional<bool> IsBigEndian = isBigEndian(ByteOffsets, FirstOffset);
6698   if (!IsBigEndian.hasValue())
6699     return SDValue();
6700 
6701   // The node we are looking at matches with the pattern, check if we can
6702   // replace it with a single bswap if needed and store.
6703 
6704   // If the store needs byte swap check if the target supports it
6705   bool NeedsBswap = DAG.getDataLayout().isBigEndian() != *IsBigEndian;
6706 
6707   // Before legalize we can introduce illegal bswaps which will be later
6708   // converted to an explicit bswap sequence. This way we end up with a single
6709   // store and byte shuffling instead of several stores and byte shuffling.
6710   if (NeedsBswap && LegalOperations && !TLI.isOperationLegal(ISD::BSWAP, VT))
6711     return SDValue();
6712 
6713   // Check that a store of the wide type is both allowed and fast on the target
6714   bool Fast = false;
6715   bool Allowed =
6716       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
6717                              *FirstStore->getMemOperand(), &Fast);
6718   if (!Allowed || !Fast)
6719     return SDValue();
6720 
6721   if (VT != CombinedValue.getValueType()) {
6722     assert(CombinedValue.getValueType().getSizeInBits() > VT.getSizeInBits() &&
6723            "Get unexpected store value to combine");
6724     CombinedValue = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT,
6725                              CombinedValue);
6726   }
6727 
6728   if (NeedsBswap)
6729     CombinedValue = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, CombinedValue);
6730 
6731   SDValue NewStore =
6732     DAG.getStore(Chain, SDLoc(N),  CombinedValue, FirstStore->getBasePtr(),
6733                  FirstStore->getPointerInfo(), FirstStore->getAlignment());
6734 
6735   // Rely on other DAG combine rules to remove the other individual stores.
6736   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
6737   return NewStore;
6738 }
6739 
6740 /// Match a pattern where a wide type scalar value is loaded by several narrow
6741 /// loads and combined by shifts and ors. Fold it into a single load or a load
6742 /// and a BSWAP if the targets supports it.
6743 ///
6744 /// Assuming little endian target:
6745 ///  i8 *a = ...
6746 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
6747 /// =>
6748 ///  i32 val = *((i32)a)
6749 ///
6750 ///  i8 *a = ...
6751 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
6752 /// =>
6753 ///  i32 val = BSWAP(*((i32)a))
6754 ///
6755 /// TODO: This rule matches complex patterns with OR node roots and doesn't
6756 /// interact well with the worklist mechanism. When a part of the pattern is
6757 /// updated (e.g. one of the loads) its direct users are put into the worklist,
6758 /// but the root node of the pattern which triggers the load combine is not
6759 /// necessarily a direct user of the changed node. For example, once the address
6760 /// of t28 load is reassociated load combine won't be triggered:
6761 ///             t25: i32 = add t4, Constant:i32<2>
6762 ///           t26: i64 = sign_extend t25
6763 ///        t27: i64 = add t2, t26
6764 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
6765 ///     t29: i32 = zero_extend t28
6766 ///   t32: i32 = shl t29, Constant:i8<8>
6767 /// t33: i32 = or t23, t32
6768 /// As a possible fix visitLoad can check if the load can be a part of a load
6769 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)6770 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
6771   assert(N->getOpcode() == ISD::OR &&
6772          "Can only match load combining against OR nodes");
6773 
6774   // Handles simple types only
6775   EVT VT = N->getValueType(0);
6776   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6777     return SDValue();
6778   unsigned ByteWidth = VT.getSizeInBits() / 8;
6779 
6780   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
6781   auto MemoryByteOffset = [&] (ByteProvider P) {
6782     assert(P.isMemory() && "Must be a memory byte provider");
6783     unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits();
6784     assert(LoadBitWidth % 8 == 0 &&
6785            "can only analyze providers for individual bytes not bit");
6786     unsigned LoadByteWidth = LoadBitWidth / 8;
6787     return IsBigEndianTarget
6788             ? BigEndianByteAt(LoadByteWidth, P.ByteOffset)
6789             : LittleEndianByteAt(LoadByteWidth, P.ByteOffset);
6790   };
6791 
6792   Optional<BaseIndexOffset> Base;
6793   SDValue Chain;
6794 
6795   SmallPtrSet<LoadSDNode *, 8> Loads;
6796   Optional<ByteProvider> FirstByteProvider;
6797   int64_t FirstOffset = INT64_MAX;
6798 
6799   // Check if all the bytes of the OR we are looking at are loaded from the same
6800   // base address. Collect bytes offsets from Base address in ByteOffsets.
6801   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
6802   unsigned ZeroExtendedBytes = 0;
6803   for (int i = ByteWidth - 1; i >= 0; --i) {
6804     auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
6805     if (!P)
6806       return SDValue();
6807 
6808     if (P->isConstantZero()) {
6809       // It's OK for the N most significant bytes to be 0, we can just
6810       // zero-extend the load.
6811       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
6812         return SDValue();
6813       continue;
6814     }
6815     assert(P->isMemory() && "provenance should either be memory or zero");
6816 
6817     LoadSDNode *L = P->Load;
6818     assert(L->hasNUsesOfValue(1, 0) && L->isSimple() &&
6819            !L->isIndexed() &&
6820            "Must be enforced by calculateByteProvider");
6821     assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
6822 
6823     // All loads must share the same chain
6824     SDValue LChain = L->getChain();
6825     if (!Chain)
6826       Chain = LChain;
6827     else if (Chain != LChain)
6828       return SDValue();
6829 
6830     // Loads must share the same base address
6831     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
6832     int64_t ByteOffsetFromBase = 0;
6833     if (!Base)
6834       Base = Ptr;
6835     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
6836       return SDValue();
6837 
6838     // Calculate the offset of the current byte from the base address
6839     ByteOffsetFromBase += MemoryByteOffset(*P);
6840     ByteOffsets[i] = ByteOffsetFromBase;
6841 
6842     // Remember the first byte load
6843     if (ByteOffsetFromBase < FirstOffset) {
6844       FirstByteProvider = P;
6845       FirstOffset = ByteOffsetFromBase;
6846     }
6847 
6848     Loads.insert(L);
6849   }
6850   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
6851          "memory, so there must be at least one load which produces the value");
6852   assert(Base && "Base address of the accessed memory location must be set");
6853   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
6854 
6855   bool NeedsZext = ZeroExtendedBytes > 0;
6856 
6857   EVT MemVT =
6858       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
6859 
6860   if (!MemVT.isSimple())
6861     return SDValue();
6862 
6863   // Before legalize we can introduce too wide illegal loads which will be later
6864   // split into legal sized loads. This enables us to combine i64 load by i8
6865   // patterns to a couple of i32 loads on 32 bit targets.
6866   if (LegalOperations &&
6867       !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
6868                             MemVT))
6869     return SDValue();
6870 
6871   // Check if the bytes of the OR we are looking at match with either big or
6872   // little endian value load
6873   Optional<bool> IsBigEndian = isBigEndian(
6874       makeArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
6875   if (!IsBigEndian.hasValue())
6876     return SDValue();
6877 
6878   assert(FirstByteProvider && "must be set");
6879 
6880   // Ensure that the first byte is loaded from zero offset of the first load.
6881   // So the combined value can be loaded from the first load address.
6882   if (MemoryByteOffset(*FirstByteProvider) != 0)
6883     return SDValue();
6884   LoadSDNode *FirstLoad = FirstByteProvider->Load;
6885 
6886   // The node we are looking at matches with the pattern, check if we can
6887   // replace it with a single (possibly zero-extended) load and bswap + shift if
6888   // needed.
6889 
6890   // If the load needs byte swap check if the target supports it
6891   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
6892 
6893   // Before legalize we can introduce illegal bswaps which will be later
6894   // converted to an explicit bswap sequence. This way we end up with a single
6895   // load and byte shuffling instead of several loads and byte shuffling.
6896   // We do not introduce illegal bswaps when zero-extending as this tends to
6897   // introduce too many arithmetic instructions.
6898   if (NeedsBswap && (LegalOperations || NeedsZext) &&
6899       !TLI.isOperationLegal(ISD::BSWAP, VT))
6900     return SDValue();
6901 
6902   // If we need to bswap and zero extend, we have to insert a shift. Check that
6903   // it is legal.
6904   if (NeedsBswap && NeedsZext && LegalOperations &&
6905       !TLI.isOperationLegal(ISD::SHL, VT))
6906     return SDValue();
6907 
6908   // Check that a load of the wide type is both allowed and fast on the target
6909   bool Fast = false;
6910   bool Allowed =
6911       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
6912                              *FirstLoad->getMemOperand(), &Fast);
6913   if (!Allowed || !Fast)
6914     return SDValue();
6915 
6916   SDValue NewLoad = DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
6917                                    SDLoc(N), VT, Chain, FirstLoad->getBasePtr(),
6918                                    FirstLoad->getPointerInfo(), MemVT,
6919                                    FirstLoad->getAlignment());
6920 
6921   // Transfer chain users from old loads to the new load.
6922   for (LoadSDNode *L : Loads)
6923     DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
6924 
6925   if (!NeedsBswap)
6926     return NewLoad;
6927 
6928   SDValue ShiftedLoad =
6929       NeedsZext
6930           ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
6931                         DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
6932                                                    SDLoc(N), LegalOperations))
6933           : NewLoad;
6934   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
6935 }
6936 
6937 // If the target has andn, bsl, or a similar bit-select instruction,
6938 // we want to unfold masked merge, with canonical pattern of:
6939 //   |        A  |  |B|
6940 //   ((x ^ y) & m) ^ y
6941 //    |  D  |
6942 // Into:
6943 //   (x & m) | (y & ~m)
6944 // If y is a constant, and the 'andn' does not work with immediates,
6945 // we unfold into a different pattern:
6946 //   ~(~x & m) & (m | y)
6947 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
6948 //       the very least that breaks andnpd / andnps patterns, and because those
6949 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)6950 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
6951   assert(N->getOpcode() == ISD::XOR);
6952 
6953   // Don't touch 'not' (i.e. where y = -1).
6954   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
6955     return SDValue();
6956 
6957   EVT VT = N->getValueType(0);
6958 
6959   // There are 3 commutable operators in the pattern,
6960   // so we have to deal with 8 possible variants of the basic pattern.
6961   SDValue X, Y, M;
6962   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
6963     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
6964       return false;
6965     SDValue Xor = And.getOperand(XorIdx);
6966     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
6967       return false;
6968     SDValue Xor0 = Xor.getOperand(0);
6969     SDValue Xor1 = Xor.getOperand(1);
6970     // Don't touch 'not' (i.e. where y = -1).
6971     if (isAllOnesOrAllOnesSplat(Xor1))
6972       return false;
6973     if (Other == Xor0)
6974       std::swap(Xor0, Xor1);
6975     if (Other != Xor1)
6976       return false;
6977     X = Xor0;
6978     Y = Xor1;
6979     M = And.getOperand(XorIdx ? 0 : 1);
6980     return true;
6981   };
6982 
6983   SDValue N0 = N->getOperand(0);
6984   SDValue N1 = N->getOperand(1);
6985   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
6986       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
6987     return SDValue();
6988 
6989   // Don't do anything if the mask is constant. This should not be reachable.
6990   // InstCombine should have already unfolded this pattern, and DAGCombiner
6991   // probably shouldn't produce it, too.
6992   if (isa<ConstantSDNode>(M.getNode()))
6993     return SDValue();
6994 
6995   // We can transform if the target has AndNot
6996   if (!TLI.hasAndNot(M))
6997     return SDValue();
6998 
6999   SDLoc DL(N);
7000 
7001   // If Y is a constant, check that 'andn' works with immediates.
7002   if (!TLI.hasAndNot(Y)) {
7003     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
7004     // If not, we need to do a bit more work to make sure andn is still used.
7005     SDValue NotX = DAG.getNOT(DL, X, VT);
7006     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
7007     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
7008     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
7009     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
7010   }
7011 
7012   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
7013   SDValue NotM = DAG.getNOT(DL, M, VT);
7014   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
7015 
7016   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
7017 }
7018 
visitXOR(SDNode * N)7019 SDValue DAGCombiner::visitXOR(SDNode *N) {
7020   SDValue N0 = N->getOperand(0);
7021   SDValue N1 = N->getOperand(1);
7022   EVT VT = N0.getValueType();
7023 
7024   // fold vector ops
7025   if (VT.isVector()) {
7026     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7027       return FoldedVOp;
7028 
7029     // fold (xor x, 0) -> x, vector edition
7030     if (ISD::isBuildVectorAllZeros(N0.getNode()))
7031       return N1;
7032     if (ISD::isBuildVectorAllZeros(N1.getNode()))
7033       return N0;
7034   }
7035 
7036   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
7037   SDLoc DL(N);
7038   if (N0.isUndef() && N1.isUndef())
7039     return DAG.getConstant(0, DL, VT);
7040   // fold (xor x, undef) -> undef
7041   if (N0.isUndef())
7042     return N0;
7043   if (N1.isUndef())
7044     return N1;
7045   // fold (xor c1, c2) -> c1^c2
7046   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
7047   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
7048   if (N0C && N1C)
7049     return DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, N0C, N1C);
7050   // canonicalize constant to RHS
7051   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7052      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7053     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
7054   // fold (xor x, 0) -> x
7055   if (isNullConstant(N1))
7056     return N0;
7057 
7058   if (SDValue NewSel = foldBinOpIntoSelect(N))
7059     return NewSel;
7060 
7061   // reassociate xor
7062   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
7063     return RXOR;
7064 
7065   // fold !(x cc y) -> (x !cc y)
7066   unsigned N0Opcode = N0.getOpcode();
7067   SDValue LHS, RHS, CC;
7068   if (TLI.isConstTrueVal(N1.getNode()) && isSetCCEquivalent(N0, LHS, RHS, CC)) {
7069     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
7070                                                LHS.getValueType());
7071     if (!LegalOperations ||
7072         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
7073       switch (N0Opcode) {
7074       default:
7075         llvm_unreachable("Unhandled SetCC Equivalent!");
7076       case ISD::SETCC:
7077         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
7078       case ISD::SELECT_CC:
7079         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
7080                                N0.getOperand(3), NotCC);
7081       }
7082     }
7083   }
7084 
7085   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
7086   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
7087       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
7088     SDValue V = N0.getOperand(0);
7089     SDLoc DL0(N0);
7090     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
7091                     DAG.getConstant(1, DL0, V.getValueType()));
7092     AddToWorklist(V.getNode());
7093     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
7094   }
7095 
7096   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
7097   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
7098       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
7099     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
7100     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
7101       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
7102       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
7103       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
7104       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
7105       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
7106     }
7107   }
7108   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
7109   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
7110       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
7111     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
7112     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
7113       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
7114       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
7115       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
7116       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
7117       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
7118     }
7119   }
7120 
7121   // fold (not (neg x)) -> (add X, -1)
7122   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
7123   // Y is a constant or the subtract has a single use.
7124   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
7125       isNullConstant(N0.getOperand(0))) {
7126     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
7127                        DAG.getAllOnesConstant(DL, VT));
7128   }
7129 
7130   // fold (not (add X, -1)) -> (neg X)
7131   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
7132       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
7133     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
7134                        N0.getOperand(0));
7135   }
7136 
7137   // fold (xor (and x, y), y) -> (and (not x), y)
7138   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
7139     SDValue X = N0.getOperand(0);
7140     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
7141     AddToWorklist(NotX.getNode());
7142     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
7143   }
7144 
7145   if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) {
7146     ConstantSDNode *XorC = isConstOrConstSplat(N1);
7147     ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1));
7148     unsigned BitWidth = VT.getScalarSizeInBits();
7149     if (XorC && ShiftC) {
7150       // Don't crash on an oversized shift. We can not guarantee that a bogus
7151       // shift has been simplified to undef.
7152       uint64_t ShiftAmt = ShiftC->getLimitedValue();
7153       if (ShiftAmt < BitWidth) {
7154         APInt Ones = APInt::getAllOnesValue(BitWidth);
7155         Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt);
7156         if (XorC->getAPIntValue() == Ones) {
7157           // If the xor constant is a shifted -1, do a 'not' before the shift:
7158           // xor (X << ShiftC), XorC --> (not X) << ShiftC
7159           // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
7160           SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
7161           return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1));
7162         }
7163       }
7164     }
7165   }
7166 
7167   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
7168   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
7169     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
7170     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
7171     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
7172       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
7173       SDValue S0 = S.getOperand(0);
7174       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0)) {
7175         unsigned OpSizeInBits = VT.getScalarSizeInBits();
7176         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
7177           if (C->getAPIntValue() == (OpSizeInBits - 1))
7178             return DAG.getNode(ISD::ABS, DL, VT, S0);
7179       }
7180     }
7181   }
7182 
7183   // fold (xor x, x) -> 0
7184   if (N0 == N1)
7185     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
7186 
7187   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
7188   // Here is a concrete example of this equivalence:
7189   // i16   x ==  14
7190   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
7191   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
7192   //
7193   // =>
7194   //
7195   // i16     ~1      == 0b1111111111111110
7196   // i16 rol(~1, 14) == 0b1011111111111111
7197   //
7198   // Some additional tips to help conceptualize this transform:
7199   // - Try to see the operation as placing a single zero in a value of all ones.
7200   // - There exists no value for x which would allow the result to contain zero.
7201   // - Values of x larger than the bitwidth are undefined and do not require a
7202   //   consistent result.
7203   // - Pushing the zero left requires shifting one bits in from the right.
7204   // A rotate left of ~1 is a nice way of achieving the desired result.
7205   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
7206       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
7207     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
7208                        N0.getOperand(1));
7209   }
7210 
7211   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
7212   if (N0Opcode == N1.getOpcode())
7213     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7214       return V;
7215 
7216   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
7217   if (SDValue MM = unfoldMaskedMerge(N))
7218     return MM;
7219 
7220   // Simplify the expression using non-local knowledge.
7221   if (SimplifyDemandedBits(SDValue(N, 0)))
7222     return SDValue(N, 0);
7223 
7224   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
7225     return Combined;
7226 
7227   return SDValue();
7228 }
7229 
7230 /// If we have a shift-by-constant of a bitwise logic op that itself has a
7231 /// shift-by-constant operand with identical opcode, we may be able to convert
7232 /// that into 2 independent shifts followed by the logic op. This is a
7233 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)7234 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
7235   // Match a one-use bitwise logic op.
7236   SDValue LogicOp = Shift->getOperand(0);
7237   if (!LogicOp.hasOneUse())
7238     return SDValue();
7239 
7240   unsigned LogicOpcode = LogicOp.getOpcode();
7241   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
7242       LogicOpcode != ISD::XOR)
7243     return SDValue();
7244 
7245   // Find a matching one-use shift by constant.
7246   unsigned ShiftOpcode = Shift->getOpcode();
7247   SDValue C1 = Shift->getOperand(1);
7248   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
7249   assert(C1Node && "Expected a shift with constant operand");
7250   const APInt &C1Val = C1Node->getAPIntValue();
7251   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
7252                              const APInt *&ShiftAmtVal) {
7253     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
7254       return false;
7255 
7256     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
7257     if (!ShiftCNode)
7258       return false;
7259 
7260     // Capture the shifted operand and shift amount value.
7261     ShiftOp = V.getOperand(0);
7262     ShiftAmtVal = &ShiftCNode->getAPIntValue();
7263 
7264     // Shift amount types do not have to match their operand type, so check that
7265     // the constants are the same width.
7266     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
7267       return false;
7268 
7269     // The fold is not valid if the sum of the shift values exceeds bitwidth.
7270     if ((*ShiftAmtVal + C1Val).uge(V.getScalarValueSizeInBits()))
7271       return false;
7272 
7273     return true;
7274   };
7275 
7276   // Logic ops are commutative, so check each operand for a match.
7277   SDValue X, Y;
7278   const APInt *C0Val;
7279   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
7280     Y = LogicOp.getOperand(1);
7281   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
7282     Y = LogicOp.getOperand(0);
7283   else
7284     return SDValue();
7285 
7286   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
7287   SDLoc DL(Shift);
7288   EVT VT = Shift->getValueType(0);
7289   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
7290   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
7291   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
7292   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
7293   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
7294 }
7295 
7296 /// Handle transforms common to the three shifts, when the shift amount is a
7297 /// constant.
7298 /// We are looking for: (shift being one of shl/sra/srl)
7299 ///   shift (binop X, C0), C1
7300 /// And want to transform into:
7301 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)7302 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
7303   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
7304 
7305   // Do not turn a 'not' into a regular xor.
7306   if (isBitwiseNot(N->getOperand(0)))
7307     return SDValue();
7308 
7309   // The inner binop must be one-use, since we want to replace it.
7310   SDValue LHS = N->getOperand(0);
7311   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
7312     return SDValue();
7313 
7314   // TODO: This is limited to early combining because it may reveal regressions
7315   //       otherwise. But since we just checked a target hook to see if this is
7316   //       desirable, that should have filtered out cases where this interferes
7317   //       with some other pattern matching.
7318   if (!LegalTypes)
7319     if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
7320       return R;
7321 
7322   // We want to pull some binops through shifts, so that we have (and (shift))
7323   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
7324   // thing happens with address calculations, so it's important to canonicalize
7325   // it.
7326   switch (LHS.getOpcode()) {
7327   default:
7328     return SDValue();
7329   case ISD::OR:
7330   case ISD::XOR:
7331   case ISD::AND:
7332     break;
7333   case ISD::ADD:
7334     if (N->getOpcode() != ISD::SHL)
7335       return SDValue(); // only shl(add) not sr[al](add).
7336     break;
7337   }
7338 
7339   // We require the RHS of the binop to be a constant and not opaque as well.
7340   ConstantSDNode *BinOpCst = getAsNonOpaqueConstant(LHS.getOperand(1));
7341   if (!BinOpCst)
7342     return SDValue();
7343 
7344   // FIXME: disable this unless the input to the binop is a shift by a constant
7345   // or is copy/select. Enable this in other cases when figure out it's exactly
7346   // profitable.
7347   SDValue BinOpLHSVal = LHS.getOperand(0);
7348   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
7349                             BinOpLHSVal.getOpcode() == ISD::SRA ||
7350                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
7351                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
7352   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
7353                         BinOpLHSVal.getOpcode() == ISD::SELECT;
7354 
7355   if (!IsShiftByConstant && !IsCopyOrSelect)
7356     return SDValue();
7357 
7358   if (IsCopyOrSelect && N->hasOneUse())
7359     return SDValue();
7360 
7361   // Fold the constants, shifting the binop RHS by the shift amount.
7362   SDLoc DL(N);
7363   EVT VT = N->getValueType(0);
7364   SDValue NewRHS = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(1),
7365                                N->getOperand(1));
7366   assert(isa<ConstantSDNode>(NewRHS) && "Folding was not successful!");
7367 
7368   SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
7369                                  N->getOperand(1));
7370   return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
7371 }
7372 
distributeTruncateThroughAnd(SDNode * N)7373 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
7374   assert(N->getOpcode() == ISD::TRUNCATE);
7375   assert(N->getOperand(0).getOpcode() == ISD::AND);
7376 
7377   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
7378   EVT TruncVT = N->getValueType(0);
7379   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
7380       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
7381     SDValue N01 = N->getOperand(0).getOperand(1);
7382     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
7383       SDLoc DL(N);
7384       SDValue N00 = N->getOperand(0).getOperand(0);
7385       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
7386       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
7387       AddToWorklist(Trunc00.getNode());
7388       AddToWorklist(Trunc01.getNode());
7389       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
7390     }
7391   }
7392 
7393   return SDValue();
7394 }
7395 
visitRotate(SDNode * N)7396 SDValue DAGCombiner::visitRotate(SDNode *N) {
7397   SDLoc dl(N);
7398   SDValue N0 = N->getOperand(0);
7399   SDValue N1 = N->getOperand(1);
7400   EVT VT = N->getValueType(0);
7401   unsigned Bitsize = VT.getScalarSizeInBits();
7402 
7403   // fold (rot x, 0) -> x
7404   if (isNullOrNullSplat(N1))
7405     return N0;
7406 
7407   // fold (rot x, c) -> x iff (c % BitSize) == 0
7408   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
7409     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
7410     if (DAG.MaskedValueIsZero(N1, ModuloMask))
7411       return N0;
7412   }
7413 
7414   // fold (rot x, c) -> (rot x, c % BitSize)
7415   // TODO - support non-uniform vector amounts.
7416   if (ConstantSDNode *Cst = isConstOrConstSplat(N1)) {
7417     if (Cst->getAPIntValue().uge(Bitsize)) {
7418       uint64_t RotAmt = Cst->getAPIntValue().urem(Bitsize);
7419       return DAG.getNode(N->getOpcode(), dl, VT, N0,
7420                          DAG.getConstant(RotAmt, dl, N1.getValueType()));
7421     }
7422   }
7423 
7424   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
7425   if (N1.getOpcode() == ISD::TRUNCATE &&
7426       N1.getOperand(0).getOpcode() == ISD::AND) {
7427     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
7428       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
7429   }
7430 
7431   unsigned NextOp = N0.getOpcode();
7432   // fold (rot* (rot* x, c2), c1) -> (rot* x, c1 +- c2 % bitsize)
7433   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
7434     SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
7435     SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
7436     if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
7437       EVT ShiftVT = C1->getValueType(0);
7438       bool SameSide = (N->getOpcode() == NextOp);
7439       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
7440       if (SDValue CombinedShift =
7441               DAG.FoldConstantArithmetic(CombineOp, dl, ShiftVT, C1, C2)) {
7442         SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
7443         SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
7444             ISD::SREM, dl, ShiftVT, CombinedShift.getNode(),
7445             BitsizeC.getNode());
7446         return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
7447                            CombinedShiftNorm);
7448       }
7449     }
7450   }
7451   return SDValue();
7452 }
7453 
visitSHL(SDNode * N)7454 SDValue DAGCombiner::visitSHL(SDNode *N) {
7455   SDValue N0 = N->getOperand(0);
7456   SDValue N1 = N->getOperand(1);
7457   if (SDValue V = DAG.simplifyShift(N0, N1))
7458     return V;
7459 
7460   EVT VT = N0.getValueType();
7461   EVT ShiftVT = N1.getValueType();
7462   unsigned OpSizeInBits = VT.getScalarSizeInBits();
7463 
7464   // fold vector ops
7465   if (VT.isVector()) {
7466     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7467       return FoldedVOp;
7468 
7469     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
7470     // If setcc produces all-one true value then:
7471     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
7472     if (N1CV && N1CV->isConstant()) {
7473       if (N0.getOpcode() == ISD::AND) {
7474         SDValue N00 = N0->getOperand(0);
7475         SDValue N01 = N0->getOperand(1);
7476         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
7477 
7478         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
7479             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
7480                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
7481           if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT,
7482                                                      N01CV, N1CV))
7483             return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
7484         }
7485       }
7486     }
7487   }
7488 
7489   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7490 
7491   // fold (shl c1, c2) -> c1<<c2
7492   // TODO - support non-uniform vector shift amounts.
7493   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
7494   if (N0C && N1C && !N1C->isOpaque())
7495     return DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, N0C, N1C);
7496 
7497   if (SDValue NewSel = foldBinOpIntoSelect(N))
7498     return NewSel;
7499 
7500   // if (shl x, c) is known to be zero, return 0
7501   if (DAG.MaskedValueIsZero(SDValue(N, 0),
7502                             APInt::getAllOnesValue(OpSizeInBits)))
7503     return DAG.getConstant(0, SDLoc(N), VT);
7504 
7505   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
7506   if (N1.getOpcode() == ISD::TRUNCATE &&
7507       N1.getOperand(0).getOpcode() == ISD::AND) {
7508     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
7509       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
7510   }
7511 
7512   // TODO - support non-uniform vector shift amounts.
7513   if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
7514     return SDValue(N, 0);
7515 
7516   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
7517   if (N0.getOpcode() == ISD::SHL) {
7518     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
7519                                           ConstantSDNode *RHS) {
7520       APInt c1 = LHS->getAPIntValue();
7521       APInt c2 = RHS->getAPIntValue();
7522       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7523       return (c1 + c2).uge(OpSizeInBits);
7524     };
7525     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
7526       return DAG.getConstant(0, SDLoc(N), VT);
7527 
7528     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
7529                                        ConstantSDNode *RHS) {
7530       APInt c1 = LHS->getAPIntValue();
7531       APInt c2 = RHS->getAPIntValue();
7532       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7533       return (c1 + c2).ult(OpSizeInBits);
7534     };
7535     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
7536       SDLoc DL(N);
7537       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
7538       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
7539     }
7540   }
7541 
7542   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
7543   // For this to be valid, the second form must not preserve any of the bits
7544   // that are shifted out by the inner shift in the first form.  This means
7545   // the outer shift size must be >= the number of bits added by the ext.
7546   // As a corollary, we don't care what kind of ext it is.
7547   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
7548        N0.getOpcode() == ISD::ANY_EXTEND ||
7549        N0.getOpcode() == ISD::SIGN_EXTEND) &&
7550       N0.getOperand(0).getOpcode() == ISD::SHL) {
7551     SDValue N0Op0 = N0.getOperand(0);
7552     SDValue InnerShiftAmt = N0Op0.getOperand(1);
7553     EVT InnerVT = N0Op0.getValueType();
7554     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
7555 
7556     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
7557                                                          ConstantSDNode *RHS) {
7558       APInt c1 = LHS->getAPIntValue();
7559       APInt c2 = RHS->getAPIntValue();
7560       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7561       return c2.uge(OpSizeInBits - InnerBitwidth) &&
7562              (c1 + c2).uge(OpSizeInBits);
7563     };
7564     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
7565                                   /*AllowUndefs*/ false,
7566                                   /*AllowTypeMismatch*/ true))
7567       return DAG.getConstant(0, SDLoc(N), VT);
7568 
7569     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
7570                                                       ConstantSDNode *RHS) {
7571       APInt c1 = LHS->getAPIntValue();
7572       APInt c2 = RHS->getAPIntValue();
7573       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7574       return c2.uge(OpSizeInBits - InnerBitwidth) &&
7575              (c1 + c2).ult(OpSizeInBits);
7576     };
7577     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
7578                                   /*AllowUndefs*/ false,
7579                                   /*AllowTypeMismatch*/ true)) {
7580       SDLoc DL(N);
7581       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
7582       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
7583       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
7584       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
7585     }
7586   }
7587 
7588   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
7589   // Only fold this if the inner zext has no other uses to avoid increasing
7590   // the total number of instructions.
7591   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
7592       N0.getOperand(0).getOpcode() == ISD::SRL) {
7593     SDValue N0Op0 = N0.getOperand(0);
7594     SDValue InnerShiftAmt = N0Op0.getOperand(1);
7595 
7596     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7597       APInt c1 = LHS->getAPIntValue();
7598       APInt c2 = RHS->getAPIntValue();
7599       zeroExtendToMatch(c1, c2);
7600       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
7601     };
7602     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
7603                                   /*AllowUndefs*/ false,
7604                                   /*AllowTypeMismatch*/ true)) {
7605       SDLoc DL(N);
7606       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
7607       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
7608       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
7609       AddToWorklist(NewSHL.getNode());
7610       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
7611     }
7612   }
7613 
7614   // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
7615   // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1  > C2
7616   // TODO - support non-uniform vector shift amounts.
7617   if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
7618       N0->getFlags().hasExact()) {
7619     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
7620       uint64_t C1 = N0C1->getZExtValue();
7621       uint64_t C2 = N1C->getZExtValue();
7622       SDLoc DL(N);
7623       if (C1 <= C2)
7624         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
7625                            DAG.getConstant(C2 - C1, DL, ShiftVT));
7626       return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0),
7627                          DAG.getConstant(C1 - C2, DL, ShiftVT));
7628     }
7629   }
7630 
7631   // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
7632   //                               (and (srl x, (sub c1, c2), MASK)
7633   // Only fold this if the inner shift has no other uses -- if it does, folding
7634   // this will increase the total number of instructions.
7635   // TODO - drop hasOneUse requirement if c1 == c2?
7636   // TODO - support non-uniform vector shift amounts.
7637   if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() &&
7638       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
7639     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
7640       if (N0C1->getAPIntValue().ult(OpSizeInBits)) {
7641         uint64_t c1 = N0C1->getZExtValue();
7642         uint64_t c2 = N1C->getZExtValue();
7643         APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1);
7644         SDValue Shift;
7645         if (c2 > c1) {
7646           Mask <<= c2 - c1;
7647           SDLoc DL(N);
7648           Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
7649                               DAG.getConstant(c2 - c1, DL, ShiftVT));
7650         } else {
7651           Mask.lshrInPlace(c1 - c2);
7652           SDLoc DL(N);
7653           Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
7654                               DAG.getConstant(c1 - c2, DL, ShiftVT));
7655         }
7656         SDLoc DL(N0);
7657         return DAG.getNode(ISD::AND, DL, VT, Shift,
7658                            DAG.getConstant(Mask, DL, VT));
7659       }
7660     }
7661   }
7662 
7663   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
7664   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
7665       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
7666     SDLoc DL(N);
7667     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
7668     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
7669     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
7670   }
7671 
7672   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
7673   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
7674   // Variant of version done on multiply, except mul by a power of 2 is turned
7675   // into a shift.
7676   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
7677       N0.getNode()->hasOneUse() &&
7678       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
7679       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
7680       TLI.isDesirableToCommuteWithShift(N, Level)) {
7681     SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
7682     SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
7683     AddToWorklist(Shl0.getNode());
7684     AddToWorklist(Shl1.getNode());
7685     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
7686   }
7687 
7688   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
7689   if (N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse() &&
7690       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
7691       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) {
7692     SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
7693     if (isConstantOrConstantVector(Shl))
7694       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
7695   }
7696 
7697   if (N1C && !N1C->isOpaque())
7698     if (SDValue NewSHL = visitShiftByConstant(N))
7699       return NewSHL;
7700 
7701   return SDValue();
7702 }
7703 
visitSRA(SDNode * N)7704 SDValue DAGCombiner::visitSRA(SDNode *N) {
7705   SDValue N0 = N->getOperand(0);
7706   SDValue N1 = N->getOperand(1);
7707   if (SDValue V = DAG.simplifyShift(N0, N1))
7708     return V;
7709 
7710   EVT VT = N0.getValueType();
7711   unsigned OpSizeInBits = VT.getScalarSizeInBits();
7712 
7713   // Arithmetic shifting an all-sign-bit value is a no-op.
7714   // fold (sra 0, x) -> 0
7715   // fold (sra -1, x) -> -1
7716   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
7717     return N0;
7718 
7719   // fold vector ops
7720   if (VT.isVector())
7721     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7722       return FoldedVOp;
7723 
7724   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7725 
7726   // fold (sra c1, c2) -> (sra c1, c2)
7727   // TODO - support non-uniform vector shift amounts.
7728   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
7729   if (N0C && N1C && !N1C->isOpaque())
7730     return DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, N0C, N1C);
7731 
7732   if (SDValue NewSel = foldBinOpIntoSelect(N))
7733     return NewSel;
7734 
7735   // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
7736   // sext_inreg.
7737   if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
7738     unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
7739     EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
7740     if (VT.isVector())
7741       ExtVT = EVT::getVectorVT(*DAG.getContext(),
7742                                ExtVT, VT.getVectorNumElements());
7743     if (!LegalOperations ||
7744         TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
7745         TargetLowering::Legal)
7746       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
7747                          N0.getOperand(0), DAG.getValueType(ExtVT));
7748   }
7749 
7750   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
7751   // clamp (add c1, c2) to max shift.
7752   if (N0.getOpcode() == ISD::SRA) {
7753     SDLoc DL(N);
7754     EVT ShiftVT = N1.getValueType();
7755     EVT ShiftSVT = ShiftVT.getScalarType();
7756     SmallVector<SDValue, 16> ShiftValues;
7757 
7758     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7759       APInt c1 = LHS->getAPIntValue();
7760       APInt c2 = RHS->getAPIntValue();
7761       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7762       APInt Sum = c1 + c2;
7763       unsigned ShiftSum =
7764           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
7765       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
7766       return true;
7767     };
7768     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
7769       SDValue ShiftValue;
7770       if (VT.isVector())
7771         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
7772       else
7773         ShiftValue = ShiftValues[0];
7774       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
7775     }
7776   }
7777 
7778   // fold (sra (shl X, m), (sub result_size, n))
7779   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
7780   // result_size - n != m.
7781   // If truncate is free for the target sext(shl) is likely to result in better
7782   // code.
7783   if (N0.getOpcode() == ISD::SHL && N1C) {
7784     // Get the two constanst of the shifts, CN0 = m, CN = n.
7785     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
7786     if (N01C) {
7787       LLVMContext &Ctx = *DAG.getContext();
7788       // Determine what the truncate's result bitsize and type would be.
7789       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
7790 
7791       if (VT.isVector())
7792         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements());
7793 
7794       // Determine the residual right-shift amount.
7795       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
7796 
7797       // If the shift is not a no-op (in which case this should be just a sign
7798       // extend already), the truncated to type is legal, sign_extend is legal
7799       // on that type, and the truncate to that type is both legal and free,
7800       // perform the transform.
7801       if ((ShiftAmt > 0) &&
7802           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
7803           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
7804           TLI.isTruncateFree(VT, TruncVT)) {
7805         SDLoc DL(N);
7806         SDValue Amt = DAG.getConstant(ShiftAmt, DL,
7807             getShiftAmountTy(N0.getOperand(0).getValueType()));
7808         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
7809                                     N0.getOperand(0), Amt);
7810         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
7811                                     Shift);
7812         return DAG.getNode(ISD::SIGN_EXTEND, DL,
7813                            N->getValueType(0), Trunc);
7814       }
7815     }
7816   }
7817 
7818   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
7819   //   sra (add (shl X, N1C), AddC), N1C -->
7820   //   sext (add (trunc X to (width - N1C)), AddC')
7821   if (!LegalTypes && N0.getOpcode() == ISD::ADD && N0.hasOneUse() && N1C &&
7822       N0.getOperand(0).getOpcode() == ISD::SHL &&
7823       N0.getOperand(0).getOperand(1) == N1 && N0.getOperand(0).hasOneUse()) {
7824     if (ConstantSDNode *AddC = isConstOrConstSplat(N0.getOperand(1))) {
7825       SDValue Shl = N0.getOperand(0);
7826       // Determine what the truncate's type would be and ask the target if that
7827       // is a free operation.
7828       LLVMContext &Ctx = *DAG.getContext();
7829       unsigned ShiftAmt = N1C->getZExtValue();
7830       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
7831       if (VT.isVector())
7832         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements());
7833 
7834       // TODO: The simple type check probably belongs in the default hook
7835       //       implementation and/or target-specific overrides (because
7836       //       non-simple types likely require masking when legalized), but that
7837       //       restriction may conflict with other transforms.
7838       if (TruncVT.isSimple() && TLI.isTruncateFree(VT, TruncVT)) {
7839         SDLoc DL(N);
7840         SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
7841         SDValue ShiftC = DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).
7842                              trunc(TruncVT.getScalarSizeInBits()), DL, TruncVT);
7843         SDValue Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
7844         return DAG.getSExtOrTrunc(Add, DL, VT);
7845       }
7846     }
7847   }
7848 
7849   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
7850   if (N1.getOpcode() == ISD::TRUNCATE &&
7851       N1.getOperand(0).getOpcode() == ISD::AND) {
7852     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
7853       return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
7854   }
7855 
7856   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
7857   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
7858   //      if c1 is equal to the number of bits the trunc removes
7859   // TODO - support non-uniform vector shift amounts.
7860   if (N0.getOpcode() == ISD::TRUNCATE &&
7861       (N0.getOperand(0).getOpcode() == ISD::SRL ||
7862        N0.getOperand(0).getOpcode() == ISD::SRA) &&
7863       N0.getOperand(0).hasOneUse() &&
7864       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
7865     SDValue N0Op0 = N0.getOperand(0);
7866     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
7867       EVT LargeVT = N0Op0.getValueType();
7868       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
7869       if (LargeShift->getAPIntValue() == TruncBits) {
7870         SDLoc DL(N);
7871         SDValue Amt = DAG.getConstant(N1C->getZExtValue() + TruncBits, DL,
7872                                       getShiftAmountTy(LargeVT));
7873         SDValue SRA =
7874             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
7875         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
7876       }
7877     }
7878   }
7879 
7880   // Simplify, based on bits shifted out of the LHS.
7881   // TODO - support non-uniform vector shift amounts.
7882   if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
7883     return SDValue(N, 0);
7884 
7885   // If the sign bit is known to be zero, switch this to a SRL.
7886   if (DAG.SignBitIsZero(N0))
7887     return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
7888 
7889   if (N1C && !N1C->isOpaque())
7890     if (SDValue NewSRA = visitShiftByConstant(N))
7891       return NewSRA;
7892 
7893   return SDValue();
7894 }
7895 
visitSRL(SDNode * N)7896 SDValue DAGCombiner::visitSRL(SDNode *N) {
7897   SDValue N0 = N->getOperand(0);
7898   SDValue N1 = N->getOperand(1);
7899   if (SDValue V = DAG.simplifyShift(N0, N1))
7900     return V;
7901 
7902   EVT VT = N0.getValueType();
7903   unsigned OpSizeInBits = VT.getScalarSizeInBits();
7904 
7905   // fold vector ops
7906   if (VT.isVector())
7907     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7908       return FoldedVOp;
7909 
7910   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7911 
7912   // fold (srl c1, c2) -> c1 >>u c2
7913   // TODO - support non-uniform vector shift amounts.
7914   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
7915   if (N0C && N1C && !N1C->isOpaque())
7916     return DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, N0C, N1C);
7917 
7918   if (SDValue NewSel = foldBinOpIntoSelect(N))
7919     return NewSel;
7920 
7921   // if (srl x, c) is known to be zero, return 0
7922   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
7923                                    APInt::getAllOnesValue(OpSizeInBits)))
7924     return DAG.getConstant(0, SDLoc(N), VT);
7925 
7926   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
7927   if (N0.getOpcode() == ISD::SRL) {
7928     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
7929                                           ConstantSDNode *RHS) {
7930       APInt c1 = LHS->getAPIntValue();
7931       APInt c2 = RHS->getAPIntValue();
7932       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7933       return (c1 + c2).uge(OpSizeInBits);
7934     };
7935     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
7936       return DAG.getConstant(0, SDLoc(N), VT);
7937 
7938     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
7939                                        ConstantSDNode *RHS) {
7940       APInt c1 = LHS->getAPIntValue();
7941       APInt c2 = RHS->getAPIntValue();
7942       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7943       return (c1 + c2).ult(OpSizeInBits);
7944     };
7945     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
7946       SDLoc DL(N);
7947       EVT ShiftVT = N1.getValueType();
7948       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
7949       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
7950     }
7951   }
7952 
7953   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
7954       N0.getOperand(0).getOpcode() == ISD::SRL) {
7955     SDValue InnerShift = N0.getOperand(0);
7956     // TODO - support non-uniform vector shift amounts.
7957     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
7958       uint64_t c1 = N001C->getZExtValue();
7959       uint64_t c2 = N1C->getZExtValue();
7960       EVT InnerShiftVT = InnerShift.getValueType();
7961       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
7962       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
7963       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
7964       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
7965       if (c1 + OpSizeInBits == InnerShiftSize) {
7966         SDLoc DL(N);
7967         if (c1 + c2 >= InnerShiftSize)
7968           return DAG.getConstant(0, DL, VT);
7969         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
7970         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
7971                                        InnerShift.getOperand(0), NewShiftAmt);
7972         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
7973       }
7974       // In the more general case, we can clear the high bits after the shift:
7975       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
7976       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
7977           c1 + c2 < InnerShiftSize) {
7978         SDLoc DL(N);
7979         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
7980         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
7981                                        InnerShift.getOperand(0), NewShiftAmt);
7982         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
7983                                                             OpSizeInBits - c2),
7984                                        DL, InnerShiftVT);
7985         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
7986         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
7987       }
7988     }
7989   }
7990 
7991   // fold (srl (shl x, c), c) -> (and x, cst2)
7992   // TODO - (srl (shl x, c1), c2).
7993   if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 &&
7994       isConstantOrConstantVector(N1, /* NoOpaques */ true)) {
7995     SDLoc DL(N);
7996     SDValue Mask =
7997         DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1);
7998     AddToWorklist(Mask.getNode());
7999     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask);
8000   }
8001 
8002   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
8003   // TODO - support non-uniform vector shift amounts.
8004   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
8005     // Shifting in all undef bits?
8006     EVT SmallVT = N0.getOperand(0).getValueType();
8007     unsigned BitSize = SmallVT.getScalarSizeInBits();
8008     if (N1C->getAPIntValue().uge(BitSize))
8009       return DAG.getUNDEF(VT);
8010 
8011     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
8012       uint64_t ShiftAmt = N1C->getZExtValue();
8013       SDLoc DL0(N0);
8014       SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
8015                                        N0.getOperand(0),
8016                           DAG.getConstant(ShiftAmt, DL0,
8017                                           getShiftAmountTy(SmallVT)));
8018       AddToWorklist(SmallShift.getNode());
8019       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
8020       SDLoc DL(N);
8021       return DAG.getNode(ISD::AND, DL, VT,
8022                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
8023                          DAG.getConstant(Mask, DL, VT));
8024     }
8025   }
8026 
8027   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
8028   // bit, which is unmodified by sra.
8029   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
8030     if (N0.getOpcode() == ISD::SRA)
8031       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
8032   }
8033 
8034   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit).
8035   if (N1C && N0.getOpcode() == ISD::CTLZ &&
8036       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
8037     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
8038 
8039     // If any of the input bits are KnownOne, then the input couldn't be all
8040     // zeros, thus the result of the srl will always be zero.
8041     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
8042 
8043     // If all of the bits input the to ctlz node are known to be zero, then
8044     // the result of the ctlz is "32" and the result of the shift is one.
8045     APInt UnknownBits = ~Known.Zero;
8046     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
8047 
8048     // Otherwise, check to see if there is exactly one bit input to the ctlz.
8049     if (UnknownBits.isPowerOf2()) {
8050       // Okay, we know that only that the single bit specified by UnknownBits
8051       // could be set on input to the CTLZ node. If this bit is set, the SRL
8052       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
8053       // to an SRL/XOR pair, which is likely to simplify more.
8054       unsigned ShAmt = UnknownBits.countTrailingZeros();
8055       SDValue Op = N0.getOperand(0);
8056 
8057       if (ShAmt) {
8058         SDLoc DL(N0);
8059         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
8060                   DAG.getConstant(ShAmt, DL,
8061                                   getShiftAmountTy(Op.getValueType())));
8062         AddToWorklist(Op.getNode());
8063       }
8064 
8065       SDLoc DL(N);
8066       return DAG.getNode(ISD::XOR, DL, VT,
8067                          Op, DAG.getConstant(1, DL, VT));
8068     }
8069   }
8070 
8071   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
8072   if (N1.getOpcode() == ISD::TRUNCATE &&
8073       N1.getOperand(0).getOpcode() == ISD::AND) {
8074     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8075       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
8076   }
8077 
8078   // fold operands of srl based on knowledge that the low bits are not
8079   // demanded.
8080   // TODO - support non-uniform vector shift amounts.
8081   if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
8082     return SDValue(N, 0);
8083 
8084   if (N1C && !N1C->isOpaque())
8085     if (SDValue NewSRL = visitShiftByConstant(N))
8086       return NewSRL;
8087 
8088   // Attempt to convert a srl of a load into a narrower zero-extending load.
8089   if (SDValue NarrowLoad = ReduceLoadWidth(N))
8090     return NarrowLoad;
8091 
8092   // Here is a common situation. We want to optimize:
8093   //
8094   //   %a = ...
8095   //   %b = and i32 %a, 2
8096   //   %c = srl i32 %b, 1
8097   //   brcond i32 %c ...
8098   //
8099   // into
8100   //
8101   //   %a = ...
8102   //   %b = and %a, 2
8103   //   %c = setcc eq %b, 0
8104   //   brcond %c ...
8105   //
8106   // However when after the source operand of SRL is optimized into AND, the SRL
8107   // itself may not be optimized further. Look for it and add the BRCOND into
8108   // the worklist.
8109   if (N->hasOneUse()) {
8110     SDNode *Use = *N->use_begin();
8111     if (Use->getOpcode() == ISD::BRCOND)
8112       AddToWorklist(Use);
8113     else if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse()) {
8114       // Also look pass the truncate.
8115       Use = *Use->use_begin();
8116       if (Use->getOpcode() == ISD::BRCOND)
8117         AddToWorklist(Use);
8118     }
8119   }
8120 
8121   return SDValue();
8122 }
8123 
visitFunnelShift(SDNode * N)8124 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
8125   EVT VT = N->getValueType(0);
8126   SDValue N0 = N->getOperand(0);
8127   SDValue N1 = N->getOperand(1);
8128   SDValue N2 = N->getOperand(2);
8129   bool IsFSHL = N->getOpcode() == ISD::FSHL;
8130   unsigned BitWidth = VT.getScalarSizeInBits();
8131 
8132   // fold (fshl N0, N1, 0) -> N0
8133   // fold (fshr N0, N1, 0) -> N1
8134   if (isPowerOf2_32(BitWidth))
8135     if (DAG.MaskedValueIsZero(
8136             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
8137       return IsFSHL ? N0 : N1;
8138 
8139   auto IsUndefOrZero = [](SDValue V) {
8140     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
8141   };
8142 
8143   // TODO - support non-uniform vector shift amounts.
8144   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
8145     EVT ShAmtTy = N2.getValueType();
8146 
8147     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
8148     if (Cst->getAPIntValue().uge(BitWidth)) {
8149       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
8150       return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
8151                          DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
8152     }
8153 
8154     unsigned ShAmt = Cst->getZExtValue();
8155     if (ShAmt == 0)
8156       return IsFSHL ? N0 : N1;
8157 
8158     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
8159     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
8160     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
8161     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
8162     if (IsUndefOrZero(N0))
8163       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
8164                          DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
8165                                          SDLoc(N), ShAmtTy));
8166     if (IsUndefOrZero(N1))
8167       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
8168                          DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
8169                                          SDLoc(N), ShAmtTy));
8170   }
8171 
8172   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
8173   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
8174   // iff We know the shift amount is in range.
8175   // TODO: when is it worth doing SUB(BW, N2) as well?
8176   if (isPowerOf2_32(BitWidth)) {
8177     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
8178     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
8179       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
8180     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
8181       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
8182   }
8183 
8184   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
8185   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
8186   // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
8187   // is legal as well we might be better off avoiding non-constant (BW - N2).
8188   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
8189   if (N0 == N1 && hasOperation(RotOpc, VT))
8190     return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
8191 
8192   // Simplify, based on bits shifted out of N0/N1.
8193   if (SimplifyDemandedBits(SDValue(N, 0)))
8194     return SDValue(N, 0);
8195 
8196   return SDValue();
8197 }
8198 
visitABS(SDNode * N)8199 SDValue DAGCombiner::visitABS(SDNode *N) {
8200   SDValue N0 = N->getOperand(0);
8201   EVT VT = N->getValueType(0);
8202 
8203   // fold (abs c1) -> c2
8204   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8205     return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
8206   // fold (abs (abs x)) -> (abs x)
8207   if (N0.getOpcode() == ISD::ABS)
8208     return N0;
8209   // fold (abs x) -> x iff not-negative
8210   if (DAG.SignBitIsZero(N0))
8211     return N0;
8212   return SDValue();
8213 }
8214 
visitBSWAP(SDNode * N)8215 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
8216   SDValue N0 = N->getOperand(0);
8217   EVT VT = N->getValueType(0);
8218 
8219   // fold (bswap c1) -> c2
8220   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8221     return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N0);
8222   // fold (bswap (bswap x)) -> x
8223   if (N0.getOpcode() == ISD::BSWAP)
8224     return N0->getOperand(0);
8225   return SDValue();
8226 }
8227 
visitBITREVERSE(SDNode * N)8228 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
8229   SDValue N0 = N->getOperand(0);
8230   EVT VT = N->getValueType(0);
8231 
8232   // fold (bitreverse c1) -> c2
8233   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8234     return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
8235   // fold (bitreverse (bitreverse x)) -> x
8236   if (N0.getOpcode() == ISD::BITREVERSE)
8237     return N0.getOperand(0);
8238   return SDValue();
8239 }
8240 
visitCTLZ(SDNode * N)8241 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
8242   SDValue N0 = N->getOperand(0);
8243   EVT VT = N->getValueType(0);
8244 
8245   // fold (ctlz c1) -> c2
8246   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8247     return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
8248 
8249   // If the value is known never to be zero, switch to the undef version.
8250   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
8251     if (DAG.isKnownNeverZero(N0))
8252       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8253   }
8254 
8255   return SDValue();
8256 }
8257 
visitCTLZ_ZERO_UNDEF(SDNode * N)8258 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
8259   SDValue N0 = N->getOperand(0);
8260   EVT VT = N->getValueType(0);
8261 
8262   // fold (ctlz_zero_undef c1) -> c2
8263   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8264     return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8265   return SDValue();
8266 }
8267 
visitCTTZ(SDNode * N)8268 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
8269   SDValue N0 = N->getOperand(0);
8270   EVT VT = N->getValueType(0);
8271 
8272   // fold (cttz c1) -> c2
8273   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8274     return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
8275 
8276   // If the value is known never to be zero, switch to the undef version.
8277   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
8278     if (DAG.isKnownNeverZero(N0))
8279       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8280   }
8281 
8282   return SDValue();
8283 }
8284 
visitCTTZ_ZERO_UNDEF(SDNode * N)8285 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
8286   SDValue N0 = N->getOperand(0);
8287   EVT VT = N->getValueType(0);
8288 
8289   // fold (cttz_zero_undef c1) -> c2
8290   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8291     return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8292   return SDValue();
8293 }
8294 
visitCTPOP(SDNode * N)8295 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
8296   SDValue N0 = N->getOperand(0);
8297   EVT VT = N->getValueType(0);
8298 
8299   // fold (ctpop c1) -> c2
8300   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8301     return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
8302   return SDValue();
8303 }
8304 
8305 // FIXME: This should be checking for no signed zeros on individual operands, as
8306 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)8307 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
8308                                          SDValue RHS,
8309                                          const TargetLowering &TLI) {
8310   const TargetOptions &Options = DAG.getTarget().Options;
8311   EVT VT = LHS.getValueType();
8312 
8313   return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
8314          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
8315          DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
8316 }
8317 
8318 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)8319 static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
8320                                    SDValue RHS, SDValue True, SDValue False,
8321                                    ISD::CondCode CC, const TargetLowering &TLI,
8322                                    SelectionDAG &DAG) {
8323   if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True))
8324     return SDValue();
8325 
8326   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
8327   switch (CC) {
8328   case ISD::SETOLT:
8329   case ISD::SETOLE:
8330   case ISD::SETLT:
8331   case ISD::SETLE:
8332   case ISD::SETULT:
8333   case ISD::SETULE: {
8334     // Since it's known never nan to get here already, either fminnum or
8335     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
8336     // expanded in terms of it.
8337     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
8338     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
8339       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
8340 
8341     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
8342     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
8343       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
8344     return SDValue();
8345   }
8346   case ISD::SETOGT:
8347   case ISD::SETOGE:
8348   case ISD::SETGT:
8349   case ISD::SETGE:
8350   case ISD::SETUGT:
8351   case ISD::SETUGE: {
8352     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
8353     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
8354       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
8355 
8356     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
8357     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
8358       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
8359     return SDValue();
8360   }
8361   default:
8362     return SDValue();
8363   }
8364 }
8365 
8366 /// If a (v)select has a condition value that is a sign-bit test, try to smear
8367 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)8368 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
8369   SDValue Cond = N->getOperand(0);
8370   SDValue C1 = N->getOperand(1);
8371   SDValue C2 = N->getOperand(2);
8372   assert(isConstantOrConstantVector(C1) && isConstantOrConstantVector(C2) &&
8373          "Expected select-of-constants");
8374 
8375   EVT VT = N->getValueType(0);
8376   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
8377       VT != Cond.getOperand(0).getValueType())
8378     return SDValue();
8379 
8380   // The inverted-condition + commuted-select variants of these patterns are
8381   // canonicalized to these forms in IR.
8382   SDValue X = Cond.getOperand(0);
8383   SDValue CondC = Cond.getOperand(1);
8384   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
8385   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
8386       isAllOnesOrAllOnesSplat(C2)) {
8387     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
8388     SDLoc DL(N);
8389     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
8390     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
8391     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
8392   }
8393   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
8394     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
8395     SDLoc DL(N);
8396     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
8397     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
8398     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
8399   }
8400   return SDValue();
8401 }
8402 
foldSelectOfConstants(SDNode * N)8403 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
8404   SDValue Cond = N->getOperand(0);
8405   SDValue N1 = N->getOperand(1);
8406   SDValue N2 = N->getOperand(2);
8407   EVT VT = N->getValueType(0);
8408   EVT CondVT = Cond.getValueType();
8409   SDLoc DL(N);
8410 
8411   if (!VT.isInteger())
8412     return SDValue();
8413 
8414   auto *C1 = dyn_cast<ConstantSDNode>(N1);
8415   auto *C2 = dyn_cast<ConstantSDNode>(N2);
8416   if (!C1 || !C2)
8417     return SDValue();
8418 
8419   // Only do this before legalization to avoid conflicting with target-specific
8420   // transforms in the other direction (create a select from a zext/sext). There
8421   // is also a target-independent combine here in DAGCombiner in the other
8422   // direction for (select Cond, -1, 0) when the condition is not i1.
8423   if (CondVT == MVT::i1 && !LegalOperations) {
8424     if (C1->isNullValue() && C2->isOne()) {
8425       // select Cond, 0, 1 --> zext (!Cond)
8426       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
8427       if (VT != MVT::i1)
8428         NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond);
8429       return NotCond;
8430     }
8431     if (C1->isNullValue() && C2->isAllOnesValue()) {
8432       // select Cond, 0, -1 --> sext (!Cond)
8433       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
8434       if (VT != MVT::i1)
8435         NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond);
8436       return NotCond;
8437     }
8438     if (C1->isOne() && C2->isNullValue()) {
8439       // select Cond, 1, 0 --> zext (Cond)
8440       if (VT != MVT::i1)
8441         Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
8442       return Cond;
8443     }
8444     if (C1->isAllOnesValue() && C2->isNullValue()) {
8445       // select Cond, -1, 0 --> sext (Cond)
8446       if (VT != MVT::i1)
8447         Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
8448       return Cond;
8449     }
8450 
8451     // Use a target hook because some targets may prefer to transform in the
8452     // other direction.
8453     if (TLI.convertSelectOfConstantsToMath(VT)) {
8454       // For any constants that differ by 1, we can transform the select into an
8455       // extend and add.
8456       const APInt &C1Val = C1->getAPIntValue();
8457       const APInt &C2Val = C2->getAPIntValue();
8458       if (C1Val - 1 == C2Val) {
8459         // select Cond, C1, C1-1 --> add (zext Cond), C1-1
8460         if (VT != MVT::i1)
8461           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
8462         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
8463       }
8464       if (C1Val + 1 == C2Val) {
8465         // select Cond, C1, C1+1 --> add (sext Cond), C1+1
8466         if (VT != MVT::i1)
8467           Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
8468         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
8469       }
8470 
8471       // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
8472       if (C1Val.isPowerOf2() && C2Val.isNullValue()) {
8473         if (VT != MVT::i1)
8474           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
8475         SDValue ShAmtC = DAG.getConstant(C1Val.exactLogBase2(), DL, VT);
8476         return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
8477       }
8478 
8479       if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
8480         return V;
8481     }
8482 
8483     return SDValue();
8484   }
8485 
8486   // fold (select Cond, 0, 1) -> (xor Cond, 1)
8487   // We can't do this reliably if integer based booleans have different contents
8488   // to floating point based booleans. This is because we can't tell whether we
8489   // have an integer-based boolean or a floating-point-based boolean unless we
8490   // can find the SETCC that produced it and inspect its operands. This is
8491   // fairly easy if C is the SETCC node, but it can potentially be
8492   // undiscoverable (or not reasonably discoverable). For example, it could be
8493   // in another basic block or it could require searching a complicated
8494   // expression.
8495   if (CondVT.isInteger() &&
8496       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
8497           TargetLowering::ZeroOrOneBooleanContent &&
8498       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
8499           TargetLowering::ZeroOrOneBooleanContent &&
8500       C1->isNullValue() && C2->isOne()) {
8501     SDValue NotCond =
8502         DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
8503     if (VT.bitsEq(CondVT))
8504       return NotCond;
8505     return DAG.getZExtOrTrunc(NotCond, DL, VT);
8506   }
8507 
8508   return SDValue();
8509 }
8510 
visitSELECT(SDNode * N)8511 SDValue DAGCombiner::visitSELECT(SDNode *N) {
8512   SDValue N0 = N->getOperand(0);
8513   SDValue N1 = N->getOperand(1);
8514   SDValue N2 = N->getOperand(2);
8515   EVT VT = N->getValueType(0);
8516   EVT VT0 = N0.getValueType();
8517   SDLoc DL(N);
8518   SDNodeFlags Flags = N->getFlags();
8519 
8520   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
8521     return V;
8522 
8523   // fold (select X, X, Y) -> (or X, Y)
8524   // fold (select X, 1, Y) -> (or C, Y)
8525   if (VT == VT0 && VT == MVT::i1 && (N0 == N1 || isOneConstant(N1)))
8526     return DAG.getNode(ISD::OR, DL, VT, N0, N2);
8527 
8528   if (SDValue V = foldSelectOfConstants(N))
8529     return V;
8530 
8531   // fold (select C, 0, X) -> (and (not C), X)
8532   if (VT == VT0 && VT == MVT::i1 && isNullConstant(N1)) {
8533     SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
8534     AddToWorklist(NOTNode.getNode());
8535     return DAG.getNode(ISD::AND, DL, VT, NOTNode, N2);
8536   }
8537   // fold (select C, X, 1) -> (or (not C), X)
8538   if (VT == VT0 && VT == MVT::i1 && isOneConstant(N2)) {
8539     SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
8540     AddToWorklist(NOTNode.getNode());
8541     return DAG.getNode(ISD::OR, DL, VT, NOTNode, N1);
8542   }
8543   // fold (select X, Y, X) -> (and X, Y)
8544   // fold (select X, Y, 0) -> (and X, Y)
8545   if (VT == VT0 && VT == MVT::i1 && (N0 == N2 || isNullConstant(N2)))
8546     return DAG.getNode(ISD::AND, DL, VT, N0, N1);
8547 
8548   // If we can fold this based on the true/false value, do so.
8549   if (SimplifySelectOps(N, N1, N2))
8550     return SDValue(N, 0); // Don't revisit N.
8551 
8552   if (VT0 == MVT::i1) {
8553     // The code in this block deals with the following 2 equivalences:
8554     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
8555     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
8556     // The target can specify its preferred form with the
8557     // shouldNormalizeToSelectSequence() callback. However we always transform
8558     // to the right anyway if we find the inner select exists in the DAG anyway
8559     // and we always transform to the left side if we know that we can further
8560     // optimize the combination of the conditions.
8561     bool normalizeToSequence =
8562         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
8563     // select (and Cond0, Cond1), X, Y
8564     //   -> select Cond0, (select Cond1, X, Y), Y
8565     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
8566       SDValue Cond0 = N0->getOperand(0);
8567       SDValue Cond1 = N0->getOperand(1);
8568       SDValue InnerSelect =
8569           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
8570       if (normalizeToSequence || !InnerSelect.use_empty())
8571         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
8572                            InnerSelect, N2, Flags);
8573       // Cleanup on failure.
8574       if (InnerSelect.use_empty())
8575         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
8576     }
8577     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
8578     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
8579       SDValue Cond0 = N0->getOperand(0);
8580       SDValue Cond1 = N0->getOperand(1);
8581       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
8582                                         Cond1, N1, N2, Flags);
8583       if (normalizeToSequence || !InnerSelect.use_empty())
8584         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
8585                            InnerSelect, Flags);
8586       // Cleanup on failure.
8587       if (InnerSelect.use_empty())
8588         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
8589     }
8590 
8591     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
8592     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
8593       SDValue N1_0 = N1->getOperand(0);
8594       SDValue N1_1 = N1->getOperand(1);
8595       SDValue N1_2 = N1->getOperand(2);
8596       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
8597         // Create the actual and node if we can generate good code for it.
8598         if (!normalizeToSequence) {
8599           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
8600           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
8601                              N2, Flags);
8602         }
8603         // Otherwise see if we can optimize the "and" to a better pattern.
8604         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
8605           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
8606                              N2, Flags);
8607         }
8608       }
8609     }
8610     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
8611     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
8612       SDValue N2_0 = N2->getOperand(0);
8613       SDValue N2_1 = N2->getOperand(1);
8614       SDValue N2_2 = N2->getOperand(2);
8615       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
8616         // Create the actual or node if we can generate good code for it.
8617         if (!normalizeToSequence) {
8618           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
8619           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
8620                              N2_2, Flags);
8621         }
8622         // Otherwise see if we can optimize to a better pattern.
8623         if (SDValue Combined = visitORLike(N0, N2_0, N))
8624           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
8625                              N2_2, Flags);
8626       }
8627     }
8628   }
8629 
8630   // select (not Cond), N1, N2 -> select Cond, N2, N1
8631   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
8632     SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
8633     SelectOp->setFlags(Flags);
8634     return SelectOp;
8635   }
8636 
8637   // Fold selects based on a setcc into other things, such as min/max/abs.
8638   if (N0.getOpcode() == ISD::SETCC) {
8639     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
8640     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
8641 
8642     // select (fcmp lt x, y), x, y -> fminnum x, y
8643     // select (fcmp gt x, y), x, y -> fmaxnum x, y
8644     //
8645     // This is OK if we don't care what happens if either operand is a NaN.
8646     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
8647       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2,
8648                                                 CC, TLI, DAG))
8649         return FMinMax;
8650 
8651     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
8652     // This is conservatively limited to pre-legal-operations to give targets
8653     // a chance to reverse the transform if they want to do that. Also, it is
8654     // unlikely that the pattern would be formed late, so it's probably not
8655     // worth going through the other checks.
8656     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
8657         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
8658         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
8659       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
8660       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
8661       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
8662         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
8663         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
8664         //
8665         // The IR equivalent of this transform would have this form:
8666         //   %a = add %x, C
8667         //   %c = icmp ugt %x, ~C
8668         //   %r = select %c, -1, %a
8669         //   =>
8670         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
8671         //   %u0 = extractvalue %u, 0
8672         //   %u1 = extractvalue %u, 1
8673         //   %r = select %u1, -1, %u0
8674         SDVTList VTs = DAG.getVTList(VT, VT0);
8675         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
8676         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
8677       }
8678     }
8679 
8680     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
8681         (!LegalOperations &&
8682          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
8683       // Any flags available in a select/setcc fold will be on the setcc as they
8684       // migrated from fcmp
8685       Flags = N0.getNode()->getFlags();
8686       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
8687                                        N2, N0.getOperand(2));
8688       SelectNode->setFlags(Flags);
8689       return SelectNode;
8690     }
8691 
8692     return SimplifySelect(DL, N0, N1, N2);
8693   }
8694 
8695   return SDValue();
8696 }
8697 
8698 // This function assumes all the vselect's arguments are CONCAT_VECTOR
8699 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)8700 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
8701   SDLoc DL(N);
8702   SDValue Cond = N->getOperand(0);
8703   SDValue LHS = N->getOperand(1);
8704   SDValue RHS = N->getOperand(2);
8705   EVT VT = N->getValueType(0);
8706   int NumElems = VT.getVectorNumElements();
8707   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
8708          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
8709          Cond.getOpcode() == ISD::BUILD_VECTOR);
8710 
8711   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
8712   // binary ones here.
8713   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
8714     return SDValue();
8715 
8716   // We're sure we have an even number of elements due to the
8717   // concat_vectors we have as arguments to vselect.
8718   // Skip BV elements until we find one that's not an UNDEF
8719   // After we find an UNDEF element, keep looping until we get to half the
8720   // length of the BV and see if all the non-undef nodes are the same.
8721   ConstantSDNode *BottomHalf = nullptr;
8722   for (int i = 0; i < NumElems / 2; ++i) {
8723     if (Cond->getOperand(i)->isUndef())
8724       continue;
8725 
8726     if (BottomHalf == nullptr)
8727       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
8728     else if (Cond->getOperand(i).getNode() != BottomHalf)
8729       return SDValue();
8730   }
8731 
8732   // Do the same for the second half of the BuildVector
8733   ConstantSDNode *TopHalf = nullptr;
8734   for (int i = NumElems / 2; i < NumElems; ++i) {
8735     if (Cond->getOperand(i)->isUndef())
8736       continue;
8737 
8738     if (TopHalf == nullptr)
8739       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
8740     else if (Cond->getOperand(i).getNode() != TopHalf)
8741       return SDValue();
8742   }
8743 
8744   assert(TopHalf && BottomHalf &&
8745          "One half of the selector was all UNDEFs and the other was all the "
8746          "same value. This should have been addressed before this function.");
8747   return DAG.getNode(
8748       ISD::CONCAT_VECTORS, DL, VT,
8749       BottomHalf->isNullValue() ? RHS->getOperand(0) : LHS->getOperand(0),
8750       TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1));
8751 }
8752 
visitMSCATTER(SDNode * N)8753 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
8754   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
8755   SDValue Mask = MSC->getMask();
8756   SDValue Chain = MSC->getChain();
8757   SDLoc DL(N);
8758 
8759   // Zap scatters with a zero mask.
8760   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
8761     return Chain;
8762 
8763   return SDValue();
8764 }
8765 
visitMSTORE(SDNode * N)8766 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
8767   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
8768   SDValue Mask = MST->getMask();
8769   SDValue Chain = MST->getChain();
8770   SDLoc DL(N);
8771 
8772   // Zap masked stores with a zero mask.
8773   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
8774     return Chain;
8775 
8776   // Try transforming N to an indexed store.
8777   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
8778     return SDValue(N, 0);
8779 
8780   return SDValue();
8781 }
8782 
visitMGATHER(SDNode * N)8783 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
8784   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
8785   SDValue Mask = MGT->getMask();
8786   SDLoc DL(N);
8787 
8788   // Zap gathers with a zero mask.
8789   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
8790     return CombineTo(N, MGT->getPassThru(), MGT->getChain());
8791 
8792   return SDValue();
8793 }
8794 
visitMLOAD(SDNode * N)8795 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
8796   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
8797   SDValue Mask = MLD->getMask();
8798   SDLoc DL(N);
8799 
8800   // Zap masked loads with a zero mask.
8801   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
8802     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
8803 
8804   // Try transforming N to an indexed load.
8805   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
8806     return SDValue(N, 0);
8807 
8808   return SDValue();
8809 }
8810 
8811 /// A vector select of 2 constant vectors can be simplified to math/logic to
8812 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)8813 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
8814   SDValue Cond = N->getOperand(0);
8815   SDValue N1 = N->getOperand(1);
8816   SDValue N2 = N->getOperand(2);
8817   EVT VT = N->getValueType(0);
8818   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
8819       !TLI.convertSelectOfConstantsToMath(VT) ||
8820       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
8821       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
8822     return SDValue();
8823 
8824   // Check if we can use the condition value to increment/decrement a single
8825   // constant value. This simplifies a select to an add and removes a constant
8826   // load/materialization from the general case.
8827   bool AllAddOne = true;
8828   bool AllSubOne = true;
8829   unsigned Elts = VT.getVectorNumElements();
8830   for (unsigned i = 0; i != Elts; ++i) {
8831     SDValue N1Elt = N1.getOperand(i);
8832     SDValue N2Elt = N2.getOperand(i);
8833     if (N1Elt.isUndef() || N2Elt.isUndef())
8834       continue;
8835 
8836     const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
8837     const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
8838     if (C1 != C2 + 1)
8839       AllAddOne = false;
8840     if (C1 != C2 - 1)
8841       AllSubOne = false;
8842   }
8843 
8844   // Further simplifications for the extra-special cases where the constants are
8845   // all 0 or all -1 should be implemented as folds of these patterns.
8846   SDLoc DL(N);
8847   if (AllAddOne || AllSubOne) {
8848     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
8849     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
8850     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
8851     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
8852     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
8853   }
8854 
8855   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
8856   APInt Pow2C;
8857   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
8858       isNullOrNullSplat(N2)) {
8859     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
8860     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
8861     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
8862   }
8863 
8864   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
8865     return V;
8866 
8867   // The general case for select-of-constants:
8868   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
8869   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
8870   // leave that to a machine-specific pass.
8871   return SDValue();
8872 }
8873 
visitVSELECT(SDNode * N)8874 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
8875   SDValue N0 = N->getOperand(0);
8876   SDValue N1 = N->getOperand(1);
8877   SDValue N2 = N->getOperand(2);
8878   EVT VT = N->getValueType(0);
8879   SDLoc DL(N);
8880 
8881   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
8882     return V;
8883 
8884   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
8885   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
8886     return DAG.getSelect(DL, VT, F, N2, N1);
8887 
8888   // Canonicalize integer abs.
8889   // vselect (setg[te] X,  0),  X, -X ->
8890   // vselect (setgt    X, -1),  X, -X ->
8891   // vselect (setl[te] X,  0), -X,  X ->
8892   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
8893   if (N0.getOpcode() == ISD::SETCC) {
8894     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
8895     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
8896     bool isAbs = false;
8897     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
8898 
8899     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
8900          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
8901         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
8902       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
8903     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
8904              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
8905       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
8906 
8907     if (isAbs) {
8908       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
8909         return DAG.getNode(ISD::ABS, DL, VT, LHS);
8910 
8911       SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
8912                                   DAG.getConstant(VT.getScalarSizeInBits() - 1,
8913                                                   DL, getShiftAmountTy(VT)));
8914       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
8915       AddToWorklist(Shift.getNode());
8916       AddToWorklist(Add.getNode());
8917       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
8918     }
8919 
8920     // vselect x, y (fcmp lt x, y) -> fminnum x, y
8921     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
8922     //
8923     // This is OK if we don't care about what happens if either operand is a
8924     // NaN.
8925     //
8926     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
8927       if (SDValue FMinMax =
8928               combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC, TLI, DAG))
8929         return FMinMax;
8930     }
8931 
8932     // If this select has a condition (setcc) with narrower operands than the
8933     // select, try to widen the compare to match the select width.
8934     // TODO: This should be extended to handle any constant.
8935     // TODO: This could be extended to handle non-loading patterns, but that
8936     //       requires thorough testing to avoid regressions.
8937     if (isNullOrNullSplat(RHS)) {
8938       EVT NarrowVT = LHS.getValueType();
8939       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
8940       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
8941       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
8942       unsigned WideWidth = WideVT.getScalarSizeInBits();
8943       bool IsSigned = isSignedIntSetCC(CC);
8944       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
8945       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
8946           SetCCWidth != 1 && SetCCWidth < WideWidth &&
8947           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
8948           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
8949         // Both compare operands can be widened for free. The LHS can use an
8950         // extended load, and the RHS is a constant:
8951         //   vselect (ext (setcc load(X), C)), N1, N2 -->
8952         //   vselect (setcc extload(X), C'), N1, N2
8953         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
8954         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
8955         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
8956         EVT WideSetCCVT = getSetCCResultType(WideVT);
8957         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
8958         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
8959       }
8960     }
8961   }
8962 
8963   if (SimplifySelectOps(N, N1, N2))
8964     return SDValue(N, 0);  // Don't revisit N.
8965 
8966   // Fold (vselect (build_vector all_ones), N1, N2) -> N1
8967   if (ISD::isBuildVectorAllOnes(N0.getNode()))
8968     return N1;
8969   // Fold (vselect (build_vector all_zeros), N1, N2) -> N2
8970   if (ISD::isBuildVectorAllZeros(N0.getNode()))
8971     return N2;
8972 
8973   // The ConvertSelectToConcatVector function is assuming both the above
8974   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
8975   // and addressed.
8976   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
8977       N2.getOpcode() == ISD::CONCAT_VECTORS &&
8978       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
8979     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
8980       return CV;
8981   }
8982 
8983   if (SDValue V = foldVSelectOfConstants(N))
8984     return V;
8985 
8986   return SDValue();
8987 }
8988 
visitSELECT_CC(SDNode * N)8989 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
8990   SDValue N0 = N->getOperand(0);
8991   SDValue N1 = N->getOperand(1);
8992   SDValue N2 = N->getOperand(2);
8993   SDValue N3 = N->getOperand(3);
8994   SDValue N4 = N->getOperand(4);
8995   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
8996 
8997   // fold select_cc lhs, rhs, x, x, cc -> x
8998   if (N2 == N3)
8999     return N2;
9000 
9001   // Determine if the condition we're dealing with is constant
9002   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
9003                                   CC, SDLoc(N), false)) {
9004     AddToWorklist(SCC.getNode());
9005 
9006     if (ConstantSDNode *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode())) {
9007       if (!SCCC->isNullValue())
9008         return N2;    // cond always true -> true val
9009       else
9010         return N3;    // cond always false -> false val
9011     } else if (SCC->isUndef()) {
9012       // When the condition is UNDEF, just return the first operand. This is
9013       // coherent the DAG creation, no setcc node is created in this case
9014       return N2;
9015     } else if (SCC.getOpcode() == ISD::SETCC) {
9016       // Fold to a simpler select_cc
9017       SDValue SelectOp = DAG.getNode(
9018           ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
9019           SCC.getOperand(1), N2, N3, SCC.getOperand(2));
9020       SelectOp->setFlags(SCC->getFlags());
9021       return SelectOp;
9022     }
9023   }
9024 
9025   // If we can fold this based on the true/false value, do so.
9026   if (SimplifySelectOps(N, N2, N3))
9027     return SDValue(N, 0);  // Don't revisit N.
9028 
9029   // fold select_cc into other things, such as min/max/abs
9030   return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
9031 }
9032 
visitSETCC(SDNode * N)9033 SDValue DAGCombiner::visitSETCC(SDNode *N) {
9034   // setcc is very commonly used as an argument to brcond. This pattern
9035   // also lend itself to numerous combines and, as a result, it is desired
9036   // we keep the argument to a brcond as a setcc as much as possible.
9037   bool PreferSetCC =
9038       N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
9039 
9040   SDValue Combined = SimplifySetCC(
9041       N->getValueType(0), N->getOperand(0), N->getOperand(1),
9042       cast<CondCodeSDNode>(N->getOperand(2))->get(), SDLoc(N), !PreferSetCC);
9043 
9044   if (!Combined)
9045     return SDValue();
9046 
9047   // If we prefer to have a setcc, and we don't, we'll try our best to
9048   // recreate one using rebuildSetCC.
9049   if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
9050     SDValue NewSetCC = rebuildSetCC(Combined);
9051 
9052     // We don't have anything interesting to combine to.
9053     if (NewSetCC.getNode() == N)
9054       return SDValue();
9055 
9056     if (NewSetCC)
9057       return NewSetCC;
9058   }
9059 
9060   return Combined;
9061 }
9062 
visitSETCCCARRY(SDNode * N)9063 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
9064   SDValue LHS = N->getOperand(0);
9065   SDValue RHS = N->getOperand(1);
9066   SDValue Carry = N->getOperand(2);
9067   SDValue Cond = N->getOperand(3);
9068 
9069   // If Carry is false, fold to a regular SETCC.
9070   if (isNullConstant(Carry))
9071     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
9072 
9073   return SDValue();
9074 }
9075 
9076 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
9077 /// a build_vector of constants.
9078 /// This function is called by the DAGCombiner when visiting sext/zext/aext
9079 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
9080 /// Vector extends are not folded if operations are legal; this is to
9081 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)9082 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
9083                                          SelectionDAG &DAG, bool LegalTypes) {
9084   unsigned Opcode = N->getOpcode();
9085   SDValue N0 = N->getOperand(0);
9086   EVT VT = N->getValueType(0);
9087   SDLoc DL(N);
9088 
9089   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
9090          Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
9091          Opcode == ISD::ZERO_EXTEND_VECTOR_INREG)
9092          && "Expected EXTEND dag node in input!");
9093 
9094   // fold (sext c1) -> c1
9095   // fold (zext c1) -> c1
9096   // fold (aext c1) -> c1
9097   if (isa<ConstantSDNode>(N0))
9098     return DAG.getNode(Opcode, DL, VT, N0);
9099 
9100   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
9101   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
9102   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
9103   if (N0->getOpcode() == ISD::SELECT) {
9104     SDValue Op1 = N0->getOperand(1);
9105     SDValue Op2 = N0->getOperand(2);
9106     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
9107         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
9108       // For any_extend, choose sign extension of the constants to allow a
9109       // possible further transform to sign_extend_inreg.i.e.
9110       //
9111       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
9112       // t2: i64 = any_extend t1
9113       // -->
9114       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
9115       // -->
9116       // t4: i64 = sign_extend_inreg t3
9117       unsigned FoldOpc = Opcode;
9118       if (FoldOpc == ISD::ANY_EXTEND)
9119         FoldOpc = ISD::SIGN_EXTEND;
9120       return DAG.getSelect(DL, VT, N0->getOperand(0),
9121                            DAG.getNode(FoldOpc, DL, VT, Op1),
9122                            DAG.getNode(FoldOpc, DL, VT, Op2));
9123     }
9124   }
9125 
9126   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
9127   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
9128   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
9129   EVT SVT = VT.getScalarType();
9130   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
9131       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
9132     return SDValue();
9133 
9134   // We can fold this node into a build_vector.
9135   unsigned VTBits = SVT.getSizeInBits();
9136   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
9137   SmallVector<SDValue, 8> Elts;
9138   unsigned NumElts = VT.getVectorNumElements();
9139 
9140   // For zero-extensions, UNDEF elements still guarantee to have the upper
9141   // bits set to zero.
9142   bool IsZext =
9143       Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG;
9144 
9145   for (unsigned i = 0; i != NumElts; ++i) {
9146     SDValue Op = N0.getOperand(i);
9147     if (Op.isUndef()) {
9148       Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT));
9149       continue;
9150     }
9151 
9152     SDLoc DL(Op);
9153     // Get the constant value and if needed trunc it to the size of the type.
9154     // Nodes like build_vector might have constants wider than the scalar type.
9155     APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().zextOrTrunc(EVTBits);
9156     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
9157       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
9158     else
9159       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
9160   }
9161 
9162   return DAG.getBuildVector(VT, DL, Elts);
9163 }
9164 
9165 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
9166 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
9167 // transformation. Returns true if extension are possible and the above
9168 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)9169 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
9170                                     unsigned ExtOpc,
9171                                     SmallVectorImpl<SDNode *> &ExtendNodes,
9172                                     const TargetLowering &TLI) {
9173   bool HasCopyToRegUses = false;
9174   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
9175   for (SDNode::use_iterator UI = N0.getNode()->use_begin(),
9176                             UE = N0.getNode()->use_end();
9177        UI != UE; ++UI) {
9178     SDNode *User = *UI;
9179     if (User == N)
9180       continue;
9181     if (UI.getUse().getResNo() != N0.getResNo())
9182       continue;
9183     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
9184     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
9185       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
9186       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
9187         // Sign bits will be lost after a zext.
9188         return false;
9189       bool Add = false;
9190       for (unsigned i = 0; i != 2; ++i) {
9191         SDValue UseOp = User->getOperand(i);
9192         if (UseOp == N0)
9193           continue;
9194         if (!isa<ConstantSDNode>(UseOp))
9195           return false;
9196         Add = true;
9197       }
9198       if (Add)
9199         ExtendNodes.push_back(User);
9200       continue;
9201     }
9202     // If truncates aren't free and there are users we can't
9203     // extend, it isn't worthwhile.
9204     if (!isTruncFree)
9205       return false;
9206     // Remember if this value is live-out.
9207     if (User->getOpcode() == ISD::CopyToReg)
9208       HasCopyToRegUses = true;
9209   }
9210 
9211   if (HasCopyToRegUses) {
9212     bool BothLiveOut = false;
9213     for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
9214          UI != UE; ++UI) {
9215       SDUse &Use = UI.getUse();
9216       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
9217         BothLiveOut = true;
9218         break;
9219       }
9220     }
9221     if (BothLiveOut)
9222       // Both unextended and extended values are live out. There had better be
9223       // a good reason for the transformation.
9224       return ExtendNodes.size();
9225   }
9226   return true;
9227 }
9228 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)9229 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
9230                                   SDValue OrigLoad, SDValue ExtLoad,
9231                                   ISD::NodeType ExtType) {
9232   // Extend SetCC uses if necessary.
9233   SDLoc DL(ExtLoad);
9234   for (SDNode *SetCC : SetCCs) {
9235     SmallVector<SDValue, 4> Ops;
9236 
9237     for (unsigned j = 0; j != 2; ++j) {
9238       SDValue SOp = SetCC->getOperand(j);
9239       if (SOp == OrigLoad)
9240         Ops.push_back(ExtLoad);
9241       else
9242         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
9243     }
9244 
9245     Ops.push_back(SetCC->getOperand(2));
9246     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
9247   }
9248 }
9249 
9250 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)9251 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
9252   SDValue N0 = N->getOperand(0);
9253   EVT DstVT = N->getValueType(0);
9254   EVT SrcVT = N0.getValueType();
9255 
9256   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
9257           N->getOpcode() == ISD::ZERO_EXTEND) &&
9258          "Unexpected node type (not an extend)!");
9259 
9260   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
9261   // For example, on a target with legal v4i32, but illegal v8i32, turn:
9262   //   (v8i32 (sext (v8i16 (load x))))
9263   // into:
9264   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
9265   //                          (v4i32 (sextload (x + 16)))))
9266   // Where uses of the original load, i.e.:
9267   //   (v8i16 (load x))
9268   // are replaced with:
9269   //   (v8i16 (truncate
9270   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
9271   //                            (v4i32 (sextload (x + 16)))))))
9272   //
9273   // This combine is only applicable to illegal, but splittable, vectors.
9274   // All legal types, and illegal non-vector types, are handled elsewhere.
9275   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
9276   //
9277   if (N0->getOpcode() != ISD::LOAD)
9278     return SDValue();
9279 
9280   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
9281 
9282   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
9283       !N0.hasOneUse() || !LN0->isSimple() ||
9284       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
9285       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
9286     return SDValue();
9287 
9288   SmallVector<SDNode *, 4> SetCCs;
9289   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
9290     return SDValue();
9291 
9292   ISD::LoadExtType ExtType =
9293       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
9294 
9295   // Try to split the vector types to get down to legal types.
9296   EVT SplitSrcVT = SrcVT;
9297   EVT SplitDstVT = DstVT;
9298   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
9299          SplitSrcVT.getVectorNumElements() > 1) {
9300     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
9301     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
9302   }
9303 
9304   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
9305     return SDValue();
9306 
9307   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
9308 
9309   SDLoc DL(N);
9310   const unsigned NumSplits =
9311       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
9312   const unsigned Stride = SplitSrcVT.getStoreSize();
9313   SmallVector<SDValue, 4> Loads;
9314   SmallVector<SDValue, 4> Chains;
9315 
9316   SDValue BasePtr = LN0->getBasePtr();
9317   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
9318     const unsigned Offset = Idx * Stride;
9319     const unsigned Align = MinAlign(LN0->getAlignment(), Offset);
9320 
9321     SDValue SplitLoad = DAG.getExtLoad(
9322         ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
9323         LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
9324         LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
9325 
9326     BasePtr = DAG.getMemBasePlusOffset(BasePtr, Stride, DL);
9327 
9328     Loads.push_back(SplitLoad.getValue(0));
9329     Chains.push_back(SplitLoad.getValue(1));
9330   }
9331 
9332   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
9333   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
9334 
9335   // Simplify TF.
9336   AddToWorklist(NewChain.getNode());
9337 
9338   CombineTo(N, NewValue);
9339 
9340   // Replace uses of the original load (before extension)
9341   // with a truncate of the concatenated sextloaded vectors.
9342   SDValue Trunc =
9343       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
9344   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
9345   CombineTo(N0.getNode(), Trunc, NewChain);
9346   return SDValue(N, 0); // Return N so it doesn't get rechecked!
9347 }
9348 
9349 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
9350 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)9351 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
9352   assert(N->getOpcode() == ISD::ZERO_EXTEND);
9353   EVT VT = N->getValueType(0);
9354   EVT OrigVT = N->getOperand(0).getValueType();
9355   if (TLI.isZExtFree(OrigVT, VT))
9356     return SDValue();
9357 
9358   // and/or/xor
9359   SDValue N0 = N->getOperand(0);
9360   if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
9361         N0.getOpcode() == ISD::XOR) ||
9362       N0.getOperand(1).getOpcode() != ISD::Constant ||
9363       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
9364     return SDValue();
9365 
9366   // shl/shr
9367   SDValue N1 = N0->getOperand(0);
9368   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
9369       N1.getOperand(1).getOpcode() != ISD::Constant ||
9370       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
9371     return SDValue();
9372 
9373   // load
9374   if (!isa<LoadSDNode>(N1.getOperand(0)))
9375     return SDValue();
9376   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
9377   EVT MemVT = Load->getMemoryVT();
9378   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
9379       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
9380     return SDValue();
9381 
9382 
9383   // If the shift op is SHL, the logic op must be AND, otherwise the result
9384   // will be wrong.
9385   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
9386     return SDValue();
9387 
9388   if (!N0.hasOneUse() || !N1.hasOneUse())
9389     return SDValue();
9390 
9391   SmallVector<SDNode*, 4> SetCCs;
9392   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
9393                                ISD::ZERO_EXTEND, SetCCs, TLI))
9394     return SDValue();
9395 
9396   // Actually do the transformation.
9397   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
9398                                    Load->getChain(), Load->getBasePtr(),
9399                                    Load->getMemoryVT(), Load->getMemOperand());
9400 
9401   SDLoc DL1(N1);
9402   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
9403                               N1.getOperand(1));
9404 
9405   APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
9406   Mask = Mask.zext(VT.getSizeInBits());
9407   SDLoc DL0(N0);
9408   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
9409                             DAG.getConstant(Mask, DL0, VT));
9410 
9411   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
9412   CombineTo(N, And);
9413   if (SDValue(Load, 0).hasOneUse()) {
9414     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
9415   } else {
9416     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
9417                                 Load->getValueType(0), ExtLoad);
9418     CombineTo(Load, Trunc, ExtLoad.getValue(1));
9419   }
9420 
9421   // N0 is dead at this point.
9422   recursivelyDeleteUnusedNodes(N0.getNode());
9423 
9424   return SDValue(N,0); // Return N so it doesn't get rechecked!
9425 }
9426 
9427 /// If we're narrowing or widening the result of a vector select and the final
9428 /// size is the same size as a setcc (compare) feeding the select, then try to
9429 /// apply the cast operation to the select's operands because matching vector
9430 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)9431 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
9432   unsigned CastOpcode = Cast->getOpcode();
9433   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
9434           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
9435           CastOpcode == ISD::FP_ROUND) &&
9436          "Unexpected opcode for vector select narrowing/widening");
9437 
9438   // We only do this transform before legal ops because the pattern may be
9439   // obfuscated by target-specific operations after legalization. Do not create
9440   // an illegal select op, however, because that may be difficult to lower.
9441   EVT VT = Cast->getValueType(0);
9442   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
9443     return SDValue();
9444 
9445   SDValue VSel = Cast->getOperand(0);
9446   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
9447       VSel.getOperand(0).getOpcode() != ISD::SETCC)
9448     return SDValue();
9449 
9450   // Does the setcc have the same vector size as the casted select?
9451   SDValue SetCC = VSel.getOperand(0);
9452   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
9453   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
9454     return SDValue();
9455 
9456   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
9457   SDValue A = VSel.getOperand(1);
9458   SDValue B = VSel.getOperand(2);
9459   SDValue CastA, CastB;
9460   SDLoc DL(Cast);
9461   if (CastOpcode == ISD::FP_ROUND) {
9462     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
9463     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
9464     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
9465   } else {
9466     CastA = DAG.getNode(CastOpcode, DL, VT, A);
9467     CastB = DAG.getNode(CastOpcode, DL, VT, B);
9468   }
9469   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
9470 }
9471 
9472 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
9473 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)9474 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
9475                                      const TargetLowering &TLI, EVT VT,
9476                                      bool LegalOperations, SDNode *N,
9477                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
9478   SDNode *N0Node = N0.getNode();
9479   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
9480                                                    : ISD::isZEXTLoad(N0Node);
9481   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
9482       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
9483     return SDValue();
9484 
9485   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
9486   EVT MemVT = LN0->getMemoryVT();
9487   if ((LegalOperations || !LN0->isSimple() ||
9488        VT.isVector()) &&
9489       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
9490     return SDValue();
9491 
9492   SDValue ExtLoad =
9493       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
9494                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
9495   Combiner.CombineTo(N, ExtLoad);
9496   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
9497   if (LN0->use_empty())
9498     Combiner.recursivelyDeleteUnusedNodes(LN0);
9499   return SDValue(N, 0); // Return N so it doesn't get rechecked!
9500 }
9501 
9502 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
9503 // Only generate vector extloads when 1) they're legal, and 2) they are
9504 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)9505 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
9506                                   const TargetLowering &TLI, EVT VT,
9507                                   bool LegalOperations, SDNode *N, SDValue N0,
9508                                   ISD::LoadExtType ExtLoadType,
9509                                   ISD::NodeType ExtOpc) {
9510   if (!ISD::isNON_EXTLoad(N0.getNode()) ||
9511       !ISD::isUNINDEXEDLoad(N0.getNode()) ||
9512       ((LegalOperations || VT.isVector() ||
9513         !cast<LoadSDNode>(N0)->isSimple()) &&
9514        !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
9515     return {};
9516 
9517   bool DoXform = true;
9518   SmallVector<SDNode *, 4> SetCCs;
9519   if (!N0.hasOneUse())
9520     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
9521   if (VT.isVector())
9522     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
9523   if (!DoXform)
9524     return {};
9525 
9526   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
9527   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
9528                                    LN0->getBasePtr(), N0.getValueType(),
9529                                    LN0->getMemOperand());
9530   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
9531   // If the load value is used only by N, replace it via CombineTo N.
9532   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
9533   Combiner.CombineTo(N, ExtLoad);
9534   if (NoReplaceTrunc) {
9535     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
9536     Combiner.recursivelyDeleteUnusedNodes(LN0);
9537   } else {
9538     SDValue Trunc =
9539         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
9540     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
9541   }
9542   return SDValue(N, 0); // Return N so it doesn't get rechecked!
9543 }
9544 
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)9545 static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG,
9546                                         const TargetLowering &TLI, EVT VT,
9547                                         SDNode *N, SDValue N0,
9548                                         ISD::LoadExtType ExtLoadType,
9549                                         ISD::NodeType ExtOpc) {
9550   if (!N0.hasOneUse())
9551     return SDValue();
9552 
9553   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
9554   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
9555     return SDValue();
9556 
9557   if (!TLI.isLoadExtLegal(ExtLoadType, VT, Ld->getValueType(0)))
9558     return SDValue();
9559 
9560   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
9561     return SDValue();
9562 
9563   SDLoc dl(Ld);
9564   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
9565   SDValue NewLoad = DAG.getMaskedLoad(
9566       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
9567       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
9568       ExtLoadType, Ld->isExpandingLoad());
9569   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
9570   return NewLoad;
9571 }
9572 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)9573 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
9574                                        bool LegalOperations) {
9575   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
9576           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
9577 
9578   SDValue SetCC = N->getOperand(0);
9579   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
9580       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
9581     return SDValue();
9582 
9583   SDValue X = SetCC.getOperand(0);
9584   SDValue Ones = SetCC.getOperand(1);
9585   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
9586   EVT VT = N->getValueType(0);
9587   EVT XVT = X.getValueType();
9588   // setge X, C is canonicalized to setgt, so we do not need to match that
9589   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
9590   // not require the 'not' op.
9591   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
9592     // Invert and smear/shift the sign bit:
9593     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
9594     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
9595     SDLoc DL(N);
9596     unsigned ShCt = VT.getSizeInBits() - 1;
9597     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9598     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
9599       SDValue NotX = DAG.getNOT(DL, X, VT);
9600       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
9601       auto ShiftOpcode =
9602         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
9603       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
9604     }
9605   }
9606   return SDValue();
9607 }
9608 
visitSIGN_EXTEND(SDNode * N)9609 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
9610   SDValue N0 = N->getOperand(0);
9611   EVT VT = N->getValueType(0);
9612   SDLoc DL(N);
9613 
9614   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
9615     return Res;
9616 
9617   // fold (sext (sext x)) -> (sext x)
9618   // fold (sext (aext x)) -> (sext x)
9619   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
9620     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
9621 
9622   if (N0.getOpcode() == ISD::TRUNCATE) {
9623     // fold (sext (truncate (load x))) -> (sext (smaller load x))
9624     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
9625     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
9626       SDNode *oye = N0.getOperand(0).getNode();
9627       if (NarrowLoad.getNode() != N0.getNode()) {
9628         CombineTo(N0.getNode(), NarrowLoad);
9629         // CombineTo deleted the truncate, if needed, but not what's under it.
9630         AddToWorklist(oye);
9631       }
9632       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
9633     }
9634 
9635     // See if the value being truncated is already sign extended.  If so, just
9636     // eliminate the trunc/sext pair.
9637     SDValue Op = N0.getOperand(0);
9638     unsigned OpBits   = Op.getScalarValueSizeInBits();
9639     unsigned MidBits  = N0.getScalarValueSizeInBits();
9640     unsigned DestBits = VT.getScalarSizeInBits();
9641     unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
9642 
9643     if (OpBits == DestBits) {
9644       // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
9645       // bits, it is already ready.
9646       if (NumSignBits > DestBits-MidBits)
9647         return Op;
9648     } else if (OpBits < DestBits) {
9649       // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
9650       // bits, just sext from i32.
9651       if (NumSignBits > OpBits-MidBits)
9652         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
9653     } else {
9654       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
9655       // bits, just truncate to i32.
9656       if (NumSignBits > OpBits-MidBits)
9657         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
9658     }
9659 
9660     // fold (sext (truncate x)) -> (sextinreg x).
9661     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
9662                                                  N0.getValueType())) {
9663       if (OpBits < DestBits)
9664         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
9665       else if (OpBits > DestBits)
9666         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
9667       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
9668                          DAG.getValueType(N0.getValueType()));
9669     }
9670   }
9671 
9672   // Try to simplify (sext (load x)).
9673   if (SDValue foldedExt =
9674           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
9675                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
9676     return foldedExt;
9677 
9678   if (SDValue foldedExt =
9679       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD,
9680                                ISD::SIGN_EXTEND))
9681     return foldedExt;
9682 
9683   // fold (sext (load x)) to multiple smaller sextloads.
9684   // Only on illegal but splittable vectors.
9685   if (SDValue ExtLoad = CombineExtLoad(N))
9686     return ExtLoad;
9687 
9688   // Try to simplify (sext (sextload x)).
9689   if (SDValue foldedExt = tryToFoldExtOfExtload(
9690           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
9691     return foldedExt;
9692 
9693   // fold (sext (and/or/xor (load x), cst)) ->
9694   //      (and/or/xor (sextload x), (sext cst))
9695   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
9696        N0.getOpcode() == ISD::XOR) &&
9697       isa<LoadSDNode>(N0.getOperand(0)) &&
9698       N0.getOperand(1).getOpcode() == ISD::Constant &&
9699       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
9700     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
9701     EVT MemVT = LN00->getMemoryVT();
9702     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
9703       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
9704       SmallVector<SDNode*, 4> SetCCs;
9705       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
9706                                              ISD::SIGN_EXTEND, SetCCs, TLI);
9707       if (DoXform) {
9708         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
9709                                          LN00->getChain(), LN00->getBasePtr(),
9710                                          LN00->getMemoryVT(),
9711                                          LN00->getMemOperand());
9712         APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
9713         Mask = Mask.sext(VT.getSizeInBits());
9714         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
9715                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
9716         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
9717         bool NoReplaceTruncAnd = !N0.hasOneUse();
9718         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
9719         CombineTo(N, And);
9720         // If N0 has multiple uses, change other uses as well.
9721         if (NoReplaceTruncAnd) {
9722           SDValue TruncAnd =
9723               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
9724           CombineTo(N0.getNode(), TruncAnd);
9725         }
9726         if (NoReplaceTrunc) {
9727           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
9728         } else {
9729           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
9730                                       LN00->getValueType(0), ExtLoad);
9731           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
9732         }
9733         return SDValue(N,0); // Return N so it doesn't get rechecked!
9734       }
9735     }
9736   }
9737 
9738   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
9739     return V;
9740 
9741   if (N0.getOpcode() == ISD::SETCC) {
9742     SDValue N00 = N0.getOperand(0);
9743     SDValue N01 = N0.getOperand(1);
9744     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
9745     EVT N00VT = N0.getOperand(0).getValueType();
9746 
9747     // sext(setcc) -> sext_in_reg(vsetcc) for vectors.
9748     // Only do this before legalize for now.
9749     if (VT.isVector() && !LegalOperations &&
9750         TLI.getBooleanContents(N00VT) ==
9751             TargetLowering::ZeroOrNegativeOneBooleanContent) {
9752       // On some architectures (such as SSE/NEON/etc) the SETCC result type is
9753       // of the same size as the compared operands. Only optimize sext(setcc())
9754       // if this is the case.
9755       EVT SVT = getSetCCResultType(N00VT);
9756 
9757       // If we already have the desired type, don't change it.
9758       if (SVT != N0.getValueType()) {
9759         // We know that the # elements of the results is the same as the
9760         // # elements of the compare (and the # elements of the compare result
9761         // for that matter).  Check to see that they are the same size.  If so,
9762         // we know that the element size of the sext'd result matches the
9763         // element size of the compare operands.
9764         if (VT.getSizeInBits() == SVT.getSizeInBits())
9765           return DAG.getSetCC(DL, VT, N00, N01, CC);
9766 
9767         // If the desired elements are smaller or larger than the source
9768         // elements, we can use a matching integer vector type and then
9769         // truncate/sign extend.
9770         EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
9771         if (SVT == MatchingVecType) {
9772           SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
9773           return DAG.getSExtOrTrunc(VsetCC, DL, VT);
9774         }
9775       }
9776     }
9777 
9778     // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
9779     // Here, T can be 1 or -1, depending on the type of the setcc and
9780     // getBooleanContents().
9781     unsigned SetCCWidth = N0.getScalarValueSizeInBits();
9782 
9783     // To determine the "true" side of the select, we need to know the high bit
9784     // of the value returned by the setcc if it evaluates to true.
9785     // If the type of the setcc is i1, then the true case of the select is just
9786     // sext(i1 1), that is, -1.
9787     // If the type of the setcc is larger (say, i8) then the value of the high
9788     // bit depends on getBooleanContents(), so ask TLI for a real "true" value
9789     // of the appropriate width.
9790     SDValue ExtTrueVal = (SetCCWidth == 1)
9791                              ? DAG.getAllOnesConstant(DL, VT)
9792                              : DAG.getBoolConstant(true, DL, VT, N00VT);
9793     SDValue Zero = DAG.getConstant(0, DL, VT);
9794     if (SDValue SCC =
9795             SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
9796       return SCC;
9797 
9798     if (!VT.isVector() && !TLI.convertSelectOfConstantsToMath(VT)) {
9799       EVT SetCCVT = getSetCCResultType(N00VT);
9800       // Don't do this transform for i1 because there's a select transform
9801       // that would reverse it.
9802       // TODO: We should not do this transform at all without a target hook
9803       // because a sext is likely cheaper than a select?
9804       if (SetCCVT.getScalarSizeInBits() != 1 &&
9805           (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
9806         SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
9807         return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
9808       }
9809     }
9810   }
9811 
9812   // fold (sext x) -> (zext x) if the sign bit is known zero.
9813   if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
9814       DAG.SignBitIsZero(N0))
9815     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
9816 
9817   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
9818     return NewVSel;
9819 
9820   // Eliminate this sign extend by doing a negation in the destination type:
9821   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
9822   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
9823       isNullOrNullSplat(N0.getOperand(0)) &&
9824       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
9825       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
9826     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
9827     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Zext);
9828   }
9829   // Eliminate this sign extend by doing a decrement in the destination type:
9830   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
9831   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
9832       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
9833       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
9834       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
9835     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
9836     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
9837   }
9838 
9839   return SDValue();
9840 }
9841 
9842 // isTruncateOf - If N is a truncate of some other value, return true, record
9843 // the value being truncated in Op and which of Op's bits are zero/one in Known.
9844 // This function computes KnownBits to avoid a duplicated call to
9845 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)9846 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
9847                          KnownBits &Known) {
9848   if (N->getOpcode() == ISD::TRUNCATE) {
9849     Op = N->getOperand(0);
9850     Known = DAG.computeKnownBits(Op);
9851     return true;
9852   }
9853 
9854   if (N.getOpcode() != ISD::SETCC ||
9855       N.getValueType().getScalarType() != MVT::i1 ||
9856       cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
9857     return false;
9858 
9859   SDValue Op0 = N->getOperand(0);
9860   SDValue Op1 = N->getOperand(1);
9861   assert(Op0.getValueType() == Op1.getValueType());
9862 
9863   if (isNullOrNullSplat(Op0))
9864     Op = Op1;
9865   else if (isNullOrNullSplat(Op1))
9866     Op = Op0;
9867   else
9868     return false;
9869 
9870   Known = DAG.computeKnownBits(Op);
9871 
9872   return (Known.Zero | 1).isAllOnesValue();
9873 }
9874 
9875 /// Given an extending node with a pop-count operand, if the target does not
9876 /// support a pop-count in the narrow source type but does support it in the
9877 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)9878 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
9879   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
9880           Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
9881 
9882   SDValue CtPop = Extend->getOperand(0);
9883   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
9884     return SDValue();
9885 
9886   EVT VT = Extend->getValueType(0);
9887   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9888   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
9889       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
9890     return SDValue();
9891 
9892   // zext (ctpop X) --> ctpop (zext X)
9893   SDLoc DL(Extend);
9894   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
9895   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
9896 }
9897 
visitZERO_EXTEND(SDNode * N)9898 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
9899   SDValue N0 = N->getOperand(0);
9900   EVT VT = N->getValueType(0);
9901 
9902   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
9903     return Res;
9904 
9905   // fold (zext (zext x)) -> (zext x)
9906   // fold (zext (aext x)) -> (zext x)
9907   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
9908     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT,
9909                        N0.getOperand(0));
9910 
9911   // fold (zext (truncate x)) -> (zext x) or
9912   //      (zext (truncate x)) -> (truncate x)
9913   // This is valid when the truncated bits of x are already zero.
9914   SDValue Op;
9915   KnownBits Known;
9916   if (isTruncateOf(DAG, N0, Op, Known)) {
9917     APInt TruncatedBits =
9918       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
9919       APInt(Op.getScalarValueSizeInBits(), 0) :
9920       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
9921                         N0.getScalarValueSizeInBits(),
9922                         std::min(Op.getScalarValueSizeInBits(),
9923                                  VT.getScalarSizeInBits()));
9924     if (TruncatedBits.isSubsetOf(Known.Zero))
9925       return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
9926   }
9927 
9928   // fold (zext (truncate x)) -> (and x, mask)
9929   if (N0.getOpcode() == ISD::TRUNCATE) {
9930     // fold (zext (truncate (load x))) -> (zext (smaller load x))
9931     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
9932     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
9933       SDNode *oye = N0.getOperand(0).getNode();
9934       if (NarrowLoad.getNode() != N0.getNode()) {
9935         CombineTo(N0.getNode(), NarrowLoad);
9936         // CombineTo deleted the truncate, if needed, but not what's under it.
9937         AddToWorklist(oye);
9938       }
9939       return SDValue(N, 0); // Return N so it doesn't get rechecked!
9940     }
9941 
9942     EVT SrcVT = N0.getOperand(0).getValueType();
9943     EVT MinVT = N0.getValueType();
9944 
9945     // Try to mask before the extension to avoid having to generate a larger mask,
9946     // possibly over several sub-vectors.
9947     if (SrcVT.bitsLT(VT) && VT.isVector()) {
9948       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
9949                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
9950         SDValue Op = N0.getOperand(0);
9951         Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
9952         AddToWorklist(Op.getNode());
9953         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
9954         // Transfer the debug info; the new node is equivalent to N0.
9955         DAG.transferDbgValues(N0, ZExtOrTrunc);
9956         return ZExtOrTrunc;
9957       }
9958     }
9959 
9960     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
9961       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
9962       AddToWorklist(Op.getNode());
9963       SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
9964       // We may safely transfer the debug info describing the truncate node over
9965       // to the equivalent and operation.
9966       DAG.transferDbgValues(N0, And);
9967       return And;
9968     }
9969   }
9970 
9971   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
9972   // if either of the casts is not free.
9973   if (N0.getOpcode() == ISD::AND &&
9974       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
9975       N0.getOperand(1).getOpcode() == ISD::Constant &&
9976       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
9977                            N0.getValueType()) ||
9978        !TLI.isZExtFree(N0.getValueType(), VT))) {
9979     SDValue X = N0.getOperand(0).getOperand(0);
9980     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
9981     APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
9982     Mask = Mask.zext(VT.getSizeInBits());
9983     SDLoc DL(N);
9984     return DAG.getNode(ISD::AND, DL, VT,
9985                        X, DAG.getConstant(Mask, DL, VT));
9986   }
9987 
9988   // Try to simplify (zext (load x)).
9989   if (SDValue foldedExt =
9990           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
9991                              ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
9992     return foldedExt;
9993 
9994   if (SDValue foldedExt =
9995       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD,
9996                                ISD::ZERO_EXTEND))
9997     return foldedExt;
9998 
9999   // fold (zext (load x)) to multiple smaller zextloads.
10000   // Only on illegal but splittable vectors.
10001   if (SDValue ExtLoad = CombineExtLoad(N))
10002     return ExtLoad;
10003 
10004   // fold (zext (and/or/xor (load x), cst)) ->
10005   //      (and/or/xor (zextload x), (zext cst))
10006   // Unless (and (load x) cst) will match as a zextload already and has
10007   // additional users.
10008   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
10009        N0.getOpcode() == ISD::XOR) &&
10010       isa<LoadSDNode>(N0.getOperand(0)) &&
10011       N0.getOperand(1).getOpcode() == ISD::Constant &&
10012       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
10013     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
10014     EVT MemVT = LN00->getMemoryVT();
10015     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
10016         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
10017       bool DoXform = true;
10018       SmallVector<SDNode*, 4> SetCCs;
10019       if (!N0.hasOneUse()) {
10020         if (N0.getOpcode() == ISD::AND) {
10021           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
10022           EVT LoadResultTy = AndC->getValueType(0);
10023           EVT ExtVT;
10024           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
10025             DoXform = false;
10026         }
10027       }
10028       if (DoXform)
10029         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
10030                                           ISD::ZERO_EXTEND, SetCCs, TLI);
10031       if (DoXform) {
10032         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
10033                                          LN00->getChain(), LN00->getBasePtr(),
10034                                          LN00->getMemoryVT(),
10035                                          LN00->getMemOperand());
10036         APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
10037         Mask = Mask.zext(VT.getSizeInBits());
10038         SDLoc DL(N);
10039         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
10040                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
10041         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
10042         bool NoReplaceTruncAnd = !N0.hasOneUse();
10043         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
10044         CombineTo(N, And);
10045         // If N0 has multiple uses, change other uses as well.
10046         if (NoReplaceTruncAnd) {
10047           SDValue TruncAnd =
10048               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
10049           CombineTo(N0.getNode(), TruncAnd);
10050         }
10051         if (NoReplaceTrunc) {
10052           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
10053         } else {
10054           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
10055                                       LN00->getValueType(0), ExtLoad);
10056           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
10057         }
10058         return SDValue(N,0); // Return N so it doesn't get rechecked!
10059       }
10060     }
10061   }
10062 
10063   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
10064   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
10065   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
10066     return ZExtLoad;
10067 
10068   // Try to simplify (zext (zextload x)).
10069   if (SDValue foldedExt = tryToFoldExtOfExtload(
10070           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
10071     return foldedExt;
10072 
10073   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
10074     return V;
10075 
10076   if (N0.getOpcode() == ISD::SETCC) {
10077     // Only do this before legalize for now.
10078     if (!LegalOperations && VT.isVector() &&
10079         N0.getValueType().getVectorElementType() == MVT::i1) {
10080       EVT N00VT = N0.getOperand(0).getValueType();
10081       if (getSetCCResultType(N00VT) == N0.getValueType())
10082         return SDValue();
10083 
10084       // We know that the # elements of the results is the same as the #
10085       // elements of the compare (and the # elements of the compare result for
10086       // that matter). Check to see that they are the same size. If so, we know
10087       // that the element size of the sext'd result matches the element size of
10088       // the compare operands.
10089       SDLoc DL(N);
10090       SDValue VecOnes = DAG.getConstant(1, DL, VT);
10091       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
10092         // zext(setcc) -> (and (vsetcc), (1, 1, ...) for vectors.
10093         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
10094                                      N0.getOperand(1), N0.getOperand(2));
10095         return DAG.getNode(ISD::AND, DL, VT, VSetCC, VecOnes);
10096       }
10097 
10098       // If the desired elements are smaller or larger than the source
10099       // elements we can use a matching integer vector type and then
10100       // truncate/sign extend.
10101       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
10102       SDValue VsetCC =
10103           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
10104                       N0.getOperand(1), N0.getOperand(2));
10105       return DAG.getNode(ISD::AND, DL, VT, DAG.getSExtOrTrunc(VsetCC, DL, VT),
10106                          VecOnes);
10107     }
10108 
10109     // zext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
10110     SDLoc DL(N);
10111     if (SDValue SCC = SimplifySelectCC(
10112             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
10113             DAG.getConstant(0, DL, VT),
10114             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
10115       return SCC;
10116   }
10117 
10118   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
10119   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
10120       isa<ConstantSDNode>(N0.getOperand(1)) &&
10121       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
10122       N0.hasOneUse()) {
10123     SDValue ShAmt = N0.getOperand(1);
10124     if (N0.getOpcode() == ISD::SHL) {
10125       SDValue InnerZExt = N0.getOperand(0);
10126       // If the original shl may be shifting out bits, do not perform this
10127       // transformation.
10128       unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
10129         InnerZExt.getOperand(0).getValueSizeInBits();
10130       if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits))
10131         return SDValue();
10132     }
10133 
10134     SDLoc DL(N);
10135 
10136     // Ensure that the shift amount is wide enough for the shifted value.
10137     if (VT.getSizeInBits() >= 256)
10138       ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
10139 
10140     return DAG.getNode(N0.getOpcode(), DL, VT,
10141                        DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
10142                        ShAmt);
10143   }
10144 
10145   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
10146     return NewVSel;
10147 
10148   if (SDValue NewCtPop = widenCtPop(N, DAG))
10149     return NewCtPop;
10150 
10151   return SDValue();
10152 }
10153 
visitANY_EXTEND(SDNode * N)10154 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
10155   SDValue N0 = N->getOperand(0);
10156   EVT VT = N->getValueType(0);
10157 
10158   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
10159     return Res;
10160 
10161   // fold (aext (aext x)) -> (aext x)
10162   // fold (aext (zext x)) -> (zext x)
10163   // fold (aext (sext x)) -> (sext x)
10164   if (N0.getOpcode() == ISD::ANY_EXTEND  ||
10165       N0.getOpcode() == ISD::ZERO_EXTEND ||
10166       N0.getOpcode() == ISD::SIGN_EXTEND)
10167     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
10168 
10169   // fold (aext (truncate (load x))) -> (aext (smaller load x))
10170   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
10171   if (N0.getOpcode() == ISD::TRUNCATE) {
10172     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
10173       SDNode *oye = N0.getOperand(0).getNode();
10174       if (NarrowLoad.getNode() != N0.getNode()) {
10175         CombineTo(N0.getNode(), NarrowLoad);
10176         // CombineTo deleted the truncate, if needed, but not what's under it.
10177         AddToWorklist(oye);
10178       }
10179       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10180     }
10181   }
10182 
10183   // fold (aext (truncate x))
10184   if (N0.getOpcode() == ISD::TRUNCATE)
10185     return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
10186 
10187   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
10188   // if the trunc is not free.
10189   if (N0.getOpcode() == ISD::AND &&
10190       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
10191       N0.getOperand(1).getOpcode() == ISD::Constant &&
10192       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
10193                           N0.getValueType())) {
10194     SDLoc DL(N);
10195     SDValue X = N0.getOperand(0).getOperand(0);
10196     X = DAG.getAnyExtOrTrunc(X, DL, VT);
10197     APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
10198     Mask = Mask.zext(VT.getSizeInBits());
10199     return DAG.getNode(ISD::AND, DL, VT,
10200                        X, DAG.getConstant(Mask, DL, VT));
10201   }
10202 
10203   // fold (aext (load x)) -> (aext (truncate (extload x)))
10204   // None of the supported targets knows how to perform load and any_ext
10205   // on vectors in one instruction.  We only perform this transformation on
10206   // scalars.
10207   if (ISD::isNON_EXTLoad(N0.getNode()) && !VT.isVector() &&
10208       ISD::isUNINDEXEDLoad(N0.getNode()) &&
10209       TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
10210     bool DoXform = true;
10211     SmallVector<SDNode*, 4> SetCCs;
10212     if (!N0.hasOneUse())
10213       DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs,
10214                                         TLI);
10215     if (DoXform) {
10216       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10217       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
10218                                        LN0->getChain(),
10219                                        LN0->getBasePtr(), N0.getValueType(),
10220                                        LN0->getMemOperand());
10221       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
10222       // If the load value is used only by N, replace it via CombineTo N.
10223       bool NoReplaceTrunc = N0.hasOneUse();
10224       CombineTo(N, ExtLoad);
10225       if (NoReplaceTrunc) {
10226         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
10227         recursivelyDeleteUnusedNodes(LN0);
10228       } else {
10229         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
10230                                     N0.getValueType(), ExtLoad);
10231         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
10232       }
10233       return SDValue(N, 0); // Return N so it doesn't get rechecked!
10234     }
10235   }
10236 
10237   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
10238   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
10239   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
10240   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
10241       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
10242     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10243     ISD::LoadExtType ExtType = LN0->getExtensionType();
10244     EVT MemVT = LN0->getMemoryVT();
10245     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
10246       SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
10247                                        VT, LN0->getChain(), LN0->getBasePtr(),
10248                                        MemVT, LN0->getMemOperand());
10249       CombineTo(N, ExtLoad);
10250       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
10251       recursivelyDeleteUnusedNodes(LN0);
10252       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10253     }
10254   }
10255 
10256   if (N0.getOpcode() == ISD::SETCC) {
10257     // For vectors:
10258     // aext(setcc) -> vsetcc
10259     // aext(setcc) -> truncate(vsetcc)
10260     // aext(setcc) -> aext(vsetcc)
10261     // Only do this before legalize for now.
10262     if (VT.isVector() && !LegalOperations) {
10263       EVT N00VT = N0.getOperand(0).getValueType();
10264       if (getSetCCResultType(N00VT) == N0.getValueType())
10265         return SDValue();
10266 
10267       // We know that the # elements of the results is the same as the
10268       // # elements of the compare (and the # elements of the compare result
10269       // for that matter).  Check to see that they are the same size.  If so,
10270       // we know that the element size of the sext'd result matches the
10271       // element size of the compare operands.
10272       if (VT.getSizeInBits() == N00VT.getSizeInBits())
10273         return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
10274                              N0.getOperand(1),
10275                              cast<CondCodeSDNode>(N0.getOperand(2))->get());
10276 
10277       // If the desired elements are smaller or larger than the source
10278       // elements we can use a matching integer vector type and then
10279       // truncate/any extend
10280       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
10281       SDValue VsetCC =
10282         DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
10283                       N0.getOperand(1),
10284                       cast<CondCodeSDNode>(N0.getOperand(2))->get());
10285       return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
10286     }
10287 
10288     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
10289     SDLoc DL(N);
10290     if (SDValue SCC = SimplifySelectCC(
10291             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
10292             DAG.getConstant(0, DL, VT),
10293             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
10294       return SCC;
10295   }
10296 
10297   if (SDValue NewCtPop = widenCtPop(N, DAG))
10298     return NewCtPop;
10299 
10300   return SDValue();
10301 }
10302 
visitAssertExt(SDNode * N)10303 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
10304   unsigned Opcode = N->getOpcode();
10305   SDValue N0 = N->getOperand(0);
10306   SDValue N1 = N->getOperand(1);
10307   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
10308 
10309   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
10310   if (N0.getOpcode() == Opcode &&
10311       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
10312     return N0;
10313 
10314   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
10315       N0.getOperand(0).getOpcode() == Opcode) {
10316     // We have an assert, truncate, assert sandwich. Make one stronger assert
10317     // by asserting on the smallest asserted type to the larger source type.
10318     // This eliminates the later assert:
10319     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
10320     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
10321     SDValue BigA = N0.getOperand(0);
10322     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
10323     assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
10324            "Asserting zero/sign-extended bits to a type larger than the "
10325            "truncated destination does not provide information");
10326 
10327     SDLoc DL(N);
10328     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
10329     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
10330     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
10331                                     BigA.getOperand(0), MinAssertVTVal);
10332     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
10333   }
10334 
10335   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
10336   // than X. Just move the AssertZext in front of the truncate and drop the
10337   // AssertSExt.
10338   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
10339       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
10340       Opcode == ISD::AssertZext) {
10341     SDValue BigA = N0.getOperand(0);
10342     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
10343     assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
10344            "Asserting zero/sign-extended bits to a type larger than the "
10345            "truncated destination does not provide information");
10346 
10347     if (AssertVT.bitsLT(BigA_AssertVT)) {
10348       SDLoc DL(N);
10349       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
10350                                       BigA.getOperand(0), N1);
10351       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
10352     }
10353   }
10354 
10355   return SDValue();
10356 }
10357 
10358 /// If the result of a wider load is shifted to right of N  bits and then
10359 /// truncated to a narrower type and where N is a multiple of number of bits of
10360 /// the narrower type, transform it to a narrower load from address + N / num of
10361 /// bits of new type. Also narrow the load if the result is masked with an AND
10362 /// to effectively produce a smaller type. If the result is to be extended, also
10363 /// fold the extension to form a extending load.
ReduceLoadWidth(SDNode * N)10364 SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
10365   unsigned Opc = N->getOpcode();
10366 
10367   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
10368   SDValue N0 = N->getOperand(0);
10369   EVT VT = N->getValueType(0);
10370   EVT ExtVT = VT;
10371 
10372   // This transformation isn't valid for vector loads.
10373   if (VT.isVector())
10374     return SDValue();
10375 
10376   unsigned ShAmt = 0;
10377   bool HasShiftedOffset = false;
10378   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
10379   // extended to VT.
10380   if (Opc == ISD::SIGN_EXTEND_INREG) {
10381     ExtType = ISD::SEXTLOAD;
10382     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
10383   } else if (Opc == ISD::SRL) {
10384     // Another special-case: SRL is basically zero-extending a narrower value,
10385     // or it maybe shifting a higher subword, half or byte into the lowest
10386     // bits.
10387     ExtType = ISD::ZEXTLOAD;
10388     N0 = SDValue(N, 0);
10389 
10390     auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
10391     auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
10392     if (!N01 || !LN0)
10393       return SDValue();
10394 
10395     uint64_t ShiftAmt = N01->getZExtValue();
10396     uint64_t MemoryWidth = LN0->getMemoryVT().getSizeInBits();
10397     if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
10398       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
10399     else
10400       ExtVT = EVT::getIntegerVT(*DAG.getContext(),
10401                                 VT.getSizeInBits() - ShiftAmt);
10402   } else if (Opc == ISD::AND) {
10403     // An AND with a constant mask is the same as a truncate + zero-extend.
10404     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
10405     if (!AndC)
10406       return SDValue();
10407 
10408     const APInt &Mask = AndC->getAPIntValue();
10409     unsigned ActiveBits = 0;
10410     if (Mask.isMask()) {
10411       ActiveBits = Mask.countTrailingOnes();
10412     } else if (Mask.isShiftedMask()) {
10413       ShAmt = Mask.countTrailingZeros();
10414       APInt ShiftedMask = Mask.lshr(ShAmt);
10415       ActiveBits = ShiftedMask.countTrailingOnes();
10416       HasShiftedOffset = true;
10417     } else
10418       return SDValue();
10419 
10420     ExtType = ISD::ZEXTLOAD;
10421     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
10422   }
10423 
10424   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
10425     SDValue SRL = N0;
10426     if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
10427       ShAmt = ConstShift->getZExtValue();
10428       unsigned EVTBits = ExtVT.getSizeInBits();
10429       // Is the shift amount a multiple of size of VT?
10430       if ((ShAmt & (EVTBits-1)) == 0) {
10431         N0 = N0.getOperand(0);
10432         // Is the load width a multiple of size of VT?
10433         if ((N0.getValueSizeInBits() & (EVTBits-1)) != 0)
10434           return SDValue();
10435       }
10436 
10437       // At this point, we must have a load or else we can't do the transform.
10438       if (!isa<LoadSDNode>(N0)) return SDValue();
10439 
10440       auto *LN0 = cast<LoadSDNode>(N0);
10441 
10442       // Because a SRL must be assumed to *need* to zero-extend the high bits
10443       // (as opposed to anyext the high bits), we can't combine the zextload
10444       // lowering of SRL and an sextload.
10445       if (LN0->getExtensionType() == ISD::SEXTLOAD)
10446         return SDValue();
10447 
10448       // If the shift amount is larger than the input type then we're not
10449       // accessing any of the loaded bytes.  If the load was a zextload/extload
10450       // then the result of the shift+trunc is zero/undef (handled elsewhere).
10451       if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
10452         return SDValue();
10453 
10454       // If the SRL is only used by a masking AND, we may be able to adjust
10455       // the ExtVT to make the AND redundant.
10456       SDNode *Mask = *(SRL->use_begin());
10457       if (Mask->getOpcode() == ISD::AND &&
10458           isa<ConstantSDNode>(Mask->getOperand(1))) {
10459         const APInt &ShiftMask =
10460           cast<ConstantSDNode>(Mask->getOperand(1))->getAPIntValue();
10461         if (ShiftMask.isMask()) {
10462           EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
10463                                            ShiftMask.countTrailingOnes());
10464           // If the mask is smaller, recompute the type.
10465           if ((ExtVT.getSizeInBits() > MaskedVT.getSizeInBits()) &&
10466               TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
10467             ExtVT = MaskedVT;
10468         }
10469       }
10470     }
10471   }
10472 
10473   // If the load is shifted left (and the result isn't shifted back right),
10474   // we can fold the truncate through the shift.
10475   unsigned ShLeftAmt = 0;
10476   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
10477       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
10478     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
10479       ShLeftAmt = N01->getZExtValue();
10480       N0 = N0.getOperand(0);
10481     }
10482   }
10483 
10484   // If we haven't found a load, we can't narrow it.
10485   if (!isa<LoadSDNode>(N0))
10486     return SDValue();
10487 
10488   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10489   // Reducing the width of a volatile load is illegal.  For atomics, we may be
10490   // able to reduce the width provided we never widen again. (see D66309)
10491   if (!LN0->isSimple() ||
10492       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
10493     return SDValue();
10494 
10495   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
10496     unsigned LVTStoreBits = LN0->getMemoryVT().getStoreSizeInBits();
10497     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits();
10498     return LVTStoreBits - EVTStoreBits - ShAmt;
10499   };
10500 
10501   // For big endian targets, we need to adjust the offset to the pointer to
10502   // load the correct bytes.
10503   if (DAG.getDataLayout().isBigEndian())
10504     ShAmt = AdjustBigEndianShift(ShAmt);
10505 
10506   uint64_t PtrOff = ShAmt / 8;
10507   unsigned NewAlign = MinAlign(LN0->getAlignment(), PtrOff);
10508   SDLoc DL(LN0);
10509   // The original load itself didn't wrap, so an offset within it doesn't.
10510   SDNodeFlags Flags;
10511   Flags.setNoUnsignedWrap(true);
10512   SDValue NewPtr =
10513       DAG.getMemBasePlusOffset(LN0->getBasePtr(), PtrOff, DL, Flags);
10514   AddToWorklist(NewPtr.getNode());
10515 
10516   SDValue Load;
10517   if (ExtType == ISD::NON_EXTLOAD)
10518     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
10519                        LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
10520                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
10521   else
10522     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
10523                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
10524                           NewAlign, LN0->getMemOperand()->getFlags(),
10525                           LN0->getAAInfo());
10526 
10527   // Replace the old load's chain with the new load's chain.
10528   WorklistRemover DeadNodes(*this);
10529   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
10530 
10531   // Shift the result left, if we've swallowed a left shift.
10532   SDValue Result = Load;
10533   if (ShLeftAmt != 0) {
10534     EVT ShImmTy = getShiftAmountTy(Result.getValueType());
10535     if (!isUIntN(ShImmTy.getSizeInBits(), ShLeftAmt))
10536       ShImmTy = VT;
10537     // If the shift amount is as large as the result size (but, presumably,
10538     // no larger than the source) then the useful bits of the result are
10539     // zero; we can't simply return the shortened shift, because the result
10540     // of that operation is undefined.
10541     if (ShLeftAmt >= VT.getSizeInBits())
10542       Result = DAG.getConstant(0, DL, VT);
10543     else
10544       Result = DAG.getNode(ISD::SHL, DL, VT,
10545                           Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
10546   }
10547 
10548   if (HasShiftedOffset) {
10549     // Recalculate the shift amount after it has been altered to calculate
10550     // the offset.
10551     if (DAG.getDataLayout().isBigEndian())
10552       ShAmt = AdjustBigEndianShift(ShAmt);
10553 
10554     // We're using a shifted mask, so the load now has an offset. This means
10555     // that data has been loaded into the lower bytes than it would have been
10556     // before, so we need to shl the loaded data into the correct position in the
10557     // register.
10558     SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
10559     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
10560     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
10561   }
10562 
10563   // Return the new loaded value.
10564   return Result;
10565 }
10566 
visitSIGN_EXTEND_INREG(SDNode * N)10567 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
10568   SDValue N0 = N->getOperand(0);
10569   SDValue N1 = N->getOperand(1);
10570   EVT VT = N->getValueType(0);
10571   EVT EVT = cast<VTSDNode>(N1)->getVT();
10572   unsigned VTBits = VT.getScalarSizeInBits();
10573   unsigned EVTBits = EVT.getScalarSizeInBits();
10574 
10575   if (N0.isUndef())
10576     return DAG.getUNDEF(VT);
10577 
10578   // fold (sext_in_reg c1) -> c1
10579   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10580     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
10581 
10582   // If the input is already sign extended, just drop the extension.
10583   if (DAG.ComputeNumSignBits(N0) >= VTBits-EVTBits+1)
10584     return N0;
10585 
10586   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
10587   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
10588       EVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
10589     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
10590                        N0.getOperand(0), N1);
10591 
10592   // fold (sext_in_reg (sext x)) -> (sext x)
10593   // fold (sext_in_reg (aext x)) -> (sext x)
10594   // if x is small enough or if we know that x has more than 1 sign bit and the
10595   // sign_extend_inreg is extending from one of them.
10596   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
10597     SDValue N00 = N0.getOperand(0);
10598     unsigned N00Bits = N00.getScalarValueSizeInBits();
10599     if ((N00Bits <= EVTBits ||
10600          (N00Bits - DAG.ComputeNumSignBits(N00)) < EVTBits) &&
10601         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
10602       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
10603   }
10604 
10605   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
10606   if ((N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
10607        N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
10608        N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) &&
10609       N0.getOperand(0).getScalarValueSizeInBits() == EVTBits) {
10610     if (!LegalOperations ||
10611         TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT))
10612       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
10613                          N0.getOperand(0));
10614   }
10615 
10616   // fold (sext_in_reg (zext x)) -> (sext x)
10617   // iff we are extending the source sign bit.
10618   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
10619     SDValue N00 = N0.getOperand(0);
10620     if (N00.getScalarValueSizeInBits() == EVTBits &&
10621         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
10622       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
10623   }
10624 
10625   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
10626   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, EVTBits - 1)))
10627     return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT.getScalarType());
10628 
10629   // fold operands of sext_in_reg based on knowledge that the top bits are not
10630   // demanded.
10631   if (SimplifyDemandedBits(SDValue(N, 0)))
10632     return SDValue(N, 0);
10633 
10634   // fold (sext_in_reg (load x)) -> (smaller sextload x)
10635   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
10636   if (SDValue NarrowLoad = ReduceLoadWidth(N))
10637     return NarrowLoad;
10638 
10639   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
10640   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
10641   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
10642   if (N0.getOpcode() == ISD::SRL) {
10643     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
10644       if (ShAmt->getAPIntValue().ule(VTBits - EVTBits)) {
10645         // We can turn this into an SRA iff the input to the SRL is already sign
10646         // extended enough.
10647         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
10648         if (((VTBits - EVTBits) - ShAmt->getZExtValue()) < InSignBits)
10649           return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
10650                              N0.getOperand(1));
10651       }
10652   }
10653 
10654   // fold (sext_inreg (extload x)) -> (sextload x)
10655   // If sextload is not supported by target, we can only do the combine when
10656   // load has one use. Doing otherwise can block folding the extload with other
10657   // extends that the target does support.
10658   if (ISD::isEXTLoad(N0.getNode()) &&
10659       ISD::isUNINDEXEDLoad(N0.getNode()) &&
10660       EVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
10661       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
10662         N0.hasOneUse()) ||
10663        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) {
10664     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10665     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
10666                                      LN0->getChain(),
10667                                      LN0->getBasePtr(), EVT,
10668                                      LN0->getMemOperand());
10669     CombineTo(N, ExtLoad);
10670     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
10671     AddToWorklist(ExtLoad.getNode());
10672     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10673   }
10674   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
10675   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
10676       N0.hasOneUse() &&
10677       EVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
10678       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
10679        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) {
10680     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10681     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
10682                                      LN0->getChain(),
10683                                      LN0->getBasePtr(), EVT,
10684                                      LN0->getMemOperand());
10685     CombineTo(N, ExtLoad);
10686     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
10687     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10688   }
10689 
10690   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
10691   if (EVTBits <= 16 && N0.getOpcode() == ISD::OR) {
10692     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
10693                                            N0.getOperand(1), false))
10694       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
10695                          BSwap, N1);
10696   }
10697 
10698   return SDValue();
10699 }
10700 
visitSIGN_EXTEND_VECTOR_INREG(SDNode * N)10701 SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) {
10702   SDValue N0 = N->getOperand(0);
10703   EVT VT = N->getValueType(0);
10704 
10705   if (N0.isUndef())
10706     return DAG.getUNDEF(VT);
10707 
10708   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
10709     return Res;
10710 
10711   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
10712     return SDValue(N, 0);
10713 
10714   return SDValue();
10715 }
10716 
visitZERO_EXTEND_VECTOR_INREG(SDNode * N)10717 SDValue DAGCombiner::visitZERO_EXTEND_VECTOR_INREG(SDNode *N) {
10718   SDValue N0 = N->getOperand(0);
10719   EVT VT = N->getValueType(0);
10720 
10721   if (N0.isUndef())
10722     return DAG.getUNDEF(VT);
10723 
10724   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
10725     return Res;
10726 
10727   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
10728     return SDValue(N, 0);
10729 
10730   return SDValue();
10731 }
10732 
visitTRUNCATE(SDNode * N)10733 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
10734   SDValue N0 = N->getOperand(0);
10735   EVT VT = N->getValueType(0);
10736   EVT SrcVT = N0.getValueType();
10737   bool isLE = DAG.getDataLayout().isLittleEndian();
10738 
10739   // noop truncate
10740   if (SrcVT == VT)
10741     return N0;
10742 
10743   // fold (truncate (truncate x)) -> (truncate x)
10744   if (N0.getOpcode() == ISD::TRUNCATE)
10745     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
10746 
10747   // fold (truncate c1) -> c1
10748   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
10749     SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
10750     if (C.getNode() != N)
10751       return C;
10752   }
10753 
10754   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
10755   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
10756       N0.getOpcode() == ISD::SIGN_EXTEND ||
10757       N0.getOpcode() == ISD::ANY_EXTEND) {
10758     // if the source is smaller than the dest, we still need an extend.
10759     if (N0.getOperand(0).getValueType().bitsLT(VT))
10760       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
10761     // if the source is larger than the dest, than we just need the truncate.
10762     if (N0.getOperand(0).getValueType().bitsGT(VT))
10763       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
10764     // if the source and dest are the same type, we can drop both the extend
10765     // and the truncate.
10766     return N0.getOperand(0);
10767   }
10768 
10769   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
10770   if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
10771     return SDValue();
10772 
10773   // Fold extract-and-trunc into a narrow extract. For example:
10774   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
10775   //   i32 y = TRUNCATE(i64 x)
10776   //        -- becomes --
10777   //   v16i8 b = BITCAST (v2i64 val)
10778   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
10779   //
10780   // Note: We only run this optimization after type legalization (which often
10781   // creates this pattern) and before operation legalization after which
10782   // we need to be more careful about the vector instructions that we generate.
10783   if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
10784       LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
10785     EVT VecTy = N0.getOperand(0).getValueType();
10786     EVT ExTy = N0.getValueType();
10787     EVT TrTy = N->getValueType(0);
10788 
10789     unsigned NumElem = VecTy.getVectorNumElements();
10790     unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
10791 
10792     EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, SizeRatio * NumElem);
10793     assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
10794 
10795     SDValue EltNo = N0->getOperand(1);
10796     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
10797       int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
10798       EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout());
10799       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
10800 
10801       SDLoc DL(N);
10802       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
10803                          DAG.getBitcast(NVT, N0.getOperand(0)),
10804                          DAG.getConstant(Index, DL, IndexTy));
10805     }
10806   }
10807 
10808   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
10809   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
10810     if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
10811         TLI.isTruncateFree(SrcVT, VT)) {
10812       SDLoc SL(N0);
10813       SDValue Cond = N0.getOperand(0);
10814       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
10815       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
10816       return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
10817     }
10818   }
10819 
10820   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
10821   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
10822       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
10823       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
10824     SDValue Amt = N0.getOperand(1);
10825     KnownBits Known = DAG.computeKnownBits(Amt);
10826     unsigned Size = VT.getScalarSizeInBits();
10827     if (Known.getBitWidth() - Known.countMinLeadingZeros() <= Log2_32(Size)) {
10828       SDLoc SL(N);
10829       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
10830 
10831       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
10832       if (AmtVT != Amt.getValueType()) {
10833         Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
10834         AddToWorklist(Amt.getNode());
10835       }
10836       return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
10837     }
10838   }
10839 
10840   // Attempt to pre-truncate BUILD_VECTOR sources.
10841   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
10842       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType())) {
10843     SDLoc DL(N);
10844     EVT SVT = VT.getScalarType();
10845     SmallVector<SDValue, 8> TruncOps;
10846     for (const SDValue &Op : N0->op_values()) {
10847       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
10848       TruncOps.push_back(TruncOp);
10849     }
10850     return DAG.getBuildVector(VT, DL, TruncOps);
10851   }
10852 
10853   // Fold a series of buildvector, bitcast, and truncate if possible.
10854   // For example fold
10855   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
10856   //   (2xi32 (buildvector x, y)).
10857   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
10858       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
10859       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
10860       N0.getOperand(0).hasOneUse()) {
10861     SDValue BuildVect = N0.getOperand(0);
10862     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
10863     EVT TruncVecEltTy = VT.getVectorElementType();
10864 
10865     // Check that the element types match.
10866     if (BuildVectEltTy == TruncVecEltTy) {
10867       // Now we only need to compute the offset of the truncated elements.
10868       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
10869       unsigned TruncVecNumElts = VT.getVectorNumElements();
10870       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
10871 
10872       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
10873              "Invalid number of elements");
10874 
10875       SmallVector<SDValue, 8> Opnds;
10876       for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
10877         Opnds.push_back(BuildVect.getOperand(i));
10878 
10879       return DAG.getBuildVector(VT, SDLoc(N), Opnds);
10880     }
10881   }
10882 
10883   // See if we can simplify the input to this truncate through knowledge that
10884   // only the low bits are being used.
10885   // For example "trunc (or (shl x, 8), y)" // -> trunc y
10886   // Currently we only perform this optimization on scalars because vectors
10887   // may have different active low bits.
10888   if (!VT.isVector()) {
10889     APInt Mask =
10890         APInt::getLowBitsSet(N0.getValueSizeInBits(), VT.getSizeInBits());
10891     if (SDValue Shorter = DAG.GetDemandedBits(N0, Mask))
10892       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter);
10893   }
10894 
10895   // fold (truncate (load x)) -> (smaller load x)
10896   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
10897   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
10898     if (SDValue Reduced = ReduceLoadWidth(N))
10899       return Reduced;
10900 
10901     // Handle the case where the load remains an extending load even
10902     // after truncation.
10903     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
10904       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10905       if (LN0->isSimple() &&
10906           LN0->getMemoryVT().getStoreSizeInBits() < VT.getSizeInBits()) {
10907         SDValue NewLoad = DAG.getExtLoad(LN0->getExtensionType(), SDLoc(LN0),
10908                                          VT, LN0->getChain(), LN0->getBasePtr(),
10909                                          LN0->getMemoryVT(),
10910                                          LN0->getMemOperand());
10911         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
10912         return NewLoad;
10913       }
10914     }
10915   }
10916 
10917   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
10918   // where ... are all 'undef'.
10919   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
10920     SmallVector<EVT, 8> VTs;
10921     SDValue V;
10922     unsigned Idx = 0;
10923     unsigned NumDefs = 0;
10924 
10925     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
10926       SDValue X = N0.getOperand(i);
10927       if (!X.isUndef()) {
10928         V = X;
10929         Idx = i;
10930         NumDefs++;
10931       }
10932       // Stop if more than one members are non-undef.
10933       if (NumDefs > 1)
10934         break;
10935       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
10936                                      VT.getVectorElementType(),
10937                                      X.getValueType().getVectorNumElements()));
10938     }
10939 
10940     if (NumDefs == 0)
10941       return DAG.getUNDEF(VT);
10942 
10943     if (NumDefs == 1) {
10944       assert(V.getNode() && "The single defined operand is empty!");
10945       SmallVector<SDValue, 8> Opnds;
10946       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
10947         if (i != Idx) {
10948           Opnds.push_back(DAG.getUNDEF(VTs[i]));
10949           continue;
10950         }
10951         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
10952         AddToWorklist(NV.getNode());
10953         Opnds.push_back(NV);
10954       }
10955       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
10956     }
10957   }
10958 
10959   // Fold truncate of a bitcast of a vector to an extract of the low vector
10960   // element.
10961   //
10962   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
10963   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
10964     SDValue VecSrc = N0.getOperand(0);
10965     EVT VecSrcVT = VecSrc.getValueType();
10966     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
10967         (!LegalOperations ||
10968          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
10969       SDLoc SL(N);
10970 
10971       EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
10972       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
10973       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
10974                          DAG.getConstant(Idx, SL, IdxVT));
10975     }
10976   }
10977 
10978   // Simplify the operands using demanded-bits information.
10979   if (!VT.isVector() &&
10980       SimplifyDemandedBits(SDValue(N, 0)))
10981     return SDValue(N, 0);
10982 
10983   // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
10984   // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
10985   // When the adde's carry is not used.
10986   if ((N0.getOpcode() == ISD::ADDE || N0.getOpcode() == ISD::ADDCARRY) &&
10987       N0.hasOneUse() && !N0.getNode()->hasAnyUseOfValue(1) &&
10988       // We only do for addcarry before legalize operation
10989       ((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
10990        TLI.isOperationLegal(N0.getOpcode(), VT))) {
10991     SDLoc SL(N);
10992     auto X = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
10993     auto Y = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
10994     auto VTs = DAG.getVTList(VT, N0->getValueType(1));
10995     return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2));
10996   }
10997 
10998   // fold (truncate (extract_subvector(ext x))) ->
10999   //      (extract_subvector x)
11000   // TODO: This can be generalized to cover cases where the truncate and extract
11001   // do not fully cancel each other out.
11002   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
11003     SDValue N00 = N0.getOperand(0);
11004     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
11005         N00.getOpcode() == ISD::ZERO_EXTEND ||
11006         N00.getOpcode() == ISD::ANY_EXTEND) {
11007       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
11008           VT.getVectorElementType())
11009         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
11010                            N00.getOperand(0), N0.getOperand(1));
11011     }
11012   }
11013 
11014   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
11015     return NewVSel;
11016 
11017   // Narrow a suitable binary operation with a non-opaque constant operand by
11018   // moving it ahead of the truncate. This is limited to pre-legalization
11019   // because targets may prefer a wider type during later combines and invert
11020   // this transform.
11021   switch (N0.getOpcode()) {
11022   case ISD::ADD:
11023   case ISD::SUB:
11024   case ISD::MUL:
11025   case ISD::AND:
11026   case ISD::OR:
11027   case ISD::XOR:
11028     if (!LegalOperations && N0.hasOneUse() &&
11029         (isConstantOrConstantVector(N0.getOperand(0), true) ||
11030          isConstantOrConstantVector(N0.getOperand(1), true))) {
11031       // TODO: We already restricted this to pre-legalization, but for vectors
11032       // we are extra cautious to not create an unsupported operation.
11033       // Target-specific changes are likely needed to avoid regressions here.
11034       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
11035         SDLoc DL(N);
11036         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
11037         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
11038         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
11039       }
11040     }
11041   }
11042 
11043   return SDValue();
11044 }
11045 
getBuildPairElt(SDNode * N,unsigned i)11046 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
11047   SDValue Elt = N->getOperand(i);
11048   if (Elt.getOpcode() != ISD::MERGE_VALUES)
11049     return Elt.getNode();
11050   return Elt.getOperand(Elt.getResNo()).getNode();
11051 }
11052 
11053 /// build_pair (load, load) -> load
11054 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)11055 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
11056   assert(N->getOpcode() == ISD::BUILD_PAIR);
11057 
11058   LoadSDNode *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
11059   LoadSDNode *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
11060 
11061   // A BUILD_PAIR is always having the least significant part in elt 0 and the
11062   // most significant part in elt 1. So when combining into one large load, we
11063   // need to consider the endianness.
11064   if (DAG.getDataLayout().isBigEndian())
11065     std::swap(LD1, LD2);
11066 
11067   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !LD1->hasOneUse() ||
11068       LD1->getAddressSpace() != LD2->getAddressSpace())
11069     return SDValue();
11070   EVT LD1VT = LD1->getValueType(0);
11071   unsigned LD1Bytes = LD1VT.getStoreSize();
11072   if (ISD::isNON_EXTLoad(LD2) && LD2->hasOneUse() &&
11073       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1)) {
11074     unsigned Align = LD1->getAlignment();
11075     unsigned NewAlign = DAG.getDataLayout().getABITypeAlignment(
11076         VT.getTypeForEVT(*DAG.getContext()));
11077 
11078     if (NewAlign <= Align &&
11079         (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)))
11080       return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
11081                          LD1->getPointerInfo(), Align);
11082   }
11083 
11084   return SDValue();
11085 }
11086 
getPPCf128HiElementSelector(const SelectionDAG & DAG)11087 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
11088   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
11089   // and Lo parts; on big-endian machines it doesn't.
11090   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
11091 }
11092 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)11093 static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
11094                                     const TargetLowering &TLI) {
11095   // If this is not a bitcast to an FP type or if the target doesn't have
11096   // IEEE754-compliant FP logic, we're done.
11097   EVT VT = N->getValueType(0);
11098   if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
11099     return SDValue();
11100 
11101   // TODO: Handle cases where the integer constant is a different scalar
11102   // bitwidth to the FP.
11103   SDValue N0 = N->getOperand(0);
11104   EVT SourceVT = N0.getValueType();
11105   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
11106     return SDValue();
11107 
11108   unsigned FPOpcode;
11109   APInt SignMask;
11110   switch (N0.getOpcode()) {
11111   case ISD::AND:
11112     FPOpcode = ISD::FABS;
11113     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
11114     break;
11115   case ISD::XOR:
11116     FPOpcode = ISD::FNEG;
11117     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
11118     break;
11119   case ISD::OR:
11120     FPOpcode = ISD::FABS;
11121     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
11122     break;
11123   default:
11124     return SDValue();
11125   }
11126 
11127   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
11128   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
11129   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
11130   //   fneg (fabs X)
11131   SDValue LogicOp0 = N0.getOperand(0);
11132   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
11133   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
11134       LogicOp0.getOpcode() == ISD::BITCAST &&
11135       LogicOp0.getOperand(0).getValueType() == VT) {
11136     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
11137     NumFPLogicOpsConv++;
11138     if (N0.getOpcode() == ISD::OR)
11139       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
11140     return FPOp;
11141   }
11142 
11143   return SDValue();
11144 }
11145 
visitBITCAST(SDNode * N)11146 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
11147   SDValue N0 = N->getOperand(0);
11148   EVT VT = N->getValueType(0);
11149 
11150   if (N0.isUndef())
11151     return DAG.getUNDEF(VT);
11152 
11153   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
11154   // Only do this before legalize types, unless both types are integer and the
11155   // scalar type is legal. Only do this before legalize ops, since the target
11156   // maybe depending on the bitcast.
11157   // First check to see if this is all constant.
11158   // TODO: Support FP bitcasts after legalize types.
11159   if (VT.isVector() &&
11160       (!LegalTypes ||
11161        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
11162         TLI.isTypeLegal(VT.getVectorElementType()))) &&
11163       N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() &&
11164       cast<BuildVectorSDNode>(N0)->isConstant())
11165     return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
11166                                              VT.getVectorElementType());
11167 
11168   // If the input is a constant, let getNode fold it.
11169   if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0)) {
11170     // If we can't allow illegal operations, we need to check that this is just
11171     // a fp -> int or int -> conversion and that the resulting operation will
11172     // be legal.
11173     if (!LegalOperations ||
11174         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
11175          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
11176         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
11177          TLI.isOperationLegal(ISD::Constant, VT))) {
11178       SDValue C = DAG.getBitcast(VT, N0);
11179       if (C.getNode() != N)
11180         return C;
11181     }
11182   }
11183 
11184   // (conv (conv x, t1), t2) -> (conv x, t2)
11185   if (N0.getOpcode() == ISD::BITCAST)
11186     return DAG.getBitcast(VT, N0.getOperand(0));
11187 
11188   // fold (conv (load x)) -> (load (conv*)x)
11189   // If the resultant load doesn't need a higher alignment than the original!
11190   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
11191       // Do not remove the cast if the types differ in endian layout.
11192       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
11193           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
11194       // If the load is volatile, we only want to change the load type if the
11195       // resulting load is legal. Otherwise we might increase the number of
11196       // memory accesses. We don't care if the original type was legal or not
11197       // as we assume software couldn't rely on the number of accesses of an
11198       // illegal type.
11199       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
11200        TLI.isOperationLegal(ISD::LOAD, VT))) {
11201     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11202 
11203     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
11204                                     *LN0->getMemOperand())) {
11205       SDValue Load =
11206           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
11207                       LN0->getPointerInfo(), LN0->getAlignment(),
11208                       LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
11209       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
11210       return Load;
11211     }
11212   }
11213 
11214   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
11215     return V;
11216 
11217   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
11218   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
11219   //
11220   // For ppc_fp128:
11221   // fold (bitcast (fneg x)) ->
11222   //     flipbit = signbit
11223   //     (xor (bitcast x) (build_pair flipbit, flipbit))
11224   //
11225   // fold (bitcast (fabs x)) ->
11226   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
11227   //     (xor (bitcast x) (build_pair flipbit, flipbit))
11228   // This often reduces constant pool loads.
11229   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
11230        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
11231       N0.getNode()->hasOneUse() && VT.isInteger() &&
11232       !VT.isVector() && !N0.getValueType().isVector()) {
11233     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
11234     AddToWorklist(NewConv.getNode());
11235 
11236     SDLoc DL(N);
11237     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
11238       assert(VT.getSizeInBits() == 128);
11239       SDValue SignBit = DAG.getConstant(
11240           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
11241       SDValue FlipBit;
11242       if (N0.getOpcode() == ISD::FNEG) {
11243         FlipBit = SignBit;
11244         AddToWorklist(FlipBit.getNode());
11245       } else {
11246         assert(N0.getOpcode() == ISD::FABS);
11247         SDValue Hi =
11248             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
11249                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
11250                                               SDLoc(NewConv)));
11251         AddToWorklist(Hi.getNode());
11252         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
11253         AddToWorklist(FlipBit.getNode());
11254       }
11255       SDValue FlipBits =
11256           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
11257       AddToWorklist(FlipBits.getNode());
11258       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
11259     }
11260     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
11261     if (N0.getOpcode() == ISD::FNEG)
11262       return DAG.getNode(ISD::XOR, DL, VT,
11263                          NewConv, DAG.getConstant(SignBit, DL, VT));
11264     assert(N0.getOpcode() == ISD::FABS);
11265     return DAG.getNode(ISD::AND, DL, VT,
11266                        NewConv, DAG.getConstant(~SignBit, DL, VT));
11267   }
11268 
11269   // fold (bitconvert (fcopysign cst, x)) ->
11270   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
11271   // Note that we don't handle (copysign x, cst) because this can always be
11272   // folded to an fneg or fabs.
11273   //
11274   // For ppc_fp128:
11275   // fold (bitcast (fcopysign cst, x)) ->
11276   //     flipbit = (and (extract_element
11277   //                     (xor (bitcast cst), (bitcast x)), 0),
11278   //                    signbit)
11279   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
11280   if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse() &&
11281       isa<ConstantFPSDNode>(N0.getOperand(0)) &&
11282       VT.isInteger() && !VT.isVector()) {
11283     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
11284     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
11285     if (isTypeLegal(IntXVT)) {
11286       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
11287       AddToWorklist(X.getNode());
11288 
11289       // If X has a different width than the result/lhs, sext it or truncate it.
11290       unsigned VTWidth = VT.getSizeInBits();
11291       if (OrigXWidth < VTWidth) {
11292         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
11293         AddToWorklist(X.getNode());
11294       } else if (OrigXWidth > VTWidth) {
11295         // To get the sign bit in the right place, we have to shift it right
11296         // before truncating.
11297         SDLoc DL(X);
11298         X = DAG.getNode(ISD::SRL, DL,
11299                         X.getValueType(), X,
11300                         DAG.getConstant(OrigXWidth-VTWidth, DL,
11301                                         X.getValueType()));
11302         AddToWorklist(X.getNode());
11303         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
11304         AddToWorklist(X.getNode());
11305       }
11306 
11307       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
11308         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
11309         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
11310         AddToWorklist(Cst.getNode());
11311         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
11312         AddToWorklist(X.getNode());
11313         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
11314         AddToWorklist(XorResult.getNode());
11315         SDValue XorResult64 = DAG.getNode(
11316             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
11317             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
11318                                   SDLoc(XorResult)));
11319         AddToWorklist(XorResult64.getNode());
11320         SDValue FlipBit =
11321             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
11322                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
11323         AddToWorklist(FlipBit.getNode());
11324         SDValue FlipBits =
11325             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
11326         AddToWorklist(FlipBits.getNode());
11327         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
11328       }
11329       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
11330       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
11331                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
11332       AddToWorklist(X.getNode());
11333 
11334       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
11335       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
11336                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
11337       AddToWorklist(Cst.getNode());
11338 
11339       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
11340     }
11341   }
11342 
11343   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
11344   if (N0.getOpcode() == ISD::BUILD_PAIR)
11345     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
11346       return CombineLD;
11347 
11348   // Remove double bitcasts from shuffles - this is often a legacy of
11349   // XformToShuffleWithZero being used to combine bitmaskings (of
11350   // float vectors bitcast to integer vectors) into shuffles.
11351   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
11352   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
11353       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
11354       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
11355       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
11356     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
11357 
11358     // If operands are a bitcast, peek through if it casts the original VT.
11359     // If operands are a constant, just bitcast back to original VT.
11360     auto PeekThroughBitcast = [&](SDValue Op) {
11361       if (Op.getOpcode() == ISD::BITCAST &&
11362           Op.getOperand(0).getValueType() == VT)
11363         return SDValue(Op.getOperand(0));
11364       if (Op.isUndef() || ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
11365           ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
11366         return DAG.getBitcast(VT, Op);
11367       return SDValue();
11368     };
11369 
11370     // FIXME: If either input vector is bitcast, try to convert the shuffle to
11371     // the result type of this bitcast. This would eliminate at least one
11372     // bitcast. See the transform in InstCombine.
11373     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
11374     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
11375     if (!(SV0 && SV1))
11376       return SDValue();
11377 
11378     int MaskScale =
11379         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
11380     SmallVector<int, 8> NewMask;
11381     for (int M : SVN->getMask())
11382       for (int i = 0; i != MaskScale; ++i)
11383         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
11384 
11385     SDValue LegalShuffle =
11386         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
11387     if (LegalShuffle)
11388       return LegalShuffle;
11389   }
11390 
11391   return SDValue();
11392 }
11393 
visitBUILD_PAIR(SDNode * N)11394 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
11395   EVT VT = N->getValueType(0);
11396   return CombineConsecutiveLoads(N, VT);
11397 }
11398 
11399 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
11400 /// operands. DstEltVT indicates the destination element value type.
11401 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)11402 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
11403   EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
11404 
11405   // If this is already the right type, we're done.
11406   if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
11407 
11408   unsigned SrcBitSize = SrcEltVT.getSizeInBits();
11409   unsigned DstBitSize = DstEltVT.getSizeInBits();
11410 
11411   // If this is a conversion of N elements of one type to N elements of another
11412   // type, convert each element.  This handles FP<->INT cases.
11413   if (SrcBitSize == DstBitSize) {
11414     SmallVector<SDValue, 8> Ops;
11415     for (SDValue Op : BV->op_values()) {
11416       // If the vector element type is not legal, the BUILD_VECTOR operands
11417       // are promoted and implicitly truncated.  Make that explicit here.
11418       if (Op.getValueType() != SrcEltVT)
11419         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
11420       Ops.push_back(DAG.getBitcast(DstEltVT, Op));
11421       AddToWorklist(Ops.back().getNode());
11422     }
11423     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
11424                               BV->getValueType(0).getVectorNumElements());
11425     return DAG.getBuildVector(VT, SDLoc(BV), Ops);
11426   }
11427 
11428   // Otherwise, we're growing or shrinking the elements.  To avoid having to
11429   // handle annoying details of growing/shrinking FP values, we convert them to
11430   // int first.
11431   if (SrcEltVT.isFloatingPoint()) {
11432     // Convert the input float vector to a int vector where the elements are the
11433     // same sizes.
11434     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
11435     BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
11436     SrcEltVT = IntVT;
11437   }
11438 
11439   // Now we know the input is an integer vector.  If the output is a FP type,
11440   // convert to integer first, then to FP of the right size.
11441   if (DstEltVT.isFloatingPoint()) {
11442     EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
11443     SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
11444 
11445     // Next, convert to FP elements of the same size.
11446     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
11447   }
11448 
11449   SDLoc DL(BV);
11450 
11451   // Okay, we know the src/dst types are both integers of differing types.
11452   // Handling growing first.
11453   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
11454   if (SrcBitSize < DstBitSize) {
11455     unsigned NumInputsPerOutput = DstBitSize/SrcBitSize;
11456 
11457     SmallVector<SDValue, 8> Ops;
11458     for (unsigned i = 0, e = BV->getNumOperands(); i != e;
11459          i += NumInputsPerOutput) {
11460       bool isLE = DAG.getDataLayout().isLittleEndian();
11461       APInt NewBits = APInt(DstBitSize, 0);
11462       bool EltIsUndef = true;
11463       for (unsigned j = 0; j != NumInputsPerOutput; ++j) {
11464         // Shift the previously computed bits over.
11465         NewBits <<= SrcBitSize;
11466         SDValue Op = BV->getOperand(i+ (isLE ? (NumInputsPerOutput-j-1) : j));
11467         if (Op.isUndef()) continue;
11468         EltIsUndef = false;
11469 
11470         NewBits |= cast<ConstantSDNode>(Op)->getAPIntValue().
11471                    zextOrTrunc(SrcBitSize).zext(DstBitSize);
11472       }
11473 
11474       if (EltIsUndef)
11475         Ops.push_back(DAG.getUNDEF(DstEltVT));
11476       else
11477         Ops.push_back(DAG.getConstant(NewBits, DL, DstEltVT));
11478     }
11479 
11480     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
11481     return DAG.getBuildVector(VT, DL, Ops);
11482   }
11483 
11484   // Finally, this must be the case where we are shrinking elements: each input
11485   // turns into multiple outputs.
11486   unsigned NumOutputsPerInput = SrcBitSize/DstBitSize;
11487   EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
11488                             NumOutputsPerInput*BV->getNumOperands());
11489   SmallVector<SDValue, 8> Ops;
11490 
11491   for (const SDValue &Op : BV->op_values()) {
11492     if (Op.isUndef()) {
11493       Ops.append(NumOutputsPerInput, DAG.getUNDEF(DstEltVT));
11494       continue;
11495     }
11496 
11497     APInt OpVal = cast<ConstantSDNode>(Op)->
11498                   getAPIntValue().zextOrTrunc(SrcBitSize);
11499 
11500     for (unsigned j = 0; j != NumOutputsPerInput; ++j) {
11501       APInt ThisVal = OpVal.trunc(DstBitSize);
11502       Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT));
11503       OpVal.lshrInPlace(DstBitSize);
11504     }
11505 
11506     // For big endian targets, swap the order of the pieces of each element.
11507     if (DAG.getDataLayout().isBigEndian())
11508       std::reverse(Ops.end()-NumOutputsPerInput, Ops.end());
11509   }
11510 
11511   return DAG.getBuildVector(VT, DL, Ops);
11512 }
11513 
isContractable(SDNode * N)11514 static bool isContractable(SDNode *N) {
11515   SDNodeFlags F = N->getFlags();
11516   return F.hasAllowContract() || F.hasAllowReassociation();
11517 }
11518 
11519 /// Try to perform FMA combining on a given FADD node.
visitFADDForFMACombine(SDNode * N)11520 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
11521   SDValue N0 = N->getOperand(0);
11522   SDValue N1 = N->getOperand(1);
11523   EVT VT = N->getValueType(0);
11524   SDLoc SL(N);
11525 
11526   const TargetOptions &Options = DAG.getTarget().Options;
11527 
11528   // Floating-point multiply-add with intermediate rounding.
11529   bool HasFMAD = (LegalOperations && TLI.isFMADLegalForFAddFSub(DAG, N));
11530 
11531   // Floating-point multiply-add without intermediate rounding.
11532   bool HasFMA =
11533       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
11534       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
11535 
11536   // No valid opcode, do not combine.
11537   if (!HasFMAD && !HasFMA)
11538     return SDValue();
11539 
11540   SDNodeFlags Flags = N->getFlags();
11541   bool CanFuse = Options.UnsafeFPMath || isContractable(N);
11542   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
11543                               CanFuse || HasFMAD);
11544   // If the addition is not contractable, do not combine.
11545   if (!AllowFusionGlobally && !isContractable(N))
11546     return SDValue();
11547 
11548   const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
11549   if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
11550     return SDValue();
11551 
11552   // Always prefer FMAD to FMA for precision.
11553   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
11554   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
11555 
11556   // Is the node an FMUL and contractable either due to global flags or
11557   // SDNodeFlags.
11558   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
11559     if (N.getOpcode() != ISD::FMUL)
11560       return false;
11561     return AllowFusionGlobally || isContractable(N.getNode());
11562   };
11563   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
11564   // prefer to fold the multiply with fewer uses.
11565   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
11566     if (N0.getNode()->use_size() > N1.getNode()->use_size())
11567       std::swap(N0, N1);
11568   }
11569 
11570   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
11571   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
11572     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11573                        N0.getOperand(0), N0.getOperand(1), N1, Flags);
11574   }
11575 
11576   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
11577   // Note: Commutes FADD operands.
11578   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
11579     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11580                        N1.getOperand(0), N1.getOperand(1), N0, Flags);
11581   }
11582 
11583   // Look through FP_EXTEND nodes to do more combining.
11584 
11585   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
11586   if (N0.getOpcode() == ISD::FP_EXTEND) {
11587     SDValue N00 = N0.getOperand(0);
11588     if (isContractableFMUL(N00) &&
11589         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11590                             N00.getValueType())) {
11591       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11592                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11593                                      N00.getOperand(0)),
11594                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11595                                      N00.getOperand(1)), N1, Flags);
11596     }
11597   }
11598 
11599   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
11600   // Note: Commutes FADD operands.
11601   if (N1.getOpcode() == ISD::FP_EXTEND) {
11602     SDValue N10 = N1.getOperand(0);
11603     if (isContractableFMUL(N10) &&
11604         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11605                             N10.getValueType())) {
11606       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11607                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11608                                      N10.getOperand(0)),
11609                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11610                                      N10.getOperand(1)), N0, Flags);
11611     }
11612   }
11613 
11614   // More folding opportunities when target permits.
11615   if (Aggressive) {
11616     // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
11617     if (CanFuse &&
11618         N0.getOpcode() == PreferredFusedOpcode &&
11619         N0.getOperand(2).getOpcode() == ISD::FMUL &&
11620         N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
11621       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11622                          N0.getOperand(0), N0.getOperand(1),
11623                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11624                                      N0.getOperand(2).getOperand(0),
11625                                      N0.getOperand(2).getOperand(1),
11626                                      N1, Flags), Flags);
11627     }
11628 
11629     // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x))
11630     if (CanFuse &&
11631         N1->getOpcode() == PreferredFusedOpcode &&
11632         N1.getOperand(2).getOpcode() == ISD::FMUL &&
11633         N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) {
11634       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11635                          N1.getOperand(0), N1.getOperand(1),
11636                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11637                                      N1.getOperand(2).getOperand(0),
11638                                      N1.getOperand(2).getOperand(1),
11639                                      N0, Flags), Flags);
11640     }
11641 
11642 
11643     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
11644     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
11645     auto FoldFAddFMAFPExtFMul = [&] (
11646       SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z,
11647       SDNodeFlags Flags) {
11648       return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
11649                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11650                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
11651                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
11652                                      Z, Flags), Flags);
11653     };
11654     if (N0.getOpcode() == PreferredFusedOpcode) {
11655       SDValue N02 = N0.getOperand(2);
11656       if (N02.getOpcode() == ISD::FP_EXTEND) {
11657         SDValue N020 = N02.getOperand(0);
11658         if (isContractableFMUL(N020) &&
11659             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11660                                 N020.getValueType())) {
11661           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
11662                                       N020.getOperand(0), N020.getOperand(1),
11663                                       N1, Flags);
11664         }
11665       }
11666     }
11667 
11668     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
11669     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
11670     // FIXME: This turns two single-precision and one double-precision
11671     // operation into two double-precision operations, which might not be
11672     // interesting for all targets, especially GPUs.
11673     auto FoldFAddFPExtFMAFMul = [&] (
11674       SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z,
11675       SDNodeFlags Flags) {
11676       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11677                          DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
11678                          DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
11679                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11680                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
11681                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
11682                                      Z, Flags), Flags);
11683     };
11684     if (N0.getOpcode() == ISD::FP_EXTEND) {
11685       SDValue N00 = N0.getOperand(0);
11686       if (N00.getOpcode() == PreferredFusedOpcode) {
11687         SDValue N002 = N00.getOperand(2);
11688         if (isContractableFMUL(N002) &&
11689             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11690                                 N00.getValueType())) {
11691           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
11692                                       N002.getOperand(0), N002.getOperand(1),
11693                                       N1, Flags);
11694         }
11695       }
11696     }
11697 
11698     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
11699     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
11700     if (N1.getOpcode() == PreferredFusedOpcode) {
11701       SDValue N12 = N1.getOperand(2);
11702       if (N12.getOpcode() == ISD::FP_EXTEND) {
11703         SDValue N120 = N12.getOperand(0);
11704         if (isContractableFMUL(N120) &&
11705             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11706                                 N120.getValueType())) {
11707           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
11708                                       N120.getOperand(0), N120.getOperand(1),
11709                                       N0, Flags);
11710         }
11711       }
11712     }
11713 
11714     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
11715     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
11716     // FIXME: This turns two single-precision and one double-precision
11717     // operation into two double-precision operations, which might not be
11718     // interesting for all targets, especially GPUs.
11719     if (N1.getOpcode() == ISD::FP_EXTEND) {
11720       SDValue N10 = N1.getOperand(0);
11721       if (N10.getOpcode() == PreferredFusedOpcode) {
11722         SDValue N102 = N10.getOperand(2);
11723         if (isContractableFMUL(N102) &&
11724             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11725                                 N10.getValueType())) {
11726           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
11727                                       N102.getOperand(0), N102.getOperand(1),
11728                                       N0, Flags);
11729         }
11730       }
11731     }
11732   }
11733 
11734   return SDValue();
11735 }
11736 
11737 /// Try to perform FMA combining on a given FSUB node.
visitFSUBForFMACombine(SDNode * N)11738 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
11739   SDValue N0 = N->getOperand(0);
11740   SDValue N1 = N->getOperand(1);
11741   EVT VT = N->getValueType(0);
11742   SDLoc SL(N);
11743 
11744   const TargetOptions &Options = DAG.getTarget().Options;
11745   // Floating-point multiply-add with intermediate rounding.
11746   bool HasFMAD = (LegalOperations && TLI.isFMADLegalForFAddFSub(DAG, N));
11747 
11748   // Floating-point multiply-add without intermediate rounding.
11749   bool HasFMA =
11750       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
11751       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
11752 
11753   // No valid opcode, do not combine.
11754   if (!HasFMAD && !HasFMA)
11755     return SDValue();
11756 
11757   const SDNodeFlags Flags = N->getFlags();
11758   bool CanFuse = Options.UnsafeFPMath || isContractable(N);
11759   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
11760                               CanFuse || HasFMAD);
11761 
11762   // If the subtraction is not contractable, do not combine.
11763   if (!AllowFusionGlobally && !isContractable(N))
11764     return SDValue();
11765 
11766   const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
11767   if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
11768     return SDValue();
11769 
11770   // Always prefer FMAD to FMA for precision.
11771   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
11772   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
11773 
11774   // Is the node an FMUL and contractable either due to global flags or
11775   // SDNodeFlags.
11776   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
11777     if (N.getOpcode() != ISD::FMUL)
11778       return false;
11779     return AllowFusionGlobally || isContractable(N.getNode());
11780   };
11781 
11782   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
11783   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
11784     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11785                        N0.getOperand(0), N0.getOperand(1),
11786                        DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
11787   }
11788 
11789   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
11790   // Note: Commutes FSUB operands.
11791   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
11792     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11793                        DAG.getNode(ISD::FNEG, SL, VT,
11794                                    N1.getOperand(0)),
11795                        N1.getOperand(1), N0, Flags);
11796   }
11797 
11798   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
11799   if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
11800       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
11801     SDValue N00 = N0.getOperand(0).getOperand(0);
11802     SDValue N01 = N0.getOperand(0).getOperand(1);
11803     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11804                        DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
11805                        DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
11806   }
11807 
11808   // Look through FP_EXTEND nodes to do more combining.
11809 
11810   // fold (fsub (fpext (fmul x, y)), z)
11811   //   -> (fma (fpext x), (fpext y), (fneg z))
11812   if (N0.getOpcode() == ISD::FP_EXTEND) {
11813     SDValue N00 = N0.getOperand(0);
11814     if (isContractableFMUL(N00) &&
11815         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11816                             N00.getValueType())) {
11817       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11818                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11819                                      N00.getOperand(0)),
11820                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11821                                      N00.getOperand(1)),
11822                          DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
11823     }
11824   }
11825 
11826   // fold (fsub x, (fpext (fmul y, z)))
11827   //   -> (fma (fneg (fpext y)), (fpext z), x)
11828   // Note: Commutes FSUB operands.
11829   if (N1.getOpcode() == ISD::FP_EXTEND) {
11830     SDValue N10 = N1.getOperand(0);
11831     if (isContractableFMUL(N10) &&
11832         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11833                             N10.getValueType())) {
11834       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11835                          DAG.getNode(ISD::FNEG, SL, VT,
11836                                      DAG.getNode(ISD::FP_EXTEND, SL, VT,
11837                                                  N10.getOperand(0))),
11838                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11839                                      N10.getOperand(1)),
11840                          N0, Flags);
11841     }
11842   }
11843 
11844   // fold (fsub (fpext (fneg (fmul, x, y))), z)
11845   //   -> (fneg (fma (fpext x), (fpext y), z))
11846   // Note: This could be removed with appropriate canonicalization of the
11847   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
11848   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
11849   // from implementing the canonicalization in visitFSUB.
11850   if (N0.getOpcode() == ISD::FP_EXTEND) {
11851     SDValue N00 = N0.getOperand(0);
11852     if (N00.getOpcode() == ISD::FNEG) {
11853       SDValue N000 = N00.getOperand(0);
11854       if (isContractableFMUL(N000) &&
11855           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11856                               N00.getValueType())) {
11857         return DAG.getNode(ISD::FNEG, SL, VT,
11858                            DAG.getNode(PreferredFusedOpcode, SL, VT,
11859                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11860                                                    N000.getOperand(0)),
11861                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11862                                                    N000.getOperand(1)),
11863                                        N1, Flags));
11864       }
11865     }
11866   }
11867 
11868   // fold (fsub (fneg (fpext (fmul, x, y))), z)
11869   //   -> (fneg (fma (fpext x)), (fpext y), z)
11870   // Note: This could be removed with appropriate canonicalization of the
11871   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
11872   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
11873   // from implementing the canonicalization in visitFSUB.
11874   if (N0.getOpcode() == ISD::FNEG) {
11875     SDValue N00 = N0.getOperand(0);
11876     if (N00.getOpcode() == ISD::FP_EXTEND) {
11877       SDValue N000 = N00.getOperand(0);
11878       if (isContractableFMUL(N000) &&
11879           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11880                               N000.getValueType())) {
11881         return DAG.getNode(ISD::FNEG, SL, VT,
11882                            DAG.getNode(PreferredFusedOpcode, SL, VT,
11883                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11884                                                    N000.getOperand(0)),
11885                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11886                                                    N000.getOperand(1)),
11887                                        N1, Flags));
11888       }
11889     }
11890   }
11891 
11892   // More folding opportunities when target permits.
11893   if (Aggressive) {
11894     // fold (fsub (fma x, y, (fmul u, v)), z)
11895     //   -> (fma x, y (fma u, v, (fneg z)))
11896     if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
11897         isContractableFMUL(N0.getOperand(2)) && N0->hasOneUse() &&
11898         N0.getOperand(2)->hasOneUse()) {
11899       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11900                          N0.getOperand(0), N0.getOperand(1),
11901                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11902                                      N0.getOperand(2).getOperand(0),
11903                                      N0.getOperand(2).getOperand(1),
11904                                      DAG.getNode(ISD::FNEG, SL, VT,
11905                                                  N1), Flags), Flags);
11906     }
11907 
11908     // fold (fsub x, (fma y, z, (fmul u, v)))
11909     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
11910     if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
11911         isContractableFMUL(N1.getOperand(2)) &&
11912         N1->hasOneUse()) {
11913       SDValue N20 = N1.getOperand(2).getOperand(0);
11914       SDValue N21 = N1.getOperand(2).getOperand(1);
11915       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11916                          DAG.getNode(ISD::FNEG, SL, VT,
11917                                      N1.getOperand(0)),
11918                          N1.getOperand(1),
11919                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11920                                      DAG.getNode(ISD::FNEG, SL, VT, N20),
11921                                      N21, N0, Flags), Flags);
11922     }
11923 
11924 
11925     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
11926     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
11927     if (N0.getOpcode() == PreferredFusedOpcode &&
11928         N0->hasOneUse()) {
11929       SDValue N02 = N0.getOperand(2);
11930       if (N02.getOpcode() == ISD::FP_EXTEND) {
11931         SDValue N020 = N02.getOperand(0);
11932         if (isContractableFMUL(N020) &&
11933             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11934                                 N020.getValueType())) {
11935           return DAG.getNode(PreferredFusedOpcode, SL, VT,
11936                              N0.getOperand(0), N0.getOperand(1),
11937                              DAG.getNode(PreferredFusedOpcode, SL, VT,
11938                                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11939                                                      N020.getOperand(0)),
11940                                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11941                                                      N020.getOperand(1)),
11942                                          DAG.getNode(ISD::FNEG, SL, VT,
11943                                                      N1), Flags), Flags);
11944         }
11945       }
11946     }
11947 
11948     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
11949     //   -> (fma (fpext x), (fpext y),
11950     //           (fma (fpext u), (fpext v), (fneg z)))
11951     // FIXME: This turns two single-precision and one double-precision
11952     // operation into two double-precision operations, which might not be
11953     // interesting for all targets, especially GPUs.
11954     if (N0.getOpcode() == ISD::FP_EXTEND) {
11955       SDValue N00 = N0.getOperand(0);
11956       if (N00.getOpcode() == PreferredFusedOpcode) {
11957         SDValue N002 = N00.getOperand(2);
11958         if (isContractableFMUL(N002) &&
11959             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11960                                 N00.getValueType())) {
11961           return DAG.getNode(PreferredFusedOpcode, SL, VT,
11962                              DAG.getNode(ISD::FP_EXTEND, SL, VT,
11963                                          N00.getOperand(0)),
11964                              DAG.getNode(ISD::FP_EXTEND, SL, VT,
11965                                          N00.getOperand(1)),
11966                              DAG.getNode(PreferredFusedOpcode, SL, VT,
11967                                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11968                                                      N002.getOperand(0)),
11969                                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11970                                                      N002.getOperand(1)),
11971                                          DAG.getNode(ISD::FNEG, SL, VT,
11972                                                      N1), Flags), Flags);
11973         }
11974       }
11975     }
11976 
11977     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
11978     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
11979     if (N1.getOpcode() == PreferredFusedOpcode &&
11980         N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
11981         N1->hasOneUse()) {
11982       SDValue N120 = N1.getOperand(2).getOperand(0);
11983       if (isContractableFMUL(N120) &&
11984           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11985                               N120.getValueType())) {
11986         SDValue N1200 = N120.getOperand(0);
11987         SDValue N1201 = N120.getOperand(1);
11988         return DAG.getNode(PreferredFusedOpcode, SL, VT,
11989                            DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
11990                            N1.getOperand(1),
11991                            DAG.getNode(PreferredFusedOpcode, SL, VT,
11992                                        DAG.getNode(ISD::FNEG, SL, VT,
11993                                                    DAG.getNode(ISD::FP_EXTEND, SL,
11994                                                                VT, N1200)),
11995                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11996                                                    N1201),
11997                                        N0, Flags), Flags);
11998       }
11999     }
12000 
12001     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
12002     //   -> (fma (fneg (fpext y)), (fpext z),
12003     //           (fma (fneg (fpext u)), (fpext v), x))
12004     // FIXME: This turns two single-precision and one double-precision
12005     // operation into two double-precision operations, which might not be
12006     // interesting for all targets, especially GPUs.
12007     if (N1.getOpcode() == ISD::FP_EXTEND &&
12008         N1.getOperand(0).getOpcode() == PreferredFusedOpcode) {
12009       SDValue CvtSrc = N1.getOperand(0);
12010       SDValue N100 = CvtSrc.getOperand(0);
12011       SDValue N101 = CvtSrc.getOperand(1);
12012       SDValue N102 = CvtSrc.getOperand(2);
12013       if (isContractableFMUL(N102) &&
12014           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12015                               CvtSrc.getValueType())) {
12016         SDValue N1020 = N102.getOperand(0);
12017         SDValue N1021 = N102.getOperand(1);
12018         return DAG.getNode(PreferredFusedOpcode, SL, VT,
12019                            DAG.getNode(ISD::FNEG, SL, VT,
12020                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
12021                                                    N100)),
12022                            DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
12023                            DAG.getNode(PreferredFusedOpcode, SL, VT,
12024                                        DAG.getNode(ISD::FNEG, SL, VT,
12025                                                    DAG.getNode(ISD::FP_EXTEND, SL,
12026                                                                VT, N1020)),
12027                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
12028                                                    N1021),
12029                                        N0, Flags), Flags);
12030       }
12031     }
12032   }
12033 
12034   return SDValue();
12035 }
12036 
12037 /// Try to perform FMA combining on a given FMUL node based on the distributive
12038 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
12039 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)12040 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
12041   SDValue N0 = N->getOperand(0);
12042   SDValue N1 = N->getOperand(1);
12043   EVT VT = N->getValueType(0);
12044   SDLoc SL(N);
12045   const SDNodeFlags Flags = N->getFlags();
12046 
12047   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
12048 
12049   const TargetOptions &Options = DAG.getTarget().Options;
12050 
12051   // The transforms below are incorrect when x == 0 and y == inf, because the
12052   // intermediate multiplication produces a nan.
12053   if (!Options.NoInfsFPMath)
12054     return SDValue();
12055 
12056   // Floating-point multiply-add without intermediate rounding.
12057   bool HasFMA =
12058       (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
12059       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
12060       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
12061 
12062   // Floating-point multiply-add with intermediate rounding. This can result
12063   // in a less precise result due to the changed rounding order.
12064   bool HasFMAD = Options.UnsafeFPMath &&
12065                  (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
12066 
12067   // No valid opcode, do not combine.
12068   if (!HasFMAD && !HasFMA)
12069     return SDValue();
12070 
12071   // Always prefer FMAD to FMA for precision.
12072   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
12073   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
12074 
12075   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
12076   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
12077   auto FuseFADD = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) {
12078     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
12079       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
12080         if (C->isExactlyValue(+1.0))
12081           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
12082                              Y, Flags);
12083         if (C->isExactlyValue(-1.0))
12084           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
12085                              DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
12086       }
12087     }
12088     return SDValue();
12089   };
12090 
12091   if (SDValue FMA = FuseFADD(N0, N1, Flags))
12092     return FMA;
12093   if (SDValue FMA = FuseFADD(N1, N0, Flags))
12094     return FMA;
12095 
12096   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
12097   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
12098   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
12099   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
12100   auto FuseFSUB = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) {
12101     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
12102       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
12103         if (C0->isExactlyValue(+1.0))
12104           return DAG.getNode(PreferredFusedOpcode, SL, VT,
12105                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
12106                              Y, Flags);
12107         if (C0->isExactlyValue(-1.0))
12108           return DAG.getNode(PreferredFusedOpcode, SL, VT,
12109                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
12110                              DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
12111       }
12112       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
12113         if (C1->isExactlyValue(+1.0))
12114           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
12115                              DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
12116         if (C1->isExactlyValue(-1.0))
12117           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
12118                              Y, Flags);
12119       }
12120     }
12121     return SDValue();
12122   };
12123 
12124   if (SDValue FMA = FuseFSUB(N0, N1, Flags))
12125     return FMA;
12126   if (SDValue FMA = FuseFSUB(N1, N0, Flags))
12127     return FMA;
12128 
12129   return SDValue();
12130 }
12131 
visitFADD(SDNode * N)12132 SDValue DAGCombiner::visitFADD(SDNode *N) {
12133   SDValue N0 = N->getOperand(0);
12134   SDValue N1 = N->getOperand(1);
12135   bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0);
12136   bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1);
12137   EVT VT = N->getValueType(0);
12138   SDLoc DL(N);
12139   const TargetOptions &Options = DAG.getTarget().Options;
12140   const SDNodeFlags Flags = N->getFlags();
12141 
12142   // fold vector ops
12143   if (VT.isVector())
12144     if (SDValue FoldedVOp = SimplifyVBinOp(N))
12145       return FoldedVOp;
12146 
12147   // fold (fadd c1, c2) -> c1 + c2
12148   if (N0CFP && N1CFP)
12149     return DAG.getNode(ISD::FADD, DL, VT, N0, N1, Flags);
12150 
12151   // canonicalize constant to RHS
12152   if (N0CFP && !N1CFP)
12153     return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags);
12154 
12155   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
12156   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
12157   if (N1C && N1C->isZero())
12158     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
12159       return N0;
12160 
12161   if (SDValue NewSel = foldBinOpIntoSelect(N))
12162     return NewSel;
12163 
12164   // fold (fadd A, (fneg B)) -> (fsub A, B)
12165   if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
12166       TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize) == 2)
12167     return DAG.getNode(
12168         ISD::FSUB, DL, VT, N0,
12169         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
12170 
12171   // fold (fadd (fneg A), B) -> (fsub B, A)
12172   if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
12173       TLI.isNegatibleForFree(N0, DAG, LegalOperations, ForCodeSize) == 2)
12174     return DAG.getNode(
12175         ISD::FSUB, DL, VT, N1,
12176         TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize), Flags);
12177 
12178   auto isFMulNegTwo = [](SDValue FMul) {
12179     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
12180       return false;
12181     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
12182     return C && C->isExactlyValue(-2.0);
12183   };
12184 
12185   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
12186   if (isFMulNegTwo(N0)) {
12187     SDValue B = N0.getOperand(0);
12188     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags);
12189     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add, Flags);
12190   }
12191   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
12192   if (isFMulNegTwo(N1)) {
12193     SDValue B = N1.getOperand(0);
12194     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags);
12195     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add, Flags);
12196   }
12197 
12198   // No FP constant should be created after legalization as Instruction
12199   // Selection pass has a hard time dealing with FP constants.
12200   bool AllowNewConst = (Level < AfterLegalizeDAG);
12201 
12202   // If nnan is enabled, fold lots of things.
12203   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
12204     // If allowed, fold (fadd (fneg x), x) -> 0.0
12205     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
12206       return DAG.getConstantFP(0.0, DL, VT);
12207 
12208     // If allowed, fold (fadd x, (fneg x)) -> 0.0
12209     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
12210       return DAG.getConstantFP(0.0, DL, VT);
12211   }
12212 
12213   // If 'unsafe math' or reassoc and nsz, fold lots of things.
12214   // TODO: break out portions of the transformations below for which Unsafe is
12215   //       considered and which do not require both nsz and reassoc
12216   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
12217        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
12218       AllowNewConst) {
12219     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
12220     if (N1CFP && N0.getOpcode() == ISD::FADD &&
12221         isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
12222       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, Flags);
12223       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC, Flags);
12224     }
12225 
12226     // We can fold chains of FADD's of the same value into multiplications.
12227     // This transform is not safe in general because we are reducing the number
12228     // of rounding steps.
12229     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
12230       if (N0.getOpcode() == ISD::FMUL) {
12231         bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
12232         bool CFP01 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
12233 
12234         // (fadd (fmul x, c), x) -> (fmul x, c+1)
12235         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
12236           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
12237                                        DAG.getConstantFP(1.0, DL, VT), Flags);
12238           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP, Flags);
12239         }
12240 
12241         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
12242         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
12243             N1.getOperand(0) == N1.getOperand(1) &&
12244             N0.getOperand(0) == N1.getOperand(0)) {
12245           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
12246                                        DAG.getConstantFP(2.0, DL, VT), Flags);
12247           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP, Flags);
12248         }
12249       }
12250 
12251       if (N1.getOpcode() == ISD::FMUL) {
12252         bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
12253         bool CFP11 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
12254 
12255         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
12256         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
12257           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
12258                                        DAG.getConstantFP(1.0, DL, VT), Flags);
12259           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP, Flags);
12260         }
12261 
12262         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
12263         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
12264             N0.getOperand(0) == N0.getOperand(1) &&
12265             N1.getOperand(0) == N0.getOperand(0)) {
12266           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
12267                                        DAG.getConstantFP(2.0, DL, VT), Flags);
12268           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP, Flags);
12269         }
12270       }
12271 
12272       if (N0.getOpcode() == ISD::FADD) {
12273         bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
12274         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
12275         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
12276             (N0.getOperand(0) == N1)) {
12277           return DAG.getNode(ISD::FMUL, DL, VT,
12278                              N1, DAG.getConstantFP(3.0, DL, VT), Flags);
12279         }
12280       }
12281 
12282       if (N1.getOpcode() == ISD::FADD) {
12283         bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
12284         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
12285         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
12286             N1.getOperand(0) == N0) {
12287           return DAG.getNode(ISD::FMUL, DL, VT,
12288                              N0, DAG.getConstantFP(3.0, DL, VT), Flags);
12289         }
12290       }
12291 
12292       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
12293       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
12294           N0.getOperand(0) == N0.getOperand(1) &&
12295           N1.getOperand(0) == N1.getOperand(1) &&
12296           N0.getOperand(0) == N1.getOperand(0)) {
12297         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
12298                            DAG.getConstantFP(4.0, DL, VT), Flags);
12299       }
12300     }
12301   } // enable-unsafe-fp-math
12302 
12303   // FADD -> FMA combines:
12304   if (SDValue Fused = visitFADDForFMACombine(N)) {
12305     AddToWorklist(Fused.getNode());
12306     return Fused;
12307   }
12308   return SDValue();
12309 }
12310 
visitFSUB(SDNode * N)12311 SDValue DAGCombiner::visitFSUB(SDNode *N) {
12312   SDValue N0 = N->getOperand(0);
12313   SDValue N1 = N->getOperand(1);
12314   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
12315   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
12316   EVT VT = N->getValueType(0);
12317   SDLoc DL(N);
12318   const TargetOptions &Options = DAG.getTarget().Options;
12319   const SDNodeFlags Flags = N->getFlags();
12320 
12321   // fold vector ops
12322   if (VT.isVector())
12323     if (SDValue FoldedVOp = SimplifyVBinOp(N))
12324       return FoldedVOp;
12325 
12326   // fold (fsub c1, c2) -> c1-c2
12327   if (N0CFP && N1CFP)
12328     return DAG.getNode(ISD::FSUB, DL, VT, N0, N1, Flags);
12329 
12330   if (SDValue NewSel = foldBinOpIntoSelect(N))
12331     return NewSel;
12332 
12333   // (fsub A, 0) -> A
12334   if (N1CFP && N1CFP->isZero()) {
12335     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
12336         Flags.hasNoSignedZeros()) {
12337       return N0;
12338     }
12339   }
12340 
12341   if (N0 == N1) {
12342     // (fsub x, x) -> 0.0
12343     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
12344       return DAG.getConstantFP(0.0f, DL, VT);
12345   }
12346 
12347   // (fsub -0.0, N1) -> -N1
12348   // NOTE: It is safe to transform an FSUB(-0.0,X) into an FNEG(X), since the
12349   //       FSUB does not specify the sign bit of a NaN. Also note that for
12350   //       the same reason, the inverse transform is not safe, unless fast math
12351   //       flags are in play.
12352   if (N0CFP && N0CFP->isZero()) {
12353     if (N0CFP->isNegative() ||
12354         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
12355       if (TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize))
12356         return TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
12357       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
12358         return DAG.getNode(ISD::FNEG, DL, VT, N1, Flags);
12359     }
12360   }
12361 
12362   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
12363        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
12364       N1.getOpcode() == ISD::FADD) {
12365     // X - (X + Y) -> -Y
12366     if (N0 == N1->getOperand(0))
12367       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1), Flags);
12368     // X - (Y + X) -> -Y
12369     if (N0 == N1->getOperand(1))
12370       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0), Flags);
12371   }
12372 
12373   // fold (fsub A, (fneg B)) -> (fadd A, B)
12374   if (TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize))
12375     return DAG.getNode(
12376         ISD::FADD, DL, VT, N0,
12377         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
12378 
12379   // FSUB -> FMA combines:
12380   if (SDValue Fused = visitFSUBForFMACombine(N)) {
12381     AddToWorklist(Fused.getNode());
12382     return Fused;
12383   }
12384 
12385   return SDValue();
12386 }
12387 
12388 /// Return true if both inputs are at least as cheap in negated form and at
12389 /// least one input is strictly cheaper in negated form.
isCheaperToUseNegatedFPOps(SDValue X,SDValue Y)12390 bool DAGCombiner::isCheaperToUseNegatedFPOps(SDValue X, SDValue Y) {
12391   if (char LHSNeg =
12392           TLI.isNegatibleForFree(X, DAG, LegalOperations, ForCodeSize))
12393     if (char RHSNeg =
12394             TLI.isNegatibleForFree(Y, DAG, LegalOperations, ForCodeSize))
12395       // Both negated operands are at least as cheap as their counterparts.
12396       // Check to see if at least one is cheaper negated.
12397       if (LHSNeg == 2 || RHSNeg == 2)
12398         return true;
12399 
12400   return false;
12401 }
12402 
visitFMUL(SDNode * N)12403 SDValue DAGCombiner::visitFMUL(SDNode *N) {
12404   SDValue N0 = N->getOperand(0);
12405   SDValue N1 = N->getOperand(1);
12406   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
12407   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
12408   EVT VT = N->getValueType(0);
12409   SDLoc DL(N);
12410   const TargetOptions &Options = DAG.getTarget().Options;
12411   const SDNodeFlags Flags = N->getFlags();
12412 
12413   // fold vector ops
12414   if (VT.isVector()) {
12415     // This just handles C1 * C2 for vectors. Other vector folds are below.
12416     if (SDValue FoldedVOp = SimplifyVBinOp(N))
12417       return FoldedVOp;
12418   }
12419 
12420   // fold (fmul c1, c2) -> c1*c2
12421   if (N0CFP && N1CFP)
12422     return DAG.getNode(ISD::FMUL, DL, VT, N0, N1, Flags);
12423 
12424   // canonicalize constant to RHS
12425   if (isConstantFPBuildVectorOrConstantFP(N0) &&
12426      !isConstantFPBuildVectorOrConstantFP(N1))
12427     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0, Flags);
12428 
12429   if (SDValue NewSel = foldBinOpIntoSelect(N))
12430     return NewSel;
12431 
12432   if ((Options.NoNaNsFPMath && Options.NoSignedZerosFPMath) ||
12433       (Flags.hasNoNaNs() && Flags.hasNoSignedZeros())) {
12434     // fold (fmul A, 0) -> 0
12435     if (N1CFP && N1CFP->isZero())
12436       return N1;
12437   }
12438 
12439   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
12440     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
12441     if (isConstantFPBuildVectorOrConstantFP(N1) &&
12442         N0.getOpcode() == ISD::FMUL) {
12443       SDValue N00 = N0.getOperand(0);
12444       SDValue N01 = N0.getOperand(1);
12445       // Avoid an infinite loop by making sure that N00 is not a constant
12446       // (the inner multiply has not been constant folded yet).
12447       if (isConstantFPBuildVectorOrConstantFP(N01) &&
12448           !isConstantFPBuildVectorOrConstantFP(N00)) {
12449         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags);
12450         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags);
12451       }
12452     }
12453 
12454     // Match a special-case: we convert X * 2.0 into fadd.
12455     // fmul (fadd X, X), C -> fmul X, 2.0 * C
12456     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
12457         N0.getOperand(0) == N0.getOperand(1)) {
12458       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
12459       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1, Flags);
12460       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts, Flags);
12461     }
12462   }
12463 
12464   // fold (fmul X, 2.0) -> (fadd X, X)
12465   if (N1CFP && N1CFP->isExactlyValue(+2.0))
12466     return DAG.getNode(ISD::FADD, DL, VT, N0, N0, Flags);
12467 
12468   // fold (fmul X, -1.0) -> (fneg X)
12469   if (N1CFP && N1CFP->isExactlyValue(-1.0))
12470     if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
12471       return DAG.getNode(ISD::FNEG, DL, VT, N0);
12472 
12473   // -N0 * -N1 --> N0 * N1
12474   if (isCheaperToUseNegatedFPOps(N0, N1)) {
12475     SDValue NegN0 =
12476         TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
12477     SDValue NegN1 =
12478         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
12479     return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1, Flags);
12480   }
12481 
12482   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
12483   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
12484   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
12485       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
12486       TLI.isOperationLegal(ISD::FABS, VT)) {
12487     SDValue Select = N0, X = N1;
12488     if (Select.getOpcode() != ISD::SELECT)
12489       std::swap(Select, X);
12490 
12491     SDValue Cond = Select.getOperand(0);
12492     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
12493     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
12494 
12495     if (TrueOpnd && FalseOpnd &&
12496         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
12497         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
12498         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
12499       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
12500       switch (CC) {
12501       default: break;
12502       case ISD::SETOLT:
12503       case ISD::SETULT:
12504       case ISD::SETOLE:
12505       case ISD::SETULE:
12506       case ISD::SETLT:
12507       case ISD::SETLE:
12508         std::swap(TrueOpnd, FalseOpnd);
12509         LLVM_FALLTHROUGH;
12510       case ISD::SETOGT:
12511       case ISD::SETUGT:
12512       case ISD::SETOGE:
12513       case ISD::SETUGE:
12514       case ISD::SETGT:
12515       case ISD::SETGE:
12516         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
12517             TLI.isOperationLegal(ISD::FNEG, VT))
12518           return DAG.getNode(ISD::FNEG, DL, VT,
12519                    DAG.getNode(ISD::FABS, DL, VT, X));
12520         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
12521           return DAG.getNode(ISD::FABS, DL, VT, X);
12522 
12523         break;
12524       }
12525     }
12526   }
12527 
12528   // FMUL -> FMA combines:
12529   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
12530     AddToWorklist(Fused.getNode());
12531     return Fused;
12532   }
12533 
12534   return SDValue();
12535 }
12536 
visitFMA(SDNode * N)12537 SDValue DAGCombiner::visitFMA(SDNode *N) {
12538   SDValue N0 = N->getOperand(0);
12539   SDValue N1 = N->getOperand(1);
12540   SDValue N2 = N->getOperand(2);
12541   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
12542   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
12543   EVT VT = N->getValueType(0);
12544   SDLoc DL(N);
12545   const TargetOptions &Options = DAG.getTarget().Options;
12546 
12547   // FMA nodes have flags that propagate to the created nodes.
12548   const SDNodeFlags Flags = N->getFlags();
12549   bool UnsafeFPMath = Options.UnsafeFPMath || isContractable(N);
12550 
12551   // Constant fold FMA.
12552   if (isa<ConstantFPSDNode>(N0) &&
12553       isa<ConstantFPSDNode>(N1) &&
12554       isa<ConstantFPSDNode>(N2)) {
12555     return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
12556   }
12557 
12558   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
12559   if (isCheaperToUseNegatedFPOps(N0, N1)) {
12560     SDValue NegN0 =
12561         TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
12562     SDValue NegN1 =
12563         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
12564     return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2, Flags);
12565   }
12566 
12567   if (UnsafeFPMath) {
12568     if (N0CFP && N0CFP->isZero())
12569       return N2;
12570     if (N1CFP && N1CFP->isZero())
12571       return N2;
12572   }
12573   // TODO: The FMA node should have flags that propagate to these nodes.
12574   if (N0CFP && N0CFP->isExactlyValue(1.0))
12575     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
12576   if (N1CFP && N1CFP->isExactlyValue(1.0))
12577     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
12578 
12579   // Canonicalize (fma c, x, y) -> (fma x, c, y)
12580   if (isConstantFPBuildVectorOrConstantFP(N0) &&
12581      !isConstantFPBuildVectorOrConstantFP(N1))
12582     return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
12583 
12584   if (UnsafeFPMath) {
12585     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
12586     if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
12587         isConstantFPBuildVectorOrConstantFP(N1) &&
12588         isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
12589       return DAG.getNode(ISD::FMUL, DL, VT, N0,
12590                          DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1),
12591                                      Flags), Flags);
12592     }
12593 
12594     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
12595     if (N0.getOpcode() == ISD::FMUL &&
12596         isConstantFPBuildVectorOrConstantFP(N1) &&
12597         isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
12598       return DAG.getNode(ISD::FMA, DL, VT,
12599                          N0.getOperand(0),
12600                          DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1),
12601                                      Flags),
12602                          N2);
12603     }
12604   }
12605 
12606   // (fma x, 1, y) -> (fadd x, y)
12607   // (fma x, -1, y) -> (fadd (fneg x), y)
12608   if (N1CFP) {
12609     if (N1CFP->isExactlyValue(1.0))
12610       // TODO: The FMA node should have flags that propagate to this node.
12611       return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
12612 
12613     if (N1CFP->isExactlyValue(-1.0) &&
12614         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
12615       SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
12616       AddToWorklist(RHSNeg.getNode());
12617       // TODO: The FMA node should have flags that propagate to this node.
12618       return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
12619     }
12620 
12621     // fma (fneg x), K, y -> fma x -K, y
12622     if (N0.getOpcode() == ISD::FNEG &&
12623         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
12624          (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
12625                                               ForCodeSize)))) {
12626       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
12627                          DAG.getNode(ISD::FNEG, DL, VT, N1, Flags), N2);
12628     }
12629   }
12630 
12631   if (UnsafeFPMath) {
12632     // (fma x, c, x) -> (fmul x, (c+1))
12633     if (N1CFP && N0 == N2) {
12634       return DAG.getNode(ISD::FMUL, DL, VT, N0,
12635                          DAG.getNode(ISD::FADD, DL, VT, N1,
12636                                      DAG.getConstantFP(1.0, DL, VT), Flags),
12637                          Flags);
12638     }
12639 
12640     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
12641     if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
12642       return DAG.getNode(ISD::FMUL, DL, VT, N0,
12643                          DAG.getNode(ISD::FADD, DL, VT, N1,
12644                                      DAG.getConstantFP(-1.0, DL, VT), Flags),
12645                          Flags);
12646     }
12647   }
12648 
12649   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
12650   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
12651   if (!TLI.isFNegFree(VT) &&
12652       TLI.isNegatibleForFree(SDValue(N, 0), DAG, LegalOperations,
12653                              ForCodeSize) == 2)
12654     return DAG.getNode(ISD::FNEG, DL, VT,
12655                        TLI.getNegatedExpression(SDValue(N, 0), DAG,
12656                                                 LegalOperations, ForCodeSize),
12657                        Flags);
12658   return SDValue();
12659 }
12660 
12661 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
12662 // reciprocal.
12663 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
12664 // Notice that this is not always beneficial. One reason is different targets
12665 // may have different costs for FDIV and FMUL, so sometimes the cost of two
12666 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
12667 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)12668 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
12669   // TODO: Limit this transform based on optsize/minsize - it always creates at
12670   //       least 1 extra instruction. But the perf win may be substantial enough
12671   //       that only minsize should restrict this.
12672   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
12673   const SDNodeFlags Flags = N->getFlags();
12674   if (!UnsafeMath && !Flags.hasAllowReciprocal())
12675     return SDValue();
12676 
12677   // Skip if current node is a reciprocal/fneg-reciprocal.
12678   SDValue N0 = N->getOperand(0);
12679   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
12680   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
12681     return SDValue();
12682 
12683   // Exit early if the target does not want this transform or if there can't
12684   // possibly be enough uses of the divisor to make the transform worthwhile.
12685   SDValue N1 = N->getOperand(1);
12686   unsigned MinUses = TLI.combineRepeatedFPDivisors();
12687 
12688   // For splat vectors, scale the number of uses by the splat factor. If we can
12689   // convert the division into a scalar op, that will likely be much faster.
12690   unsigned NumElts = 1;
12691   EVT VT = N->getValueType(0);
12692   if (VT.isVector() && DAG.isSplatValue(N1))
12693     NumElts = VT.getVectorNumElements();
12694 
12695   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
12696     return SDValue();
12697 
12698   // Find all FDIV users of the same divisor.
12699   // Use a set because duplicates may be present in the user list.
12700   SetVector<SDNode *> Users;
12701   for (auto *U : N1->uses()) {
12702     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
12703       // This division is eligible for optimization only if global unsafe math
12704       // is enabled or if this division allows reciprocal formation.
12705       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
12706         Users.insert(U);
12707     }
12708   }
12709 
12710   // Now that we have the actual number of divisor uses, make sure it meets
12711   // the minimum threshold specified by the target.
12712   if ((Users.size() * NumElts) < MinUses)
12713     return SDValue();
12714 
12715   SDLoc DL(N);
12716   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
12717   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
12718 
12719   // Dividend / Divisor -> Dividend * Reciprocal
12720   for (auto *U : Users) {
12721     SDValue Dividend = U->getOperand(0);
12722     if (Dividend != FPOne) {
12723       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
12724                                     Reciprocal, Flags);
12725       CombineTo(U, NewNode);
12726     } else if (U != Reciprocal.getNode()) {
12727       // In the absence of fast-math-flags, this user node is always the
12728       // same node as Reciprocal, but with FMF they may be different nodes.
12729       CombineTo(U, Reciprocal);
12730     }
12731   }
12732   return SDValue(N, 0);  // N was replaced.
12733 }
12734 
visitFDIV(SDNode * N)12735 SDValue DAGCombiner::visitFDIV(SDNode *N) {
12736   SDValue N0 = N->getOperand(0);
12737   SDValue N1 = N->getOperand(1);
12738   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
12739   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
12740   EVT VT = N->getValueType(0);
12741   SDLoc DL(N);
12742   const TargetOptions &Options = DAG.getTarget().Options;
12743   SDNodeFlags Flags = N->getFlags();
12744 
12745   // fold vector ops
12746   if (VT.isVector())
12747     if (SDValue FoldedVOp = SimplifyVBinOp(N))
12748       return FoldedVOp;
12749 
12750   // fold (fdiv c1, c2) -> c1/c2
12751   if (N0CFP && N1CFP)
12752     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1, Flags);
12753 
12754   if (SDValue NewSel = foldBinOpIntoSelect(N))
12755     return NewSel;
12756 
12757   if (SDValue V = combineRepeatedFPDivisors(N))
12758     return V;
12759 
12760   if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
12761     // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
12762     if (N1CFP) {
12763       // Compute the reciprocal 1.0 / c2.
12764       const APFloat &N1APF = N1CFP->getValueAPF();
12765       APFloat Recip(N1APF.getSemantics(), 1); // 1.0
12766       APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
12767       // Only do the transform if the reciprocal is a legal fp immediate that
12768       // isn't too nasty (eg NaN, denormal, ...).
12769       if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
12770           (!LegalOperations ||
12771            // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
12772            // backend)... we should handle this gracefully after Legalize.
12773            // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
12774            TLI.isOperationLegal(ISD::ConstantFP, VT) ||
12775            TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
12776         return DAG.getNode(ISD::FMUL, DL, VT, N0,
12777                            DAG.getConstantFP(Recip, DL, VT), Flags);
12778     }
12779 
12780     // If this FDIV is part of a reciprocal square root, it may be folded
12781     // into a target-specific square root estimate instruction.
12782     if (N1.getOpcode() == ISD::FSQRT) {
12783       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
12784         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
12785     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
12786                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
12787       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
12788                                           Flags)) {
12789         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
12790         AddToWorklist(RV.getNode());
12791         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
12792       }
12793     } else if (N1.getOpcode() == ISD::FP_ROUND &&
12794                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
12795       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
12796                                           Flags)) {
12797         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
12798         AddToWorklist(RV.getNode());
12799         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
12800       }
12801     } else if (N1.getOpcode() == ISD::FMUL) {
12802       // Look through an FMUL. Even though this won't remove the FDIV directly,
12803       // it's still worthwhile to get rid of the FSQRT if possible.
12804       SDValue SqrtOp;
12805       SDValue OtherOp;
12806       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
12807         SqrtOp = N1.getOperand(0);
12808         OtherOp = N1.getOperand(1);
12809       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
12810         SqrtOp = N1.getOperand(1);
12811         OtherOp = N1.getOperand(0);
12812       }
12813       if (SqrtOp.getNode()) {
12814         // We found a FSQRT, so try to make this fold:
12815         // x / (y * sqrt(z)) -> x * (rsqrt(z) / y)
12816         if (SDValue RV = buildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) {
12817           RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags);
12818           AddToWorklist(RV.getNode());
12819           return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
12820         }
12821       }
12822     }
12823 
12824     // Fold into a reciprocal estimate and multiply instead of a real divide.
12825     if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
12826       return RV;
12827   }
12828 
12829   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
12830   if (isCheaperToUseNegatedFPOps(N0, N1))
12831     return DAG.getNode(
12832         ISD::FDIV, SDLoc(N), VT,
12833         TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize),
12834         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
12835 
12836   return SDValue();
12837 }
12838 
visitFREM(SDNode * N)12839 SDValue DAGCombiner::visitFREM(SDNode *N) {
12840   SDValue N0 = N->getOperand(0);
12841   SDValue N1 = N->getOperand(1);
12842   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
12843   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
12844   EVT VT = N->getValueType(0);
12845 
12846   // fold (frem c1, c2) -> fmod(c1,c2)
12847   if (N0CFP && N1CFP)
12848     return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1, N->getFlags());
12849 
12850   if (SDValue NewSel = foldBinOpIntoSelect(N))
12851     return NewSel;
12852 
12853   return SDValue();
12854 }
12855 
visitFSQRT(SDNode * N)12856 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
12857   SDNodeFlags Flags = N->getFlags();
12858   if (!DAG.getTarget().Options.UnsafeFPMath &&
12859       !Flags.hasApproximateFuncs())
12860     return SDValue();
12861 
12862   SDValue N0 = N->getOperand(0);
12863   if (TLI.isFsqrtCheap(N0, DAG))
12864     return SDValue();
12865 
12866   // FSQRT nodes have flags that propagate to the created nodes.
12867   return buildSqrtEstimate(N0, Flags);
12868 }
12869 
12870 /// copysign(x, fp_extend(y)) -> copysign(x, y)
12871 /// copysign(x, fp_round(y)) -> copysign(x, y)
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)12872 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
12873   SDValue N1 = N->getOperand(1);
12874   if ((N1.getOpcode() == ISD::FP_EXTEND ||
12875        N1.getOpcode() == ISD::FP_ROUND)) {
12876     // Do not optimize out type conversion of f128 type yet.
12877     // For some targets like x86_64, configuration is changed to keep one f128
12878     // value in one SSE register, but instruction selection cannot handle
12879     // FCOPYSIGN on SSE registers yet.
12880     EVT N1VT = N1->getValueType(0);
12881     EVT N1Op0VT = N1->getOperand(0).getValueType();
12882     return (N1VT == N1Op0VT || N1Op0VT != MVT::f128);
12883   }
12884   return false;
12885 }
12886 
visitFCOPYSIGN(SDNode * N)12887 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
12888   SDValue N0 = N->getOperand(0);
12889   SDValue N1 = N->getOperand(1);
12890   bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0);
12891   bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1);
12892   EVT VT = N->getValueType(0);
12893 
12894   if (N0CFP && N1CFP) // Constant fold
12895     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1);
12896 
12897   if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
12898     const APFloat &V = N1C->getValueAPF();
12899     // copysign(x, c1) -> fabs(x)       iff ispos(c1)
12900     // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
12901     if (!V.isNegative()) {
12902       if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
12903         return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
12904     } else {
12905       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
12906         return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
12907                            DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
12908     }
12909   }
12910 
12911   // copysign(fabs(x), y) -> copysign(x, y)
12912   // copysign(fneg(x), y) -> copysign(x, y)
12913   // copysign(copysign(x,z), y) -> copysign(x, y)
12914   if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
12915       N0.getOpcode() == ISD::FCOPYSIGN)
12916     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
12917 
12918   // copysign(x, abs(y)) -> abs(x)
12919   if (N1.getOpcode() == ISD::FABS)
12920     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
12921 
12922   // copysign(x, copysign(y,z)) -> copysign(x, z)
12923   if (N1.getOpcode() == ISD::FCOPYSIGN)
12924     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
12925 
12926   // copysign(x, fp_extend(y)) -> copysign(x, y)
12927   // copysign(x, fp_round(y)) -> copysign(x, y)
12928   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
12929     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
12930 
12931   return SDValue();
12932 }
12933 
visitFPOW(SDNode * N)12934 SDValue DAGCombiner::visitFPOW(SDNode *N) {
12935   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
12936   if (!ExponentC)
12937     return SDValue();
12938 
12939   // Try to convert x ** (1/3) into cube root.
12940   // TODO: Handle the various flavors of long double.
12941   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
12942   //       Some range near 1/3 should be fine.
12943   EVT VT = N->getValueType(0);
12944   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
12945       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
12946     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
12947     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
12948     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
12949     // For regular numbers, rounding may cause the results to differ.
12950     // Therefore, we require { nsz ninf nnan afn } for this transform.
12951     // TODO: We could select out the special cases if we don't have nsz/ninf.
12952     SDNodeFlags Flags = N->getFlags();
12953     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
12954         !Flags.hasApproximateFuncs())
12955       return SDValue();
12956 
12957     // Do not create a cbrt() libcall if the target does not have it, and do not
12958     // turn a pow that has lowering support into a cbrt() libcall.
12959     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
12960         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
12961          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
12962       return SDValue();
12963 
12964     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0), Flags);
12965   }
12966 
12967   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
12968   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
12969   // TODO: This could be extended (using a target hook) to handle smaller
12970   // power-of-2 fractional exponents.
12971   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
12972   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
12973   if (ExponentIs025 || ExponentIs075) {
12974     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
12975     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
12976     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
12977     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
12978     // For regular numbers, rounding may cause the results to differ.
12979     // Therefore, we require { nsz ninf afn } for this transform.
12980     // TODO: We could select out the special cases if we don't have nsz/ninf.
12981     SDNodeFlags Flags = N->getFlags();
12982 
12983     // We only need no signed zeros for the 0.25 case.
12984     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
12985         !Flags.hasApproximateFuncs())
12986       return SDValue();
12987 
12988     // Don't double the number of libcalls. We are trying to inline fast code.
12989     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
12990       return SDValue();
12991 
12992     // Assume that libcalls are the smallest code.
12993     // TODO: This restriction should probably be lifted for vectors.
12994     if (ForCodeSize)
12995       return SDValue();
12996 
12997     // pow(X, 0.25) --> sqrt(sqrt(X))
12998     SDLoc DL(N);
12999     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0), Flags);
13000     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt, Flags);
13001     if (ExponentIs025)
13002       return SqrtSqrt;
13003     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
13004     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt, Flags);
13005   }
13006 
13007   return SDValue();
13008 }
13009 
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)13010 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
13011                                const TargetLowering &TLI) {
13012   // This optimization is guarded by a function attribute because it may produce
13013   // unexpected results. Ie, programs may be relying on the platform-specific
13014   // undefined behavior when the float-to-int conversion overflows.
13015   const Function &F = DAG.getMachineFunction().getFunction();
13016   Attribute StrictOverflow = F.getFnAttribute("strict-float-cast-overflow");
13017   if (StrictOverflow.getValueAsString().equals("false"))
13018     return SDValue();
13019 
13020   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
13021   // replacing casts with a libcall. We also must be allowed to ignore -0.0
13022   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
13023   // conversions would return +0.0.
13024   // FIXME: We should be able to use node-level FMF here.
13025   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
13026   EVT VT = N->getValueType(0);
13027   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
13028       !DAG.getTarget().Options.NoSignedZerosFPMath)
13029     return SDValue();
13030 
13031   // fptosi/fptoui round towards zero, so converting from FP to integer and
13032   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
13033   SDValue N0 = N->getOperand(0);
13034   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
13035       N0.getOperand(0).getValueType() == VT)
13036     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
13037 
13038   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
13039       N0.getOperand(0).getValueType() == VT)
13040     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
13041 
13042   return SDValue();
13043 }
13044 
visitSINT_TO_FP(SDNode * N)13045 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
13046   SDValue N0 = N->getOperand(0);
13047   EVT VT = N->getValueType(0);
13048   EVT OpVT = N0.getValueType();
13049 
13050   // [us]itofp(undef) = 0, because the result value is bounded.
13051   if (N0.isUndef())
13052     return DAG.getConstantFP(0.0, SDLoc(N), VT);
13053 
13054   // fold (sint_to_fp c1) -> c1fp
13055   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
13056       // ...but only if the target supports immediate floating-point values
13057       (!LegalOperations ||
13058        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
13059     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
13060 
13061   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
13062   // but UINT_TO_FP is legal on this target, try to convert.
13063   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
13064       hasOperation(ISD::UINT_TO_FP, OpVT)) {
13065     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
13066     if (DAG.SignBitIsZero(N0))
13067       return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
13068   }
13069 
13070   // The next optimizations are desirable only if SELECT_CC can be lowered.
13071   if (TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT) || !LegalOperations) {
13072     // fold (sint_to_fp (setcc x, y, cc)) -> (select_cc x, y, -1.0, 0.0,, cc)
13073     if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
13074         !VT.isVector() &&
13075         (!LegalOperations ||
13076          TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
13077       SDLoc DL(N);
13078       SDValue Ops[] =
13079         { N0.getOperand(0), N0.getOperand(1),
13080           DAG.getConstantFP(-1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT),
13081           N0.getOperand(2) };
13082       return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops);
13083     }
13084 
13085     // fold (sint_to_fp (zext (setcc x, y, cc))) ->
13086     //      (select_cc x, y, 1.0, 0.0,, cc)
13087     if (N0.getOpcode() == ISD::ZERO_EXTEND &&
13088         N0.getOperand(0).getOpcode() == ISD::SETCC &&!VT.isVector() &&
13089         (!LegalOperations ||
13090          TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
13091       SDLoc DL(N);
13092       SDValue Ops[] =
13093         { N0.getOperand(0).getOperand(0), N0.getOperand(0).getOperand(1),
13094           DAG.getConstantFP(1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT),
13095           N0.getOperand(0).getOperand(2) };
13096       return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops);
13097     }
13098   }
13099 
13100   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
13101     return FTrunc;
13102 
13103   return SDValue();
13104 }
13105 
visitUINT_TO_FP(SDNode * N)13106 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
13107   SDValue N0 = N->getOperand(0);
13108   EVT VT = N->getValueType(0);
13109   EVT OpVT = N0.getValueType();
13110 
13111   // [us]itofp(undef) = 0, because the result value is bounded.
13112   if (N0.isUndef())
13113     return DAG.getConstantFP(0.0, SDLoc(N), VT);
13114 
13115   // fold (uint_to_fp c1) -> c1fp
13116   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
13117       // ...but only if the target supports immediate floating-point values
13118       (!LegalOperations ||
13119        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
13120     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
13121 
13122   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
13123   // but SINT_TO_FP is legal on this target, try to convert.
13124   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
13125       hasOperation(ISD::SINT_TO_FP, OpVT)) {
13126     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
13127     if (DAG.SignBitIsZero(N0))
13128       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
13129   }
13130 
13131   // The next optimizations are desirable only if SELECT_CC can be lowered.
13132   if (TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT) || !LegalOperations) {
13133     // fold (uint_to_fp (setcc x, y, cc)) -> (select_cc x, y, -1.0, 0.0,, cc)
13134     if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
13135         (!LegalOperations ||
13136          TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
13137       SDLoc DL(N);
13138       SDValue Ops[] =
13139         { N0.getOperand(0), N0.getOperand(1),
13140           DAG.getConstantFP(1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT),
13141           N0.getOperand(2) };
13142       return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops);
13143     }
13144   }
13145 
13146   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
13147     return FTrunc;
13148 
13149   return SDValue();
13150 }
13151 
13152 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)13153 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
13154   SDValue N0 = N->getOperand(0);
13155   EVT VT = N->getValueType(0);
13156 
13157   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
13158     return SDValue();
13159 
13160   SDValue Src = N0.getOperand(0);
13161   EVT SrcVT = Src.getValueType();
13162   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
13163   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
13164 
13165   // We can safely assume the conversion won't overflow the output range,
13166   // because (for example) (uint8_t)18293.f is undefined behavior.
13167 
13168   // Since we can assume the conversion won't overflow, our decision as to
13169   // whether the input will fit in the float should depend on the minimum
13170   // of the input range and output range.
13171 
13172   // This means this is also safe for a signed input and unsigned output, since
13173   // a negative input would lead to undefined behavior.
13174   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
13175   unsigned OutputSize = (int)VT.getScalarSizeInBits() - IsOutputSigned;
13176   unsigned ActualSize = std::min(InputSize, OutputSize);
13177   const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
13178 
13179   // We can only fold away the float conversion if the input range can be
13180   // represented exactly in the float range.
13181   if (APFloat::semanticsPrecision(sem) >= ActualSize) {
13182     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
13183       unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
13184                                                        : ISD::ZERO_EXTEND;
13185       return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
13186     }
13187     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
13188       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
13189     return DAG.getBitcast(VT, Src);
13190   }
13191   return SDValue();
13192 }
13193 
visitFP_TO_SINT(SDNode * N)13194 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
13195   SDValue N0 = N->getOperand(0);
13196   EVT VT = N->getValueType(0);
13197 
13198   // fold (fp_to_sint undef) -> undef
13199   if (N0.isUndef())
13200     return DAG.getUNDEF(VT);
13201 
13202   // fold (fp_to_sint c1fp) -> c1
13203   if (isConstantFPBuildVectorOrConstantFP(N0))
13204     return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
13205 
13206   return FoldIntToFPToInt(N, DAG);
13207 }
13208 
visitFP_TO_UINT(SDNode * N)13209 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
13210   SDValue N0 = N->getOperand(0);
13211   EVT VT = N->getValueType(0);
13212 
13213   // fold (fp_to_uint undef) -> undef
13214   if (N0.isUndef())
13215     return DAG.getUNDEF(VT);
13216 
13217   // fold (fp_to_uint c1fp) -> c1
13218   if (isConstantFPBuildVectorOrConstantFP(N0))
13219     return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
13220 
13221   return FoldIntToFPToInt(N, DAG);
13222 }
13223 
visitFP_ROUND(SDNode * N)13224 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
13225   SDValue N0 = N->getOperand(0);
13226   SDValue N1 = N->getOperand(1);
13227   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
13228   EVT VT = N->getValueType(0);
13229 
13230   // fold (fp_round c1fp) -> c1fp
13231   if (N0CFP)
13232     return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT, N0, N1);
13233 
13234   // fold (fp_round (fp_extend x)) -> x
13235   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
13236     return N0.getOperand(0);
13237 
13238   // fold (fp_round (fp_round x)) -> (fp_round x)
13239   if (N0.getOpcode() == ISD::FP_ROUND) {
13240     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
13241     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
13242 
13243     // Skip this folding if it results in an fp_round from f80 to f16.
13244     //
13245     // f80 to f16 always generates an expensive (and as yet, unimplemented)
13246     // libcall to __truncxfhf2 instead of selecting native f16 conversion
13247     // instructions from f32 or f64.  Moreover, the first (value-preserving)
13248     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
13249     // x86.
13250     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
13251       return SDValue();
13252 
13253     // If the first fp_round isn't a value preserving truncation, it might
13254     // introduce a tie in the second fp_round, that wouldn't occur in the
13255     // single-step fp_round we want to fold to.
13256     // In other words, double rounding isn't the same as rounding.
13257     // Also, this is a value preserving truncation iff both fp_round's are.
13258     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
13259       SDLoc DL(N);
13260       return DAG.getNode(ISD::FP_ROUND, DL, VT, N0.getOperand(0),
13261                          DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL));
13262     }
13263   }
13264 
13265   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
13266   if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse()) {
13267     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
13268                               N0.getOperand(0), N1);
13269     AddToWorklist(Tmp.getNode());
13270     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
13271                        Tmp, N0.getOperand(1));
13272   }
13273 
13274   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13275     return NewVSel;
13276 
13277   return SDValue();
13278 }
13279 
visitFP_EXTEND(SDNode * N)13280 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
13281   SDValue N0 = N->getOperand(0);
13282   EVT VT = N->getValueType(0);
13283 
13284   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
13285   if (N->hasOneUse() &&
13286       N->use_begin()->getOpcode() == ISD::FP_ROUND)
13287     return SDValue();
13288 
13289   // fold (fp_extend c1fp) -> c1fp
13290   if (isConstantFPBuildVectorOrConstantFP(N0))
13291     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
13292 
13293   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
13294   if (N0.getOpcode() == ISD::FP16_TO_FP &&
13295       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
13296     return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
13297 
13298   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
13299   // value of X.
13300   if (N0.getOpcode() == ISD::FP_ROUND
13301       && N0.getConstantOperandVal(1) == 1) {
13302     SDValue In = N0.getOperand(0);
13303     if (In.getValueType() == VT) return In;
13304     if (VT.bitsLT(In.getValueType()))
13305       return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
13306                          In, N0.getOperand(1));
13307     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
13308   }
13309 
13310   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
13311   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
13312        TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
13313     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13314     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
13315                                      LN0->getChain(),
13316                                      LN0->getBasePtr(), N0.getValueType(),
13317                                      LN0->getMemOperand());
13318     CombineTo(N, ExtLoad);
13319     CombineTo(N0.getNode(),
13320               DAG.getNode(ISD::FP_ROUND, SDLoc(N0),
13321                           N0.getValueType(), ExtLoad,
13322                           DAG.getIntPtrConstant(1, SDLoc(N0))),
13323               ExtLoad.getValue(1));
13324     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13325   }
13326 
13327   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13328     return NewVSel;
13329 
13330   return SDValue();
13331 }
13332 
visitFCEIL(SDNode * N)13333 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
13334   SDValue N0 = N->getOperand(0);
13335   EVT VT = N->getValueType(0);
13336 
13337   // fold (fceil c1) -> fceil(c1)
13338   if (isConstantFPBuildVectorOrConstantFP(N0))
13339     return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
13340 
13341   return SDValue();
13342 }
13343 
visitFTRUNC(SDNode * N)13344 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
13345   SDValue N0 = N->getOperand(0);
13346   EVT VT = N->getValueType(0);
13347 
13348   // fold (ftrunc c1) -> ftrunc(c1)
13349   if (isConstantFPBuildVectorOrConstantFP(N0))
13350     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
13351 
13352   // fold ftrunc (known rounded int x) -> x
13353   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
13354   // likely to be generated to extract integer from a rounded floating value.
13355   switch (N0.getOpcode()) {
13356   default: break;
13357   case ISD::FRINT:
13358   case ISD::FTRUNC:
13359   case ISD::FNEARBYINT:
13360   case ISD::FFLOOR:
13361   case ISD::FCEIL:
13362     return N0;
13363   }
13364 
13365   return SDValue();
13366 }
13367 
visitFFLOOR(SDNode * N)13368 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
13369   SDValue N0 = N->getOperand(0);
13370   EVT VT = N->getValueType(0);
13371 
13372   // fold (ffloor c1) -> ffloor(c1)
13373   if (isConstantFPBuildVectorOrConstantFP(N0))
13374     return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
13375 
13376   return SDValue();
13377 }
13378 
13379 // FIXME: FNEG and FABS have a lot in common; refactor.
visitFNEG(SDNode * N)13380 SDValue DAGCombiner::visitFNEG(SDNode *N) {
13381   SDValue N0 = N->getOperand(0);
13382   EVT VT = N->getValueType(0);
13383 
13384   // Constant fold FNEG.
13385   if (isConstantFPBuildVectorOrConstantFP(N0))
13386     return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
13387 
13388   if (TLI.isNegatibleForFree(N0, DAG, LegalOperations, ForCodeSize))
13389     return TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
13390 
13391   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0 FIXME: This is
13392   // duplicated in isNegatibleForFree, but isNegatibleForFree doesn't know it
13393   // was called from a context with a nsz flag if the input fsub does not.
13394   if (N0.getOpcode() == ISD::FSUB &&
13395       (DAG.getTarget().Options.NoSignedZerosFPMath ||
13396        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
13397     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
13398                        N0.getOperand(0), N->getFlags());
13399   }
13400 
13401   // Transform fneg(bitconvert(x)) -> bitconvert(x ^ sign) to avoid loading
13402   // constant pool values.
13403   if (!TLI.isFNegFree(VT) &&
13404       N0.getOpcode() == ISD::BITCAST &&
13405       N0.getNode()->hasOneUse()) {
13406     SDValue Int = N0.getOperand(0);
13407     EVT IntVT = Int.getValueType();
13408     if (IntVT.isInteger() && !IntVT.isVector()) {
13409       APInt SignMask;
13410       if (N0.getValueType().isVector()) {
13411         // For a vector, get a mask such as 0x80... per scalar element
13412         // and splat it.
13413         SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
13414         SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
13415       } else {
13416         // For a scalar, just generate 0x80...
13417         SignMask = APInt::getSignMask(IntVT.getSizeInBits());
13418       }
13419       SDLoc DL0(N0);
13420       Int = DAG.getNode(ISD::XOR, DL0, IntVT, Int,
13421                         DAG.getConstant(SignMask, DL0, IntVT));
13422       AddToWorklist(Int.getNode());
13423       return DAG.getBitcast(VT, Int);
13424     }
13425   }
13426 
13427   // (fneg (fmul c, x)) -> (fmul -c, x)
13428   if (N0.getOpcode() == ISD::FMUL &&
13429       (N0.getNode()->hasOneUse() || !TLI.isFNegFree(VT))) {
13430     ConstantFPSDNode *CFP1 = dyn_cast<ConstantFPSDNode>(N0.getOperand(1));
13431     if (CFP1) {
13432       APFloat CVal = CFP1->getValueAPF();
13433       CVal.changeSign();
13434       if (LegalDAG && (TLI.isFPImmLegal(CVal, VT, ForCodeSize) ||
13435                        TLI.isOperationLegal(ISD::ConstantFP, VT)))
13436         return DAG.getNode(
13437             ISD::FMUL, SDLoc(N), VT, N0.getOperand(0),
13438             DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0.getOperand(1)),
13439             N0->getFlags());
13440     }
13441   }
13442 
13443   return SDValue();
13444 }
13445 
visitFMinMax(SelectionDAG & DAG,SDNode * N,APFloat (* Op)(const APFloat &,const APFloat &))13446 static SDValue visitFMinMax(SelectionDAG &DAG, SDNode *N,
13447                             APFloat (*Op)(const APFloat &, const APFloat &)) {
13448   SDValue N0 = N->getOperand(0);
13449   SDValue N1 = N->getOperand(1);
13450   EVT VT = N->getValueType(0);
13451   const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0);
13452   const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1);
13453 
13454   if (N0CFP && N1CFP) {
13455     const APFloat &C0 = N0CFP->getValueAPF();
13456     const APFloat &C1 = N1CFP->getValueAPF();
13457     return DAG.getConstantFP(Op(C0, C1), SDLoc(N), VT);
13458   }
13459 
13460   // Canonicalize to constant on RHS.
13461   if (isConstantFPBuildVectorOrConstantFP(N0) &&
13462       !isConstantFPBuildVectorOrConstantFP(N1))
13463     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
13464 
13465   return SDValue();
13466 }
13467 
visitFMINNUM(SDNode * N)13468 SDValue DAGCombiner::visitFMINNUM(SDNode *N) {
13469   return visitFMinMax(DAG, N, minnum);
13470 }
13471 
visitFMAXNUM(SDNode * N)13472 SDValue DAGCombiner::visitFMAXNUM(SDNode *N) {
13473   return visitFMinMax(DAG, N, maxnum);
13474 }
13475 
visitFMINIMUM(SDNode * N)13476 SDValue DAGCombiner::visitFMINIMUM(SDNode *N) {
13477   return visitFMinMax(DAG, N, minimum);
13478 }
13479 
visitFMAXIMUM(SDNode * N)13480 SDValue DAGCombiner::visitFMAXIMUM(SDNode *N) {
13481   return visitFMinMax(DAG, N, maximum);
13482 }
13483 
visitFABS(SDNode * N)13484 SDValue DAGCombiner::visitFABS(SDNode *N) {
13485   SDValue N0 = N->getOperand(0);
13486   EVT VT = N->getValueType(0);
13487 
13488   // fold (fabs c1) -> fabs(c1)
13489   if (isConstantFPBuildVectorOrConstantFP(N0))
13490     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
13491 
13492   // fold (fabs (fabs x)) -> (fabs x)
13493   if (N0.getOpcode() == ISD::FABS)
13494     return N->getOperand(0);
13495 
13496   // fold (fabs (fneg x)) -> (fabs x)
13497   // fold (fabs (fcopysign x, y)) -> (fabs x)
13498   if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
13499     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
13500 
13501   // fabs(bitcast(x)) -> bitcast(x & ~sign) to avoid constant pool loads.
13502   if (!TLI.isFAbsFree(VT) && N0.getOpcode() == ISD::BITCAST && N0.hasOneUse()) {
13503     SDValue Int = N0.getOperand(0);
13504     EVT IntVT = Int.getValueType();
13505     if (IntVT.isInteger() && !IntVT.isVector()) {
13506       APInt SignMask;
13507       if (N0.getValueType().isVector()) {
13508         // For a vector, get a mask such as 0x7f... per scalar element
13509         // and splat it.
13510         SignMask = ~APInt::getSignMask(N0.getScalarValueSizeInBits());
13511         SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
13512       } else {
13513         // For a scalar, just generate 0x7f...
13514         SignMask = ~APInt::getSignMask(IntVT.getSizeInBits());
13515       }
13516       SDLoc DL(N0);
13517       Int = DAG.getNode(ISD::AND, DL, IntVT, Int,
13518                         DAG.getConstant(SignMask, DL, IntVT));
13519       AddToWorklist(Int.getNode());
13520       return DAG.getBitcast(N->getValueType(0), Int);
13521     }
13522   }
13523 
13524   return SDValue();
13525 }
13526 
visitBRCOND(SDNode * N)13527 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
13528   SDValue Chain = N->getOperand(0);
13529   SDValue N1 = N->getOperand(1);
13530   SDValue N2 = N->getOperand(2);
13531 
13532   // If N is a constant we could fold this into a fallthrough or unconditional
13533   // branch. However that doesn't happen very often in normal code, because
13534   // Instcombine/SimplifyCFG should have handled the available opportunities.
13535   // If we did this folding here, it would be necessary to update the
13536   // MachineBasicBlock CFG, which is awkward.
13537 
13538   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
13539   // on the target.
13540   if (N1.getOpcode() == ISD::SETCC &&
13541       TLI.isOperationLegalOrCustom(ISD::BR_CC,
13542                                    N1.getOperand(0).getValueType())) {
13543     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
13544                        Chain, N1.getOperand(2),
13545                        N1.getOperand(0), N1.getOperand(1), N2);
13546   }
13547 
13548   if (N1.hasOneUse()) {
13549     if (SDValue NewN1 = rebuildSetCC(N1))
13550       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain, NewN1, N2);
13551   }
13552 
13553   return SDValue();
13554 }
13555 
rebuildSetCC(SDValue N)13556 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
13557   if (N.getOpcode() == ISD::SRL ||
13558       (N.getOpcode() == ISD::TRUNCATE &&
13559        (N.getOperand(0).hasOneUse() &&
13560         N.getOperand(0).getOpcode() == ISD::SRL))) {
13561     // Look pass the truncate.
13562     if (N.getOpcode() == ISD::TRUNCATE)
13563       N = N.getOperand(0);
13564 
13565     // Match this pattern so that we can generate simpler code:
13566     //
13567     //   %a = ...
13568     //   %b = and i32 %a, 2
13569     //   %c = srl i32 %b, 1
13570     //   brcond i32 %c ...
13571     //
13572     // into
13573     //
13574     //   %a = ...
13575     //   %b = and i32 %a, 2
13576     //   %c = setcc eq %b, 0
13577     //   brcond %c ...
13578     //
13579     // This applies only when the AND constant value has one bit set and the
13580     // SRL constant is equal to the log2 of the AND constant. The back-end is
13581     // smart enough to convert the result into a TEST/JMP sequence.
13582     SDValue Op0 = N.getOperand(0);
13583     SDValue Op1 = N.getOperand(1);
13584 
13585     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
13586       SDValue AndOp1 = Op0.getOperand(1);
13587 
13588       if (AndOp1.getOpcode() == ISD::Constant) {
13589         const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
13590 
13591         if (AndConst.isPowerOf2() &&
13592             cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
13593           SDLoc DL(N);
13594           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
13595                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
13596                               ISD::SETNE);
13597         }
13598       }
13599     }
13600   }
13601 
13602   // Transform br(xor(x, y)) -> br(x != y)
13603   // Transform br(xor(xor(x,y), 1)) -> br (x == y)
13604   if (N.getOpcode() == ISD::XOR) {
13605     // Because we may call this on a speculatively constructed
13606     // SimplifiedSetCC Node, we need to simplify this node first.
13607     // Ideally this should be folded into SimplifySetCC and not
13608     // here. For now, grab a handle to N so we don't lose it from
13609     // replacements interal to the visit.
13610     HandleSDNode XORHandle(N);
13611     while (N.getOpcode() == ISD::XOR) {
13612       SDValue Tmp = visitXOR(N.getNode());
13613       // No simplification done.
13614       if (!Tmp.getNode())
13615         break;
13616       // Returning N is form in-visit replacement that may invalidated
13617       // N. Grab value from Handle.
13618       if (Tmp.getNode() == N.getNode())
13619         N = XORHandle.getValue();
13620       else // Node simplified. Try simplifying again.
13621         N = Tmp;
13622     }
13623 
13624     if (N.getOpcode() != ISD::XOR)
13625       return N;
13626 
13627     SDNode *TheXor = N.getNode();
13628 
13629     SDValue Op0 = TheXor->getOperand(0);
13630     SDValue Op1 = TheXor->getOperand(1);
13631 
13632     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
13633       bool Equal = false;
13634       if (isOneConstant(Op0) && Op0.hasOneUse() &&
13635           Op0.getOpcode() == ISD::XOR) {
13636         TheXor = Op0.getNode();
13637         Equal = true;
13638       }
13639 
13640       EVT SetCCVT = N.getValueType();
13641       if (LegalTypes)
13642         SetCCVT = getSetCCResultType(SetCCVT);
13643       // Replace the uses of XOR with SETCC
13644       return DAG.getSetCC(SDLoc(TheXor), SetCCVT, Op0, Op1,
13645                           Equal ? ISD::SETEQ : ISD::SETNE);
13646     }
13647   }
13648 
13649   return SDValue();
13650 }
13651 
13652 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
13653 //
visitBR_CC(SDNode * N)13654 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
13655   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
13656   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
13657 
13658   // If N is a constant we could fold this into a fallthrough or unconditional
13659   // branch. However that doesn't happen very often in normal code, because
13660   // Instcombine/SimplifyCFG should have handled the available opportunities.
13661   // If we did this folding here, it would be necessary to update the
13662   // MachineBasicBlock CFG, which is awkward.
13663 
13664   // Use SimplifySetCC to simplify SETCC's.
13665   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
13666                                CondLHS, CondRHS, CC->get(), SDLoc(N),
13667                                false);
13668   if (Simp.getNode()) AddToWorklist(Simp.getNode());
13669 
13670   // fold to a simpler setcc
13671   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
13672     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
13673                        N->getOperand(0), Simp.getOperand(2),
13674                        Simp.getOperand(0), Simp.getOperand(1),
13675                        N->getOperand(4));
13676 
13677   return SDValue();
13678 }
13679 
13680 /// Return true if 'Use' is a load or a store that uses N as its base pointer
13681 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)13682 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use,
13683                                     SelectionDAG &DAG,
13684                                     const TargetLowering &TLI) {
13685   EVT VT;
13686   unsigned AS;
13687 
13688   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
13689     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
13690       return false;
13691     VT = LD->getMemoryVT();
13692     AS = LD->getAddressSpace();
13693   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
13694     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
13695       return false;
13696     VT = ST->getMemoryVT();
13697     AS = ST->getAddressSpace();
13698   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
13699     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
13700       return false;
13701     VT = LD->getMemoryVT();
13702     AS = LD->getAddressSpace();
13703   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
13704     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
13705       return false;
13706     VT = ST->getMemoryVT();
13707     AS = ST->getAddressSpace();
13708   } else
13709     return false;
13710 
13711   TargetLowering::AddrMode AM;
13712   if (N->getOpcode() == ISD::ADD) {
13713     AM.HasBaseReg = true;
13714     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
13715     if (Offset)
13716       // [reg +/- imm]
13717       AM.BaseOffs = Offset->getSExtValue();
13718     else
13719       // [reg +/- reg]
13720       AM.Scale = 1;
13721   } else if (N->getOpcode() == ISD::SUB) {
13722     AM.HasBaseReg = true;
13723     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
13724     if (Offset)
13725       // [reg +/- imm]
13726       AM.BaseOffs = -Offset->getSExtValue();
13727     else
13728       // [reg +/- reg]
13729       AM.Scale = 1;
13730   } else
13731     return false;
13732 
13733   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
13734                                    VT.getTypeForEVT(*DAG.getContext()), AS);
13735 }
13736 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)13737 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
13738                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
13739                                      const TargetLowering &TLI) {
13740   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
13741     if (LD->isIndexed())
13742       return false;
13743     EVT VT = LD->getMemoryVT();
13744     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
13745       return false;
13746     Ptr = LD->getBasePtr();
13747   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
13748     if (ST->isIndexed())
13749       return false;
13750     EVT VT = ST->getMemoryVT();
13751     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
13752       return false;
13753     Ptr = ST->getBasePtr();
13754     IsLoad = false;
13755   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
13756     if (LD->isIndexed())
13757       return false;
13758     EVT VT = LD->getMemoryVT();
13759     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
13760         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
13761       return false;
13762     Ptr = LD->getBasePtr();
13763     IsMasked = true;
13764   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
13765     if (ST->isIndexed())
13766       return false;
13767     EVT VT = ST->getMemoryVT();
13768     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
13769         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
13770       return false;
13771     Ptr = ST->getBasePtr();
13772     IsLoad = false;
13773     IsMasked = true;
13774   } else {
13775     return false;
13776   }
13777   return true;
13778 }
13779 
13780 /// Try turning a load/store into a pre-indexed load/store when the base
13781 /// pointer is an add or subtract and it has other uses besides the load/store.
13782 /// After the transformation, the new indexed load/store has effectively folded
13783 /// the add/subtract in and all of its other uses are redirected to the
13784 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)13785 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
13786   if (Level < AfterLegalizeDAG)
13787     return false;
13788 
13789   bool IsLoad = true;
13790   bool IsMasked = false;
13791   SDValue Ptr;
13792   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
13793                                 Ptr, TLI))
13794     return false;
13795 
13796   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
13797   // out.  There is no reason to make this a preinc/predec.
13798   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
13799       Ptr.getNode()->hasOneUse())
13800     return false;
13801 
13802   // Ask the target to do addressing mode selection.
13803   SDValue BasePtr;
13804   SDValue Offset;
13805   ISD::MemIndexedMode AM = ISD::UNINDEXED;
13806   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
13807     return false;
13808 
13809   // Backends without true r+i pre-indexed forms may need to pass a
13810   // constant base with a variable offset so that constant coercion
13811   // will work with the patterns in canonical form.
13812   bool Swapped = false;
13813   if (isa<ConstantSDNode>(BasePtr)) {
13814     std::swap(BasePtr, Offset);
13815     Swapped = true;
13816   }
13817 
13818   // Don't create a indexed load / store with zero offset.
13819   if (isNullConstant(Offset))
13820     return false;
13821 
13822   // Try turning it into a pre-indexed load / store except when:
13823   // 1) The new base ptr is a frame index.
13824   // 2) If N is a store and the new base ptr is either the same as or is a
13825   //    predecessor of the value being stored.
13826   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
13827   //    that would create a cycle.
13828   // 4) All uses are load / store ops that use it as old base ptr.
13829 
13830   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
13831   // (plus the implicit offset) to a register to preinc anyway.
13832   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
13833     return false;
13834 
13835   // Check #2.
13836   if (!IsLoad) {
13837     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
13838                            : cast<StoreSDNode>(N)->getValue();
13839 
13840     // Would require a copy.
13841     if (Val == BasePtr)
13842       return false;
13843 
13844     // Would create a cycle.
13845     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
13846       return false;
13847   }
13848 
13849   // Caches for hasPredecessorHelper.
13850   SmallPtrSet<const SDNode *, 32> Visited;
13851   SmallVector<const SDNode *, 16> Worklist;
13852   Worklist.push_back(N);
13853 
13854   // If the offset is a constant, there may be other adds of constants that
13855   // can be folded with this one. We should do this to avoid having to keep
13856   // a copy of the original base pointer.
13857   SmallVector<SDNode *, 16> OtherUses;
13858   if (isa<ConstantSDNode>(Offset))
13859     for (SDNode::use_iterator UI = BasePtr.getNode()->use_begin(),
13860                               UE = BasePtr.getNode()->use_end();
13861          UI != UE; ++UI) {
13862       SDUse &Use = UI.getUse();
13863       // Skip the use that is Ptr and uses of other results from BasePtr's
13864       // node (important for nodes that return multiple results).
13865       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
13866         continue;
13867 
13868       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
13869         continue;
13870 
13871       if (Use.getUser()->getOpcode() != ISD::ADD &&
13872           Use.getUser()->getOpcode() != ISD::SUB) {
13873         OtherUses.clear();
13874         break;
13875       }
13876 
13877       SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
13878       if (!isa<ConstantSDNode>(Op1)) {
13879         OtherUses.clear();
13880         break;
13881       }
13882 
13883       // FIXME: In some cases, we can be smarter about this.
13884       if (Op1.getValueType() != Offset.getValueType()) {
13885         OtherUses.clear();
13886         break;
13887       }
13888 
13889       OtherUses.push_back(Use.getUser());
13890     }
13891 
13892   if (Swapped)
13893     std::swap(BasePtr, Offset);
13894 
13895   // Now check for #3 and #4.
13896   bool RealUse = false;
13897 
13898   for (SDNode *Use : Ptr.getNode()->uses()) {
13899     if (Use == N)
13900       continue;
13901     if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
13902       return false;
13903 
13904     // If Ptr may be folded in addressing mode of other use, then it's
13905     // not profitable to do this transformation.
13906     if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
13907       RealUse = true;
13908   }
13909 
13910   if (!RealUse)
13911     return false;
13912 
13913   SDValue Result;
13914   if (!IsMasked) {
13915     if (IsLoad)
13916       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
13917     else
13918       Result =
13919           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
13920   } else {
13921     if (IsLoad)
13922       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
13923                                         Offset, AM);
13924     else
13925       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
13926                                          Offset, AM);
13927   }
13928   ++PreIndexedNodes;
13929   ++NodesCombined;
13930   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
13931              Result.getNode()->dump(&DAG); dbgs() << '\n');
13932   WorklistRemover DeadNodes(*this);
13933   if (IsLoad) {
13934     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
13935     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
13936   } else {
13937     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
13938   }
13939 
13940   // Finally, since the node is now dead, remove it from the graph.
13941   deleteAndRecombine(N);
13942 
13943   if (Swapped)
13944     std::swap(BasePtr, Offset);
13945 
13946   // Replace other uses of BasePtr that can be updated to use Ptr
13947   for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
13948     unsigned OffsetIdx = 1;
13949     if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
13950       OffsetIdx = 0;
13951     assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
13952            BasePtr.getNode() && "Expected BasePtr operand");
13953 
13954     // We need to replace ptr0 in the following expression:
13955     //   x0 * offset0 + y0 * ptr0 = t0
13956     // knowing that
13957     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
13958     //
13959     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
13960     // indexed load/store and the expression that needs to be re-written.
13961     //
13962     // Therefore, we have:
13963     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
13964 
13965     ConstantSDNode *CN =
13966       cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
13967     int X0, X1, Y0, Y1;
13968     const APInt &Offset0 = CN->getAPIntValue();
13969     APInt Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
13970 
13971     X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
13972     Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
13973     X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
13974     Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
13975 
13976     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
13977 
13978     APInt CNV = Offset0;
13979     if (X0 < 0) CNV = -CNV;
13980     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
13981     else CNV = CNV - Offset1;
13982 
13983     SDLoc DL(OtherUses[i]);
13984 
13985     // We can now generate the new expression.
13986     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
13987     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
13988 
13989     SDValue NewUse = DAG.getNode(Opcode,
13990                                  DL,
13991                                  OtherUses[i]->getValueType(0), NewOp1, NewOp2);
13992     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
13993     deleteAndRecombine(OtherUses[i]);
13994   }
13995 
13996   // Replace the uses of Ptr with uses of the updated base value.
13997   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
13998   deleteAndRecombine(Ptr.getNode());
13999   AddToWorklist(Result.getNode());
14000 
14001   return true;
14002 }
14003 
14004 /// Try to combine a load/store with a add/sub of the base pointer node into a
14005 /// post-indexed load/store. The transformation folded the add/subtract into the
14006 /// new indexed load/store effectively and all of its uses are redirected to the
14007 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)14008 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
14009   if (Level < AfterLegalizeDAG)
14010     return false;
14011 
14012   bool IsLoad = true;
14013   bool IsMasked = false;
14014   SDValue Ptr;
14015   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad, IsMasked,
14016                                 Ptr, TLI))
14017     return false;
14018 
14019   if (Ptr.getNode()->hasOneUse())
14020     return false;
14021 
14022   for (SDNode *Op : Ptr.getNode()->uses()) {
14023     if (Op == N ||
14024         (Op->getOpcode() != ISD::ADD && Op->getOpcode() != ISD::SUB))
14025       continue;
14026 
14027     SDValue BasePtr;
14028     SDValue Offset;
14029     ISD::MemIndexedMode AM = ISD::UNINDEXED;
14030     if (TLI.getPostIndexedAddressParts(N, Op, BasePtr, Offset, AM, DAG)) {
14031       // Don't create a indexed load / store with zero offset.
14032       if (isNullConstant(Offset))
14033         continue;
14034 
14035       // Try turning it into a post-indexed load / store except when
14036       // 1) All uses are load / store ops that use it as base ptr (and
14037       //    it may be folded as addressing mmode).
14038       // 2) Op must be independent of N, i.e. Op is neither a predecessor
14039       //    nor a successor of N. Otherwise, if Op is folded that would
14040       //    create a cycle.
14041 
14042       if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
14043         continue;
14044 
14045       // Check for #1.
14046       bool TryNext = false;
14047       for (SDNode *Use : BasePtr.getNode()->uses()) {
14048         if (Use == Ptr.getNode())
14049           continue;
14050 
14051         // If all the uses are load / store addresses, then don't do the
14052         // transformation.
14053         if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
14054           bool RealUse = false;
14055           for (SDNode *UseUse : Use->uses()) {
14056             if (!canFoldInAddressingMode(Use, UseUse, DAG, TLI))
14057               RealUse = true;
14058           }
14059 
14060           if (!RealUse) {
14061             TryNext = true;
14062             break;
14063           }
14064         }
14065       }
14066 
14067       if (TryNext)
14068         continue;
14069 
14070       // Check for #2.
14071       SmallPtrSet<const SDNode *, 32> Visited;
14072       SmallVector<const SDNode *, 8> Worklist;
14073       // Ptr is predecessor to both N and Op.
14074       Visited.insert(Ptr.getNode());
14075       Worklist.push_back(N);
14076       Worklist.push_back(Op);
14077       if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
14078           !SDNode::hasPredecessorHelper(Op, Visited, Worklist)) {
14079         SDValue Result;
14080         if (!IsMasked)
14081           Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
14082                                                Offset, AM)
14083                           : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
14084                                                 BasePtr, Offset, AM);
14085         else
14086           Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
14087                                                      BasePtr, Offset, AM)
14088                           : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
14089                                                       BasePtr, Offset, AM);
14090         ++PostIndexedNodes;
14091         ++NodesCombined;
14092         LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG);
14093                    dbgs() << "\nWith: "; Result.getNode()->dump(&DAG);
14094                    dbgs() << '\n');
14095         WorklistRemover DeadNodes(*this);
14096         if (IsLoad) {
14097           DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
14098           DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
14099         } else {
14100           DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
14101         }
14102 
14103         // Finally, since the node is now dead, remove it from the graph.
14104         deleteAndRecombine(N);
14105 
14106         // Replace the uses of Use with uses of the updated base value.
14107         DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
14108                                       Result.getValue(IsLoad ? 1 : 0));
14109         deleteAndRecombine(Op);
14110         return true;
14111       }
14112     }
14113   }
14114 
14115   return false;
14116 }
14117 
14118 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)14119 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
14120   ISD::MemIndexedMode AM = LD->getAddressingMode();
14121   assert(AM != ISD::UNINDEXED);
14122   SDValue BP = LD->getOperand(1);
14123   SDValue Inc = LD->getOperand(2);
14124 
14125   // Some backends use TargetConstants for load offsets, but don't expect
14126   // TargetConstants in general ADD nodes. We can convert these constants into
14127   // regular Constants (if the constant is not opaque).
14128   assert((Inc.getOpcode() != ISD::TargetConstant ||
14129           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
14130          "Cannot split out indexing using opaque target constants");
14131   if (Inc.getOpcode() == ISD::TargetConstant) {
14132     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
14133     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
14134                           ConstInc->getValueType(0));
14135   }
14136 
14137   unsigned Opc =
14138       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
14139   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
14140 }
14141 
numVectorEltsOrZero(EVT T)14142 static inline int numVectorEltsOrZero(EVT T) {
14143   return T.isVector() ? T.getVectorNumElements() : 0;
14144 }
14145 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)14146 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
14147   Val = ST->getValue();
14148   EVT STType = Val.getValueType();
14149   EVT STMemType = ST->getMemoryVT();
14150   if (STType == STMemType)
14151     return true;
14152   if (isTypeLegal(STMemType))
14153     return false; // fail.
14154   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
14155       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
14156     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
14157     return true;
14158   }
14159   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
14160       STType.isInteger() && STMemType.isInteger()) {
14161     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
14162     return true;
14163   }
14164   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
14165     Val = DAG.getBitcast(STMemType, Val);
14166     return true;
14167   }
14168   return false; // fail.
14169 }
14170 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)14171 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
14172   EVT LDMemType = LD->getMemoryVT();
14173   EVT LDType = LD->getValueType(0);
14174   assert(Val.getValueType() == LDMemType &&
14175          "Attempting to extend value of non-matching type");
14176   if (LDType == LDMemType)
14177     return true;
14178   if (LDMemType.isInteger() && LDType.isInteger()) {
14179     switch (LD->getExtensionType()) {
14180     case ISD::NON_EXTLOAD:
14181       Val = DAG.getBitcast(LDType, Val);
14182       return true;
14183     case ISD::EXTLOAD:
14184       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
14185       return true;
14186     case ISD::SEXTLOAD:
14187       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
14188       return true;
14189     case ISD::ZEXTLOAD:
14190       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
14191       return true;
14192     }
14193   }
14194   return false;
14195 }
14196 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)14197 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
14198   if (OptLevel == CodeGenOpt::None || !LD->isSimple())
14199     return SDValue();
14200   SDValue Chain = LD->getOperand(0);
14201   StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
14202   // TODO: Relax this restriction for unordered atomics (see D66309)
14203   if (!ST || !ST->isSimple())
14204     return SDValue();
14205 
14206   EVT LDType = LD->getValueType(0);
14207   EVT LDMemType = LD->getMemoryVT();
14208   EVT STMemType = ST->getMemoryVT();
14209   EVT STType = ST->getValue().getValueType();
14210 
14211   BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
14212   BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
14213   int64_t Offset;
14214   if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
14215     return SDValue();
14216 
14217   // Normalize for Endianness. After this Offset=0 will denote that the least
14218   // significant bit in the loaded value maps to the least significant bit in
14219   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
14220   // n:th least significant byte of the stored value.
14221   if (DAG.getDataLayout().isBigEndian())
14222     Offset = ((int64_t)STMemType.getStoreSizeInBits() -
14223               (int64_t)LDMemType.getStoreSizeInBits()) / 8 - Offset;
14224 
14225   // Check that the stored value cover all bits that are loaded.
14226   bool STCoversLD =
14227       (Offset >= 0) &&
14228       (Offset * 8 + LDMemType.getSizeInBits() <= STMemType.getSizeInBits());
14229 
14230   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
14231     if (LD->isIndexed()) {
14232       // Cannot handle opaque target constants and we must respect the user's
14233       // request not to split indexes from loads.
14234       if (!canSplitIdx(LD))
14235         return SDValue();
14236       SDValue Idx = SplitIndexingFromLoad(LD);
14237       SDValue Ops[] = {Val, Idx, Chain};
14238       return CombineTo(LD, Ops, 3);
14239     }
14240     return CombineTo(LD, Val, Chain);
14241   };
14242 
14243   if (!STCoversLD)
14244     return SDValue();
14245 
14246   // Memory as copy space (potentially masked).
14247   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
14248     // Simple case: Direct non-truncating forwarding
14249     if (LDType.getSizeInBits() == LDMemType.getSizeInBits())
14250       return ReplaceLd(LD, ST->getValue(), Chain);
14251     // Can we model the truncate and extension with an and mask?
14252     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
14253         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
14254       // Mask to size of LDMemType
14255       auto Mask =
14256           DAG.getConstant(APInt::getLowBitsSet(STType.getSizeInBits(),
14257                                                STMemType.getSizeInBits()),
14258                           SDLoc(ST), STType);
14259       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
14260       return ReplaceLd(LD, Val, Chain);
14261     }
14262   }
14263 
14264   // TODO: Deal with nonzero offset.
14265   if (LD->getBasePtr().isUndef() || Offset != 0)
14266     return SDValue();
14267   // Model necessary truncations / extenstions.
14268   SDValue Val;
14269   // Truncate Value To Stored Memory Size.
14270   do {
14271     if (!getTruncatedStoreValue(ST, Val))
14272       continue;
14273     if (!isTypeLegal(LDMemType))
14274       continue;
14275     if (STMemType != LDMemType) {
14276       // TODO: Support vectors? This requires extract_subvector/bitcast.
14277       if (!STMemType.isVector() && !LDMemType.isVector() &&
14278           STMemType.isInteger() && LDMemType.isInteger())
14279         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
14280       else
14281         continue;
14282     }
14283     if (!extendLoadedValueToExtension(LD, Val))
14284       continue;
14285     return ReplaceLd(LD, Val, Chain);
14286   } while (false);
14287 
14288   // On failure, cleanup dead nodes we may have created.
14289   if (Val->use_empty())
14290     deleteAndRecombine(Val.getNode());
14291   return SDValue();
14292 }
14293 
visitLOAD(SDNode * N)14294 SDValue DAGCombiner::visitLOAD(SDNode *N) {
14295   LoadSDNode *LD  = cast<LoadSDNode>(N);
14296   SDValue Chain = LD->getChain();
14297   SDValue Ptr   = LD->getBasePtr();
14298 
14299   // If load is not volatile and there are no uses of the loaded value (and
14300   // the updated indexed value in case of indexed loads), change uses of the
14301   // chain value into uses of the chain input (i.e. delete the dead load).
14302   // TODO: Allow this for unordered atomics (see D66309)
14303   if (LD->isSimple()) {
14304     if (N->getValueType(1) == MVT::Other) {
14305       // Unindexed loads.
14306       if (!N->hasAnyUseOfValue(0)) {
14307         // It's not safe to use the two value CombineTo variant here. e.g.
14308         // v1, chain2 = load chain1, loc
14309         // v2, chain3 = load chain2, loc
14310         // v3         = add v2, c
14311         // Now we replace use of chain2 with chain1.  This makes the second load
14312         // isomorphic to the one we are deleting, and thus makes this load live.
14313         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
14314                    dbgs() << "\nWith chain: "; Chain.getNode()->dump(&DAG);
14315                    dbgs() << "\n");
14316         WorklistRemover DeadNodes(*this);
14317         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
14318         AddUsersToWorklist(Chain.getNode());
14319         if (N->use_empty())
14320           deleteAndRecombine(N);
14321 
14322         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14323       }
14324     } else {
14325       // Indexed loads.
14326       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
14327 
14328       // If this load has an opaque TargetConstant offset, then we cannot split
14329       // the indexing into an add/sub directly (that TargetConstant may not be
14330       // valid for a different type of node, and we cannot convert an opaque
14331       // target constant into a regular constant).
14332       bool CanSplitIdx = canSplitIdx(LD);
14333 
14334       if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
14335         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
14336         SDValue Index;
14337         if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
14338           Index = SplitIndexingFromLoad(LD);
14339           // Try to fold the base pointer arithmetic into subsequent loads and
14340           // stores.
14341           AddUsersToWorklist(N);
14342         } else
14343           Index = DAG.getUNDEF(N->getValueType(1));
14344         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
14345                    dbgs() << "\nWith: "; Undef.getNode()->dump(&DAG);
14346                    dbgs() << " and 2 other values\n");
14347         WorklistRemover DeadNodes(*this);
14348         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
14349         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
14350         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
14351         deleteAndRecombine(N);
14352         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14353       }
14354     }
14355   }
14356 
14357   // If this load is directly stored, replace the load value with the stored
14358   // value.
14359   if (auto V = ForwardStoreValueToDirectLoad(LD))
14360     return V;
14361 
14362   // Try to infer better alignment information than the load already has.
14363   if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) {
14364     if (unsigned Align = DAG.InferPtrAlignment(Ptr)) {
14365       if (Align > LD->getAlignment() && LD->getSrcValueOffset() % Align == 0) {
14366         SDValue NewLoad = DAG.getExtLoad(
14367             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
14368             LD->getPointerInfo(), LD->getMemoryVT(), Align,
14369             LD->getMemOperand()->getFlags(), LD->getAAInfo());
14370         // NewLoad will always be N as we are only refining the alignment
14371         assert(NewLoad.getNode() == N);
14372         (void)NewLoad;
14373       }
14374     }
14375   }
14376 
14377   if (LD->isUnindexed()) {
14378     // Walk up chain skipping non-aliasing memory nodes.
14379     SDValue BetterChain = FindBetterChain(LD, Chain);
14380 
14381     // If there is a better chain.
14382     if (Chain != BetterChain) {
14383       SDValue ReplLoad;
14384 
14385       // Replace the chain to void dependency.
14386       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
14387         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
14388                                BetterChain, Ptr, LD->getMemOperand());
14389       } else {
14390         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
14391                                   LD->getValueType(0),
14392                                   BetterChain, Ptr, LD->getMemoryVT(),
14393                                   LD->getMemOperand());
14394       }
14395 
14396       // Create token factor to keep old chain connected.
14397       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
14398                                   MVT::Other, Chain, ReplLoad.getValue(1));
14399 
14400       // Replace uses with load result and token factor
14401       return CombineTo(N, ReplLoad.getValue(0), Token);
14402     }
14403   }
14404 
14405   // Try transforming N to an indexed load.
14406   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
14407     return SDValue(N, 0);
14408 
14409   // Try to slice up N to more direct loads if the slices are mapped to
14410   // different register banks or pairing can take place.
14411   if (SliceUpLoad(N))
14412     return SDValue(N, 0);
14413 
14414   return SDValue();
14415 }
14416 
14417 namespace {
14418 
14419 /// Helper structure used to slice a load in smaller loads.
14420 /// Basically a slice is obtained from the following sequence:
14421 /// Origin = load Ty1, Base
14422 /// Shift = srl Ty1 Origin, CstTy Amount
14423 /// Inst = trunc Shift to Ty2
14424 ///
14425 /// Then, it will be rewritten into:
14426 /// Slice = load SliceTy, Base + SliceOffset
14427 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
14428 ///
14429 /// SliceTy is deduced from the number of bits that are actually used to
14430 /// build Inst.
14431 struct LoadedSlice {
14432   /// Helper structure used to compute the cost of a slice.
14433   struct Cost {
14434     /// Are we optimizing for code size.
14435     bool ForCodeSize = false;
14436 
14437     /// Various cost.
14438     unsigned Loads = 0;
14439     unsigned Truncates = 0;
14440     unsigned CrossRegisterBanksCopies = 0;
14441     unsigned ZExts = 0;
14442     unsigned Shift = 0;
14443 
Cost__anon4d358cf62211::LoadedSlice::Cost14444     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
14445 
14446     /// Get the cost of one isolated slice.
Cost__anon4d358cf62211::LoadedSlice::Cost14447     Cost(const LoadedSlice &LS, bool ForCodeSize)
14448         : ForCodeSize(ForCodeSize), Loads(1) {
14449       EVT TruncType = LS.Inst->getValueType(0);
14450       EVT LoadedType = LS.getLoadedType();
14451       if (TruncType != LoadedType &&
14452           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
14453         ZExts = 1;
14454     }
14455 
14456     /// Account for slicing gain in the current cost.
14457     /// Slicing provide a few gains like removing a shift or a
14458     /// truncate. This method allows to grow the cost of the original
14459     /// load with the gain from this slice.
addSliceGain__anon4d358cf62211::LoadedSlice::Cost14460     void addSliceGain(const LoadedSlice &LS) {
14461       // Each slice saves a truncate.
14462       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
14463       if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(),
14464                               LS.Inst->getValueType(0)))
14465         ++Truncates;
14466       // If there is a shift amount, this slice gets rid of it.
14467       if (LS.Shift)
14468         ++Shift;
14469       // If this slice can merge a cross register bank copy, account for it.
14470       if (LS.canMergeExpensiveCrossRegisterBankCopy())
14471         ++CrossRegisterBanksCopies;
14472     }
14473 
operator +=__anon4d358cf62211::LoadedSlice::Cost14474     Cost &operator+=(const Cost &RHS) {
14475       Loads += RHS.Loads;
14476       Truncates += RHS.Truncates;
14477       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
14478       ZExts += RHS.ZExts;
14479       Shift += RHS.Shift;
14480       return *this;
14481     }
14482 
operator ==__anon4d358cf62211::LoadedSlice::Cost14483     bool operator==(const Cost &RHS) const {
14484       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
14485              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
14486              ZExts == RHS.ZExts && Shift == RHS.Shift;
14487     }
14488 
operator !=__anon4d358cf62211::LoadedSlice::Cost14489     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
14490 
operator <__anon4d358cf62211::LoadedSlice::Cost14491     bool operator<(const Cost &RHS) const {
14492       // Assume cross register banks copies are as expensive as loads.
14493       // FIXME: Do we want some more target hooks?
14494       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
14495       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
14496       // Unless we are optimizing for code size, consider the
14497       // expensive operation first.
14498       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
14499         return ExpensiveOpsLHS < ExpensiveOpsRHS;
14500       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
14501              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
14502     }
14503 
operator >__anon4d358cf62211::LoadedSlice::Cost14504     bool operator>(const Cost &RHS) const { return RHS < *this; }
14505 
operator <=__anon4d358cf62211::LoadedSlice::Cost14506     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
14507 
operator >=__anon4d358cf62211::LoadedSlice::Cost14508     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
14509   };
14510 
14511   // The last instruction that represent the slice. This should be a
14512   // truncate instruction.
14513   SDNode *Inst;
14514 
14515   // The original load instruction.
14516   LoadSDNode *Origin;
14517 
14518   // The right shift amount in bits from the original load.
14519   unsigned Shift;
14520 
14521   // The DAG from which Origin came from.
14522   // This is used to get some contextual information about legal types, etc.
14523   SelectionDAG *DAG;
14524 
LoadedSlice__anon4d358cf62211::LoadedSlice14525   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
14526               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
14527       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
14528 
14529   /// Get the bits used in a chunk of bits \p BitWidth large.
14530   /// \return Result is \p BitWidth and has used bits set to 1 and
14531   ///         not used bits set to 0.
getUsedBits__anon4d358cf62211::LoadedSlice14532   APInt getUsedBits() const {
14533     // Reproduce the trunc(lshr) sequence:
14534     // - Start from the truncated value.
14535     // - Zero extend to the desired bit width.
14536     // - Shift left.
14537     assert(Origin && "No original load to compare against.");
14538     unsigned BitWidth = Origin->getValueSizeInBits(0);
14539     assert(Inst && "This slice is not bound to an instruction");
14540     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
14541            "Extracted slice is bigger than the whole type!");
14542     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
14543     UsedBits.setAllBits();
14544     UsedBits = UsedBits.zext(BitWidth);
14545     UsedBits <<= Shift;
14546     return UsedBits;
14547   }
14548 
14549   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon4d358cf62211::LoadedSlice14550   unsigned getLoadedSize() const {
14551     unsigned SliceSize = getUsedBits().countPopulation();
14552     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
14553     return SliceSize / 8;
14554   }
14555 
14556   /// Get the type that will be loaded for this slice.
14557   /// Note: This may not be the final type for the slice.
getLoadedType__anon4d358cf62211::LoadedSlice14558   EVT getLoadedType() const {
14559     assert(DAG && "Missing context");
14560     LLVMContext &Ctxt = *DAG->getContext();
14561     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
14562   }
14563 
14564   /// Get the alignment of the load used for this slice.
getAlignment__anon4d358cf62211::LoadedSlice14565   unsigned getAlignment() const {
14566     unsigned Alignment = Origin->getAlignment();
14567     uint64_t Offset = getOffsetFromBase();
14568     if (Offset != 0)
14569       Alignment = MinAlign(Alignment, Alignment + Offset);
14570     return Alignment;
14571   }
14572 
14573   /// Check if this slice can be rewritten with legal operations.
isLegal__anon4d358cf62211::LoadedSlice14574   bool isLegal() const {
14575     // An invalid slice is not legal.
14576     if (!Origin || !Inst || !DAG)
14577       return false;
14578 
14579     // Offsets are for indexed load only, we do not handle that.
14580     if (!Origin->getOffset().isUndef())
14581       return false;
14582 
14583     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
14584 
14585     // Check that the type is legal.
14586     EVT SliceType = getLoadedType();
14587     if (!TLI.isTypeLegal(SliceType))
14588       return false;
14589 
14590     // Check that the load is legal for this type.
14591     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
14592       return false;
14593 
14594     // Check that the offset can be computed.
14595     // 1. Check its type.
14596     EVT PtrType = Origin->getBasePtr().getValueType();
14597     if (PtrType == MVT::Untyped || PtrType.isExtended())
14598       return false;
14599 
14600     // 2. Check that it fits in the immediate.
14601     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
14602       return false;
14603 
14604     // 3. Check that the computation is legal.
14605     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
14606       return false;
14607 
14608     // Check that the zext is legal if it needs one.
14609     EVT TruncateType = Inst->getValueType(0);
14610     if (TruncateType != SliceType &&
14611         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
14612       return false;
14613 
14614     return true;
14615   }
14616 
14617   /// Get the offset in bytes of this slice in the original chunk of
14618   /// bits.
14619   /// \pre DAG != nullptr.
getOffsetFromBase__anon4d358cf62211::LoadedSlice14620   uint64_t getOffsetFromBase() const {
14621     assert(DAG && "Missing context.");
14622     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
14623     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
14624     uint64_t Offset = Shift / 8;
14625     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
14626     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
14627            "The size of the original loaded type is not a multiple of a"
14628            " byte.");
14629     // If Offset is bigger than TySizeInBytes, it means we are loading all
14630     // zeros. This should have been optimized before in the process.
14631     assert(TySizeInBytes > Offset &&
14632            "Invalid shift amount for given loaded size");
14633     if (IsBigEndian)
14634       Offset = TySizeInBytes - Offset - getLoadedSize();
14635     return Offset;
14636   }
14637 
14638   /// Generate the sequence of instructions to load the slice
14639   /// represented by this object and redirect the uses of this slice to
14640   /// this new sequence of instructions.
14641   /// \pre this->Inst && this->Origin are valid Instructions and this
14642   /// object passed the legal check: LoadedSlice::isLegal returned true.
14643   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon4d358cf62211::LoadedSlice14644   SDValue loadSlice() const {
14645     assert(Inst && Origin && "Unable to replace a non-existing slice.");
14646     const SDValue &OldBaseAddr = Origin->getBasePtr();
14647     SDValue BaseAddr = OldBaseAddr;
14648     // Get the offset in that chunk of bytes w.r.t. the endianness.
14649     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
14650     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
14651     if (Offset) {
14652       // BaseAddr = BaseAddr + Offset.
14653       EVT ArithType = BaseAddr.getValueType();
14654       SDLoc DL(Origin);
14655       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
14656                               DAG->getConstant(Offset, DL, ArithType));
14657     }
14658 
14659     // Create the type of the loaded slice according to its size.
14660     EVT SliceType = getLoadedType();
14661 
14662     // Create the load for the slice.
14663     SDValue LastInst =
14664         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
14665                      Origin->getPointerInfo().getWithOffset(Offset),
14666                      getAlignment(), Origin->getMemOperand()->getFlags());
14667     // If the final type is not the same as the loaded type, this means that
14668     // we have to pad with zero. Create a zero extend for that.
14669     EVT FinalType = Inst->getValueType(0);
14670     if (SliceType != FinalType)
14671       LastInst =
14672           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
14673     return LastInst;
14674   }
14675 
14676   /// Check if this slice can be merged with an expensive cross register
14677   /// bank copy. E.g.,
14678   /// i = load i32
14679   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon4d358cf62211::LoadedSlice14680   bool canMergeExpensiveCrossRegisterBankCopy() const {
14681     if (!Inst || !Inst->hasOneUse())
14682       return false;
14683     SDNode *Use = *Inst->use_begin();
14684     if (Use->getOpcode() != ISD::BITCAST)
14685       return false;
14686     assert(DAG && "Missing context");
14687     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
14688     EVT ResVT = Use->getValueType(0);
14689     const TargetRegisterClass *ResRC =
14690         TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
14691     const TargetRegisterClass *ArgRC =
14692         TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
14693                            Use->getOperand(0)->isDivergent());
14694     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
14695       return false;
14696 
14697     // At this point, we know that we perform a cross-register-bank copy.
14698     // Check if it is expensive.
14699     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
14700     // Assume bitcasts are cheap, unless both register classes do not
14701     // explicitly share a common sub class.
14702     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
14703       return false;
14704 
14705     // Check if it will be merged with the load.
14706     // 1. Check the alignment constraint.
14707     unsigned RequiredAlignment = DAG->getDataLayout().getABITypeAlignment(
14708         ResVT.getTypeForEVT(*DAG->getContext()));
14709 
14710     if (RequiredAlignment > getAlignment())
14711       return false;
14712 
14713     // 2. Check that the load is a legal operation for that type.
14714     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
14715       return false;
14716 
14717     // 3. Check that we do not have a zext in the way.
14718     if (Inst->getValueType(0) != getLoadedType())
14719       return false;
14720 
14721     return true;
14722   }
14723 };
14724 
14725 } // end anonymous namespace
14726 
14727 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
14728 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)14729 static bool areUsedBitsDense(const APInt &UsedBits) {
14730   // If all the bits are one, this is dense!
14731   if (UsedBits.isAllOnesValue())
14732     return true;
14733 
14734   // Get rid of the unused bits on the right.
14735   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros());
14736   // Get rid of the unused bits on the left.
14737   if (NarrowedUsedBits.countLeadingZeros())
14738     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
14739   // Check that the chunk of bits is completely used.
14740   return NarrowedUsedBits.isAllOnesValue();
14741 }
14742 
14743 /// Check whether or not \p First and \p Second are next to each other
14744 /// in memory. This means that there is no hole between the bits loaded
14745 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)14746 static bool areSlicesNextToEachOther(const LoadedSlice &First,
14747                                      const LoadedSlice &Second) {
14748   assert(First.Origin == Second.Origin && First.Origin &&
14749          "Unable to match different memory origins.");
14750   APInt UsedBits = First.getUsedBits();
14751   assert((UsedBits & Second.getUsedBits()) == 0 &&
14752          "Slices are not supposed to overlap.");
14753   UsedBits |= Second.getUsedBits();
14754   return areUsedBitsDense(UsedBits);
14755 }
14756 
14757 /// Adjust the \p GlobalLSCost according to the target
14758 /// paring capabilities and the layout of the slices.
14759 /// \pre \p GlobalLSCost should account for at least as many loads as
14760 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)14761 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
14762                                  LoadedSlice::Cost &GlobalLSCost) {
14763   unsigned NumberOfSlices = LoadedSlices.size();
14764   // If there is less than 2 elements, no pairing is possible.
14765   if (NumberOfSlices < 2)
14766     return;
14767 
14768   // Sort the slices so that elements that are likely to be next to each
14769   // other in memory are next to each other in the list.
14770   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
14771     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
14772     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
14773   });
14774   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
14775   // First (resp. Second) is the first (resp. Second) potentially candidate
14776   // to be placed in a paired load.
14777   const LoadedSlice *First = nullptr;
14778   const LoadedSlice *Second = nullptr;
14779   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
14780                 // Set the beginning of the pair.
14781                                                            First = Second) {
14782     Second = &LoadedSlices[CurrSlice];
14783 
14784     // If First is NULL, it means we start a new pair.
14785     // Get to the next slice.
14786     if (!First)
14787       continue;
14788 
14789     EVT LoadedType = First->getLoadedType();
14790 
14791     // If the types of the slices are different, we cannot pair them.
14792     if (LoadedType != Second->getLoadedType())
14793       continue;
14794 
14795     // Check if the target supplies paired loads for this type.
14796     unsigned RequiredAlignment = 0;
14797     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
14798       // move to the next pair, this type is hopeless.
14799       Second = nullptr;
14800       continue;
14801     }
14802     // Check if we meet the alignment requirement.
14803     if (RequiredAlignment > First->getAlignment())
14804       continue;
14805 
14806     // Check that both loads are next to each other in memory.
14807     if (!areSlicesNextToEachOther(*First, *Second))
14808       continue;
14809 
14810     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
14811     --GlobalLSCost.Loads;
14812     // Move to the next pair.
14813     Second = nullptr;
14814   }
14815 }
14816 
14817 /// Check the profitability of all involved LoadedSlice.
14818 /// Currently, it is considered profitable if there is exactly two
14819 /// involved slices (1) which are (2) next to each other in memory, and
14820 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
14821 ///
14822 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
14823 /// the elements themselves.
14824 ///
14825 /// FIXME: When the cost model will be mature enough, we can relax
14826 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)14827 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
14828                                 const APInt &UsedBits, bool ForCodeSize) {
14829   unsigned NumberOfSlices = LoadedSlices.size();
14830   if (StressLoadSlicing)
14831     return NumberOfSlices > 1;
14832 
14833   // Check (1).
14834   if (NumberOfSlices != 2)
14835     return false;
14836 
14837   // Check (2).
14838   if (!areUsedBitsDense(UsedBits))
14839     return false;
14840 
14841   // Check (3).
14842   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
14843   // The original code has one big load.
14844   OrigCost.Loads = 1;
14845   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
14846     const LoadedSlice &LS = LoadedSlices[CurrSlice];
14847     // Accumulate the cost of all the slices.
14848     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
14849     GlobalSlicingCost += SliceCost;
14850 
14851     // Account as cost in the original configuration the gain obtained
14852     // with the current slices.
14853     OrigCost.addSliceGain(LS);
14854   }
14855 
14856   // If the target supports paired load, adjust the cost accordingly.
14857   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
14858   return OrigCost > GlobalSlicingCost;
14859 }
14860 
14861 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
14862 /// operations, split it in the various pieces being extracted.
14863 ///
14864 /// This sort of thing is introduced by SROA.
14865 /// This slicing takes care not to insert overlapping loads.
14866 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)14867 bool DAGCombiner::SliceUpLoad(SDNode *N) {
14868   if (Level < AfterLegalizeDAG)
14869     return false;
14870 
14871   LoadSDNode *LD = cast<LoadSDNode>(N);
14872   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
14873       !LD->getValueType(0).isInteger())
14874     return false;
14875 
14876   // Keep track of already used bits to detect overlapping values.
14877   // In that case, we will just abort the transformation.
14878   APInt UsedBits(LD->getValueSizeInBits(0), 0);
14879 
14880   SmallVector<LoadedSlice, 4> LoadedSlices;
14881 
14882   // Check if this load is used as several smaller chunks of bits.
14883   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
14884   // of computation for each trunc.
14885   for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
14886        UI != UIEnd; ++UI) {
14887     // Skip the uses of the chain.
14888     if (UI.getUse().getResNo() != 0)
14889       continue;
14890 
14891     SDNode *User = *UI;
14892     unsigned Shift = 0;
14893 
14894     // Check if this is a trunc(lshr).
14895     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
14896         isa<ConstantSDNode>(User->getOperand(1))) {
14897       Shift = User->getConstantOperandVal(1);
14898       User = *User->use_begin();
14899     }
14900 
14901     // At this point, User is a Truncate, iff we encountered, trunc or
14902     // trunc(lshr).
14903     if (User->getOpcode() != ISD::TRUNCATE)
14904       return false;
14905 
14906     // The width of the type must be a power of 2 and greater than 8-bits.
14907     // Otherwise the load cannot be represented in LLVM IR.
14908     // Moreover, if we shifted with a non-8-bits multiple, the slice
14909     // will be across several bytes. We do not support that.
14910     unsigned Width = User->getValueSizeInBits(0);
14911     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
14912       return false;
14913 
14914     // Build the slice for this chain of computations.
14915     LoadedSlice LS(User, LD, Shift, &DAG);
14916     APInt CurrentUsedBits = LS.getUsedBits();
14917 
14918     // Check if this slice overlaps with another.
14919     if ((CurrentUsedBits & UsedBits) != 0)
14920       return false;
14921     // Update the bits used globally.
14922     UsedBits |= CurrentUsedBits;
14923 
14924     // Check if the new slice would be legal.
14925     if (!LS.isLegal())
14926       return false;
14927 
14928     // Record the slice.
14929     LoadedSlices.push_back(LS);
14930   }
14931 
14932   // Abort slicing if it does not seem to be profitable.
14933   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
14934     return false;
14935 
14936   ++SlicedLoads;
14937 
14938   // Rewrite each chain to use an independent load.
14939   // By construction, each chain can be represented by a unique load.
14940 
14941   // Prepare the argument for the new token factor for all the slices.
14942   SmallVector<SDValue, 8> ArgChains;
14943   for (SmallVectorImpl<LoadedSlice>::const_iterator
14944            LSIt = LoadedSlices.begin(),
14945            LSItEnd = LoadedSlices.end();
14946        LSIt != LSItEnd; ++LSIt) {
14947     SDValue SliceInst = LSIt->loadSlice();
14948     CombineTo(LSIt->Inst, SliceInst, true);
14949     if (SliceInst.getOpcode() != ISD::LOAD)
14950       SliceInst = SliceInst.getOperand(0);
14951     assert(SliceInst->getOpcode() == ISD::LOAD &&
14952            "It takes more than a zext to get to the loaded slice!!");
14953     ArgChains.push_back(SliceInst.getValue(1));
14954   }
14955 
14956   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
14957                               ArgChains);
14958   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
14959   AddToWorklist(Chain.getNode());
14960   return true;
14961 }
14962 
14963 /// Check to see if V is (and load (ptr), imm), where the load is having
14964 /// specific bytes cleared out.  If so, return the byte size being masked out
14965 /// and the shift amount.
14966 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)14967 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
14968   std::pair<unsigned, unsigned> Result(0, 0);
14969 
14970   // Check for the structure we're looking for.
14971   if (V->getOpcode() != ISD::AND ||
14972       !isa<ConstantSDNode>(V->getOperand(1)) ||
14973       !ISD::isNormalLoad(V->getOperand(0).getNode()))
14974     return Result;
14975 
14976   // Check the chain and pointer.
14977   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
14978   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
14979 
14980   // This only handles simple types.
14981   if (V.getValueType() != MVT::i16 &&
14982       V.getValueType() != MVT::i32 &&
14983       V.getValueType() != MVT::i64)
14984     return Result;
14985 
14986   // Check the constant mask.  Invert it so that the bits being masked out are
14987   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
14988   // follow the sign bit for uniformity.
14989   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
14990   unsigned NotMaskLZ = countLeadingZeros(NotMask);
14991   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
14992   unsigned NotMaskTZ = countTrailingZeros(NotMask);
14993   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
14994   if (NotMaskLZ == 64) return Result;  // All zero mask.
14995 
14996   // See if we have a continuous run of bits.  If so, we have 0*1+0*
14997   if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
14998     return Result;
14999 
15000   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
15001   if (V.getValueType() != MVT::i64 && NotMaskLZ)
15002     NotMaskLZ -= 64-V.getValueSizeInBits();
15003 
15004   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
15005   switch (MaskedBytes) {
15006   case 1:
15007   case 2:
15008   case 4: break;
15009   default: return Result; // All one mask, or 5-byte mask.
15010   }
15011 
15012   // Verify that the first bit starts at a multiple of mask so that the access
15013   // is aligned the same as the access width.
15014   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
15015 
15016   // For narrowing to be valid, it must be the case that the load the
15017   // immediately preceding memory operation before the store.
15018   if (LD == Chain.getNode())
15019     ; // ok.
15020   else if (Chain->getOpcode() == ISD::TokenFactor &&
15021            SDValue(LD, 1).hasOneUse()) {
15022     // LD has only 1 chain use so they are no indirect dependencies.
15023     if (!LD->isOperandOf(Chain.getNode()))
15024       return Result;
15025   } else
15026     return Result; // Fail.
15027 
15028   Result.first = MaskedBytes;
15029   Result.second = NotMaskTZ/8;
15030   return Result;
15031 }
15032 
15033 /// Check to see if IVal is something that provides a value as specified by
15034 /// MaskInfo. If so, replace the specified store with a narrower store of
15035 /// truncated IVal.
15036 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)15037 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
15038                                 SDValue IVal, StoreSDNode *St,
15039                                 DAGCombiner *DC) {
15040   unsigned NumBytes = MaskInfo.first;
15041   unsigned ByteShift = MaskInfo.second;
15042   SelectionDAG &DAG = DC->getDAG();
15043 
15044   // Check to see if IVal is all zeros in the part being masked in by the 'or'
15045   // that uses this.  If not, this is not a replacement.
15046   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
15047                                   ByteShift*8, (ByteShift+NumBytes)*8);
15048   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
15049 
15050   // Check that it is legal on the target to do this.  It is legal if the new
15051   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
15052   // legalization (and the target doesn't explicitly think this is a bad idea).
15053   MVT VT = MVT::getIntegerVT(NumBytes * 8);
15054   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
15055   if (!DC->isTypeLegal(VT))
15056     return SDValue();
15057   if (St->getMemOperand() &&
15058       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
15059                               *St->getMemOperand()))
15060     return SDValue();
15061 
15062   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
15063   // shifted by ByteShift and truncated down to NumBytes.
15064   if (ByteShift) {
15065     SDLoc DL(IVal);
15066     IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
15067                        DAG.getConstant(ByteShift*8, DL,
15068                                     DC->getShiftAmountTy(IVal.getValueType())));
15069   }
15070 
15071   // Figure out the offset for the store and the alignment of the access.
15072   unsigned StOffset;
15073   unsigned NewAlign = St->getAlignment();
15074 
15075   if (DAG.getDataLayout().isLittleEndian())
15076     StOffset = ByteShift;
15077   else
15078     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
15079 
15080   SDValue Ptr = St->getBasePtr();
15081   if (StOffset) {
15082     SDLoc DL(IVal);
15083     Ptr = DAG.getMemBasePlusOffset(Ptr, StOffset, DL);
15084     NewAlign = MinAlign(NewAlign, StOffset);
15085   }
15086 
15087   // Truncate down to the new size.
15088   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
15089 
15090   ++OpsNarrowed;
15091   return DAG
15092       .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
15093                 St->getPointerInfo().getWithOffset(StOffset), NewAlign);
15094 }
15095 
15096 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
15097 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
15098 /// narrowing the load and store if it would end up being a win for performance
15099 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)15100 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
15101   StoreSDNode *ST  = cast<StoreSDNode>(N);
15102   if (!ST->isSimple())
15103     return SDValue();
15104 
15105   SDValue Chain = ST->getChain();
15106   SDValue Value = ST->getValue();
15107   SDValue Ptr   = ST->getBasePtr();
15108   EVT VT = Value.getValueType();
15109 
15110   if (ST->isTruncatingStore() || VT.isVector() || !Value.hasOneUse())
15111     return SDValue();
15112 
15113   unsigned Opc = Value.getOpcode();
15114 
15115   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
15116   // is a byte mask indicating a consecutive number of bytes, check to see if
15117   // Y is known to provide just those bytes.  If so, we try to replace the
15118   // load + replace + store sequence with a single (narrower) store, which makes
15119   // the load dead.
15120   if (Opc == ISD::OR) {
15121     std::pair<unsigned, unsigned> MaskedLoad;
15122     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
15123     if (MaskedLoad.first)
15124       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
15125                                                   Value.getOperand(1), ST,this))
15126         return NewST;
15127 
15128     // Or is commutative, so try swapping X and Y.
15129     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
15130     if (MaskedLoad.first)
15131       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
15132                                                   Value.getOperand(0), ST,this))
15133         return NewST;
15134   }
15135 
15136   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
15137       Value.getOperand(1).getOpcode() != ISD::Constant)
15138     return SDValue();
15139 
15140   SDValue N0 = Value.getOperand(0);
15141   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
15142       Chain == SDValue(N0.getNode(), 1)) {
15143     LoadSDNode *LD = cast<LoadSDNode>(N0);
15144     if (LD->getBasePtr() != Ptr ||
15145         LD->getPointerInfo().getAddrSpace() !=
15146         ST->getPointerInfo().getAddrSpace())
15147       return SDValue();
15148 
15149     // Find the type to narrow it the load / op / store to.
15150     SDValue N1 = Value.getOperand(1);
15151     unsigned BitWidth = N1.getValueSizeInBits();
15152     APInt Imm = cast<ConstantSDNode>(N1)->getAPIntValue();
15153     if (Opc == ISD::AND)
15154       Imm ^= APInt::getAllOnesValue(BitWidth);
15155     if (Imm == 0 || Imm.isAllOnesValue())
15156       return SDValue();
15157     unsigned ShAmt = Imm.countTrailingZeros();
15158     unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1;
15159     unsigned NewBW = NextPowerOf2(MSB - ShAmt);
15160     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
15161     // The narrowing should be profitable, the load/store operation should be
15162     // legal (or custom) and the store size should be equal to the NewVT width.
15163     while (NewBW < BitWidth &&
15164            (NewVT.getStoreSizeInBits() != NewBW ||
15165             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
15166             !TLI.isNarrowingProfitable(VT, NewVT))) {
15167       NewBW = NextPowerOf2(NewBW);
15168       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
15169     }
15170     if (NewBW >= BitWidth)
15171       return SDValue();
15172 
15173     // If the lsb changed does not start at the type bitwidth boundary,
15174     // start at the previous one.
15175     if (ShAmt % NewBW)
15176       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
15177     APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
15178                                    std::min(BitWidth, ShAmt + NewBW));
15179     if ((Imm & Mask) == Imm) {
15180       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
15181       if (Opc == ISD::AND)
15182         NewImm ^= APInt::getAllOnesValue(NewBW);
15183       uint64_t PtrOff = ShAmt / 8;
15184       // For big endian targets, we need to adjust the offset to the pointer to
15185       // load the correct bytes.
15186       if (DAG.getDataLayout().isBigEndian())
15187         PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
15188 
15189       unsigned NewAlign = MinAlign(LD->getAlignment(), PtrOff);
15190       Type *NewVTTy = NewVT.getTypeForEVT(*DAG.getContext());
15191       if (NewAlign < DAG.getDataLayout().getABITypeAlignment(NewVTTy))
15192         return SDValue();
15193 
15194       SDValue NewPtr = DAG.getMemBasePlusOffset(Ptr, PtrOff, SDLoc(LD));
15195       SDValue NewLD =
15196           DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
15197                       LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
15198                       LD->getMemOperand()->getFlags(), LD->getAAInfo());
15199       SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
15200                                    DAG.getConstant(NewImm, SDLoc(Value),
15201                                                    NewVT));
15202       SDValue NewST =
15203           DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
15204                        ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
15205 
15206       AddToWorklist(NewPtr.getNode());
15207       AddToWorklist(NewLD.getNode());
15208       AddToWorklist(NewVal.getNode());
15209       WorklistRemover DeadNodes(*this);
15210       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
15211       ++OpsNarrowed;
15212       return NewST;
15213     }
15214   }
15215 
15216   return SDValue();
15217 }
15218 
15219 /// For a given floating point load / store pair, if the load value isn't used
15220 /// by any other operations, then consider transforming the pair to integer
15221 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)15222 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
15223   StoreSDNode *ST  = cast<StoreSDNode>(N);
15224   SDValue Value = ST->getValue();
15225   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
15226       Value.hasOneUse()) {
15227     LoadSDNode *LD = cast<LoadSDNode>(Value);
15228     EVT VT = LD->getMemoryVT();
15229     if (!VT.isFloatingPoint() ||
15230         VT != ST->getMemoryVT() ||
15231         LD->isNonTemporal() ||
15232         ST->isNonTemporal() ||
15233         LD->getPointerInfo().getAddrSpace() != 0 ||
15234         ST->getPointerInfo().getAddrSpace() != 0)
15235       return SDValue();
15236 
15237     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
15238     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
15239         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
15240         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
15241         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT))
15242       return SDValue();
15243 
15244     unsigned LDAlign = LD->getAlignment();
15245     unsigned STAlign = ST->getAlignment();
15246     Type *IntVTTy = IntVT.getTypeForEVT(*DAG.getContext());
15247     unsigned ABIAlign = DAG.getDataLayout().getABITypeAlignment(IntVTTy);
15248     if (LDAlign < ABIAlign || STAlign < ABIAlign)
15249       return SDValue();
15250 
15251     SDValue NewLD =
15252         DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
15253                     LD->getPointerInfo(), LDAlign);
15254 
15255     SDValue NewST =
15256         DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
15257                      ST->getPointerInfo(), STAlign);
15258 
15259     AddToWorklist(NewLD.getNode());
15260     AddToWorklist(NewST.getNode());
15261     WorklistRemover DeadNodes(*this);
15262     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
15263     ++LdStFP2Int;
15264     return NewST;
15265   }
15266 
15267   return SDValue();
15268 }
15269 
15270 // This is a helper function for visitMUL to check the profitability
15271 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
15272 // MulNode is the original multiply, AddNode is (add x, c1),
15273 // and ConstNode is c2.
15274 //
15275 // If the (add x, c1) has multiple uses, we could increase
15276 // the number of adds if we make this transformation.
15277 // It would only be worth doing this if we can remove a
15278 // multiply in the process. Check for that here.
15279 // To illustrate:
15280 //     (A + c1) * c3
15281 //     (A + c2) * c3
15282 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue & AddNode,SDValue & ConstNode)15283 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
15284                                               SDValue &AddNode,
15285                                               SDValue &ConstNode) {
15286   APInt Val;
15287 
15288   // If the add only has one use, this would be OK to do.
15289   if (AddNode.getNode()->hasOneUse())
15290     return true;
15291 
15292   // Walk all the users of the constant with which we're multiplying.
15293   for (SDNode *Use : ConstNode->uses()) {
15294     if (Use == MulNode) // This use is the one we're on right now. Skip it.
15295       continue;
15296 
15297     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
15298       SDNode *OtherOp;
15299       SDNode *MulVar = AddNode.getOperand(0).getNode();
15300 
15301       // OtherOp is what we're multiplying against the constant.
15302       if (Use->getOperand(0) == ConstNode)
15303         OtherOp = Use->getOperand(1).getNode();
15304       else
15305         OtherOp = Use->getOperand(0).getNode();
15306 
15307       // Check to see if multiply is with the same operand of our "add".
15308       //
15309       //     ConstNode  = CONST
15310       //     Use = ConstNode * A  <-- visiting Use. OtherOp is A.
15311       //     ...
15312       //     AddNode  = (A + c1)  <-- MulVar is A.
15313       //         = AddNode * ConstNode   <-- current visiting instruction.
15314       //
15315       // If we make this transformation, we will have a common
15316       // multiply (ConstNode * A) that we can save.
15317       if (OtherOp == MulVar)
15318         return true;
15319 
15320       // Now check to see if a future expansion will give us a common
15321       // multiply.
15322       //
15323       //     ConstNode  = CONST
15324       //     AddNode    = (A + c1)
15325       //     ...   = AddNode * ConstNode <-- current visiting instruction.
15326       //     ...
15327       //     OtherOp = (A + c2)
15328       //     Use     = OtherOp * ConstNode <-- visiting Use.
15329       //
15330       // If we make this transformation, we will have a common
15331       // multiply (CONST * A) after we also do the same transformation
15332       // to the "t2" instruction.
15333       if (OtherOp->getOpcode() == ISD::ADD &&
15334           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
15335           OtherOp->getOperand(0).getNode() == MulVar)
15336         return true;
15337     }
15338   }
15339 
15340   // Didn't find a case where this would be profitable.
15341   return false;
15342 }
15343 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)15344 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
15345                                          unsigned NumStores) {
15346   SmallVector<SDValue, 8> Chains;
15347   SmallPtrSet<const SDNode *, 8> Visited;
15348   SDLoc StoreDL(StoreNodes[0].MemNode);
15349 
15350   for (unsigned i = 0; i < NumStores; ++i) {
15351     Visited.insert(StoreNodes[i].MemNode);
15352   }
15353 
15354   // don't include nodes that are children or repeated nodes.
15355   for (unsigned i = 0; i < NumStores; ++i) {
15356     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
15357       Chains.push_back(StoreNodes[i].MemNode->getChain());
15358   }
15359 
15360   assert(Chains.size() > 0 && "Chain should have generated a chain");
15361   return DAG.getTokenFactor(StoreDL, Chains);
15362 }
15363 
MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)15364 bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
15365     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
15366     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
15367   // Make sure we have something to merge.
15368   if (NumStores < 2)
15369     return false;
15370 
15371   // The latest Node in the DAG.
15372   SDLoc DL(StoreNodes[0].MemNode);
15373 
15374   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
15375   unsigned SizeInBits = NumStores * ElementSizeBits;
15376   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
15377 
15378   EVT StoreTy;
15379   if (UseVector) {
15380     unsigned Elts = NumStores * NumMemElts;
15381     // Get the type for the merged vector store.
15382     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
15383   } else
15384     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
15385 
15386   SDValue StoredVal;
15387   if (UseVector) {
15388     if (IsConstantSrc) {
15389       SmallVector<SDValue, 8> BuildVector;
15390       for (unsigned I = 0; I != NumStores; ++I) {
15391         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
15392         SDValue Val = St->getValue();
15393         // If constant is of the wrong type, convert it now.
15394         if (MemVT != Val.getValueType()) {
15395           Val = peekThroughBitcasts(Val);
15396           // Deal with constants of wrong size.
15397           if (ElementSizeBits != Val.getValueSizeInBits()) {
15398             EVT IntMemVT =
15399                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
15400             if (isa<ConstantFPSDNode>(Val)) {
15401               // Not clear how to truncate FP values.
15402               return false;
15403             } else if (auto *C = dyn_cast<ConstantSDNode>(Val))
15404               Val = DAG.getConstant(C->getAPIntValue()
15405                                         .zextOrTrunc(Val.getValueSizeInBits())
15406                                         .zextOrTrunc(ElementSizeBits),
15407                                     SDLoc(C), IntMemVT);
15408           }
15409           // Make sure correctly size type is the correct type.
15410           Val = DAG.getBitcast(MemVT, Val);
15411         }
15412         BuildVector.push_back(Val);
15413       }
15414       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
15415                                                : ISD::BUILD_VECTOR,
15416                               DL, StoreTy, BuildVector);
15417     } else {
15418       SmallVector<SDValue, 8> Ops;
15419       for (unsigned i = 0; i < NumStores; ++i) {
15420         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
15421         SDValue Val = peekThroughBitcasts(St->getValue());
15422         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
15423         // type MemVT. If the underlying value is not the correct
15424         // type, but it is an extraction of an appropriate vector we
15425         // can recast Val to be of the correct type. This may require
15426         // converting between EXTRACT_VECTOR_ELT and
15427         // EXTRACT_SUBVECTOR.
15428         if ((MemVT != Val.getValueType()) &&
15429             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
15430              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
15431           EVT MemVTScalarTy = MemVT.getScalarType();
15432           // We may need to add a bitcast here to get types to line up.
15433           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
15434             Val = DAG.getBitcast(MemVT, Val);
15435           } else {
15436             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
15437                                             : ISD::EXTRACT_VECTOR_ELT;
15438             SDValue Vec = Val.getOperand(0);
15439             SDValue Idx = Val.getOperand(1);
15440             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
15441           }
15442         }
15443         Ops.push_back(Val);
15444       }
15445 
15446       // Build the extracted vector elements back into a vector.
15447       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
15448                                                : ISD::BUILD_VECTOR,
15449                               DL, StoreTy, Ops);
15450     }
15451   } else {
15452     // We should always use a vector store when merging extracted vector
15453     // elements, so this path implies a store of constants.
15454     assert(IsConstantSrc && "Merged vector elements should use vector store");
15455 
15456     APInt StoreInt(SizeInBits, 0);
15457 
15458     // Construct a single integer constant which is made of the smaller
15459     // constant inputs.
15460     bool IsLE = DAG.getDataLayout().isLittleEndian();
15461     for (unsigned i = 0; i < NumStores; ++i) {
15462       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
15463       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
15464 
15465       SDValue Val = St->getValue();
15466       Val = peekThroughBitcasts(Val);
15467       StoreInt <<= ElementSizeBits;
15468       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
15469         StoreInt |= C->getAPIntValue()
15470                         .zextOrTrunc(ElementSizeBits)
15471                         .zextOrTrunc(SizeInBits);
15472       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
15473         StoreInt |= C->getValueAPF()
15474                         .bitcastToAPInt()
15475                         .zextOrTrunc(ElementSizeBits)
15476                         .zextOrTrunc(SizeInBits);
15477         // If fp truncation is necessary give up for now.
15478         if (MemVT.getSizeInBits() != ElementSizeBits)
15479           return false;
15480       } else {
15481         llvm_unreachable("Invalid constant element type");
15482       }
15483     }
15484 
15485     // Create the new Load and Store operations.
15486     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
15487   }
15488 
15489   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
15490   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
15491 
15492   // make sure we use trunc store if it's necessary to be legal.
15493   SDValue NewStore;
15494   if (!UseTrunc) {
15495     NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
15496                             FirstInChain->getPointerInfo(),
15497                             FirstInChain->getAlignment());
15498   } else { // Must be realized as a trunc store
15499     EVT LegalizedStoredValTy =
15500         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
15501     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
15502     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
15503     SDValue ExtendedStoreVal =
15504         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
15505                         LegalizedStoredValTy);
15506     NewStore = DAG.getTruncStore(
15507         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
15508         FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
15509         FirstInChain->getAlignment(),
15510         FirstInChain->getMemOperand()->getFlags());
15511   }
15512 
15513   // Replace all merged stores with the new store.
15514   for (unsigned i = 0; i < NumStores; ++i)
15515     CombineTo(StoreNodes[i].MemNode, NewStore);
15516 
15517   AddToWorklist(NewChain.getNode());
15518   return true;
15519 }
15520 
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)15521 void DAGCombiner::getStoreMergeCandidates(
15522     StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
15523     SDNode *&RootNode) {
15524   // This holds the base pointer, index, and the offset in bytes from the base
15525   // pointer.
15526   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
15527   EVT MemVT = St->getMemoryVT();
15528 
15529   SDValue Val = peekThroughBitcasts(St->getValue());
15530   // We must have a base and an offset.
15531   if (!BasePtr.getBase().getNode())
15532     return;
15533 
15534   // Do not handle stores to undef base pointers.
15535   if (BasePtr.getBase().isUndef())
15536     return;
15537 
15538   bool IsConstantSrc = isa<ConstantSDNode>(Val) || isa<ConstantFPSDNode>(Val);
15539   bool IsExtractVecSrc = (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
15540                           Val.getOpcode() == ISD::EXTRACT_SUBVECTOR);
15541   bool IsLoadSrc = isa<LoadSDNode>(Val);
15542   BaseIndexOffset LBasePtr;
15543   // Match on loadbaseptr if relevant.
15544   EVT LoadVT;
15545   if (IsLoadSrc) {
15546     auto *Ld = cast<LoadSDNode>(Val);
15547     LBasePtr = BaseIndexOffset::match(Ld, DAG);
15548     LoadVT = Ld->getMemoryVT();
15549     // Load and store should be the same type.
15550     if (MemVT != LoadVT)
15551       return;
15552     // Loads must only have one use.
15553     if (!Ld->hasNUsesOfValue(1, 0))
15554       return;
15555     // The memory operands must not be volatile/indexed/atomic.
15556     // TODO: May be able to relax for unordered atomics (see D66309)
15557     if (!Ld->isSimple() || Ld->isIndexed())
15558       return;
15559   }
15560   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
15561                             int64_t &Offset) -> bool {
15562     // The memory operands must not be volatile/indexed/atomic.
15563     // TODO: May be able to relax for unordered atomics (see D66309)
15564     if (!Other->isSimple() ||  Other->isIndexed())
15565       return false;
15566     // Don't mix temporal stores with non-temporal stores.
15567     if (St->isNonTemporal() != Other->isNonTemporal())
15568       return false;
15569     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
15570     // Allow merging constants of different types as integers.
15571     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
15572                                            : Other->getMemoryVT() != MemVT;
15573     if (IsLoadSrc) {
15574       if (NoTypeMatch)
15575         return false;
15576       // The Load's Base Ptr must also match
15577       if (LoadSDNode *OtherLd = dyn_cast<LoadSDNode>(OtherBC)) {
15578         BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
15579         if (LoadVT != OtherLd->getMemoryVT())
15580           return false;
15581         // Loads must only have one use.
15582         if (!OtherLd->hasNUsesOfValue(1, 0))
15583           return false;
15584         // The memory operands must not be volatile/indexed/atomic.
15585         // TODO: May be able to relax for unordered atomics (see D66309)
15586         if (!OtherLd->isSimple() ||
15587             OtherLd->isIndexed())
15588           return false;
15589         // Don't mix temporal loads with non-temporal loads.
15590         if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
15591           return false;
15592         if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
15593           return false;
15594       } else
15595         return false;
15596     }
15597     if (IsConstantSrc) {
15598       if (NoTypeMatch)
15599         return false;
15600       if (!(isa<ConstantSDNode>(OtherBC) || isa<ConstantFPSDNode>(OtherBC)))
15601         return false;
15602     }
15603     if (IsExtractVecSrc) {
15604       // Do not merge truncated stores here.
15605       if (Other->isTruncatingStore())
15606         return false;
15607       if (!MemVT.bitsEq(OtherBC.getValueType()))
15608         return false;
15609       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
15610           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
15611         return false;
15612     }
15613     Ptr = BaseIndexOffset::match(Other, DAG);
15614     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
15615   };
15616 
15617   // Check if the pair of StoreNode and the RootNode already bail out many
15618   // times which is over the limit in dependence check.
15619   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
15620                                         SDNode *RootNode) -> bool {
15621     auto RootCount = StoreRootCountMap.find(StoreNode);
15622     if (RootCount != StoreRootCountMap.end() &&
15623         RootCount->second.first == RootNode &&
15624         RootCount->second.second > StoreMergeDependenceLimit)
15625       return true;
15626     return false;
15627   };
15628 
15629   // We looking for a root node which is an ancestor to all mergable
15630   // stores. We search up through a load, to our root and then down
15631   // through all children. For instance we will find Store{1,2,3} if
15632   // St is Store1, Store2. or Store3 where the root is not a load
15633   // which always true for nonvolatile ops. TODO: Expand
15634   // the search to find all valid candidates through multiple layers of loads.
15635   //
15636   // Root
15637   // |-------|-------|
15638   // Load    Load    Store3
15639   // |       |
15640   // Store1   Store2
15641   //
15642   // FIXME: We should be able to climb and
15643   // descend TokenFactors to find candidates as well.
15644 
15645   RootNode = St->getChain().getNode();
15646 
15647   unsigned NumNodesExplored = 0;
15648   if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
15649     RootNode = Ldn->getChain().getNode();
15650     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
15651          I != E && NumNodesExplored < 1024; ++I, ++NumNodesExplored)
15652       if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) // walk down chain
15653         for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
15654           if (I2.getOperandNo() == 0)
15655             if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I2)) {
15656               BaseIndexOffset Ptr;
15657               int64_t PtrDiff;
15658               if (CandidateMatch(OtherST, Ptr, PtrDiff) &&
15659                   !OverLimitInDependenceCheck(OtherST, RootNode))
15660                 StoreNodes.push_back(MemOpLink(OtherST, PtrDiff));
15661             }
15662   } else
15663     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
15664          I != E && NumNodesExplored < 1024; ++I, ++NumNodesExplored)
15665       if (I.getOperandNo() == 0)
15666         if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I)) {
15667           BaseIndexOffset Ptr;
15668           int64_t PtrDiff;
15669           if (CandidateMatch(OtherST, Ptr, PtrDiff) &&
15670               !OverLimitInDependenceCheck(OtherST, RootNode))
15671             StoreNodes.push_back(MemOpLink(OtherST, PtrDiff));
15672         }
15673 }
15674 
15675 // We need to check that merging these stores does not cause a loop in
15676 // the DAG. Any store candidate may depend on another candidate
15677 // indirectly through its operand (we already consider dependencies
15678 // through the chain). Check in parallel by searching up from
15679 // non-chain operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)15680 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
15681     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
15682     SDNode *RootNode) {
15683   // FIXME: We should be able to truncate a full search of
15684   // predecessors by doing a BFS and keeping tabs the originating
15685   // stores from which worklist nodes come from in a similar way to
15686   // TokenFactor simplfication.
15687 
15688   SmallPtrSet<const SDNode *, 32> Visited;
15689   SmallVector<const SDNode *, 8> Worklist;
15690 
15691   // RootNode is a predecessor to all candidates so we need not search
15692   // past it. Add RootNode (peeking through TokenFactors). Do not count
15693   // these towards size check.
15694 
15695   Worklist.push_back(RootNode);
15696   while (!Worklist.empty()) {
15697     auto N = Worklist.pop_back_val();
15698     if (!Visited.insert(N).second)
15699       continue; // Already present in Visited.
15700     if (N->getOpcode() == ISD::TokenFactor) {
15701       for (SDValue Op : N->ops())
15702         Worklist.push_back(Op.getNode());
15703     }
15704   }
15705 
15706   // Don't count pruning nodes towards max.
15707   unsigned int Max = 1024 + Visited.size();
15708   // Search Ops of store candidates.
15709   for (unsigned i = 0; i < NumStores; ++i) {
15710     SDNode *N = StoreNodes[i].MemNode;
15711     // Of the 4 Store Operands:
15712     //   * Chain (Op 0) -> We have already considered these
15713     //                    in candidate selection and can be
15714     //                    safely ignored
15715     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
15716     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
15717     //                       but aren't necessarily fromt the same base node, so
15718     //                       cycles possible (e.g. via indexed store).
15719     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
15720     //               non-indexed stores). Not constant on all targets (e.g. ARM)
15721     //               and so can participate in a cycle.
15722     for (unsigned j = 1; j < N->getNumOperands(); ++j)
15723       Worklist.push_back(N->getOperand(j).getNode());
15724   }
15725   // Search through DAG. We can stop early if we find a store node.
15726   for (unsigned i = 0; i < NumStores; ++i)
15727     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
15728                                      Max)) {
15729       // If the searching bail out, record the StoreNode and RootNode in the
15730       // StoreRootCountMap. If we have seen the pair many times over a limit,
15731       // we won't add the StoreNode into StoreNodes set again.
15732       if (Visited.size() >= Max) {
15733         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
15734         if (RootCount.first == RootNode)
15735           RootCount.second++;
15736         else
15737           RootCount = {RootNode, 1};
15738       }
15739       return false;
15740     }
15741   return true;
15742 }
15743 
MergeConsecutiveStores(StoreSDNode * St)15744 bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
15745   if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
15746     return false;
15747 
15748   EVT MemVT = St->getMemoryVT();
15749   int64_t ElementSizeBytes = MemVT.getStoreSize();
15750   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
15751 
15752   if (MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
15753     return false;
15754 
15755   bool NoVectors = DAG.getMachineFunction().getFunction().hasFnAttribute(
15756       Attribute::NoImplicitFloat);
15757 
15758   // This function cannot currently deal with non-byte-sized memory sizes.
15759   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
15760     return false;
15761 
15762   if (!MemVT.isSimple())
15763     return false;
15764 
15765   // Perform an early exit check. Do not bother looking at stored values that
15766   // are not constants, loads, or extracted vector elements.
15767   SDValue StoredVal = peekThroughBitcasts(St->getValue());
15768   bool IsLoadSrc = isa<LoadSDNode>(StoredVal);
15769   bool IsConstantSrc = isa<ConstantSDNode>(StoredVal) ||
15770                        isa<ConstantFPSDNode>(StoredVal);
15771   bool IsExtractVecSrc = (StoredVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
15772                           StoredVal.getOpcode() == ISD::EXTRACT_SUBVECTOR);
15773   bool IsNonTemporalStore = St->isNonTemporal();
15774   bool IsNonTemporalLoad =
15775       IsLoadSrc && cast<LoadSDNode>(StoredVal)->isNonTemporal();
15776 
15777   if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc)
15778     return false;
15779 
15780   SmallVector<MemOpLink, 8> StoreNodes;
15781   SDNode *RootNode;
15782   // Find potential store merge candidates by searching through chain sub-DAG
15783   getStoreMergeCandidates(St, StoreNodes, RootNode);
15784 
15785   // Check if there is anything to merge.
15786   if (StoreNodes.size() < 2)
15787     return false;
15788 
15789   // Sort the memory operands according to their distance from the
15790   // base pointer.
15791   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
15792     return LHS.OffsetFromBase < RHS.OffsetFromBase;
15793   });
15794 
15795   // Store Merge attempts to merge the lowest stores. This generally
15796   // works out as if successful, as the remaining stores are checked
15797   // after the first collection of stores is merged. However, in the
15798   // case that a non-mergeable store is found first, e.g., {p[-2],
15799   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
15800   // mergeable cases. To prevent this, we prune such stores from the
15801   // front of StoreNodes here.
15802 
15803   bool RV = false;
15804   while (StoreNodes.size() > 1) {
15805     size_t StartIdx = 0;
15806     while ((StartIdx + 1 < StoreNodes.size()) &&
15807            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
15808                StoreNodes[StartIdx + 1].OffsetFromBase)
15809       ++StartIdx;
15810 
15811     // Bail if we don't have enough candidates to merge.
15812     if (StartIdx + 1 >= StoreNodes.size())
15813       return RV;
15814 
15815     if (StartIdx)
15816       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
15817 
15818     // Scan the memory operations on the chain and find the first
15819     // non-consecutive store memory address.
15820     unsigned NumConsecutiveStores = 1;
15821     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
15822     // Check that the addresses are consecutive starting from the second
15823     // element in the list of stores.
15824     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
15825       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
15826       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
15827         break;
15828       NumConsecutiveStores = i + 1;
15829     }
15830 
15831     if (NumConsecutiveStores < 2) {
15832       StoreNodes.erase(StoreNodes.begin(),
15833                        StoreNodes.begin() + NumConsecutiveStores);
15834       continue;
15835     }
15836 
15837     // The node with the lowest store address.
15838     LLVMContext &Context = *DAG.getContext();
15839     const DataLayout &DL = DAG.getDataLayout();
15840 
15841     // Store the constants into memory as one consecutive store.
15842     if (IsConstantSrc) {
15843       while (NumConsecutiveStores >= 2) {
15844         LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
15845         unsigned FirstStoreAS = FirstInChain->getAddressSpace();
15846         unsigned FirstStoreAlign = FirstInChain->getAlignment();
15847         unsigned LastLegalType = 1;
15848         unsigned LastLegalVectorType = 1;
15849         bool LastIntegerTrunc = false;
15850         bool NonZero = false;
15851         unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
15852         for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
15853           StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
15854           SDValue StoredVal = ST->getValue();
15855           bool IsElementZero = false;
15856           if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
15857             IsElementZero = C->isNullValue();
15858           else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
15859             IsElementZero = C->getConstantFPValue()->isNullValue();
15860           if (IsElementZero) {
15861             if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
15862               FirstZeroAfterNonZero = i;
15863           }
15864           NonZero |= !IsElementZero;
15865 
15866           // Find a legal type for the constant store.
15867           unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
15868           EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
15869           bool IsFast = false;
15870 
15871           // Break early when size is too large to be legal.
15872           if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
15873             break;
15874 
15875           if (TLI.isTypeLegal(StoreTy) &&
15876               TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
15877               TLI.allowsMemoryAccess(Context, DL, StoreTy,
15878                                      *FirstInChain->getMemOperand(), &IsFast) &&
15879               IsFast) {
15880             LastIntegerTrunc = false;
15881             LastLegalType = i + 1;
15882             // Or check whether a truncstore is legal.
15883           } else if (TLI.getTypeAction(Context, StoreTy) ==
15884                      TargetLowering::TypePromoteInteger) {
15885             EVT LegalizedStoredValTy =
15886                 TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
15887             if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
15888                 TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
15889                 TLI.allowsMemoryAccess(Context, DL, StoreTy,
15890                                        *FirstInChain->getMemOperand(),
15891                                        &IsFast) &&
15892                 IsFast) {
15893               LastIntegerTrunc = true;
15894               LastLegalType = i + 1;
15895             }
15896           }
15897 
15898           // We only use vectors if the constant is known to be zero or the
15899           // target allows it and the function is not marked with the
15900           // noimplicitfloat attribute.
15901           if ((!NonZero ||
15902                TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
15903               !NoVectors) {
15904             // Find a legal type for the vector store.
15905             unsigned Elts = (i + 1) * NumMemElts;
15906             EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
15907             if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
15908                 TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
15909                 TLI.allowsMemoryAccess(
15910                     Context, DL, Ty, *FirstInChain->getMemOperand(), &IsFast) &&
15911                 IsFast)
15912               LastLegalVectorType = i + 1;
15913           }
15914         }
15915 
15916         bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors;
15917         unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
15918 
15919         // Check if we found a legal integer type that creates a meaningful
15920         // merge.
15921         if (NumElem < 2) {
15922           // We know that candidate stores are in order and of correct
15923           // shape. While there is no mergeable sequence from the
15924           // beginning one may start later in the sequence. The only
15925           // reason a merge of size N could have failed where another of
15926           // the same size would not have, is if the alignment has
15927           // improved or we've dropped a non-zero value. Drop as many
15928           // candidates as we can here.
15929           unsigned NumSkip = 1;
15930           while (
15931               (NumSkip < NumConsecutiveStores) &&
15932               (NumSkip < FirstZeroAfterNonZero) &&
15933               (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
15934             NumSkip++;
15935 
15936           StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
15937           NumConsecutiveStores -= NumSkip;
15938           continue;
15939         }
15940 
15941         // Check that we can merge these candidates without causing a cycle.
15942         if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
15943                                                       RootNode)) {
15944           StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
15945           NumConsecutiveStores -= NumElem;
15946           continue;
15947         }
15948 
15949         RV |= MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, true,
15950                                               UseVector, LastIntegerTrunc);
15951 
15952         // Remove merged stores for next iteration.
15953         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
15954         NumConsecutiveStores -= NumElem;
15955       }
15956       continue;
15957     }
15958 
15959     // When extracting multiple vector elements, try to store them
15960     // in one vector store rather than a sequence of scalar stores.
15961     if (IsExtractVecSrc) {
15962       // Loop on Consecutive Stores on success.
15963       while (NumConsecutiveStores >= 2) {
15964         LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
15965         unsigned FirstStoreAS = FirstInChain->getAddressSpace();
15966         unsigned FirstStoreAlign = FirstInChain->getAlignment();
15967         unsigned NumStoresToMerge = 1;
15968         for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
15969           // Find a legal type for the vector store.
15970           unsigned Elts = (i + 1) * NumMemElts;
15971           EVT Ty =
15972               EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
15973           bool IsFast;
15974 
15975           // Break early when size is too large to be legal.
15976           if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
15977             break;
15978 
15979           if (TLI.isTypeLegal(Ty) &&
15980               TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
15981               TLI.allowsMemoryAccess(Context, DL, Ty,
15982                                      *FirstInChain->getMemOperand(), &IsFast) &&
15983               IsFast)
15984             NumStoresToMerge = i + 1;
15985         }
15986 
15987         // Check if we found a legal integer type creating a meaningful
15988         // merge.
15989         if (NumStoresToMerge < 2) {
15990           // We know that candidate stores are in order and of correct
15991           // shape. While there is no mergeable sequence from the
15992           // beginning one may start later in the sequence. The only
15993           // reason a merge of size N could have failed where another of
15994           // the same size would not have, is if the alignment has
15995           // improved. Drop as many candidates as we can here.
15996           unsigned NumSkip = 1;
15997           while (
15998               (NumSkip < NumConsecutiveStores) &&
15999               (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
16000             NumSkip++;
16001 
16002           StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
16003           NumConsecutiveStores -= NumSkip;
16004           continue;
16005         }
16006 
16007         // Check that we can merge these candidates without causing a cycle.
16008         if (!checkMergeStoreCandidatesForDependencies(
16009                 StoreNodes, NumStoresToMerge, RootNode)) {
16010           StoreNodes.erase(StoreNodes.begin(),
16011                            StoreNodes.begin() + NumStoresToMerge);
16012           NumConsecutiveStores -= NumStoresToMerge;
16013           continue;
16014         }
16015 
16016         RV |= MergeStoresOfConstantsOrVecElts(
16017             StoreNodes, MemVT, NumStoresToMerge, false, true, false);
16018 
16019         StoreNodes.erase(StoreNodes.begin(),
16020                          StoreNodes.begin() + NumStoresToMerge);
16021         NumConsecutiveStores -= NumStoresToMerge;
16022       }
16023       continue;
16024     }
16025 
16026     // Below we handle the case of multiple consecutive stores that
16027     // come from multiple consecutive loads. We merge them into a single
16028     // wide load and a single wide store.
16029 
16030     // Look for load nodes which are used by the stored values.
16031     SmallVector<MemOpLink, 8> LoadNodes;
16032 
16033     // Find acceptable loads. Loads need to have the same chain (token factor),
16034     // must not be zext, volatile, indexed, and they must be consecutive.
16035     BaseIndexOffset LdBasePtr;
16036 
16037     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
16038       StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
16039       SDValue Val = peekThroughBitcasts(St->getValue());
16040       LoadSDNode *Ld = cast<LoadSDNode>(Val);
16041 
16042       BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
16043       // If this is not the first ptr that we check.
16044       int64_t LdOffset = 0;
16045       if (LdBasePtr.getBase().getNode()) {
16046         // The base ptr must be the same.
16047         if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
16048           break;
16049       } else {
16050         // Check that all other base pointers are the same as this one.
16051         LdBasePtr = LdPtr;
16052       }
16053 
16054       // We found a potential memory operand to merge.
16055       LoadNodes.push_back(MemOpLink(Ld, LdOffset));
16056     }
16057 
16058     while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
16059       // If we have load/store pair instructions and we only have two values,
16060       // don't bother merging.
16061       unsigned RequiredAlignment;
16062       if (LoadNodes.size() == 2 &&
16063           TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
16064           StoreNodes[0].MemNode->getAlignment() >= RequiredAlignment) {
16065         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
16066         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
16067         break;
16068       }
16069       LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
16070       unsigned FirstStoreAS = FirstInChain->getAddressSpace();
16071       unsigned FirstStoreAlign = FirstInChain->getAlignment();
16072       LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
16073       unsigned FirstLoadAlign = FirstLoad->getAlignment();
16074 
16075       // Scan the memory operations on the chain and find the first
16076       // non-consecutive load memory address. These variables hold the index in
16077       // the store node array.
16078 
16079       unsigned LastConsecutiveLoad = 1;
16080 
16081       // This variable refers to the size and not index in the array.
16082       unsigned LastLegalVectorType = 1;
16083       unsigned LastLegalIntegerType = 1;
16084       bool isDereferenceable = true;
16085       bool DoIntegerTruncate = false;
16086       StartAddress = LoadNodes[0].OffsetFromBase;
16087       SDValue FirstChain = FirstLoad->getChain();
16088       for (unsigned i = 1; i < LoadNodes.size(); ++i) {
16089         // All loads must share the same chain.
16090         if (LoadNodes[i].MemNode->getChain() != FirstChain)
16091           break;
16092 
16093         int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
16094         if (CurrAddress - StartAddress != (ElementSizeBytes * i))
16095           break;
16096         LastConsecutiveLoad = i;
16097 
16098         if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
16099           isDereferenceable = false;
16100 
16101         // Find a legal type for the vector store.
16102         unsigned Elts = (i + 1) * NumMemElts;
16103         EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
16104 
16105         // Break early when size is too large to be legal.
16106         if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
16107           break;
16108 
16109         bool IsFastSt, IsFastLd;
16110         if (TLI.isTypeLegal(StoreTy) &&
16111             TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
16112             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16113                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
16114             IsFastSt &&
16115             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16116                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
16117             IsFastLd) {
16118           LastLegalVectorType = i + 1;
16119         }
16120 
16121         // Find a legal type for the integer store.
16122         unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
16123         StoreTy = EVT::getIntegerVT(Context, SizeInBits);
16124         if (TLI.isTypeLegal(StoreTy) &&
16125             TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
16126             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16127                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
16128             IsFastSt &&
16129             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16130                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
16131             IsFastLd) {
16132           LastLegalIntegerType = i + 1;
16133           DoIntegerTruncate = false;
16134           // Or check whether a truncstore and extload is legal.
16135         } else if (TLI.getTypeAction(Context, StoreTy) ==
16136                    TargetLowering::TypePromoteInteger) {
16137           EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
16138           if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
16139               TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
16140               TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy,
16141                                  StoreTy) &&
16142               TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy,
16143                                  StoreTy) &&
16144               TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
16145               TLI.allowsMemoryAccess(Context, DL, StoreTy,
16146                                      *FirstInChain->getMemOperand(),
16147                                      &IsFastSt) &&
16148               IsFastSt &&
16149               TLI.allowsMemoryAccess(Context, DL, StoreTy,
16150                                      *FirstLoad->getMemOperand(), &IsFastLd) &&
16151               IsFastLd) {
16152             LastLegalIntegerType = i + 1;
16153             DoIntegerTruncate = true;
16154           }
16155         }
16156       }
16157 
16158       // Only use vector types if the vector type is larger than the integer
16159       // type. If they are the same, use integers.
16160       bool UseVectorTy =
16161           LastLegalVectorType > LastLegalIntegerType && !NoVectors;
16162       unsigned LastLegalType =
16163           std::max(LastLegalVectorType, LastLegalIntegerType);
16164 
16165       // We add +1 here because the LastXXX variables refer to location while
16166       // the NumElem refers to array/index size.
16167       unsigned NumElem =
16168           std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
16169       NumElem = std::min(LastLegalType, NumElem);
16170 
16171       if (NumElem < 2) {
16172         // We know that candidate stores are in order and of correct
16173         // shape. While there is no mergeable sequence from the
16174         // beginning one may start later in the sequence. The only
16175         // reason a merge of size N could have failed where another of
16176         // the same size would not have is if the alignment or either
16177         // the load or store has improved. Drop as many candidates as we
16178         // can here.
16179         unsigned NumSkip = 1;
16180         while ((NumSkip < LoadNodes.size()) &&
16181                (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) &&
16182                (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
16183           NumSkip++;
16184         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
16185         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
16186         NumConsecutiveStores -= NumSkip;
16187         continue;
16188       }
16189 
16190       // Check that we can merge these candidates without causing a cycle.
16191       if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
16192                                                     RootNode)) {
16193         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
16194         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
16195         NumConsecutiveStores -= NumElem;
16196         continue;
16197       }
16198 
16199       // Find if it is better to use vectors or integers to load and store
16200       // to memory.
16201       EVT JointMemOpVT;
16202       if (UseVectorTy) {
16203         // Find a legal type for the vector store.
16204         unsigned Elts = NumElem * NumMemElts;
16205         JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
16206       } else {
16207         unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
16208         JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
16209       }
16210 
16211       SDLoc LoadDL(LoadNodes[0].MemNode);
16212       SDLoc StoreDL(StoreNodes[0].MemNode);
16213 
16214       // The merged loads are required to have the same incoming chain, so
16215       // using the first's chain is acceptable.
16216 
16217       SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
16218       AddToWorklist(NewStoreChain.getNode());
16219 
16220       MachineMemOperand::Flags LdMMOFlags =
16221           isDereferenceable ? MachineMemOperand::MODereferenceable
16222                             : MachineMemOperand::MONone;
16223       if (IsNonTemporalLoad)
16224         LdMMOFlags |= MachineMemOperand::MONonTemporal;
16225 
16226       MachineMemOperand::Flags StMMOFlags =
16227           IsNonTemporalStore ? MachineMemOperand::MONonTemporal
16228                              : MachineMemOperand::MONone;
16229 
16230       SDValue NewLoad, NewStore;
16231       if (UseVectorTy || !DoIntegerTruncate) {
16232         NewLoad =
16233             DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(),
16234                         FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(),
16235                         FirstLoadAlign, LdMMOFlags);
16236         NewStore = DAG.getStore(
16237             NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
16238             FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
16239       } else { // This must be the truncstore/extload case
16240         EVT ExtendedTy =
16241             TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
16242         NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
16243                                  FirstLoad->getChain(), FirstLoad->getBasePtr(),
16244                                  FirstLoad->getPointerInfo(), JointMemOpVT,
16245                                  FirstLoadAlign, LdMMOFlags);
16246         NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad,
16247                                      FirstInChain->getBasePtr(),
16248                                      FirstInChain->getPointerInfo(),
16249                                      JointMemOpVT, FirstInChain->getAlignment(),
16250                                      FirstInChain->getMemOperand()->getFlags());
16251       }
16252 
16253       // Transfer chain users from old loads to the new load.
16254       for (unsigned i = 0; i < NumElem; ++i) {
16255         LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
16256         DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
16257                                       SDValue(NewLoad.getNode(), 1));
16258       }
16259 
16260       // Replace the all stores with the new store. Recursively remove
16261       // corresponding value if its no longer used.
16262       for (unsigned i = 0; i < NumElem; ++i) {
16263         SDValue Val = StoreNodes[i].MemNode->getOperand(1);
16264         CombineTo(StoreNodes[i].MemNode, NewStore);
16265         if (Val.getNode()->use_empty())
16266           recursivelyDeleteUnusedNodes(Val.getNode());
16267       }
16268 
16269       RV = true;
16270       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
16271       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
16272       NumConsecutiveStores -= NumElem;
16273     }
16274   }
16275   return RV;
16276 }
16277 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)16278 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
16279   SDLoc SL(ST);
16280   SDValue ReplStore;
16281 
16282   // Replace the chain to avoid dependency.
16283   if (ST->isTruncatingStore()) {
16284     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
16285                                   ST->getBasePtr(), ST->getMemoryVT(),
16286                                   ST->getMemOperand());
16287   } else {
16288     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
16289                              ST->getMemOperand());
16290   }
16291 
16292   // Create token to keep both nodes around.
16293   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
16294                               MVT::Other, ST->getChain(), ReplStore);
16295 
16296   // Make sure the new and old chains are cleaned up.
16297   AddToWorklist(Token.getNode());
16298 
16299   // Don't add users to work list.
16300   return CombineTo(ST, Token, false);
16301 }
16302 
replaceStoreOfFPConstant(StoreSDNode * ST)16303 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
16304   SDValue Value = ST->getValue();
16305   if (Value.getOpcode() == ISD::TargetConstantFP)
16306     return SDValue();
16307 
16308   if (!ISD::isNormalStore(ST))
16309     return SDValue();
16310 
16311   SDLoc DL(ST);
16312 
16313   SDValue Chain = ST->getChain();
16314   SDValue Ptr = ST->getBasePtr();
16315 
16316   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
16317 
16318   // NOTE: If the original store is volatile, this transform must not increase
16319   // the number of stores.  For example, on x86-32 an f64 can be stored in one
16320   // processor operation but an i64 (which is not legal) requires two.  So the
16321   // transform should not be done in this case.
16322 
16323   SDValue Tmp;
16324   switch (CFP->getSimpleValueType(0).SimpleTy) {
16325   default:
16326     llvm_unreachable("Unknown FP type");
16327   case MVT::f16:    // We don't do this for these yet.
16328   case MVT::f80:
16329   case MVT::f128:
16330   case MVT::ppcf128:
16331     return SDValue();
16332   case MVT::f32:
16333     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
16334         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
16335       ;
16336       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
16337                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
16338                             MVT::i32);
16339       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
16340     }
16341 
16342     return SDValue();
16343   case MVT::f64:
16344     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
16345          ST->isSimple()) ||
16346         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
16347       ;
16348       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
16349                             getZExtValue(), SDLoc(CFP), MVT::i64);
16350       return DAG.getStore(Chain, DL, Tmp,
16351                           Ptr, ST->getMemOperand());
16352     }
16353 
16354     if (ST->isSimple() &&
16355         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
16356       // Many FP stores are not made apparent until after legalize, e.g. for
16357       // argument passing.  Since this is so common, custom legalize the
16358       // 64-bit integer store into two 32-bit stores.
16359       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
16360       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
16361       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
16362       if (DAG.getDataLayout().isBigEndian())
16363         std::swap(Lo, Hi);
16364 
16365       unsigned Alignment = ST->getAlignment();
16366       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
16367       AAMDNodes AAInfo = ST->getAAInfo();
16368 
16369       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
16370                                  ST->getAlignment(), MMOFlags, AAInfo);
16371       Ptr = DAG.getMemBasePlusOffset(Ptr, 4, DL);
16372       Alignment = MinAlign(Alignment, 4U);
16373       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
16374                                  ST->getPointerInfo().getWithOffset(4),
16375                                  Alignment, MMOFlags, AAInfo);
16376       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
16377                          St0, St1);
16378     }
16379 
16380     return SDValue();
16381   }
16382 }
16383 
visitSTORE(SDNode * N)16384 SDValue DAGCombiner::visitSTORE(SDNode *N) {
16385   StoreSDNode *ST  = cast<StoreSDNode>(N);
16386   SDValue Chain = ST->getChain();
16387   SDValue Value = ST->getValue();
16388   SDValue Ptr   = ST->getBasePtr();
16389 
16390   // If this is a store of a bit convert, store the input value if the
16391   // resultant store does not need a higher alignment than the original.
16392   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
16393       ST->isUnindexed()) {
16394     EVT SVT = Value.getOperand(0).getValueType();
16395     // If the store is volatile, we only want to change the store type if the
16396     // resulting store is legal. Otherwise we might increase the number of
16397     // memory accesses. We don't care if the original type was legal or not
16398     // as we assume software couldn't rely on the number of accesses of an
16399     // illegal type.
16400     // TODO: May be able to relax for unordered atomics (see D66309)
16401     if (((!LegalOperations && ST->isSimple()) ||
16402          TLI.isOperationLegal(ISD::STORE, SVT)) &&
16403         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
16404                                      DAG, *ST->getMemOperand())) {
16405       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
16406                           ST->getMemOperand());
16407     }
16408   }
16409 
16410   // Turn 'store undef, Ptr' -> nothing.
16411   if (Value.isUndef() && ST->isUnindexed())
16412     return Chain;
16413 
16414   // Try to infer better alignment information than the store already has.
16415   if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) {
16416     if (unsigned Align = DAG.InferPtrAlignment(Ptr)) {
16417       if (Align > ST->getAlignment() && ST->getSrcValueOffset() % Align == 0) {
16418         SDValue NewStore =
16419             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
16420                               ST->getMemoryVT(), Align,
16421                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
16422         // NewStore will always be N as we are only refining the alignment
16423         assert(NewStore.getNode() == N);
16424         (void)NewStore;
16425       }
16426     }
16427   }
16428 
16429   // Try transforming a pair floating point load / store ops to integer
16430   // load / store ops.
16431   if (SDValue NewST = TransformFPLoadStorePair(N))
16432     return NewST;
16433 
16434   // Try transforming several stores into STORE (BSWAP).
16435   if (SDValue Store = MatchStoreCombine(ST))
16436     return Store;
16437 
16438   if (ST->isUnindexed()) {
16439     // Walk up chain skipping non-aliasing memory nodes, on this store and any
16440     // adjacent stores.
16441     if (findBetterNeighborChains(ST)) {
16442       // replaceStoreChain uses CombineTo, which handled all of the worklist
16443       // manipulation. Return the original node to not do anything else.
16444       return SDValue(ST, 0);
16445     }
16446     Chain = ST->getChain();
16447   }
16448 
16449   // FIXME: is there such a thing as a truncating indexed store?
16450   if (ST->isTruncatingStore() && ST->isUnindexed() &&
16451       Value.getValueType().isInteger() &&
16452       (!isa<ConstantSDNode>(Value) ||
16453        !cast<ConstantSDNode>(Value)->isOpaque())) {
16454     APInt TruncDemandedBits =
16455         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
16456                              ST->getMemoryVT().getScalarSizeInBits());
16457 
16458     // See if we can simplify the input to this truncstore with knowledge that
16459     // only the low bits are being used.  For example:
16460     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
16461     AddToWorklist(Value.getNode());
16462     if (SDValue Shorter = DAG.GetDemandedBits(Value, TruncDemandedBits))
16463       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
16464                                ST->getMemOperand());
16465 
16466     // Otherwise, see if we can simplify the operation with
16467     // SimplifyDemandedBits, which only works if the value has a single use.
16468     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
16469       // Re-visit the store if anything changed and the store hasn't been merged
16470       // with another node (N is deleted) SimplifyDemandedBits will add Value's
16471       // node back to the worklist if necessary, but we also need to re-visit
16472       // the Store node itself.
16473       if (N->getOpcode() != ISD::DELETED_NODE)
16474         AddToWorklist(N);
16475       return SDValue(N, 0);
16476     }
16477   }
16478 
16479   // If this is a load followed by a store to the same location, then the store
16480   // is dead/noop.
16481   // TODO: Can relax for unordered atomics (see D66309)
16482   if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) {
16483     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
16484         ST->isUnindexed() && ST->isSimple() &&
16485         // There can't be any side effects between the load and store, such as
16486         // a call or store.
16487         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
16488       // The store is dead, remove it.
16489       return Chain;
16490     }
16491   }
16492 
16493   // TODO: Can relax for unordered atomics (see D66309)
16494   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
16495     if (ST->isUnindexed() && ST->isSimple() &&
16496         ST1->isUnindexed() && ST1->isSimple()) {
16497       if (ST1->getBasePtr() == Ptr && ST1->getValue() == Value &&
16498           ST->getMemoryVT() == ST1->getMemoryVT()) {
16499         // If this is a store followed by a store with the same value to the
16500         // same location, then the store is dead/noop.
16501         return Chain;
16502       }
16503 
16504       if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
16505           !ST1->getBasePtr().isUndef()) {
16506         const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
16507         const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
16508         unsigned STBitSize = ST->getMemoryVT().getSizeInBits();
16509         unsigned ChainBitSize = ST1->getMemoryVT().getSizeInBits();
16510         // If this is a store who's preceding store to a subset of the current
16511         // location and no one other node is chained to that store we can
16512         // effectively drop the store. Do not remove stores to undef as they may
16513         // be used as data sinks.
16514         if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) {
16515           CombineTo(ST1, ST1->getChain());
16516           return SDValue();
16517         }
16518       }
16519     }
16520   }
16521 
16522   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
16523   // truncating store.  We can do this even if this is already a truncstore.
16524   if ((Value.getOpcode() == ISD::FP_ROUND || Value.getOpcode() == ISD::TRUNCATE)
16525       && Value.getNode()->hasOneUse() && ST->isUnindexed() &&
16526       TLI.isTruncStoreLegal(Value.getOperand(0).getValueType(),
16527                             ST->getMemoryVT())) {
16528     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
16529                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
16530   }
16531 
16532   // Always perform this optimization before types are legal. If the target
16533   // prefers, also try this after legalization to catch stores that were created
16534   // by intrinsics or other nodes.
16535   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
16536     while (true) {
16537       // There can be multiple store sequences on the same chain.
16538       // Keep trying to merge store sequences until we are unable to do so
16539       // or until we merge the last store on the chain.
16540       bool Changed = MergeConsecutiveStores(ST);
16541       if (!Changed) break;
16542       // Return N as merge only uses CombineTo and no worklist clean
16543       // up is necessary.
16544       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
16545         return SDValue(N, 0);
16546     }
16547   }
16548 
16549   // Try transforming N to an indexed store.
16550   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
16551     return SDValue(N, 0);
16552 
16553   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
16554   //
16555   // Make sure to do this only after attempting to merge stores in order to
16556   //  avoid changing the types of some subset of stores due to visit order,
16557   //  preventing their merging.
16558   if (isa<ConstantFPSDNode>(ST->getValue())) {
16559     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
16560       return NewSt;
16561   }
16562 
16563   if (SDValue NewSt = splitMergedValStore(ST))
16564     return NewSt;
16565 
16566   return ReduceLoadOpStoreWidth(N);
16567 }
16568 
visitLIFETIME_END(SDNode * N)16569 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
16570   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
16571   if (!LifetimeEnd->hasOffset())
16572     return SDValue();
16573 
16574   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
16575                                         LifetimeEnd->getOffset(), false);
16576 
16577   // We walk up the chains to find stores.
16578   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
16579   while (!Chains.empty()) {
16580     SDValue Chain = Chains.back();
16581     Chains.pop_back();
16582     if (!Chain.hasOneUse())
16583       continue;
16584     switch (Chain.getOpcode()) {
16585     case ISD::TokenFactor:
16586       for (unsigned Nops = Chain.getNumOperands(); Nops;)
16587         Chains.push_back(Chain.getOperand(--Nops));
16588       break;
16589     case ISD::LIFETIME_START:
16590     case ISD::LIFETIME_END:
16591       // We can forward past any lifetime start/end that can be proven not to
16592       // alias the node.
16593       if (!isAlias(Chain.getNode(), N))
16594         Chains.push_back(Chain.getOperand(0));
16595       break;
16596     case ISD::STORE: {
16597       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
16598       // TODO: Can relax for unordered atomics (see D66309)
16599       if (!ST->isSimple() || ST->isIndexed())
16600         continue;
16601       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
16602       // If we store purely within object bounds just before its lifetime ends,
16603       // we can remove the store.
16604       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
16605                                    ST->getMemoryVT().getStoreSizeInBits())) {
16606         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
16607                    dbgs() << "\nwithin LIFETIME_END of : ";
16608                    LifetimeEndBase.dump(); dbgs() << "\n");
16609         CombineTo(ST, ST->getChain());
16610         return SDValue(N, 0);
16611       }
16612     }
16613     }
16614   }
16615   return SDValue();
16616 }
16617 
16618 /// For the instruction sequence of store below, F and I values
16619 /// are bundled together as an i64 value before being stored into memory.
16620 /// Sometimes it is more efficent to generate separate stores for F and I,
16621 /// which can remove the bitwise instructions or sink them to colder places.
16622 ///
16623 ///   (store (or (zext (bitcast F to i32) to i64),
16624 ///              (shl (zext I to i64), 32)), addr)  -->
16625 ///   (store F, addr) and (store I, addr+4)
16626 ///
16627 /// Similarly, splitting for other merged store can also be beneficial, like:
16628 /// For pair of {i32, i32}, i64 store --> two i32 stores.
16629 /// For pair of {i32, i16}, i64 store --> two i32 stores.
16630 /// For pair of {i16, i16}, i32 store --> two i16 stores.
16631 /// For pair of {i16, i8},  i32 store --> two i16 stores.
16632 /// For pair of {i8, i8},   i16 store --> two i8 stores.
16633 ///
16634 /// We allow each target to determine specifically which kind of splitting is
16635 /// supported.
16636 ///
16637 /// The store patterns are commonly seen from the simple code snippet below
16638 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
16639 ///   void goo(const std::pair<int, float> &);
16640 ///   hoo() {
16641 ///     ...
16642 ///     goo(std::make_pair(tmp, ftmp));
16643 ///     ...
16644 ///   }
16645 ///
splitMergedValStore(StoreSDNode * ST)16646 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
16647   if (OptLevel == CodeGenOpt::None)
16648     return SDValue();
16649 
16650   // Can't change the number of memory accesses for a volatile store or break
16651   // atomicity for an atomic one.
16652   if (!ST->isSimple())
16653     return SDValue();
16654 
16655   SDValue Val = ST->getValue();
16656   SDLoc DL(ST);
16657 
16658   // Match OR operand.
16659   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
16660     return SDValue();
16661 
16662   // Match SHL operand and get Lower and Higher parts of Val.
16663   SDValue Op1 = Val.getOperand(0);
16664   SDValue Op2 = Val.getOperand(1);
16665   SDValue Lo, Hi;
16666   if (Op1.getOpcode() != ISD::SHL) {
16667     std::swap(Op1, Op2);
16668     if (Op1.getOpcode() != ISD::SHL)
16669       return SDValue();
16670   }
16671   Lo = Op2;
16672   Hi = Op1.getOperand(0);
16673   if (!Op1.hasOneUse())
16674     return SDValue();
16675 
16676   // Match shift amount to HalfValBitSize.
16677   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
16678   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
16679   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
16680     return SDValue();
16681 
16682   // Lo and Hi are zero-extended from int with size less equal than 32
16683   // to i64.
16684   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
16685       !Lo.getOperand(0).getValueType().isScalarInteger() ||
16686       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
16687       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
16688       !Hi.getOperand(0).getValueType().isScalarInteger() ||
16689       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
16690     return SDValue();
16691 
16692   // Use the EVT of low and high parts before bitcast as the input
16693   // of target query.
16694   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
16695                   ? Lo.getOperand(0).getValueType()
16696                   : Lo.getValueType();
16697   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
16698                    ? Hi.getOperand(0).getValueType()
16699                    : Hi.getValueType();
16700   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
16701     return SDValue();
16702 
16703   // Start to split store.
16704   unsigned Alignment = ST->getAlignment();
16705   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
16706   AAMDNodes AAInfo = ST->getAAInfo();
16707 
16708   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
16709   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
16710   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
16711   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
16712 
16713   SDValue Chain = ST->getChain();
16714   SDValue Ptr = ST->getBasePtr();
16715   // Lower value store.
16716   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
16717                              ST->getAlignment(), MMOFlags, AAInfo);
16718   Ptr = DAG.getMemBasePlusOffset(Ptr, HalfValBitSize / 8, DL);
16719   // Higher value store.
16720   SDValue St1 =
16721       DAG.getStore(St0, DL, Hi, Ptr,
16722                    ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
16723                    Alignment / 2, MMOFlags, AAInfo);
16724   return St1;
16725 }
16726 
16727 /// Convert a disguised subvector insertion into a shuffle:
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)16728 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
16729   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
16730          "Expected extract_vector_elt");
16731   SDValue InsertVal = N->getOperand(1);
16732   SDValue Vec = N->getOperand(0);
16733 
16734   // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
16735   // InsIndex)
16736   //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
16737   //   CONCAT_VECTORS.
16738   if (Vec.getOpcode() == ISD::VECTOR_SHUFFLE && Vec.hasOneUse() &&
16739       InsertVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
16740       isa<ConstantSDNode>(InsertVal.getOperand(1))) {
16741     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Vec.getNode());
16742     ArrayRef<int> Mask = SVN->getMask();
16743 
16744     SDValue X = Vec.getOperand(0);
16745     SDValue Y = Vec.getOperand(1);
16746 
16747     // Vec's operand 0 is using indices from 0 to N-1 and
16748     // operand 1 from N to 2N - 1, where N is the number of
16749     // elements in the vectors.
16750     SDValue InsertVal0 = InsertVal.getOperand(0);
16751     int ElementOffset = -1;
16752 
16753     // We explore the inputs of the shuffle in order to see if we find the
16754     // source of the extract_vector_elt. If so, we can use it to modify the
16755     // shuffle rather than perform an insert_vector_elt.
16756     SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
16757     ArgWorkList.emplace_back(Mask.size(), Y);
16758     ArgWorkList.emplace_back(0, X);
16759 
16760     while (!ArgWorkList.empty()) {
16761       int ArgOffset;
16762       SDValue ArgVal;
16763       std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
16764 
16765       if (ArgVal == InsertVal0) {
16766         ElementOffset = ArgOffset;
16767         break;
16768       }
16769 
16770       // Peek through concat_vector.
16771       if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
16772         int CurrentArgOffset =
16773             ArgOffset + ArgVal.getValueType().getVectorNumElements();
16774         int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
16775         for (SDValue Op : reverse(ArgVal->ops())) {
16776           CurrentArgOffset -= Step;
16777           ArgWorkList.emplace_back(CurrentArgOffset, Op);
16778         }
16779 
16780         // Make sure we went through all the elements and did not screw up index
16781         // computation.
16782         assert(CurrentArgOffset == ArgOffset);
16783       }
16784     }
16785 
16786     if (ElementOffset != -1) {
16787       SmallVector<int, 16> NewMask(Mask.begin(), Mask.end());
16788 
16789       auto *ExtrIndex = cast<ConstantSDNode>(InsertVal.getOperand(1));
16790       NewMask[InsIndex] = ElementOffset + ExtrIndex->getZExtValue();
16791       assert(NewMask[InsIndex] <
16792                  (int)(2 * Vec.getValueType().getVectorNumElements()) &&
16793              NewMask[InsIndex] >= 0 && "NewMask[InsIndex] is out of bound");
16794 
16795       SDValue LegalShuffle =
16796               TLI.buildLegalVectorShuffle(Vec.getValueType(), SDLoc(N), X,
16797                                           Y, NewMask, DAG);
16798       if (LegalShuffle)
16799         return LegalShuffle;
16800     }
16801   }
16802 
16803   // insert_vector_elt V, (bitcast X from vector type), IdxC -->
16804   // bitcast(shuffle (bitcast V), (extended X), Mask)
16805   // Note: We do not use an insert_subvector node because that requires a
16806   // legal subvector type.
16807   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
16808       !InsertVal.getOperand(0).getValueType().isVector())
16809     return SDValue();
16810 
16811   SDValue SubVec = InsertVal.getOperand(0);
16812   SDValue DestVec = N->getOperand(0);
16813   EVT SubVecVT = SubVec.getValueType();
16814   EVT VT = DestVec.getValueType();
16815   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
16816   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
16817   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
16818 
16819   // Step 1: Create a shuffle mask that implements this insert operation. The
16820   // vector that we are inserting into will be operand 0 of the shuffle, so
16821   // those elements are just 'i'. The inserted subvector is in the first
16822   // positions of operand 1 of the shuffle. Example:
16823   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
16824   SmallVector<int, 16> Mask(NumMaskVals);
16825   for (unsigned i = 0; i != NumMaskVals; ++i) {
16826     if (i / NumSrcElts == InsIndex)
16827       Mask[i] = (i % NumSrcElts) + NumMaskVals;
16828     else
16829       Mask[i] = i;
16830   }
16831 
16832   // Bail out if the target can not handle the shuffle we want to create.
16833   EVT SubVecEltVT = SubVecVT.getVectorElementType();
16834   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
16835   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
16836     return SDValue();
16837 
16838   // Step 2: Create a wide vector from the inserted source vector by appending
16839   // undefined elements. This is the same size as our destination vector.
16840   SDLoc DL(N);
16841   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
16842   ConcatOps[0] = SubVec;
16843   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
16844 
16845   // Step 3: Shuffle in the padded subvector.
16846   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
16847   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
16848   AddToWorklist(PaddedSubV.getNode());
16849   AddToWorklist(DestVecBC.getNode());
16850   AddToWorklist(Shuf.getNode());
16851   return DAG.getBitcast(VT, Shuf);
16852 }
16853 
visitINSERT_VECTOR_ELT(SDNode * N)16854 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
16855   SDValue InVec = N->getOperand(0);
16856   SDValue InVal = N->getOperand(1);
16857   SDValue EltNo = N->getOperand(2);
16858   SDLoc DL(N);
16859 
16860   EVT VT = InVec.getValueType();
16861   unsigned NumElts = VT.getVectorNumElements();
16862 
16863   // Insert into out-of-bounds element is undefined.
16864   if (auto *IndexC = dyn_cast<ConstantSDNode>(EltNo))
16865     if (IndexC->getZExtValue() >= VT.getVectorNumElements())
16866       return DAG.getUNDEF(VT);
16867 
16868   // Remove redundant insertions:
16869   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
16870   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
16871       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
16872     return InVec;
16873 
16874   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
16875   if (!IndexC) {
16876     // If this is variable insert to undef vector, it might be better to splat:
16877     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
16878     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
16879       SmallVector<SDValue, 8> Ops(NumElts, InVal);
16880       return DAG.getBuildVector(VT, DL, Ops);
16881     }
16882     return SDValue();
16883   }
16884 
16885   // We must know which element is being inserted for folds below here.
16886   unsigned Elt = IndexC->getZExtValue();
16887   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
16888     return Shuf;
16889 
16890   // Canonicalize insert_vector_elt dag nodes.
16891   // Example:
16892   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
16893   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
16894   //
16895   // Do this only if the child insert_vector node has one use; also
16896   // do this only if indices are both constants and Idx1 < Idx0.
16897   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
16898       && isa<ConstantSDNode>(InVec.getOperand(2))) {
16899     unsigned OtherElt = InVec.getConstantOperandVal(2);
16900     if (Elt < OtherElt) {
16901       // Swap nodes.
16902       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
16903                                   InVec.getOperand(0), InVal, EltNo);
16904       AddToWorklist(NewOp.getNode());
16905       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
16906                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
16907     }
16908   }
16909 
16910   // If we can't generate a legal BUILD_VECTOR, exit
16911   if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
16912     return SDValue();
16913 
16914   // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially
16915   // be converted to a BUILD_VECTOR).  Fill in the Ops vector with the
16916   // vector elements.
16917   SmallVector<SDValue, 8> Ops;
16918   // Do not combine these two vectors if the output vector will not replace
16919   // the input vector.
16920   if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) {
16921     Ops.append(InVec.getNode()->op_begin(),
16922                InVec.getNode()->op_end());
16923   } else if (InVec.isUndef()) {
16924     Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType()));
16925   } else {
16926     return SDValue();
16927   }
16928   assert(Ops.size() == NumElts && "Unexpected vector size");
16929 
16930   // Insert the element
16931   if (Elt < Ops.size()) {
16932     // All the operands of BUILD_VECTOR must have the same type;
16933     // we enforce that here.
16934     EVT OpVT = Ops[0].getValueType();
16935     Ops[Elt] = OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal;
16936   }
16937 
16938   // Return the new vector
16939   return DAG.getBuildVector(VT, DL, Ops);
16940 }
16941 
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)16942 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
16943                                                   SDValue EltNo,
16944                                                   LoadSDNode *OriginalLoad) {
16945   assert(OriginalLoad->isSimple());
16946 
16947   EVT ResultVT = EVE->getValueType(0);
16948   EVT VecEltVT = InVecVT.getVectorElementType();
16949   unsigned Align = OriginalLoad->getAlignment();
16950   unsigned NewAlign = DAG.getDataLayout().getABITypeAlignment(
16951       VecEltVT.getTypeForEVT(*DAG.getContext()));
16952 
16953   if (NewAlign > Align || !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT))
16954     return SDValue();
16955 
16956   ISD::LoadExtType ExtTy = ResultVT.bitsGT(VecEltVT) ?
16957     ISD::NON_EXTLOAD : ISD::EXTLOAD;
16958   if (!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
16959     return SDValue();
16960 
16961   Align = NewAlign;
16962 
16963   SDValue NewPtr = OriginalLoad->getBasePtr();
16964   SDValue Offset;
16965   EVT PtrType = NewPtr.getValueType();
16966   MachinePointerInfo MPI;
16967   SDLoc DL(EVE);
16968   if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
16969     int Elt = ConstEltNo->getZExtValue();
16970     unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
16971     Offset = DAG.getConstant(PtrOff, DL, PtrType);
16972     MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
16973   } else {
16974     Offset = DAG.getZExtOrTrunc(EltNo, DL, PtrType);
16975     Offset = DAG.getNode(
16976         ISD::MUL, DL, PtrType, Offset,
16977         DAG.getConstant(VecEltVT.getStoreSize(), DL, PtrType));
16978     // Discard the pointer info except the address space because the memory
16979     // operand can't represent this new access since the offset is variable.
16980     MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
16981   }
16982   NewPtr = DAG.getMemBasePlusOffset(NewPtr, Offset, DL);
16983 
16984   // The replacement we need to do here is a little tricky: we need to
16985   // replace an extractelement of a load with a load.
16986   // Use ReplaceAllUsesOfValuesWith to do the replacement.
16987   // Note that this replacement assumes that the extractvalue is the only
16988   // use of the load; that's okay because we don't want to perform this
16989   // transformation in other cases anyway.
16990   SDValue Load;
16991   SDValue Chain;
16992   if (ResultVT.bitsGT(VecEltVT)) {
16993     // If the result type of vextract is wider than the load, then issue an
16994     // extending load instead.
16995     ISD::LoadExtType ExtType = TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT,
16996                                                   VecEltVT)
16997                                    ? ISD::ZEXTLOAD
16998                                    : ISD::EXTLOAD;
16999     Load = DAG.getExtLoad(ExtType, SDLoc(EVE), ResultVT,
17000                           OriginalLoad->getChain(), NewPtr, MPI, VecEltVT,
17001                           Align, OriginalLoad->getMemOperand()->getFlags(),
17002                           OriginalLoad->getAAInfo());
17003     Chain = Load.getValue(1);
17004   } else {
17005     Load = DAG.getLoad(VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr,
17006                        MPI, Align, OriginalLoad->getMemOperand()->getFlags(),
17007                        OriginalLoad->getAAInfo());
17008     Chain = Load.getValue(1);
17009     if (ResultVT.bitsLT(VecEltVT))
17010       Load = DAG.getNode(ISD::TRUNCATE, SDLoc(EVE), ResultVT, Load);
17011     else
17012       Load = DAG.getBitcast(ResultVT, Load);
17013   }
17014   WorklistRemover DeadNodes(*this);
17015   SDValue From[] = { SDValue(EVE, 0), SDValue(OriginalLoad, 1) };
17016   SDValue To[] = { Load, Chain };
17017   DAG.ReplaceAllUsesOfValuesWith(From, To, 2);
17018   // Make sure to revisit this node to clean it up; it will usually be dead.
17019   AddToWorklist(EVE);
17020   // Since we're explicitly calling ReplaceAllUses, add the new node to the
17021   // worklist explicitly as well.
17022   AddToWorklistWithUsers(Load.getNode());
17023   ++OpsNarrowed;
17024   return SDValue(EVE, 0);
17025 }
17026 
17027 /// Transform a vector binary operation into a scalar binary operation by moving
17028 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)17029 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
17030                                        bool LegalOperations) {
17031   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17032   SDValue Vec = ExtElt->getOperand(0);
17033   SDValue Index = ExtElt->getOperand(1);
17034   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
17035   if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
17036       Vec.getNode()->getNumValues() != 1)
17037     return SDValue();
17038 
17039   // Targets may want to avoid this to prevent an expensive register transfer.
17040   if (!TLI.shouldScalarizeBinop(Vec))
17041     return SDValue();
17042 
17043   // Extracting an element of a vector constant is constant-folded, so this
17044   // transform is just replacing a vector op with a scalar op while moving the
17045   // extract.
17046   SDValue Op0 = Vec.getOperand(0);
17047   SDValue Op1 = Vec.getOperand(1);
17048   if (isAnyConstantBuildVector(Op0, true) ||
17049       isAnyConstantBuildVector(Op1, true)) {
17050     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
17051     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
17052     SDLoc DL(ExtElt);
17053     EVT VT = ExtElt->getValueType(0);
17054     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
17055     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
17056     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
17057   }
17058 
17059   return SDValue();
17060 }
17061 
visitEXTRACT_VECTOR_ELT(SDNode * N)17062 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
17063   SDValue VecOp = N->getOperand(0);
17064   SDValue Index = N->getOperand(1);
17065   EVT ScalarVT = N->getValueType(0);
17066   EVT VecVT = VecOp.getValueType();
17067   if (VecOp.isUndef())
17068     return DAG.getUNDEF(ScalarVT);
17069 
17070   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
17071   //
17072   // This only really matters if the index is non-constant since other combines
17073   // on the constant elements already work.
17074   SDLoc DL(N);
17075   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
17076       Index == VecOp.getOperand(2)) {
17077     SDValue Elt = VecOp.getOperand(1);
17078     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
17079   }
17080 
17081   // (vextract (scalar_to_vector val, 0) -> val
17082   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
17083     // Check if the result type doesn't match the inserted element type. A
17084     // SCALAR_TO_VECTOR may truncate the inserted element and the
17085     // EXTRACT_VECTOR_ELT may widen the extracted vector.
17086     SDValue InOp = VecOp.getOperand(0);
17087     if (InOp.getValueType() != ScalarVT) {
17088       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
17089       return DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
17090     }
17091     return InOp;
17092   }
17093 
17094   // extract_vector_elt of out-of-bounds element -> UNDEF
17095   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
17096   unsigned NumElts = VecVT.getVectorNumElements();
17097   if (IndexC && IndexC->getAPIntValue().uge(NumElts))
17098     return DAG.getUNDEF(ScalarVT);
17099 
17100   // extract_vector_elt (build_vector x, y), 1 -> y
17101   if (IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR &&
17102       TLI.isTypeLegal(VecVT) &&
17103       (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
17104     SDValue Elt = VecOp.getOperand(IndexC->getZExtValue());
17105     EVT InEltVT = Elt.getValueType();
17106 
17107     // Sometimes build_vector's scalar input types do not match result type.
17108     if (ScalarVT == InEltVT)
17109       return Elt;
17110 
17111     // TODO: It may be useful to truncate if free if the build_vector implicitly
17112     // converts.
17113   }
17114 
17115   // TODO: These transforms should not require the 'hasOneUse' restriction, but
17116   // there are regressions on multiple targets without it. We can end up with a
17117   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
17118   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
17119       VecOp.hasOneUse()) {
17120     // The vector index of the LSBs of the source depend on the endian-ness.
17121     bool IsLE = DAG.getDataLayout().isLittleEndian();
17122     unsigned ExtractIndex = IndexC->getZExtValue();
17123     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
17124     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
17125     SDValue BCSrc = VecOp.getOperand(0);
17126     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
17127       return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc);
17128 
17129     if (LegalTypes && BCSrc.getValueType().isInteger() &&
17130         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
17131       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
17132       // trunc i64 X to i32
17133       SDValue X = BCSrc.getOperand(0);
17134       assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
17135              "Extract element and scalar to vector can't change element type "
17136              "from FP to integer.");
17137       unsigned XBitWidth = X.getValueSizeInBits();
17138       unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
17139       BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
17140 
17141       // An extract element return value type can be wider than its vector
17142       // operand element type. In that case, the high bits are undefined, so
17143       // it's possible that we may need to extend rather than truncate.
17144       if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
17145         assert(XBitWidth % VecEltBitWidth == 0 &&
17146                "Scalar bitwidth must be a multiple of vector element bitwidth");
17147         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
17148       }
17149     }
17150   }
17151 
17152   if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
17153     return BO;
17154 
17155   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
17156   // We only perform this optimization before the op legalization phase because
17157   // we may introduce new vector instructions which are not backed by TD
17158   // patterns. For example on AVX, extracting elements from a wide vector
17159   // without using extract_subvector. However, if we can find an underlying
17160   // scalar value, then we can always use that.
17161   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
17162     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
17163     // Find the new index to extract from.
17164     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
17165 
17166     // Extracting an undef index is undef.
17167     if (OrigElt == -1)
17168       return DAG.getUNDEF(ScalarVT);
17169 
17170     // Select the right vector half to extract from.
17171     SDValue SVInVec;
17172     if (OrigElt < (int)NumElts) {
17173       SVInVec = VecOp.getOperand(0);
17174     } else {
17175       SVInVec = VecOp.getOperand(1);
17176       OrigElt -= NumElts;
17177     }
17178 
17179     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
17180       SDValue InOp = SVInVec.getOperand(OrigElt);
17181       if (InOp.getValueType() != ScalarVT) {
17182         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
17183         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
17184       }
17185 
17186       return InOp;
17187     }
17188 
17189     // FIXME: We should handle recursing on other vector shuffles and
17190     // scalar_to_vector here as well.
17191 
17192     if (!LegalOperations ||
17193         // FIXME: Should really be just isOperationLegalOrCustom.
17194         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
17195         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
17196       EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout());
17197       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
17198                          DAG.getConstant(OrigElt, DL, IndexTy));
17199     }
17200   }
17201 
17202   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
17203   // simplify it based on the (valid) extraction indices.
17204   if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
17205         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
17206                Use->getOperand(0) == VecOp &&
17207                isa<ConstantSDNode>(Use->getOperand(1));
17208       })) {
17209     APInt DemandedElts = APInt::getNullValue(NumElts);
17210     for (SDNode *Use : VecOp->uses()) {
17211       auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
17212       if (CstElt->getAPIntValue().ult(NumElts))
17213         DemandedElts.setBit(CstElt->getZExtValue());
17214     }
17215     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
17216       // We simplified the vector operand of this extract element. If this
17217       // extract is not dead, visit it again so it is folded properly.
17218       if (N->getOpcode() != ISD::DELETED_NODE)
17219         AddToWorklist(N);
17220       return SDValue(N, 0);
17221     }
17222   }
17223 
17224   // Everything under here is trying to match an extract of a loaded value.
17225   // If the result of load has to be truncated, then it's not necessarily
17226   // profitable.
17227   bool BCNumEltsChanged = false;
17228   EVT ExtVT = VecVT.getVectorElementType();
17229   EVT LVT = ExtVT;
17230   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
17231     return SDValue();
17232 
17233   if (VecOp.getOpcode() == ISD::BITCAST) {
17234     // Don't duplicate a load with other uses.
17235     if (!VecOp.hasOneUse())
17236       return SDValue();
17237 
17238     EVT BCVT = VecOp.getOperand(0).getValueType();
17239     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
17240       return SDValue();
17241     if (NumElts != BCVT.getVectorNumElements())
17242       BCNumEltsChanged = true;
17243     VecOp = VecOp.getOperand(0);
17244     ExtVT = BCVT.getVectorElementType();
17245   }
17246 
17247   // extract (vector load $addr), i --> load $addr + i * size
17248   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
17249       ISD::isNormalLoad(VecOp.getNode()) &&
17250       !Index->hasPredecessor(VecOp.getNode())) {
17251     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
17252     if (VecLoad && VecLoad->isSimple())
17253       return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
17254   }
17255 
17256   // Perform only after legalization to ensure build_vector / vector_shuffle
17257   // optimizations have already been done.
17258   if (!LegalOperations || !IndexC)
17259     return SDValue();
17260 
17261   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
17262   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
17263   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
17264   int Elt = IndexC->getZExtValue();
17265   LoadSDNode *LN0 = nullptr;
17266   if (ISD::isNormalLoad(VecOp.getNode())) {
17267     LN0 = cast<LoadSDNode>(VecOp);
17268   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
17269              VecOp.getOperand(0).getValueType() == ExtVT &&
17270              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
17271     // Don't duplicate a load with other uses.
17272     if (!VecOp.hasOneUse())
17273       return SDValue();
17274 
17275     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
17276   }
17277   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
17278     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
17279     // =>
17280     // (load $addr+1*size)
17281 
17282     // Don't duplicate a load with other uses.
17283     if (!VecOp.hasOneUse())
17284       return SDValue();
17285 
17286     // If the bit convert changed the number of elements, it is unsafe
17287     // to examine the mask.
17288     if (BCNumEltsChanged)
17289       return SDValue();
17290 
17291     // Select the input vector, guarding against out of range extract vector.
17292     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
17293     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
17294 
17295     if (VecOp.getOpcode() == ISD::BITCAST) {
17296       // Don't duplicate a load with other uses.
17297       if (!VecOp.hasOneUse())
17298         return SDValue();
17299 
17300       VecOp = VecOp.getOperand(0);
17301     }
17302     if (ISD::isNormalLoad(VecOp.getNode())) {
17303       LN0 = cast<LoadSDNode>(VecOp);
17304       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
17305       Index = DAG.getConstant(Elt, DL, Index.getValueType());
17306     }
17307   }
17308 
17309   // Make sure we found a non-volatile load and the extractelement is
17310   // the only use.
17311   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
17312     return SDValue();
17313 
17314   // If Idx was -1 above, Elt is going to be -1, so just return undef.
17315   if (Elt == -1)
17316     return DAG.getUNDEF(LVT);
17317 
17318   return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
17319 }
17320 
17321 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)17322 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
17323   // We perform this optimization post type-legalization because
17324   // the type-legalizer often scalarizes integer-promoted vectors.
17325   // Performing this optimization before may create bit-casts which
17326   // will be type-legalized to complex code sequences.
17327   // We perform this optimization only before the operation legalizer because we
17328   // may introduce illegal operations.
17329   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
17330     return SDValue();
17331 
17332   unsigned NumInScalars = N->getNumOperands();
17333   SDLoc DL(N);
17334   EVT VT = N->getValueType(0);
17335 
17336   // Check to see if this is a BUILD_VECTOR of a bunch of values
17337   // which come from any_extend or zero_extend nodes. If so, we can create
17338   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
17339   // optimizations. We do not handle sign-extend because we can't fill the sign
17340   // using shuffles.
17341   EVT SourceType = MVT::Other;
17342   bool AllAnyExt = true;
17343 
17344   for (unsigned i = 0; i != NumInScalars; ++i) {
17345     SDValue In = N->getOperand(i);
17346     // Ignore undef inputs.
17347     if (In.isUndef()) continue;
17348 
17349     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
17350     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
17351 
17352     // Abort if the element is not an extension.
17353     if (!ZeroExt && !AnyExt) {
17354       SourceType = MVT::Other;
17355       break;
17356     }
17357 
17358     // The input is a ZeroExt or AnyExt. Check the original type.
17359     EVT InTy = In.getOperand(0).getValueType();
17360 
17361     // Check that all of the widened source types are the same.
17362     if (SourceType == MVT::Other)
17363       // First time.
17364       SourceType = InTy;
17365     else if (InTy != SourceType) {
17366       // Multiple income types. Abort.
17367       SourceType = MVT::Other;
17368       break;
17369     }
17370 
17371     // Check if all of the extends are ANY_EXTENDs.
17372     AllAnyExt &= AnyExt;
17373   }
17374 
17375   // In order to have valid types, all of the inputs must be extended from the
17376   // same source type and all of the inputs must be any or zero extend.
17377   // Scalar sizes must be a power of two.
17378   EVT OutScalarTy = VT.getScalarType();
17379   bool ValidTypes = SourceType != MVT::Other &&
17380                  isPowerOf2_32(OutScalarTy.getSizeInBits()) &&
17381                  isPowerOf2_32(SourceType.getSizeInBits());
17382 
17383   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
17384   // turn into a single shuffle instruction.
17385   if (!ValidTypes)
17386     return SDValue();
17387 
17388   bool isLE = DAG.getDataLayout().isLittleEndian();
17389   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
17390   assert(ElemRatio > 1 && "Invalid element size ratio");
17391   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
17392                                DAG.getConstant(0, DL, SourceType);
17393 
17394   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
17395   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
17396 
17397   // Populate the new build_vector
17398   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
17399     SDValue Cast = N->getOperand(i);
17400     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
17401             Cast.getOpcode() == ISD::ZERO_EXTEND ||
17402             Cast.isUndef()) && "Invalid cast opcode");
17403     SDValue In;
17404     if (Cast.isUndef())
17405       In = DAG.getUNDEF(SourceType);
17406     else
17407       In = Cast->getOperand(0);
17408     unsigned Index = isLE ? (i * ElemRatio) :
17409                             (i * ElemRatio + (ElemRatio - 1));
17410 
17411     assert(Index < Ops.size() && "Invalid index");
17412     Ops[Index] = In;
17413   }
17414 
17415   // The type of the new BUILD_VECTOR node.
17416   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
17417   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
17418          "Invalid vector size");
17419   // Check if the new vector type is legal.
17420   if (!isTypeLegal(VecVT) ||
17421       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
17422        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
17423     return SDValue();
17424 
17425   // Make the new BUILD_VECTOR.
17426   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
17427 
17428   // The new BUILD_VECTOR node has the potential to be further optimized.
17429   AddToWorklist(BV.getNode());
17430   // Bitcast to the desired type.
17431   return DAG.getBitcast(VT, BV);
17432 }
17433 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)17434 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
17435                                            ArrayRef<int> VectorMask,
17436                                            SDValue VecIn1, SDValue VecIn2,
17437                                            unsigned LeftIdx, bool DidSplitVec) {
17438   MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
17439   SDValue ZeroIdx = DAG.getConstant(0, DL, IdxTy);
17440 
17441   EVT VT = N->getValueType(0);
17442   EVT InVT1 = VecIn1.getValueType();
17443   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
17444 
17445   unsigned NumElems = VT.getVectorNumElements();
17446   unsigned ShuffleNumElems = NumElems;
17447 
17448   // If we artificially split a vector in two already, then the offsets in the
17449   // operands will all be based off of VecIn1, even those in VecIn2.
17450   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
17451 
17452   // We can't generate a shuffle node with mismatched input and output types.
17453   // Try to make the types match the type of the output.
17454   if (InVT1 != VT || InVT2 != VT) {
17455     if ((VT.getSizeInBits() % InVT1.getSizeInBits() == 0) && InVT1 == InVT2) {
17456       // If the output vector length is a multiple of both input lengths,
17457       // we can concatenate them and pad the rest with undefs.
17458       unsigned NumConcats = VT.getSizeInBits() / InVT1.getSizeInBits();
17459       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
17460       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
17461       ConcatOps[0] = VecIn1;
17462       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
17463       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
17464       VecIn2 = SDValue();
17465     } else if (InVT1.getSizeInBits() == VT.getSizeInBits() * 2) {
17466       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
17467         return SDValue();
17468 
17469       if (!VecIn2.getNode()) {
17470         // If we only have one input vector, and it's twice the size of the
17471         // output, split it in two.
17472         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
17473                              DAG.getConstant(NumElems, DL, IdxTy));
17474         VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
17475         // Since we now have shorter input vectors, adjust the offset of the
17476         // second vector's start.
17477         Vec2Offset = NumElems;
17478       } else if (InVT2.getSizeInBits() <= InVT1.getSizeInBits()) {
17479         // VecIn1 is wider than the output, and we have another, possibly
17480         // smaller input. Pad the smaller input with undefs, shuffle at the
17481         // input vector width, and extract the output.
17482         // The shuffle type is different than VT, so check legality again.
17483         if (LegalOperations &&
17484             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
17485           return SDValue();
17486 
17487         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
17488         // lower it back into a BUILD_VECTOR. So if the inserted type is
17489         // illegal, don't even try.
17490         if (InVT1 != InVT2) {
17491           if (!TLI.isTypeLegal(InVT2))
17492             return SDValue();
17493           VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
17494                                DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
17495         }
17496         ShuffleNumElems = NumElems * 2;
17497       } else {
17498         // Both VecIn1 and VecIn2 are wider than the output, and VecIn2 is wider
17499         // than VecIn1. We can't handle this for now - this case will disappear
17500         // when we start sorting the vectors by type.
17501         return SDValue();
17502       }
17503     } else if (InVT2.getSizeInBits() * 2 == VT.getSizeInBits() &&
17504                InVT1.getSizeInBits() == VT.getSizeInBits()) {
17505       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
17506       ConcatOps[0] = VecIn2;
17507       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
17508     } else {
17509       // TODO: Support cases where the length mismatch isn't exactly by a
17510       // factor of 2.
17511       // TODO: Move this check upwards, so that if we have bad type
17512       // mismatches, we don't create any DAG nodes.
17513       return SDValue();
17514     }
17515   }
17516 
17517   // Initialize mask to undef.
17518   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
17519 
17520   // Only need to run up to the number of elements actually used, not the
17521   // total number of elements in the shuffle - if we are shuffling a wider
17522   // vector, the high lanes should be set to undef.
17523   for (unsigned i = 0; i != NumElems; ++i) {
17524     if (VectorMask[i] <= 0)
17525       continue;
17526 
17527     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
17528     if (VectorMask[i] == (int)LeftIdx) {
17529       Mask[i] = ExtIndex;
17530     } else if (VectorMask[i] == (int)LeftIdx + 1) {
17531       Mask[i] = Vec2Offset + ExtIndex;
17532     }
17533   }
17534 
17535   // The type the input vectors may have changed above.
17536   InVT1 = VecIn1.getValueType();
17537 
17538   // If we already have a VecIn2, it should have the same type as VecIn1.
17539   // If we don't, get an undef/zero vector of the appropriate type.
17540   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
17541   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
17542 
17543   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
17544   if (ShuffleNumElems > NumElems)
17545     Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
17546 
17547   return Shuffle;
17548 }
17549 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)17550 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
17551   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
17552 
17553   // First, determine where the build vector is not undef.
17554   // TODO: We could extend this to handle zero elements as well as undefs.
17555   int NumBVOps = BV->getNumOperands();
17556   int ZextElt = -1;
17557   for (int i = 0; i != NumBVOps; ++i) {
17558     SDValue Op = BV->getOperand(i);
17559     if (Op.isUndef())
17560       continue;
17561     if (ZextElt == -1)
17562       ZextElt = i;
17563     else
17564       return SDValue();
17565   }
17566   // Bail out if there's no non-undef element.
17567   if (ZextElt == -1)
17568     return SDValue();
17569 
17570   // The build vector contains some number of undef elements and exactly
17571   // one other element. That other element must be a zero-extended scalar
17572   // extracted from a vector at a constant index to turn this into a shuffle.
17573   // Also, require that the build vector does not implicitly truncate/extend
17574   // its elements.
17575   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
17576   EVT VT = BV->getValueType(0);
17577   SDValue Zext = BV->getOperand(ZextElt);
17578   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
17579       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
17580       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
17581       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
17582     return SDValue();
17583 
17584   // The zero-extend must be a multiple of the source size, and we must be
17585   // building a vector of the same size as the source of the extract element.
17586   SDValue Extract = Zext.getOperand(0);
17587   unsigned DestSize = Zext.getValueSizeInBits();
17588   unsigned SrcSize = Extract.getValueSizeInBits();
17589   if (DestSize % SrcSize != 0 ||
17590       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
17591     return SDValue();
17592 
17593   // Create a shuffle mask that will combine the extracted element with zeros
17594   // and undefs.
17595   int ZextRatio = DestSize / SrcSize;
17596   int NumMaskElts = NumBVOps * ZextRatio;
17597   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
17598   for (int i = 0; i != NumMaskElts; ++i) {
17599     if (i / ZextRatio == ZextElt) {
17600       // The low bits of the (potentially translated) extracted element map to
17601       // the source vector. The high bits map to zero. We will use a zero vector
17602       // as the 2nd source operand of the shuffle, so use the 1st element of
17603       // that vector (mask value is number-of-elements) for the high bits.
17604       if (i % ZextRatio == 0)
17605         ShufMask[i] = Extract.getConstantOperandVal(1);
17606       else
17607         ShufMask[i] = NumMaskElts;
17608     }
17609 
17610     // Undef elements of the build vector remain undef because we initialize
17611     // the shuffle mask with -1.
17612   }
17613 
17614   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
17615   // bitcast (shuffle V, ZeroVec, VectorMask)
17616   SDLoc DL(BV);
17617   EVT VecVT = Extract.getOperand(0).getValueType();
17618   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
17619   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17620   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
17621                                              ZeroVec, ShufMask, DAG);
17622   if (!Shuf)
17623     return SDValue();
17624   return DAG.getBitcast(VT, Shuf);
17625 }
17626 
17627 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
17628 // operations. If the types of the vectors we're extracting from allow it,
17629 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)17630 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
17631   SDLoc DL(N);
17632   EVT VT = N->getValueType(0);
17633 
17634   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
17635   if (!isTypeLegal(VT))
17636     return SDValue();
17637 
17638   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
17639     return V;
17640 
17641   // May only combine to shuffle after legalize if shuffle is legal.
17642   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
17643     return SDValue();
17644 
17645   bool UsesZeroVector = false;
17646   unsigned NumElems = N->getNumOperands();
17647 
17648   // Record, for each element of the newly built vector, which input vector
17649   // that element comes from. -1 stands for undef, 0 for the zero vector,
17650   // and positive values for the input vectors.
17651   // VectorMask maps each element to its vector number, and VecIn maps vector
17652   // numbers to their initial SDValues.
17653 
17654   SmallVector<int, 8> VectorMask(NumElems, -1);
17655   SmallVector<SDValue, 8> VecIn;
17656   VecIn.push_back(SDValue());
17657 
17658   for (unsigned i = 0; i != NumElems; ++i) {
17659     SDValue Op = N->getOperand(i);
17660 
17661     if (Op.isUndef())
17662       continue;
17663 
17664     // See if we can use a blend with a zero vector.
17665     // TODO: Should we generalize this to a blend with an arbitrary constant
17666     // vector?
17667     if (isNullConstant(Op) || isNullFPConstant(Op)) {
17668       UsesZeroVector = true;
17669       VectorMask[i] = 0;
17670       continue;
17671     }
17672 
17673     // Not an undef or zero. If the input is something other than an
17674     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
17675     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
17676         !isa<ConstantSDNode>(Op.getOperand(1)))
17677       return SDValue();
17678     SDValue ExtractedFromVec = Op.getOperand(0);
17679 
17680     const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
17681     if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
17682       return SDValue();
17683 
17684     // All inputs must have the same element type as the output.
17685     if (VT.getVectorElementType() !=
17686         ExtractedFromVec.getValueType().getVectorElementType())
17687       return SDValue();
17688 
17689     // Have we seen this input vector before?
17690     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
17691     // a map back from SDValues to numbers isn't worth it.
17692     unsigned Idx = std::distance(
17693         VecIn.begin(), std::find(VecIn.begin(), VecIn.end(), ExtractedFromVec));
17694     if (Idx == VecIn.size())
17695       VecIn.push_back(ExtractedFromVec);
17696 
17697     VectorMask[i] = Idx;
17698   }
17699 
17700   // If we didn't find at least one input vector, bail out.
17701   if (VecIn.size() < 2)
17702     return SDValue();
17703 
17704   // If all the Operands of BUILD_VECTOR extract from same
17705   // vector, then split the vector efficiently based on the maximum
17706   // vector access index and adjust the VectorMask and
17707   // VecIn accordingly.
17708   bool DidSplitVec = false;
17709   if (VecIn.size() == 2) {
17710     unsigned MaxIndex = 0;
17711     unsigned NearestPow2 = 0;
17712     SDValue Vec = VecIn.back();
17713     EVT InVT = Vec.getValueType();
17714     MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
17715     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
17716 
17717     for (unsigned i = 0; i < NumElems; i++) {
17718       if (VectorMask[i] <= 0)
17719         continue;
17720       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
17721       IndexVec[i] = Index;
17722       MaxIndex = std::max(MaxIndex, Index);
17723     }
17724 
17725     NearestPow2 = PowerOf2Ceil(MaxIndex);
17726     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
17727         NumElems * 2 < NearestPow2) {
17728       unsigned SplitSize = NearestPow2 / 2;
17729       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
17730                                      InVT.getVectorElementType(), SplitSize);
17731       if (TLI.isTypeLegal(SplitVT)) {
17732         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
17733                                      DAG.getConstant(SplitSize, DL, IdxTy));
17734         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
17735                                      DAG.getConstant(0, DL, IdxTy));
17736         VecIn.pop_back();
17737         VecIn.push_back(VecIn1);
17738         VecIn.push_back(VecIn2);
17739         DidSplitVec = true;
17740 
17741         for (unsigned i = 0; i < NumElems; i++) {
17742           if (VectorMask[i] <= 0)
17743             continue;
17744           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
17745         }
17746       }
17747     }
17748   }
17749 
17750   // TODO: We want to sort the vectors by descending length, so that adjacent
17751   // pairs have similar length, and the longer vector is always first in the
17752   // pair.
17753 
17754   // TODO: Should this fire if some of the input vectors has illegal type (like
17755   // it does now), or should we let legalization run its course first?
17756 
17757   // Shuffle phase:
17758   // Take pairs of vectors, and shuffle them so that the result has elements
17759   // from these vectors in the correct places.
17760   // For example, given:
17761   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
17762   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
17763   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
17764   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
17765   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
17766   // We will generate:
17767   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
17768   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
17769   SmallVector<SDValue, 4> Shuffles;
17770   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
17771     unsigned LeftIdx = 2 * In + 1;
17772     SDValue VecLeft = VecIn[LeftIdx];
17773     SDValue VecRight =
17774         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
17775 
17776     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
17777                                                 VecRight, LeftIdx, DidSplitVec))
17778       Shuffles.push_back(Shuffle);
17779     else
17780       return SDValue();
17781   }
17782 
17783   // If we need the zero vector as an "ingredient" in the blend tree, add it
17784   // to the list of shuffles.
17785   if (UsesZeroVector)
17786     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
17787                                       : DAG.getConstantFP(0.0, DL, VT));
17788 
17789   // If we only have one shuffle, we're done.
17790   if (Shuffles.size() == 1)
17791     return Shuffles[0];
17792 
17793   // Update the vector mask to point to the post-shuffle vectors.
17794   for (int &Vec : VectorMask)
17795     if (Vec == 0)
17796       Vec = Shuffles.size() - 1;
17797     else
17798       Vec = (Vec - 1) / 2;
17799 
17800   // More than one shuffle. Generate a binary tree of blends, e.g. if from
17801   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
17802   // generate:
17803   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
17804   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
17805   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
17806   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
17807   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
17808   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
17809   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
17810 
17811   // Make sure the initial size of the shuffle list is even.
17812   if (Shuffles.size() % 2)
17813     Shuffles.push_back(DAG.getUNDEF(VT));
17814 
17815   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
17816     if (CurSize % 2) {
17817       Shuffles[CurSize] = DAG.getUNDEF(VT);
17818       CurSize++;
17819     }
17820     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
17821       int Left = 2 * In;
17822       int Right = 2 * In + 1;
17823       SmallVector<int, 8> Mask(NumElems, -1);
17824       for (unsigned i = 0; i != NumElems; ++i) {
17825         if (VectorMask[i] == Left) {
17826           Mask[i] = i;
17827           VectorMask[i] = In;
17828         } else if (VectorMask[i] == Right) {
17829           Mask[i] = i + NumElems;
17830           VectorMask[i] = In;
17831         }
17832       }
17833 
17834       Shuffles[In] =
17835           DAG.getVectorShuffle(VT, DL, Shuffles[Left], Shuffles[Right], Mask);
17836     }
17837   }
17838   return Shuffles[0];
17839 }
17840 
17841 // Try to turn a build vector of zero extends of extract vector elts into a
17842 // a vector zero extend and possibly an extract subvector.
17843 // TODO: Support sign extend?
17844 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)17845 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
17846   if (LegalOperations)
17847     return SDValue();
17848 
17849   EVT VT = N->getValueType(0);
17850 
17851   bool FoundZeroExtend = false;
17852   SDValue Op0 = N->getOperand(0);
17853   auto checkElem = [&](SDValue Op) -> int64_t {
17854     unsigned Opc = Op.getOpcode();
17855     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
17856     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
17857         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
17858         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
17859       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
17860         return C->getZExtValue();
17861     return -1;
17862   };
17863 
17864   // Make sure the first element matches
17865   // (zext (extract_vector_elt X, C))
17866   int64_t Offset = checkElem(Op0);
17867   if (Offset < 0)
17868     return SDValue();
17869 
17870   unsigned NumElems = N->getNumOperands();
17871   SDValue In = Op0.getOperand(0).getOperand(0);
17872   EVT InSVT = In.getValueType().getScalarType();
17873   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
17874 
17875   // Don't create an illegal input type after type legalization.
17876   if (LegalTypes && !TLI.isTypeLegal(InVT))
17877     return SDValue();
17878 
17879   // Ensure all the elements come from the same vector and are adjacent.
17880   for (unsigned i = 1; i != NumElems; ++i) {
17881     if ((Offset + i) != checkElem(N->getOperand(i)))
17882       return SDValue();
17883   }
17884 
17885   SDLoc DL(N);
17886   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
17887                    Op0.getOperand(0).getOperand(1));
17888   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
17889                      VT, In);
17890 }
17891 
visitBUILD_VECTOR(SDNode * N)17892 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
17893   EVT VT = N->getValueType(0);
17894 
17895   // A vector built entirely of undefs is undef.
17896   if (ISD::allOperandsUndef(N))
17897     return DAG.getUNDEF(VT);
17898 
17899   // If this is a splat of a bitcast from another vector, change to a
17900   // concat_vector.
17901   // For example:
17902   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
17903   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
17904   //
17905   // If X is a build_vector itself, the concat can become a larger build_vector.
17906   // TODO: Maybe this is useful for non-splat too?
17907   if (!LegalOperations) {
17908     if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
17909       Splat = peekThroughBitcasts(Splat);
17910       EVT SrcVT = Splat.getValueType();
17911       if (SrcVT.isVector()) {
17912         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
17913         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
17914                                      SrcVT.getVectorElementType(), NumElts);
17915         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
17916           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
17917           SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
17918                                        NewVT, Ops);
17919           return DAG.getBitcast(VT, Concat);
17920         }
17921       }
17922     }
17923   }
17924 
17925   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
17926   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
17927     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
17928       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
17929       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
17930     }
17931 
17932   // Check if we can express BUILD VECTOR via subvector extract.
17933   if (!LegalTypes && (N->getNumOperands() > 1)) {
17934     SDValue Op0 = N->getOperand(0);
17935     auto checkElem = [&](SDValue Op) -> uint64_t {
17936       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
17937           (Op0.getOperand(0) == Op.getOperand(0)))
17938         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
17939           return CNode->getZExtValue();
17940       return -1;
17941     };
17942 
17943     int Offset = checkElem(Op0);
17944     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
17945       if (Offset + i != checkElem(N->getOperand(i))) {
17946         Offset = -1;
17947         break;
17948       }
17949     }
17950 
17951     if ((Offset == 0) &&
17952         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
17953       return Op0.getOperand(0);
17954     if ((Offset != -1) &&
17955         ((Offset % N->getValueType(0).getVectorNumElements()) ==
17956          0)) // IDX must be multiple of output size.
17957       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
17958                          Op0.getOperand(0), Op0.getOperand(1));
17959   }
17960 
17961   if (SDValue V = convertBuildVecZextToZext(N))
17962     return V;
17963 
17964   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
17965     return V;
17966 
17967   if (SDValue V = reduceBuildVecToShuffle(N))
17968     return V;
17969 
17970   return SDValue();
17971 }
17972 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)17973 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
17974   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17975   EVT OpVT = N->getOperand(0).getValueType();
17976 
17977   // If the operands are legal vectors, leave them alone.
17978   if (TLI.isTypeLegal(OpVT))
17979     return SDValue();
17980 
17981   SDLoc DL(N);
17982   EVT VT = N->getValueType(0);
17983   SmallVector<SDValue, 8> Ops;
17984 
17985   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
17986   SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
17987 
17988   // Keep track of what we encounter.
17989   bool AnyInteger = false;
17990   bool AnyFP = false;
17991   for (const SDValue &Op : N->ops()) {
17992     if (ISD::BITCAST == Op.getOpcode() &&
17993         !Op.getOperand(0).getValueType().isVector())
17994       Ops.push_back(Op.getOperand(0));
17995     else if (ISD::UNDEF == Op.getOpcode())
17996       Ops.push_back(ScalarUndef);
17997     else
17998       return SDValue();
17999 
18000     // Note whether we encounter an integer or floating point scalar.
18001     // If it's neither, bail out, it could be something weird like x86mmx.
18002     EVT LastOpVT = Ops.back().getValueType();
18003     if (LastOpVT.isFloatingPoint())
18004       AnyFP = true;
18005     else if (LastOpVT.isInteger())
18006       AnyInteger = true;
18007     else
18008       return SDValue();
18009   }
18010 
18011   // If any of the operands is a floating point scalar bitcast to a vector,
18012   // use floating point types throughout, and bitcast everything.
18013   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
18014   if (AnyFP) {
18015     SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
18016     ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
18017     if (AnyInteger) {
18018       for (SDValue &Op : Ops) {
18019         if (Op.getValueType() == SVT)
18020           continue;
18021         if (Op.isUndef())
18022           Op = ScalarUndef;
18023         else
18024           Op = DAG.getBitcast(SVT, Op);
18025       }
18026     }
18027   }
18028 
18029   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
18030                                VT.getSizeInBits() / SVT.getSizeInBits());
18031   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
18032 }
18033 
18034 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
18035 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
18036 // most two distinct vectors the same size as the result, attempt to turn this
18037 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)18038 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
18039   EVT VT = N->getValueType(0);
18040   EVT OpVT = N->getOperand(0).getValueType();
18041   int NumElts = VT.getVectorNumElements();
18042   int NumOpElts = OpVT.getVectorNumElements();
18043 
18044   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
18045   SmallVector<int, 8> Mask;
18046 
18047   for (SDValue Op : N->ops()) {
18048     Op = peekThroughBitcasts(Op);
18049 
18050     // UNDEF nodes convert to UNDEF shuffle mask values.
18051     if (Op.isUndef()) {
18052       Mask.append((unsigned)NumOpElts, -1);
18053       continue;
18054     }
18055 
18056     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
18057       return SDValue();
18058 
18059     // What vector are we extracting the subvector from and at what index?
18060     SDValue ExtVec = Op.getOperand(0);
18061 
18062     // We want the EVT of the original extraction to correctly scale the
18063     // extraction index.
18064     EVT ExtVT = ExtVec.getValueType();
18065     ExtVec = peekThroughBitcasts(ExtVec);
18066 
18067     // UNDEF nodes convert to UNDEF shuffle mask values.
18068     if (ExtVec.isUndef()) {
18069       Mask.append((unsigned)NumOpElts, -1);
18070       continue;
18071     }
18072 
18073     if (!isa<ConstantSDNode>(Op.getOperand(1)))
18074       return SDValue();
18075     int ExtIdx = Op.getConstantOperandVal(1);
18076 
18077     // Ensure that we are extracting a subvector from a vector the same
18078     // size as the result.
18079     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
18080       return SDValue();
18081 
18082     // Scale the subvector index to account for any bitcast.
18083     int NumExtElts = ExtVT.getVectorNumElements();
18084     if (0 == (NumExtElts % NumElts))
18085       ExtIdx /= (NumExtElts / NumElts);
18086     else if (0 == (NumElts % NumExtElts))
18087       ExtIdx *= (NumElts / NumExtElts);
18088     else
18089       return SDValue();
18090 
18091     // At most we can reference 2 inputs in the final shuffle.
18092     if (SV0.isUndef() || SV0 == ExtVec) {
18093       SV0 = ExtVec;
18094       for (int i = 0; i != NumOpElts; ++i)
18095         Mask.push_back(i + ExtIdx);
18096     } else if (SV1.isUndef() || SV1 == ExtVec) {
18097       SV1 = ExtVec;
18098       for (int i = 0; i != NumOpElts; ++i)
18099         Mask.push_back(i + ExtIdx + NumElts);
18100     } else {
18101       return SDValue();
18102     }
18103   }
18104 
18105   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18106   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
18107                                      DAG.getBitcast(VT, SV1), Mask, DAG);
18108 }
18109 
visitCONCAT_VECTORS(SDNode * N)18110 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
18111   // If we only have one input vector, we don't need to do any concatenation.
18112   if (N->getNumOperands() == 1)
18113     return N->getOperand(0);
18114 
18115   // Check if all of the operands are undefs.
18116   EVT VT = N->getValueType(0);
18117   if (ISD::allOperandsUndef(N))
18118     return DAG.getUNDEF(VT);
18119 
18120   // Optimize concat_vectors where all but the first of the vectors are undef.
18121   if (std::all_of(std::next(N->op_begin()), N->op_end(), [](const SDValue &Op) {
18122         return Op.isUndef();
18123       })) {
18124     SDValue In = N->getOperand(0);
18125     assert(In.getValueType().isVector() && "Must concat vectors");
18126 
18127     // If the input is a concat_vectors, just make a larger concat by padding
18128     // with smaller undefs.
18129     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) {
18130       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
18131       SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
18132       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
18133       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
18134     }
18135 
18136     SDValue Scalar = peekThroughOneUseBitcasts(In);
18137 
18138     // concat_vectors(scalar_to_vector(scalar), undef) ->
18139     //     scalar_to_vector(scalar)
18140     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
18141          Scalar.hasOneUse()) {
18142       EVT SVT = Scalar.getValueType().getVectorElementType();
18143       if (SVT == Scalar.getOperand(0).getValueType())
18144         Scalar = Scalar.getOperand(0);
18145     }
18146 
18147     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
18148     if (!Scalar.getValueType().isVector()) {
18149       // If the bitcast type isn't legal, it might be a trunc of a legal type;
18150       // look through the trunc so we can still do the transform:
18151       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
18152       if (Scalar->getOpcode() == ISD::TRUNCATE &&
18153           !TLI.isTypeLegal(Scalar.getValueType()) &&
18154           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
18155         Scalar = Scalar->getOperand(0);
18156 
18157       EVT SclTy = Scalar.getValueType();
18158 
18159       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
18160         return SDValue();
18161 
18162       // Bail out if the vector size is not a multiple of the scalar size.
18163       if (VT.getSizeInBits() % SclTy.getSizeInBits())
18164         return SDValue();
18165 
18166       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
18167       if (VNTNumElms < 2)
18168         return SDValue();
18169 
18170       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
18171       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
18172         return SDValue();
18173 
18174       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
18175       return DAG.getBitcast(VT, Res);
18176     }
18177   }
18178 
18179   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
18180   // We have already tested above for an UNDEF only concatenation.
18181   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
18182   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
18183   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
18184     return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
18185   };
18186   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
18187     SmallVector<SDValue, 8> Opnds;
18188     EVT SVT = VT.getScalarType();
18189 
18190     EVT MinVT = SVT;
18191     if (!SVT.isFloatingPoint()) {
18192       // If BUILD_VECTOR are from built from integer, they may have different
18193       // operand types. Get the smallest type and truncate all operands to it.
18194       bool FoundMinVT = false;
18195       for (const SDValue &Op : N->ops())
18196         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
18197           EVT OpSVT = Op.getOperand(0).getValueType();
18198           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
18199           FoundMinVT = true;
18200         }
18201       assert(FoundMinVT && "Concat vector type mismatch");
18202     }
18203 
18204     for (const SDValue &Op : N->ops()) {
18205       EVT OpVT = Op.getValueType();
18206       unsigned NumElts = OpVT.getVectorNumElements();
18207 
18208       if (ISD::UNDEF == Op.getOpcode())
18209         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
18210 
18211       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
18212         if (SVT.isFloatingPoint()) {
18213           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
18214           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
18215         } else {
18216           for (unsigned i = 0; i != NumElts; ++i)
18217             Opnds.push_back(
18218                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
18219         }
18220       }
18221     }
18222 
18223     assert(VT.getVectorNumElements() == Opnds.size() &&
18224            "Concat vector type mismatch");
18225     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
18226   }
18227 
18228   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
18229   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
18230     return V;
18231 
18232   // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
18233   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT))
18234     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
18235       return V;
18236 
18237   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
18238   // nodes often generate nop CONCAT_VECTOR nodes.
18239   // Scan the CONCAT_VECTOR operands and look for a CONCAT operations that
18240   // place the incoming vectors at the exact same location.
18241   SDValue SingleSource = SDValue();
18242   unsigned PartNumElem = N->getOperand(0).getValueType().getVectorNumElements();
18243 
18244   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
18245     SDValue Op = N->getOperand(i);
18246 
18247     if (Op.isUndef())
18248       continue;
18249 
18250     // Check if this is the identity extract:
18251     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
18252       return SDValue();
18253 
18254     // Find the single incoming vector for the extract_subvector.
18255     if (SingleSource.getNode()) {
18256       if (Op.getOperand(0) != SingleSource)
18257         return SDValue();
18258     } else {
18259       SingleSource = Op.getOperand(0);
18260 
18261       // Check the source type is the same as the type of the result.
18262       // If not, this concat may extend the vector, so we can not
18263       // optimize it away.
18264       if (SingleSource.getValueType() != N->getValueType(0))
18265         return SDValue();
18266     }
18267 
18268     auto *CS = dyn_cast<ConstantSDNode>(Op.getOperand(1));
18269     // The extract index must be constant.
18270     if (!CS)
18271       return SDValue();
18272 
18273     // Check that we are reading from the identity index.
18274     unsigned IdentityIndex = i * PartNumElem;
18275     if (CS->getAPIntValue() != IdentityIndex)
18276       return SDValue();
18277   }
18278 
18279   if (SingleSource.getNode())
18280     return SingleSource;
18281 
18282   return SDValue();
18283 }
18284 
18285 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
18286 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)18287 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
18288   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
18289       V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
18290     return V.getOperand(1);
18291   }
18292   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
18293   if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
18294       V.getOperand(0).getValueType() == SubVT &&
18295       (IndexC->getZExtValue() % SubVT.getVectorNumElements()) == 0) {
18296     uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorNumElements();
18297     return V.getOperand(SubIdx);
18298   }
18299   return SDValue();
18300 }
18301 
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG)18302 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
18303                                               SelectionDAG &DAG) {
18304   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18305   SDValue BinOp = Extract->getOperand(0);
18306   unsigned BinOpcode = BinOp.getOpcode();
18307   if (!TLI.isBinOp(BinOpcode) || BinOp.getNode()->getNumValues() != 1)
18308     return SDValue();
18309 
18310   EVT VecVT = BinOp.getValueType();
18311   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
18312   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
18313     return SDValue();
18314 
18315   SDValue Index = Extract->getOperand(1);
18316   EVT SubVT = Extract->getValueType(0);
18317   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT))
18318     return SDValue();
18319 
18320   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
18321   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
18322 
18323   // TODO: We could handle the case where only 1 operand is being inserted by
18324   //       creating an extract of the other operand, but that requires checking
18325   //       number of uses and/or costs.
18326   if (!Sub0 || !Sub1)
18327     return SDValue();
18328 
18329   // We are inserting both operands of the wide binop only to extract back
18330   // to the narrow vector size. Eliminate all of the insert/extract:
18331   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
18332   return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
18333                      BinOp->getFlags());
18334 }
18335 
18336 /// If we are extracting a subvector produced by a wide binary operator try
18337 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG)18338 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) {
18339   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
18340   // some of these bailouts with other transforms.
18341 
18342   if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG))
18343     return V;
18344 
18345   // The extract index must be a constant, so we can map it to a concat operand.
18346   auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
18347   if (!ExtractIndexC)
18348     return SDValue();
18349 
18350   // We are looking for an optionally bitcasted wide vector binary operator
18351   // feeding an extract subvector.
18352   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18353   SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
18354   unsigned BOpcode = BinOp.getOpcode();
18355   if (!TLI.isBinOp(BOpcode) || BinOp.getNode()->getNumValues() != 1)
18356     return SDValue();
18357 
18358   // The binop must be a vector type, so we can extract some fraction of it.
18359   EVT WideBVT = BinOp.getValueType();
18360   if (!WideBVT.isVector())
18361     return SDValue();
18362 
18363   EVT VT = Extract->getValueType(0);
18364   unsigned ExtractIndex = ExtractIndexC->getZExtValue();
18365   assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
18366          "Extract index is not a multiple of the vector length.");
18367 
18368   // Bail out if this is not a proper multiple width extraction.
18369   unsigned WideWidth = WideBVT.getSizeInBits();
18370   unsigned NarrowWidth = VT.getSizeInBits();
18371   if (WideWidth % NarrowWidth != 0)
18372     return SDValue();
18373 
18374   // Bail out if we are extracting a fraction of a single operation. This can
18375   // occur because we potentially looked through a bitcast of the binop.
18376   unsigned NarrowingRatio = WideWidth / NarrowWidth;
18377   unsigned WideNumElts = WideBVT.getVectorNumElements();
18378   if (WideNumElts % NarrowingRatio != 0)
18379     return SDValue();
18380 
18381   // Bail out if the target does not support a narrower version of the binop.
18382   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
18383                                    WideNumElts / NarrowingRatio);
18384   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
18385     return SDValue();
18386 
18387   // If extraction is cheap, we don't need to look at the binop operands
18388   // for concat ops. The narrow binop alone makes this transform profitable.
18389   // We can't just reuse the original extract index operand because we may have
18390   // bitcasted.
18391   unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
18392   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
18393   EVT ExtBOIdxVT = Extract->getOperand(1).getValueType();
18394   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
18395       BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
18396     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
18397     SDLoc DL(Extract);
18398     SDValue NewExtIndex = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT);
18399     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
18400                             BinOp.getOperand(0), NewExtIndex);
18401     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
18402                             BinOp.getOperand(1), NewExtIndex);
18403     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y,
18404                                       BinOp.getNode()->getFlags());
18405     return DAG.getBitcast(VT, NarrowBinOp);
18406   }
18407 
18408   // Only handle the case where we are doubling and then halving. A larger ratio
18409   // may require more than two narrow binops to replace the wide binop.
18410   if (NarrowingRatio != 2)
18411     return SDValue();
18412 
18413   // TODO: The motivating case for this transform is an x86 AVX1 target. That
18414   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
18415   // flavors, but no other 256-bit integer support. This could be extended to
18416   // handle any binop, but that may require fixing/adding other folds to avoid
18417   // codegen regressions.
18418   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
18419     return SDValue();
18420 
18421   // We need at least one concatenation operation of a binop operand to make
18422   // this transform worthwhile. The concat must double the input vector sizes.
18423   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
18424     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
18425       return V.getOperand(ConcatOpNum);
18426     return SDValue();
18427   };
18428   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
18429   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
18430 
18431   if (SubVecL || SubVecR) {
18432     // If a binop operand was not the result of a concat, we must extract a
18433     // half-sized operand for our new narrow binop:
18434     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
18435     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
18436     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
18437     SDLoc DL(Extract);
18438     SDValue IndexC = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT);
18439     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
18440                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
18441                                       BinOp.getOperand(0), IndexC);
18442 
18443     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
18444                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
18445                                       BinOp.getOperand(1), IndexC);
18446 
18447     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
18448     return DAG.getBitcast(VT, NarrowBinOp);
18449   }
18450 
18451   return SDValue();
18452 }
18453 
18454 /// If we are extracting a subvector from a wide vector load, convert to a
18455 /// narrow load to eliminate the extraction:
18456 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)18457 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
18458   // TODO: Add support for big-endian. The offset calculation must be adjusted.
18459   if (DAG.getDataLayout().isBigEndian())
18460     return SDValue();
18461 
18462   auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
18463   auto *ExtIdx = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
18464   if (!Ld || Ld->getExtensionType() || !Ld->isSimple() ||
18465       !ExtIdx)
18466     return SDValue();
18467 
18468   // Allow targets to opt-out.
18469   EVT VT = Extract->getValueType(0);
18470   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18471   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
18472     return SDValue();
18473 
18474   // The narrow load will be offset from the base address of the old load if
18475   // we are extracting from something besides index 0 (little-endian).
18476   SDLoc DL(Extract);
18477   SDValue BaseAddr = Ld->getOperand(1);
18478   unsigned Offset = ExtIdx->getZExtValue() * VT.getScalarType().getStoreSize();
18479 
18480   // TODO: Use "BaseIndexOffset" to make this more effective.
18481   SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL);
18482   MachineFunction &MF = DAG.getMachineFunction();
18483   MachineMemOperand *MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset,
18484                                                    VT.getStoreSize());
18485   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
18486   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
18487   return NewLd;
18488 }
18489 
visitEXTRACT_SUBVECTOR(SDNode * N)18490 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
18491   EVT NVT = N->getValueType(0);
18492   SDValue V = N->getOperand(0);
18493 
18494   // Extract from UNDEF is UNDEF.
18495   if (V.isUndef())
18496     return DAG.getUNDEF(NVT);
18497 
18498   if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
18499     if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
18500       return NarrowLoad;
18501 
18502   // Combine an extract of an extract into a single extract_subvector.
18503   // ext (ext X, C), 0 --> ext X, C
18504   SDValue Index = N->getOperand(1);
18505   if (isNullConstant(Index) && V.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
18506       V.hasOneUse() && isa<ConstantSDNode>(V.getOperand(1))) {
18507     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
18508                                     V.getConstantOperandVal(1)) &&
18509         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
18510       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
18511                          V.getOperand(1));
18512     }
18513   }
18514 
18515   // Try to move vector bitcast after extract_subv by scaling extraction index:
18516   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
18517   if (isa<ConstantSDNode>(Index) && V.getOpcode() == ISD::BITCAST &&
18518       V.getOperand(0).getValueType().isVector()) {
18519     SDValue SrcOp = V.getOperand(0);
18520     EVT SrcVT = SrcOp.getValueType();
18521     unsigned SrcNumElts = SrcVT.getVectorNumElements();
18522     unsigned DestNumElts = V.getValueType().getVectorNumElements();
18523     if ((SrcNumElts % DestNumElts) == 0) {
18524       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
18525       unsigned NewExtNumElts = NVT.getVectorNumElements() * SrcDestRatio;
18526       EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
18527                                       NewExtNumElts);
18528       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
18529         unsigned IndexValScaled = N->getConstantOperandVal(1) * SrcDestRatio;
18530         SDLoc DL(N);
18531         SDValue NewIndex = DAG.getIntPtrConstant(IndexValScaled, DL);
18532         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
18533                                          V.getOperand(0), NewIndex);
18534         return DAG.getBitcast(NVT, NewExtract);
18535       }
18536     }
18537     if ((DestNumElts % SrcNumElts) == 0) {
18538       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
18539       if ((NVT.getVectorNumElements() % DestSrcRatio) == 0) {
18540         unsigned NewExtNumElts = NVT.getVectorNumElements() / DestSrcRatio;
18541         EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(),
18542                                         SrcVT.getScalarType(), NewExtNumElts);
18543         if ((N->getConstantOperandVal(1) % DestSrcRatio) == 0 &&
18544             TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
18545           unsigned IndexValScaled = N->getConstantOperandVal(1) / DestSrcRatio;
18546           SDLoc DL(N);
18547           SDValue NewIndex = DAG.getIntPtrConstant(IndexValScaled, DL);
18548           SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
18549                                            V.getOperand(0), NewIndex);
18550           return DAG.getBitcast(NVT, NewExtract);
18551         }
18552       }
18553     }
18554   }
18555 
18556   if (V.getOpcode() == ISD::CONCAT_VECTORS && isa<ConstantSDNode>(Index)) {
18557     EVT ConcatSrcVT = V.getOperand(0).getValueType();
18558     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
18559            "Concat and extract subvector do not change element type");
18560 
18561     unsigned ExtIdx = N->getConstantOperandVal(1);
18562     unsigned ExtNumElts = NVT.getVectorNumElements();
18563     assert(ExtIdx % ExtNumElts == 0 &&
18564            "Extract index is not a multiple of the input vector length.");
18565 
18566     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorNumElements();
18567     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
18568 
18569     // If the concatenated source types match this extract, it's a direct
18570     // simplification:
18571     // extract_subvec (concat V1, V2, ...), i --> Vi
18572     if (ConcatSrcNumElts == ExtNumElts)
18573       return V.getOperand(ConcatOpIdx);
18574 
18575     // If the concatenated source vectors are a multiple length of this extract,
18576     // then extract a fraction of one of those source vectors directly from a
18577     // concat operand. Example:
18578     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
18579     //   v2i8 extract_subvec v8i8 Y, 6
18580     if (ConcatSrcNumElts % ExtNumElts == 0) {
18581       SDLoc DL(N);
18582       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
18583       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
18584              "Trying to extract from >1 concat operand?");
18585       assert(NewExtIdx % ExtNumElts == 0 &&
18586              "Extract index is not a multiple of the input vector length.");
18587       MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
18588       SDValue NewIndexC = DAG.getConstant(NewExtIdx, DL, IdxTy);
18589       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
18590                          V.getOperand(ConcatOpIdx), NewIndexC);
18591     }
18592   }
18593 
18594   V = peekThroughBitcasts(V);
18595 
18596   // If the input is a build vector. Try to make a smaller build vector.
18597   if (V.getOpcode() == ISD::BUILD_VECTOR) {
18598     if (auto *IdxC = dyn_cast<ConstantSDNode>(Index)) {
18599       EVT InVT = V.getValueType();
18600       unsigned ExtractSize = NVT.getSizeInBits();
18601       unsigned EltSize = InVT.getScalarSizeInBits();
18602       // Only do this if we won't split any elements.
18603       if (ExtractSize % EltSize == 0) {
18604         unsigned NumElems = ExtractSize / EltSize;
18605         EVT EltVT = InVT.getVectorElementType();
18606         EVT ExtractVT = NumElems == 1 ? EltVT
18607                                       : EVT::getVectorVT(*DAG.getContext(),
18608                                                          EltVT, NumElems);
18609         if ((Level < AfterLegalizeDAG ||
18610              (NumElems == 1 ||
18611               TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
18612             (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
18613           unsigned IdxVal = IdxC->getZExtValue();
18614           IdxVal *= NVT.getScalarSizeInBits();
18615           IdxVal /= EltSize;
18616 
18617           if (NumElems == 1) {
18618             SDValue Src = V->getOperand(IdxVal);
18619             if (EltVT != Src.getValueType())
18620               Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
18621             return DAG.getBitcast(NVT, Src);
18622           }
18623 
18624           // Extract the pieces from the original build_vector.
18625           SDValue BuildVec = DAG.getBuildVector(
18626               ExtractVT, SDLoc(N), V->ops().slice(IdxVal, NumElems));
18627           return DAG.getBitcast(NVT, BuildVec);
18628         }
18629       }
18630     }
18631   }
18632 
18633   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
18634     // Handle only simple case where vector being inserted and vector
18635     // being extracted are of same size.
18636     EVT SmallVT = V.getOperand(1).getValueType();
18637     if (!NVT.bitsEq(SmallVT))
18638       return SDValue();
18639 
18640     // Only handle cases where both indexes are constants.
18641     auto *ExtIdx = dyn_cast<ConstantSDNode>(Index);
18642     auto *InsIdx = dyn_cast<ConstantSDNode>(V.getOperand(2));
18643     if (InsIdx && ExtIdx) {
18644       // Combine:
18645       //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
18646       // Into:
18647       //    indices are equal or bit offsets are equal => V1
18648       //    otherwise => (extract_subvec V1, ExtIdx)
18649       if (InsIdx->getZExtValue() * SmallVT.getScalarSizeInBits() ==
18650           ExtIdx->getZExtValue() * NVT.getScalarSizeInBits())
18651         return DAG.getBitcast(NVT, V.getOperand(1));
18652       return DAG.getNode(
18653           ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
18654           DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
18655           Index);
18656     }
18657   }
18658 
18659   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG))
18660     return NarrowBOp;
18661 
18662   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
18663     return SDValue(N, 0);
18664 
18665   return SDValue();
18666 }
18667 
18668 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
18669 /// followed by concatenation. Narrow vector ops may have better performance
18670 /// than wide ops, and this can unlock further narrowing of other vector ops.
18671 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)18672 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
18673                                          SelectionDAG &DAG) {
18674   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
18675   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
18676       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
18677       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
18678     return SDValue();
18679 
18680   // Split the wide shuffle mask into halves. Any mask element that is accessing
18681   // operand 1 is offset down to account for narrowing of the vectors.
18682   ArrayRef<int> Mask = Shuf->getMask();
18683   EVT VT = Shuf->getValueType(0);
18684   unsigned NumElts = VT.getVectorNumElements();
18685   unsigned HalfNumElts = NumElts / 2;
18686   SmallVector<int, 16> Mask0(HalfNumElts, -1);
18687   SmallVector<int, 16> Mask1(HalfNumElts, -1);
18688   for (unsigned i = 0; i != NumElts; ++i) {
18689     if (Mask[i] == -1)
18690       continue;
18691     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
18692     if (i < HalfNumElts)
18693       Mask0[i] = M;
18694     else
18695       Mask1[i - HalfNumElts] = M;
18696   }
18697 
18698   // Ask the target if this is a valid transform.
18699   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18700   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
18701                                 HalfNumElts);
18702   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
18703       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
18704     return SDValue();
18705 
18706   // shuffle (concat X, undef), (concat Y, undef), Mask -->
18707   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
18708   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
18709   SDLoc DL(Shuf);
18710   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
18711   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
18712   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
18713 }
18714 
18715 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
18716 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)18717 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
18718   EVT VT = N->getValueType(0);
18719   unsigned NumElts = VT.getVectorNumElements();
18720 
18721   SDValue N0 = N->getOperand(0);
18722   SDValue N1 = N->getOperand(1);
18723   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
18724   ArrayRef<int> Mask = SVN->getMask();
18725 
18726   SmallVector<SDValue, 4> Ops;
18727   EVT ConcatVT = N0.getOperand(0).getValueType();
18728   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
18729   unsigned NumConcats = NumElts / NumElemsPerConcat;
18730 
18731   auto IsUndefMaskElt = [](int i) { return i == -1; };
18732 
18733   // Special case: shuffle(concat(A,B)) can be more efficiently represented
18734   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
18735   // half vector elements.
18736   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
18737       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
18738                    IsUndefMaskElt)) {
18739     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
18740                               N0.getOperand(1),
18741                               Mask.slice(0, NumElemsPerConcat));
18742     N1 = DAG.getUNDEF(ConcatVT);
18743     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
18744   }
18745 
18746   // Look at every vector that's inserted. We're looking for exact
18747   // subvector-sized copies from a concatenated vector
18748   for (unsigned I = 0; I != NumConcats; ++I) {
18749     unsigned Begin = I * NumElemsPerConcat;
18750     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
18751 
18752     // Make sure we're dealing with a copy.
18753     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
18754       Ops.push_back(DAG.getUNDEF(ConcatVT));
18755       continue;
18756     }
18757 
18758     int OpIdx = -1;
18759     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
18760       if (IsUndefMaskElt(SubMask[i]))
18761         continue;
18762       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
18763         return SDValue();
18764       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
18765       if (0 <= OpIdx && EltOpIdx != OpIdx)
18766         return SDValue();
18767       OpIdx = EltOpIdx;
18768     }
18769     assert(0 <= OpIdx && "Unknown concat_vectors op");
18770 
18771     if (OpIdx < (int)N0.getNumOperands())
18772       Ops.push_back(N0.getOperand(OpIdx));
18773     else
18774       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
18775   }
18776 
18777   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
18778 }
18779 
18780 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
18781 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
18782 //
18783 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
18784 // a simplification in some sense, but it isn't appropriate in general: some
18785 // BUILD_VECTORs are substantially cheaper than others. The general case
18786 // of a BUILD_VECTOR requires inserting each element individually (or
18787 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
18788 // all constants is a single constant pool load.  A BUILD_VECTOR where each
18789 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
18790 // are undef lowers to a small number of element insertions.
18791 //
18792 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
18793 // We don't fold shuffles where one side is a non-zero constant, and we don't
18794 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
18795 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)18796 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
18797                                        SelectionDAG &DAG,
18798                                        const TargetLowering &TLI) {
18799   EVT VT = SVN->getValueType(0);
18800   unsigned NumElts = VT.getVectorNumElements();
18801   SDValue N0 = SVN->getOperand(0);
18802   SDValue N1 = SVN->getOperand(1);
18803 
18804   if (!N0->hasOneUse())
18805     return SDValue();
18806 
18807   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
18808   // discussed above.
18809   if (!N1.isUndef()) {
18810     if (!N1->hasOneUse())
18811       return SDValue();
18812 
18813     bool N0AnyConst = isAnyConstantBuildVector(N0);
18814     bool N1AnyConst = isAnyConstantBuildVector(N1);
18815     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
18816       return SDValue();
18817     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
18818       return SDValue();
18819   }
18820 
18821   // If both inputs are splats of the same value then we can safely merge this
18822   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
18823   bool IsSplat = false;
18824   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
18825   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
18826   if (BV0 && BV1)
18827     if (SDValue Splat0 = BV0->getSplatValue())
18828       IsSplat = (Splat0 == BV1->getSplatValue());
18829 
18830   SmallVector<SDValue, 8> Ops;
18831   SmallSet<SDValue, 16> DuplicateOps;
18832   for (int M : SVN->getMask()) {
18833     SDValue Op = DAG.getUNDEF(VT.getScalarType());
18834     if (M >= 0) {
18835       int Idx = M < (int)NumElts ? M : M - NumElts;
18836       SDValue &S = (M < (int)NumElts ? N0 : N1);
18837       if (S.getOpcode() == ISD::BUILD_VECTOR) {
18838         Op = S.getOperand(Idx);
18839       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
18840         SDValue Op0 = S.getOperand(0);
18841         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
18842       } else {
18843         // Operand can't be combined - bail out.
18844         return SDValue();
18845       }
18846     }
18847 
18848     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
18849     // generating a splat; semantically, this is fine, but it's likely to
18850     // generate low-quality code if the target can't reconstruct an appropriate
18851     // shuffle.
18852     if (!Op.isUndef() && !isa<ConstantSDNode>(Op) && !isa<ConstantFPSDNode>(Op))
18853       if (!IsSplat && !DuplicateOps.insert(Op).second)
18854         return SDValue();
18855 
18856     Ops.push_back(Op);
18857   }
18858 
18859   // BUILD_VECTOR requires all inputs to be of the same type, find the
18860   // maximum type and extend them all.
18861   EVT SVT = VT.getScalarType();
18862   if (SVT.isInteger())
18863     for (SDValue &Op : Ops)
18864       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
18865   if (SVT != VT.getScalarType())
18866     for (SDValue &Op : Ops)
18867       Op = TLI.isZExtFree(Op.getValueType(), SVT)
18868                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
18869                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT);
18870   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
18871 }
18872 
18873 // Match shuffles that can be converted to any_vector_extend_in_reg.
18874 // This is often generated during legalization.
18875 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
18876 // TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case.
combineShuffleToVectorExtend(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)18877 static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN,
18878                                             SelectionDAG &DAG,
18879                                             const TargetLowering &TLI,
18880                                             bool LegalOperations) {
18881   EVT VT = SVN->getValueType(0);
18882   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
18883 
18884   // TODO Add support for big-endian when we have a test case.
18885   if (!VT.isInteger() || IsBigEndian)
18886     return SDValue();
18887 
18888   unsigned NumElts = VT.getVectorNumElements();
18889   unsigned EltSizeInBits = VT.getScalarSizeInBits();
18890   ArrayRef<int> Mask = SVN->getMask();
18891   SDValue N0 = SVN->getOperand(0);
18892 
18893   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
18894   auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) {
18895     for (unsigned i = 0; i != NumElts; ++i) {
18896       if (Mask[i] < 0)
18897         continue;
18898       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
18899         continue;
18900       return false;
18901     }
18902     return true;
18903   };
18904 
18905   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
18906   // power-of-2 extensions as they are the most likely.
18907   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
18908     // Check for non power of 2 vector sizes
18909     if (NumElts % Scale != 0)
18910       continue;
18911     if (!isAnyExtend(Scale))
18912       continue;
18913 
18914     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
18915     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
18916     // Never create an illegal type. Only create unsupported operations if we
18917     // are pre-legalization.
18918     if (TLI.isTypeLegal(OutVT))
18919       if (!LegalOperations ||
18920           TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT))
18921         return DAG.getBitcast(VT,
18922                               DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG,
18923                                           SDLoc(SVN), OutVT, N0));
18924   }
18925 
18926   return SDValue();
18927 }
18928 
18929 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
18930 // each source element of a large type into the lowest elements of a smaller
18931 // destination type. This is often generated during legalization.
18932 // If the source node itself was a '*_extend_vector_inreg' node then we should
18933 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)18934 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
18935                                         SelectionDAG &DAG) {
18936   EVT VT = SVN->getValueType(0);
18937   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
18938 
18939   // TODO Add support for big-endian when we have a test case.
18940   if (!VT.isInteger() || IsBigEndian)
18941     return SDValue();
18942 
18943   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
18944 
18945   unsigned Opcode = N0.getOpcode();
18946   if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
18947       Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
18948       Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
18949     return SDValue();
18950 
18951   SDValue N00 = N0.getOperand(0);
18952   ArrayRef<int> Mask = SVN->getMask();
18953   unsigned NumElts = VT.getVectorNumElements();
18954   unsigned EltSizeInBits = VT.getScalarSizeInBits();
18955   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
18956   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
18957 
18958   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
18959     return SDValue();
18960   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
18961 
18962   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
18963   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
18964   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
18965   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
18966     for (unsigned i = 0; i != NumElts; ++i) {
18967       if (Mask[i] < 0)
18968         continue;
18969       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
18970         continue;
18971       return false;
18972     }
18973     return true;
18974   };
18975 
18976   // At the moment we just handle the case where we've truncated back to the
18977   // same size as before the extension.
18978   // TODO: handle more extension/truncation cases as cases arise.
18979   if (EltSizeInBits != ExtSrcSizeInBits)
18980     return SDValue();
18981 
18982   // We can remove *extend_vector_inreg only if the truncation happens at
18983   // the same scale as the extension.
18984   if (isTruncate(ExtScale))
18985     return DAG.getBitcast(VT, N00);
18986 
18987   return SDValue();
18988 }
18989 
18990 // Combine shuffles of splat-shuffles of the form:
18991 // shuffle (shuffle V, undef, splat-mask), undef, M
18992 // If splat-mask contains undef elements, we need to be careful about
18993 // introducing undef's in the folded mask which are not the result of composing
18994 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)18995 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
18996                                         SelectionDAG &DAG) {
18997   if (!Shuf->getOperand(1).isUndef())
18998     return SDValue();
18999   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
19000   if (!Splat || !Splat->isSplat())
19001     return SDValue();
19002 
19003   ArrayRef<int> ShufMask = Shuf->getMask();
19004   ArrayRef<int> SplatMask = Splat->getMask();
19005   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
19006 
19007   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
19008   // every undef mask element in the splat-shuffle has a corresponding undef
19009   // element in the user-shuffle's mask or if the composition of mask elements
19010   // would result in undef.
19011   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
19012   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
19013   //   In this case it is not legal to simplify to the splat-shuffle because we
19014   //   may be exposing the users of the shuffle an undef element at index 1
19015   //   which was not there before the combine.
19016   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
19017   //   In this case the composition of masks yields SplatMask, so it's ok to
19018   //   simplify to the splat-shuffle.
19019   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
19020   //   In this case the composed mask includes all undef elements of SplatMask
19021   //   and in addition sets element zero to undef. It is safe to simplify to
19022   //   the splat-shuffle.
19023   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
19024                                        ArrayRef<int> SplatMask) {
19025     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
19026       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
19027           SplatMask[UserMask[i]] != -1)
19028         return false;
19029     return true;
19030   };
19031   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
19032     return Shuf->getOperand(0);
19033 
19034   // Create a new shuffle with a mask that is composed of the two shuffles'
19035   // masks.
19036   SmallVector<int, 32> NewMask;
19037   for (int Idx : ShufMask)
19038     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
19039 
19040   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
19041                               Splat->getOperand(0), Splat->getOperand(1),
19042                               NewMask);
19043 }
19044 
19045 /// If the shuffle mask is taking exactly one element from the first vector
19046 /// operand and passing through all other elements from the second vector
19047 /// operand, return the index of the mask element that is choosing an element
19048 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)19049 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
19050   int MaskSize = Mask.size();
19051   int EltFromOp0 = -1;
19052   // TODO: This does not match if there are undef elements in the shuffle mask.
19053   // Should we ignore undefs in the shuffle mask instead? The trade-off is
19054   // removing an instruction (a shuffle), but losing the knowledge that some
19055   // vector lanes are not needed.
19056   for (int i = 0; i != MaskSize; ++i) {
19057     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
19058       // We're looking for a shuffle of exactly one element from operand 0.
19059       if (EltFromOp0 != -1)
19060         return -1;
19061       EltFromOp0 = i;
19062     } else if (Mask[i] != i + MaskSize) {
19063       // Nothing from operand 1 can change lanes.
19064       return -1;
19065     }
19066   }
19067   return EltFromOp0;
19068 }
19069 
19070 /// If a shuffle inserts exactly one element from a source vector operand into
19071 /// another vector operand and we can access the specified element as a scalar,
19072 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)19073 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
19074                                       SelectionDAG &DAG) {
19075   // First, check if we are taking one element of a vector and shuffling that
19076   // element into another vector.
19077   ArrayRef<int> Mask = Shuf->getMask();
19078   SmallVector<int, 16> CommutedMask(Mask.begin(), Mask.end());
19079   SDValue Op0 = Shuf->getOperand(0);
19080   SDValue Op1 = Shuf->getOperand(1);
19081   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
19082   if (ShufOp0Index == -1) {
19083     // Commute mask and check again.
19084     ShuffleVectorSDNode::commuteMask(CommutedMask);
19085     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
19086     if (ShufOp0Index == -1)
19087       return SDValue();
19088     // Commute operands to match the commuted shuffle mask.
19089     std::swap(Op0, Op1);
19090     Mask = CommutedMask;
19091   }
19092 
19093   // The shuffle inserts exactly one element from operand 0 into operand 1.
19094   // Now see if we can access that element as a scalar via a real insert element
19095   // instruction.
19096   // TODO: We can try harder to locate the element as a scalar. Examples: it
19097   // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
19098   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
19099          "Shuffle mask value must be from operand 0");
19100   if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
19101     return SDValue();
19102 
19103   auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
19104   if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
19105     return SDValue();
19106 
19107   // There's an existing insertelement with constant insertion index, so we
19108   // don't need to check the legality/profitability of a replacement operation
19109   // that differs at most in the constant value. The target should be able to
19110   // lower any of those in a similar way. If not, legalization will expand this
19111   // to a scalar-to-vector plus shuffle.
19112   //
19113   // Note that the shuffle may move the scalar from the position that the insert
19114   // element used. Therefore, our new insert element occurs at the shuffle's
19115   // mask index value, not the insert's index value.
19116   // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
19117   SDValue NewInsIndex = DAG.getConstant(ShufOp0Index, SDLoc(Shuf),
19118                                         Op0.getOperand(2).getValueType());
19119   return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
19120                      Op1, Op0.getOperand(1), NewInsIndex);
19121 }
19122 
19123 /// If we have a unary shuffle of a shuffle, see if it can be folded away
19124 /// completely. This has the potential to lose undef knowledge because the first
19125 /// shuffle may not have an undef mask element where the second one does. So
19126 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)19127 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
19128   // shuf (shuf0 X, Y, Mask0), undef, Mask
19129   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
19130   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
19131     return SDValue();
19132 
19133   ArrayRef<int> Mask = Shuf->getMask();
19134   ArrayRef<int> Mask0 = Shuf0->getMask();
19135   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
19136     // Ignore undef elements.
19137     if (Mask[i] == -1)
19138       continue;
19139     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
19140 
19141     // Is the element of the shuffle operand chosen by this shuffle the same as
19142     // the element chosen by the shuffle operand itself?
19143     if (Mask0[Mask[i]] != Mask0[i])
19144       return SDValue();
19145   }
19146   // Every element of this shuffle is identical to the result of the previous
19147   // shuffle, so we can replace this value.
19148   return Shuf->getOperand(0);
19149 }
19150 
visitVECTOR_SHUFFLE(SDNode * N)19151 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
19152   EVT VT = N->getValueType(0);
19153   unsigned NumElts = VT.getVectorNumElements();
19154 
19155   SDValue N0 = N->getOperand(0);
19156   SDValue N1 = N->getOperand(1);
19157 
19158   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
19159 
19160   // Canonicalize shuffle undef, undef -> undef
19161   if (N0.isUndef() && N1.isUndef())
19162     return DAG.getUNDEF(VT);
19163 
19164   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
19165 
19166   // Canonicalize shuffle v, v -> v, undef
19167   if (N0 == N1) {
19168     SmallVector<int, 8> NewMask;
19169     for (unsigned i = 0; i != NumElts; ++i) {
19170       int Idx = SVN->getMaskElt(i);
19171       if (Idx >= (int)NumElts) Idx -= NumElts;
19172       NewMask.push_back(Idx);
19173     }
19174     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT), NewMask);
19175   }
19176 
19177   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
19178   if (N0.isUndef())
19179     return DAG.getCommutedVectorShuffle(*SVN);
19180 
19181   // Remove references to rhs if it is undef
19182   if (N1.isUndef()) {
19183     bool Changed = false;
19184     SmallVector<int, 8> NewMask;
19185     for (unsigned i = 0; i != NumElts; ++i) {
19186       int Idx = SVN->getMaskElt(i);
19187       if (Idx >= (int)NumElts) {
19188         Idx = -1;
19189         Changed = true;
19190       }
19191       NewMask.push_back(Idx);
19192     }
19193     if (Changed)
19194       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
19195   }
19196 
19197   if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
19198     return InsElt;
19199 
19200   // A shuffle of a single vector that is a splatted value can always be folded.
19201   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
19202     return V;
19203 
19204   // If it is a splat, check if the argument vector is another splat or a
19205   // build_vector.
19206   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
19207     int SplatIndex = SVN->getSplatIndex();
19208     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
19209         TLI.isBinOp(N0.getOpcode()) && N0.getNode()->getNumValues() == 1) {
19210       // splat (vector_bo L, R), Index -->
19211       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
19212       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
19213       SDLoc DL(N);
19214       EVT EltVT = VT.getScalarType();
19215       SDValue Index = DAG.getIntPtrConstant(SplatIndex, DL);
19216       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
19217       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
19218       SDValue NewBO = DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR,
19219                                   N0.getNode()->getFlags());
19220       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
19221       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
19222       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
19223     }
19224 
19225     // If this is a bit convert that changes the element type of the vector but
19226     // not the number of vector elements, look through it.  Be careful not to
19227     // look though conversions that change things like v4f32 to v2f64.
19228     SDNode *V = N0.getNode();
19229     if (V->getOpcode() == ISD::BITCAST) {
19230       SDValue ConvInput = V->getOperand(0);
19231       if (ConvInput.getValueType().isVector() &&
19232           ConvInput.getValueType().getVectorNumElements() == NumElts)
19233         V = ConvInput.getNode();
19234     }
19235 
19236     if (V->getOpcode() == ISD::BUILD_VECTOR) {
19237       assert(V->getNumOperands() == NumElts &&
19238              "BUILD_VECTOR has wrong number of operands");
19239       SDValue Base;
19240       bool AllSame = true;
19241       for (unsigned i = 0; i != NumElts; ++i) {
19242         if (!V->getOperand(i).isUndef()) {
19243           Base = V->getOperand(i);
19244           break;
19245         }
19246       }
19247       // Splat of <u, u, u, u>, return <u, u, u, u>
19248       if (!Base.getNode())
19249         return N0;
19250       for (unsigned i = 0; i != NumElts; ++i) {
19251         if (V->getOperand(i) != Base) {
19252           AllSame = false;
19253           break;
19254         }
19255       }
19256       // Splat of <x, x, x, x>, return <x, x, x, x>
19257       if (AllSame)
19258         return N0;
19259 
19260       // Canonicalize any other splat as a build_vector.
19261       SDValue Splatted = V->getOperand(SplatIndex);
19262       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
19263       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
19264 
19265       // We may have jumped through bitcasts, so the type of the
19266       // BUILD_VECTOR may not match the type of the shuffle.
19267       if (V->getValueType(0) != VT)
19268         NewBV = DAG.getBitcast(VT, NewBV);
19269       return NewBV;
19270     }
19271   }
19272 
19273   // Simplify source operands based on shuffle mask.
19274   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
19275     return SDValue(N, 0);
19276 
19277   // This is intentionally placed after demanded elements simplification because
19278   // it could eliminate knowledge of undef elements created by this shuffle.
19279   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
19280     return ShufOp;
19281 
19282   // Match shuffles that can be converted to any_vector_extend_in_reg.
19283   if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations))
19284     return V;
19285 
19286   // Combine "truncate_vector_in_reg" style shuffles.
19287   if (SDValue V = combineTruncationShuffle(SVN, DAG))
19288     return V;
19289 
19290   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
19291       Level < AfterLegalizeVectorOps &&
19292       (N1.isUndef() ||
19293       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
19294        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
19295     if (SDValue V = partitionShuffleOfConcats(N, DAG))
19296       return V;
19297   }
19298 
19299   // A shuffle of a concat of the same narrow vector can be reduced to use
19300   // only low-half elements of a concat with undef:
19301   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
19302   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
19303       N0.getNumOperands() == 2 &&
19304       N0.getOperand(0) == N0.getOperand(1)) {
19305     int HalfNumElts = (int)NumElts / 2;
19306     SmallVector<int, 8> NewMask;
19307     for (unsigned i = 0; i != NumElts; ++i) {
19308       int Idx = SVN->getMaskElt(i);
19309       if (Idx >= HalfNumElts) {
19310         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
19311         Idx -= HalfNumElts;
19312       }
19313       NewMask.push_back(Idx);
19314     }
19315     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
19316       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
19317       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
19318                                    N0.getOperand(0), UndefVec);
19319       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
19320     }
19321   }
19322 
19323   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
19324   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
19325   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
19326     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
19327       return Res;
19328 
19329   // If this shuffle only has a single input that is a bitcasted shuffle,
19330   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
19331   // back to their original types.
19332   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
19333       N1.isUndef() && Level < AfterLegalizeVectorOps &&
19334       TLI.isTypeLegal(VT)) {
19335     auto ScaleShuffleMask = [](ArrayRef<int> Mask, int Scale) {
19336       if (Scale == 1)
19337         return SmallVector<int, 8>(Mask.begin(), Mask.end());
19338 
19339       SmallVector<int, 8> NewMask;
19340       for (int M : Mask)
19341         for (int s = 0; s != Scale; ++s)
19342           NewMask.push_back(M < 0 ? -1 : Scale * M + s);
19343       return NewMask;
19344     };
19345 
19346     SDValue BC0 = peekThroughOneUseBitcasts(N0);
19347     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
19348       EVT SVT = VT.getScalarType();
19349       EVT InnerVT = BC0->getValueType(0);
19350       EVT InnerSVT = InnerVT.getScalarType();
19351 
19352       // Determine which shuffle works with the smaller scalar type.
19353       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
19354       EVT ScaleSVT = ScaleVT.getScalarType();
19355 
19356       if (TLI.isTypeLegal(ScaleVT) &&
19357           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
19358           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
19359         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
19360         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
19361 
19362         // Scale the shuffle masks to the smaller scalar type.
19363         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
19364         SmallVector<int, 8> InnerMask =
19365             ScaleShuffleMask(InnerSVN->getMask(), InnerScale);
19366         SmallVector<int, 8> OuterMask =
19367             ScaleShuffleMask(SVN->getMask(), OuterScale);
19368 
19369         // Merge the shuffle masks.
19370         SmallVector<int, 8> NewMask;
19371         for (int M : OuterMask)
19372           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
19373 
19374         // Test for shuffle mask legality over both commutations.
19375         SDValue SV0 = BC0->getOperand(0);
19376         SDValue SV1 = BC0->getOperand(1);
19377         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
19378         if (!LegalMask) {
19379           std::swap(SV0, SV1);
19380           ShuffleVectorSDNode::commuteMask(NewMask);
19381           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
19382         }
19383 
19384         if (LegalMask) {
19385           SV0 = DAG.getBitcast(ScaleVT, SV0);
19386           SV1 = DAG.getBitcast(ScaleVT, SV1);
19387           return DAG.getBitcast(
19388               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
19389         }
19390       }
19391     }
19392   }
19393 
19394   // Canonicalize shuffles according to rules:
19395   //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
19396   //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
19397   //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
19398   if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
19399       N0.getOpcode() != ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG &&
19400       TLI.isTypeLegal(VT)) {
19401     // The incoming shuffle must be of the same type as the result of the
19402     // current shuffle.
19403     assert(N1->getOperand(0).getValueType() == VT &&
19404            "Shuffle types don't match");
19405 
19406     SDValue SV0 = N1->getOperand(0);
19407     SDValue SV1 = N1->getOperand(1);
19408     bool HasSameOp0 = N0 == SV0;
19409     bool IsSV1Undef = SV1.isUndef();
19410     if (HasSameOp0 || IsSV1Undef || N0 == SV1)
19411       // Commute the operands of this shuffle so that next rule
19412       // will trigger.
19413       return DAG.getCommutedVectorShuffle(*SVN);
19414   }
19415 
19416   // Try to fold according to rules:
19417   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
19418   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
19419   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
19420   // Don't try to fold shuffles with illegal type.
19421   // Only fold if this shuffle is the only user of the other shuffle.
19422   if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) &&
19423       Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
19424     ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0);
19425 
19426     // Don't try to fold splats; they're likely to simplify somehow, or they
19427     // might be free.
19428     if (OtherSV->isSplat())
19429       return SDValue();
19430 
19431     // The incoming shuffle must be of the same type as the result of the
19432     // current shuffle.
19433     assert(OtherSV->getOperand(0).getValueType() == VT &&
19434            "Shuffle types don't match");
19435 
19436     SDValue SV0, SV1;
19437     SmallVector<int, 4> Mask;
19438     // Compute the combined shuffle mask for a shuffle with SV0 as the first
19439     // operand, and SV1 as the second operand.
19440     for (unsigned i = 0; i != NumElts; ++i) {
19441       int Idx = SVN->getMaskElt(i);
19442       if (Idx < 0) {
19443         // Propagate Undef.
19444         Mask.push_back(Idx);
19445         continue;
19446       }
19447 
19448       SDValue CurrentVec;
19449       if (Idx < (int)NumElts) {
19450         // This shuffle index refers to the inner shuffle N0. Lookup the inner
19451         // shuffle mask to identify which vector is actually referenced.
19452         Idx = OtherSV->getMaskElt(Idx);
19453         if (Idx < 0) {
19454           // Propagate Undef.
19455           Mask.push_back(Idx);
19456           continue;
19457         }
19458 
19459         CurrentVec = (Idx < (int) NumElts) ? OtherSV->getOperand(0)
19460                                            : OtherSV->getOperand(1);
19461       } else {
19462         // This shuffle index references an element within N1.
19463         CurrentVec = N1;
19464       }
19465 
19466       // Simple case where 'CurrentVec' is UNDEF.
19467       if (CurrentVec.isUndef()) {
19468         Mask.push_back(-1);
19469         continue;
19470       }
19471 
19472       // Canonicalize the shuffle index. We don't know yet if CurrentVec
19473       // will be the first or second operand of the combined shuffle.
19474       Idx = Idx % NumElts;
19475       if (!SV0.getNode() || SV0 == CurrentVec) {
19476         // Ok. CurrentVec is the left hand side.
19477         // Update the mask accordingly.
19478         SV0 = CurrentVec;
19479         Mask.push_back(Idx);
19480         continue;
19481       }
19482 
19483       // Bail out if we cannot convert the shuffle pair into a single shuffle.
19484       if (SV1.getNode() && SV1 != CurrentVec)
19485         return SDValue();
19486 
19487       // Ok. CurrentVec is the right hand side.
19488       // Update the mask accordingly.
19489       SV1 = CurrentVec;
19490       Mask.push_back(Idx + NumElts);
19491     }
19492 
19493     // Check if all indices in Mask are Undef. In case, propagate Undef.
19494     bool isUndefMask = true;
19495     for (unsigned i = 0; i != NumElts && isUndefMask; ++i)
19496       isUndefMask &= Mask[i] < 0;
19497 
19498     if (isUndefMask)
19499       return DAG.getUNDEF(VT);
19500 
19501     if (!SV0.getNode())
19502       SV0 = DAG.getUNDEF(VT);
19503     if (!SV1.getNode())
19504       SV1 = DAG.getUNDEF(VT);
19505 
19506     // Avoid introducing shuffles with illegal mask.
19507     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
19508     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
19509     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
19510     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
19511     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
19512     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
19513     return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG);
19514   }
19515 
19516   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
19517     return V;
19518 
19519   return SDValue();
19520 }
19521 
visitSCALAR_TO_VECTOR(SDNode * N)19522 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
19523   SDValue InVal = N->getOperand(0);
19524   EVT VT = N->getValueType(0);
19525 
19526   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
19527   // with a VECTOR_SHUFFLE and possible truncate.
19528   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
19529     SDValue InVec = InVal->getOperand(0);
19530     SDValue EltNo = InVal->getOperand(1);
19531     auto InVecT = InVec.getValueType();
19532     if (ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo)) {
19533       SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1);
19534       int Elt = C0->getZExtValue();
19535       NewMask[0] = Elt;
19536       // If we have an implict truncate do truncate here as long as it's legal.
19537       // if it's not legal, this should
19538       if (VT.getScalarType() != InVal.getValueType() &&
19539           InVal.getValueType().isScalarInteger() &&
19540           isTypeLegal(VT.getScalarType())) {
19541         SDValue Val =
19542             DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal);
19543         return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
19544       }
19545       if (VT.getScalarType() == InVecT.getScalarType() &&
19546           VT.getVectorNumElements() <= InVecT.getVectorNumElements()) {
19547         SDValue LegalShuffle =
19548           TLI.buildLegalVectorShuffle(InVecT, SDLoc(N), InVec,
19549                                       DAG.getUNDEF(InVecT), NewMask, DAG);
19550         if (LegalShuffle) {
19551           // If the initial vector is the correct size this shuffle is a
19552           // valid result.
19553           if (VT == InVecT)
19554             return LegalShuffle;
19555           // If not we must truncate the vector.
19556           if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
19557             MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
19558             SDValue ZeroIdx = DAG.getConstant(0, SDLoc(N), IdxTy);
19559             EVT SubVT =
19560                 EVT::getVectorVT(*DAG.getContext(), InVecT.getVectorElementType(),
19561                                  VT.getVectorNumElements());
19562             return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT,
19563                                LegalShuffle, ZeroIdx);
19564           }
19565         }
19566       }
19567     }
19568   }
19569 
19570   return SDValue();
19571 }
19572 
visitINSERT_SUBVECTOR(SDNode * N)19573 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
19574   EVT VT = N->getValueType(0);
19575   SDValue N0 = N->getOperand(0);
19576   SDValue N1 = N->getOperand(1);
19577   SDValue N2 = N->getOperand(2);
19578 
19579   // If inserting an UNDEF, just return the original vector.
19580   if (N1.isUndef())
19581     return N0;
19582 
19583   // If this is an insert of an extracted vector into an undef vector, we can
19584   // just use the input to the extract.
19585   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
19586       N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
19587     return N1.getOperand(0);
19588 
19589   // If we are inserting a bitcast value into an undef, with the same
19590   // number of elements, just use the bitcast input of the extract.
19591   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
19592   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
19593   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
19594       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
19595       N1.getOperand(0).getOperand(1) == N2 &&
19596       N1.getOperand(0).getOperand(0).getValueType().getVectorNumElements() ==
19597           VT.getVectorNumElements() &&
19598       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
19599           VT.getSizeInBits()) {
19600     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
19601   }
19602 
19603   // If both N1 and N2 are bitcast values on which insert_subvector
19604   // would makes sense, pull the bitcast through.
19605   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
19606   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
19607   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
19608     SDValue CN0 = N0.getOperand(0);
19609     SDValue CN1 = N1.getOperand(0);
19610     EVT CN0VT = CN0.getValueType();
19611     EVT CN1VT = CN1.getValueType();
19612     if (CN0VT.isVector() && CN1VT.isVector() &&
19613         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
19614         CN0VT.getVectorNumElements() == VT.getVectorNumElements()) {
19615       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
19616                                       CN0.getValueType(), CN0, CN1, N2);
19617       return DAG.getBitcast(VT, NewINSERT);
19618     }
19619   }
19620 
19621   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
19622   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
19623   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
19624   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
19625       N0.getOperand(1).getValueType() == N1.getValueType() &&
19626       N0.getOperand(2) == N2)
19627     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
19628                        N1, N2);
19629 
19630   // Eliminate an intermediate insert into an undef vector:
19631   // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
19632   // insert_subvector undef, X, N2
19633   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
19634       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
19635     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
19636                        N1.getOperand(1), N2);
19637 
19638   if (!isa<ConstantSDNode>(N2))
19639     return SDValue();
19640 
19641   uint64_t InsIdx = cast<ConstantSDNode>(N2)->getZExtValue();
19642 
19643   // Push subvector bitcasts to the output, adjusting the index as we go.
19644   // insert_subvector(bitcast(v), bitcast(s), c1)
19645   // -> bitcast(insert_subvector(v, s, c2))
19646   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
19647       N1.getOpcode() == ISD::BITCAST) {
19648     SDValue N0Src = peekThroughBitcasts(N0);
19649     SDValue N1Src = peekThroughBitcasts(N1);
19650     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
19651     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
19652     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
19653         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
19654       EVT NewVT;
19655       SDLoc DL(N);
19656       SDValue NewIdx;
19657       MVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
19658       LLVMContext &Ctx = *DAG.getContext();
19659       unsigned NumElts = VT.getVectorNumElements();
19660       unsigned EltSizeInBits = VT.getScalarSizeInBits();
19661       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
19662         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
19663         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
19664         NewIdx = DAG.getConstant(InsIdx * Scale, DL, IdxVT);
19665       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
19666         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
19667         if ((NumElts % Scale) == 0 && (InsIdx % Scale) == 0) {
19668           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts / Scale);
19669           NewIdx = DAG.getConstant(InsIdx / Scale, DL, IdxVT);
19670         }
19671       }
19672       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
19673         SDValue Res = DAG.getBitcast(NewVT, N0Src);
19674         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
19675         return DAG.getBitcast(VT, Res);
19676       }
19677     }
19678   }
19679 
19680   // Canonicalize insert_subvector dag nodes.
19681   // Example:
19682   // (insert_subvector (insert_subvector A, Idx0), Idx1)
19683   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
19684   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
19685       N1.getValueType() == N0.getOperand(1).getValueType() &&
19686       isa<ConstantSDNode>(N0.getOperand(2))) {
19687     unsigned OtherIdx = N0.getConstantOperandVal(2);
19688     if (InsIdx < OtherIdx) {
19689       // Swap nodes.
19690       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
19691                                   N0.getOperand(0), N1, N2);
19692       AddToWorklist(NewOp.getNode());
19693       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
19694                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
19695     }
19696   }
19697 
19698   // If the input vector is a concatenation, and the insert replaces
19699   // one of the pieces, we can optimize into a single concat_vectors.
19700   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
19701       N0.getOperand(0).getValueType() == N1.getValueType()) {
19702     unsigned Factor = N1.getValueType().getVectorNumElements();
19703 
19704     SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
19705     Ops[cast<ConstantSDNode>(N2)->getZExtValue() / Factor] = N1;
19706 
19707     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
19708   }
19709 
19710   // Simplify source operands based on insertion.
19711   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
19712     return SDValue(N, 0);
19713 
19714   return SDValue();
19715 }
19716 
visitFP_TO_FP16(SDNode * N)19717 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
19718   SDValue N0 = N->getOperand(0);
19719 
19720   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
19721   if (N0->getOpcode() == ISD::FP16_TO_FP)
19722     return N0->getOperand(0);
19723 
19724   return SDValue();
19725 }
19726 
visitFP16_TO_FP(SDNode * N)19727 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
19728   SDValue N0 = N->getOperand(0);
19729 
19730   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
19731   if (N0->getOpcode() == ISD::AND) {
19732     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
19733     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
19734       return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
19735                          N0.getOperand(0));
19736     }
19737   }
19738 
19739   return SDValue();
19740 }
19741 
visitVECREDUCE(SDNode * N)19742 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
19743   SDValue N0 = N->getOperand(0);
19744   EVT VT = N0.getValueType();
19745   unsigned Opcode = N->getOpcode();
19746 
19747   // VECREDUCE over 1-element vector is just an extract.
19748   if (VT.getVectorNumElements() == 1) {
19749     SDLoc dl(N);
19750     SDValue Res = DAG.getNode(
19751         ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
19752         DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout())));
19753     if (Res.getValueType() != N->getValueType(0))
19754       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
19755     return Res;
19756   }
19757 
19758   // On an boolean vector an and/or reduction is the same as a umin/umax
19759   // reduction. Convert them if the latter is legal while the former isn't.
19760   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
19761     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
19762         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
19763     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
19764         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
19765         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
19766       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
19767   }
19768 
19769   return SDValue();
19770 }
19771 
19772 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
19773 /// with the destination vector and a zero vector.
19774 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
19775 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)19776 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
19777   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
19778 
19779   EVT VT = N->getValueType(0);
19780   SDValue LHS = N->getOperand(0);
19781   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
19782   SDLoc DL(N);
19783 
19784   // Make sure we're not running after operation legalization where it
19785   // may have custom lowered the vector shuffles.
19786   if (LegalOperations)
19787     return SDValue();
19788 
19789   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
19790     return SDValue();
19791 
19792   EVT RVT = RHS.getValueType();
19793   unsigned NumElts = RHS.getNumOperands();
19794 
19795   // Attempt to create a valid clear mask, splitting the mask into
19796   // sub elements and checking to see if each is
19797   // all zeros or all ones - suitable for shuffle masking.
19798   auto BuildClearMask = [&](int Split) {
19799     int NumSubElts = NumElts * Split;
19800     int NumSubBits = RVT.getScalarSizeInBits() / Split;
19801 
19802     SmallVector<int, 8> Indices;
19803     for (int i = 0; i != NumSubElts; ++i) {
19804       int EltIdx = i / Split;
19805       int SubIdx = i % Split;
19806       SDValue Elt = RHS.getOperand(EltIdx);
19807       // X & undef --> 0 (not undef). So this lane must be converted to choose
19808       // from the zero constant vector (same as if the element had all 0-bits).
19809       if (Elt.isUndef()) {
19810         Indices.push_back(i + NumSubElts);
19811         continue;
19812       }
19813 
19814       APInt Bits;
19815       if (isa<ConstantSDNode>(Elt))
19816         Bits = cast<ConstantSDNode>(Elt)->getAPIntValue();
19817       else if (isa<ConstantFPSDNode>(Elt))
19818         Bits = cast<ConstantFPSDNode>(Elt)->getValueAPF().bitcastToAPInt();
19819       else
19820         return SDValue();
19821 
19822       // Extract the sub element from the constant bit mask.
19823       if (DAG.getDataLayout().isBigEndian())
19824         Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
19825       else
19826         Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
19827 
19828       if (Bits.isAllOnesValue())
19829         Indices.push_back(i);
19830       else if (Bits == 0)
19831         Indices.push_back(i + NumSubElts);
19832       else
19833         return SDValue();
19834     }
19835 
19836     // Let's see if the target supports this vector_shuffle.
19837     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
19838     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
19839     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
19840       return SDValue();
19841 
19842     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
19843     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
19844                                                    DAG.getBitcast(ClearVT, LHS),
19845                                                    Zero, Indices));
19846   };
19847 
19848   // Determine maximum split level (byte level masking).
19849   int MaxSplit = 1;
19850   if (RVT.getScalarSizeInBits() % 8 == 0)
19851     MaxSplit = RVT.getScalarSizeInBits() / 8;
19852 
19853   for (int Split = 1; Split <= MaxSplit; ++Split)
19854     if (RVT.getScalarSizeInBits() % Split == 0)
19855       if (SDValue S = BuildClearMask(Split))
19856         return S;
19857 
19858   return SDValue();
19859 }
19860 
19861 /// If a vector binop is performed on splat values, it may be profitable to
19862 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG)19863 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG) {
19864   SDValue N0 = N->getOperand(0);
19865   SDValue N1 = N->getOperand(1);
19866   unsigned Opcode = N->getOpcode();
19867   EVT VT = N->getValueType(0);
19868   EVT EltVT = VT.getVectorElementType();
19869   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19870 
19871   // TODO: Remove/replace the extract cost check? If the elements are available
19872   //       as scalars, then there may be no extract cost. Should we ask if
19873   //       inserting a scalar back into a vector is cheap instead?
19874   int Index0, Index1;
19875   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
19876   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
19877   if (!Src0 || !Src1 || Index0 != Index1 ||
19878       Src0.getValueType().getVectorElementType() != EltVT ||
19879       Src1.getValueType().getVectorElementType() != EltVT ||
19880       !TLI.isExtractVecEltCheap(VT, Index0) ||
19881       !TLI.isOperationLegalOrCustom(Opcode, EltVT))
19882     return SDValue();
19883 
19884   SDLoc DL(N);
19885   SDValue IndexC =
19886       DAG.getConstant(Index0, DL, TLI.getVectorIdxTy(DAG.getDataLayout()));
19887   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, N0, IndexC);
19888   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, N1, IndexC);
19889   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
19890 
19891   // If all lanes but 1 are undefined, no need to splat the scalar result.
19892   // TODO: Keep track of undefs and use that info in the general case.
19893   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
19894       count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
19895       count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
19896     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
19897     // build_vec ..undef, (bo X, Y), undef...
19898     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
19899     Ops[Index0] = ScalarBO;
19900     return DAG.getBuildVector(VT, DL, Ops);
19901   }
19902 
19903   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
19904   SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
19905   return DAG.getBuildVector(VT, DL, Ops);
19906 }
19907 
19908 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N)19909 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
19910   assert(N->getValueType(0).isVector() &&
19911          "SimplifyVBinOp only works on vectors!");
19912 
19913   SDValue LHS = N->getOperand(0);
19914   SDValue RHS = N->getOperand(1);
19915   SDValue Ops[] = {LHS, RHS};
19916   EVT VT = N->getValueType(0);
19917   unsigned Opcode = N->getOpcode();
19918 
19919   // See if we can constant fold the vector operation.
19920   if (SDValue Fold = DAG.FoldConstantVectorArithmetic(
19921           Opcode, SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags()))
19922     return Fold;
19923 
19924   // Move unary shuffles with identical masks after a vector binop:
19925   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
19926   //   --> shuffle (VBinOp A, B), Undef, Mask
19927   // This does not require type legality checks because we are creating the
19928   // same types of operations that are in the original sequence. We do have to
19929   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
19930   // though. This code is adapted from the identical transform in instcombine.
19931   if (Opcode != ISD::UDIV && Opcode != ISD::SDIV &&
19932       Opcode != ISD::UREM && Opcode != ISD::SREM &&
19933       Opcode != ISD::UDIVREM && Opcode != ISD::SDIVREM) {
19934     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
19935     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
19936     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
19937         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
19938         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
19939       SDLoc DL(N);
19940       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
19941                                      RHS.getOperand(0), N->getFlags());
19942       SDValue UndefV = LHS.getOperand(1);
19943       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
19944     }
19945   }
19946 
19947   // The following pattern is likely to emerge with vector reduction ops. Moving
19948   // the binary operation ahead of insertion may allow using a narrower vector
19949   // instruction that has better performance than the wide version of the op:
19950   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
19951   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
19952       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
19953       LHS.getOperand(2) == RHS.getOperand(2) &&
19954       (LHS.hasOneUse() || RHS.hasOneUse())) {
19955     SDValue X = LHS.getOperand(1);
19956     SDValue Y = RHS.getOperand(1);
19957     SDValue Z = LHS.getOperand(2);
19958     EVT NarrowVT = X.getValueType();
19959     if (NarrowVT == Y.getValueType() &&
19960         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
19961       // (binop undef, undef) may not return undef, so compute that result.
19962       SDLoc DL(N);
19963       SDValue VecC =
19964           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
19965       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
19966       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
19967     }
19968   }
19969 
19970   // Make sure all but the first op are undef or constant.
19971   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
19972     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
19973            std::all_of(std::next(Concat->op_begin()), Concat->op_end(),
19974                      [](const SDValue &Op) {
19975                        return Op.isUndef() ||
19976                               ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
19977                      });
19978   };
19979 
19980   // The following pattern is likely to emerge with vector reduction ops. Moving
19981   // the binary operation ahead of the concat may allow using a narrower vector
19982   // instruction that has better performance than the wide version of the op:
19983   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
19984   //   concat (VBinOp X, Y), VecC
19985   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
19986       (LHS.hasOneUse() || RHS.hasOneUse())) {
19987     EVT NarrowVT = LHS.getOperand(0).getValueType();
19988     if (NarrowVT == RHS.getOperand(0).getValueType() &&
19989         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
19990       SDLoc DL(N);
19991       unsigned NumOperands = LHS.getNumOperands();
19992       SmallVector<SDValue, 4> ConcatOps;
19993       for (unsigned i = 0; i != NumOperands; ++i) {
19994         // This constant fold for operands 1 and up.
19995         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
19996                                         RHS.getOperand(i)));
19997       }
19998 
19999       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
20000     }
20001   }
20002 
20003   if (SDValue V = scalarizeBinOpOfSplats(N, DAG))
20004     return V;
20005 
20006   return SDValue();
20007 }
20008 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)20009 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
20010                                     SDValue N2) {
20011   assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
20012 
20013   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
20014                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
20015 
20016   // If we got a simplified select_cc node back from SimplifySelectCC, then
20017   // break it down into a new SETCC node, and a new SELECT node, and then return
20018   // the SELECT node, since we were called with a SELECT node.
20019   if (SCC.getNode()) {
20020     // Check to see if we got a select_cc back (to turn into setcc/select).
20021     // Otherwise, just return whatever node we got back, like fabs.
20022     if (SCC.getOpcode() == ISD::SELECT_CC) {
20023       const SDNodeFlags Flags = N0.getNode()->getFlags();
20024       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
20025                                   N0.getValueType(),
20026                                   SCC.getOperand(0), SCC.getOperand(1),
20027                                   SCC.getOperand(4), Flags);
20028       AddToWorklist(SETCC.getNode());
20029       SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
20030                                          SCC.getOperand(2), SCC.getOperand(3));
20031       SelectNode->setFlags(Flags);
20032       return SelectNode;
20033     }
20034 
20035     return SCC;
20036   }
20037   return SDValue();
20038 }
20039 
20040 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
20041 /// being selected between, see if we can simplify the select.  Callers of this
20042 /// should assume that TheSelect is deleted if this returns true.  As such, they
20043 /// should return the appropriate thing (e.g. the node) back to the top-level of
20044 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)20045 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
20046                                     SDValue RHS) {
20047   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
20048   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
20049   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
20050     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
20051       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
20052       SDValue Sqrt = RHS;
20053       ISD::CondCode CC;
20054       SDValue CmpLHS;
20055       const ConstantFPSDNode *Zero = nullptr;
20056 
20057       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
20058         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
20059         CmpLHS = TheSelect->getOperand(0);
20060         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
20061       } else {
20062         // SELECT or VSELECT
20063         SDValue Cmp = TheSelect->getOperand(0);
20064         if (Cmp.getOpcode() == ISD::SETCC) {
20065           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
20066           CmpLHS = Cmp.getOperand(0);
20067           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
20068         }
20069       }
20070       if (Zero && Zero->isZero() &&
20071           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
20072           CC == ISD::SETULT || CC == ISD::SETLT)) {
20073         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
20074         CombineTo(TheSelect, Sqrt);
20075         return true;
20076       }
20077     }
20078   }
20079   // Cannot simplify select with vector condition
20080   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
20081 
20082   // If this is a select from two identical things, try to pull the operation
20083   // through the select.
20084   if (LHS.getOpcode() != RHS.getOpcode() ||
20085       !LHS.hasOneUse() || !RHS.hasOneUse())
20086     return false;
20087 
20088   // If this is a load and the token chain is identical, replace the select
20089   // of two loads with a load through a select of the address to load from.
20090   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
20091   // constants have been dropped into the constant pool.
20092   if (LHS.getOpcode() == ISD::LOAD) {
20093     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
20094     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
20095 
20096     // Token chains must be identical.
20097     if (LHS.getOperand(0) != RHS.getOperand(0) ||
20098         // Do not let this transformation reduce the number of volatile loads.
20099         // Be conservative for atomics for the moment
20100         // TODO: This does appear to be legal for unordered atomics (see D66309)
20101         !LLD->isSimple() || !RLD->isSimple() ||
20102         // FIXME: If either is a pre/post inc/dec load,
20103         // we'd need to split out the address adjustment.
20104         LLD->isIndexed() || RLD->isIndexed() ||
20105         // If this is an EXTLOAD, the VT's must match.
20106         LLD->getMemoryVT() != RLD->getMemoryVT() ||
20107         // If this is an EXTLOAD, the kind of extension must match.
20108         (LLD->getExtensionType() != RLD->getExtensionType() &&
20109          // The only exception is if one of the extensions is anyext.
20110          LLD->getExtensionType() != ISD::EXTLOAD &&
20111          RLD->getExtensionType() != ISD::EXTLOAD) ||
20112         // FIXME: this discards src value information.  This is
20113         // over-conservative. It would be beneficial to be able to remember
20114         // both potential memory locations.  Since we are discarding
20115         // src value info, don't do the transformation if the memory
20116         // locations are not in the default address space.
20117         LLD->getPointerInfo().getAddrSpace() != 0 ||
20118         RLD->getPointerInfo().getAddrSpace() != 0 ||
20119         // We can't produce a CMOV of a TargetFrameIndex since we won't
20120         // generate the address generation required.
20121         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
20122         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
20123         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
20124                                       LLD->getBasePtr().getValueType()))
20125       return false;
20126 
20127     // The loads must not depend on one another.
20128     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
20129       return false;
20130 
20131     // Check that the select condition doesn't reach either load.  If so,
20132     // folding this will induce a cycle into the DAG.  If not, this is safe to
20133     // xform, so create a select of the addresses.
20134 
20135     SmallPtrSet<const SDNode *, 32> Visited;
20136     SmallVector<const SDNode *, 16> Worklist;
20137 
20138     // Always fail if LLD and RLD are not independent. TheSelect is a
20139     // predecessor to all Nodes in question so we need not search past it.
20140 
20141     Visited.insert(TheSelect);
20142     Worklist.push_back(LLD);
20143     Worklist.push_back(RLD);
20144 
20145     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
20146         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
20147       return false;
20148 
20149     SDValue Addr;
20150     if (TheSelect->getOpcode() == ISD::SELECT) {
20151       // We cannot do this optimization if any pair of {RLD, LLD} is a
20152       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
20153       // Loads, we only need to check if CondNode is a successor to one of the
20154       // loads. We can further avoid this if there's no use of their chain
20155       // value.
20156       SDNode *CondNode = TheSelect->getOperand(0).getNode();
20157       Worklist.push_back(CondNode);
20158 
20159       if ((LLD->hasAnyUseOfValue(1) &&
20160            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
20161           (RLD->hasAnyUseOfValue(1) &&
20162            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
20163         return false;
20164 
20165       Addr = DAG.getSelect(SDLoc(TheSelect),
20166                            LLD->getBasePtr().getValueType(),
20167                            TheSelect->getOperand(0), LLD->getBasePtr(),
20168                            RLD->getBasePtr());
20169     } else {  // Otherwise SELECT_CC
20170       // We cannot do this optimization if any pair of {RLD, LLD} is a
20171       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
20172       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
20173       // one of the loads. We can further avoid this if there's no use of their
20174       // chain value.
20175 
20176       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
20177       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
20178       Worklist.push_back(CondLHS);
20179       Worklist.push_back(CondRHS);
20180 
20181       if ((LLD->hasAnyUseOfValue(1) &&
20182            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
20183           (RLD->hasAnyUseOfValue(1) &&
20184            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
20185         return false;
20186 
20187       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
20188                          LLD->getBasePtr().getValueType(),
20189                          TheSelect->getOperand(0),
20190                          TheSelect->getOperand(1),
20191                          LLD->getBasePtr(), RLD->getBasePtr(),
20192                          TheSelect->getOperand(4));
20193     }
20194 
20195     SDValue Load;
20196     // It is safe to replace the two loads if they have different alignments,
20197     // but the new load must be the minimum (most restrictive) alignment of the
20198     // inputs.
20199     unsigned Alignment = std::min(LLD->getAlignment(), RLD->getAlignment());
20200     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
20201     if (!RLD->isInvariant())
20202       MMOFlags &= ~MachineMemOperand::MOInvariant;
20203     if (!RLD->isDereferenceable())
20204       MMOFlags &= ~MachineMemOperand::MODereferenceable;
20205     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
20206       // FIXME: Discards pointer and AA info.
20207       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
20208                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
20209                          MMOFlags);
20210     } else {
20211       // FIXME: Discards pointer and AA info.
20212       Load = DAG.getExtLoad(
20213           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
20214                                                   : LLD->getExtensionType(),
20215           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
20216           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
20217     }
20218 
20219     // Users of the select now use the result of the load.
20220     CombineTo(TheSelect, Load);
20221 
20222     // Users of the old loads now use the new load's chain.  We know the
20223     // old-load value is dead now.
20224     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
20225     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
20226     return true;
20227   }
20228 
20229   return false;
20230 }
20231 
20232 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
20233 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)20234 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
20235                                             SDValue N1, SDValue N2, SDValue N3,
20236                                             ISD::CondCode CC) {
20237   // If this is a select where the false operand is zero and the compare is a
20238   // check of the sign bit, see if we can perform the "gzip trick":
20239   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
20240   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
20241   EVT XType = N0.getValueType();
20242   EVT AType = N2.getValueType();
20243   if (!isNullConstant(N3) || !XType.bitsGE(AType))
20244     return SDValue();
20245 
20246   // If the comparison is testing for a positive value, we have to invert
20247   // the sign bit mask, so only do that transform if the target has a bitwise
20248   // 'and not' instruction (the invert is free).
20249   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
20250     // (X > -1) ? A : 0
20251     // (X >  0) ? X : 0 <-- This is canonical signed max.
20252     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
20253       return SDValue();
20254   } else if (CC == ISD::SETLT) {
20255     // (X <  0) ? A : 0
20256     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
20257     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
20258       return SDValue();
20259   } else {
20260     return SDValue();
20261   }
20262 
20263   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
20264   // constant.
20265   EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
20266   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
20267   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
20268     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
20269     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
20270       SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
20271       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
20272       AddToWorklist(Shift.getNode());
20273 
20274       if (XType.bitsGT(AType)) {
20275         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
20276         AddToWorklist(Shift.getNode());
20277       }
20278 
20279       if (CC == ISD::SETGT)
20280         Shift = DAG.getNOT(DL, Shift, AType);
20281 
20282       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
20283     }
20284   }
20285 
20286   unsigned ShCt = XType.getSizeInBits() - 1;
20287   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
20288     return SDValue();
20289 
20290   SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
20291   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
20292   AddToWorklist(Shift.getNode());
20293 
20294   if (XType.bitsGT(AType)) {
20295     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
20296     AddToWorklist(Shift.getNode());
20297   }
20298 
20299   if (CC == ISD::SETGT)
20300     Shift = DAG.getNOT(DL, Shift, AType);
20301 
20302   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
20303 }
20304 
20305 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
20306 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
20307 /// in it. This may be a win when the constant is not otherwise available
20308 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)20309 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
20310     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
20311     ISD::CondCode CC) {
20312   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
20313     return SDValue();
20314 
20315   // If we are before legalize types, we want the other legalization to happen
20316   // first (for example, to avoid messing with soft float).
20317   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
20318   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
20319   EVT VT = N2.getValueType();
20320   if (!TV || !FV || !TLI.isTypeLegal(VT))
20321     return SDValue();
20322 
20323   // If a constant can be materialized without loads, this does not make sense.
20324   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
20325       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
20326       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
20327     return SDValue();
20328 
20329   // If both constants have multiple uses, then we won't need to do an extra
20330   // load. The values are likely around in registers for other users.
20331   if (!TV->hasOneUse() && !FV->hasOneUse())
20332     return SDValue();
20333 
20334   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
20335                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
20336   Type *FPTy = Elts[0]->getType();
20337   const DataLayout &TD = DAG.getDataLayout();
20338 
20339   // Create a ConstantArray of the two constants.
20340   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
20341   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
20342                                       TD.getPrefTypeAlignment(FPTy));
20343   unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment();
20344 
20345   // Get offsets to the 0 and 1 elements of the array, so we can select between
20346   // them.
20347   SDValue Zero = DAG.getIntPtrConstant(0, DL);
20348   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
20349   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
20350   SDValue Cond =
20351       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
20352   AddToWorklist(Cond.getNode());
20353   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
20354   AddToWorklist(CstOffset.getNode());
20355   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
20356   AddToWorklist(CPIdx.getNode());
20357   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
20358                      MachinePointerInfo::getConstantPool(
20359                          DAG.getMachineFunction()), Alignment);
20360 }
20361 
20362 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
20363 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)20364 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
20365                                       SDValue N2, SDValue N3, ISD::CondCode CC,
20366                                       bool NotExtCompare) {
20367   // (x ? y : y) -> y.
20368   if (N2 == N3) return N2;
20369 
20370   EVT CmpOpVT = N0.getValueType();
20371   EVT CmpResVT = getSetCCResultType(CmpOpVT);
20372   EVT VT = N2.getValueType();
20373   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
20374   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
20375   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
20376 
20377   // Determine if the condition we're dealing with is constant.
20378   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
20379     AddToWorklist(SCC.getNode());
20380     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
20381       // fold select_cc true, x, y -> x
20382       // fold select_cc false, x, y -> y
20383       return !(SCCC->isNullValue()) ? N2 : N3;
20384     }
20385   }
20386 
20387   if (SDValue V =
20388           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
20389     return V;
20390 
20391   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
20392     return V;
20393 
20394   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (shr (shl x)) A)
20395   // where y is has a single bit set.
20396   // A plaintext description would be, we can turn the SELECT_CC into an AND
20397   // when the condition can be materialized as an all-ones register.  Any
20398   // single bit-test can be materialized as an all-ones register with
20399   // shift-left and shift-right-arith.
20400   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
20401       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
20402     SDValue AndLHS = N0->getOperand(0);
20403     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
20404     if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
20405       // Shift the tested bit over the sign bit.
20406       const APInt &AndMask = ConstAndRHS->getAPIntValue();
20407       unsigned ShCt = AndMask.getBitWidth() - 1;
20408       if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
20409         SDValue ShlAmt =
20410           DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
20411                           getShiftAmountTy(AndLHS.getValueType()));
20412         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
20413 
20414         // Now arithmetic right shift it all the way over, so the result is
20415         // either all-ones, or zero.
20416         SDValue ShrAmt =
20417           DAG.getConstant(ShCt, SDLoc(Shl),
20418                           getShiftAmountTy(Shl.getValueType()));
20419         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
20420 
20421         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
20422       }
20423     }
20424   }
20425 
20426   // fold select C, 16, 0 -> shl C, 4
20427   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
20428   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
20429 
20430   if ((Fold || Swap) &&
20431       TLI.getBooleanContents(CmpOpVT) ==
20432           TargetLowering::ZeroOrOneBooleanContent &&
20433       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
20434 
20435     if (Swap) {
20436       CC = ISD::getSetCCInverse(CC, CmpOpVT);
20437       std::swap(N2C, N3C);
20438     }
20439 
20440     // If the caller doesn't want us to simplify this into a zext of a compare,
20441     // don't do it.
20442     if (NotExtCompare && N2C->isOne())
20443       return SDValue();
20444 
20445     SDValue Temp, SCC;
20446     // zext (setcc n0, n1)
20447     if (LegalTypes) {
20448       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
20449       if (VT.bitsLT(SCC.getValueType()))
20450         Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
20451       else
20452         Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
20453     } else {
20454       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
20455       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
20456     }
20457 
20458     AddToWorklist(SCC.getNode());
20459     AddToWorklist(Temp.getNode());
20460 
20461     if (N2C->isOne())
20462       return Temp;
20463 
20464     unsigned ShCt = N2C->getAPIntValue().logBase2();
20465     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
20466       return SDValue();
20467 
20468     // shl setcc result by log2 n2c
20469     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
20470                        DAG.getConstant(ShCt, SDLoc(Temp),
20471                                        getShiftAmountTy(Temp.getValueType())));
20472   }
20473 
20474   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
20475   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
20476   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
20477   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
20478   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
20479   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
20480   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
20481   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
20482   if (N1C && N1C->isNullValue() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
20483     SDValue ValueOnZero = N2;
20484     SDValue Count = N3;
20485     // If the condition is NE instead of E, swap the operands.
20486     if (CC == ISD::SETNE)
20487       std::swap(ValueOnZero, Count);
20488     // Check if the value on zero is a constant equal to the bits in the type.
20489     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
20490       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
20491         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
20492         // legal, combine to just cttz.
20493         if ((Count.getOpcode() == ISD::CTTZ ||
20494              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
20495             N0 == Count.getOperand(0) &&
20496             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
20497           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
20498         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
20499         // legal, combine to just ctlz.
20500         if ((Count.getOpcode() == ISD::CTLZ ||
20501              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
20502             N0 == Count.getOperand(0) &&
20503             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
20504           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
20505       }
20506     }
20507   }
20508 
20509   return SDValue();
20510 }
20511 
20512 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)20513 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
20514                                    ISD::CondCode Cond, const SDLoc &DL,
20515                                    bool foldBooleans) {
20516   TargetLowering::DAGCombinerInfo
20517     DagCombineInfo(DAG, Level, false, this);
20518   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
20519 }
20520 
20521 /// Given an ISD::SDIV node expressing a divide by constant, return
20522 /// a DAG expression to select that will generate the same value by multiplying
20523 /// by a magic number.
20524 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)20525 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
20526   // when optimising for minimum size, we don't want to expand a div to a mul
20527   // and a shift.
20528   if (DAG.getMachineFunction().getFunction().hasMinSize())
20529     return SDValue();
20530 
20531   SmallVector<SDNode *, 8> Built;
20532   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
20533     for (SDNode *N : Built)
20534       AddToWorklist(N);
20535     return S;
20536   }
20537 
20538   return SDValue();
20539 }
20540 
20541 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
20542 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)20543 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
20544   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
20545   if (!C)
20546     return SDValue();
20547 
20548   // Avoid division by zero.
20549   if (C->isNullValue())
20550     return SDValue();
20551 
20552   SmallVector<SDNode *, 8> Built;
20553   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
20554     for (SDNode *N : Built)
20555       AddToWorklist(N);
20556     return S;
20557   }
20558 
20559   return SDValue();
20560 }
20561 
20562 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
20563 /// expression that will generate the same value by multiplying by a magic
20564 /// number.
20565 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)20566 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
20567   // when optimising for minimum size, we don't want to expand a div to a mul
20568   // and a shift.
20569   if (DAG.getMachineFunction().getFunction().hasMinSize())
20570     return SDValue();
20571 
20572   SmallVector<SDNode *, 8> Built;
20573   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
20574     for (SDNode *N : Built)
20575       AddToWorklist(N);
20576     return S;
20577   }
20578 
20579   return SDValue();
20580 }
20581 
20582 /// Determines the LogBase2 value for a non-null input value using the
20583 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL)20584 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
20585   EVT VT = V.getValueType();
20586   unsigned EltBits = VT.getScalarSizeInBits();
20587   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
20588   SDValue Base = DAG.getConstant(EltBits - 1, DL, VT);
20589   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
20590   return LogBase2;
20591 }
20592 
20593 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
20594 /// For the reciprocal, we need to find the zero of the function:
20595 ///   F(X) = A X - 1 [which has a zero at X = 1/A]
20596 ///     =>
20597 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
20598 ///     does not require additional intermediate precision]
20599 /// For the last iteration, put numerator N into it to gain more precision:
20600 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)20601 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
20602                                       SDNodeFlags Flags) {
20603   if (LegalDAG)
20604     return SDValue();
20605 
20606   // TODO: Handle half and/or extended types?
20607   EVT VT = Op.getValueType();
20608   if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
20609     return SDValue();
20610 
20611   // If estimates are explicitly disabled for this function, we're done.
20612   MachineFunction &MF = DAG.getMachineFunction();
20613   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
20614   if (Enabled == TLI.ReciprocalEstimate::Disabled)
20615     return SDValue();
20616 
20617   // Estimates may be explicitly enabled for this type with a custom number of
20618   // refinement steps.
20619   int Iterations = TLI.getDivRefinementSteps(VT, MF);
20620   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
20621     AddToWorklist(Est.getNode());
20622 
20623     SDLoc DL(Op);
20624     if (Iterations) {
20625       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
20626 
20627       // Newton iterations: Est = Est + Est (N - Arg * Est)
20628       // If this is the last iteration, also multiply by the numerator.
20629       for (int i = 0; i < Iterations; ++i) {
20630         SDValue MulEst = Est;
20631 
20632         if (i == Iterations - 1) {
20633           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
20634           AddToWorklist(MulEst.getNode());
20635         }
20636 
20637         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
20638         AddToWorklist(NewEst.getNode());
20639 
20640         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
20641                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
20642         AddToWorklist(NewEst.getNode());
20643 
20644         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
20645         AddToWorklist(NewEst.getNode());
20646 
20647         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
20648         AddToWorklist(Est.getNode());
20649       }
20650     } else {
20651       // If no iterations are available, multiply with N.
20652       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
20653       AddToWorklist(Est.getNode());
20654     }
20655 
20656     return Est;
20657   }
20658 
20659   return SDValue();
20660 }
20661 
20662 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
20663 /// For the reciprocal sqrt, we need to find the zero of the function:
20664 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
20665 ///     =>
20666 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
20667 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)20668 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
20669                                          unsigned Iterations,
20670                                          SDNodeFlags Flags, bool Reciprocal) {
20671   EVT VT = Arg.getValueType();
20672   SDLoc DL(Arg);
20673   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
20674 
20675   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
20676   // this entire sequence requires only one FP constant.
20677   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
20678   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
20679 
20680   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
20681   for (unsigned i = 0; i < Iterations; ++i) {
20682     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
20683     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
20684     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
20685     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
20686   }
20687 
20688   // If non-reciprocal square root is requested, multiply the result by Arg.
20689   if (!Reciprocal)
20690     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
20691 
20692   return Est;
20693 }
20694 
20695 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
20696 /// For the reciprocal sqrt, we need to find the zero of the function:
20697 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
20698 ///     =>
20699 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)20700 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
20701                                          unsigned Iterations,
20702                                          SDNodeFlags Flags, bool Reciprocal) {
20703   EVT VT = Arg.getValueType();
20704   SDLoc DL(Arg);
20705   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
20706   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
20707 
20708   // This routine must enter the loop below to work correctly
20709   // when (Reciprocal == false).
20710   assert(Iterations > 0);
20711 
20712   // Newton iterations for reciprocal square root:
20713   // E = (E * -0.5) * ((A * E) * E + -3.0)
20714   for (unsigned i = 0; i < Iterations; ++i) {
20715     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
20716     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
20717     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
20718 
20719     // When calculating a square root at the last iteration build:
20720     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
20721     // (notice a common subexpression)
20722     SDValue LHS;
20723     if (Reciprocal || (i + 1) < Iterations) {
20724       // RSQRT: LHS = (E * -0.5)
20725       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
20726     } else {
20727       // SQRT: LHS = (A * E) * -0.5
20728       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
20729     }
20730 
20731     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
20732   }
20733 
20734   return Est;
20735 }
20736 
20737 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
20738 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
20739 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)20740 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
20741                                            bool Reciprocal) {
20742   if (LegalDAG)
20743     return SDValue();
20744 
20745   // TODO: Handle half and/or extended types?
20746   EVT VT = Op.getValueType();
20747   if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
20748     return SDValue();
20749 
20750   // If estimates are explicitly disabled for this function, we're done.
20751   MachineFunction &MF = DAG.getMachineFunction();
20752   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
20753   if (Enabled == TLI.ReciprocalEstimate::Disabled)
20754     return SDValue();
20755 
20756   // Estimates may be explicitly enabled for this type with a custom number of
20757   // refinement steps.
20758   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
20759 
20760   bool UseOneConstNR = false;
20761   if (SDValue Est =
20762       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
20763                           Reciprocal)) {
20764     AddToWorklist(Est.getNode());
20765 
20766     if (Iterations) {
20767       Est = UseOneConstNR
20768             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
20769             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
20770 
20771       if (!Reciprocal) {
20772         // The estimate is now completely wrong if the input was exactly 0.0 or
20773         // possibly a denormal. Force the answer to 0.0 for those cases.
20774         SDLoc DL(Op);
20775         EVT CCVT = getSetCCResultType(VT);
20776         ISD::NodeType SelOpcode = VT.isVector() ? ISD::VSELECT : ISD::SELECT;
20777         DenormalMode DenormMode = DAG.getDenormalMode(VT);
20778         if (DenormMode == DenormalMode::IEEE) {
20779           // fabs(X) < SmallestNormal ? 0.0 : Est
20780           const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
20781           APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
20782           SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
20783           SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
20784           SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
20785           SDValue IsDenorm = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
20786           Est = DAG.getNode(SelOpcode, DL, VT, IsDenorm, FPZero, Est);
20787         } else {
20788           // X == 0.0 ? 0.0 : Est
20789           SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
20790           SDValue IsZero = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
20791           Est = DAG.getNode(SelOpcode, DL, VT, IsZero, FPZero, Est);
20792         }
20793       }
20794     }
20795     return Est;
20796   }
20797 
20798   return SDValue();
20799 }
20800 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)20801 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
20802   return buildSqrtEstimateImpl(Op, Flags, true);
20803 }
20804 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)20805 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
20806   return buildSqrtEstimateImpl(Op, Flags, false);
20807 }
20808 
20809 /// Return true if there is any possibility that the two addresses overlap.
isAlias(SDNode * Op0,SDNode * Op1) const20810 bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const {
20811 
20812   struct MemUseCharacteristics {
20813     bool IsVolatile;
20814     bool IsAtomic;
20815     SDValue BasePtr;
20816     int64_t Offset;
20817     Optional<int64_t> NumBytes;
20818     MachineMemOperand *MMO;
20819   };
20820 
20821   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
20822     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
20823       int64_t Offset = 0;
20824       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
20825         Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
20826                      ? C->getSExtValue()
20827                      : (LSN->getAddressingMode() == ISD::PRE_DEC)
20828                            ? -1 * C->getSExtValue()
20829                            : 0;
20830       return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(),
20831               Offset /*base offset*/,
20832               Optional<int64_t>(LSN->getMemoryVT().getStoreSize()),
20833               LSN->getMemOperand()};
20834     }
20835     if (const auto *LN = cast<LifetimeSDNode>(N))
20836       return {false /*isVolatile*/, /*isAtomic*/ false, LN->getOperand(1),
20837               (LN->hasOffset()) ? LN->getOffset() : 0,
20838               (LN->hasOffset()) ? Optional<int64_t>(LN->getSize())
20839                                 : Optional<int64_t>(),
20840               (MachineMemOperand *)nullptr};
20841     // Default.
20842     return {false /*isvolatile*/, /*isAtomic*/ false, SDValue(),
20843             (int64_t)0 /*offset*/,
20844             Optional<int64_t>() /*size*/, (MachineMemOperand *)nullptr};
20845   };
20846 
20847   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
20848                         MUC1 = getCharacteristics(Op1);
20849 
20850   // If they are to the same address, then they must be aliases.
20851   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
20852       MUC0.Offset == MUC1.Offset)
20853     return true;
20854 
20855   // If they are both volatile then they cannot be reordered.
20856   if (MUC0.IsVolatile && MUC1.IsVolatile)
20857     return true;
20858 
20859   // Be conservative about atomics for the moment
20860   // TODO: This is way overconservative for unordered atomics (see D66309)
20861   if (MUC0.IsAtomic && MUC1.IsAtomic)
20862     return true;
20863 
20864   if (MUC0.MMO && MUC1.MMO) {
20865     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
20866         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
20867       return false;
20868   }
20869 
20870   // Try to prove that there is aliasing, or that there is no aliasing. Either
20871   // way, we can return now. If nothing can be proved, proceed with more tests.
20872   bool IsAlias;
20873   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
20874                                        DAG, IsAlias))
20875     return IsAlias;
20876 
20877   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
20878   // either are not known.
20879   if (!MUC0.MMO || !MUC1.MMO)
20880     return true;
20881 
20882   // If one operation reads from invariant memory, and the other may store, they
20883   // cannot alias. These should really be checking the equivalent of mayWrite,
20884   // but it only matters for memory nodes other than load /store.
20885   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
20886       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
20887     return false;
20888 
20889   // If we know required SrcValue1 and SrcValue2 have relatively large
20890   // alignment compared to the size and offset of the access, we may be able
20891   // to prove they do not alias. This check is conservative for now to catch
20892   // cases created by splitting vector types.
20893   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
20894   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
20895   unsigned OrigAlignment0 = MUC0.MMO->getBaseAlignment();
20896   unsigned OrigAlignment1 = MUC1.MMO->getBaseAlignment();
20897   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
20898       MUC0.NumBytes.hasValue() && MUC1.NumBytes.hasValue() &&
20899       *MUC0.NumBytes == *MUC1.NumBytes && OrigAlignment0 > *MUC0.NumBytes) {
20900     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0;
20901     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1;
20902 
20903     // There is no overlap between these relatively aligned accesses of
20904     // similar size. Return no alias.
20905     if ((OffAlign0 + *MUC0.NumBytes) <= OffAlign1 ||
20906         (OffAlign1 + *MUC1.NumBytes) <= OffAlign0)
20907       return false;
20908   }
20909 
20910   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
20911                    ? CombinerGlobalAA
20912                    : DAG.getSubtarget().useAA();
20913 #ifndef NDEBUG
20914   if (CombinerAAOnlyFunc.getNumOccurrences() &&
20915       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
20916     UseAA = false;
20917 #endif
20918 
20919   if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue()) {
20920     // Use alias analysis information.
20921     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
20922     int64_t Overlap0 = *MUC0.NumBytes + SrcValOffset0 - MinOffset;
20923     int64_t Overlap1 = *MUC1.NumBytes + SrcValOffset1 - MinOffset;
20924     AliasResult AAResult = AA->alias(
20925         MemoryLocation(MUC0.MMO->getValue(), Overlap0,
20926                        UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
20927         MemoryLocation(MUC1.MMO->getValue(), Overlap1,
20928                        UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes()));
20929     if (AAResult == NoAlias)
20930       return false;
20931   }
20932 
20933   // Otherwise we have to assume they alias.
20934   return true;
20935 }
20936 
20937 /// Walk up chain skipping non-aliasing memory nodes,
20938 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)20939 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
20940                                    SmallVectorImpl<SDValue> &Aliases) {
20941   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
20942   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
20943 
20944   // Get alias information for node.
20945   // TODO: relax aliasing for unordered atomics (see D66309)
20946   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
20947 
20948   // Starting off.
20949   Chains.push_back(OriginalChain);
20950   unsigned Depth = 0;
20951 
20952   // Attempt to improve chain by a single step
20953   std::function<bool(SDValue &)> ImproveChain = [&](SDValue &C) -> bool {
20954     switch (C.getOpcode()) {
20955     case ISD::EntryToken:
20956       // No need to mark EntryToken.
20957       C = SDValue();
20958       return true;
20959     case ISD::LOAD:
20960     case ISD::STORE: {
20961       // Get alias information for C.
20962       // TODO: Relax aliasing for unordered atomics (see D66309)
20963       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
20964                       cast<LSBaseSDNode>(C.getNode())->isSimple();
20965       if ((IsLoad && IsOpLoad) || !isAlias(N, C.getNode())) {
20966         // Look further up the chain.
20967         C = C.getOperand(0);
20968         return true;
20969       }
20970       // Alias, so stop here.
20971       return false;
20972     }
20973 
20974     case ISD::CopyFromReg:
20975       // Always forward past past CopyFromReg.
20976       C = C.getOperand(0);
20977       return true;
20978 
20979     case ISD::LIFETIME_START:
20980     case ISD::LIFETIME_END: {
20981       // We can forward past any lifetime start/end that can be proven not to
20982       // alias the memory access.
20983       if (!isAlias(N, C.getNode())) {
20984         // Look further up the chain.
20985         C = C.getOperand(0);
20986         return true;
20987       }
20988       return false;
20989     }
20990     default:
20991       return false;
20992     }
20993   };
20994 
20995   // Look at each chain and determine if it is an alias.  If so, add it to the
20996   // aliases list.  If not, then continue up the chain looking for the next
20997   // candidate.
20998   while (!Chains.empty()) {
20999     SDValue Chain = Chains.pop_back_val();
21000 
21001     // Don't bother if we've seen Chain before.
21002     if (!Visited.insert(Chain.getNode()).second)
21003       continue;
21004 
21005     // For TokenFactor nodes, look at each operand and only continue up the
21006     // chain until we reach the depth limit.
21007     //
21008     // FIXME: The depth check could be made to return the last non-aliasing
21009     // chain we found before we hit a tokenfactor rather than the original
21010     // chain.
21011     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
21012       Aliases.clear();
21013       Aliases.push_back(OriginalChain);
21014       return;
21015     }
21016 
21017     if (Chain.getOpcode() == ISD::TokenFactor) {
21018       // We have to check each of the operands of the token factor for "small"
21019       // token factors, so we queue them up.  Adding the operands to the queue
21020       // (stack) in reverse order maintains the original order and increases the
21021       // likelihood that getNode will find a matching token factor (CSE.)
21022       if (Chain.getNumOperands() > 16) {
21023         Aliases.push_back(Chain);
21024         continue;
21025       }
21026       for (unsigned n = Chain.getNumOperands(); n;)
21027         Chains.push_back(Chain.getOperand(--n));
21028       ++Depth;
21029       continue;
21030     }
21031     // Everything else
21032     if (ImproveChain(Chain)) {
21033       // Updated Chain Found, Consider new chain if one exists.
21034       if (Chain.getNode())
21035         Chains.push_back(Chain);
21036       ++Depth;
21037       continue;
21038     }
21039     // No Improved Chain Possible, treat as Alias.
21040     Aliases.push_back(Chain);
21041   }
21042 }
21043 
21044 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
21045 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)21046 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
21047   if (OptLevel == CodeGenOpt::None)
21048     return OldChain;
21049 
21050   // Ops for replacing token factor.
21051   SmallVector<SDValue, 8> Aliases;
21052 
21053   // Accumulate all the aliases to this node.
21054   GatherAllAliases(N, OldChain, Aliases);
21055 
21056   // If no operands then chain to entry token.
21057   if (Aliases.size() == 0)
21058     return DAG.getEntryNode();
21059 
21060   // If a single operand then chain to it.  We don't need to revisit it.
21061   if (Aliases.size() == 1)
21062     return Aliases[0];
21063 
21064   // Construct a custom tailored token factor.
21065   return DAG.getTokenFactor(SDLoc(N), Aliases);
21066 }
21067 
21068 namespace {
21069 // TODO: Replace with with std::monostate when we move to C++17.
21070 struct UnitT { } Unit;
operator ==(const UnitT &,const UnitT &)21071 bool operator==(const UnitT &, const UnitT &) { return true; }
operator !=(const UnitT &,const UnitT &)21072 bool operator!=(const UnitT &, const UnitT &) { return false; }
21073 } // namespace
21074 
21075 // This function tries to collect a bunch of potentially interesting
21076 // nodes to improve the chains of, all at once. This might seem
21077 // redundant, as this function gets called when visiting every store
21078 // node, so why not let the work be done on each store as it's visited?
21079 //
21080 // I believe this is mainly important because MergeConsecutiveStores
21081 // is unable to deal with merging stores of different sizes, so unless
21082 // we improve the chains of all the potential candidates up-front
21083 // before running MergeConsecutiveStores, it might only see some of
21084 // the nodes that will eventually be candidates, and then not be able
21085 // to go from a partially-merged state to the desired final
21086 // fully-merged state.
21087 
parallelizeChainedStores(StoreSDNode * St)21088 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
21089   SmallVector<StoreSDNode *, 8> ChainedStores;
21090   StoreSDNode *STChain = St;
21091   // Intervals records which offsets from BaseIndex have been covered. In
21092   // the common case, every store writes to the immediately previous address
21093   // space and thus merged with the previous interval at insertion time.
21094 
21095   using IMap =
21096       llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>;
21097   IMap::Allocator A;
21098   IMap Intervals(A);
21099 
21100   // This holds the base pointer, index, and the offset in bytes from the base
21101   // pointer.
21102   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
21103 
21104   // We must have a base and an offset.
21105   if (!BasePtr.getBase().getNode())
21106     return false;
21107 
21108   // Do not handle stores to undef base pointers.
21109   if (BasePtr.getBase().isUndef())
21110     return false;
21111 
21112   // Add ST's interval.
21113   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit);
21114 
21115   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
21116     // If the chain has more than one use, then we can't reorder the mem ops.
21117     if (!SDValue(Chain, 0)->hasOneUse())
21118       break;
21119     // TODO: Relax for unordered atomics (see D66309)
21120     if (!Chain->isSimple() || Chain->isIndexed())
21121       break;
21122 
21123     // Find the base pointer and offset for this memory node.
21124     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
21125     // Check that the base pointer is the same as the original one.
21126     int64_t Offset;
21127     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
21128       break;
21129     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
21130     // Make sure we don't overlap with other intervals by checking the ones to
21131     // the left or right before inserting.
21132     auto I = Intervals.find(Offset);
21133     // If there's a next interval, we should end before it.
21134     if (I != Intervals.end() && I.start() < (Offset + Length))
21135       break;
21136     // If there's a previous interval, we should start after it.
21137     if (I != Intervals.begin() && (--I).stop() <= Offset)
21138       break;
21139     Intervals.insert(Offset, Offset + Length, Unit);
21140 
21141     ChainedStores.push_back(Chain);
21142     STChain = Chain;
21143   }
21144 
21145   // If we didn't find a chained store, exit.
21146   if (ChainedStores.size() == 0)
21147     return false;
21148 
21149   // Improve all chained stores (St and ChainedStores members) starting from
21150   // where the store chain ended and return single TokenFactor.
21151   SDValue NewChain = STChain->getChain();
21152   SmallVector<SDValue, 8> TFOps;
21153   for (unsigned I = ChainedStores.size(); I;) {
21154     StoreSDNode *S = ChainedStores[--I];
21155     SDValue BetterChain = FindBetterChain(S, NewChain);
21156     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
21157         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
21158     TFOps.push_back(SDValue(S, 0));
21159     ChainedStores[I] = S;
21160   }
21161 
21162   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
21163   SDValue BetterChain = FindBetterChain(St, NewChain);
21164   SDValue NewST;
21165   if (St->isTruncatingStore())
21166     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
21167                               St->getBasePtr(), St->getMemoryVT(),
21168                               St->getMemOperand());
21169   else
21170     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
21171                          St->getBasePtr(), St->getMemOperand());
21172 
21173   TFOps.push_back(NewST);
21174 
21175   // If we improved every element of TFOps, then we've lost the dependence on
21176   // NewChain to successors of St and we need to add it back to TFOps. Do so at
21177   // the beginning to keep relative order consistent with FindBetterChains.
21178   auto hasImprovedChain = [&](SDValue ST) -> bool {
21179     return ST->getOperand(0) != NewChain;
21180   };
21181   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
21182   if (AddNewChain)
21183     TFOps.insert(TFOps.begin(), NewChain);
21184 
21185   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
21186   CombineTo(St, TF);
21187 
21188   // Add TF and its operands to the worklist.
21189   AddToWorklist(TF.getNode());
21190   for (const SDValue &Op : TF->ops())
21191     AddToWorklist(Op.getNode());
21192   AddToWorklist(STChain);
21193   return true;
21194 }
21195 
findBetterNeighborChains(StoreSDNode * St)21196 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
21197   if (OptLevel == CodeGenOpt::None)
21198     return false;
21199 
21200   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
21201 
21202   // We must have a base and an offset.
21203   if (!BasePtr.getBase().getNode())
21204     return false;
21205 
21206   // Do not handle stores to undef base pointers.
21207   if (BasePtr.getBase().isUndef())
21208     return false;
21209 
21210   // Directly improve a chain of disjoint stores starting at St.
21211   if (parallelizeChainedStores(St))
21212     return true;
21213 
21214   // Improve St's Chain..
21215   SDValue BetterChain = FindBetterChain(St, St->getChain());
21216   if (St->getChain() != BetterChain) {
21217     replaceStoreChain(St, BetterChain);
21218     return true;
21219   }
21220   return false;
21221 }
21222 
21223 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOpt::Level OptLevel)21224 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
21225                            CodeGenOpt::Level OptLevel) {
21226   /// This is the main entry point to this class.
21227   DAGCombiner(*this, AA, OptLevel).Run(Level);
21228 }
21229