1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/Statistic.h"
31 #include "llvm/Analysis/AliasAnalysis.h"
32 #include "llvm/Analysis/MemoryLocation.h"
33 #include "llvm/CodeGen/DAGCombine.h"
34 #include "llvm/CodeGen/ISDOpcodes.h"
35 #include "llvm/CodeGen/MachineFrameInfo.h"
36 #include "llvm/CodeGen/MachineFunction.h"
37 #include "llvm/CodeGen/MachineMemOperand.h"
38 #include "llvm/CodeGen/RuntimeLibcalls.h"
39 #include "llvm/CodeGen/SelectionDAG.h"
40 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
41 #include "llvm/CodeGen/SelectionDAGNodes.h"
42 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
43 #include "llvm/CodeGen/TargetLowering.h"
44 #include "llvm/CodeGen/TargetRegisterInfo.h"
45 #include "llvm/CodeGen/TargetSubtargetInfo.h"
46 #include "llvm/CodeGen/ValueTypes.h"
47 #include "llvm/IR/Attributes.h"
48 #include "llvm/IR/Constant.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/DerivedTypes.h"
51 #include "llvm/IR/Function.h"
52 #include "llvm/IR/LLVMContext.h"
53 #include "llvm/IR/Metadata.h"
54 #include "llvm/Support/Casting.h"
55 #include "llvm/Support/CodeGen.h"
56 #include "llvm/Support/CommandLine.h"
57 #include "llvm/Support/Compiler.h"
58 #include "llvm/Support/Debug.h"
59 #include "llvm/Support/ErrorHandling.h"
60 #include "llvm/Support/KnownBits.h"
61 #include "llvm/Support/MachineValueType.h"
62 #include "llvm/Support/MathExtras.h"
63 #include "llvm/Support/raw_ostream.h"
64 #include "llvm/Target/TargetMachine.h"
65 #include "llvm/Target/TargetOptions.h"
66 #include <algorithm>
67 #include <cassert>
68 #include <cstdint>
69 #include <functional>
70 #include <iterator>
71 #include <string>
72 #include <tuple>
73 #include <utility>
74 
75 using namespace llvm;
76 
77 #define DEBUG_TYPE "dagcombine"
78 
79 STATISTIC(NodesCombined   , "Number of dag nodes combined");
80 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
81 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
82 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
83 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
84 STATISTIC(SlicedLoads, "Number of load sliced");
85 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
86 
87 static cl::opt<bool>
88 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
89                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
90 
91 static cl::opt<bool>
92 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
93         cl::desc("Enable DAG combiner's use of TBAA"));
94 
95 #ifndef NDEBUG
96 static cl::opt<std::string>
97 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
98                    cl::desc("Only use DAG-combiner alias analysis in this"
99                             " function"));
100 #endif
101 
102 /// Hidden option to stress test load slicing, i.e., when this option
103 /// is enabled, load slicing bypasses most of its profitability guards.
104 static cl::opt<bool>
105 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
106                   cl::desc("Bypass the profitability model of load slicing"),
107                   cl::init(false));
108 
109 static cl::opt<bool>
110   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
111                     cl::desc("DAG combiner may split indexing from loads"));
112 
113 static cl::opt<bool>
114     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
115                        cl::desc("DAG combiner enable merging multiple stores "
116                                 "into a wider store"));
117 
118 static cl::opt<unsigned> TokenFactorInlineLimit(
119     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
120     cl::desc("Limit the number of operands to inline for Token Factors"));
121 
122 static cl::opt<unsigned> StoreMergeDependenceLimit(
123     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
124     cl::desc("Limit the number of times for the same StoreNode and RootNode "
125              "to bail out in store merging dependence check"));
126 
127 namespace {
128 
129   class DAGCombiner {
130     SelectionDAG &DAG;
131     const TargetLowering &TLI;
132     CombineLevel Level;
133     CodeGenOpt::Level OptLevel;
134     bool LegalDAG = false;
135     bool LegalOperations = false;
136     bool LegalTypes = false;
137     bool ForCodeSize;
138 
139     /// Worklist of all of the nodes that need to be simplified.
140     ///
141     /// This must behave as a stack -- new nodes to process are pushed onto the
142     /// back and when processing we pop off of the back.
143     ///
144     /// The worklist will not contain duplicates but may contain null entries
145     /// due to nodes being deleted from the underlying DAG.
146     SmallVector<SDNode *, 64> Worklist;
147 
148     /// Mapping from an SDNode to its position on the worklist.
149     ///
150     /// This is used to find and remove nodes from the worklist (by nulling
151     /// them) when they are deleted from the underlying DAG. It relies on
152     /// stable indices of nodes within the worklist.
153     DenseMap<SDNode *, unsigned> WorklistMap;
154     /// This records all nodes attempted to add to the worklist since we
155     /// considered a new worklist entry. As we keep do not add duplicate nodes
156     /// in the worklist, this is different from the tail of the worklist.
157     SmallSetVector<SDNode *, 32> PruningList;
158 
159     /// Set of nodes which have been combined (at least once).
160     ///
161     /// This is used to allow us to reliably add any operands of a DAG node
162     /// which have not yet been combined to the worklist.
163     SmallPtrSet<SDNode *, 32> CombinedNodes;
164 
165     /// Map from candidate StoreNode to the pair of RootNode and count.
166     /// The count is used to track how many times we have seen the StoreNode
167     /// with the same RootNode bail out in dependence check. If we have seen
168     /// the bail out for the same pair many times over a limit, we won't
169     /// consider the StoreNode with the same RootNode as store merging
170     /// candidate again.
171     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
172 
173     // AA - Used for DAG load/store alias analysis.
174     AliasAnalysis *AA;
175 
176     /// When an instruction is simplified, add all users of the instruction to
177     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)178     void AddUsersToWorklist(SDNode *N) {
179       for (SDNode *Node : N->uses())
180         AddToWorklist(Node);
181     }
182 
183     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)184     void AddToWorklistWithUsers(SDNode *N) {
185       AddUsersToWorklist(N);
186       AddToWorklist(N);
187     }
188 
189     // Prune potentially dangling nodes. This is called after
190     // any visit to a node, but should also be called during a visit after any
191     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()192     void clearAddedDanglingWorklistEntries() {
193       // Check any nodes added to the worklist to see if they are prunable.
194       while (!PruningList.empty()) {
195         auto *N = PruningList.pop_back_val();
196         if (N->use_empty())
197           recursivelyDeleteUnusedNodes(N);
198       }
199     }
200 
getNextWorklistEntry()201     SDNode *getNextWorklistEntry() {
202       // Before we do any work, remove nodes that are not in use.
203       clearAddedDanglingWorklistEntries();
204       SDNode *N = nullptr;
205       // The Worklist holds the SDNodes in order, but it may contain null
206       // entries.
207       while (!N && !Worklist.empty()) {
208         N = Worklist.pop_back_val();
209       }
210 
211       if (N) {
212         bool GoodWorklistEntry = WorklistMap.erase(N);
213         (void)GoodWorklistEntry;
214         assert(GoodWorklistEntry &&
215                "Found a worklist entry without a corresponding map entry!");
216       }
217       return N;
218     }
219 
220     /// Call the node-specific routine that folds each particular type of node.
221     SDValue visit(SDNode *N);
222 
223   public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOpt::Level OL)224     DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
225         : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes),
226           OptLevel(OL), AA(AA) {
227       ForCodeSize = DAG.shouldOptForSize();
228 
229       MaximumLegalStoreInBits = 0;
230       // We use the minimum store size here, since that's all we can guarantee
231       // for the scalable vector types.
232       for (MVT VT : MVT::all_valuetypes())
233         if (EVT(VT).isSimple() && VT != MVT::Other &&
234             TLI.isTypeLegal(EVT(VT)) &&
235             VT.getSizeInBits().getKnownMinSize() >= MaximumLegalStoreInBits)
236           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinSize();
237     }
238 
ConsiderForPruning(SDNode * N)239     void ConsiderForPruning(SDNode *N) {
240       // Mark this for potential pruning.
241       PruningList.insert(N);
242     }
243 
244     /// Add to the worklist making sure its instance is at the back (next to be
245     /// processed.)
AddToWorklist(SDNode * N)246     void AddToWorklist(SDNode *N) {
247       assert(N->getOpcode() != ISD::DELETED_NODE &&
248              "Deleted Node added to Worklist");
249 
250       // Skip handle nodes as they can't usefully be combined and confuse the
251       // zero-use deletion strategy.
252       if (N->getOpcode() == ISD::HANDLENODE)
253         return;
254 
255       ConsiderForPruning(N);
256 
257       if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
258         Worklist.push_back(N);
259     }
260 
261     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)262     void removeFromWorklist(SDNode *N) {
263       CombinedNodes.erase(N);
264       PruningList.remove(N);
265       StoreRootCountMap.erase(N);
266 
267       auto It = WorklistMap.find(N);
268       if (It == WorklistMap.end())
269         return; // Not in the worklist.
270 
271       // Null out the entry rather than erasing it to avoid a linear operation.
272       Worklist[It->second] = nullptr;
273       WorklistMap.erase(It);
274     }
275 
276     void deleteAndRecombine(SDNode *N);
277     bool recursivelyDeleteUnusedNodes(SDNode *N);
278 
279     /// Replaces all uses of the results of one DAG node with new values.
280     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
281                       bool AddTo = true);
282 
283     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)284     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
285       return CombineTo(N, &Res, 1, AddTo);
286     }
287 
288     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)289     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
290                       bool AddTo = true) {
291       SDValue To[] = { Res0, Res1 };
292       return CombineTo(N, To, 2, AddTo);
293     }
294 
295     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
296 
297   private:
298     unsigned MaximumLegalStoreInBits;
299 
300     /// Check the specified integer node value to see if it can be simplified or
301     /// if things it uses can be simplified by bit propagation.
302     /// If so, return true.
SimplifyDemandedBits(SDValue Op)303     bool SimplifyDemandedBits(SDValue Op) {
304       unsigned BitWidth = Op.getScalarValueSizeInBits();
305       APInt DemandedBits = APInt::getAllOnesValue(BitWidth);
306       return SimplifyDemandedBits(Op, DemandedBits);
307     }
308 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)309     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
310       EVT VT = Op.getValueType();
311       unsigned NumElts = VT.isVector() ? VT.getVectorNumElements() : 1;
312       APInt DemandedElts = APInt::getAllOnesValue(NumElts);
313       return SimplifyDemandedBits(Op, DemandedBits, DemandedElts);
314     }
315 
316     /// Check the specified vector node value to see if it can be simplified or
317     /// if things it uses can be simplified as it only uses some of the
318     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)319     bool SimplifyDemandedVectorElts(SDValue Op) {
320       unsigned NumElts = Op.getValueType().getVectorNumElements();
321       APInt DemandedElts = APInt::getAllOnesValue(NumElts);
322       return SimplifyDemandedVectorElts(Op, DemandedElts);
323     }
324 
325     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
326                               const APInt &DemandedElts);
327     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
328                                     bool AssumeSingleUse = false);
329 
330     bool CombineToPreIndexedLoadStore(SDNode *N);
331     bool CombineToPostIndexedLoadStore(SDNode *N);
332     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
333     bool SliceUpLoad(SDNode *N);
334 
335     // Scalars have size 0 to distinguish from singleton vectors.
336     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
337     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
338     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
339 
340     /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
341     ///   load.
342     ///
343     /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
344     /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
345     /// \param EltNo index of the vector element to load.
346     /// \param OriginalLoad load that EVE came from to be replaced.
347     /// \returns EVE on success SDValue() on failure.
348     SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
349                                          SDValue EltNo,
350                                          LoadSDNode *OriginalLoad);
351     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
352     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
353     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
354     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
355     SDValue PromoteIntBinOp(SDValue Op);
356     SDValue PromoteIntShiftOp(SDValue Op);
357     SDValue PromoteExtend(SDValue Op);
358     bool PromoteLoad(SDValue Op);
359 
360     /// Call the node-specific routine that knows how to fold each
361     /// particular type of node. If that doesn't do anything, try the
362     /// target-specific DAG combines.
363     SDValue combine(SDNode *N);
364 
365     // Visitation implementation - Implement dag node combining for different
366     // node types.  The semantics are as follows:
367     // Return Value:
368     //   SDValue.getNode() == 0 - No change was made
369     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
370     //   otherwise              - N should be replaced by the returned Operand.
371     //
372     SDValue visitTokenFactor(SDNode *N);
373     SDValue visitMERGE_VALUES(SDNode *N);
374     SDValue visitADD(SDNode *N);
375     SDValue visitADDLike(SDNode *N);
376     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
377     SDValue visitSUB(SDNode *N);
378     SDValue visitADDSAT(SDNode *N);
379     SDValue visitSUBSAT(SDNode *N);
380     SDValue visitADDC(SDNode *N);
381     SDValue visitADDO(SDNode *N);
382     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
383     SDValue visitSUBC(SDNode *N);
384     SDValue visitSUBO(SDNode *N);
385     SDValue visitADDE(SDNode *N);
386     SDValue visitADDCARRY(SDNode *N);
387     SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
388     SDValue visitSUBE(SDNode *N);
389     SDValue visitSUBCARRY(SDNode *N);
390     SDValue visitMUL(SDNode *N);
391     SDValue visitMULFIX(SDNode *N);
392     SDValue useDivRem(SDNode *N);
393     SDValue visitSDIV(SDNode *N);
394     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
395     SDValue visitUDIV(SDNode *N);
396     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
397     SDValue visitREM(SDNode *N);
398     SDValue visitMULHU(SDNode *N);
399     SDValue visitMULHS(SDNode *N);
400     SDValue visitSMUL_LOHI(SDNode *N);
401     SDValue visitUMUL_LOHI(SDNode *N);
402     SDValue visitMULO(SDNode *N);
403     SDValue visitIMINMAX(SDNode *N);
404     SDValue visitAND(SDNode *N);
405     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
406     SDValue visitOR(SDNode *N);
407     SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
408     SDValue visitXOR(SDNode *N);
409     SDValue SimplifyVBinOp(SDNode *N);
410     SDValue visitSHL(SDNode *N);
411     SDValue visitSRA(SDNode *N);
412     SDValue visitSRL(SDNode *N);
413     SDValue visitFunnelShift(SDNode *N);
414     SDValue visitRotate(SDNode *N);
415     SDValue visitABS(SDNode *N);
416     SDValue visitBSWAP(SDNode *N);
417     SDValue visitBITREVERSE(SDNode *N);
418     SDValue visitCTLZ(SDNode *N);
419     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
420     SDValue visitCTTZ(SDNode *N);
421     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
422     SDValue visitCTPOP(SDNode *N);
423     SDValue visitSELECT(SDNode *N);
424     SDValue visitVSELECT(SDNode *N);
425     SDValue visitSELECT_CC(SDNode *N);
426     SDValue visitSETCC(SDNode *N);
427     SDValue visitSETCCCARRY(SDNode *N);
428     SDValue visitSIGN_EXTEND(SDNode *N);
429     SDValue visitZERO_EXTEND(SDNode *N);
430     SDValue visitANY_EXTEND(SDNode *N);
431     SDValue visitAssertExt(SDNode *N);
432     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
433     SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N);
434     SDValue visitZERO_EXTEND_VECTOR_INREG(SDNode *N);
435     SDValue visitTRUNCATE(SDNode *N);
436     SDValue visitBITCAST(SDNode *N);
437     SDValue visitBUILD_PAIR(SDNode *N);
438     SDValue visitFADD(SDNode *N);
439     SDValue visitFSUB(SDNode *N);
440     SDValue visitFMUL(SDNode *N);
441     SDValue visitFMA(SDNode *N);
442     SDValue visitFDIV(SDNode *N);
443     SDValue visitFREM(SDNode *N);
444     SDValue visitFSQRT(SDNode *N);
445     SDValue visitFCOPYSIGN(SDNode *N);
446     SDValue visitFPOW(SDNode *N);
447     SDValue visitSINT_TO_FP(SDNode *N);
448     SDValue visitUINT_TO_FP(SDNode *N);
449     SDValue visitFP_TO_SINT(SDNode *N);
450     SDValue visitFP_TO_UINT(SDNode *N);
451     SDValue visitFP_ROUND(SDNode *N);
452     SDValue visitFP_EXTEND(SDNode *N);
453     SDValue visitFNEG(SDNode *N);
454     SDValue visitFABS(SDNode *N);
455     SDValue visitFCEIL(SDNode *N);
456     SDValue visitFTRUNC(SDNode *N);
457     SDValue visitFFLOOR(SDNode *N);
458     SDValue visitFMINNUM(SDNode *N);
459     SDValue visitFMAXNUM(SDNode *N);
460     SDValue visitFMINIMUM(SDNode *N);
461     SDValue visitFMAXIMUM(SDNode *N);
462     SDValue visitBRCOND(SDNode *N);
463     SDValue visitBR_CC(SDNode *N);
464     SDValue visitLOAD(SDNode *N);
465 
466     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
467     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
468 
469     SDValue visitSTORE(SDNode *N);
470     SDValue visitLIFETIME_END(SDNode *N);
471     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
472     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
473     SDValue visitBUILD_VECTOR(SDNode *N);
474     SDValue visitCONCAT_VECTORS(SDNode *N);
475     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
476     SDValue visitVECTOR_SHUFFLE(SDNode *N);
477     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
478     SDValue visitINSERT_SUBVECTOR(SDNode *N);
479     SDValue visitMLOAD(SDNode *N);
480     SDValue visitMSTORE(SDNode *N);
481     SDValue visitMGATHER(SDNode *N);
482     SDValue visitMSCATTER(SDNode *N);
483     SDValue visitFP_TO_FP16(SDNode *N);
484     SDValue visitFP16_TO_FP(SDNode *N);
485     SDValue visitVECREDUCE(SDNode *N);
486 
487     SDValue visitFADDForFMACombine(SDNode *N);
488     SDValue visitFSUBForFMACombine(SDNode *N);
489     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
490 
491     SDValue XformToShuffleWithZero(SDNode *N);
492     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
493                                                     const SDLoc &DL, SDValue N0,
494                                                     SDValue N1);
495     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
496                                       SDValue N1);
497     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
498                            SDValue N1, SDNodeFlags Flags);
499 
500     SDValue visitShiftByConstant(SDNode *N);
501 
502     SDValue foldSelectOfConstants(SDNode *N);
503     SDValue foldVSelectOfConstants(SDNode *N);
504     SDValue foldBinOpIntoSelect(SDNode *BO);
505     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
506     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
507     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
508     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
509                              SDValue N2, SDValue N3, ISD::CondCode CC,
510                              bool NotExtCompare = false);
511     SDValue convertSelectOfFPConstantsToLoadOffset(
512         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
513         ISD::CondCode CC);
514     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
515                                    SDValue N2, SDValue N3, ISD::CondCode CC);
516     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
517                               const SDLoc &DL);
518     SDValue unfoldMaskedMerge(SDNode *N);
519     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
520     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
521                           const SDLoc &DL, bool foldBooleans);
522     SDValue rebuildSetCC(SDValue N);
523 
524     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
525                            SDValue &CC) const;
526     bool isOneUseSetCC(SDValue N) const;
527     bool isCheaperToUseNegatedFPOps(SDValue X, SDValue Y);
528 
529     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
530                                          unsigned HiOp);
531     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
532     SDValue CombineExtLoad(SDNode *N);
533     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
534     SDValue combineRepeatedFPDivisors(SDNode *N);
535     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
536     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
537     SDValue BuildSDIV(SDNode *N);
538     SDValue BuildSDIVPow2(SDNode *N);
539     SDValue BuildUDIV(SDNode *N);
540     SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
541     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
542     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
543     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
544     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
545     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
546                                 SDNodeFlags Flags, bool Reciprocal);
547     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
548                                 SDNodeFlags Flags, bool Reciprocal);
549     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
550                                bool DemandHighBits = true);
551     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
552     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
553                               SDValue InnerPos, SDValue InnerNeg,
554                               unsigned PosOpcode, unsigned NegOpcode,
555                               const SDLoc &DL);
556     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
557     SDValue MatchLoadCombine(SDNode *N);
558     SDValue MatchStoreCombine(StoreSDNode *N);
559     SDValue ReduceLoadWidth(SDNode *N);
560     SDValue ReduceLoadOpStoreWidth(SDNode *N);
561     SDValue splitMergedValStore(StoreSDNode *ST);
562     SDValue TransformFPLoadStorePair(SDNode *N);
563     SDValue convertBuildVecZextToZext(SDNode *N);
564     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
565     SDValue reduceBuildVecToShuffle(SDNode *N);
566     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
567                                   ArrayRef<int> VectorMask, SDValue VecIn1,
568                                   SDValue VecIn2, unsigned LeftIdx,
569                                   bool DidSplitVec);
570     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
571 
572     /// Walk up chain skipping non-aliasing memory nodes,
573     /// looking for aliasing nodes and adding them to the Aliases vector.
574     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
575                           SmallVectorImpl<SDValue> &Aliases);
576 
577     /// Return true if there is any possibility that the two addresses overlap.
578     bool isAlias(SDNode *Op0, SDNode *Op1) const;
579 
580     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
581     /// chain (aliasing node.)
582     SDValue FindBetterChain(SDNode *N, SDValue Chain);
583 
584     /// Try to replace a store and any possibly adjacent stores on
585     /// consecutive chains with better chains. Return true only if St is
586     /// replaced.
587     ///
588     /// Notice that other chains may still be replaced even if the function
589     /// returns false.
590     bool findBetterNeighborChains(StoreSDNode *St);
591 
592     // Helper for findBetterNeighborChains. Walk up store chain add additional
593     // chained stores that do not overlap and can be parallelized.
594     bool parallelizeChainedStores(StoreSDNode *St);
595 
596     /// Holds a pointer to an LSBaseSDNode as well as information on where it
597     /// is located in a sequence of memory operations connected by a chain.
598     struct MemOpLink {
599       // Ptr to the mem node.
600       LSBaseSDNode *MemNode;
601 
602       // Offset from the base ptr.
603       int64_t OffsetFromBase;
604 
MemOpLink__anon9770a4810111::DAGCombiner::MemOpLink605       MemOpLink(LSBaseSDNode *N, int64_t Offset)
606           : MemNode(N), OffsetFromBase(Offset) {}
607     };
608 
609     /// This is a helper function for visitMUL to check the profitability
610     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
611     /// MulNode is the original multiply, AddNode is (add x, c1),
612     /// and ConstNode is c2.
613     bool isMulAddWithConstProfitable(SDNode *MulNode,
614                                      SDValue &AddNode,
615                                      SDValue &ConstNode);
616 
617     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
618     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
619     /// the type of the loaded value to be extended.
620     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
621                           EVT LoadResultTy, EVT &ExtVT);
622 
623     /// Helper function to calculate whether the given Load/Store can have its
624     /// width reduced to ExtVT.
625     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
626                            EVT &MemVT, unsigned ShAmt = 0);
627 
628     /// Used by BackwardsPropagateMask to find suitable loads.
629     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
630                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
631                            ConstantSDNode *Mask, SDNode *&NodeToMask);
632     /// Attempt to propagate a given AND node back to load leaves so that they
633     /// can be combined into narrow loads.
634     bool BackwardsPropagateMask(SDNode *N);
635 
636     /// Helper function for MergeConsecutiveStores which merges the
637     /// component store chains.
638     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
639                                 unsigned NumStores);
640 
641     /// This is a helper function for MergeConsecutiveStores. When the
642     /// source elements of the consecutive stores are all constants or
643     /// all extracted vector elements, try to merge them into one
644     /// larger store introducing bitcasts if necessary.  \return True
645     /// if a merged store was created.
646     bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
647                                          EVT MemVT, unsigned NumStores,
648                                          bool IsConstantSrc, bool UseVector,
649                                          bool UseTrunc);
650 
651     /// This is a helper function for MergeConsecutiveStores. Stores
652     /// that potentially may be merged with St are placed in
653     /// StoreNodes. RootNode is a chain predecessor to all store
654     /// candidates.
655     void getStoreMergeCandidates(StoreSDNode *St,
656                                  SmallVectorImpl<MemOpLink> &StoreNodes,
657                                  SDNode *&Root);
658 
659     /// Helper function for MergeConsecutiveStores. Checks if
660     /// candidate stores have indirect dependency through their
661     /// operands. RootNode is the predecessor to all stores calculated
662     /// by getStoreMergeCandidates and is used to prune the dependency check.
663     /// \return True if safe to merge.
664     bool checkMergeStoreCandidatesForDependencies(
665         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
666         SDNode *RootNode);
667 
668     /// Merge consecutive store operations into a wide store.
669     /// This optimization uses wide integers or vectors when possible.
670     /// \return number of stores that were merged into a merged store (the
671     /// affected nodes are stored as a prefix in \p StoreNodes).
672     bool MergeConsecutiveStores(StoreSDNode *St);
673 
674     /// Try to transform a truncation where C is a constant:
675     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
676     ///
677     /// \p N needs to be a truncation and its first operand an AND. Other
678     /// requirements are checked by the function (e.g. that trunc is
679     /// single-use) and if missed an empty SDValue is returned.
680     SDValue distributeTruncateThroughAnd(SDNode *N);
681 
682     /// Helper function to determine whether the target supports operation
683     /// given by \p Opcode for type \p VT, that is, whether the operation
684     /// is legal or custom before legalizing operations, and whether is
685     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)686     bool hasOperation(unsigned Opcode, EVT VT) {
687       if (LegalOperations)
688         return TLI.isOperationLegal(Opcode, VT);
689       return TLI.isOperationLegalOrCustom(Opcode, VT);
690     }
691 
692   public:
693     /// Runs the dag combiner on all nodes in the work list
694     void Run(CombineLevel AtLevel);
695 
getDAG() const696     SelectionDAG &getDAG() const { return DAG; }
697 
698     /// Returns a type large enough to hold any valid shift amount - before type
699     /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)700     EVT getShiftAmountTy(EVT LHSTy) {
701       assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
702       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
703     }
704 
705     /// This method returns true if we are running before type legalization or
706     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)707     bool isTypeLegal(const EVT &VT) {
708       if (!LegalTypes) return true;
709       return TLI.isTypeLegal(VT);
710     }
711 
712     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const713     EVT getSetCCResultType(EVT VT) const {
714       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
715     }
716 
717     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
718                          SDValue OrigLoad, SDValue ExtLoad,
719                          ISD::NodeType ExtType);
720   };
721 
722 /// This class is a DAGUpdateListener that removes any deleted
723 /// nodes from the worklist.
724 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
725   DAGCombiner &DC;
726 
727 public:
WorklistRemover(DAGCombiner & dc)728   explicit WorklistRemover(DAGCombiner &dc)
729     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
730 
NodeDeleted(SDNode * N,SDNode * E)731   void NodeDeleted(SDNode *N, SDNode *E) override {
732     DC.removeFromWorklist(N);
733   }
734 };
735 
736 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
737   DAGCombiner &DC;
738 
739 public:
WorklistInserter(DAGCombiner & dc)740   explicit WorklistInserter(DAGCombiner &dc)
741       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
742 
743   // FIXME: Ideally we could add N to the worklist, but this causes exponential
744   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)745   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
746 };
747 
748 } // end anonymous namespace
749 
750 //===----------------------------------------------------------------------===//
751 //  TargetLowering::DAGCombinerInfo implementation
752 //===----------------------------------------------------------------------===//
753 
AddToWorklist(SDNode * N)754 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
755   ((DAGCombiner*)DC)->AddToWorklist(N);
756 }
757 
758 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)759 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
760   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
761 }
762 
763 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)764 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
765   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
766 }
767 
768 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)769 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
770   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
771 }
772 
773 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)774 recursivelyDeleteUnusedNodes(SDNode *N) {
775   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
776 }
777 
778 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)779 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
780   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Helper Functions
785 //===----------------------------------------------------------------------===//
786 
deleteAndRecombine(SDNode * N)787 void DAGCombiner::deleteAndRecombine(SDNode *N) {
788   removeFromWorklist(N);
789 
790   // If the operands of this node are only used by the node, they will now be
791   // dead. Make sure to re-visit them and recursively delete dead nodes.
792   for (const SDValue &Op : N->ops())
793     // For an operand generating multiple values, one of the values may
794     // become dead allowing further simplification (e.g. split index
795     // arithmetic from an indexed load).
796     if (Op->hasOneUse() || Op->getNumValues() > 1)
797       AddToWorklist(Op.getNode());
798 
799   DAG.DeleteNode(N);
800 }
801 
802 // APInts must be the same size for most operations, this helper
803 // function zero extends the shorter of the pair so that they match.
804 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)805 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
806   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
807   LHS = LHS.zextOrSelf(Bits);
808   RHS = RHS.zextOrSelf(Bits);
809 }
810 
811 // Return true if this node is a setcc, or is a select_cc
812 // that selects between the target values used for true and false, making it
813 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
814 // the appropriate nodes based on the type of node we are checking. This
815 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC) const816 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
817                                     SDValue &CC) const {
818   if (N.getOpcode() == ISD::SETCC) {
819     LHS = N.getOperand(0);
820     RHS = N.getOperand(1);
821     CC  = N.getOperand(2);
822     return true;
823   }
824 
825   if (N.getOpcode() != ISD::SELECT_CC ||
826       !TLI.isConstTrueVal(N.getOperand(2).getNode()) ||
827       !TLI.isConstFalseVal(N.getOperand(3).getNode()))
828     return false;
829 
830   if (TLI.getBooleanContents(N.getValueType()) ==
831       TargetLowering::UndefinedBooleanContent)
832     return false;
833 
834   LHS = N.getOperand(0);
835   RHS = N.getOperand(1);
836   CC  = N.getOperand(4);
837   return true;
838 }
839 
840 /// Return true if this is a SetCC-equivalent operation with only one use.
841 /// If this is true, it allows the users to invert the operation for free when
842 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const843 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
844   SDValue N0, N1, N2;
845   if (isSetCCEquivalent(N, N0, N1, N2) && N.getNode()->hasOneUse())
846     return true;
847   return false;
848 }
849 
850 // Returns the SDNode if it is a constant float BuildVector
851 // or constant float.
isConstantFPBuildVectorOrConstantFP(SDValue N)852 static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) {
853   if (isa<ConstantFPSDNode>(N))
854     return N.getNode();
855   if (ISD::isBuildVectorOfConstantFPSDNodes(N.getNode()))
856     return N.getNode();
857   return nullptr;
858 }
859 
860 // Determines if it is a constant integer or a build vector of constant
861 // integers (and undefs).
862 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)863 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
864   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
865     return !(Const->isOpaque() && NoOpaques);
866   if (N.getOpcode() != ISD::BUILD_VECTOR)
867     return false;
868   unsigned BitWidth = N.getScalarValueSizeInBits();
869   for (const SDValue &Op : N->op_values()) {
870     if (Op.isUndef())
871       continue;
872     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
873     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
874         (Const->isOpaque() && NoOpaques))
875       return false;
876   }
877   return true;
878 }
879 
880 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
881 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)882 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
883   if (V.getOpcode() != ISD::BUILD_VECTOR)
884     return false;
885   return isConstantOrConstantVector(V, NoOpaques) ||
886          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
887 }
888 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)889 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
890                                                              const SDLoc &DL,
891                                                              SDValue N0,
892                                                              SDValue N1) {
893   // Currently this only tries to ensure we don't undo the GEP splits done by
894   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
895   // we check if the following transformation would be problematic:
896   // (load/store (add, (add, x, offset1), offset2)) ->
897   // (load/store (add, x, offset1+offset2)).
898 
899   if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
900     return false;
901 
902   if (N0.hasOneUse())
903     return false;
904 
905   auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
906   auto *C2 = dyn_cast<ConstantSDNode>(N1);
907   if (!C1 || !C2)
908     return false;
909 
910   const APInt &C1APIntVal = C1->getAPIntValue();
911   const APInt &C2APIntVal = C2->getAPIntValue();
912   if (C1APIntVal.getBitWidth() > 64 || C2APIntVal.getBitWidth() > 64)
913     return false;
914 
915   const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
916   if (CombinedValueIntVal.getBitWidth() > 64)
917     return false;
918   const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
919 
920   for (SDNode *Node : N0->uses()) {
921     auto LoadStore = dyn_cast<MemSDNode>(Node);
922     if (LoadStore) {
923       // Is x[offset2] already not a legal addressing mode? If so then
924       // reassociating the constants breaks nothing (we test offset2 because
925       // that's the one we hope to fold into the load or store).
926       TargetLoweringBase::AddrMode AM;
927       AM.HasBaseReg = true;
928       AM.BaseOffs = C2APIntVal.getSExtValue();
929       EVT VT = LoadStore->getMemoryVT();
930       unsigned AS = LoadStore->getAddressSpace();
931       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
932       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
933         continue;
934 
935       // Would x[offset1+offset2] still be a legal addressing mode?
936       AM.BaseOffs = CombinedValue;
937       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
938         return true;
939     }
940   }
941 
942   return false;
943 }
944 
945 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
946 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)947 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
948                                                SDValue N0, SDValue N1) {
949   EVT VT = N0.getValueType();
950 
951   if (N0.getOpcode() != Opc)
952     return SDValue();
953 
954   // Don't reassociate reductions.
955   if (N0->getFlags().hasVectorReduction())
956     return SDValue();
957 
958   if (SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) {
959     if (SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
960       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
961       if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, C1, C2))
962         return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode);
963       return SDValue();
964     }
965     if (N0.hasOneUse()) {
966       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
967       //              iff (op x, c1) has one use
968       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N0.getOperand(0), N1);
969       if (!OpNode.getNode())
970         return SDValue();
971       return DAG.getNode(Opc, DL, VT, OpNode, N0.getOperand(1));
972     }
973   }
974   return SDValue();
975 }
976 
977 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)978 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
979                                     SDValue N1, SDNodeFlags Flags) {
980   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
981   // Don't reassociate reductions.
982   if (Flags.hasVectorReduction())
983     return SDValue();
984 
985   // Floating-point reassociation is not allowed without loose FP math.
986   if (N0.getValueType().isFloatingPoint() ||
987       N1.getValueType().isFloatingPoint())
988     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
989       return SDValue();
990 
991   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1))
992     return Combined;
993   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0))
994     return Combined;
995   return SDValue();
996 }
997 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)998 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
999                                bool AddTo) {
1000   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1001   ++NodesCombined;
1002   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1003              To[0].getNode()->dump(&DAG);
1004              dbgs() << " and " << NumTo - 1 << " other values\n");
1005   for (unsigned i = 0, e = NumTo; i != e; ++i)
1006     assert((!To[i].getNode() ||
1007             N->getValueType(i) == To[i].getValueType()) &&
1008            "Cannot combine value to value of different type!");
1009 
1010   WorklistRemover DeadNodes(*this);
1011   DAG.ReplaceAllUsesWith(N, To);
1012   if (AddTo) {
1013     // Push the new nodes and any users onto the worklist
1014     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1015       if (To[i].getNode()) {
1016         AddToWorklist(To[i].getNode());
1017         AddUsersToWorklist(To[i].getNode());
1018       }
1019     }
1020   }
1021 
1022   // Finally, if the node is now dead, remove it from the graph.  The node
1023   // may not be dead if the replacement process recursively simplified to
1024   // something else needing this node.
1025   if (N->use_empty())
1026     deleteAndRecombine(N);
1027   return SDValue(N, 0);
1028 }
1029 
1030 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1031 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1032   // Replace all uses.  If any nodes become isomorphic to other nodes and
1033   // are deleted, make sure to remove them from our worklist.
1034   WorklistRemover DeadNodes(*this);
1035   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1036 
1037   // Push the new node and any (possibly new) users onto the worklist.
1038   AddToWorklistWithUsers(TLO.New.getNode());
1039 
1040   // Finally, if the node is now dead, remove it from the graph.  The node
1041   // may not be dead if the replacement process recursively simplified to
1042   // something else needing this node.
1043   if (TLO.Old.getNode()->use_empty())
1044     deleteAndRecombine(TLO.Old.getNode());
1045 }
1046 
1047 /// Check the specified integer node value to see if it can be simplified or if
1048 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts)1049 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1050                                        const APInt &DemandedElts) {
1051   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1052   KnownBits Known;
1053   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO))
1054     return false;
1055 
1056   // Revisit the node.
1057   AddToWorklist(Op.getNode());
1058 
1059   // Replace the old value with the new one.
1060   ++NodesCombined;
1061   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG);
1062              dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG);
1063              dbgs() << '\n');
1064 
1065   CommitTargetLoweringOpt(TLO);
1066   return true;
1067 }
1068 
1069 /// Check the specified vector node value to see if it can be simplified or
1070 /// if things it uses can be simplified as it only uses some of the elements.
1071 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1072 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1073                                              const APInt &DemandedElts,
1074                                              bool AssumeSingleUse) {
1075   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1076   APInt KnownUndef, KnownZero;
1077   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1078                                       TLO, 0, AssumeSingleUse))
1079     return false;
1080 
1081   // Revisit the node.
1082   AddToWorklist(Op.getNode());
1083 
1084   // Replace the old value with the new one.
1085   ++NodesCombined;
1086   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG);
1087              dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG);
1088              dbgs() << '\n');
1089 
1090   CommitTargetLoweringOpt(TLO);
1091   return true;
1092 }
1093 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1094 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1095   SDLoc DL(Load);
1096   EVT VT = Load->getValueType(0);
1097   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1098 
1099   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1100              Trunc.getNode()->dump(&DAG); dbgs() << '\n');
1101   WorklistRemover DeadNodes(*this);
1102   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1103   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1104   deleteAndRecombine(Load);
1105   AddToWorklist(Trunc.getNode());
1106 }
1107 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1108 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1109   Replace = false;
1110   SDLoc DL(Op);
1111   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1112     LoadSDNode *LD = cast<LoadSDNode>(Op);
1113     EVT MemVT = LD->getMemoryVT();
1114     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1115                                                       : LD->getExtensionType();
1116     Replace = true;
1117     return DAG.getExtLoad(ExtType, DL, PVT,
1118                           LD->getChain(), LD->getBasePtr(),
1119                           MemVT, LD->getMemOperand());
1120   }
1121 
1122   unsigned Opc = Op.getOpcode();
1123   switch (Opc) {
1124   default: break;
1125   case ISD::AssertSext:
1126     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1127       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1128     break;
1129   case ISD::AssertZext:
1130     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1131       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1132     break;
1133   case ISD::Constant: {
1134     unsigned ExtOpc =
1135       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1136     return DAG.getNode(ExtOpc, DL, PVT, Op);
1137   }
1138   }
1139 
1140   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1141     return SDValue();
1142   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1143 }
1144 
SExtPromoteOperand(SDValue Op,EVT PVT)1145 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1146   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1147     return SDValue();
1148   EVT OldVT = Op.getValueType();
1149   SDLoc DL(Op);
1150   bool Replace = false;
1151   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1152   if (!NewOp.getNode())
1153     return SDValue();
1154   AddToWorklist(NewOp.getNode());
1155 
1156   if (Replace)
1157     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1158   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1159                      DAG.getValueType(OldVT));
1160 }
1161 
ZExtPromoteOperand(SDValue Op,EVT PVT)1162 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1163   EVT OldVT = Op.getValueType();
1164   SDLoc DL(Op);
1165   bool Replace = false;
1166   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1167   if (!NewOp.getNode())
1168     return SDValue();
1169   AddToWorklist(NewOp.getNode());
1170 
1171   if (Replace)
1172     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1173   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1174 }
1175 
1176 /// Promote the specified integer binary operation if the target indicates it is
1177 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1178 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1179 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1180   if (!LegalOperations)
1181     return SDValue();
1182 
1183   EVT VT = Op.getValueType();
1184   if (VT.isVector() || !VT.isInteger())
1185     return SDValue();
1186 
1187   // If operation type is 'undesirable', e.g. i16 on x86, consider
1188   // promoting it.
1189   unsigned Opc = Op.getOpcode();
1190   if (TLI.isTypeDesirableForOp(Opc, VT))
1191     return SDValue();
1192 
1193   EVT PVT = VT;
1194   // Consult target whether it is a good idea to promote this operation and
1195   // what's the right type to promote it to.
1196   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1197     assert(PVT != VT && "Don't know what type to promote to!");
1198 
1199     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1200 
1201     bool Replace0 = false;
1202     SDValue N0 = Op.getOperand(0);
1203     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1204 
1205     bool Replace1 = false;
1206     SDValue N1 = Op.getOperand(1);
1207     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1208     SDLoc DL(Op);
1209 
1210     SDValue RV =
1211         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1212 
1213     // We are always replacing N0/N1's use in N and only need
1214     // additional replacements if there are additional uses.
1215     Replace0 &= !N0->hasOneUse();
1216     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1217 
1218     // Combine Op here so it is preserved past replacements.
1219     CombineTo(Op.getNode(), RV);
1220 
1221     // If operands have a use ordering, make sure we deal with
1222     // predecessor first.
1223     if (Replace0 && Replace1 && N0.getNode()->isPredecessorOf(N1.getNode())) {
1224       std::swap(N0, N1);
1225       std::swap(NN0, NN1);
1226     }
1227 
1228     if (Replace0) {
1229       AddToWorklist(NN0.getNode());
1230       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1231     }
1232     if (Replace1) {
1233       AddToWorklist(NN1.getNode());
1234       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1235     }
1236     return Op;
1237   }
1238   return SDValue();
1239 }
1240 
1241 /// Promote the specified integer shift operation if the target indicates it is
1242 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1243 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1244 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1245   if (!LegalOperations)
1246     return SDValue();
1247 
1248   EVT VT = Op.getValueType();
1249   if (VT.isVector() || !VT.isInteger())
1250     return SDValue();
1251 
1252   // If operation type is 'undesirable', e.g. i16 on x86, consider
1253   // promoting it.
1254   unsigned Opc = Op.getOpcode();
1255   if (TLI.isTypeDesirableForOp(Opc, VT))
1256     return SDValue();
1257 
1258   EVT PVT = VT;
1259   // Consult target whether it is a good idea to promote this operation and
1260   // what's the right type to promote it to.
1261   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1262     assert(PVT != VT && "Don't know what type to promote to!");
1263 
1264     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1265 
1266     bool Replace = false;
1267     SDValue N0 = Op.getOperand(0);
1268     SDValue N1 = Op.getOperand(1);
1269     if (Opc == ISD::SRA)
1270       N0 = SExtPromoteOperand(N0, PVT);
1271     else if (Opc == ISD::SRL)
1272       N0 = ZExtPromoteOperand(N0, PVT);
1273     else
1274       N0 = PromoteOperand(N0, PVT, Replace);
1275 
1276     if (!N0.getNode())
1277       return SDValue();
1278 
1279     SDLoc DL(Op);
1280     SDValue RV =
1281         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1282 
1283     if (Replace)
1284       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1285 
1286     // Deal with Op being deleted.
1287     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1288       return RV;
1289   }
1290   return SDValue();
1291 }
1292 
PromoteExtend(SDValue Op)1293 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1294   if (!LegalOperations)
1295     return SDValue();
1296 
1297   EVT VT = Op.getValueType();
1298   if (VT.isVector() || !VT.isInteger())
1299     return SDValue();
1300 
1301   // If operation type is 'undesirable', e.g. i16 on x86, consider
1302   // promoting it.
1303   unsigned Opc = Op.getOpcode();
1304   if (TLI.isTypeDesirableForOp(Opc, VT))
1305     return SDValue();
1306 
1307   EVT PVT = VT;
1308   // Consult target whether it is a good idea to promote this operation and
1309   // what's the right type to promote it to.
1310   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1311     assert(PVT != VT && "Don't know what type to promote to!");
1312     // fold (aext (aext x)) -> (aext x)
1313     // fold (aext (zext x)) -> (zext x)
1314     // fold (aext (sext x)) -> (sext x)
1315     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
1316     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1317   }
1318   return SDValue();
1319 }
1320 
PromoteLoad(SDValue Op)1321 bool DAGCombiner::PromoteLoad(SDValue Op) {
1322   if (!LegalOperations)
1323     return false;
1324 
1325   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1326     return false;
1327 
1328   EVT VT = Op.getValueType();
1329   if (VT.isVector() || !VT.isInteger())
1330     return false;
1331 
1332   // If operation type is 'undesirable', e.g. i16 on x86, consider
1333   // promoting it.
1334   unsigned Opc = Op.getOpcode();
1335   if (TLI.isTypeDesirableForOp(Opc, VT))
1336     return false;
1337 
1338   EVT PVT = VT;
1339   // Consult target whether it is a good idea to promote this operation and
1340   // what's the right type to promote it to.
1341   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1342     assert(PVT != VT && "Don't know what type to promote to!");
1343 
1344     SDLoc DL(Op);
1345     SDNode *N = Op.getNode();
1346     LoadSDNode *LD = cast<LoadSDNode>(N);
1347     EVT MemVT = LD->getMemoryVT();
1348     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1349                                                       : LD->getExtensionType();
1350     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1351                                    LD->getChain(), LD->getBasePtr(),
1352                                    MemVT, LD->getMemOperand());
1353     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1354 
1355     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1356                Result.getNode()->dump(&DAG); dbgs() << '\n');
1357     WorklistRemover DeadNodes(*this);
1358     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1359     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1360     deleteAndRecombine(N);
1361     AddToWorklist(Result.getNode());
1362     return true;
1363   }
1364   return false;
1365 }
1366 
1367 /// Recursively delete a node which has no uses and any operands for
1368 /// which it is the only use.
1369 ///
1370 /// Note that this both deletes the nodes and removes them from the worklist.
1371 /// It also adds any nodes who have had a user deleted to the worklist as they
1372 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1373 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1374   if (!N->use_empty())
1375     return false;
1376 
1377   SmallSetVector<SDNode *, 16> Nodes;
1378   Nodes.insert(N);
1379   do {
1380     N = Nodes.pop_back_val();
1381     if (!N)
1382       continue;
1383 
1384     if (N->use_empty()) {
1385       for (const SDValue &ChildN : N->op_values())
1386         Nodes.insert(ChildN.getNode());
1387 
1388       removeFromWorklist(N);
1389       DAG.DeleteNode(N);
1390     } else {
1391       AddToWorklist(N);
1392     }
1393   } while (!Nodes.empty());
1394   return true;
1395 }
1396 
1397 //===----------------------------------------------------------------------===//
1398 //  Main DAG Combiner implementation
1399 //===----------------------------------------------------------------------===//
1400 
Run(CombineLevel AtLevel)1401 void DAGCombiner::Run(CombineLevel AtLevel) {
1402   // set the instance variables, so that the various visit routines may use it.
1403   Level = AtLevel;
1404   LegalDAG = Level >= AfterLegalizeDAG;
1405   LegalOperations = Level >= AfterLegalizeVectorOps;
1406   LegalTypes = Level >= AfterLegalizeTypes;
1407 
1408   WorklistInserter AddNodes(*this);
1409 
1410   // Add all the dag nodes to the worklist.
1411   for (SDNode &Node : DAG.allnodes())
1412     AddToWorklist(&Node);
1413 
1414   // Create a dummy node (which is not added to allnodes), that adds a reference
1415   // to the root node, preventing it from being deleted, and tracking any
1416   // changes of the root.
1417   HandleSDNode Dummy(DAG.getRoot());
1418 
1419   // While we have a valid worklist entry node, try to combine it.
1420   while (SDNode *N = getNextWorklistEntry()) {
1421     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1422     // N is deleted from the DAG, since they too may now be dead or may have a
1423     // reduced number of uses, allowing other xforms.
1424     if (recursivelyDeleteUnusedNodes(N))
1425       continue;
1426 
1427     WorklistRemover DeadNodes(*this);
1428 
1429     // If this combine is running after legalizing the DAG, re-legalize any
1430     // nodes pulled off the worklist.
1431     if (LegalDAG) {
1432       SmallSetVector<SDNode *, 16> UpdatedNodes;
1433       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1434 
1435       for (SDNode *LN : UpdatedNodes)
1436         AddToWorklistWithUsers(LN);
1437 
1438       if (!NIsValid)
1439         continue;
1440     }
1441 
1442     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1443 
1444     // Add any operands of the new node which have not yet been combined to the
1445     // worklist as well. Because the worklist uniques things already, this
1446     // won't repeatedly process the same operand.
1447     CombinedNodes.insert(N);
1448     for (const SDValue &ChildN : N->op_values())
1449       if (!CombinedNodes.count(ChildN.getNode()))
1450         AddToWorklist(ChildN.getNode());
1451 
1452     SDValue RV = combine(N);
1453 
1454     if (!RV.getNode())
1455       continue;
1456 
1457     ++NodesCombined;
1458 
1459     // If we get back the same node we passed in, rather than a new node or
1460     // zero, we know that the node must have defined multiple values and
1461     // CombineTo was used.  Since CombineTo takes care of the worklist
1462     // mechanics for us, we have no work to do in this case.
1463     if (RV.getNode() == N)
1464       continue;
1465 
1466     assert(N->getOpcode() != ISD::DELETED_NODE &&
1467            RV.getOpcode() != ISD::DELETED_NODE &&
1468            "Node was deleted but visit returned new node!");
1469 
1470     LLVM_DEBUG(dbgs() << " ... into: "; RV.getNode()->dump(&DAG));
1471 
1472     if (N->getNumValues() == RV.getNode()->getNumValues())
1473       DAG.ReplaceAllUsesWith(N, RV.getNode());
1474     else {
1475       assert(N->getValueType(0) == RV.getValueType() &&
1476              N->getNumValues() == 1 && "Type mismatch");
1477       DAG.ReplaceAllUsesWith(N, &RV);
1478     }
1479 
1480     // Push the new node and any users onto the worklist
1481     AddToWorklist(RV.getNode());
1482     AddUsersToWorklist(RV.getNode());
1483 
1484     // Finally, if the node is now dead, remove it from the graph.  The node
1485     // may not be dead if the replacement process recursively simplified to
1486     // something else needing this node. This will also take care of adding any
1487     // operands which have lost a user to the worklist.
1488     recursivelyDeleteUnusedNodes(N);
1489   }
1490 
1491   // If the root changed (e.g. it was a dead load, update the root).
1492   DAG.setRoot(Dummy.getValue());
1493   DAG.RemoveDeadNodes();
1494 }
1495 
visit(SDNode * N)1496 SDValue DAGCombiner::visit(SDNode *N) {
1497   switch (N->getOpcode()) {
1498   default: break;
1499   case ISD::TokenFactor:        return visitTokenFactor(N);
1500   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1501   case ISD::ADD:                return visitADD(N);
1502   case ISD::SUB:                return visitSUB(N);
1503   case ISD::SADDSAT:
1504   case ISD::UADDSAT:            return visitADDSAT(N);
1505   case ISD::SSUBSAT:
1506   case ISD::USUBSAT:            return visitSUBSAT(N);
1507   case ISD::ADDC:               return visitADDC(N);
1508   case ISD::SADDO:
1509   case ISD::UADDO:              return visitADDO(N);
1510   case ISD::SUBC:               return visitSUBC(N);
1511   case ISD::SSUBO:
1512   case ISD::USUBO:              return visitSUBO(N);
1513   case ISD::ADDE:               return visitADDE(N);
1514   case ISD::ADDCARRY:           return visitADDCARRY(N);
1515   case ISD::SUBE:               return visitSUBE(N);
1516   case ISD::SUBCARRY:           return visitSUBCARRY(N);
1517   case ISD::SMULFIX:
1518   case ISD::SMULFIXSAT:
1519   case ISD::UMULFIX:
1520   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1521   case ISD::MUL:                return visitMUL(N);
1522   case ISD::SDIV:               return visitSDIV(N);
1523   case ISD::UDIV:               return visitUDIV(N);
1524   case ISD::SREM:
1525   case ISD::UREM:               return visitREM(N);
1526   case ISD::MULHU:              return visitMULHU(N);
1527   case ISD::MULHS:              return visitMULHS(N);
1528   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1529   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1530   case ISD::SMULO:
1531   case ISD::UMULO:              return visitMULO(N);
1532   case ISD::SMIN:
1533   case ISD::SMAX:
1534   case ISD::UMIN:
1535   case ISD::UMAX:               return visitIMINMAX(N);
1536   case ISD::AND:                return visitAND(N);
1537   case ISD::OR:                 return visitOR(N);
1538   case ISD::XOR:                return visitXOR(N);
1539   case ISD::SHL:                return visitSHL(N);
1540   case ISD::SRA:                return visitSRA(N);
1541   case ISD::SRL:                return visitSRL(N);
1542   case ISD::ROTR:
1543   case ISD::ROTL:               return visitRotate(N);
1544   case ISD::FSHL:
1545   case ISD::FSHR:               return visitFunnelShift(N);
1546   case ISD::ABS:                return visitABS(N);
1547   case ISD::BSWAP:              return visitBSWAP(N);
1548   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1549   case ISD::CTLZ:               return visitCTLZ(N);
1550   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1551   case ISD::CTTZ:               return visitCTTZ(N);
1552   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1553   case ISD::CTPOP:              return visitCTPOP(N);
1554   case ISD::SELECT:             return visitSELECT(N);
1555   case ISD::VSELECT:            return visitVSELECT(N);
1556   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1557   case ISD::SETCC:              return visitSETCC(N);
1558   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1559   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1560   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1561   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1562   case ISD::AssertSext:
1563   case ISD::AssertZext:         return visitAssertExt(N);
1564   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
1565   case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N);
1566   case ISD::ZERO_EXTEND_VECTOR_INREG: return visitZERO_EXTEND_VECTOR_INREG(N);
1567   case ISD::TRUNCATE:           return visitTRUNCATE(N);
1568   case ISD::BITCAST:            return visitBITCAST(N);
1569   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
1570   case ISD::FADD:               return visitFADD(N);
1571   case ISD::FSUB:               return visitFSUB(N);
1572   case ISD::FMUL:               return visitFMUL(N);
1573   case ISD::FMA:                return visitFMA(N);
1574   case ISD::FDIV:               return visitFDIV(N);
1575   case ISD::FREM:               return visitFREM(N);
1576   case ISD::FSQRT:              return visitFSQRT(N);
1577   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
1578   case ISD::FPOW:               return visitFPOW(N);
1579   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
1580   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
1581   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
1582   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
1583   case ISD::FP_ROUND:           return visitFP_ROUND(N);
1584   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
1585   case ISD::FNEG:               return visitFNEG(N);
1586   case ISD::FABS:               return visitFABS(N);
1587   case ISD::FFLOOR:             return visitFFLOOR(N);
1588   case ISD::FMINNUM:            return visitFMINNUM(N);
1589   case ISD::FMAXNUM:            return visitFMAXNUM(N);
1590   case ISD::FMINIMUM:           return visitFMINIMUM(N);
1591   case ISD::FMAXIMUM:           return visitFMAXIMUM(N);
1592   case ISD::FCEIL:              return visitFCEIL(N);
1593   case ISD::FTRUNC:             return visitFTRUNC(N);
1594   case ISD::BRCOND:             return visitBRCOND(N);
1595   case ISD::BR_CC:              return visitBR_CC(N);
1596   case ISD::LOAD:               return visitLOAD(N);
1597   case ISD::STORE:              return visitSTORE(N);
1598   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
1599   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1600   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
1601   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
1602   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
1603   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
1604   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
1605   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
1606   case ISD::MGATHER:            return visitMGATHER(N);
1607   case ISD::MLOAD:              return visitMLOAD(N);
1608   case ISD::MSCATTER:           return visitMSCATTER(N);
1609   case ISD::MSTORE:             return visitMSTORE(N);
1610   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
1611   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
1612   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
1613   case ISD::VECREDUCE_FADD:
1614   case ISD::VECREDUCE_FMUL:
1615   case ISD::VECREDUCE_ADD:
1616   case ISD::VECREDUCE_MUL:
1617   case ISD::VECREDUCE_AND:
1618   case ISD::VECREDUCE_OR:
1619   case ISD::VECREDUCE_XOR:
1620   case ISD::VECREDUCE_SMAX:
1621   case ISD::VECREDUCE_SMIN:
1622   case ISD::VECREDUCE_UMAX:
1623   case ISD::VECREDUCE_UMIN:
1624   case ISD::VECREDUCE_FMAX:
1625   case ISD::VECREDUCE_FMIN:     return visitVECREDUCE(N);
1626   }
1627   return SDValue();
1628 }
1629 
combine(SDNode * N)1630 SDValue DAGCombiner::combine(SDNode *N) {
1631   SDValue RV = visit(N);
1632 
1633   // If nothing happened, try a target-specific DAG combine.
1634   if (!RV.getNode()) {
1635     assert(N->getOpcode() != ISD::DELETED_NODE &&
1636            "Node was deleted but visit returned NULL!");
1637 
1638     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1639         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
1640 
1641       // Expose the DAG combiner to the target combiner impls.
1642       TargetLowering::DAGCombinerInfo
1643         DagCombineInfo(DAG, Level, false, this);
1644 
1645       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
1646     }
1647   }
1648 
1649   // If nothing happened still, try promoting the operation.
1650   if (!RV.getNode()) {
1651     switch (N->getOpcode()) {
1652     default: break;
1653     case ISD::ADD:
1654     case ISD::SUB:
1655     case ISD::MUL:
1656     case ISD::AND:
1657     case ISD::OR:
1658     case ISD::XOR:
1659       RV = PromoteIntBinOp(SDValue(N, 0));
1660       break;
1661     case ISD::SHL:
1662     case ISD::SRA:
1663     case ISD::SRL:
1664       RV = PromoteIntShiftOp(SDValue(N, 0));
1665       break;
1666     case ISD::SIGN_EXTEND:
1667     case ISD::ZERO_EXTEND:
1668     case ISD::ANY_EXTEND:
1669       RV = PromoteExtend(SDValue(N, 0));
1670       break;
1671     case ISD::LOAD:
1672       if (PromoteLoad(SDValue(N, 0)))
1673         RV = SDValue(N, 0);
1674       break;
1675     }
1676   }
1677 
1678   // If N is a commutative binary node, try to eliminate it if the commuted
1679   // version is already present in the DAG.
1680   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) &&
1681       N->getNumValues() == 1) {
1682     SDValue N0 = N->getOperand(0);
1683     SDValue N1 = N->getOperand(1);
1684 
1685     // Constant operands are canonicalized to RHS.
1686     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
1687       SDValue Ops[] = {N1, N0};
1688       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
1689                                             N->getFlags());
1690       if (CSENode)
1691         return SDValue(CSENode, 0);
1692     }
1693   }
1694 
1695   return RV;
1696 }
1697 
1698 /// Given a node, return its input chain if it has one, otherwise return a null
1699 /// sd operand.
getInputChainForNode(SDNode * N)1700 static SDValue getInputChainForNode(SDNode *N) {
1701   if (unsigned NumOps = N->getNumOperands()) {
1702     if (N->getOperand(0).getValueType() == MVT::Other)
1703       return N->getOperand(0);
1704     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
1705       return N->getOperand(NumOps-1);
1706     for (unsigned i = 1; i < NumOps-1; ++i)
1707       if (N->getOperand(i).getValueType() == MVT::Other)
1708         return N->getOperand(i);
1709   }
1710   return SDValue();
1711 }
1712 
visitTokenFactor(SDNode * N)1713 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
1714   // If N has two operands, where one has an input chain equal to the other,
1715   // the 'other' chain is redundant.
1716   if (N->getNumOperands() == 2) {
1717     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
1718       return N->getOperand(0);
1719     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
1720       return N->getOperand(1);
1721   }
1722 
1723   // Don't simplify token factors if optnone.
1724   if (OptLevel == CodeGenOpt::None)
1725     return SDValue();
1726 
1727   // If the sole user is a token factor, we should make sure we have a
1728   // chance to merge them together. This prevents TF chains from inhibiting
1729   // optimizations.
1730   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
1731     AddToWorklist(*(N->use_begin()));
1732 
1733   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
1734   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
1735   SmallPtrSet<SDNode*, 16> SeenOps;
1736   bool Changed = false;             // If we should replace this token factor.
1737 
1738   // Start out with this token factor.
1739   TFs.push_back(N);
1740 
1741   // Iterate through token factors.  The TFs grows when new token factors are
1742   // encountered.
1743   for (unsigned i = 0; i < TFs.size(); ++i) {
1744     // Limit number of nodes to inline, to avoid quadratic compile times.
1745     // We have to add the outstanding Token Factors to Ops, otherwise we might
1746     // drop Ops from the resulting Token Factors.
1747     if (Ops.size() > TokenFactorInlineLimit) {
1748       for (unsigned j = i; j < TFs.size(); j++)
1749         Ops.emplace_back(TFs[j], 0);
1750       // Drop unprocessed Token Factors from TFs, so we do not add them to the
1751       // combiner worklist later.
1752       TFs.resize(i);
1753       break;
1754     }
1755 
1756     SDNode *TF = TFs[i];
1757     // Check each of the operands.
1758     for (const SDValue &Op : TF->op_values()) {
1759       switch (Op.getOpcode()) {
1760       case ISD::EntryToken:
1761         // Entry tokens don't need to be added to the list. They are
1762         // redundant.
1763         Changed = true;
1764         break;
1765 
1766       case ISD::TokenFactor:
1767         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
1768           // Queue up for processing.
1769           TFs.push_back(Op.getNode());
1770           Changed = true;
1771           break;
1772         }
1773         LLVM_FALLTHROUGH;
1774 
1775       default:
1776         // Only add if it isn't already in the list.
1777         if (SeenOps.insert(Op.getNode()).second)
1778           Ops.push_back(Op);
1779         else
1780           Changed = true;
1781         break;
1782       }
1783     }
1784   }
1785 
1786   // Re-visit inlined Token Factors, to clean them up in case they have been
1787   // removed. Skip the first Token Factor, as this is the current node.
1788   for (unsigned i = 1, e = TFs.size(); i < e; i++)
1789     AddToWorklist(TFs[i]);
1790 
1791   // Remove Nodes that are chained to another node in the list. Do so
1792   // by walking up chains breath-first stopping when we've seen
1793   // another operand. In general we must climb to the EntryNode, but we can exit
1794   // early if we find all remaining work is associated with just one operand as
1795   // no further pruning is possible.
1796 
1797   // List of nodes to search through and original Ops from which they originate.
1798   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
1799   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
1800   SmallPtrSet<SDNode *, 16> SeenChains;
1801   bool DidPruneOps = false;
1802 
1803   unsigned NumLeftToConsider = 0;
1804   for (const SDValue &Op : Ops) {
1805     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
1806     OpWorkCount.push_back(1);
1807   }
1808 
1809   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
1810     // If this is an Op, we can remove the op from the list. Remark any
1811     // search associated with it as from the current OpNumber.
1812     if (SeenOps.count(Op) != 0) {
1813       Changed = true;
1814       DidPruneOps = true;
1815       unsigned OrigOpNumber = 0;
1816       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
1817         OrigOpNumber++;
1818       assert((OrigOpNumber != Ops.size()) &&
1819              "expected to find TokenFactor Operand");
1820       // Re-mark worklist from OrigOpNumber to OpNumber
1821       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
1822         if (Worklist[i].second == OrigOpNumber) {
1823           Worklist[i].second = OpNumber;
1824         }
1825       }
1826       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
1827       OpWorkCount[OrigOpNumber] = 0;
1828       NumLeftToConsider--;
1829     }
1830     // Add if it's a new chain
1831     if (SeenChains.insert(Op).second) {
1832       OpWorkCount[OpNumber]++;
1833       Worklist.push_back(std::make_pair(Op, OpNumber));
1834     }
1835   };
1836 
1837   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
1838     // We need at least be consider at least 2 Ops to prune.
1839     if (NumLeftToConsider <= 1)
1840       break;
1841     auto CurNode = Worklist[i].first;
1842     auto CurOpNumber = Worklist[i].second;
1843     assert((OpWorkCount[CurOpNumber] > 0) &&
1844            "Node should not appear in worklist");
1845     switch (CurNode->getOpcode()) {
1846     case ISD::EntryToken:
1847       // Hitting EntryToken is the only way for the search to terminate without
1848       // hitting
1849       // another operand's search. Prevent us from marking this operand
1850       // considered.
1851       NumLeftToConsider++;
1852       break;
1853     case ISD::TokenFactor:
1854       for (const SDValue &Op : CurNode->op_values())
1855         AddToWorklist(i, Op.getNode(), CurOpNumber);
1856       break;
1857     case ISD::LIFETIME_START:
1858     case ISD::LIFETIME_END:
1859     case ISD::CopyFromReg:
1860     case ISD::CopyToReg:
1861       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
1862       break;
1863     default:
1864       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
1865         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
1866       break;
1867     }
1868     OpWorkCount[CurOpNumber]--;
1869     if (OpWorkCount[CurOpNumber] == 0)
1870       NumLeftToConsider--;
1871   }
1872 
1873   // If we've changed things around then replace token factor.
1874   if (Changed) {
1875     SDValue Result;
1876     if (Ops.empty()) {
1877       // The entry token is the only possible outcome.
1878       Result = DAG.getEntryNode();
1879     } else {
1880       if (DidPruneOps) {
1881         SmallVector<SDValue, 8> PrunedOps;
1882         //
1883         for (const SDValue &Op : Ops) {
1884           if (SeenChains.count(Op.getNode()) == 0)
1885             PrunedOps.push_back(Op);
1886         }
1887         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
1888       } else {
1889         Result = DAG.getTokenFactor(SDLoc(N), Ops);
1890       }
1891     }
1892     return Result;
1893   }
1894   return SDValue();
1895 }
1896 
1897 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)1898 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
1899   WorklistRemover DeadNodes(*this);
1900   // Replacing results may cause a different MERGE_VALUES to suddenly
1901   // be CSE'd with N, and carry its uses with it. Iterate until no
1902   // uses remain, to ensure that the node can be safely deleted.
1903   // First add the users of this node to the work list so that they
1904   // can be tried again once they have new operands.
1905   AddUsersToWorklist(N);
1906   do {
1907     // Do as a single replacement to avoid rewalking use lists.
1908     SmallVector<SDValue, 8> Ops;
1909     for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1910       Ops.push_back(N->getOperand(i));
1911     DAG.ReplaceAllUsesWith(N, Ops.data());
1912   } while (!N->use_empty());
1913   deleteAndRecombine(N);
1914   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
1915 }
1916 
1917 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
1918 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)1919 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
1920   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
1921   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
1922 }
1923 
foldBinOpIntoSelect(SDNode * BO)1924 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
1925   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
1926          "Unexpected binary operator");
1927 
1928   // Don't do this unless the old select is going away. We want to eliminate the
1929   // binary operator, not replace a binop with a select.
1930   // TODO: Handle ISD::SELECT_CC.
1931   unsigned SelOpNo = 0;
1932   SDValue Sel = BO->getOperand(0);
1933   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
1934     SelOpNo = 1;
1935     Sel = BO->getOperand(1);
1936   }
1937 
1938   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
1939     return SDValue();
1940 
1941   SDValue CT = Sel.getOperand(1);
1942   if (!isConstantOrConstantVector(CT, true) &&
1943       !isConstantFPBuildVectorOrConstantFP(CT))
1944     return SDValue();
1945 
1946   SDValue CF = Sel.getOperand(2);
1947   if (!isConstantOrConstantVector(CF, true) &&
1948       !isConstantFPBuildVectorOrConstantFP(CF))
1949     return SDValue();
1950 
1951   // Bail out if any constants are opaque because we can't constant fold those.
1952   // The exception is "and" and "or" with either 0 or -1 in which case we can
1953   // propagate non constant operands into select. I.e.:
1954   // and (select Cond, 0, -1), X --> select Cond, 0, X
1955   // or X, (select Cond, -1, 0) --> select Cond, -1, X
1956   auto BinOpcode = BO->getOpcode();
1957   bool CanFoldNonConst =
1958       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
1959       (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) &&
1960       (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF));
1961 
1962   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
1963   if (!CanFoldNonConst &&
1964       !isConstantOrConstantVector(CBO, true) &&
1965       !isConstantFPBuildVectorOrConstantFP(CBO))
1966     return SDValue();
1967 
1968   EVT VT = Sel.getValueType();
1969 
1970   // In case of shift value and shift amount may have different VT. For instance
1971   // on x86 shift amount is i8 regardles of LHS type. Bail out if we have
1972   // swapped operands and value types do not match. NB: x86 is fine if operands
1973   // are not swapped with shift amount VT being not bigger than shifted value.
1974   // TODO: that is possible to check for a shift operation, correct VTs and
1975   // still perform optimization on x86 if needed.
1976   if (SelOpNo && VT != CBO.getValueType())
1977     return SDValue();
1978 
1979   // We have a select-of-constants followed by a binary operator with a
1980   // constant. Eliminate the binop by pulling the constant math into the select.
1981   // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
1982   SDLoc DL(Sel);
1983   SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
1984                           : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
1985   if (!CanFoldNonConst && !NewCT.isUndef() &&
1986       !isConstantOrConstantVector(NewCT, true) &&
1987       !isConstantFPBuildVectorOrConstantFP(NewCT))
1988     return SDValue();
1989 
1990   SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
1991                           : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
1992   if (!CanFoldNonConst && !NewCF.isUndef() &&
1993       !isConstantOrConstantVector(NewCF, true) &&
1994       !isConstantFPBuildVectorOrConstantFP(NewCF))
1995     return SDValue();
1996 
1997   SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
1998   SelectOp->setFlags(BO->getFlags());
1999   return SelectOp;
2000 }
2001 
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2002 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2003   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2004          "Expecting add or sub");
2005 
2006   // Match a constant operand and a zext operand for the math instruction:
2007   // add Z, C
2008   // sub C, Z
2009   bool IsAdd = N->getOpcode() == ISD::ADD;
2010   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2011   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2012   auto *CN = dyn_cast<ConstantSDNode>(C);
2013   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2014     return SDValue();
2015 
2016   // Match the zext operand as a setcc of a boolean.
2017   if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2018       Z.getOperand(0).getValueType() != MVT::i1)
2019     return SDValue();
2020 
2021   // Match the compare as: setcc (X & 1), 0, eq.
2022   SDValue SetCC = Z.getOperand(0);
2023   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2024   if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2025       SetCC.getOperand(0).getOpcode() != ISD::AND ||
2026       !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2027     return SDValue();
2028 
2029   // We are adding/subtracting a constant and an inverted low bit. Turn that
2030   // into a subtract/add of the low bit with incremented/decremented constant:
2031   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2032   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2033   EVT VT = C.getValueType();
2034   SDLoc DL(N);
2035   SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2036   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2037                        DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2038   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2039 }
2040 
2041 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2042 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2043 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2044   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2045          "Expecting add or sub");
2046 
2047   // We need a constant operand for the add/sub, and the other operand is a
2048   // logical shift right: add (srl), C or sub C, (srl).
2049   // TODO - support non-uniform vector amounts.
2050   bool IsAdd = N->getOpcode() == ISD::ADD;
2051   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2052   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2053   ConstantSDNode *C = isConstOrConstSplat(ConstantOp);
2054   if (!C || ShiftOp.getOpcode() != ISD::SRL)
2055     return SDValue();
2056 
2057   // The shift must be of a 'not' value.
2058   SDValue Not = ShiftOp.getOperand(0);
2059   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2060     return SDValue();
2061 
2062   // The shift must be moving the sign bit to the least-significant-bit.
2063   EVT VT = ShiftOp.getValueType();
2064   SDValue ShAmt = ShiftOp.getOperand(1);
2065   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2066   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2067     return SDValue();
2068 
2069   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2070   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2071   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2072   SDLoc DL(N);
2073   auto ShOpcode = IsAdd ? ISD::SRA : ISD::SRL;
2074   SDValue NewShift = DAG.getNode(ShOpcode, DL, VT, Not.getOperand(0), ShAmt);
2075   APInt NewC = IsAdd ? C->getAPIntValue() + 1 : C->getAPIntValue() - 1;
2076   return DAG.getNode(ISD::ADD, DL, VT, NewShift, DAG.getConstant(NewC, DL, VT));
2077 }
2078 
2079 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2080 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2081 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2082 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2083   SDValue N0 = N->getOperand(0);
2084   SDValue N1 = N->getOperand(1);
2085   EVT VT = N0.getValueType();
2086   SDLoc DL(N);
2087 
2088   // fold vector ops
2089   if (VT.isVector()) {
2090     if (SDValue FoldedVOp = SimplifyVBinOp(N))
2091       return FoldedVOp;
2092 
2093     // fold (add x, 0) -> x, vector edition
2094     if (ISD::isBuildVectorAllZeros(N1.getNode()))
2095       return N0;
2096     if (ISD::isBuildVectorAllZeros(N0.getNode()))
2097       return N1;
2098   }
2099 
2100   // fold (add x, undef) -> undef
2101   if (N0.isUndef())
2102     return N0;
2103 
2104   if (N1.isUndef())
2105     return N1;
2106 
2107   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
2108     // canonicalize constant to RHS
2109     if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
2110       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2111     // fold (add c1, c2) -> c1+c2
2112     return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, N0.getNode(),
2113                                       N1.getNode());
2114   }
2115 
2116   // fold (add x, 0) -> x
2117   if (isNullConstant(N1))
2118     return N0;
2119 
2120   if (isConstantOrConstantVector(N1, /* NoOpaque */ true)) {
2121     // fold ((A-c1)+c2) -> (A+(c2-c1))
2122     if (N0.getOpcode() == ISD::SUB &&
2123         isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2124       SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N1.getNode(),
2125                                                N0.getOperand(1).getNode());
2126       assert(Sub && "Constant folding failed");
2127       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2128     }
2129 
2130     // fold ((c1-A)+c2) -> (c1+c2)-A
2131     if (N0.getOpcode() == ISD::SUB &&
2132         isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) {
2133       SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, N1.getNode(),
2134                                                N0.getOperand(0).getNode());
2135       assert(Add && "Constant folding failed");
2136       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2137     }
2138 
2139     // add (sext i1 X), 1 -> zext (not i1 X)
2140     // We don't transform this pattern:
2141     //   add (zext i1 X), -1 -> sext (not i1 X)
2142     // because most (?) targets generate better code for the zext form.
2143     if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2144         isOneOrOneSplat(N1)) {
2145       SDValue X = N0.getOperand(0);
2146       if ((!LegalOperations ||
2147            (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2148             TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2149           X.getScalarValueSizeInBits() == 1) {
2150         SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2151         return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2152       }
2153     }
2154 
2155     // Undo the add -> or combine to merge constant offsets from a frame index.
2156     if (N0.getOpcode() == ISD::OR &&
2157         isa<FrameIndexSDNode>(N0.getOperand(0)) &&
2158         isa<ConstantSDNode>(N0.getOperand(1)) &&
2159         DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) {
2160       SDValue Add0 = DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(1));
2161       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0);
2162     }
2163   }
2164 
2165   if (SDValue NewSel = foldBinOpIntoSelect(N))
2166     return NewSel;
2167 
2168   // reassociate add
2169   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N0, N1)) {
2170     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2171       return RADD;
2172   }
2173   // fold ((0-A) + B) -> B-A
2174   if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2175     return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2176 
2177   // fold (A + (0-B)) -> A-B
2178   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2179     return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2180 
2181   // fold (A+(B-A)) -> B
2182   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2183     return N1.getOperand(0);
2184 
2185   // fold ((B-A)+A) -> B
2186   if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2187     return N0.getOperand(0);
2188 
2189   // fold ((A-B)+(C-A)) -> (C-B)
2190   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2191       N0.getOperand(0) == N1.getOperand(1))
2192     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2193                        N0.getOperand(1));
2194 
2195   // fold ((A-B)+(B-C)) -> (A-C)
2196   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2197       N0.getOperand(1) == N1.getOperand(0))
2198     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2199                        N1.getOperand(1));
2200 
2201   // fold (A+(B-(A+C))) to (B-C)
2202   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2203       N0 == N1.getOperand(1).getOperand(0))
2204     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2205                        N1.getOperand(1).getOperand(1));
2206 
2207   // fold (A+(B-(C+A))) to (B-C)
2208   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2209       N0 == N1.getOperand(1).getOperand(1))
2210     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2211                        N1.getOperand(1).getOperand(0));
2212 
2213   // fold (A+((B-A)+or-C)) to (B+or-C)
2214   if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2215       N1.getOperand(0).getOpcode() == ISD::SUB &&
2216       N0 == N1.getOperand(0).getOperand(1))
2217     return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2218                        N1.getOperand(1));
2219 
2220   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2221   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB) {
2222     SDValue N00 = N0.getOperand(0);
2223     SDValue N01 = N0.getOperand(1);
2224     SDValue N10 = N1.getOperand(0);
2225     SDValue N11 = N1.getOperand(1);
2226 
2227     if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2228       return DAG.getNode(ISD::SUB, DL, VT,
2229                          DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2230                          DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2231   }
2232 
2233   // fold (add (umax X, C), -C) --> (usubsat X, C)
2234   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2235     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2236       return (!Max && !Op) ||
2237              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2238     };
2239     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2240                                   /*AllowUndefs*/ true))
2241       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2242                          N0.getOperand(1));
2243   }
2244 
2245   if (SimplifyDemandedBits(SDValue(N, 0)))
2246     return SDValue(N, 0);
2247 
2248   if (isOneOrOneSplat(N1)) {
2249     // fold (add (xor a, -1), 1) -> (sub 0, a)
2250     if (isBitwiseNot(N0))
2251       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2252                          N0.getOperand(0));
2253 
2254     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2255     if (N0.getOpcode() == ISD::ADD ||
2256         N0.getOpcode() == ISD::UADDO ||
2257         N0.getOpcode() == ISD::SADDO) {
2258       SDValue A, Xor;
2259 
2260       if (isBitwiseNot(N0.getOperand(0))) {
2261         A = N0.getOperand(1);
2262         Xor = N0.getOperand(0);
2263       } else if (isBitwiseNot(N0.getOperand(1))) {
2264         A = N0.getOperand(0);
2265         Xor = N0.getOperand(1);
2266       }
2267 
2268       if (Xor)
2269         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2270     }
2271 
2272     // Look for:
2273     //   add (add x, y), 1
2274     // And if the target does not like this form then turn into:
2275     //   sub y, (xor x, -1)
2276     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() &&
2277         N0.getOpcode() == ISD::ADD) {
2278       SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2279                                 DAG.getAllOnesConstant(DL, VT));
2280       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2281     }
2282   }
2283 
2284   // (x - y) + -1  ->  add (xor y, -1), x
2285   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2286       isAllOnesOrAllOnesSplat(N1)) {
2287     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2288     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2289   }
2290 
2291   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2292     return Combined;
2293 
2294   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2295     return Combined;
2296 
2297   return SDValue();
2298 }
2299 
visitADD(SDNode * N)2300 SDValue DAGCombiner::visitADD(SDNode *N) {
2301   SDValue N0 = N->getOperand(0);
2302   SDValue N1 = N->getOperand(1);
2303   EVT VT = N0.getValueType();
2304   SDLoc DL(N);
2305 
2306   if (SDValue Combined = visitADDLike(N))
2307     return Combined;
2308 
2309   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2310     return V;
2311 
2312   if (SDValue V = foldAddSubOfSignBit(N, DAG))
2313     return V;
2314 
2315   // fold (a+b) -> (a|b) iff a and b share no bits.
2316   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2317       DAG.haveNoCommonBitsSet(N0, N1))
2318     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2319 
2320   return SDValue();
2321 }
2322 
visitADDSAT(SDNode * N)2323 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2324   unsigned Opcode = N->getOpcode();
2325   SDValue N0 = N->getOperand(0);
2326   SDValue N1 = N->getOperand(1);
2327   EVT VT = N0.getValueType();
2328   SDLoc DL(N);
2329 
2330   // fold vector ops
2331   if (VT.isVector()) {
2332     // TODO SimplifyVBinOp
2333 
2334     // fold (add_sat x, 0) -> x, vector edition
2335     if (ISD::isBuildVectorAllZeros(N1.getNode()))
2336       return N0;
2337     if (ISD::isBuildVectorAllZeros(N0.getNode()))
2338       return N1;
2339   }
2340 
2341   // fold (add_sat x, undef) -> -1
2342   if (N0.isUndef() || N1.isUndef())
2343     return DAG.getAllOnesConstant(DL, VT);
2344 
2345   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
2346     // canonicalize constant to RHS
2347     if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
2348       return DAG.getNode(Opcode, DL, VT, N1, N0);
2349     // fold (add_sat c1, c2) -> c3
2350     return DAG.FoldConstantArithmetic(Opcode, DL, VT, N0.getNode(),
2351                                       N1.getNode());
2352   }
2353 
2354   // fold (add_sat x, 0) -> x
2355   if (isNullConstant(N1))
2356     return N0;
2357 
2358   // If it cannot overflow, transform into an add.
2359   if (Opcode == ISD::UADDSAT)
2360     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2361       return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
2362 
2363   return SDValue();
2364 }
2365 
getAsCarry(const TargetLowering & TLI,SDValue V)2366 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
2367   bool Masked = false;
2368 
2369   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2370   while (true) {
2371     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2372       V = V.getOperand(0);
2373       continue;
2374     }
2375 
2376     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
2377       Masked = true;
2378       V = V.getOperand(0);
2379       continue;
2380     }
2381 
2382     break;
2383   }
2384 
2385   // If this is not a carry, return.
2386   if (V.getResNo() != 1)
2387     return SDValue();
2388 
2389   if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
2390       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
2391     return SDValue();
2392 
2393   EVT VT = V.getNode()->getValueType(0);
2394   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
2395     return SDValue();
2396 
2397   // If the result is masked, then no matter what kind of bool it is we can
2398   // return. If it isn't, then we need to make sure the bool type is either 0 or
2399   // 1 and not other values.
2400   if (Masked ||
2401       TLI.getBooleanContents(V.getValueType()) ==
2402           TargetLoweringBase::ZeroOrOneBooleanContent)
2403     return V;
2404 
2405   return SDValue();
2406 }
2407 
2408 /// Given the operands of an add/sub operation, see if the 2nd operand is a
2409 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
2410 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)2411 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
2412                                  SelectionDAG &DAG, const SDLoc &DL) {
2413   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
2414     return SDValue();
2415 
2416   EVT VT = N0.getValueType();
2417   if (DAG.ComputeNumSignBits(N1.getOperand(0)) != VT.getScalarSizeInBits())
2418     return SDValue();
2419 
2420   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
2421   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
2422   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N1.getOperand(0));
2423 }
2424 
2425 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)2426 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
2427                                           SDNode *LocReference) {
2428   EVT VT = N0.getValueType();
2429   SDLoc DL(LocReference);
2430 
2431   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
2432   if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
2433       isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
2434     return DAG.getNode(ISD::SUB, DL, VT, N0,
2435                        DAG.getNode(ISD::SHL, DL, VT,
2436                                    N1.getOperand(0).getOperand(1),
2437                                    N1.getOperand(1)));
2438 
2439   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
2440     return V;
2441 
2442   // Look for:
2443   //   add (add x, 1), y
2444   // And if the target does not like this form then turn into:
2445   //   sub y, (xor x, -1)
2446   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() &&
2447       N0.getOpcode() == ISD::ADD && isOneOrOneSplat(N0.getOperand(1))) {
2448     SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2449                               DAG.getAllOnesConstant(DL, VT));
2450     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
2451   }
2452 
2453   // Hoist one-use subtraction by non-opaque constant:
2454   //   (x - C) + y  ->  (x + y) - C
2455   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
2456   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2457       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
2458     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
2459     return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2460   }
2461   // Hoist one-use subtraction from non-opaque constant:
2462   //   (C - x) + y  ->  (y - x) + C
2463   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
2464       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
2465     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2466     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
2467   }
2468 
2469   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
2470   // rather than 'add 0/-1' (the zext should get folded).
2471   // add (sext i1 Y), X --> sub X, (zext i1 Y)
2472   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
2473       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
2474       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
2475     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
2476     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
2477   }
2478 
2479   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
2480   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2481     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
2482     if (TN->getVT() == MVT::i1) {
2483       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
2484                                  DAG.getConstant(1, DL, VT));
2485       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
2486     }
2487   }
2488 
2489   // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2490   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
2491       N1.getResNo() == 0)
2492     return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
2493                        N0, N1.getOperand(0), N1.getOperand(2));
2494 
2495   // (add X, Carry) -> (addcarry X, 0, Carry)
2496   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2497     if (SDValue Carry = getAsCarry(TLI, N1))
2498       return DAG.getNode(ISD::ADDCARRY, DL,
2499                          DAG.getVTList(VT, Carry.getValueType()), N0,
2500                          DAG.getConstant(0, DL, VT), Carry);
2501 
2502   return SDValue();
2503 }
2504 
visitADDC(SDNode * N)2505 SDValue DAGCombiner::visitADDC(SDNode *N) {
2506   SDValue N0 = N->getOperand(0);
2507   SDValue N1 = N->getOperand(1);
2508   EVT VT = N0.getValueType();
2509   SDLoc DL(N);
2510 
2511   // If the flag result is dead, turn this into an ADD.
2512   if (!N->hasAnyUseOfValue(1))
2513     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2514                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2515 
2516   // canonicalize constant to RHS.
2517   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2518   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2519   if (N0C && !N1C)
2520     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
2521 
2522   // fold (addc x, 0) -> x + no carry out
2523   if (isNullConstant(N1))
2524     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
2525                                         DL, MVT::Glue));
2526 
2527   // If it cannot overflow, transform into an add.
2528   if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2529     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2530                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2531 
2532   return SDValue();
2533 }
2534 
flipBoolean(SDValue V,const SDLoc & DL,SelectionDAG & DAG,const TargetLowering & TLI)2535 static SDValue flipBoolean(SDValue V, const SDLoc &DL,
2536                            SelectionDAG &DAG, const TargetLowering &TLI) {
2537   EVT VT = V.getValueType();
2538 
2539   SDValue Cst;
2540   switch (TLI.getBooleanContents(VT)) {
2541   case TargetLowering::ZeroOrOneBooleanContent:
2542   case TargetLowering::UndefinedBooleanContent:
2543     Cst = DAG.getConstant(1, DL, VT);
2544     break;
2545   case TargetLowering::ZeroOrNegativeOneBooleanContent:
2546     Cst = DAG.getAllOnesConstant(DL, VT);
2547     break;
2548   }
2549 
2550   return DAG.getNode(ISD::XOR, DL, VT, V, Cst);
2551 }
2552 
2553 /**
2554  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
2555  * then the flip also occurs if computing the inverse is the same cost.
2556  * This function returns an empty SDValue in case it cannot flip the boolean
2557  * without increasing the cost of the computation. If you want to flip a boolean
2558  * no matter what, use flipBoolean.
2559  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)2560 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
2561                                   const TargetLowering &TLI,
2562                                   bool Force) {
2563   if (Force && isa<ConstantSDNode>(V))
2564     return flipBoolean(V, SDLoc(V), DAG, TLI);
2565 
2566   if (V.getOpcode() != ISD::XOR)
2567     return SDValue();
2568 
2569   ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
2570   if (!Const)
2571     return SDValue();
2572 
2573   EVT VT = V.getValueType();
2574 
2575   bool IsFlip = false;
2576   switch(TLI.getBooleanContents(VT)) {
2577     case TargetLowering::ZeroOrOneBooleanContent:
2578       IsFlip = Const->isOne();
2579       break;
2580     case TargetLowering::ZeroOrNegativeOneBooleanContent:
2581       IsFlip = Const->isAllOnesValue();
2582       break;
2583     case TargetLowering::UndefinedBooleanContent:
2584       IsFlip = (Const->getAPIntValue() & 0x01) == 1;
2585       break;
2586   }
2587 
2588   if (IsFlip)
2589     return V.getOperand(0);
2590   if (Force)
2591     return flipBoolean(V, SDLoc(V), DAG, TLI);
2592   return SDValue();
2593 }
2594 
visitADDO(SDNode * N)2595 SDValue DAGCombiner::visitADDO(SDNode *N) {
2596   SDValue N0 = N->getOperand(0);
2597   SDValue N1 = N->getOperand(1);
2598   EVT VT = N0.getValueType();
2599   bool IsSigned = (ISD::SADDO == N->getOpcode());
2600 
2601   EVT CarryVT = N->getValueType(1);
2602   SDLoc DL(N);
2603 
2604   // If the flag result is dead, turn this into an ADD.
2605   if (!N->hasAnyUseOfValue(1))
2606     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2607                      DAG.getUNDEF(CarryVT));
2608 
2609   // canonicalize constant to RHS.
2610   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2611       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2612     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
2613 
2614   // fold (addo x, 0) -> x + no carry out
2615   if (isNullOrNullSplat(N1))
2616     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
2617 
2618   if (!IsSigned) {
2619     // If it cannot overflow, transform into an add.
2620     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2621       return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2622                        DAG.getConstant(0, DL, CarryVT));
2623 
2624     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
2625     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
2626       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
2627                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
2628       return CombineTo(N, Sub,
2629                        flipBoolean(Sub.getValue(1), DL, DAG, TLI));
2630     }
2631 
2632     if (SDValue Combined = visitUADDOLike(N0, N1, N))
2633       return Combined;
2634 
2635     if (SDValue Combined = visitUADDOLike(N1, N0, N))
2636       return Combined;
2637   }
2638 
2639   return SDValue();
2640 }
2641 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)2642 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
2643   EVT VT = N0.getValueType();
2644   if (VT.isVector())
2645     return SDValue();
2646 
2647   // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2648   // If Y + 1 cannot overflow.
2649   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
2650     SDValue Y = N1.getOperand(0);
2651     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
2652     if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
2653       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
2654                          N1.getOperand(2));
2655   }
2656 
2657   // (uaddo X, Carry) -> (addcarry X, 0, Carry)
2658   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2659     if (SDValue Carry = getAsCarry(TLI, N1))
2660       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
2661                          DAG.getConstant(0, SDLoc(N), VT), Carry);
2662 
2663   return SDValue();
2664 }
2665 
visitADDE(SDNode * N)2666 SDValue DAGCombiner::visitADDE(SDNode *N) {
2667   SDValue N0 = N->getOperand(0);
2668   SDValue N1 = N->getOperand(1);
2669   SDValue CarryIn = N->getOperand(2);
2670 
2671   // canonicalize constant to RHS
2672   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2673   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2674   if (N0C && !N1C)
2675     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
2676                        N1, N0, CarryIn);
2677 
2678   // fold (adde x, y, false) -> (addc x, y)
2679   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
2680     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
2681 
2682   return SDValue();
2683 }
2684 
visitADDCARRY(SDNode * N)2685 SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
2686   SDValue N0 = N->getOperand(0);
2687   SDValue N1 = N->getOperand(1);
2688   SDValue CarryIn = N->getOperand(2);
2689   SDLoc DL(N);
2690 
2691   // canonicalize constant to RHS
2692   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2693   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2694   if (N0C && !N1C)
2695     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
2696 
2697   // fold (addcarry x, y, false) -> (uaddo x, y)
2698   if (isNullConstant(CarryIn)) {
2699     if (!LegalOperations ||
2700         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
2701       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
2702   }
2703 
2704   // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
2705   if (isNullConstant(N0) && isNullConstant(N1)) {
2706     EVT VT = N0.getValueType();
2707     EVT CarryVT = CarryIn.getValueType();
2708     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
2709     AddToWorklist(CarryExt.getNode());
2710     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
2711                                     DAG.getConstant(1, DL, VT)),
2712                      DAG.getConstant(0, DL, CarryVT));
2713   }
2714 
2715   if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
2716     return Combined;
2717 
2718   if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
2719     return Combined;
2720 
2721   return SDValue();
2722 }
2723 
2724 /**
2725  * If we are facing some sort of diamond carry propapagtion pattern try to
2726  * break it up to generate something like:
2727  *   (addcarry X, 0, (addcarry A, B, Z):Carry)
2728  *
2729  * The end result is usually an increase in operation required, but because the
2730  * carry is now linearized, other tranforms can kick in and optimize the DAG.
2731  *
2732  * Patterns typically look something like
2733  *            (uaddo A, B)
2734  *             /       \
2735  *          Carry      Sum
2736  *            |          \
2737  *            | (addcarry *, 0, Z)
2738  *            |       /
2739  *             \   Carry
2740  *              |   /
2741  * (addcarry X, *, *)
2742  *
2743  * But numerous variation exist. Our goal is to identify A, B, X and Z and
2744  * produce a combine with a single path for carry propagation.
2745  */
combineADDCARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)2746 static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
2747                                       SDValue X, SDValue Carry0, SDValue Carry1,
2748                                       SDNode *N) {
2749   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
2750     return SDValue();
2751   if (Carry1.getOpcode() != ISD::UADDO)
2752     return SDValue();
2753 
2754   SDValue Z;
2755 
2756   /**
2757    * First look for a suitable Z. It will present itself in the form of
2758    * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
2759    */
2760   if (Carry0.getOpcode() == ISD::ADDCARRY &&
2761       isNullConstant(Carry0.getOperand(1))) {
2762     Z = Carry0.getOperand(2);
2763   } else if (Carry0.getOpcode() == ISD::UADDO &&
2764              isOneConstant(Carry0.getOperand(1))) {
2765     EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
2766     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
2767   } else {
2768     // We couldn't find a suitable Z.
2769     return SDValue();
2770   }
2771 
2772 
2773   auto cancelDiamond = [&](SDValue A,SDValue B) {
2774     SDLoc DL(N);
2775     SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z);
2776     Combiner.AddToWorklist(NewY.getNode());
2777     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X,
2778                        DAG.getConstant(0, DL, X.getValueType()),
2779                        NewY.getValue(1));
2780   };
2781 
2782   /**
2783    *      (uaddo A, B)
2784    *           |
2785    *          Sum
2786    *           |
2787    * (addcarry *, 0, Z)
2788    */
2789   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
2790     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
2791   }
2792 
2793   /**
2794    * (addcarry A, 0, Z)
2795    *         |
2796    *        Sum
2797    *         |
2798    *  (uaddo *, B)
2799    */
2800   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
2801     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
2802   }
2803 
2804   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
2805     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
2806   }
2807 
2808   return SDValue();
2809 }
2810 
2811 // If we are facing some sort of diamond carry/borrow in/out pattern try to
2812 // match patterns like:
2813 //
2814 //          (uaddo A, B)            CarryIn
2815 //            |  \                     |
2816 //            |   \                    |
2817 //    PartialSum   PartialCarryOutX   /
2818 //            |        |             /
2819 //            |    ____|____________/
2820 //            |   /    |
2821 //     (uaddo *, *)    \________
2822 //       |  \                   \
2823 //       |   \                   |
2824 //       |    PartialCarryOutY   |
2825 //       |        \              |
2826 //       |         \            /
2827 //   AddCarrySum    |    ______/
2828 //                  |   /
2829 //   CarryOut = (or *, *)
2830 //
2831 // And generate ADDCARRY (or SUBCARRY) with two result values:
2832 //
2833 //    {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn)
2834 //
2835 // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with
2836 // a single path for carry/borrow out propagation:
combineCarryDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,const TargetLowering & TLI,SDValue Carry0,SDValue Carry1,SDNode * N)2837 static SDValue combineCarryDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
2838                                    const TargetLowering &TLI, SDValue Carry0,
2839                                    SDValue Carry1, SDNode *N) {
2840   if (Carry0.getResNo() != 1 || Carry1.getResNo() != 1)
2841     return SDValue();
2842   unsigned Opcode = Carry0.getOpcode();
2843   if (Opcode != Carry1.getOpcode())
2844     return SDValue();
2845   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
2846     return SDValue();
2847 
2848   // Canonicalize the add/sub of A and B as Carry0 and the add/sub of the
2849   // carry/borrow in as Carry1. (The top and middle uaddo nodes respectively in
2850   // the above ASCII art.)
2851   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
2852       Carry1.getOperand(1) != Carry0.getValue(0))
2853     std::swap(Carry0, Carry1);
2854   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
2855       Carry1.getOperand(1) != Carry0.getValue(0))
2856     return SDValue();
2857 
2858   // The carry in value must be on the righthand side for subtraction.
2859   unsigned CarryInOperandNum =
2860       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
2861   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
2862     return SDValue();
2863   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
2864 
2865   unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY;
2866   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
2867     return SDValue();
2868 
2869   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
2870   // TODO: make getAsCarry() aware of how partial carries are merged.
2871   if (CarryIn.getOpcode() != ISD::ZERO_EXTEND)
2872     return SDValue();
2873   CarryIn = CarryIn.getOperand(0);
2874   if (CarryIn.getValueType() != MVT::i1)
2875     return SDValue();
2876 
2877   SDLoc DL(N);
2878   SDValue Merged =
2879       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
2880                   Carry0.getOperand(1), CarryIn);
2881 
2882   // Please note that because we have proven that the result of the UADDO/USUBO
2883   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
2884   // therefore prove that if the first UADDO/USUBO overflows, the second
2885   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
2886   // maximum value.
2887   //
2888   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
2889   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
2890   //
2891   // This is important because it means that OR and XOR can be used to merge
2892   // carry flags; and that AND can return a constant zero.
2893   //
2894   // TODO: match other operations that can merge flags (ADD, etc)
2895   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
2896   if (N->getOpcode() == ISD::AND)
2897     return DAG.getConstant(0, DL, MVT::i1);
2898   return Merged.getValue(1);
2899 }
2900 
visitADDCARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)2901 SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
2902                                        SDNode *N) {
2903   // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry.
2904   if (isBitwiseNot(N0))
2905     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
2906       SDLoc DL(N);
2907       SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1,
2908                                 N0.getOperand(0), NotC);
2909       return CombineTo(N, Sub,
2910                        flipBoolean(Sub.getValue(1), DL, DAG, TLI));
2911     }
2912 
2913   // Iff the flag result is dead:
2914   // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
2915   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
2916   // or the dependency between the instructions.
2917   if ((N0.getOpcode() == ISD::ADD ||
2918        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
2919         N0.getValue(1) != CarryIn)) &&
2920       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
2921     return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
2922                        N0.getOperand(0), N0.getOperand(1), CarryIn);
2923 
2924   /**
2925    * When one of the addcarry argument is itself a carry, we may be facing
2926    * a diamond carry propagation. In which case we try to transform the DAG
2927    * to ensure linear carry propagation if that is possible.
2928    */
2929   if (auto Y = getAsCarry(TLI, N1)) {
2930     // Because both are carries, Y and Z can be swapped.
2931     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
2932       return R;
2933     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
2934       return R;
2935   }
2936 
2937   return SDValue();
2938 }
2939 
2940 // Since it may not be valid to emit a fold to zero for vector initializers
2941 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)2942 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
2943                              SelectionDAG &DAG, bool LegalOperations) {
2944   if (!VT.isVector())
2945     return DAG.getConstant(0, DL, VT);
2946   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
2947     return DAG.getConstant(0, DL, VT);
2948   return SDValue();
2949 }
2950 
visitSUB(SDNode * N)2951 SDValue DAGCombiner::visitSUB(SDNode *N) {
2952   SDValue N0 = N->getOperand(0);
2953   SDValue N1 = N->getOperand(1);
2954   EVT VT = N0.getValueType();
2955   SDLoc DL(N);
2956 
2957   // fold vector ops
2958   if (VT.isVector()) {
2959     if (SDValue FoldedVOp = SimplifyVBinOp(N))
2960       return FoldedVOp;
2961 
2962     // fold (sub x, 0) -> x, vector edition
2963     if (ISD::isBuildVectorAllZeros(N1.getNode()))
2964       return N0;
2965   }
2966 
2967   // fold (sub x, x) -> 0
2968   // FIXME: Refactor this and xor and other similar operations together.
2969   if (N0 == N1)
2970     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
2971   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2972       DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
2973     // fold (sub c1, c2) -> c1-c2
2974     return DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N0.getNode(),
2975                                       N1.getNode());
2976   }
2977 
2978   if (SDValue NewSel = foldBinOpIntoSelect(N))
2979     return NewSel;
2980 
2981   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
2982 
2983   // fold (sub x, c) -> (add x, -c)
2984   if (N1C) {
2985     return DAG.getNode(ISD::ADD, DL, VT, N0,
2986                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
2987   }
2988 
2989   if (isNullOrNullSplat(N0)) {
2990     unsigned BitWidth = VT.getScalarSizeInBits();
2991     // Right-shifting everything out but the sign bit followed by negation is
2992     // the same as flipping arithmetic/logical shift type without the negation:
2993     // -(X >>u 31) -> (X >>s 31)
2994     // -(X >>s 31) -> (X >>u 31)
2995     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
2996       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
2997       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
2998         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
2999         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3000           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3001       }
3002     }
3003 
3004     // 0 - X --> 0 if the sub is NUW.
3005     if (N->getFlags().hasNoUnsignedWrap())
3006       return N0;
3007 
3008     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3009       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3010       // N1 must be 0 because negating the minimum signed value is undefined.
3011       if (N->getFlags().hasNoSignedWrap())
3012         return N0;
3013 
3014       // 0 - X --> X if X is 0 or the minimum signed value.
3015       return N1;
3016     }
3017   }
3018 
3019   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3020   if (isAllOnesOrAllOnesSplat(N0))
3021     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3022 
3023   // fold (A - (0-B)) -> A+B
3024   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3025     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3026 
3027   // fold A-(A-B) -> B
3028   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3029     return N1.getOperand(1);
3030 
3031   // fold (A+B)-A -> B
3032   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3033     return N0.getOperand(1);
3034 
3035   // fold (A+B)-B -> A
3036   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3037     return N0.getOperand(0);
3038 
3039   // fold (A+C1)-C2 -> A+(C1-C2)
3040   if (N0.getOpcode() == ISD::ADD &&
3041       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3042       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3043     SDValue NewC = DAG.FoldConstantArithmetic(
3044         ISD::SUB, DL, VT, N0.getOperand(1).getNode(), N1.getNode());
3045     assert(NewC && "Constant folding failed");
3046     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3047   }
3048 
3049   // fold C2-(A+C1) -> (C2-C1)-A
3050   if (N1.getOpcode() == ISD::ADD) {
3051     SDValue N11 = N1.getOperand(1);
3052     if (isConstantOrConstantVector(N0, /* NoOpaques */ true) &&
3053         isConstantOrConstantVector(N11, /* NoOpaques */ true)) {
3054       SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N0.getNode(),
3055                                                 N11.getNode());
3056       assert(NewC && "Constant folding failed");
3057       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3058     }
3059   }
3060 
3061   // fold (A-C1)-C2 -> A-(C1+C2)
3062   if (N0.getOpcode() == ISD::SUB &&
3063       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3064       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3065     SDValue NewC = DAG.FoldConstantArithmetic(
3066         ISD::ADD, DL, VT, N0.getOperand(1).getNode(), N1.getNode());
3067     assert(NewC && "Constant folding failed");
3068     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3069   }
3070 
3071   // fold (c1-A)-c2 -> (c1-c2)-A
3072   if (N0.getOpcode() == ISD::SUB &&
3073       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3074       isConstantOrConstantVector(N0.getOperand(0), /* NoOpaques */ true)) {
3075     SDValue NewC = DAG.FoldConstantArithmetic(
3076         ISD::SUB, DL, VT, N0.getOperand(0).getNode(), N1.getNode());
3077     assert(NewC && "Constant folding failed");
3078     return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3079   }
3080 
3081   // fold ((A+(B+or-C))-B) -> A+or-C
3082   if (N0.getOpcode() == ISD::ADD &&
3083       (N0.getOperand(1).getOpcode() == ISD::SUB ||
3084        N0.getOperand(1).getOpcode() == ISD::ADD) &&
3085       N0.getOperand(1).getOperand(0) == N1)
3086     return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3087                        N0.getOperand(1).getOperand(1));
3088 
3089   // fold ((A+(C+B))-B) -> A+C
3090   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3091       N0.getOperand(1).getOperand(1) == N1)
3092     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3093                        N0.getOperand(1).getOperand(0));
3094 
3095   // fold ((A-(B-C))-C) -> A-B
3096   if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3097       N0.getOperand(1).getOperand(1) == N1)
3098     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3099                        N0.getOperand(1).getOperand(0));
3100 
3101   // fold (A-(B-C)) -> A+(C-B)
3102   if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3103     return DAG.getNode(ISD::ADD, DL, VT, N0,
3104                        DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3105                                    N1.getOperand(0)));
3106 
3107   // A - (A & B)  ->  A & (~B)
3108   if (N1.getOpcode() == ISD::AND) {
3109     SDValue A = N1.getOperand(0);
3110     SDValue B = N1.getOperand(1);
3111     if (A != N0)
3112       std::swap(A, B);
3113     if (A == N0 &&
3114         (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3115       SDValue InvB =
3116           DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3117       return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3118     }
3119   }
3120 
3121   // fold (X - (-Y * Z)) -> (X + (Y * Z))
3122   if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3123     if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3124         isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3125       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3126                                 N1.getOperand(0).getOperand(1),
3127                                 N1.getOperand(1));
3128       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3129     }
3130     if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3131         isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3132       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3133                                 N1.getOperand(0),
3134                                 N1.getOperand(1).getOperand(1));
3135       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3136     }
3137   }
3138 
3139   // If either operand of a sub is undef, the result is undef
3140   if (N0.isUndef())
3141     return N0;
3142   if (N1.isUndef())
3143     return N1;
3144 
3145   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3146     return V;
3147 
3148   if (SDValue V = foldAddSubOfSignBit(N, DAG))
3149     return V;
3150 
3151   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3152     return V;
3153 
3154   // (x - y) - 1  ->  add (xor y, -1), x
3155   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && isOneOrOneSplat(N1)) {
3156     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
3157                               DAG.getAllOnesConstant(DL, VT));
3158     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
3159   }
3160 
3161   // Look for:
3162   //   sub y, (xor x, -1)
3163   // And if the target does not like this form then turn into:
3164   //   add (add x, y), 1
3165   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3166     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3167     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3168   }
3169 
3170   // Hoist one-use addition by non-opaque constant:
3171   //   (x + C) - y  ->  (x - y) + C
3172   if (N0.hasOneUse() && N0.getOpcode() == ISD::ADD &&
3173       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3174     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3175     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3176   }
3177   // y - (x + C)  ->  (y - x) - C
3178   if (N1.hasOneUse() && N1.getOpcode() == ISD::ADD &&
3179       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
3180     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
3181     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
3182   }
3183   // (x - C) - y  ->  (x - y) - C
3184   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3185   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3186       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3187     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3188     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
3189   }
3190   // (C - x) - y  ->  C - (x + y)
3191   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3192       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3193     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
3194     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
3195   }
3196 
3197   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3198   // rather than 'sub 0/1' (the sext should get folded).
3199   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3200   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3201       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
3202       TLI.getBooleanContents(VT) ==
3203           TargetLowering::ZeroOrNegativeOneBooleanContent) {
3204     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
3205     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
3206   }
3207 
3208   // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3209   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
3210     if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3211       SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
3212       SDValue S0 = N1.getOperand(0);
3213       if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0)) {
3214         unsigned OpSizeInBits = VT.getScalarSizeInBits();
3215         if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
3216           if (C->getAPIntValue() == (OpSizeInBits - 1))
3217             return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
3218       }
3219     }
3220   }
3221 
3222   // If the relocation model supports it, consider symbol offsets.
3223   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
3224     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3225       // fold (sub Sym, c) -> Sym-c
3226       if (N1C && GA->getOpcode() == ISD::GlobalAddress)
3227         return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
3228                                     GA->getOffset() -
3229                                         (uint64_t)N1C->getSExtValue());
3230       // fold (sub Sym+c1, Sym+c2) -> c1-c2
3231       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
3232         if (GA->getGlobal() == GB->getGlobal())
3233           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
3234                                  DL, VT);
3235     }
3236 
3237   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3238   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3239     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3240     if (TN->getVT() == MVT::i1) {
3241       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3242                                  DAG.getConstant(1, DL, VT));
3243       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
3244     }
3245   }
3246 
3247   // Prefer an add for more folding potential and possibly better codegen:
3248   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3249   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3250     SDValue ShAmt = N1.getOperand(1);
3251     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
3252     if (ShAmtC &&
3253         ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
3254       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
3255       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
3256     }
3257   }
3258 
3259   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) {
3260     // (sub Carry, X)  ->  (addcarry (sub 0, X), 0, Carry)
3261     if (SDValue Carry = getAsCarry(TLI, N0)) {
3262       SDValue X = N1;
3263       SDValue Zero = DAG.getConstant(0, DL, VT);
3264       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
3265       return DAG.getNode(ISD::ADDCARRY, DL,
3266                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
3267                          Carry);
3268     }
3269   }
3270 
3271   return SDValue();
3272 }
3273 
visitSUBSAT(SDNode * N)3274 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
3275   SDValue N0 = N->getOperand(0);
3276   SDValue N1 = N->getOperand(1);
3277   EVT VT = N0.getValueType();
3278   SDLoc DL(N);
3279 
3280   // fold vector ops
3281   if (VT.isVector()) {
3282     // TODO SimplifyVBinOp
3283 
3284     // fold (sub_sat x, 0) -> x, vector edition
3285     if (ISD::isBuildVectorAllZeros(N1.getNode()))
3286       return N0;
3287   }
3288 
3289   // fold (sub_sat x, undef) -> 0
3290   if (N0.isUndef() || N1.isUndef())
3291     return DAG.getConstant(0, DL, VT);
3292 
3293   // fold (sub_sat x, x) -> 0
3294   if (N0 == N1)
3295     return DAG.getConstant(0, DL, VT);
3296 
3297   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3298       DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
3299     // fold (sub_sat c1, c2) -> c3
3300     return DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, N0.getNode(),
3301                                       N1.getNode());
3302   }
3303 
3304   // fold (sub_sat x, 0) -> x
3305   if (isNullConstant(N1))
3306     return N0;
3307 
3308   return SDValue();
3309 }
3310 
visitSUBC(SDNode * N)3311 SDValue DAGCombiner::visitSUBC(SDNode *N) {
3312   SDValue N0 = N->getOperand(0);
3313   SDValue N1 = N->getOperand(1);
3314   EVT VT = N0.getValueType();
3315   SDLoc DL(N);
3316 
3317   // If the flag result is dead, turn this into an SUB.
3318   if (!N->hasAnyUseOfValue(1))
3319     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3320                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3321 
3322   // fold (subc x, x) -> 0 + no borrow
3323   if (N0 == N1)
3324     return CombineTo(N, DAG.getConstant(0, DL, VT),
3325                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3326 
3327   // fold (subc x, 0) -> x + no borrow
3328   if (isNullConstant(N1))
3329     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3330 
3331   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3332   if (isAllOnesConstant(N0))
3333     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3334                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3335 
3336   return SDValue();
3337 }
3338 
visitSUBO(SDNode * N)3339 SDValue DAGCombiner::visitSUBO(SDNode *N) {
3340   SDValue N0 = N->getOperand(0);
3341   SDValue N1 = N->getOperand(1);
3342   EVT VT = N0.getValueType();
3343   bool IsSigned = (ISD::SSUBO == N->getOpcode());
3344 
3345   EVT CarryVT = N->getValueType(1);
3346   SDLoc DL(N);
3347 
3348   // If the flag result is dead, turn this into an SUB.
3349   if (!N->hasAnyUseOfValue(1))
3350     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3351                      DAG.getUNDEF(CarryVT));
3352 
3353   // fold (subo x, x) -> 0 + no borrow
3354   if (N0 == N1)
3355     return CombineTo(N, DAG.getConstant(0, DL, VT),
3356                      DAG.getConstant(0, DL, CarryVT));
3357 
3358   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3359 
3360   // fold (subox, c) -> (addo x, -c)
3361   if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) {
3362     return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
3363                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3364   }
3365 
3366   // fold (subo x, 0) -> x + no borrow
3367   if (isNullOrNullSplat(N1))
3368     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3369 
3370   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3371   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
3372     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3373                      DAG.getConstant(0, DL, CarryVT));
3374 
3375   return SDValue();
3376 }
3377 
visitSUBE(SDNode * N)3378 SDValue DAGCombiner::visitSUBE(SDNode *N) {
3379   SDValue N0 = N->getOperand(0);
3380   SDValue N1 = N->getOperand(1);
3381   SDValue CarryIn = N->getOperand(2);
3382 
3383   // fold (sube x, y, false) -> (subc x, y)
3384   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3385     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
3386 
3387   return SDValue();
3388 }
3389 
visitSUBCARRY(SDNode * N)3390 SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
3391   SDValue N0 = N->getOperand(0);
3392   SDValue N1 = N->getOperand(1);
3393   SDValue CarryIn = N->getOperand(2);
3394 
3395   // fold (subcarry x, y, false) -> (usubo x, y)
3396   if (isNullConstant(CarryIn)) {
3397     if (!LegalOperations ||
3398         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
3399       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
3400   }
3401 
3402   return SDValue();
3403 }
3404 
3405 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
3406 // UMULFIXSAT here.
visitMULFIX(SDNode * N)3407 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
3408   SDValue N0 = N->getOperand(0);
3409   SDValue N1 = N->getOperand(1);
3410   SDValue Scale = N->getOperand(2);
3411   EVT VT = N0.getValueType();
3412 
3413   // fold (mulfix x, undef, scale) -> 0
3414   if (N0.isUndef() || N1.isUndef())
3415     return DAG.getConstant(0, SDLoc(N), VT);
3416 
3417   // Canonicalize constant to RHS (vector doesn't have to splat)
3418   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3419      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3420     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
3421 
3422   // fold (mulfix x, 0, scale) -> 0
3423   if (isNullConstant(N1))
3424     return DAG.getConstant(0, SDLoc(N), VT);
3425 
3426   return SDValue();
3427 }
3428 
visitMUL(SDNode * N)3429 SDValue DAGCombiner::visitMUL(SDNode *N) {
3430   SDValue N0 = N->getOperand(0);
3431   SDValue N1 = N->getOperand(1);
3432   EVT VT = N0.getValueType();
3433 
3434   // fold (mul x, undef) -> 0
3435   if (N0.isUndef() || N1.isUndef())
3436     return DAG.getConstant(0, SDLoc(N), VT);
3437 
3438   bool N0IsConst = false;
3439   bool N1IsConst = false;
3440   bool N1IsOpaqueConst = false;
3441   bool N0IsOpaqueConst = false;
3442   APInt ConstValue0, ConstValue1;
3443   // fold vector ops
3444   if (VT.isVector()) {
3445     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3446       return FoldedVOp;
3447 
3448     N0IsConst = ISD::isConstantSplatVector(N0.getNode(), ConstValue0);
3449     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
3450     assert((!N0IsConst ||
3451             ConstValue0.getBitWidth() == VT.getScalarSizeInBits()) &&
3452            "Splat APInt should be element width");
3453     assert((!N1IsConst ||
3454             ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
3455            "Splat APInt should be element width");
3456   } else {
3457     N0IsConst = isa<ConstantSDNode>(N0);
3458     if (N0IsConst) {
3459       ConstValue0 = cast<ConstantSDNode>(N0)->getAPIntValue();
3460       N0IsOpaqueConst = cast<ConstantSDNode>(N0)->isOpaque();
3461     }
3462     N1IsConst = isa<ConstantSDNode>(N1);
3463     if (N1IsConst) {
3464       ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue();
3465       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
3466     }
3467   }
3468 
3469   // fold (mul c1, c2) -> c1*c2
3470   if (N0IsConst && N1IsConst && !N0IsOpaqueConst && !N1IsOpaqueConst)
3471     return DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT,
3472                                       N0.getNode(), N1.getNode());
3473 
3474   // canonicalize constant to RHS (vector doesn't have to splat)
3475   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3476      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3477     return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
3478   // fold (mul x, 0) -> 0
3479   if (N1IsConst && ConstValue1.isNullValue())
3480     return N1;
3481   // fold (mul x, 1) -> x
3482   if (N1IsConst && ConstValue1.isOneValue())
3483     return N0;
3484 
3485   if (SDValue NewSel = foldBinOpIntoSelect(N))
3486     return NewSel;
3487 
3488   // fold (mul x, -1) -> 0-x
3489   if (N1IsConst && ConstValue1.isAllOnesValue()) {
3490     SDLoc DL(N);
3491     return DAG.getNode(ISD::SUB, DL, VT,
3492                        DAG.getConstant(0, DL, VT), N0);
3493   }
3494   // fold (mul x, (1 << c)) -> x << c
3495   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
3496       DAG.isKnownToBeAPowerOfTwo(N1) &&
3497       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
3498     SDLoc DL(N);
3499     SDValue LogBase2 = BuildLogBase2(N1, DL);
3500     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
3501     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
3502     return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
3503   }
3504   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
3505   if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) {
3506     unsigned Log2Val = (-ConstValue1).logBase2();
3507     SDLoc DL(N);
3508     // FIXME: If the input is something that is easily negated (e.g. a
3509     // single-use add), we should put the negate there.
3510     return DAG.getNode(ISD::SUB, DL, VT,
3511                        DAG.getConstant(0, DL, VT),
3512                        DAG.getNode(ISD::SHL, DL, VT, N0,
3513                             DAG.getConstant(Log2Val, DL,
3514                                       getShiftAmountTy(N0.getValueType()))));
3515   }
3516 
3517   // Try to transform multiply-by-(power-of-2 +/- 1) into shift and add/sub.
3518   // mul x, (2^N + 1) --> add (shl x, N), x
3519   // mul x, (2^N - 1) --> sub (shl x, N), x
3520   // Examples: x * 33 --> (x << 5) + x
3521   //           x * 15 --> (x << 4) - x
3522   //           x * -33 --> -((x << 5) + x)
3523   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
3524   if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
3525     // TODO: We could handle more general decomposition of any constant by
3526     //       having the target set a limit on number of ops and making a
3527     //       callback to determine that sequence (similar to sqrt expansion).
3528     unsigned MathOp = ISD::DELETED_NODE;
3529     APInt MulC = ConstValue1.abs();
3530     if ((MulC - 1).isPowerOf2())
3531       MathOp = ISD::ADD;
3532     else if ((MulC + 1).isPowerOf2())
3533       MathOp = ISD::SUB;
3534 
3535     if (MathOp != ISD::DELETED_NODE) {
3536       unsigned ShAmt =
3537           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
3538       assert(ShAmt < VT.getScalarSizeInBits() &&
3539              "multiply-by-constant generated out of bounds shift");
3540       SDLoc DL(N);
3541       SDValue Shl =
3542           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
3543       SDValue R = DAG.getNode(MathOp, DL, VT, Shl, N0);
3544       if (ConstValue1.isNegative())
3545         R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R);
3546       return R;
3547     }
3548   }
3549 
3550   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
3551   if (N0.getOpcode() == ISD::SHL &&
3552       isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
3553       isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
3554     SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1));
3555     if (isConstantOrConstantVector(C3))
3556       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3);
3557   }
3558 
3559   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
3560   // use.
3561   {
3562     SDValue Sh(nullptr, 0), Y(nullptr, 0);
3563 
3564     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
3565     if (N0.getOpcode() == ISD::SHL &&
3566         isConstantOrConstantVector(N0.getOperand(1)) &&
3567         N0.getNode()->hasOneUse()) {
3568       Sh = N0; Y = N1;
3569     } else if (N1.getOpcode() == ISD::SHL &&
3570                isConstantOrConstantVector(N1.getOperand(1)) &&
3571                N1.getNode()->hasOneUse()) {
3572       Sh = N1; Y = N0;
3573     }
3574 
3575     if (Sh.getNode()) {
3576       SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, Sh.getOperand(0), Y);
3577       return DAG.getNode(ISD::SHL, SDLoc(N), VT, Mul, Sh.getOperand(1));
3578     }
3579   }
3580 
3581   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
3582   if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
3583       N0.getOpcode() == ISD::ADD &&
3584       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
3585       isMulAddWithConstProfitable(N, N0, N1))
3586       return DAG.getNode(ISD::ADD, SDLoc(N), VT,
3587                          DAG.getNode(ISD::MUL, SDLoc(N0), VT,
3588                                      N0.getOperand(0), N1),
3589                          DAG.getNode(ISD::MUL, SDLoc(N1), VT,
3590                                      N0.getOperand(1), N1));
3591 
3592   // reassociate mul
3593   if (SDValue RMUL = reassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags()))
3594     return RMUL;
3595 
3596   return SDValue();
3597 }
3598 
3599 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)3600 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
3601                                      const TargetLowering &TLI) {
3602   RTLIB::Libcall LC;
3603   EVT NodeType = Node->getValueType(0);
3604   if (!NodeType.isSimple())
3605     return false;
3606   switch (NodeType.getSimpleVT().SimpleTy) {
3607   default: return false; // No libcall for vector types.
3608   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
3609   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
3610   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
3611   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
3612   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
3613   }
3614 
3615   return TLI.getLibcallName(LC) != nullptr;
3616 }
3617 
3618 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)3619 SDValue DAGCombiner::useDivRem(SDNode *Node) {
3620   if (Node->use_empty())
3621     return SDValue(); // This is a dead node, leave it alone.
3622 
3623   unsigned Opcode = Node->getOpcode();
3624   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
3625   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
3626 
3627   // DivMod lib calls can still work on non-legal types if using lib-calls.
3628   EVT VT = Node->getValueType(0);
3629   if (VT.isVector() || !VT.isInteger())
3630     return SDValue();
3631 
3632   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
3633     return SDValue();
3634 
3635   // If DIVREM is going to get expanded into a libcall,
3636   // but there is no libcall available, then don't combine.
3637   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
3638       !isDivRemLibcallAvailable(Node, isSigned, TLI))
3639     return SDValue();
3640 
3641   // If div is legal, it's better to do the normal expansion
3642   unsigned OtherOpcode = 0;
3643   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
3644     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
3645     if (TLI.isOperationLegalOrCustom(Opcode, VT))
3646       return SDValue();
3647   } else {
3648     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
3649     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
3650       return SDValue();
3651   }
3652 
3653   SDValue Op0 = Node->getOperand(0);
3654   SDValue Op1 = Node->getOperand(1);
3655   SDValue combined;
3656   for (SDNode::use_iterator UI = Op0.getNode()->use_begin(),
3657          UE = Op0.getNode()->use_end(); UI != UE; ++UI) {
3658     SDNode *User = *UI;
3659     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
3660         User->use_empty())
3661       continue;
3662     // Convert the other matching node(s), too;
3663     // otherwise, the DIVREM may get target-legalized into something
3664     // target-specific that we won't be able to recognize.
3665     unsigned UserOpc = User->getOpcode();
3666     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
3667         User->getOperand(0) == Op0 &&
3668         User->getOperand(1) == Op1) {
3669       if (!combined) {
3670         if (UserOpc == OtherOpcode) {
3671           SDVTList VTs = DAG.getVTList(VT, VT);
3672           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
3673         } else if (UserOpc == DivRemOpc) {
3674           combined = SDValue(User, 0);
3675         } else {
3676           assert(UserOpc == Opcode);
3677           continue;
3678         }
3679       }
3680       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
3681         CombineTo(User, combined);
3682       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
3683         CombineTo(User, combined.getValue(1));
3684     }
3685   }
3686   return combined;
3687 }
3688 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)3689 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
3690   SDValue N0 = N->getOperand(0);
3691   SDValue N1 = N->getOperand(1);
3692   EVT VT = N->getValueType(0);
3693   SDLoc DL(N);
3694 
3695   unsigned Opc = N->getOpcode();
3696   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
3697   ConstantSDNode *N1C = isConstOrConstSplat(N1);
3698 
3699   // X / undef -> undef
3700   // X % undef -> undef
3701   // X / 0 -> undef
3702   // X % 0 -> undef
3703   // NOTE: This includes vectors where any divisor element is zero/undef.
3704   if (DAG.isUndef(Opc, {N0, N1}))
3705     return DAG.getUNDEF(VT);
3706 
3707   // undef / X -> 0
3708   // undef % X -> 0
3709   if (N0.isUndef())
3710     return DAG.getConstant(0, DL, VT);
3711 
3712   // 0 / X -> 0
3713   // 0 % X -> 0
3714   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3715   if (N0C && N0C->isNullValue())
3716     return N0;
3717 
3718   // X / X -> 1
3719   // X % X -> 0
3720   if (N0 == N1)
3721     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
3722 
3723   // X / 1 -> X
3724   // X % 1 -> 0
3725   // If this is a boolean op (single-bit element type), we can't have
3726   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
3727   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
3728   // it's a 1.
3729   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
3730     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
3731 
3732   return SDValue();
3733 }
3734 
visitSDIV(SDNode * N)3735 SDValue DAGCombiner::visitSDIV(SDNode *N) {
3736   SDValue N0 = N->getOperand(0);
3737   SDValue N1 = N->getOperand(1);
3738   EVT VT = N->getValueType(0);
3739   EVT CCVT = getSetCCResultType(VT);
3740 
3741   // fold vector ops
3742   if (VT.isVector())
3743     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3744       return FoldedVOp;
3745 
3746   SDLoc DL(N);
3747 
3748   // fold (sdiv c1, c2) -> c1/c2
3749   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3750   ConstantSDNode *N1C = isConstOrConstSplat(N1);
3751   if (N0C && N1C && !N0C->isOpaque() && !N1C->isOpaque())
3752     return DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, N0C, N1C);
3753   // fold (sdiv X, -1) -> 0-X
3754   if (N1C && N1C->isAllOnesValue())
3755     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
3756   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
3757   if (N1C && N1C->getAPIntValue().isMinSignedValue())
3758     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
3759                          DAG.getConstant(1, DL, VT),
3760                          DAG.getConstant(0, DL, VT));
3761 
3762   if (SDValue V = simplifyDivRem(N, DAG))
3763     return V;
3764 
3765   if (SDValue NewSel = foldBinOpIntoSelect(N))
3766     return NewSel;
3767 
3768   // If we know the sign bits of both operands are zero, strength reduce to a
3769   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
3770   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
3771     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
3772 
3773   if (SDValue V = visitSDIVLike(N0, N1, N)) {
3774     // If the corresponding remainder node exists, update its users with
3775     // (Dividend - (Quotient * Divisor).
3776     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
3777                                               { N0, N1 })) {
3778       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
3779       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
3780       AddToWorklist(Mul.getNode());
3781       AddToWorklist(Sub.getNode());
3782       CombineTo(RemNode, Sub);
3783     }
3784     return V;
3785   }
3786 
3787   // sdiv, srem -> sdivrem
3788   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
3789   // true.  Otherwise, we break the simplification logic in visitREM().
3790   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
3791   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
3792     if (SDValue DivRem = useDivRem(N))
3793         return DivRem;
3794 
3795   return SDValue();
3796 }
3797 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)3798 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
3799   SDLoc DL(N);
3800   EVT VT = N->getValueType(0);
3801   EVT CCVT = getSetCCResultType(VT);
3802   unsigned BitWidth = VT.getScalarSizeInBits();
3803 
3804   // Helper for determining whether a value is a power-2 constant scalar or a
3805   // vector of such elements.
3806   auto IsPowerOfTwo = [](ConstantSDNode *C) {
3807     if (C->isNullValue() || C->isOpaque())
3808       return false;
3809     if (C->getAPIntValue().isPowerOf2())
3810       return true;
3811     if ((-C->getAPIntValue()).isPowerOf2())
3812       return true;
3813     return false;
3814   };
3815 
3816   // fold (sdiv X, pow2) -> simple ops after legalize
3817   // FIXME: We check for the exact bit here because the generic lowering gives
3818   // better results in that case. The target-specific lowering should learn how
3819   // to handle exact sdivs efficiently.
3820   if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) {
3821     // Target-specific implementation of sdiv x, pow2.
3822     if (SDValue Res = BuildSDIVPow2(N))
3823       return Res;
3824 
3825     // Create constants that are functions of the shift amount value.
3826     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
3827     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
3828     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
3829     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
3830     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
3831     if (!isConstantOrConstantVector(Inexact))
3832       return SDValue();
3833 
3834     // Splat the sign bit into the register
3835     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
3836                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
3837     AddToWorklist(Sign.getNode());
3838 
3839     // Add (N0 < 0) ? abs2 - 1 : 0;
3840     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
3841     AddToWorklist(Srl.getNode());
3842     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
3843     AddToWorklist(Add.getNode());
3844     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
3845     AddToWorklist(Sra.getNode());
3846 
3847     // Special case: (sdiv X, 1) -> X
3848     // Special Case: (sdiv X, -1) -> 0-X
3849     SDValue One = DAG.getConstant(1, DL, VT);
3850     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
3851     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
3852     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
3853     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
3854     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
3855 
3856     // If dividing by a positive value, we're done. Otherwise, the result must
3857     // be negated.
3858     SDValue Zero = DAG.getConstant(0, DL, VT);
3859     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
3860 
3861     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
3862     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
3863     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
3864     return Res;
3865   }
3866 
3867   // If integer divide is expensive and we satisfy the requirements, emit an
3868   // alternate sequence.  Targets may check function attributes for size/speed
3869   // trade-offs.
3870   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
3871   if (isConstantOrConstantVector(N1) &&
3872       !TLI.isIntDivCheap(N->getValueType(0), Attr))
3873     if (SDValue Op = BuildSDIV(N))
3874       return Op;
3875 
3876   return SDValue();
3877 }
3878 
visitUDIV(SDNode * N)3879 SDValue DAGCombiner::visitUDIV(SDNode *N) {
3880   SDValue N0 = N->getOperand(0);
3881   SDValue N1 = N->getOperand(1);
3882   EVT VT = N->getValueType(0);
3883   EVT CCVT = getSetCCResultType(VT);
3884 
3885   // fold vector ops
3886   if (VT.isVector())
3887     if (SDValue FoldedVOp = SimplifyVBinOp(N))
3888       return FoldedVOp;
3889 
3890   SDLoc DL(N);
3891 
3892   // fold (udiv c1, c2) -> c1/c2
3893   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3894   ConstantSDNode *N1C = isConstOrConstSplat(N1);
3895   if (N0C && N1C)
3896     if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT,
3897                                                     N0C, N1C))
3898       return Folded;
3899   // fold (udiv X, -1) -> select(X == -1, 1, 0)
3900   if (N1C && N1C->getAPIntValue().isAllOnesValue())
3901     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
3902                          DAG.getConstant(1, DL, VT),
3903                          DAG.getConstant(0, DL, VT));
3904 
3905   if (SDValue V = simplifyDivRem(N, DAG))
3906     return V;
3907 
3908   if (SDValue NewSel = foldBinOpIntoSelect(N))
3909     return NewSel;
3910 
3911   if (SDValue V = visitUDIVLike(N0, N1, N)) {
3912     // If the corresponding remainder node exists, update its users with
3913     // (Dividend - (Quotient * Divisor).
3914     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
3915                                               { N0, N1 })) {
3916       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
3917       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
3918       AddToWorklist(Mul.getNode());
3919       AddToWorklist(Sub.getNode());
3920       CombineTo(RemNode, Sub);
3921     }
3922     return V;
3923   }
3924 
3925   // sdiv, srem -> sdivrem
3926   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
3927   // true.  Otherwise, we break the simplification logic in visitREM().
3928   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
3929   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
3930     if (SDValue DivRem = useDivRem(N))
3931         return DivRem;
3932 
3933   return SDValue();
3934 }
3935 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)3936 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
3937   SDLoc DL(N);
3938   EVT VT = N->getValueType(0);
3939 
3940   // fold (udiv x, (1 << c)) -> x >>u c
3941   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
3942       DAG.isKnownToBeAPowerOfTwo(N1)) {
3943     SDValue LogBase2 = BuildLogBase2(N1, DL);
3944     AddToWorklist(LogBase2.getNode());
3945 
3946     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
3947     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
3948     AddToWorklist(Trunc.getNode());
3949     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
3950   }
3951 
3952   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
3953   if (N1.getOpcode() == ISD::SHL) {
3954     SDValue N10 = N1.getOperand(0);
3955     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
3956         DAG.isKnownToBeAPowerOfTwo(N10)) {
3957       SDValue LogBase2 = BuildLogBase2(N10, DL);
3958       AddToWorklist(LogBase2.getNode());
3959 
3960       EVT ADDVT = N1.getOperand(1).getValueType();
3961       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
3962       AddToWorklist(Trunc.getNode());
3963       SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
3964       AddToWorklist(Add.getNode());
3965       return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
3966     }
3967   }
3968 
3969   // fold (udiv x, c) -> alternate
3970   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
3971   if (isConstantOrConstantVector(N1) &&
3972       !TLI.isIntDivCheap(N->getValueType(0), Attr))
3973     if (SDValue Op = BuildUDIV(N))
3974       return Op;
3975 
3976   return SDValue();
3977 }
3978 
3979 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)3980 SDValue DAGCombiner::visitREM(SDNode *N) {
3981   unsigned Opcode = N->getOpcode();
3982   SDValue N0 = N->getOperand(0);
3983   SDValue N1 = N->getOperand(1);
3984   EVT VT = N->getValueType(0);
3985   EVT CCVT = getSetCCResultType(VT);
3986 
3987   bool isSigned = (Opcode == ISD::SREM);
3988   SDLoc DL(N);
3989 
3990   // fold (rem c1, c2) -> c1%c2
3991   ConstantSDNode *N0C = isConstOrConstSplat(N0);
3992   ConstantSDNode *N1C = isConstOrConstSplat(N1);
3993   if (N0C && N1C)
3994     if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C))
3995       return Folded;
3996   // fold (urem X, -1) -> select(X == -1, 0, x)
3997   if (!isSigned && N1C && N1C->getAPIntValue().isAllOnesValue())
3998     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
3999                          DAG.getConstant(0, DL, VT), N0);
4000 
4001   if (SDValue V = simplifyDivRem(N, DAG))
4002     return V;
4003 
4004   if (SDValue NewSel = foldBinOpIntoSelect(N))
4005     return NewSel;
4006 
4007   if (isSigned) {
4008     // If we know the sign bits of both operands are zero, strength reduce to a
4009     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4010     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4011       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4012   } else {
4013     SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4014     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4015       // fold (urem x, pow2) -> (and x, pow2-1)
4016       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4017       AddToWorklist(Add.getNode());
4018       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4019     }
4020     if (N1.getOpcode() == ISD::SHL &&
4021         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
4022       // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4023       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4024       AddToWorklist(Add.getNode());
4025       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4026     }
4027   }
4028 
4029   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4030 
4031   // If X/C can be simplified by the division-by-constant logic, lower
4032   // X%C to the equivalent of X-X/C*C.
4033   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4034   // speculative DIV must not cause a DIVREM conversion.  We guard against this
4035   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
4036   // combine will not return a DIVREM.  Regardless, checking cheapness here
4037   // makes sense since the simplification results in fatter code.
4038   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
4039     SDValue OptimizedDiv =
4040         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4041     if (OptimizedDiv.getNode()) {
4042       // If the equivalent Div node also exists, update its users.
4043       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4044       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
4045                                                 { N0, N1 }))
4046         CombineTo(DivNode, OptimizedDiv);
4047       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
4048       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4049       AddToWorklist(OptimizedDiv.getNode());
4050       AddToWorklist(Mul.getNode());
4051       return Sub;
4052     }
4053   }
4054 
4055   // sdiv, srem -> sdivrem
4056   if (SDValue DivRem = useDivRem(N))
4057     return DivRem.getValue(1);
4058 
4059   return SDValue();
4060 }
4061 
visitMULHS(SDNode * N)4062 SDValue DAGCombiner::visitMULHS(SDNode *N) {
4063   SDValue N0 = N->getOperand(0);
4064   SDValue N1 = N->getOperand(1);
4065   EVT VT = N->getValueType(0);
4066   SDLoc DL(N);
4067 
4068   if (VT.isVector()) {
4069     // fold (mulhs x, 0) -> 0
4070     // do not return N0/N1, because undef node may exist.
4071     if (ISD::isBuildVectorAllZeros(N0.getNode()) ||
4072         ISD::isBuildVectorAllZeros(N1.getNode()))
4073       return DAG.getConstant(0, DL, VT);
4074   }
4075 
4076   // fold (mulhs x, 0) -> 0
4077   if (isNullConstant(N1))
4078     return N1;
4079   // fold (mulhs x, 1) -> (sra x, size(x)-1)
4080   if (isOneConstant(N1))
4081     return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
4082                        DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
4083                                        getShiftAmountTy(N0.getValueType())));
4084 
4085   // fold (mulhs x, undef) -> 0
4086   if (N0.isUndef() || N1.isUndef())
4087     return DAG.getConstant(0, DL, VT);
4088 
4089   // If the type twice as wide is legal, transform the mulhs to a wider multiply
4090   // plus a shift.
4091   if (VT.isSimple() && !VT.isVector()) {
4092     MVT Simple = VT.getSimpleVT();
4093     unsigned SimpleSize = Simple.getSizeInBits();
4094     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4095     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4096       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4097       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4098       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4099       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4100             DAG.getConstant(SimpleSize, DL,
4101                             getShiftAmountTy(N1.getValueType())));
4102       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4103     }
4104   }
4105 
4106   return SDValue();
4107 }
4108 
visitMULHU(SDNode * N)4109 SDValue DAGCombiner::visitMULHU(SDNode *N) {
4110   SDValue N0 = N->getOperand(0);
4111   SDValue N1 = N->getOperand(1);
4112   EVT VT = N->getValueType(0);
4113   SDLoc DL(N);
4114 
4115   if (VT.isVector()) {
4116     // fold (mulhu x, 0) -> 0
4117     // do not return N0/N1, because undef node may exist.
4118     if (ISD::isBuildVectorAllZeros(N0.getNode()) ||
4119         ISD::isBuildVectorAllZeros(N1.getNode()))
4120       return DAG.getConstant(0, DL, VT);
4121   }
4122 
4123   // fold (mulhu x, 0) -> 0
4124   if (isNullConstant(N1))
4125     return N1;
4126   // fold (mulhu x, 1) -> 0
4127   if (isOneConstant(N1))
4128     return DAG.getConstant(0, DL, N0.getValueType());
4129   // fold (mulhu x, undef) -> 0
4130   if (N0.isUndef() || N1.isUndef())
4131     return DAG.getConstant(0, DL, VT);
4132 
4133   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
4134   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4135       DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
4136     unsigned NumEltBits = VT.getScalarSizeInBits();
4137     SDValue LogBase2 = BuildLogBase2(N1, DL);
4138     SDValue SRLAmt = DAG.getNode(
4139         ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
4140     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4141     SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
4142     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4143   }
4144 
4145   // If the type twice as wide is legal, transform the mulhu to a wider multiply
4146   // plus a shift.
4147   if (VT.isSimple() && !VT.isVector()) {
4148     MVT Simple = VT.getSimpleVT();
4149     unsigned SimpleSize = Simple.getSizeInBits();
4150     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4151     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4152       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4153       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4154       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4155       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4156             DAG.getConstant(SimpleSize, DL,
4157                             getShiftAmountTy(N1.getValueType())));
4158       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4159     }
4160   }
4161 
4162   return SDValue();
4163 }
4164 
4165 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
4166 /// give the opcodes for the two computations that are being performed. Return
4167 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)4168 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
4169                                                 unsigned HiOp) {
4170   // If the high half is not needed, just compute the low half.
4171   bool HiExists = N->hasAnyUseOfValue(1);
4172   if (!HiExists && (!LegalOperations ||
4173                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
4174     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4175     return CombineTo(N, Res, Res);
4176   }
4177 
4178   // If the low half is not needed, just compute the high half.
4179   bool LoExists = N->hasAnyUseOfValue(0);
4180   if (!LoExists && (!LegalOperations ||
4181                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
4182     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4183     return CombineTo(N, Res, Res);
4184   }
4185 
4186   // If both halves are used, return as it is.
4187   if (LoExists && HiExists)
4188     return SDValue();
4189 
4190   // If the two computed results can be simplified separately, separate them.
4191   if (LoExists) {
4192     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4193     AddToWorklist(Lo.getNode());
4194     SDValue LoOpt = combine(Lo.getNode());
4195     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
4196         (!LegalOperations ||
4197          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
4198       return CombineTo(N, LoOpt, LoOpt);
4199   }
4200 
4201   if (HiExists) {
4202     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4203     AddToWorklist(Hi.getNode());
4204     SDValue HiOpt = combine(Hi.getNode());
4205     if (HiOpt.getNode() && HiOpt != Hi &&
4206         (!LegalOperations ||
4207          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
4208       return CombineTo(N, HiOpt, HiOpt);
4209   }
4210 
4211   return SDValue();
4212 }
4213 
visitSMUL_LOHI(SDNode * N)4214 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
4215   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
4216     return Res;
4217 
4218   EVT VT = N->getValueType(0);
4219   SDLoc DL(N);
4220 
4221   // If the type is twice as wide is legal, transform the mulhu to a wider
4222   // multiply plus a shift.
4223   if (VT.isSimple() && !VT.isVector()) {
4224     MVT Simple = VT.getSimpleVT();
4225     unsigned SimpleSize = Simple.getSizeInBits();
4226     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4227     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4228       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N->getOperand(0));
4229       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N->getOperand(1));
4230       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4231       // Compute the high part as N1.
4232       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4233             DAG.getConstant(SimpleSize, DL,
4234                             getShiftAmountTy(Lo.getValueType())));
4235       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4236       // Compute the low part as N0.
4237       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4238       return CombineTo(N, Lo, Hi);
4239     }
4240   }
4241 
4242   return SDValue();
4243 }
4244 
visitUMUL_LOHI(SDNode * N)4245 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
4246   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
4247     return Res;
4248 
4249   EVT VT = N->getValueType(0);
4250   SDLoc DL(N);
4251 
4252   // (umul_lohi N0, 0) -> (0, 0)
4253   if (isNullConstant(N->getOperand(1))) {
4254     SDValue Zero = DAG.getConstant(0, DL, VT);
4255     return CombineTo(N, Zero, Zero);
4256   }
4257 
4258   // (umul_lohi N0, 1) -> (N0, 0)
4259   if (isOneConstant(N->getOperand(1))) {
4260     SDValue Zero = DAG.getConstant(0, DL, VT);
4261     return CombineTo(N, N->getOperand(0), Zero);
4262   }
4263 
4264   // If the type is twice as wide is legal, transform the mulhu to a wider
4265   // multiply plus a shift.
4266   if (VT.isSimple() && !VT.isVector()) {
4267     MVT Simple = VT.getSimpleVT();
4268     unsigned SimpleSize = Simple.getSizeInBits();
4269     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4270     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4271       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N->getOperand(0));
4272       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N->getOperand(1));
4273       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4274       // Compute the high part as N1.
4275       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4276             DAG.getConstant(SimpleSize, DL,
4277                             getShiftAmountTy(Lo.getValueType())));
4278       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4279       // Compute the low part as N0.
4280       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4281       return CombineTo(N, Lo, Hi);
4282     }
4283   }
4284 
4285   return SDValue();
4286 }
4287 
visitMULO(SDNode * N)4288 SDValue DAGCombiner::visitMULO(SDNode *N) {
4289   SDValue N0 = N->getOperand(0);
4290   SDValue N1 = N->getOperand(1);
4291   EVT VT = N0.getValueType();
4292   bool IsSigned = (ISD::SMULO == N->getOpcode());
4293 
4294   EVT CarryVT = N->getValueType(1);
4295   SDLoc DL(N);
4296 
4297   // canonicalize constant to RHS.
4298   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4299       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4300     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
4301 
4302   // fold (mulo x, 0) -> 0 + no carry out
4303   if (isNullOrNullSplat(N1))
4304     return CombineTo(N, DAG.getConstant(0, DL, VT),
4305                      DAG.getConstant(0, DL, CarryVT));
4306 
4307   // (mulo x, 2) -> (addo x, x)
4308   if (ConstantSDNode *C2 = isConstOrConstSplat(N1))
4309     if (C2->getAPIntValue() == 2)
4310       return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
4311                          N->getVTList(), N0, N0);
4312 
4313   return SDValue();
4314 }
4315 
visitIMINMAX(SDNode * N)4316 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
4317   SDValue N0 = N->getOperand(0);
4318   SDValue N1 = N->getOperand(1);
4319   EVT VT = N0.getValueType();
4320 
4321   // fold vector ops
4322   if (VT.isVector())
4323     if (SDValue FoldedVOp = SimplifyVBinOp(N))
4324       return FoldedVOp;
4325 
4326   // fold operation with constant operands.
4327   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
4328   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
4329   if (N0C && N1C)
4330     return DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, N0C, N1C);
4331 
4332   // canonicalize constant to RHS
4333   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4334      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4335     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
4336 
4337   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
4338   // Only do this if the current op isn't legal and the flipped is.
4339   unsigned Opcode = N->getOpcode();
4340   if (!TLI.isOperationLegal(Opcode, VT) &&
4341       (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
4342       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
4343     unsigned AltOpcode;
4344     switch (Opcode) {
4345     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
4346     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
4347     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
4348     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
4349     default: llvm_unreachable("Unknown MINMAX opcode");
4350     }
4351     if (TLI.isOperationLegal(AltOpcode, VT))
4352       return DAG.getNode(AltOpcode, SDLoc(N), VT, N0, N1);
4353   }
4354 
4355   return SDValue();
4356 }
4357 
4358 /// If this is a bitwise logic instruction and both operands have the same
4359 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)4360 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
4361   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
4362   EVT VT = N0.getValueType();
4363   unsigned LogicOpcode = N->getOpcode();
4364   unsigned HandOpcode = N0.getOpcode();
4365   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
4366           LogicOpcode == ISD::XOR) && "Expected logic opcode");
4367   assert(HandOpcode == N1.getOpcode() && "Bad input!");
4368 
4369   // Bail early if none of these transforms apply.
4370   if (N0.getNumOperands() == 0)
4371     return SDValue();
4372 
4373   // FIXME: We should check number of uses of the operands to not increase
4374   //        the instruction count for all transforms.
4375 
4376   // Handle size-changing casts.
4377   SDValue X = N0.getOperand(0);
4378   SDValue Y = N1.getOperand(0);
4379   EVT XVT = X.getValueType();
4380   SDLoc DL(N);
4381   if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
4382       HandOpcode == ISD::SIGN_EXTEND) {
4383     // If both operands have other uses, this transform would create extra
4384     // instructions without eliminating anything.
4385     if (!N0.hasOneUse() && !N1.hasOneUse())
4386       return SDValue();
4387     // We need matching integer source types.
4388     if (XVT != Y.getValueType())
4389       return SDValue();
4390     // Don't create an illegal op during or after legalization. Don't ever
4391     // create an unsupported vector op.
4392     if ((VT.isVector() || LegalOperations) &&
4393         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
4394       return SDValue();
4395     // Avoid infinite looping with PromoteIntBinOp.
4396     // TODO: Should we apply desirable/legal constraints to all opcodes?
4397     if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
4398         !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
4399       return SDValue();
4400     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
4401     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4402     return DAG.getNode(HandOpcode, DL, VT, Logic);
4403   }
4404 
4405   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
4406   if (HandOpcode == ISD::TRUNCATE) {
4407     // If both operands have other uses, this transform would create extra
4408     // instructions without eliminating anything.
4409     if (!N0.hasOneUse() && !N1.hasOneUse())
4410       return SDValue();
4411     // We need matching source types.
4412     if (XVT != Y.getValueType())
4413       return SDValue();
4414     // Don't create an illegal op during or after legalization.
4415     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
4416       return SDValue();
4417     // Be extra careful sinking truncate. If it's free, there's no benefit in
4418     // widening a binop. Also, don't create a logic op on an illegal type.
4419     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
4420       return SDValue();
4421     if (!TLI.isTypeLegal(XVT))
4422       return SDValue();
4423     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4424     return DAG.getNode(HandOpcode, DL, VT, Logic);
4425   }
4426 
4427   // For binops SHL/SRL/SRA/AND:
4428   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
4429   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
4430        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
4431       N0.getOperand(1) == N1.getOperand(1)) {
4432     // If either operand has other uses, this transform is not an improvement.
4433     if (!N0.hasOneUse() || !N1.hasOneUse())
4434       return SDValue();
4435     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4436     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
4437   }
4438 
4439   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
4440   if (HandOpcode == ISD::BSWAP) {
4441     // If either operand has other uses, this transform is not an improvement.
4442     if (!N0.hasOneUse() || !N1.hasOneUse())
4443       return SDValue();
4444     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4445     return DAG.getNode(HandOpcode, DL, VT, Logic);
4446   }
4447 
4448   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
4449   // Only perform this optimization up until type legalization, before
4450   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
4451   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
4452   // we don't want to undo this promotion.
4453   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
4454   // on scalars.
4455   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
4456        Level <= AfterLegalizeTypes) {
4457     // Input types must be integer and the same.
4458     if (XVT.isInteger() && XVT == Y.getValueType() &&
4459         !(VT.isVector() && TLI.isTypeLegal(VT) &&
4460           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
4461       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
4462       return DAG.getNode(HandOpcode, DL, VT, Logic);
4463     }
4464   }
4465 
4466   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
4467   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
4468   // If both shuffles use the same mask, and both shuffle within a single
4469   // vector, then it is worthwhile to move the swizzle after the operation.
4470   // The type-legalizer generates this pattern when loading illegal
4471   // vector types from memory. In many cases this allows additional shuffle
4472   // optimizations.
4473   // There are other cases where moving the shuffle after the xor/and/or
4474   // is profitable even if shuffles don't perform a swizzle.
4475   // If both shuffles use the same mask, and both shuffles have the same first
4476   // or second operand, then it might still be profitable to move the shuffle
4477   // after the xor/and/or operation.
4478   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
4479     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
4480     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
4481     assert(X.getValueType() == Y.getValueType() &&
4482            "Inputs to shuffles are not the same type");
4483 
4484     // Check that both shuffles use the same mask. The masks are known to be of
4485     // the same length because the result vector type is the same.
4486     // Check also that shuffles have only one use to avoid introducing extra
4487     // instructions.
4488     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
4489         !SVN0->getMask().equals(SVN1->getMask()))
4490       return SDValue();
4491 
4492     // Don't try to fold this node if it requires introducing a
4493     // build vector of all zeros that might be illegal at this stage.
4494     SDValue ShOp = N0.getOperand(1);
4495     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
4496       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4497 
4498     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
4499     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
4500       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
4501                                   N0.getOperand(0), N1.getOperand(0));
4502       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
4503     }
4504 
4505     // Don't try to fold this node if it requires introducing a
4506     // build vector of all zeros that might be illegal at this stage.
4507     ShOp = N0.getOperand(0);
4508     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
4509       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4510 
4511     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
4512     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
4513       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
4514                                   N1.getOperand(1));
4515       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
4516     }
4517   }
4518 
4519   return SDValue();
4520 }
4521 
4522 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)4523 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
4524                                        const SDLoc &DL) {
4525   SDValue LL, LR, RL, RR, N0CC, N1CC;
4526   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
4527       !isSetCCEquivalent(N1, RL, RR, N1CC))
4528     return SDValue();
4529 
4530   assert(N0.getValueType() == N1.getValueType() &&
4531          "Unexpected operand types for bitwise logic op");
4532   assert(LL.getValueType() == LR.getValueType() &&
4533          RL.getValueType() == RR.getValueType() &&
4534          "Unexpected operand types for setcc");
4535 
4536   // If we're here post-legalization or the logic op type is not i1, the logic
4537   // op type must match a setcc result type. Also, all folds require new
4538   // operations on the left and right operands, so those types must match.
4539   EVT VT = N0.getValueType();
4540   EVT OpVT = LL.getValueType();
4541   if (LegalOperations || VT.getScalarType() != MVT::i1)
4542     if (VT != getSetCCResultType(OpVT))
4543       return SDValue();
4544   if (OpVT != RL.getValueType())
4545     return SDValue();
4546 
4547   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
4548   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
4549   bool IsInteger = OpVT.isInteger();
4550   if (LR == RR && CC0 == CC1 && IsInteger) {
4551     bool IsZero = isNullOrNullSplat(LR);
4552     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
4553 
4554     // All bits clear?
4555     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
4556     // All sign bits clear?
4557     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
4558     // Any bits set?
4559     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
4560     // Any sign bits set?
4561     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
4562 
4563     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
4564     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
4565     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
4566     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
4567     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
4568       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
4569       AddToWorklist(Or.getNode());
4570       return DAG.getSetCC(DL, VT, Or, LR, CC1);
4571     }
4572 
4573     // All bits set?
4574     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
4575     // All sign bits set?
4576     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
4577     // Any bits clear?
4578     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
4579     // Any sign bits clear?
4580     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
4581 
4582     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
4583     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
4584     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
4585     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
4586     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
4587       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
4588       AddToWorklist(And.getNode());
4589       return DAG.getSetCC(DL, VT, And, LR, CC1);
4590     }
4591   }
4592 
4593   // TODO: What is the 'or' equivalent of this fold?
4594   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
4595   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
4596       IsInteger && CC0 == ISD::SETNE &&
4597       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
4598        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
4599     SDValue One = DAG.getConstant(1, DL, OpVT);
4600     SDValue Two = DAG.getConstant(2, DL, OpVT);
4601     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
4602     AddToWorklist(Add.getNode());
4603     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
4604   }
4605 
4606   // Try more general transforms if the predicates match and the only user of
4607   // the compares is the 'and' or 'or'.
4608   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
4609       N0.hasOneUse() && N1.hasOneUse()) {
4610     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
4611     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
4612     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
4613       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
4614       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
4615       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
4616       SDValue Zero = DAG.getConstant(0, DL, OpVT);
4617       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
4618     }
4619 
4620     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
4621     // TODO - support non-uniform vector amounts.
4622     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
4623       // Match a shared variable operand and 2 non-opaque constant operands.
4624       ConstantSDNode *C0 = isConstOrConstSplat(LR);
4625       ConstantSDNode *C1 = isConstOrConstSplat(RR);
4626       if (LL == RL && C0 && C1 && !C0->isOpaque() && !C1->isOpaque()) {
4627         // Canonicalize larger constant as C0.
4628         if (C1->getAPIntValue().ugt(C0->getAPIntValue()))
4629           std::swap(C0, C1);
4630 
4631         // The difference of the constants must be a single bit.
4632         const APInt &C0Val = C0->getAPIntValue();
4633         const APInt &C1Val = C1->getAPIntValue();
4634         if ((C0Val - C1Val).isPowerOf2()) {
4635           // and/or (setcc X, C0, ne), (setcc X, C1, ne/eq) -->
4636           // setcc ((add X, -C1), ~(C0 - C1)), 0, ne/eq
4637           SDValue OffsetC = DAG.getConstant(-C1Val, DL, OpVT);
4638           SDValue Add = DAG.getNode(ISD::ADD, DL, OpVT, LL, OffsetC);
4639           SDValue MaskC = DAG.getConstant(~(C0Val - C1Val), DL, OpVT);
4640           SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Add, MaskC);
4641           SDValue Zero = DAG.getConstant(0, DL, OpVT);
4642           return DAG.getSetCC(DL, VT, And, Zero, CC0);
4643         }
4644       }
4645     }
4646   }
4647 
4648   // Canonicalize equivalent operands to LL == RL.
4649   if (LL == RR && LR == RL) {
4650     CC1 = ISD::getSetCCSwappedOperands(CC1);
4651     std::swap(RL, RR);
4652   }
4653 
4654   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
4655   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
4656   if (LL == RL && LR == RR) {
4657     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
4658                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
4659     if (NewCC != ISD::SETCC_INVALID &&
4660         (!LegalOperations ||
4661          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
4662           TLI.isOperationLegal(ISD::SETCC, OpVT))))
4663       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
4664   }
4665 
4666   return SDValue();
4667 }
4668 
4669 /// This contains all DAGCombine rules which reduce two values combined by
4670 /// an And operation to a single value. This makes them reusable in the context
4671 /// of visitSELECT(). Rules involving constants are not included as
4672 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)4673 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
4674   EVT VT = N1.getValueType();
4675   SDLoc DL(N);
4676 
4677   // fold (and x, undef) -> 0
4678   if (N0.isUndef() || N1.isUndef())
4679     return DAG.getConstant(0, DL, VT);
4680 
4681   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
4682     return V;
4683 
4684   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
4685       VT.getSizeInBits() <= 64) {
4686     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
4687       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
4688         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
4689         // immediate for an add, but it is legal if its top c2 bits are set,
4690         // transform the ADD so the immediate doesn't need to be materialized
4691         // in a register.
4692         APInt ADDC = ADDI->getAPIntValue();
4693         APInt SRLC = SRLI->getAPIntValue();
4694         if (ADDC.getMinSignedBits() <= 64 &&
4695             SRLC.ult(VT.getSizeInBits()) &&
4696             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
4697           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
4698                                              SRLC.getZExtValue());
4699           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
4700             ADDC |= Mask;
4701             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
4702               SDLoc DL0(N0);
4703               SDValue NewAdd =
4704                 DAG.getNode(ISD::ADD, DL0, VT,
4705                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
4706               CombineTo(N0.getNode(), NewAdd);
4707               // Return N so it doesn't get rechecked!
4708               return SDValue(N, 0);
4709             }
4710           }
4711         }
4712       }
4713     }
4714   }
4715 
4716   // Reduce bit extract of low half of an integer to the narrower type.
4717   // (and (srl i64:x, K), KMask) ->
4718   //   (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
4719   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
4720     if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
4721       if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
4722         unsigned Size = VT.getSizeInBits();
4723         const APInt &AndMask = CAnd->getAPIntValue();
4724         unsigned ShiftBits = CShift->getZExtValue();
4725 
4726         // Bail out, this node will probably disappear anyway.
4727         if (ShiftBits == 0)
4728           return SDValue();
4729 
4730         unsigned MaskBits = AndMask.countTrailingOnes();
4731         EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
4732 
4733         if (AndMask.isMask() &&
4734             // Required bits must not span the two halves of the integer and
4735             // must fit in the half size type.
4736             (ShiftBits + MaskBits <= Size / 2) &&
4737             TLI.isNarrowingProfitable(VT, HalfVT) &&
4738             TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
4739             TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
4740             TLI.isTruncateFree(VT, HalfVT) &&
4741             TLI.isZExtFree(HalfVT, VT)) {
4742           // The isNarrowingProfitable is to avoid regressions on PPC and
4743           // AArch64 which match a few 64-bit bit insert / bit extract patterns
4744           // on downstream users of this. Those patterns could probably be
4745           // extended to handle extensions mixed in.
4746 
4747           SDValue SL(N0);
4748           assert(MaskBits <= Size);
4749 
4750           // Extracting the highest bit of the low half.
4751           EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
4752           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
4753                                       N0.getOperand(0));
4754 
4755           SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
4756           SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
4757           SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
4758           SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
4759           return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
4760         }
4761       }
4762     }
4763   }
4764 
4765   return SDValue();
4766 }
4767 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)4768 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
4769                                    EVT LoadResultTy, EVT &ExtVT) {
4770   if (!AndC->getAPIntValue().isMask())
4771     return false;
4772 
4773   unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
4774 
4775   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
4776   EVT LoadedVT = LoadN->getMemoryVT();
4777 
4778   if (ExtVT == LoadedVT &&
4779       (!LegalOperations ||
4780        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
4781     // ZEXTLOAD will match without needing to change the size of the value being
4782     // loaded.
4783     return true;
4784   }
4785 
4786   // Do not change the width of a volatile or atomic loads.
4787   if (!LoadN->isSimple())
4788     return false;
4789 
4790   // Do not generate loads of non-round integer types since these can
4791   // be expensive (and would be wrong if the type is not byte sized).
4792   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
4793     return false;
4794 
4795   if (LegalOperations &&
4796       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
4797     return false;
4798 
4799   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
4800     return false;
4801 
4802   return true;
4803 }
4804 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)4805 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
4806                                     ISD::LoadExtType ExtType, EVT &MemVT,
4807                                     unsigned ShAmt) {
4808   if (!LDST)
4809     return false;
4810   // Only allow byte offsets.
4811   if (ShAmt % 8)
4812     return false;
4813 
4814   // Do not generate loads of non-round integer types since these can
4815   // be expensive (and would be wrong if the type is not byte sized).
4816   if (!MemVT.isRound())
4817     return false;
4818 
4819   // Don't change the width of a volatile or atomic loads.
4820   if (!LDST->isSimple())
4821     return false;
4822 
4823   // Verify that we are actually reducing a load width here.
4824   if (LDST->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits())
4825     return false;
4826 
4827   // Ensure that this isn't going to produce an unsupported memory access.
4828   if (ShAmt &&
4829       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
4830                               LDST->getAddressSpace(), ShAmt / 8,
4831                               LDST->getMemOperand()->getFlags()))
4832     return false;
4833 
4834   // It's not possible to generate a constant of extended or untyped type.
4835   EVT PtrType = LDST->getBasePtr().getValueType();
4836   if (PtrType == MVT::Untyped || PtrType.isExtended())
4837     return false;
4838 
4839   if (isa<LoadSDNode>(LDST)) {
4840     LoadSDNode *Load = cast<LoadSDNode>(LDST);
4841     // Don't transform one with multiple uses, this would require adding a new
4842     // load.
4843     if (!SDValue(Load, 0).hasOneUse())
4844       return false;
4845 
4846     if (LegalOperations &&
4847         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
4848       return false;
4849 
4850     // For the transform to be legal, the load must produce only two values
4851     // (the value loaded and the chain).  Don't transform a pre-increment
4852     // load, for example, which produces an extra value.  Otherwise the
4853     // transformation is not equivalent, and the downstream logic to replace
4854     // uses gets things wrong.
4855     if (Load->getNumValues() > 2)
4856       return false;
4857 
4858     // If the load that we're shrinking is an extload and we're not just
4859     // discarding the extension we can't simply shrink the load. Bail.
4860     // TODO: It would be possible to merge the extensions in some cases.
4861     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
4862         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
4863       return false;
4864 
4865     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
4866       return false;
4867   } else {
4868     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
4869     StoreSDNode *Store = cast<StoreSDNode>(LDST);
4870     // Can't write outside the original store
4871     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
4872       return false;
4873 
4874     if (LegalOperations &&
4875         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
4876       return false;
4877   }
4878   return true;
4879 }
4880 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)4881 bool DAGCombiner::SearchForAndLoads(SDNode *N,
4882                                     SmallVectorImpl<LoadSDNode*> &Loads,
4883                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
4884                                     ConstantSDNode *Mask,
4885                                     SDNode *&NodeToMask) {
4886   // Recursively search for the operands, looking for loads which can be
4887   // narrowed.
4888   for (SDValue Op : N->op_values()) {
4889     if (Op.getValueType().isVector())
4890       return false;
4891 
4892     // Some constants may need fixing up later if they are too large.
4893     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
4894       if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
4895           (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
4896         NodesWithConsts.insert(N);
4897       continue;
4898     }
4899 
4900     if (!Op.hasOneUse())
4901       return false;
4902 
4903     switch(Op.getOpcode()) {
4904     case ISD::LOAD: {
4905       auto *Load = cast<LoadSDNode>(Op);
4906       EVT ExtVT;
4907       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
4908           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
4909 
4910         // ZEXTLOAD is already small enough.
4911         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
4912             ExtVT.bitsGE(Load->getMemoryVT()))
4913           continue;
4914 
4915         // Use LE to convert equal sized loads to zext.
4916         if (ExtVT.bitsLE(Load->getMemoryVT()))
4917           Loads.push_back(Load);
4918 
4919         continue;
4920       }
4921       return false;
4922     }
4923     case ISD::ZERO_EXTEND:
4924     case ISD::AssertZext: {
4925       unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
4926       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
4927       EVT VT = Op.getOpcode() == ISD::AssertZext ?
4928         cast<VTSDNode>(Op.getOperand(1))->getVT() :
4929         Op.getOperand(0).getValueType();
4930 
4931       // We can accept extending nodes if the mask is wider or an equal
4932       // width to the original type.
4933       if (ExtVT.bitsGE(VT))
4934         continue;
4935       break;
4936     }
4937     case ISD::OR:
4938     case ISD::XOR:
4939     case ISD::AND:
4940       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
4941                              NodeToMask))
4942         return false;
4943       continue;
4944     }
4945 
4946     // Allow one node which will masked along with any loads found.
4947     if (NodeToMask)
4948       return false;
4949 
4950     // Also ensure that the node to be masked only produces one data result.
4951     NodeToMask = Op.getNode();
4952     if (NodeToMask->getNumValues() > 1) {
4953       bool HasValue = false;
4954       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
4955         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
4956         if (VT != MVT::Glue && VT != MVT::Other) {
4957           if (HasValue) {
4958             NodeToMask = nullptr;
4959             return false;
4960           }
4961           HasValue = true;
4962         }
4963       }
4964       assert(HasValue && "Node to be masked has no data result?");
4965     }
4966   }
4967   return true;
4968 }
4969 
BackwardsPropagateMask(SDNode * N)4970 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
4971   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
4972   if (!Mask)
4973     return false;
4974 
4975   if (!Mask->getAPIntValue().isMask())
4976     return false;
4977 
4978   // No need to do anything if the and directly uses a load.
4979   if (isa<LoadSDNode>(N->getOperand(0)))
4980     return false;
4981 
4982   SmallVector<LoadSDNode*, 8> Loads;
4983   SmallPtrSet<SDNode*, 2> NodesWithConsts;
4984   SDNode *FixupNode = nullptr;
4985   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
4986     if (Loads.size() == 0)
4987       return false;
4988 
4989     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
4990     SDValue MaskOp = N->getOperand(1);
4991 
4992     // If it exists, fixup the single node we allow in the tree that needs
4993     // masking.
4994     if (FixupNode) {
4995       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
4996       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
4997                                 FixupNode->getValueType(0),
4998                                 SDValue(FixupNode, 0), MaskOp);
4999       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
5000       if (And.getOpcode() == ISD ::AND)
5001         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
5002     }
5003 
5004     // Narrow any constants that need it.
5005     for (auto *LogicN : NodesWithConsts) {
5006       SDValue Op0 = LogicN->getOperand(0);
5007       SDValue Op1 = LogicN->getOperand(1);
5008 
5009       if (isa<ConstantSDNode>(Op0))
5010           std::swap(Op0, Op1);
5011 
5012       SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
5013                                 Op1, MaskOp);
5014 
5015       DAG.UpdateNodeOperands(LogicN, Op0, And);
5016     }
5017 
5018     // Create narrow loads.
5019     for (auto *Load : Loads) {
5020       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
5021       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
5022                                 SDValue(Load, 0), MaskOp);
5023       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
5024       if (And.getOpcode() == ISD ::AND)
5025         And = SDValue(
5026             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
5027       SDValue NewLoad = ReduceLoadWidth(And.getNode());
5028       assert(NewLoad &&
5029              "Shouldn't be masking the load if it can't be narrowed");
5030       CombineTo(Load, NewLoad, NewLoad.getValue(1));
5031     }
5032     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
5033     return true;
5034   }
5035   return false;
5036 }
5037 
5038 // Unfold
5039 //    x &  (-1 'logical shift' y)
5040 // To
5041 //    (x 'opposite logical shift' y) 'logical shift' y
5042 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)5043 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
5044   assert(N->getOpcode() == ISD::AND);
5045 
5046   SDValue N0 = N->getOperand(0);
5047   SDValue N1 = N->getOperand(1);
5048 
5049   // Do we actually prefer shifts over mask?
5050   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
5051     return SDValue();
5052 
5053   // Try to match  (-1 '[outer] logical shift' y)
5054   unsigned OuterShift;
5055   unsigned InnerShift; // The opposite direction to the OuterShift.
5056   SDValue Y;           // Shift amount.
5057   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
5058     if (!M.hasOneUse())
5059       return false;
5060     OuterShift = M->getOpcode();
5061     if (OuterShift == ISD::SHL)
5062       InnerShift = ISD::SRL;
5063     else if (OuterShift == ISD::SRL)
5064       InnerShift = ISD::SHL;
5065     else
5066       return false;
5067     if (!isAllOnesConstant(M->getOperand(0)))
5068       return false;
5069     Y = M->getOperand(1);
5070     return true;
5071   };
5072 
5073   SDValue X;
5074   if (matchMask(N1))
5075     X = N0;
5076   else if (matchMask(N0))
5077     X = N1;
5078   else
5079     return SDValue();
5080 
5081   SDLoc DL(N);
5082   EVT VT = N->getValueType(0);
5083 
5084   //     tmp = x   'opposite logical shift' y
5085   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
5086   //     ret = tmp 'logical shift' y
5087   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
5088 
5089   return T1;
5090 }
5091 
5092 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
5093 /// For a target with a bit test, this is expected to become test + set and save
5094 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)5095 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
5096   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
5097 
5098   // This is probably not worthwhile without a supported type.
5099   EVT VT = And->getValueType(0);
5100   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5101   if (!TLI.isTypeLegal(VT))
5102     return SDValue();
5103 
5104   // Look through an optional extension and find a 'not'.
5105   // TODO: Should we favor test+set even without the 'not' op?
5106   SDValue Not = And->getOperand(0), And1 = And->getOperand(1);
5107   if (Not.getOpcode() == ISD::ANY_EXTEND)
5108     Not = Not.getOperand(0);
5109   if (!isBitwiseNot(Not) || !Not.hasOneUse() || !isOneConstant(And1))
5110     return SDValue();
5111 
5112   // Look though an optional truncation. The source operand may not be the same
5113   // type as the original 'and', but that is ok because we are masking off
5114   // everything but the low bit.
5115   SDValue Srl = Not.getOperand(0);
5116   if (Srl.getOpcode() == ISD::TRUNCATE)
5117     Srl = Srl.getOperand(0);
5118 
5119   // Match a shift-right by constant.
5120   if (Srl.getOpcode() != ISD::SRL || !Srl.hasOneUse() ||
5121       !isa<ConstantSDNode>(Srl.getOperand(1)))
5122     return SDValue();
5123 
5124   // We might have looked through casts that make this transform invalid.
5125   // TODO: If the source type is wider than the result type, do the mask and
5126   //       compare in the source type.
5127   const APInt &ShiftAmt = Srl.getConstantOperandAPInt(1);
5128   unsigned VTBitWidth = VT.getSizeInBits();
5129   if (ShiftAmt.uge(VTBitWidth))
5130     return SDValue();
5131 
5132   // Turn this into a bit-test pattern using mask op + setcc:
5133   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
5134   SDLoc DL(And);
5135   SDValue X = DAG.getZExtOrTrunc(Srl.getOperand(0), DL, VT);
5136   EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
5137   SDValue Mask = DAG.getConstant(
5138       APInt::getOneBitSet(VTBitWidth, ShiftAmt.getZExtValue()), DL, VT);
5139   SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
5140   SDValue Zero = DAG.getConstant(0, DL, VT);
5141   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
5142   return DAG.getZExtOrTrunc(Setcc, DL, VT);
5143 }
5144 
visitAND(SDNode * N)5145 SDValue DAGCombiner::visitAND(SDNode *N) {
5146   SDValue N0 = N->getOperand(0);
5147   SDValue N1 = N->getOperand(1);
5148   EVT VT = N1.getValueType();
5149 
5150   // x & x --> x
5151   if (N0 == N1)
5152     return N0;
5153 
5154   // fold vector ops
5155   if (VT.isVector()) {
5156     if (SDValue FoldedVOp = SimplifyVBinOp(N))
5157       return FoldedVOp;
5158 
5159     // fold (and x, 0) -> 0, vector edition
5160     if (ISD::isBuildVectorAllZeros(N0.getNode()))
5161       // do not return N0, because undef node may exist in N0
5162       return DAG.getConstant(APInt::getNullValue(N0.getScalarValueSizeInBits()),
5163                              SDLoc(N), N0.getValueType());
5164     if (ISD::isBuildVectorAllZeros(N1.getNode()))
5165       // do not return N1, because undef node may exist in N1
5166       return DAG.getConstant(APInt::getNullValue(N1.getScalarValueSizeInBits()),
5167                              SDLoc(N), N1.getValueType());
5168 
5169     // fold (and x, -1) -> x, vector edition
5170     if (ISD::isBuildVectorAllOnes(N0.getNode()))
5171       return N1;
5172     if (ISD::isBuildVectorAllOnes(N1.getNode()))
5173       return N0;
5174   }
5175 
5176   // fold (and c1, c2) -> c1&c2
5177   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
5178   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5179   if (N0C && N1C && !N1C->isOpaque())
5180     return DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, N0C, N1C);
5181   // canonicalize constant to RHS
5182   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5183       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5184     return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
5185   // fold (and x, -1) -> x
5186   if (isAllOnesConstant(N1))
5187     return N0;
5188   // if (and x, c) is known to be zero, return 0
5189   unsigned BitWidth = VT.getScalarSizeInBits();
5190   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
5191                                    APInt::getAllOnesValue(BitWidth)))
5192     return DAG.getConstant(0, SDLoc(N), VT);
5193 
5194   if (SDValue NewSel = foldBinOpIntoSelect(N))
5195     return NewSel;
5196 
5197   // reassociate and
5198   if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
5199     return RAND;
5200 
5201   // Try to convert a constant mask AND into a shuffle clear mask.
5202   if (VT.isVector())
5203     if (SDValue Shuffle = XformToShuffleWithZero(N))
5204       return Shuffle;
5205 
5206   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
5207     return Combined;
5208 
5209   // fold (and (or x, C), D) -> D if (C & D) == D
5210   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
5211     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
5212   };
5213   if (N0.getOpcode() == ISD::OR &&
5214       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
5215     return N1;
5216   // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
5217   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
5218     SDValue N0Op0 = N0.getOperand(0);
5219     APInt Mask = ~N1C->getAPIntValue();
5220     Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
5221     if (DAG.MaskedValueIsZero(N0Op0, Mask)) {
5222       SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
5223                                  N0.getValueType(), N0Op0);
5224 
5225       // Replace uses of the AND with uses of the Zero extend node.
5226       CombineTo(N, Zext);
5227 
5228       // We actually want to replace all uses of the any_extend with the
5229       // zero_extend, to avoid duplicating things.  This will later cause this
5230       // AND to be folded.
5231       CombineTo(N0.getNode(), Zext);
5232       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
5233     }
5234   }
5235 
5236   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
5237   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
5238   // already be zero by virtue of the width of the base type of the load.
5239   //
5240   // the 'X' node here can either be nothing or an extract_vector_elt to catch
5241   // more cases.
5242   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
5243        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
5244        N0.getOperand(0).getOpcode() == ISD::LOAD &&
5245        N0.getOperand(0).getResNo() == 0) ||
5246       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
5247     LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
5248                                          N0 : N0.getOperand(0) );
5249 
5250     // Get the constant (if applicable) the zero'th operand is being ANDed with.
5251     // This can be a pure constant or a vector splat, in which case we treat the
5252     // vector as a scalar and use the splat value.
5253     APInt Constant = APInt::getNullValue(1);
5254     if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
5255       Constant = C->getAPIntValue();
5256     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
5257       APInt SplatValue, SplatUndef;
5258       unsigned SplatBitSize;
5259       bool HasAnyUndefs;
5260       bool IsSplat = Vector->isConstantSplat(SplatValue, SplatUndef,
5261                                              SplatBitSize, HasAnyUndefs);
5262       if (IsSplat) {
5263         // Undef bits can contribute to a possible optimisation if set, so
5264         // set them.
5265         SplatValue |= SplatUndef;
5266 
5267         // The splat value may be something like "0x00FFFFFF", which means 0 for
5268         // the first vector value and FF for the rest, repeating. We need a mask
5269         // that will apply equally to all members of the vector, so AND all the
5270         // lanes of the constant together.
5271         unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
5272 
5273         // If the splat value has been compressed to a bitlength lower
5274         // than the size of the vector lane, we need to re-expand it to
5275         // the lane size.
5276         if (EltBitWidth > SplatBitSize)
5277           for (SplatValue = SplatValue.zextOrTrunc(EltBitWidth);
5278                SplatBitSize < EltBitWidth; SplatBitSize = SplatBitSize * 2)
5279             SplatValue |= SplatValue.shl(SplatBitSize);
5280 
5281         // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
5282         // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
5283         if ((SplatBitSize % EltBitWidth) == 0) {
5284           Constant = APInt::getAllOnesValue(EltBitWidth);
5285           for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
5286             Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
5287         }
5288       }
5289     }
5290 
5291     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
5292     // actually legal and isn't going to get expanded, else this is a false
5293     // optimisation.
5294     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
5295                                                     Load->getValueType(0),
5296                                                     Load->getMemoryVT());
5297 
5298     // Resize the constant to the same size as the original memory access before
5299     // extension. If it is still the AllOnesValue then this AND is completely
5300     // unneeded.
5301     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
5302 
5303     bool B;
5304     switch (Load->getExtensionType()) {
5305     default: B = false; break;
5306     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
5307     case ISD::ZEXTLOAD:
5308     case ISD::NON_EXTLOAD: B = true; break;
5309     }
5310 
5311     if (B && Constant.isAllOnesValue()) {
5312       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
5313       // preserve semantics once we get rid of the AND.
5314       SDValue NewLoad(Load, 0);
5315 
5316       // Fold the AND away. NewLoad may get replaced immediately.
5317       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
5318 
5319       if (Load->getExtensionType() == ISD::EXTLOAD) {
5320         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
5321                               Load->getValueType(0), SDLoc(Load),
5322                               Load->getChain(), Load->getBasePtr(),
5323                               Load->getOffset(), Load->getMemoryVT(),
5324                               Load->getMemOperand());
5325         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
5326         if (Load->getNumValues() == 3) {
5327           // PRE/POST_INC loads have 3 values.
5328           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
5329                            NewLoad.getValue(2) };
5330           CombineTo(Load, To, 3, true);
5331         } else {
5332           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
5333         }
5334       }
5335 
5336       return SDValue(N, 0); // Return N so it doesn't get rechecked!
5337     }
5338   }
5339 
5340   // fold (and (load x), 255) -> (zextload x, i8)
5341   // fold (and (extload x, i16), 255) -> (zextload x, i8)
5342   // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8)
5343   if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
5344                                 (N0.getOpcode() == ISD::ANY_EXTEND &&
5345                                  N0.getOperand(0).getOpcode() == ISD::LOAD))) {
5346     if (SDValue Res = ReduceLoadWidth(N)) {
5347       LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
5348         ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
5349       AddToWorklist(N);
5350       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res);
5351       return SDValue(N, 0);
5352     }
5353   }
5354 
5355   if (LegalTypes) {
5356     // Attempt to propagate the AND back up to the leaves which, if they're
5357     // loads, can be combined to narrow loads and the AND node can be removed.
5358     // Perform after legalization so that extend nodes will already be
5359     // combined into the loads.
5360     if (BackwardsPropagateMask(N))
5361       return SDValue(N, 0);
5362   }
5363 
5364   if (SDValue Combined = visitANDLike(N0, N1, N))
5365     return Combined;
5366 
5367   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
5368   if (N0.getOpcode() == N1.getOpcode())
5369     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
5370       return V;
5371 
5372   // Masking the negated extension of a boolean is just the zero-extended
5373   // boolean:
5374   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
5375   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
5376   //
5377   // Note: the SimplifyDemandedBits fold below can make an information-losing
5378   // transform, and then we have no way to find this better fold.
5379   if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
5380     if (isNullOrNullSplat(N0.getOperand(0))) {
5381       SDValue SubRHS = N0.getOperand(1);
5382       if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
5383           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
5384         return SubRHS;
5385       if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
5386           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
5387         return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
5388     }
5389   }
5390 
5391   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
5392   // fold (and (sra)) -> (and (srl)) when possible.
5393   if (SimplifyDemandedBits(SDValue(N, 0)))
5394     return SDValue(N, 0);
5395 
5396   // fold (zext_inreg (extload x)) -> (zextload x)
5397   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
5398   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
5399       (ISD::isEXTLoad(N0.getNode()) ||
5400        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
5401     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
5402     EVT MemVT = LN0->getMemoryVT();
5403     // If we zero all the possible extended bits, then we can turn this into
5404     // a zextload if we are running before legalize or the operation is legal.
5405     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
5406     unsigned MemBitSize = MemVT.getScalarSizeInBits();
5407     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
5408     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
5409         ((!LegalOperations && LN0->isSimple()) ||
5410          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
5411       SDValue ExtLoad =
5412           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
5413                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
5414       AddToWorklist(N);
5415       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
5416       return SDValue(N, 0); // Return N so it doesn't get rechecked!
5417     }
5418   }
5419 
5420   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
5421   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
5422     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
5423                                            N0.getOperand(1), false))
5424       return BSwap;
5425   }
5426 
5427   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
5428     return Shifts;
5429 
5430   if (TLI.hasBitTest(N0, N1))
5431     if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
5432       return V;
5433 
5434   return SDValue();
5435 }
5436 
5437 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)5438 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
5439                                         bool DemandHighBits) {
5440   if (!LegalOperations)
5441     return SDValue();
5442 
5443   EVT VT = N->getValueType(0);
5444   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
5445     return SDValue();
5446   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
5447     return SDValue();
5448 
5449   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
5450   bool LookPassAnd0 = false;
5451   bool LookPassAnd1 = false;
5452   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
5453       std::swap(N0, N1);
5454   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
5455       std::swap(N0, N1);
5456   if (N0.getOpcode() == ISD::AND) {
5457     if (!N0.getNode()->hasOneUse())
5458       return SDValue();
5459     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5460     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
5461     // This is needed for X86.
5462     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
5463                   N01C->getZExtValue() != 0xFFFF))
5464       return SDValue();
5465     N0 = N0.getOperand(0);
5466     LookPassAnd0 = true;
5467   }
5468 
5469   if (N1.getOpcode() == ISD::AND) {
5470     if (!N1.getNode()->hasOneUse())
5471       return SDValue();
5472     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
5473     if (!N11C || N11C->getZExtValue() != 0xFF)
5474       return SDValue();
5475     N1 = N1.getOperand(0);
5476     LookPassAnd1 = true;
5477   }
5478 
5479   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
5480     std::swap(N0, N1);
5481   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
5482     return SDValue();
5483   if (!N0.getNode()->hasOneUse() || !N1.getNode()->hasOneUse())
5484     return SDValue();
5485 
5486   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5487   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
5488   if (!N01C || !N11C)
5489     return SDValue();
5490   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
5491     return SDValue();
5492 
5493   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
5494   SDValue N00 = N0->getOperand(0);
5495   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
5496     if (!N00.getNode()->hasOneUse())
5497       return SDValue();
5498     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
5499     if (!N001C || N001C->getZExtValue() != 0xFF)
5500       return SDValue();
5501     N00 = N00.getOperand(0);
5502     LookPassAnd0 = true;
5503   }
5504 
5505   SDValue N10 = N1->getOperand(0);
5506   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
5507     if (!N10.getNode()->hasOneUse())
5508       return SDValue();
5509     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
5510     // Also allow 0xFFFF since the bits will be shifted out. This is needed
5511     // for X86.
5512     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
5513                    N101C->getZExtValue() != 0xFFFF))
5514       return SDValue();
5515     N10 = N10.getOperand(0);
5516     LookPassAnd1 = true;
5517   }
5518 
5519   if (N00 != N10)
5520     return SDValue();
5521 
5522   // Make sure everything beyond the low halfword gets set to zero since the SRL
5523   // 16 will clear the top bits.
5524   unsigned OpSizeInBits = VT.getSizeInBits();
5525   if (DemandHighBits && OpSizeInBits > 16) {
5526     // If the left-shift isn't masked out then the only way this is a bswap is
5527     // if all bits beyond the low 8 are 0. In that case the entire pattern
5528     // reduces to a left shift anyway: leave it for other parts of the combiner.
5529     if (!LookPassAnd0)
5530       return SDValue();
5531 
5532     // However, if the right shift isn't masked out then it might be because
5533     // it's not needed. See if we can spot that too.
5534     if (!LookPassAnd1 &&
5535         !DAG.MaskedValueIsZero(
5536             N10, APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - 16)))
5537       return SDValue();
5538   }
5539 
5540   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
5541   if (OpSizeInBits > 16) {
5542     SDLoc DL(N);
5543     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
5544                       DAG.getConstant(OpSizeInBits - 16, DL,
5545                                       getShiftAmountTy(VT)));
5546   }
5547   return Res;
5548 }
5549 
5550 /// Return true if the specified node is an element that makes up a 32-bit
5551 /// packed halfword byteswap.
5552 /// ((x & 0x000000ff) << 8) |
5553 /// ((x & 0x0000ff00) >> 8) |
5554 /// ((x & 0x00ff0000) << 8) |
5555 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)5556 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
5557   if (!N.getNode()->hasOneUse())
5558     return false;
5559 
5560   unsigned Opc = N.getOpcode();
5561   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
5562     return false;
5563 
5564   SDValue N0 = N.getOperand(0);
5565   unsigned Opc0 = N0.getOpcode();
5566   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
5567     return false;
5568 
5569   ConstantSDNode *N1C = nullptr;
5570   // SHL or SRL: look upstream for AND mask operand
5571   if (Opc == ISD::AND)
5572     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5573   else if (Opc0 == ISD::AND)
5574     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5575   if (!N1C)
5576     return false;
5577 
5578   unsigned MaskByteOffset;
5579   switch (N1C->getZExtValue()) {
5580   default:
5581     return false;
5582   case 0xFF:       MaskByteOffset = 0; break;
5583   case 0xFF00:     MaskByteOffset = 1; break;
5584   case 0xFFFF:
5585     // In case demanded bits didn't clear the bits that will be shifted out.
5586     // This is needed for X86.
5587     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
5588       MaskByteOffset = 1;
5589       break;
5590     }
5591     return false;
5592   case 0xFF0000:   MaskByteOffset = 2; break;
5593   case 0xFF000000: MaskByteOffset = 3; break;
5594   }
5595 
5596   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
5597   if (Opc == ISD::AND) {
5598     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
5599       // (x >> 8) & 0xff
5600       // (x >> 8) & 0xff0000
5601       if (Opc0 != ISD::SRL)
5602         return false;
5603       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5604       if (!C || C->getZExtValue() != 8)
5605         return false;
5606     } else {
5607       // (x << 8) & 0xff00
5608       // (x << 8) & 0xff000000
5609       if (Opc0 != ISD::SHL)
5610         return false;
5611       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
5612       if (!C || C->getZExtValue() != 8)
5613         return false;
5614     }
5615   } else if (Opc == ISD::SHL) {
5616     // (x & 0xff) << 8
5617     // (x & 0xff0000) << 8
5618     if (MaskByteOffset != 0 && MaskByteOffset != 2)
5619       return false;
5620     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5621     if (!C || C->getZExtValue() != 8)
5622       return false;
5623   } else { // Opc == ISD::SRL
5624     // (x & 0xff00) >> 8
5625     // (x & 0xff000000) >> 8
5626     if (MaskByteOffset != 1 && MaskByteOffset != 3)
5627       return false;
5628     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
5629     if (!C || C->getZExtValue() != 8)
5630       return false;
5631   }
5632 
5633   if (Parts[MaskByteOffset])
5634     return false;
5635 
5636   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
5637   return true;
5638 }
5639 
5640 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)5641 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
5642   if (N.getOpcode() == ISD::OR)
5643     return isBSwapHWordElement(N.getOperand(0), Parts) &&
5644            isBSwapHWordElement(N.getOperand(1), Parts);
5645 
5646   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
5647     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
5648     if (!C || C->getAPIntValue() != 16)
5649       return false;
5650     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
5651     return true;
5652   }
5653 
5654   return false;
5655 }
5656 
5657 /// Match a 32-bit packed halfword bswap. That is
5658 /// ((x & 0x000000ff) << 8) |
5659 /// ((x & 0x0000ff00) >> 8) |
5660 /// ((x & 0x00ff0000) << 8) |
5661 /// ((x & 0xff000000) >> 8)
5662 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)5663 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
5664   if (!LegalOperations)
5665     return SDValue();
5666 
5667   EVT VT = N->getValueType(0);
5668   if (VT != MVT::i32)
5669     return SDValue();
5670   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
5671     return SDValue();
5672 
5673   // Look for either
5674   // (or (bswaphpair), (bswaphpair))
5675   // (or (or (bswaphpair), (and)), (and))
5676   // (or (or (and), (bswaphpair)), (and))
5677   SDNode *Parts[4] = {};
5678 
5679   if (isBSwapHWordPair(N0, Parts)) {
5680     // (or (or (and), (and)), (or (and), (and)))
5681     if (!isBSwapHWordPair(N1, Parts))
5682       return SDValue();
5683   } else if (N0.getOpcode() == ISD::OR) {
5684     // (or (or (or (and), (and)), (and)), (and))
5685     if (!isBSwapHWordElement(N1, Parts))
5686       return SDValue();
5687     SDValue N00 = N0.getOperand(0);
5688     SDValue N01 = N0.getOperand(1);
5689     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
5690         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
5691       return SDValue();
5692   } else
5693     return SDValue();
5694 
5695   // Make sure the parts are all coming from the same node.
5696   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
5697     return SDValue();
5698 
5699   SDLoc DL(N);
5700   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
5701                               SDValue(Parts[0], 0));
5702 
5703   // Result of the bswap should be rotated by 16. If it's not legal, then
5704   // do  (x << 16) | (x >> 16).
5705   SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
5706   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
5707     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
5708   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
5709     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
5710   return DAG.getNode(ISD::OR, DL, VT,
5711                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
5712                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
5713 }
5714 
5715 /// This contains all DAGCombine rules which reduce two values combined by
5716 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)5717 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
5718   EVT VT = N1.getValueType();
5719   SDLoc DL(N);
5720 
5721   // fold (or x, undef) -> -1
5722   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
5723     return DAG.getAllOnesConstant(DL, VT);
5724 
5725   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
5726     return V;
5727 
5728   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
5729   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
5730       // Don't increase # computations.
5731       (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
5732     // We can only do this xform if we know that bits from X that are set in C2
5733     // but not in C1 are already zero.  Likewise for Y.
5734     if (const ConstantSDNode *N0O1C =
5735         getAsNonOpaqueConstant(N0.getOperand(1))) {
5736       if (const ConstantSDNode *N1O1C =
5737           getAsNonOpaqueConstant(N1.getOperand(1))) {
5738         // We can only do this xform if we know that bits from X that are set in
5739         // C2 but not in C1 are already zero.  Likewise for Y.
5740         const APInt &LHSMask = N0O1C->getAPIntValue();
5741         const APInt &RHSMask = N1O1C->getAPIntValue();
5742 
5743         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
5744             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
5745           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
5746                                   N0.getOperand(0), N1.getOperand(0));
5747           return DAG.getNode(ISD::AND, DL, VT, X,
5748                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
5749         }
5750       }
5751     }
5752   }
5753 
5754   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
5755   if (N0.getOpcode() == ISD::AND &&
5756       N1.getOpcode() == ISD::AND &&
5757       N0.getOperand(0) == N1.getOperand(0) &&
5758       // Don't increase # computations.
5759       (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
5760     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
5761                             N0.getOperand(1), N1.getOperand(1));
5762     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
5763   }
5764 
5765   return SDValue();
5766 }
5767 
5768 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)5769 static SDValue visitORCommutative(
5770     SelectionDAG &DAG, SDValue N0, SDValue N1, SDNode *N) {
5771   EVT VT = N0.getValueType();
5772   if (N0.getOpcode() == ISD::AND) {
5773     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
5774     if (isBitwiseNot(N0.getOperand(1)) && N0.getOperand(1).getOperand(0) == N1)
5775       return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(0), N1);
5776 
5777     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
5778     if (isBitwiseNot(N0.getOperand(0)) && N0.getOperand(0).getOperand(0) == N1)
5779       return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(1), N1);
5780   }
5781 
5782   return SDValue();
5783 }
5784 
visitOR(SDNode * N)5785 SDValue DAGCombiner::visitOR(SDNode *N) {
5786   SDValue N0 = N->getOperand(0);
5787   SDValue N1 = N->getOperand(1);
5788   EVT VT = N1.getValueType();
5789 
5790   // x | x --> x
5791   if (N0 == N1)
5792     return N0;
5793 
5794   // fold vector ops
5795   if (VT.isVector()) {
5796     if (SDValue FoldedVOp = SimplifyVBinOp(N))
5797       return FoldedVOp;
5798 
5799     // fold (or x, 0) -> x, vector edition
5800     if (ISD::isBuildVectorAllZeros(N0.getNode()))
5801       return N1;
5802     if (ISD::isBuildVectorAllZeros(N1.getNode()))
5803       return N0;
5804 
5805     // fold (or x, -1) -> -1, vector edition
5806     if (ISD::isBuildVectorAllOnes(N0.getNode()))
5807       // do not return N0, because undef node may exist in N0
5808       return DAG.getAllOnesConstant(SDLoc(N), N0.getValueType());
5809     if (ISD::isBuildVectorAllOnes(N1.getNode()))
5810       // do not return N1, because undef node may exist in N1
5811       return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
5812 
5813     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
5814     // Do this only if the resulting shuffle is legal.
5815     if (isa<ShuffleVectorSDNode>(N0) &&
5816         isa<ShuffleVectorSDNode>(N1) &&
5817         // Avoid folding a node with illegal type.
5818         TLI.isTypeLegal(VT)) {
5819       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
5820       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
5821       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
5822       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
5823       // Ensure both shuffles have a zero input.
5824       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
5825         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
5826         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
5827         const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0);
5828         const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1);
5829         bool CanFold = true;
5830         int NumElts = VT.getVectorNumElements();
5831         SmallVector<int, 4> Mask(NumElts);
5832 
5833         for (int i = 0; i != NumElts; ++i) {
5834           int M0 = SV0->getMaskElt(i);
5835           int M1 = SV1->getMaskElt(i);
5836 
5837           // Determine if either index is pointing to a zero vector.
5838           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
5839           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
5840 
5841           // If one element is zero and the otherside is undef, keep undef.
5842           // This also handles the case that both are undef.
5843           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0)) {
5844             Mask[i] = -1;
5845             continue;
5846           }
5847 
5848           // Make sure only one of the elements is zero.
5849           if (M0Zero == M1Zero) {
5850             CanFold = false;
5851             break;
5852           }
5853 
5854           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
5855 
5856           // We have a zero and non-zero element. If the non-zero came from
5857           // SV0 make the index a LHS index. If it came from SV1, make it
5858           // a RHS index. We need to mod by NumElts because we don't care
5859           // which operand it came from in the original shuffles.
5860           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
5861         }
5862 
5863         if (CanFold) {
5864           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
5865           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
5866 
5867           SDValue LegalShuffle =
5868               TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
5869                                           Mask, DAG);
5870           if (LegalShuffle)
5871             return LegalShuffle;
5872         }
5873       }
5874     }
5875   }
5876 
5877   // fold (or c1, c2) -> c1|c2
5878   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
5879   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
5880   if (N0C && N1C && !N1C->isOpaque())
5881     return DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, N0C, N1C);
5882   // canonicalize constant to RHS
5883   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5884      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5885     return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
5886   // fold (or x, 0) -> x
5887   if (isNullConstant(N1))
5888     return N0;
5889   // fold (or x, -1) -> -1
5890   if (isAllOnesConstant(N1))
5891     return N1;
5892 
5893   if (SDValue NewSel = foldBinOpIntoSelect(N))
5894     return NewSel;
5895 
5896   // fold (or x, c) -> c iff (x & ~c) == 0
5897   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
5898     return N1;
5899 
5900   if (SDValue Combined = visitORLike(N0, N1, N))
5901     return Combined;
5902 
5903   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
5904     return Combined;
5905 
5906   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
5907   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
5908     return BSwap;
5909   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
5910     return BSwap;
5911 
5912   // reassociate or
5913   if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
5914     return ROR;
5915 
5916   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
5917   // iff (c1 & c2) != 0 or c1/c2 are undef.
5918   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
5919     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
5920   };
5921   if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
5922       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
5923     if (SDValue COR = DAG.FoldConstantArithmetic(
5924             ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) {
5925       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
5926       AddToWorklist(IOR.getNode());
5927       return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
5928     }
5929   }
5930 
5931   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
5932     return Combined;
5933   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
5934     return Combined;
5935 
5936   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
5937   if (N0.getOpcode() == N1.getOpcode())
5938     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
5939       return V;
5940 
5941   // See if this is some rotate idiom.
5942   if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
5943     return Rot;
5944 
5945   if (SDValue Load = MatchLoadCombine(N))
5946     return Load;
5947 
5948   // Simplify the operands using demanded-bits information.
5949   if (SimplifyDemandedBits(SDValue(N, 0)))
5950     return SDValue(N, 0);
5951 
5952   // If OR can be rewritten into ADD, try combines based on ADD.
5953   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
5954       DAG.haveNoCommonBitsSet(N0, N1))
5955     if (SDValue Combined = visitADDLike(N))
5956       return Combined;
5957 
5958   return SDValue();
5959 }
5960 
stripConstantMask(SelectionDAG & DAG,SDValue Op,SDValue & Mask)5961 static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) {
5962   if (Op.getOpcode() == ISD::AND &&
5963       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
5964     Mask = Op.getOperand(1);
5965     return Op.getOperand(0);
5966   }
5967   return Op;
5968 }
5969 
5970 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)5971 static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift,
5972                             SDValue &Mask) {
5973   Op = stripConstantMask(DAG, Op, Mask);
5974   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
5975     Shift = Op;
5976     return true;
5977   }
5978   return false;
5979 }
5980 
5981 /// Helper function for visitOR to extract the needed side of a rotate idiom
5982 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
5983 /// InstCombine merged some outside op with one of the shifts from
5984 /// the rotate pattern.
5985 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
5986 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
5987 /// patterns:
5988 ///
5989 ///   (or (add v v) (shrl v bitwidth-1)):
5990 ///     expands (add v v) -> (shl v 1)
5991 ///
5992 ///   (or (mul v c0) (shrl (mul v c1) c2)):
5993 ///     expands (mul v c0) -> (shl (mul v c1) c3)
5994 ///
5995 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
5996 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
5997 ///
5998 ///   (or (shl v c0) (shrl (shl v c1) c2)):
5999 ///     expands (shl v c0) -> (shl (shl v c1) c3)
6000 ///
6001 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
6002 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
6003 ///
6004 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)6005 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
6006                                      SDValue ExtractFrom, SDValue &Mask,
6007                                      const SDLoc &DL) {
6008   assert(OppShift && ExtractFrom && "Empty SDValue");
6009   assert(
6010       (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) &&
6011       "Existing shift must be valid as a rotate half");
6012 
6013   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
6014 
6015   // Value and Type of the shift.
6016   SDValue OppShiftLHS = OppShift.getOperand(0);
6017   EVT ShiftedVT = OppShiftLHS.getValueType();
6018 
6019   // Amount of the existing shift.
6020   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
6021 
6022   // (add v v) -> (shl v 1)
6023   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
6024       ExtractFrom.getOpcode() == ISD::ADD &&
6025       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
6026       ExtractFrom.getOperand(0) == OppShiftLHS &&
6027       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
6028     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
6029                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
6030 
6031   // Preconditions:
6032   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
6033   //
6034   // Find opcode of the needed shift to be extracted from (op0 v c0).
6035   unsigned Opcode = ISD::DELETED_NODE;
6036   bool IsMulOrDiv = false;
6037   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
6038   // opcode or its arithmetic (mul or udiv) variant.
6039   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
6040     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
6041     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
6042       return false;
6043     Opcode = NeededShift;
6044     return true;
6045   };
6046   // op0 must be either the needed shift opcode or the mul/udiv equivalent
6047   // that the needed shift can be extracted from.
6048   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
6049       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
6050     return SDValue();
6051 
6052   // op0 must be the same opcode on both sides, have the same LHS argument,
6053   // and produce the same value type.
6054   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
6055       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
6056       ShiftedVT != ExtractFrom.getValueType())
6057     return SDValue();
6058 
6059   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
6060   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
6061   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
6062   ConstantSDNode *ExtractFromCst =
6063       isConstOrConstSplat(ExtractFrom.getOperand(1));
6064   // TODO: We should be able to handle non-uniform constant vectors for these values
6065   // Check that we have constant values.
6066   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
6067       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
6068       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
6069     return SDValue();
6070 
6071   // Compute the shift amount we need to extract to complete the rotate.
6072   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
6073   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
6074     return SDValue();
6075   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
6076   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
6077   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
6078   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
6079   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
6080 
6081   // Now try extract the needed shift from the ExtractFrom op and see if the
6082   // result matches up with the existing shift's LHS op.
6083   if (IsMulOrDiv) {
6084     // Op to extract from is a mul or udiv by a constant.
6085     // Check:
6086     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
6087     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
6088     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
6089                                                  NeededShiftAmt.getZExtValue());
6090     APInt ResultAmt;
6091     APInt Rem;
6092     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
6093     if (Rem != 0 || ResultAmt != OppLHSAmt)
6094       return SDValue();
6095   } else {
6096     // Op to extract from is a shift by a constant.
6097     // Check:
6098     //      c2 - (bitwidth(op0 v c0) - c1) == c0
6099     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
6100                                           ExtractFromAmt.getBitWidth()))
6101       return SDValue();
6102   }
6103 
6104   // Return the expanded shift op that should allow a rotate to be formed.
6105   EVT ShiftVT = OppShift.getOperand(1).getValueType();
6106   EVT ResVT = ExtractFrom.getValueType();
6107   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
6108   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
6109 }
6110 
6111 // Return true if we can prove that, whenever Neg and Pos are both in the
6112 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
6113 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
6114 //
6115 //     (or (shift1 X, Neg), (shift2 X, Pos))
6116 //
6117 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
6118 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
6119 // to consider shift amounts with defined behavior.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG)6120 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
6121                            SelectionDAG &DAG) {
6122   // If EltSize is a power of 2 then:
6123   //
6124   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
6125   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
6126   //
6127   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
6128   // for the stronger condition:
6129   //
6130   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
6131   //
6132   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
6133   // we can just replace Neg with Neg' for the rest of the function.
6134   //
6135   // In other cases we check for the even stronger condition:
6136   //
6137   //     Neg == EltSize - Pos                                    [B]
6138   //
6139   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
6140   // behavior if Pos == 0 (and consequently Neg == EltSize).
6141   //
6142   // We could actually use [A] whenever EltSize is a power of 2, but the
6143   // only extra cases that it would match are those uninteresting ones
6144   // where Neg and Pos are never in range at the same time.  E.g. for
6145   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
6146   // as well as (sub 32, Pos), but:
6147   //
6148   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
6149   //
6150   // always invokes undefined behavior for 32-bit X.
6151   //
6152   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
6153   unsigned MaskLoBits = 0;
6154   if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) {
6155     if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) {
6156       KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0));
6157       unsigned Bits = Log2_64(EltSize);
6158       if (NegC->getAPIntValue().getActiveBits() <= Bits &&
6159           ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) {
6160         Neg = Neg.getOperand(0);
6161         MaskLoBits = Bits;
6162       }
6163     }
6164   }
6165 
6166   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
6167   if (Neg.getOpcode() != ISD::SUB)
6168     return false;
6169   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
6170   if (!NegC)
6171     return false;
6172   SDValue NegOp1 = Neg.getOperand(1);
6173 
6174   // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with
6175   // Pos'.  The truncation is redundant for the purpose of the equality.
6176   if (MaskLoBits && Pos.getOpcode() == ISD::AND) {
6177     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) {
6178       KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0));
6179       if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits &&
6180           ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >=
6181            MaskLoBits))
6182         Pos = Pos.getOperand(0);
6183     }
6184   }
6185 
6186   // The condition we need is now:
6187   //
6188   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
6189   //
6190   // If NegOp1 == Pos then we need:
6191   //
6192   //              EltSize & Mask == NegC & Mask
6193   //
6194   // (because "x & Mask" is a truncation and distributes through subtraction).
6195   APInt Width;
6196   if (Pos == NegOp1)
6197     Width = NegC->getAPIntValue();
6198 
6199   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
6200   // Then the condition we want to prove becomes:
6201   //
6202   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
6203   //
6204   // which, again because "x & Mask" is a truncation, becomes:
6205   //
6206   //                NegC & Mask == (EltSize - PosC) & Mask
6207   //             EltSize & Mask == (NegC + PosC) & Mask
6208   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
6209     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
6210       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
6211     else
6212       return false;
6213   } else
6214     return false;
6215 
6216   // Now we just need to check that EltSize & Mask == Width & Mask.
6217   if (MaskLoBits)
6218     // EltSize & Mask is 0 since Mask is EltSize - 1.
6219     return Width.getLoBits(MaskLoBits) == 0;
6220   return Width == EltSize;
6221 }
6222 
6223 // A subroutine of MatchRotate used once we have found an OR of two opposite
6224 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
6225 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
6226 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
6227 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)6228 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
6229                                        SDValue Neg, SDValue InnerPos,
6230                                        SDValue InnerNeg, unsigned PosOpcode,
6231                                        unsigned NegOpcode, const SDLoc &DL) {
6232   // fold (or (shl x, (*ext y)),
6233   //          (srl x, (*ext (sub 32, y)))) ->
6234   //   (rotl x, y) or (rotr x, (sub 32, y))
6235   //
6236   // fold (or (shl x, (*ext (sub 32, y))),
6237   //          (srl x, (*ext y))) ->
6238   //   (rotr x, y) or (rotl x, (sub 32, y))
6239   EVT VT = Shifted.getValueType();
6240   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG)) {
6241     bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
6242     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
6243                        HasPos ? Pos : Neg);
6244   }
6245 
6246   return SDValue();
6247 }
6248 
6249 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
6250 // idioms for rotate, and if the target supports rotation instructions, generate
6251 // a rot[lr].
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)6252 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
6253   // Must be a legal type.  Expanded 'n promoted things won't work with rotates.
6254   EVT VT = LHS.getValueType();
6255   if (!TLI.isTypeLegal(VT))
6256     return SDValue();
6257 
6258   // The target must have at least one rotate flavor.
6259   bool HasROTL = hasOperation(ISD::ROTL, VT);
6260   bool HasROTR = hasOperation(ISD::ROTR, VT);
6261   if (!HasROTL && !HasROTR)
6262     return SDValue();
6263 
6264   // Check for truncated rotate.
6265   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
6266       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
6267     assert(LHS.getValueType() == RHS.getValueType());
6268     if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
6269       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
6270     }
6271   }
6272 
6273   // Match "(X shl/srl V1) & V2" where V2 may not be present.
6274   SDValue LHSShift;   // The shift.
6275   SDValue LHSMask;    // AND value if any.
6276   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
6277 
6278   SDValue RHSShift;   // The shift.
6279   SDValue RHSMask;    // AND value if any.
6280   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
6281 
6282   // If neither side matched a rotate half, bail
6283   if (!LHSShift && !RHSShift)
6284     return SDValue();
6285 
6286   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
6287   // side of the rotate, so try to handle that here. In all cases we need to
6288   // pass the matched shift from the opposite side to compute the opcode and
6289   // needed shift amount to extract.  We still want to do this if both sides
6290   // matched a rotate half because one half may be a potential overshift that
6291   // can be broken down (ie if InstCombine merged two shl or srl ops into a
6292   // single one).
6293 
6294   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
6295   if (LHSShift)
6296     if (SDValue NewRHSShift =
6297             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
6298       RHSShift = NewRHSShift;
6299   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
6300   if (RHSShift)
6301     if (SDValue NewLHSShift =
6302             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
6303       LHSShift = NewLHSShift;
6304 
6305   // If a side is still missing, nothing else we can do.
6306   if (!RHSShift || !LHSShift)
6307     return SDValue();
6308 
6309   // At this point we've matched or extracted a shift op on each side.
6310 
6311   if (LHSShift.getOperand(0) != RHSShift.getOperand(0))
6312     return SDValue(); // Not shifting the same value.
6313 
6314   if (LHSShift.getOpcode() == RHSShift.getOpcode())
6315     return SDValue(); // Shifts must disagree.
6316 
6317   // Canonicalize shl to left side in a shl/srl pair.
6318   if (RHSShift.getOpcode() == ISD::SHL) {
6319     std::swap(LHS, RHS);
6320     std::swap(LHSShift, RHSShift);
6321     std::swap(LHSMask, RHSMask);
6322   }
6323 
6324   unsigned EltSizeInBits = VT.getScalarSizeInBits();
6325   SDValue LHSShiftArg = LHSShift.getOperand(0);
6326   SDValue LHSShiftAmt = LHSShift.getOperand(1);
6327   SDValue RHSShiftArg = RHSShift.getOperand(0);
6328   SDValue RHSShiftAmt = RHSShift.getOperand(1);
6329 
6330   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
6331   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
6332   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
6333                                         ConstantSDNode *RHS) {
6334     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
6335   };
6336   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
6337     SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
6338                               LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt);
6339 
6340     // If there is an AND of either shifted operand, apply it to the result.
6341     if (LHSMask.getNode() || RHSMask.getNode()) {
6342       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
6343       SDValue Mask = AllOnes;
6344 
6345       if (LHSMask.getNode()) {
6346         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
6347         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
6348                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
6349       }
6350       if (RHSMask.getNode()) {
6351         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
6352         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
6353                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
6354       }
6355 
6356       Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask);
6357     }
6358 
6359     return Rot;
6360   }
6361 
6362   // If there is a mask here, and we have a variable shift, we can't be sure
6363   // that we're masking out the right stuff.
6364   if (LHSMask.getNode() || RHSMask.getNode())
6365     return SDValue();
6366 
6367   // If the shift amount is sign/zext/any-extended just peel it off.
6368   SDValue LExtOp0 = LHSShiftAmt;
6369   SDValue RExtOp0 = RHSShiftAmt;
6370   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
6371        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
6372        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
6373        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
6374       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
6375        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
6376        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
6377        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
6378     LExtOp0 = LHSShiftAmt.getOperand(0);
6379     RExtOp0 = RHSShiftAmt.getOperand(0);
6380   }
6381 
6382   SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
6383                                    LExtOp0, RExtOp0, ISD::ROTL, ISD::ROTR, DL);
6384   if (TryL)
6385     return TryL;
6386 
6387   SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
6388                                    RExtOp0, LExtOp0, ISD::ROTR, ISD::ROTL, DL);
6389   if (TryR)
6390     return TryR;
6391 
6392   return SDValue();
6393 }
6394 
6395 namespace {
6396 
6397 /// Represents known origin of an individual byte in load combine pattern. The
6398 /// value of the byte is either constant zero or comes from memory.
6399 struct ByteProvider {
6400   // For constant zero providers Load is set to nullptr. For memory providers
6401   // Load represents the node which loads the byte from memory.
6402   // ByteOffset is the offset of the byte in the value produced by the load.
6403   LoadSDNode *Load = nullptr;
6404   unsigned ByteOffset = 0;
6405 
6406   ByteProvider() = default;
6407 
getMemory__anon9770a4810b11::ByteProvider6408   static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
6409     return ByteProvider(Load, ByteOffset);
6410   }
6411 
getConstantZero__anon9770a4810b11::ByteProvider6412   static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
6413 
isConstantZero__anon9770a4810b11::ByteProvider6414   bool isConstantZero() const { return !Load; }
isMemory__anon9770a4810b11::ByteProvider6415   bool isMemory() const { return Load; }
6416 
operator ==__anon9770a4810b11::ByteProvider6417   bool operator==(const ByteProvider &Other) const {
6418     return Other.Load == Load && Other.ByteOffset == ByteOffset;
6419   }
6420 
6421 private:
ByteProvider__anon9770a4810b11::ByteProvider6422   ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
6423       : Load(Load), ByteOffset(ByteOffset) {}
6424 };
6425 
6426 } // end anonymous namespace
6427 
6428 /// Recursively traverses the expression calculating the origin of the requested
6429 /// byte of the given value. Returns None if the provider can't be calculated.
6430 ///
6431 /// For all the values except the root of the expression verifies that the value
6432 /// has exactly one use and if it's not true return None. This way if the origin
6433 /// of the byte is returned it's guaranteed that the values which contribute to
6434 /// the byte are not used outside of this expression.
6435 ///
6436 /// Because the parts of the expression are not allowed to have more than one
6437 /// use this function iterates over trees, not DAGs. So it never visits the same
6438 /// node more than once.
6439 static const Optional<ByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,bool Root=false)6440 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
6441                       bool Root = false) {
6442   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
6443   if (Depth == 10)
6444     return None;
6445 
6446   if (!Root && !Op.hasOneUse())
6447     return None;
6448 
6449   assert(Op.getValueType().isScalarInteger() && "can't handle other types");
6450   unsigned BitWidth = Op.getValueSizeInBits();
6451   if (BitWidth % 8 != 0)
6452     return None;
6453   unsigned ByteWidth = BitWidth / 8;
6454   assert(Index < ByteWidth && "invalid index requested");
6455   (void) ByteWidth;
6456 
6457   switch (Op.getOpcode()) {
6458   case ISD::OR: {
6459     auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
6460     if (!LHS)
6461       return None;
6462     auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
6463     if (!RHS)
6464       return None;
6465 
6466     if (LHS->isConstantZero())
6467       return RHS;
6468     if (RHS->isConstantZero())
6469       return LHS;
6470     return None;
6471   }
6472   case ISD::SHL: {
6473     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
6474     if (!ShiftOp)
6475       return None;
6476 
6477     uint64_t BitShift = ShiftOp->getZExtValue();
6478     if (BitShift % 8 != 0)
6479       return None;
6480     uint64_t ByteShift = BitShift / 8;
6481 
6482     return Index < ByteShift
6483                ? ByteProvider::getConstantZero()
6484                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
6485                                        Depth + 1);
6486   }
6487   case ISD::ANY_EXTEND:
6488   case ISD::SIGN_EXTEND:
6489   case ISD::ZERO_EXTEND: {
6490     SDValue NarrowOp = Op->getOperand(0);
6491     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
6492     if (NarrowBitWidth % 8 != 0)
6493       return None;
6494     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
6495 
6496     if (Index >= NarrowByteWidth)
6497       return Op.getOpcode() == ISD::ZERO_EXTEND
6498                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
6499                  : None;
6500     return calculateByteProvider(NarrowOp, Index, Depth + 1);
6501   }
6502   case ISD::BSWAP:
6503     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
6504                                  Depth + 1);
6505   case ISD::LOAD: {
6506     auto L = cast<LoadSDNode>(Op.getNode());
6507     if (!L->isSimple() || L->isIndexed())
6508       return None;
6509 
6510     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
6511     if (NarrowBitWidth % 8 != 0)
6512       return None;
6513     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
6514 
6515     if (Index >= NarrowByteWidth)
6516       return L->getExtensionType() == ISD::ZEXTLOAD
6517                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
6518                  : None;
6519     return ByteProvider::getMemory(L, Index);
6520   }
6521   }
6522 
6523   return None;
6524 }
6525 
LittleEndianByteAt(unsigned BW,unsigned i)6526 static unsigned LittleEndianByteAt(unsigned BW, unsigned i) {
6527   return i;
6528 }
6529 
BigEndianByteAt(unsigned BW,unsigned i)6530 static unsigned BigEndianByteAt(unsigned BW, unsigned i) {
6531   return BW - i - 1;
6532 }
6533 
6534 // Check if the bytes offsets we are looking at match with either big or
6535 // little endian value loaded. Return true for big endian, false for little
6536 // endian, and None if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)6537 static Optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
6538                                   int64_t FirstOffset) {
6539   // The endian can be decided only when it is 2 bytes at least.
6540   unsigned Width = ByteOffsets.size();
6541   if (Width < 2)
6542     return None;
6543 
6544   bool BigEndian = true, LittleEndian = true;
6545   for (unsigned i = 0; i < Width; i++) {
6546     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
6547     LittleEndian &= CurrentByteOffset == LittleEndianByteAt(Width, i);
6548     BigEndian &= CurrentByteOffset == BigEndianByteAt(Width, i);
6549     if (!BigEndian && !LittleEndian)
6550       return None;
6551   }
6552 
6553   assert((BigEndian != LittleEndian) && "It should be either big endian or"
6554                                         "little endian");
6555   return BigEndian;
6556 }
6557 
stripTruncAndExt(SDValue Value)6558 static SDValue stripTruncAndExt(SDValue Value) {
6559   switch (Value.getOpcode()) {
6560   case ISD::TRUNCATE:
6561   case ISD::ZERO_EXTEND:
6562   case ISD::SIGN_EXTEND:
6563   case ISD::ANY_EXTEND:
6564     return stripTruncAndExt(Value.getOperand(0));
6565   }
6566   return Value;
6567 }
6568 
6569 /// Match a pattern where a wide type scalar value is stored by several narrow
6570 /// stores. Fold it into a single store or a BSWAP and a store if the targets
6571 /// supports it.
6572 ///
6573 /// Assuming little endian target:
6574 ///  i8 *p = ...
6575 ///  i32 val = ...
6576 ///  p[0] = (val >> 0) & 0xFF;
6577 ///  p[1] = (val >> 8) & 0xFF;
6578 ///  p[2] = (val >> 16) & 0xFF;
6579 ///  p[3] = (val >> 24) & 0xFF;
6580 /// =>
6581 ///  *((i32)p) = val;
6582 ///
6583 ///  i8 *p = ...
6584 ///  i32 val = ...
6585 ///  p[0] = (val >> 24) & 0xFF;
6586 ///  p[1] = (val >> 16) & 0xFF;
6587 ///  p[2] = (val >> 8) & 0xFF;
6588 ///  p[3] = (val >> 0) & 0xFF;
6589 /// =>
6590 ///  *((i32)p) = BSWAP(val);
MatchStoreCombine(StoreSDNode * N)6591 SDValue DAGCombiner::MatchStoreCombine(StoreSDNode *N) {
6592   // Collect all the stores in the chain.
6593   SDValue Chain;
6594   SmallVector<StoreSDNode *, 8> Stores;
6595   for (StoreSDNode *Store = N; Store; Store = dyn_cast<StoreSDNode>(Chain)) {
6596     // TODO: Allow unordered atomics when wider type is legal (see D66309)
6597     if (Store->getMemoryVT() != MVT::i8 ||
6598         !Store->isSimple() || Store->isIndexed())
6599       return SDValue();
6600     Stores.push_back(Store);
6601     Chain = Store->getChain();
6602   }
6603   // Handle the simple type only.
6604   unsigned Width = Stores.size();
6605   EVT VT = EVT::getIntegerVT(
6606     *DAG.getContext(), Width * N->getMemoryVT().getSizeInBits());
6607   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6608     return SDValue();
6609 
6610   if (LegalOperations && !TLI.isOperationLegal(ISD::STORE, VT))
6611     return SDValue();
6612 
6613   // Check if all the bytes of the combined value we are looking at are stored
6614   // to the same base address. Collect bytes offsets from Base address into
6615   // ByteOffsets.
6616   SDValue CombinedValue;
6617   SmallVector<int64_t, 8> ByteOffsets(Width, INT64_MAX);
6618   int64_t FirstOffset = INT64_MAX;
6619   StoreSDNode *FirstStore = nullptr;
6620   Optional<BaseIndexOffset> Base;
6621   for (auto Store : Stores) {
6622     // All the stores store different byte of the CombinedValue. A truncate is
6623     // required to get that byte value.
6624     SDValue Trunc = Store->getValue();
6625     if (Trunc.getOpcode() != ISD::TRUNCATE)
6626       return SDValue();
6627     // A shift operation is required to get the right byte offset, except the
6628     // first byte.
6629     int64_t Offset = 0;
6630     SDValue Value = Trunc.getOperand(0);
6631     if (Value.getOpcode() == ISD::SRL ||
6632         Value.getOpcode() == ISD::SRA) {
6633       ConstantSDNode *ShiftOffset =
6634         dyn_cast<ConstantSDNode>(Value.getOperand(1));
6635       // Trying to match the following pattern. The shift offset must be
6636       // a constant and a multiple of 8. It is the byte offset in "y".
6637       //
6638       // x = srl y, offset
6639       // i8 z = trunc x
6640       // store z, ...
6641       if (!ShiftOffset || (ShiftOffset->getSExtValue() % 8))
6642         return SDValue();
6643 
6644      Offset = ShiftOffset->getSExtValue()/8;
6645      Value = Value.getOperand(0);
6646     }
6647 
6648     // Stores must share the same combined value with different offsets.
6649     if (!CombinedValue)
6650       CombinedValue = Value;
6651     else if (stripTruncAndExt(CombinedValue) != stripTruncAndExt(Value))
6652       return SDValue();
6653 
6654     // The trunc and all the extend operation should be stripped to get the
6655     // real value we are stored.
6656     else if (CombinedValue.getValueType() != VT) {
6657       if (Value.getValueType() == VT ||
6658           Value.getValueSizeInBits() > CombinedValue.getValueSizeInBits())
6659         CombinedValue = Value;
6660       // Give up if the combined value type is smaller than the store size.
6661       if (CombinedValue.getValueSizeInBits() < VT.getSizeInBits())
6662         return SDValue();
6663     }
6664 
6665     // Stores must share the same base address
6666     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
6667     int64_t ByteOffsetFromBase = 0;
6668     if (!Base)
6669       Base = Ptr;
6670     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
6671       return SDValue();
6672 
6673     // Remember the first byte store
6674     if (ByteOffsetFromBase < FirstOffset) {
6675       FirstStore = Store;
6676       FirstOffset = ByteOffsetFromBase;
6677     }
6678     // Map the offset in the store and the offset in the combined value, and
6679     // early return if it has been set before.
6680     if (Offset < 0 || Offset >= Width || ByteOffsets[Offset] != INT64_MAX)
6681       return SDValue();
6682     ByteOffsets[Offset] = ByteOffsetFromBase;
6683   }
6684 
6685   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
6686   assert(FirstStore && "First store must be set");
6687 
6688   // Check if the bytes of the combined value we are looking at match with
6689   // either big or little endian value store.
6690   Optional<bool> IsBigEndian = isBigEndian(ByteOffsets, FirstOffset);
6691   if (!IsBigEndian.hasValue())
6692     return SDValue();
6693 
6694   // The node we are looking at matches with the pattern, check if we can
6695   // replace it with a single bswap if needed and store.
6696 
6697   // If the store needs byte swap check if the target supports it
6698   bool NeedsBswap = DAG.getDataLayout().isBigEndian() != *IsBigEndian;
6699 
6700   // Before legalize we can introduce illegal bswaps which will be later
6701   // converted to an explicit bswap sequence. This way we end up with a single
6702   // store and byte shuffling instead of several stores and byte shuffling.
6703   if (NeedsBswap && LegalOperations && !TLI.isOperationLegal(ISD::BSWAP, VT))
6704     return SDValue();
6705 
6706   // Check that a store of the wide type is both allowed and fast on the target
6707   bool Fast = false;
6708   bool Allowed =
6709       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
6710                              *FirstStore->getMemOperand(), &Fast);
6711   if (!Allowed || !Fast)
6712     return SDValue();
6713 
6714   if (VT != CombinedValue.getValueType()) {
6715     assert(CombinedValue.getValueType().getSizeInBits() > VT.getSizeInBits() &&
6716            "Get unexpected store value to combine");
6717     CombinedValue = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT,
6718                              CombinedValue);
6719   }
6720 
6721   if (NeedsBswap)
6722     CombinedValue = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, CombinedValue);
6723 
6724   SDValue NewStore =
6725     DAG.getStore(Chain, SDLoc(N),  CombinedValue, FirstStore->getBasePtr(),
6726                  FirstStore->getPointerInfo(), FirstStore->getAlignment());
6727 
6728   // Rely on other DAG combine rules to remove the other individual stores.
6729   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
6730   return NewStore;
6731 }
6732 
6733 /// Match a pattern where a wide type scalar value is loaded by several narrow
6734 /// loads and combined by shifts and ors. Fold it into a single load or a load
6735 /// and a BSWAP if the targets supports it.
6736 ///
6737 /// Assuming little endian target:
6738 ///  i8 *a = ...
6739 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
6740 /// =>
6741 ///  i32 val = *((i32)a)
6742 ///
6743 ///  i8 *a = ...
6744 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
6745 /// =>
6746 ///  i32 val = BSWAP(*((i32)a))
6747 ///
6748 /// TODO: This rule matches complex patterns with OR node roots and doesn't
6749 /// interact well with the worklist mechanism. When a part of the pattern is
6750 /// updated (e.g. one of the loads) its direct users are put into the worklist,
6751 /// but the root node of the pattern which triggers the load combine is not
6752 /// necessarily a direct user of the changed node. For example, once the address
6753 /// of t28 load is reassociated load combine won't be triggered:
6754 ///             t25: i32 = add t4, Constant:i32<2>
6755 ///           t26: i64 = sign_extend t25
6756 ///        t27: i64 = add t2, t26
6757 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
6758 ///     t29: i32 = zero_extend t28
6759 ///   t32: i32 = shl t29, Constant:i8<8>
6760 /// t33: i32 = or t23, t32
6761 /// As a possible fix visitLoad can check if the load can be a part of a load
6762 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)6763 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
6764   assert(N->getOpcode() == ISD::OR &&
6765          "Can only match load combining against OR nodes");
6766 
6767   // Handles simple types only
6768   EVT VT = N->getValueType(0);
6769   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6770     return SDValue();
6771   unsigned ByteWidth = VT.getSizeInBits() / 8;
6772 
6773   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
6774   auto MemoryByteOffset = [&] (ByteProvider P) {
6775     assert(P.isMemory() && "Must be a memory byte provider");
6776     unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits();
6777     assert(LoadBitWidth % 8 == 0 &&
6778            "can only analyze providers for individual bytes not bit");
6779     unsigned LoadByteWidth = LoadBitWidth / 8;
6780     return IsBigEndianTarget
6781             ? BigEndianByteAt(LoadByteWidth, P.ByteOffset)
6782             : LittleEndianByteAt(LoadByteWidth, P.ByteOffset);
6783   };
6784 
6785   Optional<BaseIndexOffset> Base;
6786   SDValue Chain;
6787 
6788   SmallPtrSet<LoadSDNode *, 8> Loads;
6789   Optional<ByteProvider> FirstByteProvider;
6790   int64_t FirstOffset = INT64_MAX;
6791 
6792   // Check if all the bytes of the OR we are looking at are loaded from the same
6793   // base address. Collect bytes offsets from Base address in ByteOffsets.
6794   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
6795   unsigned ZeroExtendedBytes = 0;
6796   for (int i = ByteWidth - 1; i >= 0; --i) {
6797     auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
6798     if (!P)
6799       return SDValue();
6800 
6801     if (P->isConstantZero()) {
6802       // It's OK for the N most significant bytes to be 0, we can just
6803       // zero-extend the load.
6804       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
6805         return SDValue();
6806       continue;
6807     }
6808     assert(P->isMemory() && "provenance should either be memory or zero");
6809 
6810     LoadSDNode *L = P->Load;
6811     assert(L->hasNUsesOfValue(1, 0) && L->isSimple() &&
6812            !L->isIndexed() &&
6813            "Must be enforced by calculateByteProvider");
6814     assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
6815 
6816     // All loads must share the same chain
6817     SDValue LChain = L->getChain();
6818     if (!Chain)
6819       Chain = LChain;
6820     else if (Chain != LChain)
6821       return SDValue();
6822 
6823     // Loads must share the same base address
6824     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
6825     int64_t ByteOffsetFromBase = 0;
6826     if (!Base)
6827       Base = Ptr;
6828     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
6829       return SDValue();
6830 
6831     // Calculate the offset of the current byte from the base address
6832     ByteOffsetFromBase += MemoryByteOffset(*P);
6833     ByteOffsets[i] = ByteOffsetFromBase;
6834 
6835     // Remember the first byte load
6836     if (ByteOffsetFromBase < FirstOffset) {
6837       FirstByteProvider = P;
6838       FirstOffset = ByteOffsetFromBase;
6839     }
6840 
6841     Loads.insert(L);
6842   }
6843   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
6844          "memory, so there must be at least one load which produces the value");
6845   assert(Base && "Base address of the accessed memory location must be set");
6846   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
6847 
6848   bool NeedsZext = ZeroExtendedBytes > 0;
6849 
6850   EVT MemVT =
6851       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
6852 
6853   if (!MemVT.isSimple())
6854     return SDValue();
6855 
6856   // Before legalize we can introduce too wide illegal loads which will be later
6857   // split into legal sized loads. This enables us to combine i64 load by i8
6858   // patterns to a couple of i32 loads on 32 bit targets.
6859   if (LegalOperations &&
6860       !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
6861                             MemVT))
6862     return SDValue();
6863 
6864   // Check if the bytes of the OR we are looking at match with either big or
6865   // little endian value load
6866   Optional<bool> IsBigEndian = isBigEndian(
6867       makeArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
6868   if (!IsBigEndian.hasValue())
6869     return SDValue();
6870 
6871   assert(FirstByteProvider && "must be set");
6872 
6873   // Ensure that the first byte is loaded from zero offset of the first load.
6874   // So the combined value can be loaded from the first load address.
6875   if (MemoryByteOffset(*FirstByteProvider) != 0)
6876     return SDValue();
6877   LoadSDNode *FirstLoad = FirstByteProvider->Load;
6878 
6879   // The node we are looking at matches with the pattern, check if we can
6880   // replace it with a single (possibly zero-extended) load and bswap + shift if
6881   // needed.
6882 
6883   // If the load needs byte swap check if the target supports it
6884   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
6885 
6886   // Before legalize we can introduce illegal bswaps which will be later
6887   // converted to an explicit bswap sequence. This way we end up with a single
6888   // load and byte shuffling instead of several loads and byte shuffling.
6889   // We do not introduce illegal bswaps when zero-extending as this tends to
6890   // introduce too many arithmetic instructions.
6891   if (NeedsBswap && (LegalOperations || NeedsZext) &&
6892       !TLI.isOperationLegal(ISD::BSWAP, VT))
6893     return SDValue();
6894 
6895   // If we need to bswap and zero extend, we have to insert a shift. Check that
6896   // it is legal.
6897   if (NeedsBswap && NeedsZext && LegalOperations &&
6898       !TLI.isOperationLegal(ISD::SHL, VT))
6899     return SDValue();
6900 
6901   // Check that a load of the wide type is both allowed and fast on the target
6902   bool Fast = false;
6903   bool Allowed =
6904       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
6905                              *FirstLoad->getMemOperand(), &Fast);
6906   if (!Allowed || !Fast)
6907     return SDValue();
6908 
6909   SDValue NewLoad = DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
6910                                    SDLoc(N), VT, Chain, FirstLoad->getBasePtr(),
6911                                    FirstLoad->getPointerInfo(), MemVT,
6912                                    FirstLoad->getAlignment());
6913 
6914   // Transfer chain users from old loads to the new load.
6915   for (LoadSDNode *L : Loads)
6916     DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
6917 
6918   if (!NeedsBswap)
6919     return NewLoad;
6920 
6921   SDValue ShiftedLoad =
6922       NeedsZext
6923           ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
6924                         DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
6925                                                    SDLoc(N), LegalOperations))
6926           : NewLoad;
6927   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
6928 }
6929 
6930 // If the target has andn, bsl, or a similar bit-select instruction,
6931 // we want to unfold masked merge, with canonical pattern of:
6932 //   |        A  |  |B|
6933 //   ((x ^ y) & m) ^ y
6934 //    |  D  |
6935 // Into:
6936 //   (x & m) | (y & ~m)
6937 // If y is a constant, and the 'andn' does not work with immediates,
6938 // we unfold into a different pattern:
6939 //   ~(~x & m) & (m | y)
6940 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
6941 //       the very least that breaks andnpd / andnps patterns, and because those
6942 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)6943 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
6944   assert(N->getOpcode() == ISD::XOR);
6945 
6946   // Don't touch 'not' (i.e. where y = -1).
6947   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
6948     return SDValue();
6949 
6950   EVT VT = N->getValueType(0);
6951 
6952   // There are 3 commutable operators in the pattern,
6953   // so we have to deal with 8 possible variants of the basic pattern.
6954   SDValue X, Y, M;
6955   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
6956     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
6957       return false;
6958     SDValue Xor = And.getOperand(XorIdx);
6959     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
6960       return false;
6961     SDValue Xor0 = Xor.getOperand(0);
6962     SDValue Xor1 = Xor.getOperand(1);
6963     // Don't touch 'not' (i.e. where y = -1).
6964     if (isAllOnesOrAllOnesSplat(Xor1))
6965       return false;
6966     if (Other == Xor0)
6967       std::swap(Xor0, Xor1);
6968     if (Other != Xor1)
6969       return false;
6970     X = Xor0;
6971     Y = Xor1;
6972     M = And.getOperand(XorIdx ? 0 : 1);
6973     return true;
6974   };
6975 
6976   SDValue N0 = N->getOperand(0);
6977   SDValue N1 = N->getOperand(1);
6978   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
6979       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
6980     return SDValue();
6981 
6982   // Don't do anything if the mask is constant. This should not be reachable.
6983   // InstCombine should have already unfolded this pattern, and DAGCombiner
6984   // probably shouldn't produce it, too.
6985   if (isa<ConstantSDNode>(M.getNode()))
6986     return SDValue();
6987 
6988   // We can transform if the target has AndNot
6989   if (!TLI.hasAndNot(M))
6990     return SDValue();
6991 
6992   SDLoc DL(N);
6993 
6994   // If Y is a constant, check that 'andn' works with immediates.
6995   if (!TLI.hasAndNot(Y)) {
6996     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
6997     // If not, we need to do a bit more work to make sure andn is still used.
6998     SDValue NotX = DAG.getNOT(DL, X, VT);
6999     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
7000     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
7001     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
7002     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
7003   }
7004 
7005   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
7006   SDValue NotM = DAG.getNOT(DL, M, VT);
7007   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
7008 
7009   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
7010 }
7011 
visitXOR(SDNode * N)7012 SDValue DAGCombiner::visitXOR(SDNode *N) {
7013   SDValue N0 = N->getOperand(0);
7014   SDValue N1 = N->getOperand(1);
7015   EVT VT = N0.getValueType();
7016 
7017   // fold vector ops
7018   if (VT.isVector()) {
7019     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7020       return FoldedVOp;
7021 
7022     // fold (xor x, 0) -> x, vector edition
7023     if (ISD::isBuildVectorAllZeros(N0.getNode()))
7024       return N1;
7025     if (ISD::isBuildVectorAllZeros(N1.getNode()))
7026       return N0;
7027   }
7028 
7029   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
7030   SDLoc DL(N);
7031   if (N0.isUndef() && N1.isUndef())
7032     return DAG.getConstant(0, DL, VT);
7033   // fold (xor x, undef) -> undef
7034   if (N0.isUndef())
7035     return N0;
7036   if (N1.isUndef())
7037     return N1;
7038   // fold (xor c1, c2) -> c1^c2
7039   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
7040   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
7041   if (N0C && N1C)
7042     return DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, N0C, N1C);
7043   // canonicalize constant to RHS
7044   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7045      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7046     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
7047   // fold (xor x, 0) -> x
7048   if (isNullConstant(N1))
7049     return N0;
7050 
7051   if (SDValue NewSel = foldBinOpIntoSelect(N))
7052     return NewSel;
7053 
7054   // reassociate xor
7055   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
7056     return RXOR;
7057 
7058   // fold !(x cc y) -> (x !cc y)
7059   unsigned N0Opcode = N0.getOpcode();
7060   SDValue LHS, RHS, CC;
7061   if (TLI.isConstTrueVal(N1.getNode()) && isSetCCEquivalent(N0, LHS, RHS, CC)) {
7062     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
7063                                                LHS.getValueType());
7064     if (!LegalOperations ||
7065         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
7066       switch (N0Opcode) {
7067       default:
7068         llvm_unreachable("Unhandled SetCC Equivalent!");
7069       case ISD::SETCC:
7070         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
7071       case ISD::SELECT_CC:
7072         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
7073                                N0.getOperand(3), NotCC);
7074       }
7075     }
7076   }
7077 
7078   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
7079   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
7080       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
7081     SDValue V = N0.getOperand(0);
7082     SDLoc DL0(N0);
7083     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
7084                     DAG.getConstant(1, DL0, V.getValueType()));
7085     AddToWorklist(V.getNode());
7086     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
7087   }
7088 
7089   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
7090   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
7091       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
7092     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
7093     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
7094       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
7095       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
7096       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
7097       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
7098       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
7099     }
7100   }
7101   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
7102   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
7103       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
7104     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
7105     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
7106       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
7107       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
7108       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
7109       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
7110       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
7111     }
7112   }
7113 
7114   // fold (not (neg x)) -> (add X, -1)
7115   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
7116   // Y is a constant or the subtract has a single use.
7117   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
7118       isNullConstant(N0.getOperand(0))) {
7119     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
7120                        DAG.getAllOnesConstant(DL, VT));
7121   }
7122 
7123   // fold (not (add X, -1)) -> (neg X)
7124   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
7125       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
7126     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
7127                        N0.getOperand(0));
7128   }
7129 
7130   // fold (xor (and x, y), y) -> (and (not x), y)
7131   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
7132     SDValue X = N0.getOperand(0);
7133     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
7134     AddToWorklist(NotX.getNode());
7135     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
7136   }
7137 
7138   if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) {
7139     ConstantSDNode *XorC = isConstOrConstSplat(N1);
7140     ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1));
7141     unsigned BitWidth = VT.getScalarSizeInBits();
7142     if (XorC && ShiftC) {
7143       // Don't crash on an oversized shift. We can not guarantee that a bogus
7144       // shift has been simplified to undef.
7145       uint64_t ShiftAmt = ShiftC->getLimitedValue();
7146       if (ShiftAmt < BitWidth) {
7147         APInt Ones = APInt::getAllOnesValue(BitWidth);
7148         Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt);
7149         if (XorC->getAPIntValue() == Ones) {
7150           // If the xor constant is a shifted -1, do a 'not' before the shift:
7151           // xor (X << ShiftC), XorC --> (not X) << ShiftC
7152           // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
7153           SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
7154           return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1));
7155         }
7156       }
7157     }
7158   }
7159 
7160   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
7161   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
7162     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
7163     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
7164     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
7165       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
7166       SDValue S0 = S.getOperand(0);
7167       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0)) {
7168         unsigned OpSizeInBits = VT.getScalarSizeInBits();
7169         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
7170           if (C->getAPIntValue() == (OpSizeInBits - 1))
7171             return DAG.getNode(ISD::ABS, DL, VT, S0);
7172       }
7173     }
7174   }
7175 
7176   // fold (xor x, x) -> 0
7177   if (N0 == N1)
7178     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
7179 
7180   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
7181   // Here is a concrete example of this equivalence:
7182   // i16   x ==  14
7183   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
7184   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
7185   //
7186   // =>
7187   //
7188   // i16     ~1      == 0b1111111111111110
7189   // i16 rol(~1, 14) == 0b1011111111111111
7190   //
7191   // Some additional tips to help conceptualize this transform:
7192   // - Try to see the operation as placing a single zero in a value of all ones.
7193   // - There exists no value for x which would allow the result to contain zero.
7194   // - Values of x larger than the bitwidth are undefined and do not require a
7195   //   consistent result.
7196   // - Pushing the zero left requires shifting one bits in from the right.
7197   // A rotate left of ~1 is a nice way of achieving the desired result.
7198   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
7199       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
7200     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
7201                        N0.getOperand(1));
7202   }
7203 
7204   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
7205   if (N0Opcode == N1.getOpcode())
7206     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7207       return V;
7208 
7209   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
7210   if (SDValue MM = unfoldMaskedMerge(N))
7211     return MM;
7212 
7213   // Simplify the expression using non-local knowledge.
7214   if (SimplifyDemandedBits(SDValue(N, 0)))
7215     return SDValue(N, 0);
7216 
7217   if (SDValue Combined = combineCarryDiamond(*this, DAG, TLI, N0, N1, N))
7218     return Combined;
7219 
7220   return SDValue();
7221 }
7222 
7223 /// If we have a shift-by-constant of a bitwise logic op that itself has a
7224 /// shift-by-constant operand with identical opcode, we may be able to convert
7225 /// that into 2 independent shifts followed by the logic op. This is a
7226 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)7227 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
7228   // Match a one-use bitwise logic op.
7229   SDValue LogicOp = Shift->getOperand(0);
7230   if (!LogicOp.hasOneUse())
7231     return SDValue();
7232 
7233   unsigned LogicOpcode = LogicOp.getOpcode();
7234   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
7235       LogicOpcode != ISD::XOR)
7236     return SDValue();
7237 
7238   // Find a matching one-use shift by constant.
7239   unsigned ShiftOpcode = Shift->getOpcode();
7240   SDValue C1 = Shift->getOperand(1);
7241   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
7242   assert(C1Node && "Expected a shift with constant operand");
7243   const APInt &C1Val = C1Node->getAPIntValue();
7244   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
7245                              const APInt *&ShiftAmtVal) {
7246     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
7247       return false;
7248 
7249     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
7250     if (!ShiftCNode)
7251       return false;
7252 
7253     // Capture the shifted operand and shift amount value.
7254     ShiftOp = V.getOperand(0);
7255     ShiftAmtVal = &ShiftCNode->getAPIntValue();
7256 
7257     // Shift amount types do not have to match their operand type, so check that
7258     // the constants are the same width.
7259     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
7260       return false;
7261 
7262     // The fold is not valid if the sum of the shift values exceeds bitwidth.
7263     if ((*ShiftAmtVal + C1Val).uge(V.getScalarValueSizeInBits()))
7264       return false;
7265 
7266     return true;
7267   };
7268 
7269   // Logic ops are commutative, so check each operand for a match.
7270   SDValue X, Y;
7271   const APInt *C0Val;
7272   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
7273     Y = LogicOp.getOperand(1);
7274   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
7275     Y = LogicOp.getOperand(0);
7276   else
7277     return SDValue();
7278 
7279   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
7280   SDLoc DL(Shift);
7281   EVT VT = Shift->getValueType(0);
7282   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
7283   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
7284   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
7285   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
7286   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
7287 }
7288 
7289 /// Handle transforms common to the three shifts, when the shift amount is a
7290 /// constant.
7291 /// We are looking for: (shift being one of shl/sra/srl)
7292 ///   shift (binop X, C0), C1
7293 /// And want to transform into:
7294 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)7295 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
7296   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
7297 
7298   // Do not turn a 'not' into a regular xor.
7299   if (isBitwiseNot(N->getOperand(0)))
7300     return SDValue();
7301 
7302   // The inner binop must be one-use, since we want to replace it.
7303   SDValue LHS = N->getOperand(0);
7304   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
7305     return SDValue();
7306 
7307   // TODO: This is limited to early combining because it may reveal regressions
7308   //       otherwise. But since we just checked a target hook to see if this is
7309   //       desirable, that should have filtered out cases where this interferes
7310   //       with some other pattern matching.
7311   if (!LegalTypes)
7312     if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
7313       return R;
7314 
7315   // We want to pull some binops through shifts, so that we have (and (shift))
7316   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
7317   // thing happens with address calculations, so it's important to canonicalize
7318   // it.
7319   switch (LHS.getOpcode()) {
7320   default:
7321     return SDValue();
7322   case ISD::OR:
7323   case ISD::XOR:
7324   case ISD::AND:
7325     break;
7326   case ISD::ADD:
7327     if (N->getOpcode() != ISD::SHL)
7328       return SDValue(); // only shl(add) not sr[al](add).
7329     break;
7330   }
7331 
7332   // We require the RHS of the binop to be a constant and not opaque as well.
7333   ConstantSDNode *BinOpCst = getAsNonOpaqueConstant(LHS.getOperand(1));
7334   if (!BinOpCst)
7335     return SDValue();
7336 
7337   // FIXME: disable this unless the input to the binop is a shift by a constant
7338   // or is copy/select. Enable this in other cases when figure out it's exactly
7339   // profitable.
7340   SDValue BinOpLHSVal = LHS.getOperand(0);
7341   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
7342                             BinOpLHSVal.getOpcode() == ISD::SRA ||
7343                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
7344                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
7345   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
7346                         BinOpLHSVal.getOpcode() == ISD::SELECT;
7347 
7348   if (!IsShiftByConstant && !IsCopyOrSelect)
7349     return SDValue();
7350 
7351   if (IsCopyOrSelect && N->hasOneUse())
7352     return SDValue();
7353 
7354   // Fold the constants, shifting the binop RHS by the shift amount.
7355   SDLoc DL(N);
7356   EVT VT = N->getValueType(0);
7357   SDValue NewRHS = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(1),
7358                                N->getOperand(1));
7359   assert(isa<ConstantSDNode>(NewRHS) && "Folding was not successful!");
7360 
7361   SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
7362                                  N->getOperand(1));
7363   return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
7364 }
7365 
distributeTruncateThroughAnd(SDNode * N)7366 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
7367   assert(N->getOpcode() == ISD::TRUNCATE);
7368   assert(N->getOperand(0).getOpcode() == ISD::AND);
7369 
7370   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
7371   EVT TruncVT = N->getValueType(0);
7372   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
7373       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
7374     SDValue N01 = N->getOperand(0).getOperand(1);
7375     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
7376       SDLoc DL(N);
7377       SDValue N00 = N->getOperand(0).getOperand(0);
7378       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
7379       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
7380       AddToWorklist(Trunc00.getNode());
7381       AddToWorklist(Trunc01.getNode());
7382       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
7383     }
7384   }
7385 
7386   return SDValue();
7387 }
7388 
visitRotate(SDNode * N)7389 SDValue DAGCombiner::visitRotate(SDNode *N) {
7390   SDLoc dl(N);
7391   SDValue N0 = N->getOperand(0);
7392   SDValue N1 = N->getOperand(1);
7393   EVT VT = N->getValueType(0);
7394   unsigned Bitsize = VT.getScalarSizeInBits();
7395 
7396   // fold (rot x, 0) -> x
7397   if (isNullOrNullSplat(N1))
7398     return N0;
7399 
7400   // fold (rot x, c) -> x iff (c % BitSize) == 0
7401   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
7402     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
7403     if (DAG.MaskedValueIsZero(N1, ModuloMask))
7404       return N0;
7405   }
7406 
7407   // fold (rot x, c) -> (rot x, c % BitSize)
7408   // TODO - support non-uniform vector amounts.
7409   if (ConstantSDNode *Cst = isConstOrConstSplat(N1)) {
7410     if (Cst->getAPIntValue().uge(Bitsize)) {
7411       uint64_t RotAmt = Cst->getAPIntValue().urem(Bitsize);
7412       return DAG.getNode(N->getOpcode(), dl, VT, N0,
7413                          DAG.getConstant(RotAmt, dl, N1.getValueType()));
7414     }
7415   }
7416 
7417   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
7418   if (N1.getOpcode() == ISD::TRUNCATE &&
7419       N1.getOperand(0).getOpcode() == ISD::AND) {
7420     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
7421       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
7422   }
7423 
7424   unsigned NextOp = N0.getOpcode();
7425   // fold (rot* (rot* x, c2), c1) -> (rot* x, c1 +- c2 % bitsize)
7426   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
7427     SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
7428     SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
7429     if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
7430       EVT ShiftVT = C1->getValueType(0);
7431       bool SameSide = (N->getOpcode() == NextOp);
7432       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
7433       if (SDValue CombinedShift =
7434               DAG.FoldConstantArithmetic(CombineOp, dl, ShiftVT, C1, C2)) {
7435         SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
7436         SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
7437             ISD::SREM, dl, ShiftVT, CombinedShift.getNode(),
7438             BitsizeC.getNode());
7439         return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
7440                            CombinedShiftNorm);
7441       }
7442     }
7443   }
7444   return SDValue();
7445 }
7446 
visitSHL(SDNode * N)7447 SDValue DAGCombiner::visitSHL(SDNode *N) {
7448   SDValue N0 = N->getOperand(0);
7449   SDValue N1 = N->getOperand(1);
7450   if (SDValue V = DAG.simplifyShift(N0, N1))
7451     return V;
7452 
7453   EVT VT = N0.getValueType();
7454   EVT ShiftVT = N1.getValueType();
7455   unsigned OpSizeInBits = VT.getScalarSizeInBits();
7456 
7457   // fold vector ops
7458   if (VT.isVector()) {
7459     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7460       return FoldedVOp;
7461 
7462     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
7463     // If setcc produces all-one true value then:
7464     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
7465     if (N1CV && N1CV->isConstant()) {
7466       if (N0.getOpcode() == ISD::AND) {
7467         SDValue N00 = N0->getOperand(0);
7468         SDValue N01 = N0->getOperand(1);
7469         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
7470 
7471         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
7472             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
7473                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
7474           if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT,
7475                                                      N01CV, N1CV))
7476             return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
7477         }
7478       }
7479     }
7480   }
7481 
7482   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7483 
7484   // fold (shl c1, c2) -> c1<<c2
7485   // TODO - support non-uniform vector shift amounts.
7486   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
7487   if (N0C && N1C && !N1C->isOpaque())
7488     return DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, N0C, N1C);
7489 
7490   if (SDValue NewSel = foldBinOpIntoSelect(N))
7491     return NewSel;
7492 
7493   // if (shl x, c) is known to be zero, return 0
7494   if (DAG.MaskedValueIsZero(SDValue(N, 0),
7495                             APInt::getAllOnesValue(OpSizeInBits)))
7496     return DAG.getConstant(0, SDLoc(N), VT);
7497 
7498   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
7499   if (N1.getOpcode() == ISD::TRUNCATE &&
7500       N1.getOperand(0).getOpcode() == ISD::AND) {
7501     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
7502       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
7503   }
7504 
7505   // TODO - support non-uniform vector shift amounts.
7506   if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
7507     return SDValue(N, 0);
7508 
7509   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
7510   if (N0.getOpcode() == ISD::SHL) {
7511     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
7512                                           ConstantSDNode *RHS) {
7513       APInt c1 = LHS->getAPIntValue();
7514       APInt c2 = RHS->getAPIntValue();
7515       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7516       return (c1 + c2).uge(OpSizeInBits);
7517     };
7518     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
7519       return DAG.getConstant(0, SDLoc(N), VT);
7520 
7521     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
7522                                        ConstantSDNode *RHS) {
7523       APInt c1 = LHS->getAPIntValue();
7524       APInt c2 = RHS->getAPIntValue();
7525       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7526       return (c1 + c2).ult(OpSizeInBits);
7527     };
7528     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
7529       SDLoc DL(N);
7530       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
7531       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
7532     }
7533   }
7534 
7535   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
7536   // For this to be valid, the second form must not preserve any of the bits
7537   // that are shifted out by the inner shift in the first form.  This means
7538   // the outer shift size must be >= the number of bits added by the ext.
7539   // As a corollary, we don't care what kind of ext it is.
7540   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
7541        N0.getOpcode() == ISD::ANY_EXTEND ||
7542        N0.getOpcode() == ISD::SIGN_EXTEND) &&
7543       N0.getOperand(0).getOpcode() == ISD::SHL) {
7544     SDValue N0Op0 = N0.getOperand(0);
7545     SDValue InnerShiftAmt = N0Op0.getOperand(1);
7546     EVT InnerVT = N0Op0.getValueType();
7547     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
7548 
7549     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
7550                                                          ConstantSDNode *RHS) {
7551       APInt c1 = LHS->getAPIntValue();
7552       APInt c2 = RHS->getAPIntValue();
7553       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7554       return c2.uge(OpSizeInBits - InnerBitwidth) &&
7555              (c1 + c2).uge(OpSizeInBits);
7556     };
7557     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
7558                                   /*AllowUndefs*/ false,
7559                                   /*AllowTypeMismatch*/ true))
7560       return DAG.getConstant(0, SDLoc(N), VT);
7561 
7562     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
7563                                                       ConstantSDNode *RHS) {
7564       APInt c1 = LHS->getAPIntValue();
7565       APInt c2 = RHS->getAPIntValue();
7566       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7567       return c2.uge(OpSizeInBits - InnerBitwidth) &&
7568              (c1 + c2).ult(OpSizeInBits);
7569     };
7570     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
7571                                   /*AllowUndefs*/ false,
7572                                   /*AllowTypeMismatch*/ true)) {
7573       SDLoc DL(N);
7574       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
7575       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
7576       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
7577       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
7578     }
7579   }
7580 
7581   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
7582   // Only fold this if the inner zext has no other uses to avoid increasing
7583   // the total number of instructions.
7584   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
7585       N0.getOperand(0).getOpcode() == ISD::SRL) {
7586     SDValue N0Op0 = N0.getOperand(0);
7587     SDValue InnerShiftAmt = N0Op0.getOperand(1);
7588 
7589     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7590       APInt c1 = LHS->getAPIntValue();
7591       APInt c2 = RHS->getAPIntValue();
7592       zeroExtendToMatch(c1, c2);
7593       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
7594     };
7595     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
7596                                   /*AllowUndefs*/ false,
7597                                   /*AllowTypeMismatch*/ true)) {
7598       SDLoc DL(N);
7599       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
7600       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
7601       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
7602       AddToWorklist(NewSHL.getNode());
7603       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
7604     }
7605   }
7606 
7607   // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
7608   // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1  > C2
7609   // TODO - support non-uniform vector shift amounts.
7610   if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
7611       N0->getFlags().hasExact()) {
7612     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
7613       uint64_t C1 = N0C1->getZExtValue();
7614       uint64_t C2 = N1C->getZExtValue();
7615       SDLoc DL(N);
7616       if (C1 <= C2)
7617         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
7618                            DAG.getConstant(C2 - C1, DL, ShiftVT));
7619       return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0),
7620                          DAG.getConstant(C1 - C2, DL, ShiftVT));
7621     }
7622   }
7623 
7624   // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
7625   //                               (and (srl x, (sub c1, c2), MASK)
7626   // Only fold this if the inner shift has no other uses -- if it does, folding
7627   // this will increase the total number of instructions.
7628   // TODO - drop hasOneUse requirement if c1 == c2?
7629   // TODO - support non-uniform vector shift amounts.
7630   if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() &&
7631       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
7632     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
7633       if (N0C1->getAPIntValue().ult(OpSizeInBits)) {
7634         uint64_t c1 = N0C1->getZExtValue();
7635         uint64_t c2 = N1C->getZExtValue();
7636         APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1);
7637         SDValue Shift;
7638         if (c2 > c1) {
7639           Mask <<= c2 - c1;
7640           SDLoc DL(N);
7641           Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
7642                               DAG.getConstant(c2 - c1, DL, ShiftVT));
7643         } else {
7644           Mask.lshrInPlace(c1 - c2);
7645           SDLoc DL(N);
7646           Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
7647                               DAG.getConstant(c1 - c2, DL, ShiftVT));
7648         }
7649         SDLoc DL(N0);
7650         return DAG.getNode(ISD::AND, DL, VT, Shift,
7651                            DAG.getConstant(Mask, DL, VT));
7652       }
7653     }
7654   }
7655 
7656   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
7657   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
7658       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
7659     SDLoc DL(N);
7660     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
7661     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
7662     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
7663   }
7664 
7665   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
7666   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
7667   // Variant of version done on multiply, except mul by a power of 2 is turned
7668   // into a shift.
7669   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
7670       N0.getNode()->hasOneUse() &&
7671       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
7672       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
7673       TLI.isDesirableToCommuteWithShift(N, Level)) {
7674     SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
7675     SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
7676     AddToWorklist(Shl0.getNode());
7677     AddToWorklist(Shl1.getNode());
7678     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
7679   }
7680 
7681   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
7682   if (N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse() &&
7683       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
7684       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) {
7685     SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
7686     if (isConstantOrConstantVector(Shl))
7687       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
7688   }
7689 
7690   if (N1C && !N1C->isOpaque())
7691     if (SDValue NewSHL = visitShiftByConstant(N))
7692       return NewSHL;
7693 
7694   return SDValue();
7695 }
7696 
visitSRA(SDNode * N)7697 SDValue DAGCombiner::visitSRA(SDNode *N) {
7698   SDValue N0 = N->getOperand(0);
7699   SDValue N1 = N->getOperand(1);
7700   if (SDValue V = DAG.simplifyShift(N0, N1))
7701     return V;
7702 
7703   EVT VT = N0.getValueType();
7704   unsigned OpSizeInBits = VT.getScalarSizeInBits();
7705 
7706   // Arithmetic shifting an all-sign-bit value is a no-op.
7707   // fold (sra 0, x) -> 0
7708   // fold (sra -1, x) -> -1
7709   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
7710     return N0;
7711 
7712   // fold vector ops
7713   if (VT.isVector())
7714     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7715       return FoldedVOp;
7716 
7717   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7718 
7719   // fold (sra c1, c2) -> (sra c1, c2)
7720   // TODO - support non-uniform vector shift amounts.
7721   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
7722   if (N0C && N1C && !N1C->isOpaque())
7723     return DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, N0C, N1C);
7724 
7725   if (SDValue NewSel = foldBinOpIntoSelect(N))
7726     return NewSel;
7727 
7728   // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
7729   // sext_inreg.
7730   if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
7731     unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
7732     EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
7733     if (VT.isVector())
7734       ExtVT = EVT::getVectorVT(*DAG.getContext(),
7735                                ExtVT, VT.getVectorNumElements());
7736     if (!LegalOperations ||
7737         TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
7738         TargetLowering::Legal)
7739       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
7740                          N0.getOperand(0), DAG.getValueType(ExtVT));
7741   }
7742 
7743   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
7744   // clamp (add c1, c2) to max shift.
7745   if (N0.getOpcode() == ISD::SRA) {
7746     SDLoc DL(N);
7747     EVT ShiftVT = N1.getValueType();
7748     EVT ShiftSVT = ShiftVT.getScalarType();
7749     SmallVector<SDValue, 16> ShiftValues;
7750 
7751     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7752       APInt c1 = LHS->getAPIntValue();
7753       APInt c2 = RHS->getAPIntValue();
7754       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7755       APInt Sum = c1 + c2;
7756       unsigned ShiftSum =
7757           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
7758       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
7759       return true;
7760     };
7761     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
7762       SDValue ShiftValue;
7763       if (VT.isVector())
7764         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
7765       else
7766         ShiftValue = ShiftValues[0];
7767       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
7768     }
7769   }
7770 
7771   // fold (sra (shl X, m), (sub result_size, n))
7772   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
7773   // result_size - n != m.
7774   // If truncate is free for the target sext(shl) is likely to result in better
7775   // code.
7776   if (N0.getOpcode() == ISD::SHL && N1C) {
7777     // Get the two constanst of the shifts, CN0 = m, CN = n.
7778     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
7779     if (N01C) {
7780       LLVMContext &Ctx = *DAG.getContext();
7781       // Determine what the truncate's result bitsize and type would be.
7782       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
7783 
7784       if (VT.isVector())
7785         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements());
7786 
7787       // Determine the residual right-shift amount.
7788       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
7789 
7790       // If the shift is not a no-op (in which case this should be just a sign
7791       // extend already), the truncated to type is legal, sign_extend is legal
7792       // on that type, and the truncate to that type is both legal and free,
7793       // perform the transform.
7794       if ((ShiftAmt > 0) &&
7795           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
7796           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
7797           TLI.isTruncateFree(VT, TruncVT)) {
7798         SDLoc DL(N);
7799         SDValue Amt = DAG.getConstant(ShiftAmt, DL,
7800             getShiftAmountTy(N0.getOperand(0).getValueType()));
7801         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
7802                                     N0.getOperand(0), Amt);
7803         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
7804                                     Shift);
7805         return DAG.getNode(ISD::SIGN_EXTEND, DL,
7806                            N->getValueType(0), Trunc);
7807       }
7808     }
7809   }
7810 
7811   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
7812   //   sra (add (shl X, N1C), AddC), N1C -->
7813   //   sext (add (trunc X to (width - N1C)), AddC')
7814   if (!LegalTypes && N0.getOpcode() == ISD::ADD && N0.hasOneUse() && N1C &&
7815       N0.getOperand(0).getOpcode() == ISD::SHL &&
7816       N0.getOperand(0).getOperand(1) == N1 && N0.getOperand(0).hasOneUse()) {
7817     if (ConstantSDNode *AddC = isConstOrConstSplat(N0.getOperand(1))) {
7818       SDValue Shl = N0.getOperand(0);
7819       // Determine what the truncate's type would be and ask the target if that
7820       // is a free operation.
7821       LLVMContext &Ctx = *DAG.getContext();
7822       unsigned ShiftAmt = N1C->getZExtValue();
7823       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
7824       if (VT.isVector())
7825         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements());
7826 
7827       // TODO: The simple type check probably belongs in the default hook
7828       //       implementation and/or target-specific overrides (because
7829       //       non-simple types likely require masking when legalized), but that
7830       //       restriction may conflict with other transforms.
7831       if (TruncVT.isSimple() && TLI.isTruncateFree(VT, TruncVT)) {
7832         SDLoc DL(N);
7833         SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
7834         SDValue ShiftC = DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).
7835                              trunc(TruncVT.getScalarSizeInBits()), DL, TruncVT);
7836         SDValue Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
7837         return DAG.getSExtOrTrunc(Add, DL, VT);
7838       }
7839     }
7840   }
7841 
7842   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
7843   if (N1.getOpcode() == ISD::TRUNCATE &&
7844       N1.getOperand(0).getOpcode() == ISD::AND) {
7845     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
7846       return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
7847   }
7848 
7849   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
7850   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
7851   //      if c1 is equal to the number of bits the trunc removes
7852   // TODO - support non-uniform vector shift amounts.
7853   if (N0.getOpcode() == ISD::TRUNCATE &&
7854       (N0.getOperand(0).getOpcode() == ISD::SRL ||
7855        N0.getOperand(0).getOpcode() == ISD::SRA) &&
7856       N0.getOperand(0).hasOneUse() &&
7857       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
7858     SDValue N0Op0 = N0.getOperand(0);
7859     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
7860       EVT LargeVT = N0Op0.getValueType();
7861       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
7862       if (LargeShift->getAPIntValue() == TruncBits) {
7863         SDLoc DL(N);
7864         SDValue Amt = DAG.getConstant(N1C->getZExtValue() + TruncBits, DL,
7865                                       getShiftAmountTy(LargeVT));
7866         SDValue SRA =
7867             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
7868         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
7869       }
7870     }
7871   }
7872 
7873   // Simplify, based on bits shifted out of the LHS.
7874   // TODO - support non-uniform vector shift amounts.
7875   if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
7876     return SDValue(N, 0);
7877 
7878   // If the sign bit is known to be zero, switch this to a SRL.
7879   if (DAG.SignBitIsZero(N0))
7880     return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
7881 
7882   if (N1C && !N1C->isOpaque())
7883     if (SDValue NewSRA = visitShiftByConstant(N))
7884       return NewSRA;
7885 
7886   return SDValue();
7887 }
7888 
visitSRL(SDNode * N)7889 SDValue DAGCombiner::visitSRL(SDNode *N) {
7890   SDValue N0 = N->getOperand(0);
7891   SDValue N1 = N->getOperand(1);
7892   if (SDValue V = DAG.simplifyShift(N0, N1))
7893     return V;
7894 
7895   EVT VT = N0.getValueType();
7896   unsigned OpSizeInBits = VT.getScalarSizeInBits();
7897 
7898   // fold vector ops
7899   if (VT.isVector())
7900     if (SDValue FoldedVOp = SimplifyVBinOp(N))
7901       return FoldedVOp;
7902 
7903   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7904 
7905   // fold (srl c1, c2) -> c1 >>u c2
7906   // TODO - support non-uniform vector shift amounts.
7907   ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
7908   if (N0C && N1C && !N1C->isOpaque())
7909     return DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, N0C, N1C);
7910 
7911   if (SDValue NewSel = foldBinOpIntoSelect(N))
7912     return NewSel;
7913 
7914   // if (srl x, c) is known to be zero, return 0
7915   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
7916                                    APInt::getAllOnesValue(OpSizeInBits)))
7917     return DAG.getConstant(0, SDLoc(N), VT);
7918 
7919   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
7920   if (N0.getOpcode() == ISD::SRL) {
7921     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
7922                                           ConstantSDNode *RHS) {
7923       APInt c1 = LHS->getAPIntValue();
7924       APInt c2 = RHS->getAPIntValue();
7925       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7926       return (c1 + c2).uge(OpSizeInBits);
7927     };
7928     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
7929       return DAG.getConstant(0, SDLoc(N), VT);
7930 
7931     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
7932                                        ConstantSDNode *RHS) {
7933       APInt c1 = LHS->getAPIntValue();
7934       APInt c2 = RHS->getAPIntValue();
7935       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7936       return (c1 + c2).ult(OpSizeInBits);
7937     };
7938     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
7939       SDLoc DL(N);
7940       EVT ShiftVT = N1.getValueType();
7941       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
7942       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
7943     }
7944   }
7945 
7946   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
7947       N0.getOperand(0).getOpcode() == ISD::SRL) {
7948     SDValue InnerShift = N0.getOperand(0);
7949     // TODO - support non-uniform vector shift amounts.
7950     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
7951       uint64_t c1 = N001C->getZExtValue();
7952       uint64_t c2 = N1C->getZExtValue();
7953       EVT InnerShiftVT = InnerShift.getValueType();
7954       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
7955       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
7956       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
7957       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
7958       if (c1 + OpSizeInBits == InnerShiftSize) {
7959         SDLoc DL(N);
7960         if (c1 + c2 >= InnerShiftSize)
7961           return DAG.getConstant(0, DL, VT);
7962         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
7963         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
7964                                        InnerShift.getOperand(0), NewShiftAmt);
7965         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
7966       }
7967       // In the more general case, we can clear the high bits after the shift:
7968       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
7969       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
7970           c1 + c2 < InnerShiftSize) {
7971         SDLoc DL(N);
7972         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
7973         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
7974                                        InnerShift.getOperand(0), NewShiftAmt);
7975         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
7976                                                             OpSizeInBits - c2),
7977                                        DL, InnerShiftVT);
7978         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
7979         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
7980       }
7981     }
7982   }
7983 
7984   // fold (srl (shl x, c), c) -> (and x, cst2)
7985   // TODO - (srl (shl x, c1), c2).
7986   if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 &&
7987       isConstantOrConstantVector(N1, /* NoOpaques */ true)) {
7988     SDLoc DL(N);
7989     SDValue Mask =
7990         DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1);
7991     AddToWorklist(Mask.getNode());
7992     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask);
7993   }
7994 
7995   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
7996   // TODO - support non-uniform vector shift amounts.
7997   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7998     // Shifting in all undef bits?
7999     EVT SmallVT = N0.getOperand(0).getValueType();
8000     unsigned BitSize = SmallVT.getScalarSizeInBits();
8001     if (N1C->getAPIntValue().uge(BitSize))
8002       return DAG.getUNDEF(VT);
8003 
8004     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
8005       uint64_t ShiftAmt = N1C->getZExtValue();
8006       SDLoc DL0(N0);
8007       SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
8008                                        N0.getOperand(0),
8009                           DAG.getConstant(ShiftAmt, DL0,
8010                                           getShiftAmountTy(SmallVT)));
8011       AddToWorklist(SmallShift.getNode());
8012       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
8013       SDLoc DL(N);
8014       return DAG.getNode(ISD::AND, DL, VT,
8015                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
8016                          DAG.getConstant(Mask, DL, VT));
8017     }
8018   }
8019 
8020   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
8021   // bit, which is unmodified by sra.
8022   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
8023     if (N0.getOpcode() == ISD::SRA)
8024       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
8025   }
8026 
8027   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit).
8028   if (N1C && N0.getOpcode() == ISD::CTLZ &&
8029       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
8030     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
8031 
8032     // If any of the input bits are KnownOne, then the input couldn't be all
8033     // zeros, thus the result of the srl will always be zero.
8034     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
8035 
8036     // If all of the bits input the to ctlz node are known to be zero, then
8037     // the result of the ctlz is "32" and the result of the shift is one.
8038     APInt UnknownBits = ~Known.Zero;
8039     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
8040 
8041     // Otherwise, check to see if there is exactly one bit input to the ctlz.
8042     if (UnknownBits.isPowerOf2()) {
8043       // Okay, we know that only that the single bit specified by UnknownBits
8044       // could be set on input to the CTLZ node. If this bit is set, the SRL
8045       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
8046       // to an SRL/XOR pair, which is likely to simplify more.
8047       unsigned ShAmt = UnknownBits.countTrailingZeros();
8048       SDValue Op = N0.getOperand(0);
8049 
8050       if (ShAmt) {
8051         SDLoc DL(N0);
8052         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
8053                   DAG.getConstant(ShAmt, DL,
8054                                   getShiftAmountTy(Op.getValueType())));
8055         AddToWorklist(Op.getNode());
8056       }
8057 
8058       SDLoc DL(N);
8059       return DAG.getNode(ISD::XOR, DL, VT,
8060                          Op, DAG.getConstant(1, DL, VT));
8061     }
8062   }
8063 
8064   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
8065   if (N1.getOpcode() == ISD::TRUNCATE &&
8066       N1.getOperand(0).getOpcode() == ISD::AND) {
8067     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8068       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
8069   }
8070 
8071   // fold operands of srl based on knowledge that the low bits are not
8072   // demanded.
8073   // TODO - support non-uniform vector shift amounts.
8074   if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
8075     return SDValue(N, 0);
8076 
8077   if (N1C && !N1C->isOpaque())
8078     if (SDValue NewSRL = visitShiftByConstant(N))
8079       return NewSRL;
8080 
8081   // Attempt to convert a srl of a load into a narrower zero-extending load.
8082   if (SDValue NarrowLoad = ReduceLoadWidth(N))
8083     return NarrowLoad;
8084 
8085   // Here is a common situation. We want to optimize:
8086   //
8087   //   %a = ...
8088   //   %b = and i32 %a, 2
8089   //   %c = srl i32 %b, 1
8090   //   brcond i32 %c ...
8091   //
8092   // into
8093   //
8094   //   %a = ...
8095   //   %b = and %a, 2
8096   //   %c = setcc eq %b, 0
8097   //   brcond %c ...
8098   //
8099   // However when after the source operand of SRL is optimized into AND, the SRL
8100   // itself may not be optimized further. Look for it and add the BRCOND into
8101   // the worklist.
8102   if (N->hasOneUse()) {
8103     SDNode *Use = *N->use_begin();
8104     if (Use->getOpcode() == ISD::BRCOND)
8105       AddToWorklist(Use);
8106     else if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse()) {
8107       // Also look pass the truncate.
8108       Use = *Use->use_begin();
8109       if (Use->getOpcode() == ISD::BRCOND)
8110         AddToWorklist(Use);
8111     }
8112   }
8113 
8114   return SDValue();
8115 }
8116 
visitFunnelShift(SDNode * N)8117 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
8118   EVT VT = N->getValueType(0);
8119   SDValue N0 = N->getOperand(0);
8120   SDValue N1 = N->getOperand(1);
8121   SDValue N2 = N->getOperand(2);
8122   bool IsFSHL = N->getOpcode() == ISD::FSHL;
8123   unsigned BitWidth = VT.getScalarSizeInBits();
8124 
8125   // fold (fshl N0, N1, 0) -> N0
8126   // fold (fshr N0, N1, 0) -> N1
8127   if (isPowerOf2_32(BitWidth))
8128     if (DAG.MaskedValueIsZero(
8129             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
8130       return IsFSHL ? N0 : N1;
8131 
8132   auto IsUndefOrZero = [](SDValue V) {
8133     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
8134   };
8135 
8136   // TODO - support non-uniform vector shift amounts.
8137   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
8138     EVT ShAmtTy = N2.getValueType();
8139 
8140     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
8141     if (Cst->getAPIntValue().uge(BitWidth)) {
8142       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
8143       return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
8144                          DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
8145     }
8146 
8147     unsigned ShAmt = Cst->getZExtValue();
8148     if (ShAmt == 0)
8149       return IsFSHL ? N0 : N1;
8150 
8151     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
8152     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
8153     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
8154     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
8155     if (IsUndefOrZero(N0))
8156       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
8157                          DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
8158                                          SDLoc(N), ShAmtTy));
8159     if (IsUndefOrZero(N1))
8160       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
8161                          DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
8162                                          SDLoc(N), ShAmtTy));
8163   }
8164 
8165   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
8166   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
8167   // iff We know the shift amount is in range.
8168   // TODO: when is it worth doing SUB(BW, N2) as well?
8169   if (isPowerOf2_32(BitWidth)) {
8170     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
8171     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
8172       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
8173     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
8174       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
8175   }
8176 
8177   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
8178   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
8179   // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
8180   // is legal as well we might be better off avoiding non-constant (BW - N2).
8181   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
8182   if (N0 == N1 && hasOperation(RotOpc, VT))
8183     return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
8184 
8185   // Simplify, based on bits shifted out of N0/N1.
8186   if (SimplifyDemandedBits(SDValue(N, 0)))
8187     return SDValue(N, 0);
8188 
8189   return SDValue();
8190 }
8191 
visitABS(SDNode * N)8192 SDValue DAGCombiner::visitABS(SDNode *N) {
8193   SDValue N0 = N->getOperand(0);
8194   EVT VT = N->getValueType(0);
8195 
8196   // fold (abs c1) -> c2
8197   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8198     return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
8199   // fold (abs (abs x)) -> (abs x)
8200   if (N0.getOpcode() == ISD::ABS)
8201     return N0;
8202   // fold (abs x) -> x iff not-negative
8203   if (DAG.SignBitIsZero(N0))
8204     return N0;
8205   return SDValue();
8206 }
8207 
visitBSWAP(SDNode * N)8208 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
8209   SDValue N0 = N->getOperand(0);
8210   EVT VT = N->getValueType(0);
8211 
8212   // fold (bswap c1) -> c2
8213   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8214     return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N0);
8215   // fold (bswap (bswap x)) -> x
8216   if (N0.getOpcode() == ISD::BSWAP)
8217     return N0->getOperand(0);
8218   return SDValue();
8219 }
8220 
visitBITREVERSE(SDNode * N)8221 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
8222   SDValue N0 = N->getOperand(0);
8223   EVT VT = N->getValueType(0);
8224 
8225   // fold (bitreverse c1) -> c2
8226   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8227     return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
8228   // fold (bitreverse (bitreverse x)) -> x
8229   if (N0.getOpcode() == ISD::BITREVERSE)
8230     return N0.getOperand(0);
8231   return SDValue();
8232 }
8233 
visitCTLZ(SDNode * N)8234 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
8235   SDValue N0 = N->getOperand(0);
8236   EVT VT = N->getValueType(0);
8237 
8238   // fold (ctlz c1) -> c2
8239   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8240     return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
8241 
8242   // If the value is known never to be zero, switch to the undef version.
8243   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
8244     if (DAG.isKnownNeverZero(N0))
8245       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8246   }
8247 
8248   return SDValue();
8249 }
8250 
visitCTLZ_ZERO_UNDEF(SDNode * N)8251 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
8252   SDValue N0 = N->getOperand(0);
8253   EVT VT = N->getValueType(0);
8254 
8255   // fold (ctlz_zero_undef c1) -> c2
8256   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8257     return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8258   return SDValue();
8259 }
8260 
visitCTTZ(SDNode * N)8261 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
8262   SDValue N0 = N->getOperand(0);
8263   EVT VT = N->getValueType(0);
8264 
8265   // fold (cttz c1) -> c2
8266   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8267     return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
8268 
8269   // If the value is known never to be zero, switch to the undef version.
8270   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
8271     if (DAG.isKnownNeverZero(N0))
8272       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8273   }
8274 
8275   return SDValue();
8276 }
8277 
visitCTTZ_ZERO_UNDEF(SDNode * N)8278 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
8279   SDValue N0 = N->getOperand(0);
8280   EVT VT = N->getValueType(0);
8281 
8282   // fold (cttz_zero_undef c1) -> c2
8283   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8284     return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
8285   return SDValue();
8286 }
8287 
visitCTPOP(SDNode * N)8288 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
8289   SDValue N0 = N->getOperand(0);
8290   EVT VT = N->getValueType(0);
8291 
8292   // fold (ctpop c1) -> c2
8293   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
8294     return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
8295   return SDValue();
8296 }
8297 
8298 // FIXME: This should be checking for no signed zeros on individual operands, as
8299 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)8300 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
8301                                          SDValue RHS,
8302                                          const TargetLowering &TLI) {
8303   const TargetOptions &Options = DAG.getTarget().Options;
8304   EVT VT = LHS.getValueType();
8305 
8306   return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
8307          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
8308          DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
8309 }
8310 
8311 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)8312 static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
8313                                    SDValue RHS, SDValue True, SDValue False,
8314                                    ISD::CondCode CC, const TargetLowering &TLI,
8315                                    SelectionDAG &DAG) {
8316   if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True))
8317     return SDValue();
8318 
8319   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
8320   switch (CC) {
8321   case ISD::SETOLT:
8322   case ISD::SETOLE:
8323   case ISD::SETLT:
8324   case ISD::SETLE:
8325   case ISD::SETULT:
8326   case ISD::SETULE: {
8327     // Since it's known never nan to get here already, either fminnum or
8328     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
8329     // expanded in terms of it.
8330     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
8331     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
8332       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
8333 
8334     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
8335     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
8336       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
8337     return SDValue();
8338   }
8339   case ISD::SETOGT:
8340   case ISD::SETOGE:
8341   case ISD::SETGT:
8342   case ISD::SETGE:
8343   case ISD::SETUGT:
8344   case ISD::SETUGE: {
8345     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
8346     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
8347       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
8348 
8349     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
8350     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
8351       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
8352     return SDValue();
8353   }
8354   default:
8355     return SDValue();
8356   }
8357 }
8358 
8359 /// If a (v)select has a condition value that is a sign-bit test, try to smear
8360 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)8361 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
8362   SDValue Cond = N->getOperand(0);
8363   SDValue C1 = N->getOperand(1);
8364   SDValue C2 = N->getOperand(2);
8365   assert(isConstantOrConstantVector(C1) && isConstantOrConstantVector(C2) &&
8366          "Expected select-of-constants");
8367 
8368   EVT VT = N->getValueType(0);
8369   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
8370       VT != Cond.getOperand(0).getValueType())
8371     return SDValue();
8372 
8373   // The inverted-condition + commuted-select variants of these patterns are
8374   // canonicalized to these forms in IR.
8375   SDValue X = Cond.getOperand(0);
8376   SDValue CondC = Cond.getOperand(1);
8377   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
8378   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
8379       isAllOnesOrAllOnesSplat(C2)) {
8380     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
8381     SDLoc DL(N);
8382     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
8383     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
8384     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
8385   }
8386   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
8387     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
8388     SDLoc DL(N);
8389     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
8390     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
8391     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
8392   }
8393   return SDValue();
8394 }
8395 
foldSelectOfConstants(SDNode * N)8396 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
8397   SDValue Cond = N->getOperand(0);
8398   SDValue N1 = N->getOperand(1);
8399   SDValue N2 = N->getOperand(2);
8400   EVT VT = N->getValueType(0);
8401   EVT CondVT = Cond.getValueType();
8402   SDLoc DL(N);
8403 
8404   if (!VT.isInteger())
8405     return SDValue();
8406 
8407   auto *C1 = dyn_cast<ConstantSDNode>(N1);
8408   auto *C2 = dyn_cast<ConstantSDNode>(N2);
8409   if (!C1 || !C2)
8410     return SDValue();
8411 
8412   // Only do this before legalization to avoid conflicting with target-specific
8413   // transforms in the other direction (create a select from a zext/sext). There
8414   // is also a target-independent combine here in DAGCombiner in the other
8415   // direction for (select Cond, -1, 0) when the condition is not i1.
8416   if (CondVT == MVT::i1 && !LegalOperations) {
8417     if (C1->isNullValue() && C2->isOne()) {
8418       // select Cond, 0, 1 --> zext (!Cond)
8419       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
8420       if (VT != MVT::i1)
8421         NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond);
8422       return NotCond;
8423     }
8424     if (C1->isNullValue() && C2->isAllOnesValue()) {
8425       // select Cond, 0, -1 --> sext (!Cond)
8426       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
8427       if (VT != MVT::i1)
8428         NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond);
8429       return NotCond;
8430     }
8431     if (C1->isOne() && C2->isNullValue()) {
8432       // select Cond, 1, 0 --> zext (Cond)
8433       if (VT != MVT::i1)
8434         Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
8435       return Cond;
8436     }
8437     if (C1->isAllOnesValue() && C2->isNullValue()) {
8438       // select Cond, -1, 0 --> sext (Cond)
8439       if (VT != MVT::i1)
8440         Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
8441       return Cond;
8442     }
8443 
8444     // Use a target hook because some targets may prefer to transform in the
8445     // other direction.
8446     if (TLI.convertSelectOfConstantsToMath(VT)) {
8447       // For any constants that differ by 1, we can transform the select into an
8448       // extend and add.
8449       const APInt &C1Val = C1->getAPIntValue();
8450       const APInt &C2Val = C2->getAPIntValue();
8451       if (C1Val - 1 == C2Val) {
8452         // select Cond, C1, C1-1 --> add (zext Cond), C1-1
8453         if (VT != MVT::i1)
8454           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
8455         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
8456       }
8457       if (C1Val + 1 == C2Val) {
8458         // select Cond, C1, C1+1 --> add (sext Cond), C1+1
8459         if (VT != MVT::i1)
8460           Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
8461         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
8462       }
8463 
8464       // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
8465       if (C1Val.isPowerOf2() && C2Val.isNullValue()) {
8466         if (VT != MVT::i1)
8467           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
8468         SDValue ShAmtC = DAG.getConstant(C1Val.exactLogBase2(), DL, VT);
8469         return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
8470       }
8471 
8472       if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
8473         return V;
8474     }
8475 
8476     return SDValue();
8477   }
8478 
8479   // fold (select Cond, 0, 1) -> (xor Cond, 1)
8480   // We can't do this reliably if integer based booleans have different contents
8481   // to floating point based booleans. This is because we can't tell whether we
8482   // have an integer-based boolean or a floating-point-based boolean unless we
8483   // can find the SETCC that produced it and inspect its operands. This is
8484   // fairly easy if C is the SETCC node, but it can potentially be
8485   // undiscoverable (or not reasonably discoverable). For example, it could be
8486   // in another basic block or it could require searching a complicated
8487   // expression.
8488   if (CondVT.isInteger() &&
8489       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
8490           TargetLowering::ZeroOrOneBooleanContent &&
8491       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
8492           TargetLowering::ZeroOrOneBooleanContent &&
8493       C1->isNullValue() && C2->isOne()) {
8494     SDValue NotCond =
8495         DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
8496     if (VT.bitsEq(CondVT))
8497       return NotCond;
8498     return DAG.getZExtOrTrunc(NotCond, DL, VT);
8499   }
8500 
8501   return SDValue();
8502 }
8503 
visitSELECT(SDNode * N)8504 SDValue DAGCombiner::visitSELECT(SDNode *N) {
8505   SDValue N0 = N->getOperand(0);
8506   SDValue N1 = N->getOperand(1);
8507   SDValue N2 = N->getOperand(2);
8508   EVT VT = N->getValueType(0);
8509   EVT VT0 = N0.getValueType();
8510   SDLoc DL(N);
8511   SDNodeFlags Flags = N->getFlags();
8512 
8513   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
8514     return V;
8515 
8516   // fold (select X, X, Y) -> (or X, Y)
8517   // fold (select X, 1, Y) -> (or C, Y)
8518   if (VT == VT0 && VT == MVT::i1 && (N0 == N1 || isOneConstant(N1)))
8519     return DAG.getNode(ISD::OR, DL, VT, N0, N2);
8520 
8521   if (SDValue V = foldSelectOfConstants(N))
8522     return V;
8523 
8524   // fold (select C, 0, X) -> (and (not C), X)
8525   if (VT == VT0 && VT == MVT::i1 && isNullConstant(N1)) {
8526     SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
8527     AddToWorklist(NOTNode.getNode());
8528     return DAG.getNode(ISD::AND, DL, VT, NOTNode, N2);
8529   }
8530   // fold (select C, X, 1) -> (or (not C), X)
8531   if (VT == VT0 && VT == MVT::i1 && isOneConstant(N2)) {
8532     SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
8533     AddToWorklist(NOTNode.getNode());
8534     return DAG.getNode(ISD::OR, DL, VT, NOTNode, N1);
8535   }
8536   // fold (select X, Y, X) -> (and X, Y)
8537   // fold (select X, Y, 0) -> (and X, Y)
8538   if (VT == VT0 && VT == MVT::i1 && (N0 == N2 || isNullConstant(N2)))
8539     return DAG.getNode(ISD::AND, DL, VT, N0, N1);
8540 
8541   // If we can fold this based on the true/false value, do so.
8542   if (SimplifySelectOps(N, N1, N2))
8543     return SDValue(N, 0); // Don't revisit N.
8544 
8545   if (VT0 == MVT::i1) {
8546     // The code in this block deals with the following 2 equivalences:
8547     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
8548     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
8549     // The target can specify its preferred form with the
8550     // shouldNormalizeToSelectSequence() callback. However we always transform
8551     // to the right anyway if we find the inner select exists in the DAG anyway
8552     // and we always transform to the left side if we know that we can further
8553     // optimize the combination of the conditions.
8554     bool normalizeToSequence =
8555         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
8556     // select (and Cond0, Cond1), X, Y
8557     //   -> select Cond0, (select Cond1, X, Y), Y
8558     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
8559       SDValue Cond0 = N0->getOperand(0);
8560       SDValue Cond1 = N0->getOperand(1);
8561       SDValue InnerSelect =
8562           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
8563       if (normalizeToSequence || !InnerSelect.use_empty())
8564         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
8565                            InnerSelect, N2, Flags);
8566       // Cleanup on failure.
8567       if (InnerSelect.use_empty())
8568         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
8569     }
8570     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
8571     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
8572       SDValue Cond0 = N0->getOperand(0);
8573       SDValue Cond1 = N0->getOperand(1);
8574       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
8575                                         Cond1, N1, N2, Flags);
8576       if (normalizeToSequence || !InnerSelect.use_empty())
8577         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
8578                            InnerSelect, Flags);
8579       // Cleanup on failure.
8580       if (InnerSelect.use_empty())
8581         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
8582     }
8583 
8584     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
8585     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
8586       SDValue N1_0 = N1->getOperand(0);
8587       SDValue N1_1 = N1->getOperand(1);
8588       SDValue N1_2 = N1->getOperand(2);
8589       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
8590         // Create the actual and node if we can generate good code for it.
8591         if (!normalizeToSequence) {
8592           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
8593           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
8594                              N2, Flags);
8595         }
8596         // Otherwise see if we can optimize the "and" to a better pattern.
8597         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
8598           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
8599                              N2, Flags);
8600         }
8601       }
8602     }
8603     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
8604     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
8605       SDValue N2_0 = N2->getOperand(0);
8606       SDValue N2_1 = N2->getOperand(1);
8607       SDValue N2_2 = N2->getOperand(2);
8608       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
8609         // Create the actual or node if we can generate good code for it.
8610         if (!normalizeToSequence) {
8611           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
8612           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
8613                              N2_2, Flags);
8614         }
8615         // Otherwise see if we can optimize to a better pattern.
8616         if (SDValue Combined = visitORLike(N0, N2_0, N))
8617           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
8618                              N2_2, Flags);
8619       }
8620     }
8621   }
8622 
8623   // select (not Cond), N1, N2 -> select Cond, N2, N1
8624   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
8625     SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
8626     SelectOp->setFlags(Flags);
8627     return SelectOp;
8628   }
8629 
8630   // Fold selects based on a setcc into other things, such as min/max/abs.
8631   if (N0.getOpcode() == ISD::SETCC) {
8632     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
8633     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
8634 
8635     // select (fcmp lt x, y), x, y -> fminnum x, y
8636     // select (fcmp gt x, y), x, y -> fmaxnum x, y
8637     //
8638     // This is OK if we don't care what happens if either operand is a NaN.
8639     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
8640       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2,
8641                                                 CC, TLI, DAG))
8642         return FMinMax;
8643 
8644     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
8645     // This is conservatively limited to pre-legal-operations to give targets
8646     // a chance to reverse the transform if they want to do that. Also, it is
8647     // unlikely that the pattern would be formed late, so it's probably not
8648     // worth going through the other checks.
8649     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
8650         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
8651         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
8652       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
8653       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
8654       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
8655         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
8656         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
8657         //
8658         // The IR equivalent of this transform would have this form:
8659         //   %a = add %x, C
8660         //   %c = icmp ugt %x, ~C
8661         //   %r = select %c, -1, %a
8662         //   =>
8663         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
8664         //   %u0 = extractvalue %u, 0
8665         //   %u1 = extractvalue %u, 1
8666         //   %r = select %u1, -1, %u0
8667         SDVTList VTs = DAG.getVTList(VT, VT0);
8668         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
8669         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
8670       }
8671     }
8672 
8673     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
8674         (!LegalOperations &&
8675          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
8676       // Any flags available in a select/setcc fold will be on the setcc as they
8677       // migrated from fcmp
8678       Flags = N0.getNode()->getFlags();
8679       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
8680                                        N2, N0.getOperand(2));
8681       SelectNode->setFlags(Flags);
8682       return SelectNode;
8683     }
8684 
8685     return SimplifySelect(DL, N0, N1, N2);
8686   }
8687 
8688   return SDValue();
8689 }
8690 
8691 // This function assumes all the vselect's arguments are CONCAT_VECTOR
8692 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)8693 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
8694   SDLoc DL(N);
8695   SDValue Cond = N->getOperand(0);
8696   SDValue LHS = N->getOperand(1);
8697   SDValue RHS = N->getOperand(2);
8698   EVT VT = N->getValueType(0);
8699   int NumElems = VT.getVectorNumElements();
8700   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
8701          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
8702          Cond.getOpcode() == ISD::BUILD_VECTOR);
8703 
8704   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
8705   // binary ones here.
8706   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
8707     return SDValue();
8708 
8709   // We're sure we have an even number of elements due to the
8710   // concat_vectors we have as arguments to vselect.
8711   // Skip BV elements until we find one that's not an UNDEF
8712   // After we find an UNDEF element, keep looping until we get to half the
8713   // length of the BV and see if all the non-undef nodes are the same.
8714   ConstantSDNode *BottomHalf = nullptr;
8715   for (int i = 0; i < NumElems / 2; ++i) {
8716     if (Cond->getOperand(i)->isUndef())
8717       continue;
8718 
8719     if (BottomHalf == nullptr)
8720       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
8721     else if (Cond->getOperand(i).getNode() != BottomHalf)
8722       return SDValue();
8723   }
8724 
8725   // Do the same for the second half of the BuildVector
8726   ConstantSDNode *TopHalf = nullptr;
8727   for (int i = NumElems / 2; i < NumElems; ++i) {
8728     if (Cond->getOperand(i)->isUndef())
8729       continue;
8730 
8731     if (TopHalf == nullptr)
8732       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
8733     else if (Cond->getOperand(i).getNode() != TopHalf)
8734       return SDValue();
8735   }
8736 
8737   assert(TopHalf && BottomHalf &&
8738          "One half of the selector was all UNDEFs and the other was all the "
8739          "same value. This should have been addressed before this function.");
8740   return DAG.getNode(
8741       ISD::CONCAT_VECTORS, DL, VT,
8742       BottomHalf->isNullValue() ? RHS->getOperand(0) : LHS->getOperand(0),
8743       TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1));
8744 }
8745 
visitMSCATTER(SDNode * N)8746 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
8747   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
8748   SDValue Mask = MSC->getMask();
8749   SDValue Chain = MSC->getChain();
8750   SDLoc DL(N);
8751 
8752   // Zap scatters with a zero mask.
8753   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
8754     return Chain;
8755 
8756   return SDValue();
8757 }
8758 
visitMSTORE(SDNode * N)8759 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
8760   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
8761   SDValue Mask = MST->getMask();
8762   SDValue Chain = MST->getChain();
8763   SDLoc DL(N);
8764 
8765   // Zap masked stores with a zero mask.
8766   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
8767     return Chain;
8768 
8769   // Try transforming N to an indexed store.
8770   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
8771     return SDValue(N, 0);
8772 
8773   return SDValue();
8774 }
8775 
visitMGATHER(SDNode * N)8776 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
8777   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
8778   SDValue Mask = MGT->getMask();
8779   SDLoc DL(N);
8780 
8781   // Zap gathers with a zero mask.
8782   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
8783     return CombineTo(N, MGT->getPassThru(), MGT->getChain());
8784 
8785   return SDValue();
8786 }
8787 
visitMLOAD(SDNode * N)8788 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
8789   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
8790   SDValue Mask = MLD->getMask();
8791   SDLoc DL(N);
8792 
8793   // Zap masked loads with a zero mask.
8794   if (ISD::isBuildVectorAllZeros(Mask.getNode()))
8795     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
8796 
8797   // Try transforming N to an indexed load.
8798   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
8799     return SDValue(N, 0);
8800 
8801   return SDValue();
8802 }
8803 
8804 /// A vector select of 2 constant vectors can be simplified to math/logic to
8805 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)8806 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
8807   SDValue Cond = N->getOperand(0);
8808   SDValue N1 = N->getOperand(1);
8809   SDValue N2 = N->getOperand(2);
8810   EVT VT = N->getValueType(0);
8811   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
8812       !TLI.convertSelectOfConstantsToMath(VT) ||
8813       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
8814       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
8815     return SDValue();
8816 
8817   // Check if we can use the condition value to increment/decrement a single
8818   // constant value. This simplifies a select to an add and removes a constant
8819   // load/materialization from the general case.
8820   bool AllAddOne = true;
8821   bool AllSubOne = true;
8822   unsigned Elts = VT.getVectorNumElements();
8823   for (unsigned i = 0; i != Elts; ++i) {
8824     SDValue N1Elt = N1.getOperand(i);
8825     SDValue N2Elt = N2.getOperand(i);
8826     if (N1Elt.isUndef() || N2Elt.isUndef())
8827       continue;
8828 
8829     const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
8830     const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
8831     if (C1 != C2 + 1)
8832       AllAddOne = false;
8833     if (C1 != C2 - 1)
8834       AllSubOne = false;
8835   }
8836 
8837   // Further simplifications for the extra-special cases where the constants are
8838   // all 0 or all -1 should be implemented as folds of these patterns.
8839   SDLoc DL(N);
8840   if (AllAddOne || AllSubOne) {
8841     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
8842     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
8843     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
8844     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
8845     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
8846   }
8847 
8848   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
8849   APInt Pow2C;
8850   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
8851       isNullOrNullSplat(N2)) {
8852     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
8853     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
8854     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
8855   }
8856 
8857   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
8858     return V;
8859 
8860   // The general case for select-of-constants:
8861   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
8862   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
8863   // leave that to a machine-specific pass.
8864   return SDValue();
8865 }
8866 
visitVSELECT(SDNode * N)8867 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
8868   SDValue N0 = N->getOperand(0);
8869   SDValue N1 = N->getOperand(1);
8870   SDValue N2 = N->getOperand(2);
8871   EVT VT = N->getValueType(0);
8872   SDLoc DL(N);
8873 
8874   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
8875     return V;
8876 
8877   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
8878   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
8879     return DAG.getSelect(DL, VT, F, N2, N1);
8880 
8881   // Canonicalize integer abs.
8882   // vselect (setg[te] X,  0),  X, -X ->
8883   // vselect (setgt    X, -1),  X, -X ->
8884   // vselect (setl[te] X,  0), -X,  X ->
8885   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
8886   if (N0.getOpcode() == ISD::SETCC) {
8887     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
8888     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
8889     bool isAbs = false;
8890     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
8891 
8892     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
8893          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
8894         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
8895       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
8896     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
8897              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
8898       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
8899 
8900     if (isAbs) {
8901       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
8902         return DAG.getNode(ISD::ABS, DL, VT, LHS);
8903 
8904       SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
8905                                   DAG.getConstant(VT.getScalarSizeInBits() - 1,
8906                                                   DL, getShiftAmountTy(VT)));
8907       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
8908       AddToWorklist(Shift.getNode());
8909       AddToWorklist(Add.getNode());
8910       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
8911     }
8912 
8913     // vselect x, y (fcmp lt x, y) -> fminnum x, y
8914     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
8915     //
8916     // This is OK if we don't care about what happens if either operand is a
8917     // NaN.
8918     //
8919     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
8920       if (SDValue FMinMax =
8921               combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC, TLI, DAG))
8922         return FMinMax;
8923     }
8924 
8925     // If this select has a condition (setcc) with narrower operands than the
8926     // select, try to widen the compare to match the select width.
8927     // TODO: This should be extended to handle any constant.
8928     // TODO: This could be extended to handle non-loading patterns, but that
8929     //       requires thorough testing to avoid regressions.
8930     if (isNullOrNullSplat(RHS)) {
8931       EVT NarrowVT = LHS.getValueType();
8932       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
8933       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
8934       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
8935       unsigned WideWidth = WideVT.getScalarSizeInBits();
8936       bool IsSigned = isSignedIntSetCC(CC);
8937       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
8938       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
8939           SetCCWidth != 1 && SetCCWidth < WideWidth &&
8940           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
8941           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
8942         // Both compare operands can be widened for free. The LHS can use an
8943         // extended load, and the RHS is a constant:
8944         //   vselect (ext (setcc load(X), C)), N1, N2 -->
8945         //   vselect (setcc extload(X), C'), N1, N2
8946         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
8947         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
8948         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
8949         EVT WideSetCCVT = getSetCCResultType(WideVT);
8950         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
8951         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
8952       }
8953     }
8954   }
8955 
8956   if (SimplifySelectOps(N, N1, N2))
8957     return SDValue(N, 0);  // Don't revisit N.
8958 
8959   // Fold (vselect (build_vector all_ones), N1, N2) -> N1
8960   if (ISD::isBuildVectorAllOnes(N0.getNode()))
8961     return N1;
8962   // Fold (vselect (build_vector all_zeros), N1, N2) -> N2
8963   if (ISD::isBuildVectorAllZeros(N0.getNode()))
8964     return N2;
8965 
8966   // The ConvertSelectToConcatVector function is assuming both the above
8967   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
8968   // and addressed.
8969   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
8970       N2.getOpcode() == ISD::CONCAT_VECTORS &&
8971       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
8972     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
8973       return CV;
8974   }
8975 
8976   if (SDValue V = foldVSelectOfConstants(N))
8977     return V;
8978 
8979   return SDValue();
8980 }
8981 
visitSELECT_CC(SDNode * N)8982 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
8983   SDValue N0 = N->getOperand(0);
8984   SDValue N1 = N->getOperand(1);
8985   SDValue N2 = N->getOperand(2);
8986   SDValue N3 = N->getOperand(3);
8987   SDValue N4 = N->getOperand(4);
8988   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
8989 
8990   // fold select_cc lhs, rhs, x, x, cc -> x
8991   if (N2 == N3)
8992     return N2;
8993 
8994   // Determine if the condition we're dealing with is constant
8995   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
8996                                   CC, SDLoc(N), false)) {
8997     AddToWorklist(SCC.getNode());
8998 
8999     if (ConstantSDNode *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode())) {
9000       if (!SCCC->isNullValue())
9001         return N2;    // cond always true -> true val
9002       else
9003         return N3;    // cond always false -> false val
9004     } else if (SCC->isUndef()) {
9005       // When the condition is UNDEF, just return the first operand. This is
9006       // coherent the DAG creation, no setcc node is created in this case
9007       return N2;
9008     } else if (SCC.getOpcode() == ISD::SETCC) {
9009       // Fold to a simpler select_cc
9010       SDValue SelectOp = DAG.getNode(
9011           ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
9012           SCC.getOperand(1), N2, N3, SCC.getOperand(2));
9013       SelectOp->setFlags(SCC->getFlags());
9014       return SelectOp;
9015     }
9016   }
9017 
9018   // If we can fold this based on the true/false value, do so.
9019   if (SimplifySelectOps(N, N2, N3))
9020     return SDValue(N, 0);  // Don't revisit N.
9021 
9022   // fold select_cc into other things, such as min/max/abs
9023   return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
9024 }
9025 
visitSETCC(SDNode * N)9026 SDValue DAGCombiner::visitSETCC(SDNode *N) {
9027   // setcc is very commonly used as an argument to brcond. This pattern
9028   // also lend itself to numerous combines and, as a result, it is desired
9029   // we keep the argument to a brcond as a setcc as much as possible.
9030   bool PreferSetCC =
9031       N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
9032 
9033   SDValue Combined = SimplifySetCC(
9034       N->getValueType(0), N->getOperand(0), N->getOperand(1),
9035       cast<CondCodeSDNode>(N->getOperand(2))->get(), SDLoc(N), !PreferSetCC);
9036 
9037   if (!Combined)
9038     return SDValue();
9039 
9040   // If we prefer to have a setcc, and we don't, we'll try our best to
9041   // recreate one using rebuildSetCC.
9042   if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
9043     SDValue NewSetCC = rebuildSetCC(Combined);
9044 
9045     // We don't have anything interesting to combine to.
9046     if (NewSetCC.getNode() == N)
9047       return SDValue();
9048 
9049     if (NewSetCC)
9050       return NewSetCC;
9051   }
9052 
9053   return Combined;
9054 }
9055 
visitSETCCCARRY(SDNode * N)9056 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
9057   SDValue LHS = N->getOperand(0);
9058   SDValue RHS = N->getOperand(1);
9059   SDValue Carry = N->getOperand(2);
9060   SDValue Cond = N->getOperand(3);
9061 
9062   // If Carry is false, fold to a regular SETCC.
9063   if (isNullConstant(Carry))
9064     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
9065 
9066   return SDValue();
9067 }
9068 
9069 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
9070 /// a build_vector of constants.
9071 /// This function is called by the DAGCombiner when visiting sext/zext/aext
9072 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
9073 /// Vector extends are not folded if operations are legal; this is to
9074 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)9075 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
9076                                          SelectionDAG &DAG, bool LegalTypes) {
9077   unsigned Opcode = N->getOpcode();
9078   SDValue N0 = N->getOperand(0);
9079   EVT VT = N->getValueType(0);
9080   SDLoc DL(N);
9081 
9082   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
9083          Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
9084          Opcode == ISD::ZERO_EXTEND_VECTOR_INREG)
9085          && "Expected EXTEND dag node in input!");
9086 
9087   // fold (sext c1) -> c1
9088   // fold (zext c1) -> c1
9089   // fold (aext c1) -> c1
9090   if (isa<ConstantSDNode>(N0))
9091     return DAG.getNode(Opcode, DL, VT, N0);
9092 
9093   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
9094   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
9095   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
9096   if (N0->getOpcode() == ISD::SELECT) {
9097     SDValue Op1 = N0->getOperand(1);
9098     SDValue Op2 = N0->getOperand(2);
9099     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
9100         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
9101       // For any_extend, choose sign extension of the constants to allow a
9102       // possible further transform to sign_extend_inreg.i.e.
9103       //
9104       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
9105       // t2: i64 = any_extend t1
9106       // -->
9107       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
9108       // -->
9109       // t4: i64 = sign_extend_inreg t3
9110       unsigned FoldOpc = Opcode;
9111       if (FoldOpc == ISD::ANY_EXTEND)
9112         FoldOpc = ISD::SIGN_EXTEND;
9113       return DAG.getSelect(DL, VT, N0->getOperand(0),
9114                            DAG.getNode(FoldOpc, DL, VT, Op1),
9115                            DAG.getNode(FoldOpc, DL, VT, Op2));
9116     }
9117   }
9118 
9119   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
9120   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
9121   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
9122   EVT SVT = VT.getScalarType();
9123   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
9124       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
9125     return SDValue();
9126 
9127   // We can fold this node into a build_vector.
9128   unsigned VTBits = SVT.getSizeInBits();
9129   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
9130   SmallVector<SDValue, 8> Elts;
9131   unsigned NumElts = VT.getVectorNumElements();
9132 
9133   // For zero-extensions, UNDEF elements still guarantee to have the upper
9134   // bits set to zero.
9135   bool IsZext =
9136       Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG;
9137 
9138   for (unsigned i = 0; i != NumElts; ++i) {
9139     SDValue Op = N0.getOperand(i);
9140     if (Op.isUndef()) {
9141       Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT));
9142       continue;
9143     }
9144 
9145     SDLoc DL(Op);
9146     // Get the constant value and if needed trunc it to the size of the type.
9147     // Nodes like build_vector might have constants wider than the scalar type.
9148     APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().zextOrTrunc(EVTBits);
9149     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
9150       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
9151     else
9152       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
9153   }
9154 
9155   return DAG.getBuildVector(VT, DL, Elts);
9156 }
9157 
9158 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
9159 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
9160 // transformation. Returns true if extension are possible and the above
9161 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)9162 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
9163                                     unsigned ExtOpc,
9164                                     SmallVectorImpl<SDNode *> &ExtendNodes,
9165                                     const TargetLowering &TLI) {
9166   bool HasCopyToRegUses = false;
9167   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
9168   for (SDNode::use_iterator UI = N0.getNode()->use_begin(),
9169                             UE = N0.getNode()->use_end();
9170        UI != UE; ++UI) {
9171     SDNode *User = *UI;
9172     if (User == N)
9173       continue;
9174     if (UI.getUse().getResNo() != N0.getResNo())
9175       continue;
9176     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
9177     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
9178       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
9179       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
9180         // Sign bits will be lost after a zext.
9181         return false;
9182       bool Add = false;
9183       for (unsigned i = 0; i != 2; ++i) {
9184         SDValue UseOp = User->getOperand(i);
9185         if (UseOp == N0)
9186           continue;
9187         if (!isa<ConstantSDNode>(UseOp))
9188           return false;
9189         Add = true;
9190       }
9191       if (Add)
9192         ExtendNodes.push_back(User);
9193       continue;
9194     }
9195     // If truncates aren't free and there are users we can't
9196     // extend, it isn't worthwhile.
9197     if (!isTruncFree)
9198       return false;
9199     // Remember if this value is live-out.
9200     if (User->getOpcode() == ISD::CopyToReg)
9201       HasCopyToRegUses = true;
9202   }
9203 
9204   if (HasCopyToRegUses) {
9205     bool BothLiveOut = false;
9206     for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
9207          UI != UE; ++UI) {
9208       SDUse &Use = UI.getUse();
9209       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
9210         BothLiveOut = true;
9211         break;
9212       }
9213     }
9214     if (BothLiveOut)
9215       // Both unextended and extended values are live out. There had better be
9216       // a good reason for the transformation.
9217       return ExtendNodes.size();
9218   }
9219   return true;
9220 }
9221 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)9222 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
9223                                   SDValue OrigLoad, SDValue ExtLoad,
9224                                   ISD::NodeType ExtType) {
9225   // Extend SetCC uses if necessary.
9226   SDLoc DL(ExtLoad);
9227   for (SDNode *SetCC : SetCCs) {
9228     SmallVector<SDValue, 4> Ops;
9229 
9230     for (unsigned j = 0; j != 2; ++j) {
9231       SDValue SOp = SetCC->getOperand(j);
9232       if (SOp == OrigLoad)
9233         Ops.push_back(ExtLoad);
9234       else
9235         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
9236     }
9237 
9238     Ops.push_back(SetCC->getOperand(2));
9239     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
9240   }
9241 }
9242 
9243 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)9244 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
9245   SDValue N0 = N->getOperand(0);
9246   EVT DstVT = N->getValueType(0);
9247   EVT SrcVT = N0.getValueType();
9248 
9249   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
9250           N->getOpcode() == ISD::ZERO_EXTEND) &&
9251          "Unexpected node type (not an extend)!");
9252 
9253   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
9254   // For example, on a target with legal v4i32, but illegal v8i32, turn:
9255   //   (v8i32 (sext (v8i16 (load x))))
9256   // into:
9257   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
9258   //                          (v4i32 (sextload (x + 16)))))
9259   // Where uses of the original load, i.e.:
9260   //   (v8i16 (load x))
9261   // are replaced with:
9262   //   (v8i16 (truncate
9263   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
9264   //                            (v4i32 (sextload (x + 16)))))))
9265   //
9266   // This combine is only applicable to illegal, but splittable, vectors.
9267   // All legal types, and illegal non-vector types, are handled elsewhere.
9268   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
9269   //
9270   if (N0->getOpcode() != ISD::LOAD)
9271     return SDValue();
9272 
9273   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
9274 
9275   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
9276       !N0.hasOneUse() || !LN0->isSimple() ||
9277       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
9278       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
9279     return SDValue();
9280 
9281   SmallVector<SDNode *, 4> SetCCs;
9282   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
9283     return SDValue();
9284 
9285   ISD::LoadExtType ExtType =
9286       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
9287 
9288   // Try to split the vector types to get down to legal types.
9289   EVT SplitSrcVT = SrcVT;
9290   EVT SplitDstVT = DstVT;
9291   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
9292          SplitSrcVT.getVectorNumElements() > 1) {
9293     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
9294     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
9295   }
9296 
9297   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
9298     return SDValue();
9299 
9300   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
9301 
9302   SDLoc DL(N);
9303   const unsigned NumSplits =
9304       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
9305   const unsigned Stride = SplitSrcVT.getStoreSize();
9306   SmallVector<SDValue, 4> Loads;
9307   SmallVector<SDValue, 4> Chains;
9308 
9309   SDValue BasePtr = LN0->getBasePtr();
9310   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
9311     const unsigned Offset = Idx * Stride;
9312     const unsigned Align = MinAlign(LN0->getAlignment(), Offset);
9313 
9314     SDValue SplitLoad = DAG.getExtLoad(
9315         ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
9316         LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
9317         LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
9318 
9319     BasePtr = DAG.getMemBasePlusOffset(BasePtr, Stride, DL);
9320 
9321     Loads.push_back(SplitLoad.getValue(0));
9322     Chains.push_back(SplitLoad.getValue(1));
9323   }
9324 
9325   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
9326   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
9327 
9328   // Simplify TF.
9329   AddToWorklist(NewChain.getNode());
9330 
9331   CombineTo(N, NewValue);
9332 
9333   // Replace uses of the original load (before extension)
9334   // with a truncate of the concatenated sextloaded vectors.
9335   SDValue Trunc =
9336       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
9337   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
9338   CombineTo(N0.getNode(), Trunc, NewChain);
9339   return SDValue(N, 0); // Return N so it doesn't get rechecked!
9340 }
9341 
9342 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
9343 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)9344 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
9345   assert(N->getOpcode() == ISD::ZERO_EXTEND);
9346   EVT VT = N->getValueType(0);
9347   EVT OrigVT = N->getOperand(0).getValueType();
9348   if (TLI.isZExtFree(OrigVT, VT))
9349     return SDValue();
9350 
9351   // and/or/xor
9352   SDValue N0 = N->getOperand(0);
9353   if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
9354         N0.getOpcode() == ISD::XOR) ||
9355       N0.getOperand(1).getOpcode() != ISD::Constant ||
9356       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
9357     return SDValue();
9358 
9359   // shl/shr
9360   SDValue N1 = N0->getOperand(0);
9361   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
9362       N1.getOperand(1).getOpcode() != ISD::Constant ||
9363       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
9364     return SDValue();
9365 
9366   // load
9367   if (!isa<LoadSDNode>(N1.getOperand(0)))
9368     return SDValue();
9369   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
9370   EVT MemVT = Load->getMemoryVT();
9371   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
9372       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
9373     return SDValue();
9374 
9375 
9376   // If the shift op is SHL, the logic op must be AND, otherwise the result
9377   // will be wrong.
9378   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
9379     return SDValue();
9380 
9381   if (!N0.hasOneUse() || !N1.hasOneUse())
9382     return SDValue();
9383 
9384   SmallVector<SDNode*, 4> SetCCs;
9385   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
9386                                ISD::ZERO_EXTEND, SetCCs, TLI))
9387     return SDValue();
9388 
9389   // Actually do the transformation.
9390   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
9391                                    Load->getChain(), Load->getBasePtr(),
9392                                    Load->getMemoryVT(), Load->getMemOperand());
9393 
9394   SDLoc DL1(N1);
9395   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
9396                               N1.getOperand(1));
9397 
9398   APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
9399   Mask = Mask.zext(VT.getSizeInBits());
9400   SDLoc DL0(N0);
9401   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
9402                             DAG.getConstant(Mask, DL0, VT));
9403 
9404   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
9405   CombineTo(N, And);
9406   if (SDValue(Load, 0).hasOneUse()) {
9407     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
9408   } else {
9409     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
9410                                 Load->getValueType(0), ExtLoad);
9411     CombineTo(Load, Trunc, ExtLoad.getValue(1));
9412   }
9413 
9414   // N0 is dead at this point.
9415   recursivelyDeleteUnusedNodes(N0.getNode());
9416 
9417   return SDValue(N,0); // Return N so it doesn't get rechecked!
9418 }
9419 
9420 /// If we're narrowing or widening the result of a vector select and the final
9421 /// size is the same size as a setcc (compare) feeding the select, then try to
9422 /// apply the cast operation to the select's operands because matching vector
9423 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)9424 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
9425   unsigned CastOpcode = Cast->getOpcode();
9426   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
9427           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
9428           CastOpcode == ISD::FP_ROUND) &&
9429          "Unexpected opcode for vector select narrowing/widening");
9430 
9431   // We only do this transform before legal ops because the pattern may be
9432   // obfuscated by target-specific operations after legalization. Do not create
9433   // an illegal select op, however, because that may be difficult to lower.
9434   EVT VT = Cast->getValueType(0);
9435   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
9436     return SDValue();
9437 
9438   SDValue VSel = Cast->getOperand(0);
9439   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
9440       VSel.getOperand(0).getOpcode() != ISD::SETCC)
9441     return SDValue();
9442 
9443   // Does the setcc have the same vector size as the casted select?
9444   SDValue SetCC = VSel.getOperand(0);
9445   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
9446   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
9447     return SDValue();
9448 
9449   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
9450   SDValue A = VSel.getOperand(1);
9451   SDValue B = VSel.getOperand(2);
9452   SDValue CastA, CastB;
9453   SDLoc DL(Cast);
9454   if (CastOpcode == ISD::FP_ROUND) {
9455     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
9456     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
9457     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
9458   } else {
9459     CastA = DAG.getNode(CastOpcode, DL, VT, A);
9460     CastB = DAG.getNode(CastOpcode, DL, VT, B);
9461   }
9462   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
9463 }
9464 
9465 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
9466 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)9467 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
9468                                      const TargetLowering &TLI, EVT VT,
9469                                      bool LegalOperations, SDNode *N,
9470                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
9471   SDNode *N0Node = N0.getNode();
9472   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
9473                                                    : ISD::isZEXTLoad(N0Node);
9474   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
9475       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
9476     return SDValue();
9477 
9478   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
9479   EVT MemVT = LN0->getMemoryVT();
9480   if ((LegalOperations || !LN0->isSimple() ||
9481        VT.isVector()) &&
9482       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
9483     return SDValue();
9484 
9485   SDValue ExtLoad =
9486       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
9487                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
9488   Combiner.CombineTo(N, ExtLoad);
9489   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
9490   if (LN0->use_empty())
9491     Combiner.recursivelyDeleteUnusedNodes(LN0);
9492   return SDValue(N, 0); // Return N so it doesn't get rechecked!
9493 }
9494 
9495 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
9496 // Only generate vector extloads when 1) they're legal, and 2) they are
9497 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)9498 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
9499                                   const TargetLowering &TLI, EVT VT,
9500                                   bool LegalOperations, SDNode *N, SDValue N0,
9501                                   ISD::LoadExtType ExtLoadType,
9502                                   ISD::NodeType ExtOpc) {
9503   if (!ISD::isNON_EXTLoad(N0.getNode()) ||
9504       !ISD::isUNINDEXEDLoad(N0.getNode()) ||
9505       ((LegalOperations || VT.isVector() ||
9506         !cast<LoadSDNode>(N0)->isSimple()) &&
9507        !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
9508     return {};
9509 
9510   bool DoXform = true;
9511   SmallVector<SDNode *, 4> SetCCs;
9512   if (!N0.hasOneUse())
9513     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
9514   if (VT.isVector())
9515     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
9516   if (!DoXform)
9517     return {};
9518 
9519   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
9520   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
9521                                    LN0->getBasePtr(), N0.getValueType(),
9522                                    LN0->getMemOperand());
9523   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
9524   // If the load value is used only by N, replace it via CombineTo N.
9525   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
9526   Combiner.CombineTo(N, ExtLoad);
9527   if (NoReplaceTrunc) {
9528     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
9529     Combiner.recursivelyDeleteUnusedNodes(LN0);
9530   } else {
9531     SDValue Trunc =
9532         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
9533     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
9534   }
9535   return SDValue(N, 0); // Return N so it doesn't get rechecked!
9536 }
9537 
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)9538 static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG,
9539                                         const TargetLowering &TLI, EVT VT,
9540                                         SDNode *N, SDValue N0,
9541                                         ISD::LoadExtType ExtLoadType,
9542                                         ISD::NodeType ExtOpc) {
9543   if (!N0.hasOneUse())
9544     return SDValue();
9545 
9546   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
9547   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
9548     return SDValue();
9549 
9550   if (!TLI.isLoadExtLegal(ExtLoadType, VT, Ld->getValueType(0)))
9551     return SDValue();
9552 
9553   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
9554     return SDValue();
9555 
9556   SDLoc dl(Ld);
9557   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
9558   SDValue NewLoad = DAG.getMaskedLoad(
9559       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
9560       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
9561       ExtLoadType, Ld->isExpandingLoad());
9562   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
9563   return NewLoad;
9564 }
9565 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)9566 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
9567                                        bool LegalOperations) {
9568   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
9569           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
9570 
9571   SDValue SetCC = N->getOperand(0);
9572   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
9573       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
9574     return SDValue();
9575 
9576   SDValue X = SetCC.getOperand(0);
9577   SDValue Ones = SetCC.getOperand(1);
9578   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
9579   EVT VT = N->getValueType(0);
9580   EVT XVT = X.getValueType();
9581   // setge X, C is canonicalized to setgt, so we do not need to match that
9582   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
9583   // not require the 'not' op.
9584   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
9585     // Invert and smear/shift the sign bit:
9586     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
9587     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
9588     SDLoc DL(N);
9589     unsigned ShCt = VT.getSizeInBits() - 1;
9590     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9591     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
9592       SDValue NotX = DAG.getNOT(DL, X, VT);
9593       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
9594       auto ShiftOpcode =
9595         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
9596       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
9597     }
9598   }
9599   return SDValue();
9600 }
9601 
visitSIGN_EXTEND(SDNode * N)9602 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
9603   SDValue N0 = N->getOperand(0);
9604   EVT VT = N->getValueType(0);
9605   SDLoc DL(N);
9606 
9607   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
9608     return Res;
9609 
9610   // fold (sext (sext x)) -> (sext x)
9611   // fold (sext (aext x)) -> (sext x)
9612   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
9613     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
9614 
9615   if (N0.getOpcode() == ISD::TRUNCATE) {
9616     // fold (sext (truncate (load x))) -> (sext (smaller load x))
9617     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
9618     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
9619       SDNode *oye = N0.getOperand(0).getNode();
9620       if (NarrowLoad.getNode() != N0.getNode()) {
9621         CombineTo(N0.getNode(), NarrowLoad);
9622         // CombineTo deleted the truncate, if needed, but not what's under it.
9623         AddToWorklist(oye);
9624       }
9625       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
9626     }
9627 
9628     // See if the value being truncated is already sign extended.  If so, just
9629     // eliminate the trunc/sext pair.
9630     SDValue Op = N0.getOperand(0);
9631     unsigned OpBits   = Op.getScalarValueSizeInBits();
9632     unsigned MidBits  = N0.getScalarValueSizeInBits();
9633     unsigned DestBits = VT.getScalarSizeInBits();
9634     unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
9635 
9636     if (OpBits == DestBits) {
9637       // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
9638       // bits, it is already ready.
9639       if (NumSignBits > DestBits-MidBits)
9640         return Op;
9641     } else if (OpBits < DestBits) {
9642       // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
9643       // bits, just sext from i32.
9644       if (NumSignBits > OpBits-MidBits)
9645         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
9646     } else {
9647       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
9648       // bits, just truncate to i32.
9649       if (NumSignBits > OpBits-MidBits)
9650         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
9651     }
9652 
9653     // fold (sext (truncate x)) -> (sextinreg x).
9654     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
9655                                                  N0.getValueType())) {
9656       if (OpBits < DestBits)
9657         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
9658       else if (OpBits > DestBits)
9659         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
9660       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
9661                          DAG.getValueType(N0.getValueType()));
9662     }
9663   }
9664 
9665   // Try to simplify (sext (load x)).
9666   if (SDValue foldedExt =
9667           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
9668                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
9669     return foldedExt;
9670 
9671   if (SDValue foldedExt =
9672       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD,
9673                                ISD::SIGN_EXTEND))
9674     return foldedExt;
9675 
9676   // fold (sext (load x)) to multiple smaller sextloads.
9677   // Only on illegal but splittable vectors.
9678   if (SDValue ExtLoad = CombineExtLoad(N))
9679     return ExtLoad;
9680 
9681   // Try to simplify (sext (sextload x)).
9682   if (SDValue foldedExt = tryToFoldExtOfExtload(
9683           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
9684     return foldedExt;
9685 
9686   // fold (sext (and/or/xor (load x), cst)) ->
9687   //      (and/or/xor (sextload x), (sext cst))
9688   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
9689        N0.getOpcode() == ISD::XOR) &&
9690       isa<LoadSDNode>(N0.getOperand(0)) &&
9691       N0.getOperand(1).getOpcode() == ISD::Constant &&
9692       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
9693     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
9694     EVT MemVT = LN00->getMemoryVT();
9695     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
9696       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
9697       SmallVector<SDNode*, 4> SetCCs;
9698       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
9699                                              ISD::SIGN_EXTEND, SetCCs, TLI);
9700       if (DoXform) {
9701         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
9702                                          LN00->getChain(), LN00->getBasePtr(),
9703                                          LN00->getMemoryVT(),
9704                                          LN00->getMemOperand());
9705         APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
9706         Mask = Mask.sext(VT.getSizeInBits());
9707         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
9708                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
9709         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
9710         bool NoReplaceTruncAnd = !N0.hasOneUse();
9711         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
9712         CombineTo(N, And);
9713         // If N0 has multiple uses, change other uses as well.
9714         if (NoReplaceTruncAnd) {
9715           SDValue TruncAnd =
9716               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
9717           CombineTo(N0.getNode(), TruncAnd);
9718         }
9719         if (NoReplaceTrunc) {
9720           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
9721         } else {
9722           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
9723                                       LN00->getValueType(0), ExtLoad);
9724           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
9725         }
9726         return SDValue(N,0); // Return N so it doesn't get rechecked!
9727       }
9728     }
9729   }
9730 
9731   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
9732     return V;
9733 
9734   if (N0.getOpcode() == ISD::SETCC) {
9735     SDValue N00 = N0.getOperand(0);
9736     SDValue N01 = N0.getOperand(1);
9737     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
9738     EVT N00VT = N0.getOperand(0).getValueType();
9739 
9740     // sext(setcc) -> sext_in_reg(vsetcc) for vectors.
9741     // Only do this before legalize for now.
9742     if (VT.isVector() && !LegalOperations &&
9743         TLI.getBooleanContents(N00VT) ==
9744             TargetLowering::ZeroOrNegativeOneBooleanContent) {
9745       // On some architectures (such as SSE/NEON/etc) the SETCC result type is
9746       // of the same size as the compared operands. Only optimize sext(setcc())
9747       // if this is the case.
9748       EVT SVT = getSetCCResultType(N00VT);
9749 
9750       // If we already have the desired type, don't change it.
9751       if (SVT != N0.getValueType()) {
9752         // We know that the # elements of the results is the same as the
9753         // # elements of the compare (and the # elements of the compare result
9754         // for that matter).  Check to see that they are the same size.  If so,
9755         // we know that the element size of the sext'd result matches the
9756         // element size of the compare operands.
9757         if (VT.getSizeInBits() == SVT.getSizeInBits())
9758           return DAG.getSetCC(DL, VT, N00, N01, CC);
9759 
9760         // If the desired elements are smaller or larger than the source
9761         // elements, we can use a matching integer vector type and then
9762         // truncate/sign extend.
9763         EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
9764         if (SVT == MatchingVecType) {
9765           SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
9766           return DAG.getSExtOrTrunc(VsetCC, DL, VT);
9767         }
9768       }
9769     }
9770 
9771     // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
9772     // Here, T can be 1 or -1, depending on the type of the setcc and
9773     // getBooleanContents().
9774     unsigned SetCCWidth = N0.getScalarValueSizeInBits();
9775 
9776     // To determine the "true" side of the select, we need to know the high bit
9777     // of the value returned by the setcc if it evaluates to true.
9778     // If the type of the setcc is i1, then the true case of the select is just
9779     // sext(i1 1), that is, -1.
9780     // If the type of the setcc is larger (say, i8) then the value of the high
9781     // bit depends on getBooleanContents(), so ask TLI for a real "true" value
9782     // of the appropriate width.
9783     SDValue ExtTrueVal = (SetCCWidth == 1)
9784                              ? DAG.getAllOnesConstant(DL, VT)
9785                              : DAG.getBoolConstant(true, DL, VT, N00VT);
9786     SDValue Zero = DAG.getConstant(0, DL, VT);
9787     if (SDValue SCC =
9788             SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
9789       return SCC;
9790 
9791     if (!VT.isVector() && !TLI.convertSelectOfConstantsToMath(VT)) {
9792       EVT SetCCVT = getSetCCResultType(N00VT);
9793       // Don't do this transform for i1 because there's a select transform
9794       // that would reverse it.
9795       // TODO: We should not do this transform at all without a target hook
9796       // because a sext is likely cheaper than a select?
9797       if (SetCCVT.getScalarSizeInBits() != 1 &&
9798           (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
9799         SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
9800         return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
9801       }
9802     }
9803   }
9804 
9805   // fold (sext x) -> (zext x) if the sign bit is known zero.
9806   if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
9807       DAG.SignBitIsZero(N0))
9808     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
9809 
9810   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
9811     return NewVSel;
9812 
9813   // Eliminate this sign extend by doing a negation in the destination type:
9814   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
9815   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
9816       isNullOrNullSplat(N0.getOperand(0)) &&
9817       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
9818       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
9819     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
9820     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Zext);
9821   }
9822   // Eliminate this sign extend by doing a decrement in the destination type:
9823   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
9824   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
9825       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
9826       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
9827       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
9828     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
9829     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
9830   }
9831 
9832   return SDValue();
9833 }
9834 
9835 // isTruncateOf - If N is a truncate of some other value, return true, record
9836 // the value being truncated in Op and which of Op's bits are zero/one in Known.
9837 // This function computes KnownBits to avoid a duplicated call to
9838 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)9839 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
9840                          KnownBits &Known) {
9841   if (N->getOpcode() == ISD::TRUNCATE) {
9842     Op = N->getOperand(0);
9843     Known = DAG.computeKnownBits(Op);
9844     return true;
9845   }
9846 
9847   if (N.getOpcode() != ISD::SETCC ||
9848       N.getValueType().getScalarType() != MVT::i1 ||
9849       cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
9850     return false;
9851 
9852   SDValue Op0 = N->getOperand(0);
9853   SDValue Op1 = N->getOperand(1);
9854   assert(Op0.getValueType() == Op1.getValueType());
9855 
9856   if (isNullOrNullSplat(Op0))
9857     Op = Op1;
9858   else if (isNullOrNullSplat(Op1))
9859     Op = Op0;
9860   else
9861     return false;
9862 
9863   Known = DAG.computeKnownBits(Op);
9864 
9865   return (Known.Zero | 1).isAllOnesValue();
9866 }
9867 
9868 /// Given an extending node with a pop-count operand, if the target does not
9869 /// support a pop-count in the narrow source type but does support it in the
9870 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)9871 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
9872   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
9873           Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
9874 
9875   SDValue CtPop = Extend->getOperand(0);
9876   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
9877     return SDValue();
9878 
9879   EVT VT = Extend->getValueType(0);
9880   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9881   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
9882       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
9883     return SDValue();
9884 
9885   // zext (ctpop X) --> ctpop (zext X)
9886   SDLoc DL(Extend);
9887   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
9888   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
9889 }
9890 
visitZERO_EXTEND(SDNode * N)9891 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
9892   SDValue N0 = N->getOperand(0);
9893   EVT VT = N->getValueType(0);
9894 
9895   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
9896     return Res;
9897 
9898   // fold (zext (zext x)) -> (zext x)
9899   // fold (zext (aext x)) -> (zext x)
9900   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
9901     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT,
9902                        N0.getOperand(0));
9903 
9904   // fold (zext (truncate x)) -> (zext x) or
9905   //      (zext (truncate x)) -> (truncate x)
9906   // This is valid when the truncated bits of x are already zero.
9907   SDValue Op;
9908   KnownBits Known;
9909   if (isTruncateOf(DAG, N0, Op, Known)) {
9910     APInt TruncatedBits =
9911       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
9912       APInt(Op.getScalarValueSizeInBits(), 0) :
9913       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
9914                         N0.getScalarValueSizeInBits(),
9915                         std::min(Op.getScalarValueSizeInBits(),
9916                                  VT.getScalarSizeInBits()));
9917     if (TruncatedBits.isSubsetOf(Known.Zero))
9918       return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
9919   }
9920 
9921   // fold (zext (truncate x)) -> (and x, mask)
9922   if (N0.getOpcode() == ISD::TRUNCATE) {
9923     // fold (zext (truncate (load x))) -> (zext (smaller load x))
9924     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
9925     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
9926       SDNode *oye = N0.getOperand(0).getNode();
9927       if (NarrowLoad.getNode() != N0.getNode()) {
9928         CombineTo(N0.getNode(), NarrowLoad);
9929         // CombineTo deleted the truncate, if needed, but not what's under it.
9930         AddToWorklist(oye);
9931       }
9932       return SDValue(N, 0); // Return N so it doesn't get rechecked!
9933     }
9934 
9935     EVT SrcVT = N0.getOperand(0).getValueType();
9936     EVT MinVT = N0.getValueType();
9937 
9938     // Try to mask before the extension to avoid having to generate a larger mask,
9939     // possibly over several sub-vectors.
9940     if (SrcVT.bitsLT(VT) && VT.isVector()) {
9941       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
9942                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
9943         SDValue Op = N0.getOperand(0);
9944         Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
9945         AddToWorklist(Op.getNode());
9946         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
9947         // Transfer the debug info; the new node is equivalent to N0.
9948         DAG.transferDbgValues(N0, ZExtOrTrunc);
9949         return ZExtOrTrunc;
9950       }
9951     }
9952 
9953     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
9954       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
9955       AddToWorklist(Op.getNode());
9956       SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
9957       // We may safely transfer the debug info describing the truncate node over
9958       // to the equivalent and operation.
9959       DAG.transferDbgValues(N0, And);
9960       return And;
9961     }
9962   }
9963 
9964   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
9965   // if either of the casts is not free.
9966   if (N0.getOpcode() == ISD::AND &&
9967       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
9968       N0.getOperand(1).getOpcode() == ISD::Constant &&
9969       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
9970                            N0.getValueType()) ||
9971        !TLI.isZExtFree(N0.getValueType(), VT))) {
9972     SDValue X = N0.getOperand(0).getOperand(0);
9973     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
9974     APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
9975     Mask = Mask.zext(VT.getSizeInBits());
9976     SDLoc DL(N);
9977     return DAG.getNode(ISD::AND, DL, VT,
9978                        X, DAG.getConstant(Mask, DL, VT));
9979   }
9980 
9981   // Try to simplify (zext (load x)).
9982   if (SDValue foldedExt =
9983           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
9984                              ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
9985     return foldedExt;
9986 
9987   if (SDValue foldedExt =
9988       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD,
9989                                ISD::ZERO_EXTEND))
9990     return foldedExt;
9991 
9992   // fold (zext (load x)) to multiple smaller zextloads.
9993   // Only on illegal but splittable vectors.
9994   if (SDValue ExtLoad = CombineExtLoad(N))
9995     return ExtLoad;
9996 
9997   // fold (zext (and/or/xor (load x), cst)) ->
9998   //      (and/or/xor (zextload x), (zext cst))
9999   // Unless (and (load x) cst) will match as a zextload already and has
10000   // additional users.
10001   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
10002        N0.getOpcode() == ISD::XOR) &&
10003       isa<LoadSDNode>(N0.getOperand(0)) &&
10004       N0.getOperand(1).getOpcode() == ISD::Constant &&
10005       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
10006     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
10007     EVT MemVT = LN00->getMemoryVT();
10008     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
10009         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
10010       bool DoXform = true;
10011       SmallVector<SDNode*, 4> SetCCs;
10012       if (!N0.hasOneUse()) {
10013         if (N0.getOpcode() == ISD::AND) {
10014           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
10015           EVT LoadResultTy = AndC->getValueType(0);
10016           EVT ExtVT;
10017           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
10018             DoXform = false;
10019         }
10020       }
10021       if (DoXform)
10022         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
10023                                           ISD::ZERO_EXTEND, SetCCs, TLI);
10024       if (DoXform) {
10025         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
10026                                          LN00->getChain(), LN00->getBasePtr(),
10027                                          LN00->getMemoryVT(),
10028                                          LN00->getMemOperand());
10029         APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
10030         Mask = Mask.zext(VT.getSizeInBits());
10031         SDLoc DL(N);
10032         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
10033                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
10034         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
10035         bool NoReplaceTruncAnd = !N0.hasOneUse();
10036         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
10037         CombineTo(N, And);
10038         // If N0 has multiple uses, change other uses as well.
10039         if (NoReplaceTruncAnd) {
10040           SDValue TruncAnd =
10041               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
10042           CombineTo(N0.getNode(), TruncAnd);
10043         }
10044         if (NoReplaceTrunc) {
10045           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
10046         } else {
10047           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
10048                                       LN00->getValueType(0), ExtLoad);
10049           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
10050         }
10051         return SDValue(N,0); // Return N so it doesn't get rechecked!
10052       }
10053     }
10054   }
10055 
10056   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
10057   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
10058   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
10059     return ZExtLoad;
10060 
10061   // Try to simplify (zext (zextload x)).
10062   if (SDValue foldedExt = tryToFoldExtOfExtload(
10063           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
10064     return foldedExt;
10065 
10066   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
10067     return V;
10068 
10069   if (N0.getOpcode() == ISD::SETCC) {
10070     // Only do this before legalize for now.
10071     if (!LegalOperations && VT.isVector() &&
10072         N0.getValueType().getVectorElementType() == MVT::i1) {
10073       EVT N00VT = N0.getOperand(0).getValueType();
10074       if (getSetCCResultType(N00VT) == N0.getValueType())
10075         return SDValue();
10076 
10077       // We know that the # elements of the results is the same as the #
10078       // elements of the compare (and the # elements of the compare result for
10079       // that matter). Check to see that they are the same size. If so, we know
10080       // that the element size of the sext'd result matches the element size of
10081       // the compare operands.
10082       SDLoc DL(N);
10083       SDValue VecOnes = DAG.getConstant(1, DL, VT);
10084       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
10085         // zext(setcc) -> (and (vsetcc), (1, 1, ...) for vectors.
10086         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
10087                                      N0.getOperand(1), N0.getOperand(2));
10088         return DAG.getNode(ISD::AND, DL, VT, VSetCC, VecOnes);
10089       }
10090 
10091       // If the desired elements are smaller or larger than the source
10092       // elements we can use a matching integer vector type and then
10093       // truncate/sign extend.
10094       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
10095       SDValue VsetCC =
10096           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
10097                       N0.getOperand(1), N0.getOperand(2));
10098       return DAG.getNode(ISD::AND, DL, VT, DAG.getSExtOrTrunc(VsetCC, DL, VT),
10099                          VecOnes);
10100     }
10101 
10102     // zext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
10103     SDLoc DL(N);
10104     if (SDValue SCC = SimplifySelectCC(
10105             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
10106             DAG.getConstant(0, DL, VT),
10107             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
10108       return SCC;
10109   }
10110 
10111   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
10112   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
10113       isa<ConstantSDNode>(N0.getOperand(1)) &&
10114       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
10115       N0.hasOneUse()) {
10116     SDValue ShAmt = N0.getOperand(1);
10117     if (N0.getOpcode() == ISD::SHL) {
10118       SDValue InnerZExt = N0.getOperand(0);
10119       // If the original shl may be shifting out bits, do not perform this
10120       // transformation.
10121       unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
10122         InnerZExt.getOperand(0).getValueSizeInBits();
10123       if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits))
10124         return SDValue();
10125     }
10126 
10127     SDLoc DL(N);
10128 
10129     // Ensure that the shift amount is wide enough for the shifted value.
10130     if (VT.getSizeInBits() >= 256)
10131       ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
10132 
10133     return DAG.getNode(N0.getOpcode(), DL, VT,
10134                        DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
10135                        ShAmt);
10136   }
10137 
10138   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
10139     return NewVSel;
10140 
10141   if (SDValue NewCtPop = widenCtPop(N, DAG))
10142     return NewCtPop;
10143 
10144   return SDValue();
10145 }
10146 
visitANY_EXTEND(SDNode * N)10147 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
10148   SDValue N0 = N->getOperand(0);
10149   EVT VT = N->getValueType(0);
10150 
10151   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
10152     return Res;
10153 
10154   // fold (aext (aext x)) -> (aext x)
10155   // fold (aext (zext x)) -> (zext x)
10156   // fold (aext (sext x)) -> (sext x)
10157   if (N0.getOpcode() == ISD::ANY_EXTEND  ||
10158       N0.getOpcode() == ISD::ZERO_EXTEND ||
10159       N0.getOpcode() == ISD::SIGN_EXTEND)
10160     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
10161 
10162   // fold (aext (truncate (load x))) -> (aext (smaller load x))
10163   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
10164   if (N0.getOpcode() == ISD::TRUNCATE) {
10165     if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
10166       SDNode *oye = N0.getOperand(0).getNode();
10167       if (NarrowLoad.getNode() != N0.getNode()) {
10168         CombineTo(N0.getNode(), NarrowLoad);
10169         // CombineTo deleted the truncate, if needed, but not what's under it.
10170         AddToWorklist(oye);
10171       }
10172       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10173     }
10174   }
10175 
10176   // fold (aext (truncate x))
10177   if (N0.getOpcode() == ISD::TRUNCATE)
10178     return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
10179 
10180   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
10181   // if the trunc is not free.
10182   if (N0.getOpcode() == ISD::AND &&
10183       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
10184       N0.getOperand(1).getOpcode() == ISD::Constant &&
10185       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
10186                           N0.getValueType())) {
10187     SDLoc DL(N);
10188     SDValue X = N0.getOperand(0).getOperand(0);
10189     X = DAG.getAnyExtOrTrunc(X, DL, VT);
10190     APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
10191     Mask = Mask.zext(VT.getSizeInBits());
10192     return DAG.getNode(ISD::AND, DL, VT,
10193                        X, DAG.getConstant(Mask, DL, VT));
10194   }
10195 
10196   // fold (aext (load x)) -> (aext (truncate (extload x)))
10197   // None of the supported targets knows how to perform load and any_ext
10198   // on vectors in one instruction.  We only perform this transformation on
10199   // scalars.
10200   if (ISD::isNON_EXTLoad(N0.getNode()) && !VT.isVector() &&
10201       ISD::isUNINDEXEDLoad(N0.getNode()) &&
10202       TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
10203     bool DoXform = true;
10204     SmallVector<SDNode*, 4> SetCCs;
10205     if (!N0.hasOneUse())
10206       DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs,
10207                                         TLI);
10208     if (DoXform) {
10209       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10210       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
10211                                        LN0->getChain(),
10212                                        LN0->getBasePtr(), N0.getValueType(),
10213                                        LN0->getMemOperand());
10214       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
10215       // If the load value is used only by N, replace it via CombineTo N.
10216       bool NoReplaceTrunc = N0.hasOneUse();
10217       CombineTo(N, ExtLoad);
10218       if (NoReplaceTrunc) {
10219         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
10220         recursivelyDeleteUnusedNodes(LN0);
10221       } else {
10222         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
10223                                     N0.getValueType(), ExtLoad);
10224         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
10225       }
10226       return SDValue(N, 0); // Return N so it doesn't get rechecked!
10227     }
10228   }
10229 
10230   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
10231   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
10232   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
10233   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
10234       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
10235     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10236     ISD::LoadExtType ExtType = LN0->getExtensionType();
10237     EVT MemVT = LN0->getMemoryVT();
10238     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
10239       SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
10240                                        VT, LN0->getChain(), LN0->getBasePtr(),
10241                                        MemVT, LN0->getMemOperand());
10242       CombineTo(N, ExtLoad);
10243       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
10244       recursivelyDeleteUnusedNodes(LN0);
10245       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10246     }
10247   }
10248 
10249   if (N0.getOpcode() == ISD::SETCC) {
10250     // For vectors:
10251     // aext(setcc) -> vsetcc
10252     // aext(setcc) -> truncate(vsetcc)
10253     // aext(setcc) -> aext(vsetcc)
10254     // Only do this before legalize for now.
10255     if (VT.isVector() && !LegalOperations) {
10256       EVT N00VT = N0.getOperand(0).getValueType();
10257       if (getSetCCResultType(N00VT) == N0.getValueType())
10258         return SDValue();
10259 
10260       // We know that the # elements of the results is the same as the
10261       // # elements of the compare (and the # elements of the compare result
10262       // for that matter).  Check to see that they are the same size.  If so,
10263       // we know that the element size of the sext'd result matches the
10264       // element size of the compare operands.
10265       if (VT.getSizeInBits() == N00VT.getSizeInBits())
10266         return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
10267                              N0.getOperand(1),
10268                              cast<CondCodeSDNode>(N0.getOperand(2))->get());
10269 
10270       // If the desired elements are smaller or larger than the source
10271       // elements we can use a matching integer vector type and then
10272       // truncate/any extend
10273       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
10274       SDValue VsetCC =
10275         DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
10276                       N0.getOperand(1),
10277                       cast<CondCodeSDNode>(N0.getOperand(2))->get());
10278       return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
10279     }
10280 
10281     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
10282     SDLoc DL(N);
10283     if (SDValue SCC = SimplifySelectCC(
10284             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
10285             DAG.getConstant(0, DL, VT),
10286             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
10287       return SCC;
10288   }
10289 
10290   if (SDValue NewCtPop = widenCtPop(N, DAG))
10291     return NewCtPop;
10292 
10293   return SDValue();
10294 }
10295 
visitAssertExt(SDNode * N)10296 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
10297   unsigned Opcode = N->getOpcode();
10298   SDValue N0 = N->getOperand(0);
10299   SDValue N1 = N->getOperand(1);
10300   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
10301 
10302   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
10303   if (N0.getOpcode() == Opcode &&
10304       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
10305     return N0;
10306 
10307   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
10308       N0.getOperand(0).getOpcode() == Opcode) {
10309     // We have an assert, truncate, assert sandwich. Make one stronger assert
10310     // by asserting on the smallest asserted type to the larger source type.
10311     // This eliminates the later assert:
10312     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
10313     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
10314     SDValue BigA = N0.getOperand(0);
10315     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
10316     assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
10317            "Asserting zero/sign-extended bits to a type larger than the "
10318            "truncated destination does not provide information");
10319 
10320     SDLoc DL(N);
10321     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
10322     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
10323     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
10324                                     BigA.getOperand(0), MinAssertVTVal);
10325     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
10326   }
10327 
10328   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
10329   // than X. Just move the AssertZext in front of the truncate and drop the
10330   // AssertSExt.
10331   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
10332       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
10333       Opcode == ISD::AssertZext) {
10334     SDValue BigA = N0.getOperand(0);
10335     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
10336     assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
10337            "Asserting zero/sign-extended bits to a type larger than the "
10338            "truncated destination does not provide information");
10339 
10340     if (AssertVT.bitsLT(BigA_AssertVT)) {
10341       SDLoc DL(N);
10342       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
10343                                       BigA.getOperand(0), N1);
10344       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
10345     }
10346   }
10347 
10348   return SDValue();
10349 }
10350 
10351 /// If the result of a wider load is shifted to right of N  bits and then
10352 /// truncated to a narrower type and where N is a multiple of number of bits of
10353 /// the narrower type, transform it to a narrower load from address + N / num of
10354 /// bits of new type. Also narrow the load if the result is masked with an AND
10355 /// to effectively produce a smaller type. If the result is to be extended, also
10356 /// fold the extension to form a extending load.
ReduceLoadWidth(SDNode * N)10357 SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
10358   unsigned Opc = N->getOpcode();
10359 
10360   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
10361   SDValue N0 = N->getOperand(0);
10362   EVT VT = N->getValueType(0);
10363   EVT ExtVT = VT;
10364 
10365   // This transformation isn't valid for vector loads.
10366   if (VT.isVector())
10367     return SDValue();
10368 
10369   unsigned ShAmt = 0;
10370   bool HasShiftedOffset = false;
10371   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
10372   // extended to VT.
10373   if (Opc == ISD::SIGN_EXTEND_INREG) {
10374     ExtType = ISD::SEXTLOAD;
10375     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
10376   } else if (Opc == ISD::SRL) {
10377     // Another special-case: SRL is basically zero-extending a narrower value,
10378     // or it maybe shifting a higher subword, half or byte into the lowest
10379     // bits.
10380     ExtType = ISD::ZEXTLOAD;
10381     N0 = SDValue(N, 0);
10382 
10383     auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
10384     auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
10385     if (!N01 || !LN0)
10386       return SDValue();
10387 
10388     uint64_t ShiftAmt = N01->getZExtValue();
10389     uint64_t MemoryWidth = LN0->getMemoryVT().getSizeInBits();
10390     if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
10391       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
10392     else
10393       ExtVT = EVT::getIntegerVT(*DAG.getContext(),
10394                                 VT.getSizeInBits() - ShiftAmt);
10395   } else if (Opc == ISD::AND) {
10396     // An AND with a constant mask is the same as a truncate + zero-extend.
10397     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
10398     if (!AndC)
10399       return SDValue();
10400 
10401     const APInt &Mask = AndC->getAPIntValue();
10402     unsigned ActiveBits = 0;
10403     if (Mask.isMask()) {
10404       ActiveBits = Mask.countTrailingOnes();
10405     } else if (Mask.isShiftedMask()) {
10406       ShAmt = Mask.countTrailingZeros();
10407       APInt ShiftedMask = Mask.lshr(ShAmt);
10408       ActiveBits = ShiftedMask.countTrailingOnes();
10409       HasShiftedOffset = true;
10410     } else
10411       return SDValue();
10412 
10413     ExtType = ISD::ZEXTLOAD;
10414     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
10415   }
10416 
10417   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
10418     SDValue SRL = N0;
10419     if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
10420       ShAmt = ConstShift->getZExtValue();
10421       unsigned EVTBits = ExtVT.getSizeInBits();
10422       // Is the shift amount a multiple of size of VT?
10423       if ((ShAmt & (EVTBits-1)) == 0) {
10424         N0 = N0.getOperand(0);
10425         // Is the load width a multiple of size of VT?
10426         if ((N0.getValueSizeInBits() & (EVTBits-1)) != 0)
10427           return SDValue();
10428       }
10429 
10430       // At this point, we must have a load or else we can't do the transform.
10431       if (!isa<LoadSDNode>(N0)) return SDValue();
10432 
10433       auto *LN0 = cast<LoadSDNode>(N0);
10434 
10435       // Because a SRL must be assumed to *need* to zero-extend the high bits
10436       // (as opposed to anyext the high bits), we can't combine the zextload
10437       // lowering of SRL and an sextload.
10438       if (LN0->getExtensionType() == ISD::SEXTLOAD)
10439         return SDValue();
10440 
10441       // If the shift amount is larger than the input type then we're not
10442       // accessing any of the loaded bytes.  If the load was a zextload/extload
10443       // then the result of the shift+trunc is zero/undef (handled elsewhere).
10444       if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
10445         return SDValue();
10446 
10447       // If the SRL is only used by a masking AND, we may be able to adjust
10448       // the ExtVT to make the AND redundant.
10449       SDNode *Mask = *(SRL->use_begin());
10450       if (Mask->getOpcode() == ISD::AND &&
10451           isa<ConstantSDNode>(Mask->getOperand(1))) {
10452         const APInt &ShiftMask =
10453           cast<ConstantSDNode>(Mask->getOperand(1))->getAPIntValue();
10454         if (ShiftMask.isMask()) {
10455           EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
10456                                            ShiftMask.countTrailingOnes());
10457           // If the mask is smaller, recompute the type.
10458           if ((ExtVT.getSizeInBits() > MaskedVT.getSizeInBits()) &&
10459               TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
10460             ExtVT = MaskedVT;
10461         }
10462       }
10463     }
10464   }
10465 
10466   // If the load is shifted left (and the result isn't shifted back right),
10467   // we can fold the truncate through the shift.
10468   unsigned ShLeftAmt = 0;
10469   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
10470       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
10471     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
10472       ShLeftAmt = N01->getZExtValue();
10473       N0 = N0.getOperand(0);
10474     }
10475   }
10476 
10477   // If we haven't found a load, we can't narrow it.
10478   if (!isa<LoadSDNode>(N0))
10479     return SDValue();
10480 
10481   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10482   // Reducing the width of a volatile load is illegal.  For atomics, we may be
10483   // able to reduce the width provided we never widen again. (see D66309)
10484   if (!LN0->isSimple() ||
10485       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
10486     return SDValue();
10487 
10488   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
10489     unsigned LVTStoreBits = LN0->getMemoryVT().getStoreSizeInBits();
10490     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits();
10491     return LVTStoreBits - EVTStoreBits - ShAmt;
10492   };
10493 
10494   // For big endian targets, we need to adjust the offset to the pointer to
10495   // load the correct bytes.
10496   if (DAG.getDataLayout().isBigEndian())
10497     ShAmt = AdjustBigEndianShift(ShAmt);
10498 
10499   uint64_t PtrOff = ShAmt / 8;
10500   unsigned NewAlign = MinAlign(LN0->getAlignment(), PtrOff);
10501   SDLoc DL(LN0);
10502   // The original load itself didn't wrap, so an offset within it doesn't.
10503   SDNodeFlags Flags;
10504   Flags.setNoUnsignedWrap(true);
10505   SDValue NewPtr =
10506       DAG.getMemBasePlusOffset(LN0->getBasePtr(), PtrOff, DL, Flags);
10507   AddToWorklist(NewPtr.getNode());
10508 
10509   SDValue Load;
10510   if (ExtType == ISD::NON_EXTLOAD)
10511     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
10512                        LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
10513                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
10514   else
10515     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
10516                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
10517                           NewAlign, LN0->getMemOperand()->getFlags(),
10518                           LN0->getAAInfo());
10519 
10520   // Replace the old load's chain with the new load's chain.
10521   WorklistRemover DeadNodes(*this);
10522   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
10523 
10524   // Shift the result left, if we've swallowed a left shift.
10525   SDValue Result = Load;
10526   if (ShLeftAmt != 0) {
10527     EVT ShImmTy = getShiftAmountTy(Result.getValueType());
10528     if (!isUIntN(ShImmTy.getSizeInBits(), ShLeftAmt))
10529       ShImmTy = VT;
10530     // If the shift amount is as large as the result size (but, presumably,
10531     // no larger than the source) then the useful bits of the result are
10532     // zero; we can't simply return the shortened shift, because the result
10533     // of that operation is undefined.
10534     if (ShLeftAmt >= VT.getSizeInBits())
10535       Result = DAG.getConstant(0, DL, VT);
10536     else
10537       Result = DAG.getNode(ISD::SHL, DL, VT,
10538                           Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
10539   }
10540 
10541   if (HasShiftedOffset) {
10542     // Recalculate the shift amount after it has been altered to calculate
10543     // the offset.
10544     if (DAG.getDataLayout().isBigEndian())
10545       ShAmt = AdjustBigEndianShift(ShAmt);
10546 
10547     // We're using a shifted mask, so the load now has an offset. This means
10548     // that data has been loaded into the lower bytes than it would have been
10549     // before, so we need to shl the loaded data into the correct position in the
10550     // register.
10551     SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
10552     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
10553     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
10554   }
10555 
10556   // Return the new loaded value.
10557   return Result;
10558 }
10559 
visitSIGN_EXTEND_INREG(SDNode * N)10560 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
10561   SDValue N0 = N->getOperand(0);
10562   SDValue N1 = N->getOperand(1);
10563   EVT VT = N->getValueType(0);
10564   EVT EVT = cast<VTSDNode>(N1)->getVT();
10565   unsigned VTBits = VT.getScalarSizeInBits();
10566   unsigned EVTBits = EVT.getScalarSizeInBits();
10567 
10568   if (N0.isUndef())
10569     return DAG.getUNDEF(VT);
10570 
10571   // fold (sext_in_reg c1) -> c1
10572   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10573     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
10574 
10575   // If the input is already sign extended, just drop the extension.
10576   if (DAG.ComputeNumSignBits(N0) >= VTBits-EVTBits+1)
10577     return N0;
10578 
10579   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
10580   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
10581       EVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
10582     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
10583                        N0.getOperand(0), N1);
10584 
10585   // fold (sext_in_reg (sext x)) -> (sext x)
10586   // fold (sext_in_reg (aext x)) -> (sext x)
10587   // if x is small enough or if we know that x has more than 1 sign bit and the
10588   // sign_extend_inreg is extending from one of them.
10589   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
10590     SDValue N00 = N0.getOperand(0);
10591     unsigned N00Bits = N00.getScalarValueSizeInBits();
10592     if ((N00Bits <= EVTBits ||
10593          (N00Bits - DAG.ComputeNumSignBits(N00)) < EVTBits) &&
10594         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
10595       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
10596   }
10597 
10598   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
10599   if ((N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
10600        N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
10601        N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) &&
10602       N0.getOperand(0).getScalarValueSizeInBits() == EVTBits) {
10603     if (!LegalOperations ||
10604         TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT))
10605       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
10606                          N0.getOperand(0));
10607   }
10608 
10609   // fold (sext_in_reg (zext x)) -> (sext x)
10610   // iff we are extending the source sign bit.
10611   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
10612     SDValue N00 = N0.getOperand(0);
10613     if (N00.getScalarValueSizeInBits() == EVTBits &&
10614         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
10615       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
10616   }
10617 
10618   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
10619   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, EVTBits - 1)))
10620     return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT.getScalarType());
10621 
10622   // fold operands of sext_in_reg based on knowledge that the top bits are not
10623   // demanded.
10624   if (SimplifyDemandedBits(SDValue(N, 0)))
10625     return SDValue(N, 0);
10626 
10627   // fold (sext_in_reg (load x)) -> (smaller sextload x)
10628   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
10629   if (SDValue NarrowLoad = ReduceLoadWidth(N))
10630     return NarrowLoad;
10631 
10632   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
10633   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
10634   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
10635   if (N0.getOpcode() == ISD::SRL) {
10636     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
10637       if (ShAmt->getAPIntValue().ule(VTBits - EVTBits)) {
10638         // We can turn this into an SRA iff the input to the SRL is already sign
10639         // extended enough.
10640         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
10641         if (((VTBits - EVTBits) - ShAmt->getZExtValue()) < InSignBits)
10642           return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
10643                              N0.getOperand(1));
10644       }
10645   }
10646 
10647   // fold (sext_inreg (extload x)) -> (sextload x)
10648   // If sextload is not supported by target, we can only do the combine when
10649   // load has one use. Doing otherwise can block folding the extload with other
10650   // extends that the target does support.
10651   if (ISD::isEXTLoad(N0.getNode()) &&
10652       ISD::isUNINDEXEDLoad(N0.getNode()) &&
10653       EVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
10654       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
10655         N0.hasOneUse()) ||
10656        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) {
10657     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10658     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
10659                                      LN0->getChain(),
10660                                      LN0->getBasePtr(), EVT,
10661                                      LN0->getMemOperand());
10662     CombineTo(N, ExtLoad);
10663     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
10664     AddToWorklist(ExtLoad.getNode());
10665     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10666   }
10667   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
10668   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
10669       N0.hasOneUse() &&
10670       EVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
10671       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
10672        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) {
10673     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10674     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
10675                                      LN0->getChain(),
10676                                      LN0->getBasePtr(), EVT,
10677                                      LN0->getMemOperand());
10678     CombineTo(N, ExtLoad);
10679     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
10680     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
10681   }
10682 
10683   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
10684   if (EVTBits <= 16 && N0.getOpcode() == ISD::OR) {
10685     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
10686                                            N0.getOperand(1), false))
10687       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
10688                          BSwap, N1);
10689   }
10690 
10691   return SDValue();
10692 }
10693 
visitSIGN_EXTEND_VECTOR_INREG(SDNode * N)10694 SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) {
10695   SDValue N0 = N->getOperand(0);
10696   EVT VT = N->getValueType(0);
10697 
10698   if (N0.isUndef())
10699     return DAG.getUNDEF(VT);
10700 
10701   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
10702     return Res;
10703 
10704   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
10705     return SDValue(N, 0);
10706 
10707   return SDValue();
10708 }
10709 
visitZERO_EXTEND_VECTOR_INREG(SDNode * N)10710 SDValue DAGCombiner::visitZERO_EXTEND_VECTOR_INREG(SDNode *N) {
10711   SDValue N0 = N->getOperand(0);
10712   EVT VT = N->getValueType(0);
10713 
10714   if (N0.isUndef())
10715     return DAG.getUNDEF(VT);
10716 
10717   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
10718     return Res;
10719 
10720   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
10721     return SDValue(N, 0);
10722 
10723   return SDValue();
10724 }
10725 
visitTRUNCATE(SDNode * N)10726 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
10727   SDValue N0 = N->getOperand(0);
10728   EVT VT = N->getValueType(0);
10729   EVT SrcVT = N0.getValueType();
10730   bool isLE = DAG.getDataLayout().isLittleEndian();
10731 
10732   // noop truncate
10733   if (SrcVT == VT)
10734     return N0;
10735 
10736   // fold (truncate (truncate x)) -> (truncate x)
10737   if (N0.getOpcode() == ISD::TRUNCATE)
10738     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
10739 
10740   // fold (truncate c1) -> c1
10741   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
10742     SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
10743     if (C.getNode() != N)
10744       return C;
10745   }
10746 
10747   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
10748   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
10749       N0.getOpcode() == ISD::SIGN_EXTEND ||
10750       N0.getOpcode() == ISD::ANY_EXTEND) {
10751     // if the source is smaller than the dest, we still need an extend.
10752     if (N0.getOperand(0).getValueType().bitsLT(VT))
10753       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
10754     // if the source is larger than the dest, than we just need the truncate.
10755     if (N0.getOperand(0).getValueType().bitsGT(VT))
10756       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
10757     // if the source and dest are the same type, we can drop both the extend
10758     // and the truncate.
10759     return N0.getOperand(0);
10760   }
10761 
10762   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
10763   if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
10764     return SDValue();
10765 
10766   // Fold extract-and-trunc into a narrow extract. For example:
10767   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
10768   //   i32 y = TRUNCATE(i64 x)
10769   //        -- becomes --
10770   //   v16i8 b = BITCAST (v2i64 val)
10771   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
10772   //
10773   // Note: We only run this optimization after type legalization (which often
10774   // creates this pattern) and before operation legalization after which
10775   // we need to be more careful about the vector instructions that we generate.
10776   if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
10777       LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
10778     EVT VecTy = N0.getOperand(0).getValueType();
10779     EVT ExTy = N0.getValueType();
10780     EVT TrTy = N->getValueType(0);
10781 
10782     unsigned NumElem = VecTy.getVectorNumElements();
10783     unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
10784 
10785     EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, SizeRatio * NumElem);
10786     assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
10787 
10788     SDValue EltNo = N0->getOperand(1);
10789     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
10790       int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
10791       EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout());
10792       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
10793 
10794       SDLoc DL(N);
10795       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
10796                          DAG.getBitcast(NVT, N0.getOperand(0)),
10797                          DAG.getConstant(Index, DL, IndexTy));
10798     }
10799   }
10800 
10801   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
10802   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
10803     if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
10804         TLI.isTruncateFree(SrcVT, VT)) {
10805       SDLoc SL(N0);
10806       SDValue Cond = N0.getOperand(0);
10807       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
10808       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
10809       return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
10810     }
10811   }
10812 
10813   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
10814   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
10815       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
10816       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
10817     SDValue Amt = N0.getOperand(1);
10818     KnownBits Known = DAG.computeKnownBits(Amt);
10819     unsigned Size = VT.getScalarSizeInBits();
10820     if (Known.getBitWidth() - Known.countMinLeadingZeros() <= Log2_32(Size)) {
10821       SDLoc SL(N);
10822       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
10823 
10824       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
10825       if (AmtVT != Amt.getValueType()) {
10826         Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
10827         AddToWorklist(Amt.getNode());
10828       }
10829       return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
10830     }
10831   }
10832 
10833   // Attempt to pre-truncate BUILD_VECTOR sources.
10834   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
10835       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType())) {
10836     SDLoc DL(N);
10837     EVT SVT = VT.getScalarType();
10838     SmallVector<SDValue, 8> TruncOps;
10839     for (const SDValue &Op : N0->op_values()) {
10840       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
10841       TruncOps.push_back(TruncOp);
10842     }
10843     return DAG.getBuildVector(VT, DL, TruncOps);
10844   }
10845 
10846   // Fold a series of buildvector, bitcast, and truncate if possible.
10847   // For example fold
10848   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
10849   //   (2xi32 (buildvector x, y)).
10850   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
10851       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
10852       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
10853       N0.getOperand(0).hasOneUse()) {
10854     SDValue BuildVect = N0.getOperand(0);
10855     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
10856     EVT TruncVecEltTy = VT.getVectorElementType();
10857 
10858     // Check that the element types match.
10859     if (BuildVectEltTy == TruncVecEltTy) {
10860       // Now we only need to compute the offset of the truncated elements.
10861       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
10862       unsigned TruncVecNumElts = VT.getVectorNumElements();
10863       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
10864 
10865       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
10866              "Invalid number of elements");
10867 
10868       SmallVector<SDValue, 8> Opnds;
10869       for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
10870         Opnds.push_back(BuildVect.getOperand(i));
10871 
10872       return DAG.getBuildVector(VT, SDLoc(N), Opnds);
10873     }
10874   }
10875 
10876   // See if we can simplify the input to this truncate through knowledge that
10877   // only the low bits are being used.
10878   // For example "trunc (or (shl x, 8), y)" // -> trunc y
10879   // Currently we only perform this optimization on scalars because vectors
10880   // may have different active low bits.
10881   if (!VT.isVector()) {
10882     APInt Mask =
10883         APInt::getLowBitsSet(N0.getValueSizeInBits(), VT.getSizeInBits());
10884     if (SDValue Shorter = DAG.GetDemandedBits(N0, Mask))
10885       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter);
10886   }
10887 
10888   // fold (truncate (load x)) -> (smaller load x)
10889   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
10890   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
10891     if (SDValue Reduced = ReduceLoadWidth(N))
10892       return Reduced;
10893 
10894     // Handle the case where the load remains an extending load even
10895     // after truncation.
10896     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
10897       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
10898       if (LN0->isSimple() &&
10899           LN0->getMemoryVT().getStoreSizeInBits() < VT.getSizeInBits()) {
10900         SDValue NewLoad = DAG.getExtLoad(LN0->getExtensionType(), SDLoc(LN0),
10901                                          VT, LN0->getChain(), LN0->getBasePtr(),
10902                                          LN0->getMemoryVT(),
10903                                          LN0->getMemOperand());
10904         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
10905         return NewLoad;
10906       }
10907     }
10908   }
10909 
10910   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
10911   // where ... are all 'undef'.
10912   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
10913     SmallVector<EVT, 8> VTs;
10914     SDValue V;
10915     unsigned Idx = 0;
10916     unsigned NumDefs = 0;
10917 
10918     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
10919       SDValue X = N0.getOperand(i);
10920       if (!X.isUndef()) {
10921         V = X;
10922         Idx = i;
10923         NumDefs++;
10924       }
10925       // Stop if more than one members are non-undef.
10926       if (NumDefs > 1)
10927         break;
10928       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
10929                                      VT.getVectorElementType(),
10930                                      X.getValueType().getVectorNumElements()));
10931     }
10932 
10933     if (NumDefs == 0)
10934       return DAG.getUNDEF(VT);
10935 
10936     if (NumDefs == 1) {
10937       assert(V.getNode() && "The single defined operand is empty!");
10938       SmallVector<SDValue, 8> Opnds;
10939       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
10940         if (i != Idx) {
10941           Opnds.push_back(DAG.getUNDEF(VTs[i]));
10942           continue;
10943         }
10944         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
10945         AddToWorklist(NV.getNode());
10946         Opnds.push_back(NV);
10947       }
10948       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
10949     }
10950   }
10951 
10952   // Fold truncate of a bitcast of a vector to an extract of the low vector
10953   // element.
10954   //
10955   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
10956   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
10957     SDValue VecSrc = N0.getOperand(0);
10958     EVT VecSrcVT = VecSrc.getValueType();
10959     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
10960         (!LegalOperations ||
10961          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
10962       SDLoc SL(N);
10963 
10964       EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
10965       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
10966       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
10967                          DAG.getConstant(Idx, SL, IdxVT));
10968     }
10969   }
10970 
10971   // Simplify the operands using demanded-bits information.
10972   if (!VT.isVector() &&
10973       SimplifyDemandedBits(SDValue(N, 0)))
10974     return SDValue(N, 0);
10975 
10976   // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
10977   // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
10978   // When the adde's carry is not used.
10979   if ((N0.getOpcode() == ISD::ADDE || N0.getOpcode() == ISD::ADDCARRY) &&
10980       N0.hasOneUse() && !N0.getNode()->hasAnyUseOfValue(1) &&
10981       // We only do for addcarry before legalize operation
10982       ((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
10983        TLI.isOperationLegal(N0.getOpcode(), VT))) {
10984     SDLoc SL(N);
10985     auto X = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
10986     auto Y = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
10987     auto VTs = DAG.getVTList(VT, N0->getValueType(1));
10988     return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2));
10989   }
10990 
10991   // fold (truncate (extract_subvector(ext x))) ->
10992   //      (extract_subvector x)
10993   // TODO: This can be generalized to cover cases where the truncate and extract
10994   // do not fully cancel each other out.
10995   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
10996     SDValue N00 = N0.getOperand(0);
10997     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
10998         N00.getOpcode() == ISD::ZERO_EXTEND ||
10999         N00.getOpcode() == ISD::ANY_EXTEND) {
11000       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
11001           VT.getVectorElementType())
11002         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
11003                            N00.getOperand(0), N0.getOperand(1));
11004     }
11005   }
11006 
11007   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
11008     return NewVSel;
11009 
11010   // Narrow a suitable binary operation with a non-opaque constant operand by
11011   // moving it ahead of the truncate. This is limited to pre-legalization
11012   // because targets may prefer a wider type during later combines and invert
11013   // this transform.
11014   switch (N0.getOpcode()) {
11015   case ISD::ADD:
11016   case ISD::SUB:
11017   case ISD::MUL:
11018   case ISD::AND:
11019   case ISD::OR:
11020   case ISD::XOR:
11021     if (!LegalOperations && N0.hasOneUse() &&
11022         (isConstantOrConstantVector(N0.getOperand(0), true) ||
11023          isConstantOrConstantVector(N0.getOperand(1), true))) {
11024       // TODO: We already restricted this to pre-legalization, but for vectors
11025       // we are extra cautious to not create an unsupported operation.
11026       // Target-specific changes are likely needed to avoid regressions here.
11027       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
11028         SDLoc DL(N);
11029         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
11030         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
11031         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
11032       }
11033     }
11034   }
11035 
11036   return SDValue();
11037 }
11038 
getBuildPairElt(SDNode * N,unsigned i)11039 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
11040   SDValue Elt = N->getOperand(i);
11041   if (Elt.getOpcode() != ISD::MERGE_VALUES)
11042     return Elt.getNode();
11043   return Elt.getOperand(Elt.getResNo()).getNode();
11044 }
11045 
11046 /// build_pair (load, load) -> load
11047 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)11048 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
11049   assert(N->getOpcode() == ISD::BUILD_PAIR);
11050 
11051   LoadSDNode *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
11052   LoadSDNode *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
11053 
11054   // A BUILD_PAIR is always having the least significant part in elt 0 and the
11055   // most significant part in elt 1. So when combining into one large load, we
11056   // need to consider the endianness.
11057   if (DAG.getDataLayout().isBigEndian())
11058     std::swap(LD1, LD2);
11059 
11060   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !LD1->hasOneUse() ||
11061       LD1->getAddressSpace() != LD2->getAddressSpace())
11062     return SDValue();
11063   EVT LD1VT = LD1->getValueType(0);
11064   unsigned LD1Bytes = LD1VT.getStoreSize();
11065   if (ISD::isNON_EXTLoad(LD2) && LD2->hasOneUse() &&
11066       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1)) {
11067     unsigned Align = LD1->getAlignment();
11068     unsigned NewAlign = DAG.getDataLayout().getABITypeAlignment(
11069         VT.getTypeForEVT(*DAG.getContext()));
11070 
11071     if (NewAlign <= Align &&
11072         (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)))
11073       return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
11074                          LD1->getPointerInfo(), Align);
11075   }
11076 
11077   return SDValue();
11078 }
11079 
getPPCf128HiElementSelector(const SelectionDAG & DAG)11080 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
11081   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
11082   // and Lo parts; on big-endian machines it doesn't.
11083   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
11084 }
11085 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)11086 static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
11087                                     const TargetLowering &TLI) {
11088   // If this is not a bitcast to an FP type or if the target doesn't have
11089   // IEEE754-compliant FP logic, we're done.
11090   EVT VT = N->getValueType(0);
11091   if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
11092     return SDValue();
11093 
11094   // TODO: Handle cases where the integer constant is a different scalar
11095   // bitwidth to the FP.
11096   SDValue N0 = N->getOperand(0);
11097   EVT SourceVT = N0.getValueType();
11098   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
11099     return SDValue();
11100 
11101   unsigned FPOpcode;
11102   APInt SignMask;
11103   switch (N0.getOpcode()) {
11104   case ISD::AND:
11105     FPOpcode = ISD::FABS;
11106     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
11107     break;
11108   case ISD::XOR:
11109     FPOpcode = ISD::FNEG;
11110     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
11111     break;
11112   case ISD::OR:
11113     FPOpcode = ISD::FABS;
11114     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
11115     break;
11116   default:
11117     return SDValue();
11118   }
11119 
11120   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
11121   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
11122   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
11123   //   fneg (fabs X)
11124   SDValue LogicOp0 = N0.getOperand(0);
11125   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
11126   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
11127       LogicOp0.getOpcode() == ISD::BITCAST &&
11128       LogicOp0.getOperand(0).getValueType() == VT) {
11129     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
11130     NumFPLogicOpsConv++;
11131     if (N0.getOpcode() == ISD::OR)
11132       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
11133     return FPOp;
11134   }
11135 
11136   return SDValue();
11137 }
11138 
visitBITCAST(SDNode * N)11139 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
11140   SDValue N0 = N->getOperand(0);
11141   EVT VT = N->getValueType(0);
11142 
11143   if (N0.isUndef())
11144     return DAG.getUNDEF(VT);
11145 
11146   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
11147   // Only do this before legalize types, unless both types are integer and the
11148   // scalar type is legal. Only do this before legalize ops, since the target
11149   // maybe depending on the bitcast.
11150   // First check to see if this is all constant.
11151   // TODO: Support FP bitcasts after legalize types.
11152   if (VT.isVector() &&
11153       (!LegalTypes ||
11154        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
11155         TLI.isTypeLegal(VT.getVectorElementType()))) &&
11156       N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() &&
11157       cast<BuildVectorSDNode>(N0)->isConstant())
11158     return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
11159                                              VT.getVectorElementType());
11160 
11161   // If the input is a constant, let getNode fold it.
11162   if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0)) {
11163     // If we can't allow illegal operations, we need to check that this is just
11164     // a fp -> int or int -> conversion and that the resulting operation will
11165     // be legal.
11166     if (!LegalOperations ||
11167         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
11168          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
11169         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
11170          TLI.isOperationLegal(ISD::Constant, VT))) {
11171       SDValue C = DAG.getBitcast(VT, N0);
11172       if (C.getNode() != N)
11173         return C;
11174     }
11175   }
11176 
11177   // (conv (conv x, t1), t2) -> (conv x, t2)
11178   if (N0.getOpcode() == ISD::BITCAST)
11179     return DAG.getBitcast(VT, N0.getOperand(0));
11180 
11181   // fold (conv (load x)) -> (load (conv*)x)
11182   // If the resultant load doesn't need a higher alignment than the original!
11183   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
11184       // Do not remove the cast if the types differ in endian layout.
11185       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
11186           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
11187       // If the load is volatile, we only want to change the load type if the
11188       // resulting load is legal. Otherwise we might increase the number of
11189       // memory accesses. We don't care if the original type was legal or not
11190       // as we assume software couldn't rely on the number of accesses of an
11191       // illegal type.
11192       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
11193        TLI.isOperationLegal(ISD::LOAD, VT))) {
11194     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11195 
11196     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
11197                                     *LN0->getMemOperand())) {
11198       SDValue Load =
11199           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
11200                       LN0->getPointerInfo(), LN0->getAlignment(),
11201                       LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
11202       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
11203       return Load;
11204     }
11205   }
11206 
11207   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
11208     return V;
11209 
11210   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
11211   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
11212   //
11213   // For ppc_fp128:
11214   // fold (bitcast (fneg x)) ->
11215   //     flipbit = signbit
11216   //     (xor (bitcast x) (build_pair flipbit, flipbit))
11217   //
11218   // fold (bitcast (fabs x)) ->
11219   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
11220   //     (xor (bitcast x) (build_pair flipbit, flipbit))
11221   // This often reduces constant pool loads.
11222   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
11223        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
11224       N0.getNode()->hasOneUse() && VT.isInteger() &&
11225       !VT.isVector() && !N0.getValueType().isVector()) {
11226     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
11227     AddToWorklist(NewConv.getNode());
11228 
11229     SDLoc DL(N);
11230     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
11231       assert(VT.getSizeInBits() == 128);
11232       SDValue SignBit = DAG.getConstant(
11233           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
11234       SDValue FlipBit;
11235       if (N0.getOpcode() == ISD::FNEG) {
11236         FlipBit = SignBit;
11237         AddToWorklist(FlipBit.getNode());
11238       } else {
11239         assert(N0.getOpcode() == ISD::FABS);
11240         SDValue Hi =
11241             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
11242                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
11243                                               SDLoc(NewConv)));
11244         AddToWorklist(Hi.getNode());
11245         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
11246         AddToWorklist(FlipBit.getNode());
11247       }
11248       SDValue FlipBits =
11249           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
11250       AddToWorklist(FlipBits.getNode());
11251       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
11252     }
11253     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
11254     if (N0.getOpcode() == ISD::FNEG)
11255       return DAG.getNode(ISD::XOR, DL, VT,
11256                          NewConv, DAG.getConstant(SignBit, DL, VT));
11257     assert(N0.getOpcode() == ISD::FABS);
11258     return DAG.getNode(ISD::AND, DL, VT,
11259                        NewConv, DAG.getConstant(~SignBit, DL, VT));
11260   }
11261 
11262   // fold (bitconvert (fcopysign cst, x)) ->
11263   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
11264   // Note that we don't handle (copysign x, cst) because this can always be
11265   // folded to an fneg or fabs.
11266   //
11267   // For ppc_fp128:
11268   // fold (bitcast (fcopysign cst, x)) ->
11269   //     flipbit = (and (extract_element
11270   //                     (xor (bitcast cst), (bitcast x)), 0),
11271   //                    signbit)
11272   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
11273   if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse() &&
11274       isa<ConstantFPSDNode>(N0.getOperand(0)) &&
11275       VT.isInteger() && !VT.isVector()) {
11276     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
11277     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
11278     if (isTypeLegal(IntXVT)) {
11279       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
11280       AddToWorklist(X.getNode());
11281 
11282       // If X has a different width than the result/lhs, sext it or truncate it.
11283       unsigned VTWidth = VT.getSizeInBits();
11284       if (OrigXWidth < VTWidth) {
11285         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
11286         AddToWorklist(X.getNode());
11287       } else if (OrigXWidth > VTWidth) {
11288         // To get the sign bit in the right place, we have to shift it right
11289         // before truncating.
11290         SDLoc DL(X);
11291         X = DAG.getNode(ISD::SRL, DL,
11292                         X.getValueType(), X,
11293                         DAG.getConstant(OrigXWidth-VTWidth, DL,
11294                                         X.getValueType()));
11295         AddToWorklist(X.getNode());
11296         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
11297         AddToWorklist(X.getNode());
11298       }
11299 
11300       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
11301         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
11302         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
11303         AddToWorklist(Cst.getNode());
11304         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
11305         AddToWorklist(X.getNode());
11306         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
11307         AddToWorklist(XorResult.getNode());
11308         SDValue XorResult64 = DAG.getNode(
11309             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
11310             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
11311                                   SDLoc(XorResult)));
11312         AddToWorklist(XorResult64.getNode());
11313         SDValue FlipBit =
11314             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
11315                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
11316         AddToWorklist(FlipBit.getNode());
11317         SDValue FlipBits =
11318             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
11319         AddToWorklist(FlipBits.getNode());
11320         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
11321       }
11322       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
11323       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
11324                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
11325       AddToWorklist(X.getNode());
11326 
11327       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
11328       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
11329                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
11330       AddToWorklist(Cst.getNode());
11331 
11332       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
11333     }
11334   }
11335 
11336   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
11337   if (N0.getOpcode() == ISD::BUILD_PAIR)
11338     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
11339       return CombineLD;
11340 
11341   // Remove double bitcasts from shuffles - this is often a legacy of
11342   // XformToShuffleWithZero being used to combine bitmaskings (of
11343   // float vectors bitcast to integer vectors) into shuffles.
11344   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
11345   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
11346       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
11347       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
11348       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
11349     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
11350 
11351     // If operands are a bitcast, peek through if it casts the original VT.
11352     // If operands are a constant, just bitcast back to original VT.
11353     auto PeekThroughBitcast = [&](SDValue Op) {
11354       if (Op.getOpcode() == ISD::BITCAST &&
11355           Op.getOperand(0).getValueType() == VT)
11356         return SDValue(Op.getOperand(0));
11357       if (Op.isUndef() || ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
11358           ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
11359         return DAG.getBitcast(VT, Op);
11360       return SDValue();
11361     };
11362 
11363     // FIXME: If either input vector is bitcast, try to convert the shuffle to
11364     // the result type of this bitcast. This would eliminate at least one
11365     // bitcast. See the transform in InstCombine.
11366     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
11367     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
11368     if (!(SV0 && SV1))
11369       return SDValue();
11370 
11371     int MaskScale =
11372         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
11373     SmallVector<int, 8> NewMask;
11374     for (int M : SVN->getMask())
11375       for (int i = 0; i != MaskScale; ++i)
11376         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
11377 
11378     SDValue LegalShuffle =
11379         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
11380     if (LegalShuffle)
11381       return LegalShuffle;
11382   }
11383 
11384   return SDValue();
11385 }
11386 
visitBUILD_PAIR(SDNode * N)11387 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
11388   EVT VT = N->getValueType(0);
11389   return CombineConsecutiveLoads(N, VT);
11390 }
11391 
11392 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
11393 /// operands. DstEltVT indicates the destination element value type.
11394 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)11395 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
11396   EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
11397 
11398   // If this is already the right type, we're done.
11399   if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
11400 
11401   unsigned SrcBitSize = SrcEltVT.getSizeInBits();
11402   unsigned DstBitSize = DstEltVT.getSizeInBits();
11403 
11404   // If this is a conversion of N elements of one type to N elements of another
11405   // type, convert each element.  This handles FP<->INT cases.
11406   if (SrcBitSize == DstBitSize) {
11407     SmallVector<SDValue, 8> Ops;
11408     for (SDValue Op : BV->op_values()) {
11409       // If the vector element type is not legal, the BUILD_VECTOR operands
11410       // are promoted and implicitly truncated.  Make that explicit here.
11411       if (Op.getValueType() != SrcEltVT)
11412         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
11413       Ops.push_back(DAG.getBitcast(DstEltVT, Op));
11414       AddToWorklist(Ops.back().getNode());
11415     }
11416     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
11417                               BV->getValueType(0).getVectorNumElements());
11418     return DAG.getBuildVector(VT, SDLoc(BV), Ops);
11419   }
11420 
11421   // Otherwise, we're growing or shrinking the elements.  To avoid having to
11422   // handle annoying details of growing/shrinking FP values, we convert them to
11423   // int first.
11424   if (SrcEltVT.isFloatingPoint()) {
11425     // Convert the input float vector to a int vector where the elements are the
11426     // same sizes.
11427     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
11428     BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
11429     SrcEltVT = IntVT;
11430   }
11431 
11432   // Now we know the input is an integer vector.  If the output is a FP type,
11433   // convert to integer first, then to FP of the right size.
11434   if (DstEltVT.isFloatingPoint()) {
11435     EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
11436     SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
11437 
11438     // Next, convert to FP elements of the same size.
11439     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
11440   }
11441 
11442   SDLoc DL(BV);
11443 
11444   // Okay, we know the src/dst types are both integers of differing types.
11445   // Handling growing first.
11446   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
11447   if (SrcBitSize < DstBitSize) {
11448     unsigned NumInputsPerOutput = DstBitSize/SrcBitSize;
11449 
11450     SmallVector<SDValue, 8> Ops;
11451     for (unsigned i = 0, e = BV->getNumOperands(); i != e;
11452          i += NumInputsPerOutput) {
11453       bool isLE = DAG.getDataLayout().isLittleEndian();
11454       APInt NewBits = APInt(DstBitSize, 0);
11455       bool EltIsUndef = true;
11456       for (unsigned j = 0; j != NumInputsPerOutput; ++j) {
11457         // Shift the previously computed bits over.
11458         NewBits <<= SrcBitSize;
11459         SDValue Op = BV->getOperand(i+ (isLE ? (NumInputsPerOutput-j-1) : j));
11460         if (Op.isUndef()) continue;
11461         EltIsUndef = false;
11462 
11463         NewBits |= cast<ConstantSDNode>(Op)->getAPIntValue().
11464                    zextOrTrunc(SrcBitSize).zext(DstBitSize);
11465       }
11466 
11467       if (EltIsUndef)
11468         Ops.push_back(DAG.getUNDEF(DstEltVT));
11469       else
11470         Ops.push_back(DAG.getConstant(NewBits, DL, DstEltVT));
11471     }
11472 
11473     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
11474     return DAG.getBuildVector(VT, DL, Ops);
11475   }
11476 
11477   // Finally, this must be the case where we are shrinking elements: each input
11478   // turns into multiple outputs.
11479   unsigned NumOutputsPerInput = SrcBitSize/DstBitSize;
11480   EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
11481                             NumOutputsPerInput*BV->getNumOperands());
11482   SmallVector<SDValue, 8> Ops;
11483 
11484   for (const SDValue &Op : BV->op_values()) {
11485     if (Op.isUndef()) {
11486       Ops.append(NumOutputsPerInput, DAG.getUNDEF(DstEltVT));
11487       continue;
11488     }
11489 
11490     APInt OpVal = cast<ConstantSDNode>(Op)->
11491                   getAPIntValue().zextOrTrunc(SrcBitSize);
11492 
11493     for (unsigned j = 0; j != NumOutputsPerInput; ++j) {
11494       APInt ThisVal = OpVal.trunc(DstBitSize);
11495       Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT));
11496       OpVal.lshrInPlace(DstBitSize);
11497     }
11498 
11499     // For big endian targets, swap the order of the pieces of each element.
11500     if (DAG.getDataLayout().isBigEndian())
11501       std::reverse(Ops.end()-NumOutputsPerInput, Ops.end());
11502   }
11503 
11504   return DAG.getBuildVector(VT, DL, Ops);
11505 }
11506 
isContractable(SDNode * N)11507 static bool isContractable(SDNode *N) {
11508   SDNodeFlags F = N->getFlags();
11509   return F.hasAllowContract() || F.hasAllowReassociation();
11510 }
11511 
11512 /// Try to perform FMA combining on a given FADD node.
visitFADDForFMACombine(SDNode * N)11513 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
11514   SDValue N0 = N->getOperand(0);
11515   SDValue N1 = N->getOperand(1);
11516   EVT VT = N->getValueType(0);
11517   SDLoc SL(N);
11518 
11519   const TargetOptions &Options = DAG.getTarget().Options;
11520 
11521   // Floating-point multiply-add with intermediate rounding.
11522   bool HasFMAD = (LegalOperations && TLI.isFMADLegalForFAddFSub(DAG, N));
11523 
11524   // Floating-point multiply-add without intermediate rounding.
11525   bool HasFMA =
11526       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
11527       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
11528 
11529   // No valid opcode, do not combine.
11530   if (!HasFMAD && !HasFMA)
11531     return SDValue();
11532 
11533   SDNodeFlags Flags = N->getFlags();
11534   bool CanFuse = Options.UnsafeFPMath || isContractable(N);
11535   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
11536                               CanFuse || HasFMAD);
11537   // If the addition is not contractable, do not combine.
11538   if (!AllowFusionGlobally && !isContractable(N))
11539     return SDValue();
11540 
11541   const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
11542   if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
11543     return SDValue();
11544 
11545   // Always prefer FMAD to FMA for precision.
11546   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
11547   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
11548 
11549   // Is the node an FMUL and contractable either due to global flags or
11550   // SDNodeFlags.
11551   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
11552     if (N.getOpcode() != ISD::FMUL)
11553       return false;
11554     return AllowFusionGlobally || isContractable(N.getNode());
11555   };
11556   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
11557   // prefer to fold the multiply with fewer uses.
11558   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
11559     if (N0.getNode()->use_size() > N1.getNode()->use_size())
11560       std::swap(N0, N1);
11561   }
11562 
11563   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
11564   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
11565     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11566                        N0.getOperand(0), N0.getOperand(1), N1, Flags);
11567   }
11568 
11569   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
11570   // Note: Commutes FADD operands.
11571   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
11572     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11573                        N1.getOperand(0), N1.getOperand(1), N0, Flags);
11574   }
11575 
11576   // Look through FP_EXTEND nodes to do more combining.
11577 
11578   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
11579   if (N0.getOpcode() == ISD::FP_EXTEND) {
11580     SDValue N00 = N0.getOperand(0);
11581     if (isContractableFMUL(N00) &&
11582         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11583                             N00.getValueType())) {
11584       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11585                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11586                                      N00.getOperand(0)),
11587                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11588                                      N00.getOperand(1)), N1, Flags);
11589     }
11590   }
11591 
11592   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
11593   // Note: Commutes FADD operands.
11594   if (N1.getOpcode() == ISD::FP_EXTEND) {
11595     SDValue N10 = N1.getOperand(0);
11596     if (isContractableFMUL(N10) &&
11597         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11598                             N10.getValueType())) {
11599       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11600                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11601                                      N10.getOperand(0)),
11602                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11603                                      N10.getOperand(1)), N0, Flags);
11604     }
11605   }
11606 
11607   // More folding opportunities when target permits.
11608   if (Aggressive) {
11609     // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
11610     if (CanFuse &&
11611         N0.getOpcode() == PreferredFusedOpcode &&
11612         N0.getOperand(2).getOpcode() == ISD::FMUL &&
11613         N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
11614       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11615                          N0.getOperand(0), N0.getOperand(1),
11616                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11617                                      N0.getOperand(2).getOperand(0),
11618                                      N0.getOperand(2).getOperand(1),
11619                                      N1, Flags), Flags);
11620     }
11621 
11622     // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x))
11623     if (CanFuse &&
11624         N1->getOpcode() == PreferredFusedOpcode &&
11625         N1.getOperand(2).getOpcode() == ISD::FMUL &&
11626         N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) {
11627       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11628                          N1.getOperand(0), N1.getOperand(1),
11629                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11630                                      N1.getOperand(2).getOperand(0),
11631                                      N1.getOperand(2).getOperand(1),
11632                                      N0, Flags), Flags);
11633     }
11634 
11635 
11636     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
11637     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
11638     auto FoldFAddFMAFPExtFMul = [&] (
11639       SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z,
11640       SDNodeFlags Flags) {
11641       return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
11642                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11643                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
11644                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
11645                                      Z, Flags), Flags);
11646     };
11647     if (N0.getOpcode() == PreferredFusedOpcode) {
11648       SDValue N02 = N0.getOperand(2);
11649       if (N02.getOpcode() == ISD::FP_EXTEND) {
11650         SDValue N020 = N02.getOperand(0);
11651         if (isContractableFMUL(N020) &&
11652             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11653                                 N020.getValueType())) {
11654           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
11655                                       N020.getOperand(0), N020.getOperand(1),
11656                                       N1, Flags);
11657         }
11658       }
11659     }
11660 
11661     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
11662     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
11663     // FIXME: This turns two single-precision and one double-precision
11664     // operation into two double-precision operations, which might not be
11665     // interesting for all targets, especially GPUs.
11666     auto FoldFAddFPExtFMAFMul = [&] (
11667       SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z,
11668       SDNodeFlags Flags) {
11669       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11670                          DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
11671                          DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
11672                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11673                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
11674                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
11675                                      Z, Flags), Flags);
11676     };
11677     if (N0.getOpcode() == ISD::FP_EXTEND) {
11678       SDValue N00 = N0.getOperand(0);
11679       if (N00.getOpcode() == PreferredFusedOpcode) {
11680         SDValue N002 = N00.getOperand(2);
11681         if (isContractableFMUL(N002) &&
11682             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11683                                 N00.getValueType())) {
11684           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
11685                                       N002.getOperand(0), N002.getOperand(1),
11686                                       N1, Flags);
11687         }
11688       }
11689     }
11690 
11691     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
11692     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
11693     if (N1.getOpcode() == PreferredFusedOpcode) {
11694       SDValue N12 = N1.getOperand(2);
11695       if (N12.getOpcode() == ISD::FP_EXTEND) {
11696         SDValue N120 = N12.getOperand(0);
11697         if (isContractableFMUL(N120) &&
11698             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11699                                 N120.getValueType())) {
11700           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
11701                                       N120.getOperand(0), N120.getOperand(1),
11702                                       N0, Flags);
11703         }
11704       }
11705     }
11706 
11707     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
11708     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
11709     // FIXME: This turns two single-precision and one double-precision
11710     // operation into two double-precision operations, which might not be
11711     // interesting for all targets, especially GPUs.
11712     if (N1.getOpcode() == ISD::FP_EXTEND) {
11713       SDValue N10 = N1.getOperand(0);
11714       if (N10.getOpcode() == PreferredFusedOpcode) {
11715         SDValue N102 = N10.getOperand(2);
11716         if (isContractableFMUL(N102) &&
11717             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11718                                 N10.getValueType())) {
11719           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
11720                                       N102.getOperand(0), N102.getOperand(1),
11721                                       N0, Flags);
11722         }
11723       }
11724     }
11725   }
11726 
11727   return SDValue();
11728 }
11729 
11730 /// Try to perform FMA combining on a given FSUB node.
visitFSUBForFMACombine(SDNode * N)11731 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
11732   SDValue N0 = N->getOperand(0);
11733   SDValue N1 = N->getOperand(1);
11734   EVT VT = N->getValueType(0);
11735   SDLoc SL(N);
11736 
11737   const TargetOptions &Options = DAG.getTarget().Options;
11738   // Floating-point multiply-add with intermediate rounding.
11739   bool HasFMAD = (LegalOperations && TLI.isFMADLegalForFAddFSub(DAG, N));
11740 
11741   // Floating-point multiply-add without intermediate rounding.
11742   bool HasFMA =
11743       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
11744       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
11745 
11746   // No valid opcode, do not combine.
11747   if (!HasFMAD && !HasFMA)
11748     return SDValue();
11749 
11750   const SDNodeFlags Flags = N->getFlags();
11751   bool CanFuse = Options.UnsafeFPMath || isContractable(N);
11752   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
11753                               CanFuse || HasFMAD);
11754 
11755   // If the subtraction is not contractable, do not combine.
11756   if (!AllowFusionGlobally && !isContractable(N))
11757     return SDValue();
11758 
11759   const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
11760   if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
11761     return SDValue();
11762 
11763   // Always prefer FMAD to FMA for precision.
11764   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
11765   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
11766 
11767   // Is the node an FMUL and contractable either due to global flags or
11768   // SDNodeFlags.
11769   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
11770     if (N.getOpcode() != ISD::FMUL)
11771       return false;
11772     return AllowFusionGlobally || isContractable(N.getNode());
11773   };
11774 
11775   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
11776   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
11777     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11778                        N0.getOperand(0), N0.getOperand(1),
11779                        DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
11780   }
11781 
11782   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
11783   // Note: Commutes FSUB operands.
11784   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
11785     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11786                        DAG.getNode(ISD::FNEG, SL, VT,
11787                                    N1.getOperand(0)),
11788                        N1.getOperand(1), N0, Flags);
11789   }
11790 
11791   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
11792   if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
11793       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
11794     SDValue N00 = N0.getOperand(0).getOperand(0);
11795     SDValue N01 = N0.getOperand(0).getOperand(1);
11796     return DAG.getNode(PreferredFusedOpcode, SL, VT,
11797                        DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
11798                        DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
11799   }
11800 
11801   // Look through FP_EXTEND nodes to do more combining.
11802 
11803   // fold (fsub (fpext (fmul x, y)), z)
11804   //   -> (fma (fpext x), (fpext y), (fneg z))
11805   if (N0.getOpcode() == ISD::FP_EXTEND) {
11806     SDValue N00 = N0.getOperand(0);
11807     if (isContractableFMUL(N00) &&
11808         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11809                             N00.getValueType())) {
11810       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11811                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11812                                      N00.getOperand(0)),
11813                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11814                                      N00.getOperand(1)),
11815                          DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
11816     }
11817   }
11818 
11819   // fold (fsub x, (fpext (fmul y, z)))
11820   //   -> (fma (fneg (fpext y)), (fpext z), x)
11821   // Note: Commutes FSUB operands.
11822   if (N1.getOpcode() == ISD::FP_EXTEND) {
11823     SDValue N10 = N1.getOperand(0);
11824     if (isContractableFMUL(N10) &&
11825         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11826                             N10.getValueType())) {
11827       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11828                          DAG.getNode(ISD::FNEG, SL, VT,
11829                                      DAG.getNode(ISD::FP_EXTEND, SL, VT,
11830                                                  N10.getOperand(0))),
11831                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11832                                      N10.getOperand(1)),
11833                          N0, Flags);
11834     }
11835   }
11836 
11837   // fold (fsub (fpext (fneg (fmul, x, y))), z)
11838   //   -> (fneg (fma (fpext x), (fpext y), z))
11839   // Note: This could be removed with appropriate canonicalization of the
11840   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
11841   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
11842   // from implementing the canonicalization in visitFSUB.
11843   if (N0.getOpcode() == ISD::FP_EXTEND) {
11844     SDValue N00 = N0.getOperand(0);
11845     if (N00.getOpcode() == ISD::FNEG) {
11846       SDValue N000 = N00.getOperand(0);
11847       if (isContractableFMUL(N000) &&
11848           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11849                               N00.getValueType())) {
11850         return DAG.getNode(ISD::FNEG, SL, VT,
11851                            DAG.getNode(PreferredFusedOpcode, SL, VT,
11852                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11853                                                    N000.getOperand(0)),
11854                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11855                                                    N000.getOperand(1)),
11856                                        N1, Flags));
11857       }
11858     }
11859   }
11860 
11861   // fold (fsub (fneg (fpext (fmul, x, y))), z)
11862   //   -> (fneg (fma (fpext x)), (fpext y), z)
11863   // Note: This could be removed with appropriate canonicalization of the
11864   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
11865   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
11866   // from implementing the canonicalization in visitFSUB.
11867   if (N0.getOpcode() == ISD::FNEG) {
11868     SDValue N00 = N0.getOperand(0);
11869     if (N00.getOpcode() == ISD::FP_EXTEND) {
11870       SDValue N000 = N00.getOperand(0);
11871       if (isContractableFMUL(N000) &&
11872           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11873                               N000.getValueType())) {
11874         return DAG.getNode(ISD::FNEG, SL, VT,
11875                            DAG.getNode(PreferredFusedOpcode, SL, VT,
11876                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11877                                                    N000.getOperand(0)),
11878                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11879                                                    N000.getOperand(1)),
11880                                        N1, Flags));
11881       }
11882     }
11883   }
11884 
11885   // More folding opportunities when target permits.
11886   if (Aggressive) {
11887     // fold (fsub (fma x, y, (fmul u, v)), z)
11888     //   -> (fma x, y (fma u, v, (fneg z)))
11889     if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
11890         isContractableFMUL(N0.getOperand(2)) && N0->hasOneUse() &&
11891         N0.getOperand(2)->hasOneUse()) {
11892       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11893                          N0.getOperand(0), N0.getOperand(1),
11894                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11895                                      N0.getOperand(2).getOperand(0),
11896                                      N0.getOperand(2).getOperand(1),
11897                                      DAG.getNode(ISD::FNEG, SL, VT,
11898                                                  N1), Flags), Flags);
11899     }
11900 
11901     // fold (fsub x, (fma y, z, (fmul u, v)))
11902     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
11903     if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
11904         isContractableFMUL(N1.getOperand(2)) &&
11905         N1->hasOneUse()) {
11906       SDValue N20 = N1.getOperand(2).getOperand(0);
11907       SDValue N21 = N1.getOperand(2).getOperand(1);
11908       return DAG.getNode(PreferredFusedOpcode, SL, VT,
11909                          DAG.getNode(ISD::FNEG, SL, VT,
11910                                      N1.getOperand(0)),
11911                          N1.getOperand(1),
11912                          DAG.getNode(PreferredFusedOpcode, SL, VT,
11913                                      DAG.getNode(ISD::FNEG, SL, VT, N20),
11914                                      N21, N0, Flags), Flags);
11915     }
11916 
11917 
11918     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
11919     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
11920     if (N0.getOpcode() == PreferredFusedOpcode &&
11921         N0->hasOneUse()) {
11922       SDValue N02 = N0.getOperand(2);
11923       if (N02.getOpcode() == ISD::FP_EXTEND) {
11924         SDValue N020 = N02.getOperand(0);
11925         if (isContractableFMUL(N020) &&
11926             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11927                                 N020.getValueType())) {
11928           return DAG.getNode(PreferredFusedOpcode, SL, VT,
11929                              N0.getOperand(0), N0.getOperand(1),
11930                              DAG.getNode(PreferredFusedOpcode, SL, VT,
11931                                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11932                                                      N020.getOperand(0)),
11933                                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11934                                                      N020.getOperand(1)),
11935                                          DAG.getNode(ISD::FNEG, SL, VT,
11936                                                      N1), Flags), Flags);
11937         }
11938       }
11939     }
11940 
11941     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
11942     //   -> (fma (fpext x), (fpext y),
11943     //           (fma (fpext u), (fpext v), (fneg z)))
11944     // FIXME: This turns two single-precision and one double-precision
11945     // operation into two double-precision operations, which might not be
11946     // interesting for all targets, especially GPUs.
11947     if (N0.getOpcode() == ISD::FP_EXTEND) {
11948       SDValue N00 = N0.getOperand(0);
11949       if (N00.getOpcode() == PreferredFusedOpcode) {
11950         SDValue N002 = N00.getOperand(2);
11951         if (isContractableFMUL(N002) &&
11952             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11953                                 N00.getValueType())) {
11954           return DAG.getNode(PreferredFusedOpcode, SL, VT,
11955                              DAG.getNode(ISD::FP_EXTEND, SL, VT,
11956                                          N00.getOperand(0)),
11957                              DAG.getNode(ISD::FP_EXTEND, SL, VT,
11958                                          N00.getOperand(1)),
11959                              DAG.getNode(PreferredFusedOpcode, SL, VT,
11960                                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11961                                                      N002.getOperand(0)),
11962                                          DAG.getNode(ISD::FP_EXTEND, SL, VT,
11963                                                      N002.getOperand(1)),
11964                                          DAG.getNode(ISD::FNEG, SL, VT,
11965                                                      N1), Flags), Flags);
11966         }
11967       }
11968     }
11969 
11970     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
11971     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
11972     if (N1.getOpcode() == PreferredFusedOpcode &&
11973         N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
11974         N1->hasOneUse()) {
11975       SDValue N120 = N1.getOperand(2).getOperand(0);
11976       if (isContractableFMUL(N120) &&
11977           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
11978                               N120.getValueType())) {
11979         SDValue N1200 = N120.getOperand(0);
11980         SDValue N1201 = N120.getOperand(1);
11981         return DAG.getNode(PreferredFusedOpcode, SL, VT,
11982                            DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
11983                            N1.getOperand(1),
11984                            DAG.getNode(PreferredFusedOpcode, SL, VT,
11985                                        DAG.getNode(ISD::FNEG, SL, VT,
11986                                                    DAG.getNode(ISD::FP_EXTEND, SL,
11987                                                                VT, N1200)),
11988                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
11989                                                    N1201),
11990                                        N0, Flags), Flags);
11991       }
11992     }
11993 
11994     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
11995     //   -> (fma (fneg (fpext y)), (fpext z),
11996     //           (fma (fneg (fpext u)), (fpext v), x))
11997     // FIXME: This turns two single-precision and one double-precision
11998     // operation into two double-precision operations, which might not be
11999     // interesting for all targets, especially GPUs.
12000     if (N1.getOpcode() == ISD::FP_EXTEND &&
12001         N1.getOperand(0).getOpcode() == PreferredFusedOpcode) {
12002       SDValue CvtSrc = N1.getOperand(0);
12003       SDValue N100 = CvtSrc.getOperand(0);
12004       SDValue N101 = CvtSrc.getOperand(1);
12005       SDValue N102 = CvtSrc.getOperand(2);
12006       if (isContractableFMUL(N102) &&
12007           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
12008                               CvtSrc.getValueType())) {
12009         SDValue N1020 = N102.getOperand(0);
12010         SDValue N1021 = N102.getOperand(1);
12011         return DAG.getNode(PreferredFusedOpcode, SL, VT,
12012                            DAG.getNode(ISD::FNEG, SL, VT,
12013                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
12014                                                    N100)),
12015                            DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
12016                            DAG.getNode(PreferredFusedOpcode, SL, VT,
12017                                        DAG.getNode(ISD::FNEG, SL, VT,
12018                                                    DAG.getNode(ISD::FP_EXTEND, SL,
12019                                                                VT, N1020)),
12020                                        DAG.getNode(ISD::FP_EXTEND, SL, VT,
12021                                                    N1021),
12022                                        N0, Flags), Flags);
12023       }
12024     }
12025   }
12026 
12027   return SDValue();
12028 }
12029 
12030 /// Try to perform FMA combining on a given FMUL node based on the distributive
12031 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
12032 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)12033 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
12034   SDValue N0 = N->getOperand(0);
12035   SDValue N1 = N->getOperand(1);
12036   EVT VT = N->getValueType(0);
12037   SDLoc SL(N);
12038   const SDNodeFlags Flags = N->getFlags();
12039 
12040   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
12041 
12042   const TargetOptions &Options = DAG.getTarget().Options;
12043 
12044   // The transforms below are incorrect when x == 0 and y == inf, because the
12045   // intermediate multiplication produces a nan.
12046   if (!Options.NoInfsFPMath)
12047     return SDValue();
12048 
12049   // Floating-point multiply-add without intermediate rounding.
12050   bool HasFMA =
12051       (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
12052       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
12053       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
12054 
12055   // Floating-point multiply-add with intermediate rounding. This can result
12056   // in a less precise result due to the changed rounding order.
12057   bool HasFMAD = Options.UnsafeFPMath &&
12058                  (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
12059 
12060   // No valid opcode, do not combine.
12061   if (!HasFMAD && !HasFMA)
12062     return SDValue();
12063 
12064   // Always prefer FMAD to FMA for precision.
12065   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
12066   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
12067 
12068   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
12069   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
12070   auto FuseFADD = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) {
12071     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
12072       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
12073         if (C->isExactlyValue(+1.0))
12074           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
12075                              Y, Flags);
12076         if (C->isExactlyValue(-1.0))
12077           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
12078                              DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
12079       }
12080     }
12081     return SDValue();
12082   };
12083 
12084   if (SDValue FMA = FuseFADD(N0, N1, Flags))
12085     return FMA;
12086   if (SDValue FMA = FuseFADD(N1, N0, Flags))
12087     return FMA;
12088 
12089   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
12090   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
12091   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
12092   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
12093   auto FuseFSUB = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) {
12094     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
12095       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
12096         if (C0->isExactlyValue(+1.0))
12097           return DAG.getNode(PreferredFusedOpcode, SL, VT,
12098                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
12099                              Y, Flags);
12100         if (C0->isExactlyValue(-1.0))
12101           return DAG.getNode(PreferredFusedOpcode, SL, VT,
12102                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
12103                              DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
12104       }
12105       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
12106         if (C1->isExactlyValue(+1.0))
12107           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
12108                              DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
12109         if (C1->isExactlyValue(-1.0))
12110           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
12111                              Y, Flags);
12112       }
12113     }
12114     return SDValue();
12115   };
12116 
12117   if (SDValue FMA = FuseFSUB(N0, N1, Flags))
12118     return FMA;
12119   if (SDValue FMA = FuseFSUB(N1, N0, Flags))
12120     return FMA;
12121 
12122   return SDValue();
12123 }
12124 
visitFADD(SDNode * N)12125 SDValue DAGCombiner::visitFADD(SDNode *N) {
12126   SDValue N0 = N->getOperand(0);
12127   SDValue N1 = N->getOperand(1);
12128   bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0);
12129   bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1);
12130   EVT VT = N->getValueType(0);
12131   SDLoc DL(N);
12132   const TargetOptions &Options = DAG.getTarget().Options;
12133   const SDNodeFlags Flags = N->getFlags();
12134 
12135   // fold vector ops
12136   if (VT.isVector())
12137     if (SDValue FoldedVOp = SimplifyVBinOp(N))
12138       return FoldedVOp;
12139 
12140   // fold (fadd c1, c2) -> c1 + c2
12141   if (N0CFP && N1CFP)
12142     return DAG.getNode(ISD::FADD, DL, VT, N0, N1, Flags);
12143 
12144   // canonicalize constant to RHS
12145   if (N0CFP && !N1CFP)
12146     return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags);
12147 
12148   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
12149   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
12150   if (N1C && N1C->isZero())
12151     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
12152       return N0;
12153 
12154   if (SDValue NewSel = foldBinOpIntoSelect(N))
12155     return NewSel;
12156 
12157   // fold (fadd A, (fneg B)) -> (fsub A, B)
12158   if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
12159       TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize) == 2)
12160     return DAG.getNode(
12161         ISD::FSUB, DL, VT, N0,
12162         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
12163 
12164   // fold (fadd (fneg A), B) -> (fsub B, A)
12165   if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
12166       TLI.isNegatibleForFree(N0, DAG, LegalOperations, ForCodeSize) == 2)
12167     return DAG.getNode(
12168         ISD::FSUB, DL, VT, N1,
12169         TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize), Flags);
12170 
12171   auto isFMulNegTwo = [](SDValue FMul) {
12172     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
12173       return false;
12174     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
12175     return C && C->isExactlyValue(-2.0);
12176   };
12177 
12178   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
12179   if (isFMulNegTwo(N0)) {
12180     SDValue B = N0.getOperand(0);
12181     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags);
12182     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add, Flags);
12183   }
12184   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
12185   if (isFMulNegTwo(N1)) {
12186     SDValue B = N1.getOperand(0);
12187     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags);
12188     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add, Flags);
12189   }
12190 
12191   // No FP constant should be created after legalization as Instruction
12192   // Selection pass has a hard time dealing with FP constants.
12193   bool AllowNewConst = (Level < AfterLegalizeDAG);
12194 
12195   // If nnan is enabled, fold lots of things.
12196   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
12197     // If allowed, fold (fadd (fneg x), x) -> 0.0
12198     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
12199       return DAG.getConstantFP(0.0, DL, VT);
12200 
12201     // If allowed, fold (fadd x, (fneg x)) -> 0.0
12202     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
12203       return DAG.getConstantFP(0.0, DL, VT);
12204   }
12205 
12206   // If 'unsafe math' or reassoc and nsz, fold lots of things.
12207   // TODO: break out portions of the transformations below for which Unsafe is
12208   //       considered and which do not require both nsz and reassoc
12209   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
12210        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
12211       AllowNewConst) {
12212     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
12213     if (N1CFP && N0.getOpcode() == ISD::FADD &&
12214         isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
12215       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, Flags);
12216       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC, Flags);
12217     }
12218 
12219     // We can fold chains of FADD's of the same value into multiplications.
12220     // This transform is not safe in general because we are reducing the number
12221     // of rounding steps.
12222     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
12223       if (N0.getOpcode() == ISD::FMUL) {
12224         bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
12225         bool CFP01 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
12226 
12227         // (fadd (fmul x, c), x) -> (fmul x, c+1)
12228         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
12229           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
12230                                        DAG.getConstantFP(1.0, DL, VT), Flags);
12231           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP, Flags);
12232         }
12233 
12234         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
12235         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
12236             N1.getOperand(0) == N1.getOperand(1) &&
12237             N0.getOperand(0) == N1.getOperand(0)) {
12238           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
12239                                        DAG.getConstantFP(2.0, DL, VT), Flags);
12240           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP, Flags);
12241         }
12242       }
12243 
12244       if (N1.getOpcode() == ISD::FMUL) {
12245         bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
12246         bool CFP11 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
12247 
12248         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
12249         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
12250           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
12251                                        DAG.getConstantFP(1.0, DL, VT), Flags);
12252           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP, Flags);
12253         }
12254 
12255         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
12256         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
12257             N0.getOperand(0) == N0.getOperand(1) &&
12258             N1.getOperand(0) == N0.getOperand(0)) {
12259           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
12260                                        DAG.getConstantFP(2.0, DL, VT), Flags);
12261           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP, Flags);
12262         }
12263       }
12264 
12265       if (N0.getOpcode() == ISD::FADD) {
12266         bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
12267         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
12268         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
12269             (N0.getOperand(0) == N1)) {
12270           return DAG.getNode(ISD::FMUL, DL, VT,
12271                              N1, DAG.getConstantFP(3.0, DL, VT), Flags);
12272         }
12273       }
12274 
12275       if (N1.getOpcode() == ISD::FADD) {
12276         bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
12277         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
12278         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
12279             N1.getOperand(0) == N0) {
12280           return DAG.getNode(ISD::FMUL, DL, VT,
12281                              N0, DAG.getConstantFP(3.0, DL, VT), Flags);
12282         }
12283       }
12284 
12285       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
12286       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
12287           N0.getOperand(0) == N0.getOperand(1) &&
12288           N1.getOperand(0) == N1.getOperand(1) &&
12289           N0.getOperand(0) == N1.getOperand(0)) {
12290         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
12291                            DAG.getConstantFP(4.0, DL, VT), Flags);
12292       }
12293     }
12294   } // enable-unsafe-fp-math
12295 
12296   // FADD -> FMA combines:
12297   if (SDValue Fused = visitFADDForFMACombine(N)) {
12298     AddToWorklist(Fused.getNode());
12299     return Fused;
12300   }
12301   return SDValue();
12302 }
12303 
visitFSUB(SDNode * N)12304 SDValue DAGCombiner::visitFSUB(SDNode *N) {
12305   SDValue N0 = N->getOperand(0);
12306   SDValue N1 = N->getOperand(1);
12307   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
12308   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
12309   EVT VT = N->getValueType(0);
12310   SDLoc DL(N);
12311   const TargetOptions &Options = DAG.getTarget().Options;
12312   const SDNodeFlags Flags = N->getFlags();
12313 
12314   // fold vector ops
12315   if (VT.isVector())
12316     if (SDValue FoldedVOp = SimplifyVBinOp(N))
12317       return FoldedVOp;
12318 
12319   // fold (fsub c1, c2) -> c1-c2
12320   if (N0CFP && N1CFP)
12321     return DAG.getNode(ISD::FSUB, DL, VT, N0, N1, Flags);
12322 
12323   if (SDValue NewSel = foldBinOpIntoSelect(N))
12324     return NewSel;
12325 
12326   // (fsub A, 0) -> A
12327   if (N1CFP && N1CFP->isZero()) {
12328     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
12329         Flags.hasNoSignedZeros()) {
12330       return N0;
12331     }
12332   }
12333 
12334   if (N0 == N1) {
12335     // (fsub x, x) -> 0.0
12336     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
12337       return DAG.getConstantFP(0.0f, DL, VT);
12338   }
12339 
12340   // (fsub -0.0, N1) -> -N1
12341   // NOTE: It is safe to transform an FSUB(-0.0,X) into an FNEG(X), since the
12342   //       FSUB does not specify the sign bit of a NaN. Also note that for
12343   //       the same reason, the inverse transform is not safe, unless fast math
12344   //       flags are in play.
12345   if (N0CFP && N0CFP->isZero()) {
12346     if (N0CFP->isNegative() ||
12347         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
12348       if (TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize))
12349         return TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
12350       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
12351         return DAG.getNode(ISD::FNEG, DL, VT, N1, Flags);
12352     }
12353   }
12354 
12355   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
12356        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
12357       N1.getOpcode() == ISD::FADD) {
12358     // X - (X + Y) -> -Y
12359     if (N0 == N1->getOperand(0))
12360       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1), Flags);
12361     // X - (Y + X) -> -Y
12362     if (N0 == N1->getOperand(1))
12363       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0), Flags);
12364   }
12365 
12366   // fold (fsub A, (fneg B)) -> (fadd A, B)
12367   if (TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize))
12368     return DAG.getNode(
12369         ISD::FADD, DL, VT, N0,
12370         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
12371 
12372   // FSUB -> FMA combines:
12373   if (SDValue Fused = visitFSUBForFMACombine(N)) {
12374     AddToWorklist(Fused.getNode());
12375     return Fused;
12376   }
12377 
12378   return SDValue();
12379 }
12380 
12381 /// Return true if both inputs are at least as cheap in negated form and at
12382 /// least one input is strictly cheaper in negated form.
isCheaperToUseNegatedFPOps(SDValue X,SDValue Y)12383 bool DAGCombiner::isCheaperToUseNegatedFPOps(SDValue X, SDValue Y) {
12384   if (char LHSNeg =
12385           TLI.isNegatibleForFree(X, DAG, LegalOperations, ForCodeSize))
12386     if (char RHSNeg =
12387             TLI.isNegatibleForFree(Y, DAG, LegalOperations, ForCodeSize))
12388       // Both negated operands are at least as cheap as their counterparts.
12389       // Check to see if at least one is cheaper negated.
12390       if (LHSNeg == 2 || RHSNeg == 2)
12391         return true;
12392 
12393   return false;
12394 }
12395 
visitFMUL(SDNode * N)12396 SDValue DAGCombiner::visitFMUL(SDNode *N) {
12397   SDValue N0 = N->getOperand(0);
12398   SDValue N1 = N->getOperand(1);
12399   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
12400   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
12401   EVT VT = N->getValueType(0);
12402   SDLoc DL(N);
12403   const TargetOptions &Options = DAG.getTarget().Options;
12404   const SDNodeFlags Flags = N->getFlags();
12405 
12406   // fold vector ops
12407   if (VT.isVector()) {
12408     // This just handles C1 * C2 for vectors. Other vector folds are below.
12409     if (SDValue FoldedVOp = SimplifyVBinOp(N))
12410       return FoldedVOp;
12411   }
12412 
12413   // fold (fmul c1, c2) -> c1*c2
12414   if (N0CFP && N1CFP)
12415     return DAG.getNode(ISD::FMUL, DL, VT, N0, N1, Flags);
12416 
12417   // canonicalize constant to RHS
12418   if (isConstantFPBuildVectorOrConstantFP(N0) &&
12419      !isConstantFPBuildVectorOrConstantFP(N1))
12420     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0, Flags);
12421 
12422   if (SDValue NewSel = foldBinOpIntoSelect(N))
12423     return NewSel;
12424 
12425   if ((Options.NoNaNsFPMath && Options.NoSignedZerosFPMath) ||
12426       (Flags.hasNoNaNs() && Flags.hasNoSignedZeros())) {
12427     // fold (fmul A, 0) -> 0
12428     if (N1CFP && N1CFP->isZero())
12429       return N1;
12430   }
12431 
12432   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
12433     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
12434     if (isConstantFPBuildVectorOrConstantFP(N1) &&
12435         N0.getOpcode() == ISD::FMUL) {
12436       SDValue N00 = N0.getOperand(0);
12437       SDValue N01 = N0.getOperand(1);
12438       // Avoid an infinite loop by making sure that N00 is not a constant
12439       // (the inner multiply has not been constant folded yet).
12440       if (isConstantFPBuildVectorOrConstantFP(N01) &&
12441           !isConstantFPBuildVectorOrConstantFP(N00)) {
12442         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags);
12443         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags);
12444       }
12445     }
12446 
12447     // Match a special-case: we convert X * 2.0 into fadd.
12448     // fmul (fadd X, X), C -> fmul X, 2.0 * C
12449     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
12450         N0.getOperand(0) == N0.getOperand(1)) {
12451       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
12452       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1, Flags);
12453       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts, Flags);
12454     }
12455   }
12456 
12457   // fold (fmul X, 2.0) -> (fadd X, X)
12458   if (N1CFP && N1CFP->isExactlyValue(+2.0))
12459     return DAG.getNode(ISD::FADD, DL, VT, N0, N0, Flags);
12460 
12461   // fold (fmul X, -1.0) -> (fneg X)
12462   if (N1CFP && N1CFP->isExactlyValue(-1.0))
12463     if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
12464       return DAG.getNode(ISD::FNEG, DL, VT, N0);
12465 
12466   // -N0 * -N1 --> N0 * N1
12467   if (isCheaperToUseNegatedFPOps(N0, N1)) {
12468     SDValue NegN0 =
12469         TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
12470     SDValue NegN1 =
12471         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
12472     return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1, Flags);
12473   }
12474 
12475   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
12476   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
12477   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
12478       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
12479       TLI.isOperationLegal(ISD::FABS, VT)) {
12480     SDValue Select = N0, X = N1;
12481     if (Select.getOpcode() != ISD::SELECT)
12482       std::swap(Select, X);
12483 
12484     SDValue Cond = Select.getOperand(0);
12485     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
12486     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
12487 
12488     if (TrueOpnd && FalseOpnd &&
12489         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
12490         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
12491         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
12492       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
12493       switch (CC) {
12494       default: break;
12495       case ISD::SETOLT:
12496       case ISD::SETULT:
12497       case ISD::SETOLE:
12498       case ISD::SETULE:
12499       case ISD::SETLT:
12500       case ISD::SETLE:
12501         std::swap(TrueOpnd, FalseOpnd);
12502         LLVM_FALLTHROUGH;
12503       case ISD::SETOGT:
12504       case ISD::SETUGT:
12505       case ISD::SETOGE:
12506       case ISD::SETUGE:
12507       case ISD::SETGT:
12508       case ISD::SETGE:
12509         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
12510             TLI.isOperationLegal(ISD::FNEG, VT))
12511           return DAG.getNode(ISD::FNEG, DL, VT,
12512                    DAG.getNode(ISD::FABS, DL, VT, X));
12513         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
12514           return DAG.getNode(ISD::FABS, DL, VT, X);
12515 
12516         break;
12517       }
12518     }
12519   }
12520 
12521   // FMUL -> FMA combines:
12522   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
12523     AddToWorklist(Fused.getNode());
12524     return Fused;
12525   }
12526 
12527   return SDValue();
12528 }
12529 
visitFMA(SDNode * N)12530 SDValue DAGCombiner::visitFMA(SDNode *N) {
12531   SDValue N0 = N->getOperand(0);
12532   SDValue N1 = N->getOperand(1);
12533   SDValue N2 = N->getOperand(2);
12534   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
12535   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
12536   EVT VT = N->getValueType(0);
12537   SDLoc DL(N);
12538   const TargetOptions &Options = DAG.getTarget().Options;
12539 
12540   // FMA nodes have flags that propagate to the created nodes.
12541   const SDNodeFlags Flags = N->getFlags();
12542   bool UnsafeFPMath = Options.UnsafeFPMath || isContractable(N);
12543 
12544   // Constant fold FMA.
12545   if (isa<ConstantFPSDNode>(N0) &&
12546       isa<ConstantFPSDNode>(N1) &&
12547       isa<ConstantFPSDNode>(N2)) {
12548     return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
12549   }
12550 
12551   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
12552   if (isCheaperToUseNegatedFPOps(N0, N1)) {
12553     SDValue NegN0 =
12554         TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
12555     SDValue NegN1 =
12556         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
12557     return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2, Flags);
12558   }
12559 
12560   if (UnsafeFPMath) {
12561     if (N0CFP && N0CFP->isZero())
12562       return N2;
12563     if (N1CFP && N1CFP->isZero())
12564       return N2;
12565   }
12566   // TODO: The FMA node should have flags that propagate to these nodes.
12567   if (N0CFP && N0CFP->isExactlyValue(1.0))
12568     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
12569   if (N1CFP && N1CFP->isExactlyValue(1.0))
12570     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
12571 
12572   // Canonicalize (fma c, x, y) -> (fma x, c, y)
12573   if (isConstantFPBuildVectorOrConstantFP(N0) &&
12574      !isConstantFPBuildVectorOrConstantFP(N1))
12575     return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
12576 
12577   if (UnsafeFPMath) {
12578     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
12579     if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
12580         isConstantFPBuildVectorOrConstantFP(N1) &&
12581         isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
12582       return DAG.getNode(ISD::FMUL, DL, VT, N0,
12583                          DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1),
12584                                      Flags), Flags);
12585     }
12586 
12587     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
12588     if (N0.getOpcode() == ISD::FMUL &&
12589         isConstantFPBuildVectorOrConstantFP(N1) &&
12590         isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
12591       return DAG.getNode(ISD::FMA, DL, VT,
12592                          N0.getOperand(0),
12593                          DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1),
12594                                      Flags),
12595                          N2);
12596     }
12597   }
12598 
12599   // (fma x, 1, y) -> (fadd x, y)
12600   // (fma x, -1, y) -> (fadd (fneg x), y)
12601   if (N1CFP) {
12602     if (N1CFP->isExactlyValue(1.0))
12603       // TODO: The FMA node should have flags that propagate to this node.
12604       return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
12605 
12606     if (N1CFP->isExactlyValue(-1.0) &&
12607         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
12608       SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
12609       AddToWorklist(RHSNeg.getNode());
12610       // TODO: The FMA node should have flags that propagate to this node.
12611       return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
12612     }
12613 
12614     // fma (fneg x), K, y -> fma x -K, y
12615     if (N0.getOpcode() == ISD::FNEG &&
12616         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
12617          (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
12618                                               ForCodeSize)))) {
12619       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
12620                          DAG.getNode(ISD::FNEG, DL, VT, N1, Flags), N2);
12621     }
12622   }
12623 
12624   if (UnsafeFPMath) {
12625     // (fma x, c, x) -> (fmul x, (c+1))
12626     if (N1CFP && N0 == N2) {
12627       return DAG.getNode(ISD::FMUL, DL, VT, N0,
12628                          DAG.getNode(ISD::FADD, DL, VT, N1,
12629                                      DAG.getConstantFP(1.0, DL, VT), Flags),
12630                          Flags);
12631     }
12632 
12633     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
12634     if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
12635       return DAG.getNode(ISD::FMUL, DL, VT, N0,
12636                          DAG.getNode(ISD::FADD, DL, VT, N1,
12637                                      DAG.getConstantFP(-1.0, DL, VT), Flags),
12638                          Flags);
12639     }
12640   }
12641 
12642   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
12643   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
12644   if (!TLI.isFNegFree(VT) &&
12645       TLI.isNegatibleForFree(SDValue(N, 0), DAG, LegalOperations,
12646                              ForCodeSize) == 2)
12647     return DAG.getNode(ISD::FNEG, DL, VT,
12648                        TLI.getNegatedExpression(SDValue(N, 0), DAG,
12649                                                 LegalOperations, ForCodeSize),
12650                        Flags);
12651   return SDValue();
12652 }
12653 
12654 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
12655 // reciprocal.
12656 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
12657 // Notice that this is not always beneficial. One reason is different targets
12658 // may have different costs for FDIV and FMUL, so sometimes the cost of two
12659 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
12660 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)12661 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
12662   // TODO: Limit this transform based on optsize/minsize - it always creates at
12663   //       least 1 extra instruction. But the perf win may be substantial enough
12664   //       that only minsize should restrict this.
12665   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
12666   const SDNodeFlags Flags = N->getFlags();
12667   if (!UnsafeMath && !Flags.hasAllowReciprocal())
12668     return SDValue();
12669 
12670   // Skip if current node is a reciprocal/fneg-reciprocal.
12671   SDValue N0 = N->getOperand(0);
12672   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
12673   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
12674     return SDValue();
12675 
12676   // Exit early if the target does not want this transform or if there can't
12677   // possibly be enough uses of the divisor to make the transform worthwhile.
12678   SDValue N1 = N->getOperand(1);
12679   unsigned MinUses = TLI.combineRepeatedFPDivisors();
12680 
12681   // For splat vectors, scale the number of uses by the splat factor. If we can
12682   // convert the division into a scalar op, that will likely be much faster.
12683   unsigned NumElts = 1;
12684   EVT VT = N->getValueType(0);
12685   if (VT.isVector() && DAG.isSplatValue(N1))
12686     NumElts = VT.getVectorNumElements();
12687 
12688   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
12689     return SDValue();
12690 
12691   // Find all FDIV users of the same divisor.
12692   // Use a set because duplicates may be present in the user list.
12693   SetVector<SDNode *> Users;
12694   for (auto *U : N1->uses()) {
12695     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
12696       // This division is eligible for optimization only if global unsafe math
12697       // is enabled or if this division allows reciprocal formation.
12698       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
12699         Users.insert(U);
12700     }
12701   }
12702 
12703   // Now that we have the actual number of divisor uses, make sure it meets
12704   // the minimum threshold specified by the target.
12705   if ((Users.size() * NumElts) < MinUses)
12706     return SDValue();
12707 
12708   SDLoc DL(N);
12709   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
12710   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
12711 
12712   // Dividend / Divisor -> Dividend * Reciprocal
12713   for (auto *U : Users) {
12714     SDValue Dividend = U->getOperand(0);
12715     if (Dividend != FPOne) {
12716       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
12717                                     Reciprocal, Flags);
12718       CombineTo(U, NewNode);
12719     } else if (U != Reciprocal.getNode()) {
12720       // In the absence of fast-math-flags, this user node is always the
12721       // same node as Reciprocal, but with FMF they may be different nodes.
12722       CombineTo(U, Reciprocal);
12723     }
12724   }
12725   return SDValue(N, 0);  // N was replaced.
12726 }
12727 
visitFDIV(SDNode * N)12728 SDValue DAGCombiner::visitFDIV(SDNode *N) {
12729   SDValue N0 = N->getOperand(0);
12730   SDValue N1 = N->getOperand(1);
12731   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
12732   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
12733   EVT VT = N->getValueType(0);
12734   SDLoc DL(N);
12735   const TargetOptions &Options = DAG.getTarget().Options;
12736   SDNodeFlags Flags = N->getFlags();
12737 
12738   // fold vector ops
12739   if (VT.isVector())
12740     if (SDValue FoldedVOp = SimplifyVBinOp(N))
12741       return FoldedVOp;
12742 
12743   // fold (fdiv c1, c2) -> c1/c2
12744   if (N0CFP && N1CFP)
12745     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1, Flags);
12746 
12747   if (SDValue NewSel = foldBinOpIntoSelect(N))
12748     return NewSel;
12749 
12750   if (SDValue V = combineRepeatedFPDivisors(N))
12751     return V;
12752 
12753   if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
12754     // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
12755     if (N1CFP) {
12756       // Compute the reciprocal 1.0 / c2.
12757       const APFloat &N1APF = N1CFP->getValueAPF();
12758       APFloat Recip(N1APF.getSemantics(), 1); // 1.0
12759       APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
12760       // Only do the transform if the reciprocal is a legal fp immediate that
12761       // isn't too nasty (eg NaN, denormal, ...).
12762       if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
12763           (!LegalOperations ||
12764            // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
12765            // backend)... we should handle this gracefully after Legalize.
12766            // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
12767            TLI.isOperationLegal(ISD::ConstantFP, VT) ||
12768            TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
12769         return DAG.getNode(ISD::FMUL, DL, VT, N0,
12770                            DAG.getConstantFP(Recip, DL, VT), Flags);
12771     }
12772 
12773     // If this FDIV is part of a reciprocal square root, it may be folded
12774     // into a target-specific square root estimate instruction.
12775     if (N1.getOpcode() == ISD::FSQRT) {
12776       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
12777         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
12778     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
12779                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
12780       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
12781                                           Flags)) {
12782         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
12783         AddToWorklist(RV.getNode());
12784         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
12785       }
12786     } else if (N1.getOpcode() == ISD::FP_ROUND &&
12787                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
12788       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
12789                                           Flags)) {
12790         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
12791         AddToWorklist(RV.getNode());
12792         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
12793       }
12794     } else if (N1.getOpcode() == ISD::FMUL) {
12795       // Look through an FMUL. Even though this won't remove the FDIV directly,
12796       // it's still worthwhile to get rid of the FSQRT if possible.
12797       SDValue SqrtOp;
12798       SDValue OtherOp;
12799       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
12800         SqrtOp = N1.getOperand(0);
12801         OtherOp = N1.getOperand(1);
12802       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
12803         SqrtOp = N1.getOperand(1);
12804         OtherOp = N1.getOperand(0);
12805       }
12806       if (SqrtOp.getNode()) {
12807         // We found a FSQRT, so try to make this fold:
12808         // x / (y * sqrt(z)) -> x * (rsqrt(z) / y)
12809         if (SDValue RV = buildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) {
12810           RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags);
12811           AddToWorklist(RV.getNode());
12812           return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
12813         }
12814       }
12815     }
12816 
12817     // Fold into a reciprocal estimate and multiply instead of a real divide.
12818     if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
12819       return RV;
12820   }
12821 
12822   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
12823   if (isCheaperToUseNegatedFPOps(N0, N1))
12824     return DAG.getNode(
12825         ISD::FDIV, SDLoc(N), VT,
12826         TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize),
12827         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
12828 
12829   return SDValue();
12830 }
12831 
visitFREM(SDNode * N)12832 SDValue DAGCombiner::visitFREM(SDNode *N) {
12833   SDValue N0 = N->getOperand(0);
12834   SDValue N1 = N->getOperand(1);
12835   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
12836   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
12837   EVT VT = N->getValueType(0);
12838 
12839   // fold (frem c1, c2) -> fmod(c1,c2)
12840   if (N0CFP && N1CFP)
12841     return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1, N->getFlags());
12842 
12843   if (SDValue NewSel = foldBinOpIntoSelect(N))
12844     return NewSel;
12845 
12846   return SDValue();
12847 }
12848 
visitFSQRT(SDNode * N)12849 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
12850   SDNodeFlags Flags = N->getFlags();
12851   if (!DAG.getTarget().Options.UnsafeFPMath &&
12852       !Flags.hasApproximateFuncs())
12853     return SDValue();
12854 
12855   SDValue N0 = N->getOperand(0);
12856   if (TLI.isFsqrtCheap(N0, DAG))
12857     return SDValue();
12858 
12859   // FSQRT nodes have flags that propagate to the created nodes.
12860   return buildSqrtEstimate(N0, Flags);
12861 }
12862 
12863 /// copysign(x, fp_extend(y)) -> copysign(x, y)
12864 /// copysign(x, fp_round(y)) -> copysign(x, y)
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)12865 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
12866   SDValue N1 = N->getOperand(1);
12867   if ((N1.getOpcode() == ISD::FP_EXTEND ||
12868        N1.getOpcode() == ISD::FP_ROUND)) {
12869     // Do not optimize out type conversion of f128 type yet.
12870     // For some targets like x86_64, configuration is changed to keep one f128
12871     // value in one SSE register, but instruction selection cannot handle
12872     // FCOPYSIGN on SSE registers yet.
12873     EVT N1VT = N1->getValueType(0);
12874     EVT N1Op0VT = N1->getOperand(0).getValueType();
12875     return (N1VT == N1Op0VT || N1Op0VT != MVT::f128);
12876   }
12877   return false;
12878 }
12879 
visitFCOPYSIGN(SDNode * N)12880 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
12881   SDValue N0 = N->getOperand(0);
12882   SDValue N1 = N->getOperand(1);
12883   bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0);
12884   bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1);
12885   EVT VT = N->getValueType(0);
12886 
12887   if (N0CFP && N1CFP) // Constant fold
12888     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1);
12889 
12890   if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
12891     const APFloat &V = N1C->getValueAPF();
12892     // copysign(x, c1) -> fabs(x)       iff ispos(c1)
12893     // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
12894     if (!V.isNegative()) {
12895       if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
12896         return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
12897     } else {
12898       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
12899         return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
12900                            DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
12901     }
12902   }
12903 
12904   // copysign(fabs(x), y) -> copysign(x, y)
12905   // copysign(fneg(x), y) -> copysign(x, y)
12906   // copysign(copysign(x,z), y) -> copysign(x, y)
12907   if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
12908       N0.getOpcode() == ISD::FCOPYSIGN)
12909     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
12910 
12911   // copysign(x, abs(y)) -> abs(x)
12912   if (N1.getOpcode() == ISD::FABS)
12913     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
12914 
12915   // copysign(x, copysign(y,z)) -> copysign(x, z)
12916   if (N1.getOpcode() == ISD::FCOPYSIGN)
12917     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
12918 
12919   // copysign(x, fp_extend(y)) -> copysign(x, y)
12920   // copysign(x, fp_round(y)) -> copysign(x, y)
12921   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
12922     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
12923 
12924   return SDValue();
12925 }
12926 
visitFPOW(SDNode * N)12927 SDValue DAGCombiner::visitFPOW(SDNode *N) {
12928   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
12929   if (!ExponentC)
12930     return SDValue();
12931 
12932   // Try to convert x ** (1/3) into cube root.
12933   // TODO: Handle the various flavors of long double.
12934   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
12935   //       Some range near 1/3 should be fine.
12936   EVT VT = N->getValueType(0);
12937   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
12938       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
12939     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
12940     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
12941     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
12942     // For regular numbers, rounding may cause the results to differ.
12943     // Therefore, we require { nsz ninf nnan afn } for this transform.
12944     // TODO: We could select out the special cases if we don't have nsz/ninf.
12945     SDNodeFlags Flags = N->getFlags();
12946     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
12947         !Flags.hasApproximateFuncs())
12948       return SDValue();
12949 
12950     // Do not create a cbrt() libcall if the target does not have it, and do not
12951     // turn a pow that has lowering support into a cbrt() libcall.
12952     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
12953         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
12954          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
12955       return SDValue();
12956 
12957     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0), Flags);
12958   }
12959 
12960   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
12961   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
12962   // TODO: This could be extended (using a target hook) to handle smaller
12963   // power-of-2 fractional exponents.
12964   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
12965   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
12966   if (ExponentIs025 || ExponentIs075) {
12967     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
12968     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
12969     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
12970     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
12971     // For regular numbers, rounding may cause the results to differ.
12972     // Therefore, we require { nsz ninf afn } for this transform.
12973     // TODO: We could select out the special cases if we don't have nsz/ninf.
12974     SDNodeFlags Flags = N->getFlags();
12975 
12976     // We only need no signed zeros for the 0.25 case.
12977     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
12978         !Flags.hasApproximateFuncs())
12979       return SDValue();
12980 
12981     // Don't double the number of libcalls. We are trying to inline fast code.
12982     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
12983       return SDValue();
12984 
12985     // Assume that libcalls are the smallest code.
12986     // TODO: This restriction should probably be lifted for vectors.
12987     if (ForCodeSize)
12988       return SDValue();
12989 
12990     // pow(X, 0.25) --> sqrt(sqrt(X))
12991     SDLoc DL(N);
12992     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0), Flags);
12993     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt, Flags);
12994     if (ExponentIs025)
12995       return SqrtSqrt;
12996     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
12997     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt, Flags);
12998   }
12999 
13000   return SDValue();
13001 }
13002 
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)13003 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
13004                                const TargetLowering &TLI) {
13005   // This optimization is guarded by a function attribute because it may produce
13006   // unexpected results. Ie, programs may be relying on the platform-specific
13007   // undefined behavior when the float-to-int conversion overflows.
13008   const Function &F = DAG.getMachineFunction().getFunction();
13009   Attribute StrictOverflow = F.getFnAttribute("strict-float-cast-overflow");
13010   if (StrictOverflow.getValueAsString().equals("false"))
13011     return SDValue();
13012 
13013   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
13014   // replacing casts with a libcall. We also must be allowed to ignore -0.0
13015   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
13016   // conversions would return +0.0.
13017   // FIXME: We should be able to use node-level FMF here.
13018   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
13019   EVT VT = N->getValueType(0);
13020   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
13021       !DAG.getTarget().Options.NoSignedZerosFPMath)
13022     return SDValue();
13023 
13024   // fptosi/fptoui round towards zero, so converting from FP to integer and
13025   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
13026   SDValue N0 = N->getOperand(0);
13027   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
13028       N0.getOperand(0).getValueType() == VT)
13029     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
13030 
13031   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
13032       N0.getOperand(0).getValueType() == VT)
13033     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
13034 
13035   return SDValue();
13036 }
13037 
visitSINT_TO_FP(SDNode * N)13038 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
13039   SDValue N0 = N->getOperand(0);
13040   EVT VT = N->getValueType(0);
13041   EVT OpVT = N0.getValueType();
13042 
13043   // [us]itofp(undef) = 0, because the result value is bounded.
13044   if (N0.isUndef())
13045     return DAG.getConstantFP(0.0, SDLoc(N), VT);
13046 
13047   // fold (sint_to_fp c1) -> c1fp
13048   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
13049       // ...but only if the target supports immediate floating-point values
13050       (!LegalOperations ||
13051        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
13052     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
13053 
13054   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
13055   // but UINT_TO_FP is legal on this target, try to convert.
13056   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
13057       hasOperation(ISD::UINT_TO_FP, OpVT)) {
13058     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
13059     if (DAG.SignBitIsZero(N0))
13060       return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
13061   }
13062 
13063   // The next optimizations are desirable only if SELECT_CC can be lowered.
13064   if (TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT) || !LegalOperations) {
13065     // fold (sint_to_fp (setcc x, y, cc)) -> (select_cc x, y, -1.0, 0.0,, cc)
13066     if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
13067         !VT.isVector() &&
13068         (!LegalOperations ||
13069          TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
13070       SDLoc DL(N);
13071       SDValue Ops[] =
13072         { N0.getOperand(0), N0.getOperand(1),
13073           DAG.getConstantFP(-1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT),
13074           N0.getOperand(2) };
13075       return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops);
13076     }
13077 
13078     // fold (sint_to_fp (zext (setcc x, y, cc))) ->
13079     //      (select_cc x, y, 1.0, 0.0,, cc)
13080     if (N0.getOpcode() == ISD::ZERO_EXTEND &&
13081         N0.getOperand(0).getOpcode() == ISD::SETCC &&!VT.isVector() &&
13082         (!LegalOperations ||
13083          TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
13084       SDLoc DL(N);
13085       SDValue Ops[] =
13086         { N0.getOperand(0).getOperand(0), N0.getOperand(0).getOperand(1),
13087           DAG.getConstantFP(1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT),
13088           N0.getOperand(0).getOperand(2) };
13089       return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops);
13090     }
13091   }
13092 
13093   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
13094     return FTrunc;
13095 
13096   return SDValue();
13097 }
13098 
visitUINT_TO_FP(SDNode * N)13099 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
13100   SDValue N0 = N->getOperand(0);
13101   EVT VT = N->getValueType(0);
13102   EVT OpVT = N0.getValueType();
13103 
13104   // [us]itofp(undef) = 0, because the result value is bounded.
13105   if (N0.isUndef())
13106     return DAG.getConstantFP(0.0, SDLoc(N), VT);
13107 
13108   // fold (uint_to_fp c1) -> c1fp
13109   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
13110       // ...but only if the target supports immediate floating-point values
13111       (!LegalOperations ||
13112        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
13113     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
13114 
13115   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
13116   // but SINT_TO_FP is legal on this target, try to convert.
13117   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
13118       hasOperation(ISD::SINT_TO_FP, OpVT)) {
13119     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
13120     if (DAG.SignBitIsZero(N0))
13121       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
13122   }
13123 
13124   // The next optimizations are desirable only if SELECT_CC can be lowered.
13125   if (TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT) || !LegalOperations) {
13126     // fold (uint_to_fp (setcc x, y, cc)) -> (select_cc x, y, -1.0, 0.0,, cc)
13127     if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
13128         (!LegalOperations ||
13129          TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
13130       SDLoc DL(N);
13131       SDValue Ops[] =
13132         { N0.getOperand(0), N0.getOperand(1),
13133           DAG.getConstantFP(1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT),
13134           N0.getOperand(2) };
13135       return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops);
13136     }
13137   }
13138 
13139   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
13140     return FTrunc;
13141 
13142   return SDValue();
13143 }
13144 
13145 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)13146 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
13147   SDValue N0 = N->getOperand(0);
13148   EVT VT = N->getValueType(0);
13149 
13150   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
13151     return SDValue();
13152 
13153   SDValue Src = N0.getOperand(0);
13154   EVT SrcVT = Src.getValueType();
13155   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
13156   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
13157 
13158   // We can safely assume the conversion won't overflow the output range,
13159   // because (for example) (uint8_t)18293.f is undefined behavior.
13160 
13161   // Since we can assume the conversion won't overflow, our decision as to
13162   // whether the input will fit in the float should depend on the minimum
13163   // of the input range and output range.
13164 
13165   // This means this is also safe for a signed input and unsigned output, since
13166   // a negative input would lead to undefined behavior.
13167   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
13168   unsigned OutputSize = (int)VT.getScalarSizeInBits() - IsOutputSigned;
13169   unsigned ActualSize = std::min(InputSize, OutputSize);
13170   const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
13171 
13172   // We can only fold away the float conversion if the input range can be
13173   // represented exactly in the float range.
13174   if (APFloat::semanticsPrecision(sem) >= ActualSize) {
13175     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
13176       unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
13177                                                        : ISD::ZERO_EXTEND;
13178       return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
13179     }
13180     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
13181       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
13182     return DAG.getBitcast(VT, Src);
13183   }
13184   return SDValue();
13185 }
13186 
visitFP_TO_SINT(SDNode * N)13187 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
13188   SDValue N0 = N->getOperand(0);
13189   EVT VT = N->getValueType(0);
13190 
13191   // fold (fp_to_sint undef) -> undef
13192   if (N0.isUndef())
13193     return DAG.getUNDEF(VT);
13194 
13195   // fold (fp_to_sint c1fp) -> c1
13196   if (isConstantFPBuildVectorOrConstantFP(N0))
13197     return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
13198 
13199   return FoldIntToFPToInt(N, DAG);
13200 }
13201 
visitFP_TO_UINT(SDNode * N)13202 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
13203   SDValue N0 = N->getOperand(0);
13204   EVT VT = N->getValueType(0);
13205 
13206   // fold (fp_to_uint undef) -> undef
13207   if (N0.isUndef())
13208     return DAG.getUNDEF(VT);
13209 
13210   // fold (fp_to_uint c1fp) -> c1
13211   if (isConstantFPBuildVectorOrConstantFP(N0))
13212     return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
13213 
13214   return FoldIntToFPToInt(N, DAG);
13215 }
13216 
visitFP_ROUND(SDNode * N)13217 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
13218   SDValue N0 = N->getOperand(0);
13219   SDValue N1 = N->getOperand(1);
13220   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
13221   EVT VT = N->getValueType(0);
13222 
13223   // fold (fp_round c1fp) -> c1fp
13224   if (N0CFP)
13225     return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT, N0, N1);
13226 
13227   // fold (fp_round (fp_extend x)) -> x
13228   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
13229     return N0.getOperand(0);
13230 
13231   // fold (fp_round (fp_round x)) -> (fp_round x)
13232   if (N0.getOpcode() == ISD::FP_ROUND) {
13233     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
13234     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
13235 
13236     // Skip this folding if it results in an fp_round from f80 to f16.
13237     //
13238     // f80 to f16 always generates an expensive (and as yet, unimplemented)
13239     // libcall to __truncxfhf2 instead of selecting native f16 conversion
13240     // instructions from f32 or f64.  Moreover, the first (value-preserving)
13241     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
13242     // x86.
13243     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
13244       return SDValue();
13245 
13246     // If the first fp_round isn't a value preserving truncation, it might
13247     // introduce a tie in the second fp_round, that wouldn't occur in the
13248     // single-step fp_round we want to fold to.
13249     // In other words, double rounding isn't the same as rounding.
13250     // Also, this is a value preserving truncation iff both fp_round's are.
13251     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
13252       SDLoc DL(N);
13253       return DAG.getNode(ISD::FP_ROUND, DL, VT, N0.getOperand(0),
13254                          DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL));
13255     }
13256   }
13257 
13258   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
13259   if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse()) {
13260     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
13261                               N0.getOperand(0), N1);
13262     AddToWorklist(Tmp.getNode());
13263     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
13264                        Tmp, N0.getOperand(1));
13265   }
13266 
13267   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13268     return NewVSel;
13269 
13270   return SDValue();
13271 }
13272 
visitFP_EXTEND(SDNode * N)13273 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
13274   SDValue N0 = N->getOperand(0);
13275   EVT VT = N->getValueType(0);
13276 
13277   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
13278   if (N->hasOneUse() &&
13279       N->use_begin()->getOpcode() == ISD::FP_ROUND)
13280     return SDValue();
13281 
13282   // fold (fp_extend c1fp) -> c1fp
13283   if (isConstantFPBuildVectorOrConstantFP(N0))
13284     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
13285 
13286   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
13287   if (N0.getOpcode() == ISD::FP16_TO_FP &&
13288       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
13289     return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
13290 
13291   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
13292   // value of X.
13293   if (N0.getOpcode() == ISD::FP_ROUND
13294       && N0.getConstantOperandVal(1) == 1) {
13295     SDValue In = N0.getOperand(0);
13296     if (In.getValueType() == VT) return In;
13297     if (VT.bitsLT(In.getValueType()))
13298       return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
13299                          In, N0.getOperand(1));
13300     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
13301   }
13302 
13303   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
13304   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
13305        TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
13306     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13307     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
13308                                      LN0->getChain(),
13309                                      LN0->getBasePtr(), N0.getValueType(),
13310                                      LN0->getMemOperand());
13311     CombineTo(N, ExtLoad);
13312     CombineTo(N0.getNode(),
13313               DAG.getNode(ISD::FP_ROUND, SDLoc(N0),
13314                           N0.getValueType(), ExtLoad,
13315                           DAG.getIntPtrConstant(1, SDLoc(N0))),
13316               ExtLoad.getValue(1));
13317     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13318   }
13319 
13320   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13321     return NewVSel;
13322 
13323   return SDValue();
13324 }
13325 
visitFCEIL(SDNode * N)13326 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
13327   SDValue N0 = N->getOperand(0);
13328   EVT VT = N->getValueType(0);
13329 
13330   // fold (fceil c1) -> fceil(c1)
13331   if (isConstantFPBuildVectorOrConstantFP(N0))
13332     return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
13333 
13334   return SDValue();
13335 }
13336 
visitFTRUNC(SDNode * N)13337 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
13338   SDValue N0 = N->getOperand(0);
13339   EVT VT = N->getValueType(0);
13340 
13341   // fold (ftrunc c1) -> ftrunc(c1)
13342   if (isConstantFPBuildVectorOrConstantFP(N0))
13343     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
13344 
13345   // fold ftrunc (known rounded int x) -> x
13346   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
13347   // likely to be generated to extract integer from a rounded floating value.
13348   switch (N0.getOpcode()) {
13349   default: break;
13350   case ISD::FRINT:
13351   case ISD::FTRUNC:
13352   case ISD::FNEARBYINT:
13353   case ISD::FFLOOR:
13354   case ISD::FCEIL:
13355     return N0;
13356   }
13357 
13358   return SDValue();
13359 }
13360 
visitFFLOOR(SDNode * N)13361 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
13362   SDValue N0 = N->getOperand(0);
13363   EVT VT = N->getValueType(0);
13364 
13365   // fold (ffloor c1) -> ffloor(c1)
13366   if (isConstantFPBuildVectorOrConstantFP(N0))
13367     return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
13368 
13369   return SDValue();
13370 }
13371 
13372 // FIXME: FNEG and FABS have a lot in common; refactor.
visitFNEG(SDNode * N)13373 SDValue DAGCombiner::visitFNEG(SDNode *N) {
13374   SDValue N0 = N->getOperand(0);
13375   EVT VT = N->getValueType(0);
13376 
13377   // Constant fold FNEG.
13378   if (isConstantFPBuildVectorOrConstantFP(N0))
13379     return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
13380 
13381   if (TLI.isNegatibleForFree(N0, DAG, LegalOperations, ForCodeSize))
13382     return TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
13383 
13384   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0 FIXME: This is
13385   // duplicated in isNegatibleForFree, but isNegatibleForFree doesn't know it
13386   // was called from a context with a nsz flag if the input fsub does not.
13387   if (N0.getOpcode() == ISD::FSUB &&
13388       (DAG.getTarget().Options.NoSignedZerosFPMath ||
13389        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
13390     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
13391                        N0.getOperand(0), N->getFlags());
13392   }
13393 
13394   // Transform fneg(bitconvert(x)) -> bitconvert(x ^ sign) to avoid loading
13395   // constant pool values.
13396   if (!TLI.isFNegFree(VT) &&
13397       N0.getOpcode() == ISD::BITCAST &&
13398       N0.getNode()->hasOneUse()) {
13399     SDValue Int = N0.getOperand(0);
13400     EVT IntVT = Int.getValueType();
13401     if (IntVT.isInteger() && !IntVT.isVector()) {
13402       APInt SignMask;
13403       if (N0.getValueType().isVector()) {
13404         // For a vector, get a mask such as 0x80... per scalar element
13405         // and splat it.
13406         SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
13407         SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
13408       } else {
13409         // For a scalar, just generate 0x80...
13410         SignMask = APInt::getSignMask(IntVT.getSizeInBits());
13411       }
13412       SDLoc DL0(N0);
13413       Int = DAG.getNode(ISD::XOR, DL0, IntVT, Int,
13414                         DAG.getConstant(SignMask, DL0, IntVT));
13415       AddToWorklist(Int.getNode());
13416       return DAG.getBitcast(VT, Int);
13417     }
13418   }
13419 
13420   // (fneg (fmul c, x)) -> (fmul -c, x)
13421   if (N0.getOpcode() == ISD::FMUL &&
13422       (N0.getNode()->hasOneUse() || !TLI.isFNegFree(VT))) {
13423     ConstantFPSDNode *CFP1 = dyn_cast<ConstantFPSDNode>(N0.getOperand(1));
13424     if (CFP1) {
13425       APFloat CVal = CFP1->getValueAPF();
13426       CVal.changeSign();
13427       if (LegalDAG && (TLI.isFPImmLegal(CVal, VT, ForCodeSize) ||
13428                        TLI.isOperationLegal(ISD::ConstantFP, VT)))
13429         return DAG.getNode(
13430             ISD::FMUL, SDLoc(N), VT, N0.getOperand(0),
13431             DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0.getOperand(1)),
13432             N0->getFlags());
13433     }
13434   }
13435 
13436   return SDValue();
13437 }
13438 
visitFMinMax(SelectionDAG & DAG,SDNode * N,APFloat (* Op)(const APFloat &,const APFloat &))13439 static SDValue visitFMinMax(SelectionDAG &DAG, SDNode *N,
13440                             APFloat (*Op)(const APFloat &, const APFloat &)) {
13441   SDValue N0 = N->getOperand(0);
13442   SDValue N1 = N->getOperand(1);
13443   EVT VT = N->getValueType(0);
13444   const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0);
13445   const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1);
13446 
13447   if (N0CFP && N1CFP) {
13448     const APFloat &C0 = N0CFP->getValueAPF();
13449     const APFloat &C1 = N1CFP->getValueAPF();
13450     return DAG.getConstantFP(Op(C0, C1), SDLoc(N), VT);
13451   }
13452 
13453   // Canonicalize to constant on RHS.
13454   if (isConstantFPBuildVectorOrConstantFP(N0) &&
13455       !isConstantFPBuildVectorOrConstantFP(N1))
13456     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
13457 
13458   return SDValue();
13459 }
13460 
visitFMINNUM(SDNode * N)13461 SDValue DAGCombiner::visitFMINNUM(SDNode *N) {
13462   return visitFMinMax(DAG, N, minnum);
13463 }
13464 
visitFMAXNUM(SDNode * N)13465 SDValue DAGCombiner::visitFMAXNUM(SDNode *N) {
13466   return visitFMinMax(DAG, N, maxnum);
13467 }
13468 
visitFMINIMUM(SDNode * N)13469 SDValue DAGCombiner::visitFMINIMUM(SDNode *N) {
13470   return visitFMinMax(DAG, N, minimum);
13471 }
13472 
visitFMAXIMUM(SDNode * N)13473 SDValue DAGCombiner::visitFMAXIMUM(SDNode *N) {
13474   return visitFMinMax(DAG, N, maximum);
13475 }
13476 
visitFABS(SDNode * N)13477 SDValue DAGCombiner::visitFABS(SDNode *N) {
13478   SDValue N0 = N->getOperand(0);
13479   EVT VT = N->getValueType(0);
13480 
13481   // fold (fabs c1) -> fabs(c1)
13482   if (isConstantFPBuildVectorOrConstantFP(N0))
13483     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
13484 
13485   // fold (fabs (fabs x)) -> (fabs x)
13486   if (N0.getOpcode() == ISD::FABS)
13487     return N->getOperand(0);
13488 
13489   // fold (fabs (fneg x)) -> (fabs x)
13490   // fold (fabs (fcopysign x, y)) -> (fabs x)
13491   if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
13492     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
13493 
13494   // fabs(bitcast(x)) -> bitcast(x & ~sign) to avoid constant pool loads.
13495   if (!TLI.isFAbsFree(VT) && N0.getOpcode() == ISD::BITCAST && N0.hasOneUse()) {
13496     SDValue Int = N0.getOperand(0);
13497     EVT IntVT = Int.getValueType();
13498     if (IntVT.isInteger() && !IntVT.isVector()) {
13499       APInt SignMask;
13500       if (N0.getValueType().isVector()) {
13501         // For a vector, get a mask such as 0x7f... per scalar element
13502         // and splat it.
13503         SignMask = ~APInt::getSignMask(N0.getScalarValueSizeInBits());
13504         SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
13505       } else {
13506         // For a scalar, just generate 0x7f...
13507         SignMask = ~APInt::getSignMask(IntVT.getSizeInBits());
13508       }
13509       SDLoc DL(N0);
13510       Int = DAG.getNode(ISD::AND, DL, IntVT, Int,
13511                         DAG.getConstant(SignMask, DL, IntVT));
13512       AddToWorklist(Int.getNode());
13513       return DAG.getBitcast(N->getValueType(0), Int);
13514     }
13515   }
13516 
13517   return SDValue();
13518 }
13519 
visitBRCOND(SDNode * N)13520 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
13521   SDValue Chain = N->getOperand(0);
13522   SDValue N1 = N->getOperand(1);
13523   SDValue N2 = N->getOperand(2);
13524 
13525   // If N is a constant we could fold this into a fallthrough or unconditional
13526   // branch. However that doesn't happen very often in normal code, because
13527   // Instcombine/SimplifyCFG should have handled the available opportunities.
13528   // If we did this folding here, it would be necessary to update the
13529   // MachineBasicBlock CFG, which is awkward.
13530 
13531   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
13532   // on the target.
13533   if (N1.getOpcode() == ISD::SETCC &&
13534       TLI.isOperationLegalOrCustom(ISD::BR_CC,
13535                                    N1.getOperand(0).getValueType())) {
13536     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
13537                        Chain, N1.getOperand(2),
13538                        N1.getOperand(0), N1.getOperand(1), N2);
13539   }
13540 
13541   if (N1.hasOneUse()) {
13542     if (SDValue NewN1 = rebuildSetCC(N1))
13543       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain, NewN1, N2);
13544   }
13545 
13546   return SDValue();
13547 }
13548 
rebuildSetCC(SDValue N)13549 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
13550   if (N.getOpcode() == ISD::SRL ||
13551       (N.getOpcode() == ISD::TRUNCATE &&
13552        (N.getOperand(0).hasOneUse() &&
13553         N.getOperand(0).getOpcode() == ISD::SRL))) {
13554     // Look pass the truncate.
13555     if (N.getOpcode() == ISD::TRUNCATE)
13556       N = N.getOperand(0);
13557 
13558     // Match this pattern so that we can generate simpler code:
13559     //
13560     //   %a = ...
13561     //   %b = and i32 %a, 2
13562     //   %c = srl i32 %b, 1
13563     //   brcond i32 %c ...
13564     //
13565     // into
13566     //
13567     //   %a = ...
13568     //   %b = and i32 %a, 2
13569     //   %c = setcc eq %b, 0
13570     //   brcond %c ...
13571     //
13572     // This applies only when the AND constant value has one bit set and the
13573     // SRL constant is equal to the log2 of the AND constant. The back-end is
13574     // smart enough to convert the result into a TEST/JMP sequence.
13575     SDValue Op0 = N.getOperand(0);
13576     SDValue Op1 = N.getOperand(1);
13577 
13578     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
13579       SDValue AndOp1 = Op0.getOperand(1);
13580 
13581       if (AndOp1.getOpcode() == ISD::Constant) {
13582         const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
13583 
13584         if (AndConst.isPowerOf2() &&
13585             cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
13586           SDLoc DL(N);
13587           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
13588                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
13589                               ISD::SETNE);
13590         }
13591       }
13592     }
13593   }
13594 
13595   // Transform br(xor(x, y)) -> br(x != y)
13596   // Transform br(xor(xor(x,y), 1)) -> br (x == y)
13597   if (N.getOpcode() == ISD::XOR) {
13598     // Because we may call this on a speculatively constructed
13599     // SimplifiedSetCC Node, we need to simplify this node first.
13600     // Ideally this should be folded into SimplifySetCC and not
13601     // here. For now, grab a handle to N so we don't lose it from
13602     // replacements interal to the visit.
13603     HandleSDNode XORHandle(N);
13604     while (N.getOpcode() == ISD::XOR) {
13605       SDValue Tmp = visitXOR(N.getNode());
13606       // No simplification done.
13607       if (!Tmp.getNode())
13608         break;
13609       // Returning N is form in-visit replacement that may invalidated
13610       // N. Grab value from Handle.
13611       if (Tmp.getNode() == N.getNode())
13612         N = XORHandle.getValue();
13613       else // Node simplified. Try simplifying again.
13614         N = Tmp;
13615     }
13616 
13617     if (N.getOpcode() != ISD::XOR)
13618       return N;
13619 
13620     SDNode *TheXor = N.getNode();
13621 
13622     SDValue Op0 = TheXor->getOperand(0);
13623     SDValue Op1 = TheXor->getOperand(1);
13624 
13625     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
13626       bool Equal = false;
13627       if (isOneConstant(Op0) && Op0.hasOneUse() &&
13628           Op0.getOpcode() == ISD::XOR) {
13629         TheXor = Op0.getNode();
13630         Equal = true;
13631       }
13632 
13633       EVT SetCCVT = N.getValueType();
13634       if (LegalTypes)
13635         SetCCVT = getSetCCResultType(SetCCVT);
13636       // Replace the uses of XOR with SETCC
13637       return DAG.getSetCC(SDLoc(TheXor), SetCCVT, Op0, Op1,
13638                           Equal ? ISD::SETEQ : ISD::SETNE);
13639     }
13640   }
13641 
13642   return SDValue();
13643 }
13644 
13645 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
13646 //
visitBR_CC(SDNode * N)13647 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
13648   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
13649   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
13650 
13651   // If N is a constant we could fold this into a fallthrough or unconditional
13652   // branch. However that doesn't happen very often in normal code, because
13653   // Instcombine/SimplifyCFG should have handled the available opportunities.
13654   // If we did this folding here, it would be necessary to update the
13655   // MachineBasicBlock CFG, which is awkward.
13656 
13657   // Use SimplifySetCC to simplify SETCC's.
13658   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
13659                                CondLHS, CondRHS, CC->get(), SDLoc(N),
13660                                false);
13661   if (Simp.getNode()) AddToWorklist(Simp.getNode());
13662 
13663   // fold to a simpler setcc
13664   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
13665     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
13666                        N->getOperand(0), Simp.getOperand(2),
13667                        Simp.getOperand(0), Simp.getOperand(1),
13668                        N->getOperand(4));
13669 
13670   return SDValue();
13671 }
13672 
13673 /// Return true if 'Use' is a load or a store that uses N as its base pointer
13674 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)13675 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use,
13676                                     SelectionDAG &DAG,
13677                                     const TargetLowering &TLI) {
13678   EVT VT;
13679   unsigned AS;
13680 
13681   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
13682     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
13683       return false;
13684     VT = LD->getMemoryVT();
13685     AS = LD->getAddressSpace();
13686   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
13687     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
13688       return false;
13689     VT = ST->getMemoryVT();
13690     AS = ST->getAddressSpace();
13691   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
13692     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
13693       return false;
13694     VT = LD->getMemoryVT();
13695     AS = LD->getAddressSpace();
13696   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
13697     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
13698       return false;
13699     VT = ST->getMemoryVT();
13700     AS = ST->getAddressSpace();
13701   } else
13702     return false;
13703 
13704   TargetLowering::AddrMode AM;
13705   if (N->getOpcode() == ISD::ADD) {
13706     AM.HasBaseReg = true;
13707     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
13708     if (Offset)
13709       // [reg +/- imm]
13710       AM.BaseOffs = Offset->getSExtValue();
13711     else
13712       // [reg +/- reg]
13713       AM.Scale = 1;
13714   } else if (N->getOpcode() == ISD::SUB) {
13715     AM.HasBaseReg = true;
13716     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
13717     if (Offset)
13718       // [reg +/- imm]
13719       AM.BaseOffs = -Offset->getSExtValue();
13720     else
13721       // [reg +/- reg]
13722       AM.Scale = 1;
13723   } else
13724     return false;
13725 
13726   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
13727                                    VT.getTypeForEVT(*DAG.getContext()), AS);
13728 }
13729 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)13730 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
13731                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
13732                                      const TargetLowering &TLI) {
13733   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
13734     if (LD->isIndexed())
13735       return false;
13736     EVT VT = LD->getMemoryVT();
13737     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
13738       return false;
13739     Ptr = LD->getBasePtr();
13740   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
13741     if (ST->isIndexed())
13742       return false;
13743     EVT VT = ST->getMemoryVT();
13744     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
13745       return false;
13746     Ptr = ST->getBasePtr();
13747     IsLoad = false;
13748   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
13749     if (LD->isIndexed())
13750       return false;
13751     EVT VT = LD->getMemoryVT();
13752     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
13753         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
13754       return false;
13755     Ptr = LD->getBasePtr();
13756     IsMasked = true;
13757   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
13758     if (ST->isIndexed())
13759       return false;
13760     EVT VT = ST->getMemoryVT();
13761     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
13762         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
13763       return false;
13764     Ptr = ST->getBasePtr();
13765     IsLoad = false;
13766     IsMasked = true;
13767   } else {
13768     return false;
13769   }
13770   return true;
13771 }
13772 
13773 /// Try turning a load/store into a pre-indexed load/store when the base
13774 /// pointer is an add or subtract and it has other uses besides the load/store.
13775 /// After the transformation, the new indexed load/store has effectively folded
13776 /// the add/subtract in and all of its other uses are redirected to the
13777 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)13778 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
13779   if (Level < AfterLegalizeDAG)
13780     return false;
13781 
13782   bool IsLoad = true;
13783   bool IsMasked = false;
13784   SDValue Ptr;
13785   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
13786                                 Ptr, TLI))
13787     return false;
13788 
13789   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
13790   // out.  There is no reason to make this a preinc/predec.
13791   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
13792       Ptr.getNode()->hasOneUse())
13793     return false;
13794 
13795   // Ask the target to do addressing mode selection.
13796   SDValue BasePtr;
13797   SDValue Offset;
13798   ISD::MemIndexedMode AM = ISD::UNINDEXED;
13799   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
13800     return false;
13801 
13802   // Backends without true r+i pre-indexed forms may need to pass a
13803   // constant base with a variable offset so that constant coercion
13804   // will work with the patterns in canonical form.
13805   bool Swapped = false;
13806   if (isa<ConstantSDNode>(BasePtr)) {
13807     std::swap(BasePtr, Offset);
13808     Swapped = true;
13809   }
13810 
13811   // Don't create a indexed load / store with zero offset.
13812   if (isNullConstant(Offset))
13813     return false;
13814 
13815   // Try turning it into a pre-indexed load / store except when:
13816   // 1) The new base ptr is a frame index.
13817   // 2) If N is a store and the new base ptr is either the same as or is a
13818   //    predecessor of the value being stored.
13819   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
13820   //    that would create a cycle.
13821   // 4) All uses are load / store ops that use it as old base ptr.
13822 
13823   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
13824   // (plus the implicit offset) to a register to preinc anyway.
13825   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
13826     return false;
13827 
13828   // Check #2.
13829   if (!IsLoad) {
13830     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
13831                            : cast<StoreSDNode>(N)->getValue();
13832 
13833     // Would require a copy.
13834     if (Val == BasePtr)
13835       return false;
13836 
13837     // Would create a cycle.
13838     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
13839       return false;
13840   }
13841 
13842   // Caches for hasPredecessorHelper.
13843   SmallPtrSet<const SDNode *, 32> Visited;
13844   SmallVector<const SDNode *, 16> Worklist;
13845   Worklist.push_back(N);
13846 
13847   // If the offset is a constant, there may be other adds of constants that
13848   // can be folded with this one. We should do this to avoid having to keep
13849   // a copy of the original base pointer.
13850   SmallVector<SDNode *, 16> OtherUses;
13851   if (isa<ConstantSDNode>(Offset))
13852     for (SDNode::use_iterator UI = BasePtr.getNode()->use_begin(),
13853                               UE = BasePtr.getNode()->use_end();
13854          UI != UE; ++UI) {
13855       SDUse &Use = UI.getUse();
13856       // Skip the use that is Ptr and uses of other results from BasePtr's
13857       // node (important for nodes that return multiple results).
13858       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
13859         continue;
13860 
13861       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
13862         continue;
13863 
13864       if (Use.getUser()->getOpcode() != ISD::ADD &&
13865           Use.getUser()->getOpcode() != ISD::SUB) {
13866         OtherUses.clear();
13867         break;
13868       }
13869 
13870       SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
13871       if (!isa<ConstantSDNode>(Op1)) {
13872         OtherUses.clear();
13873         break;
13874       }
13875 
13876       // FIXME: In some cases, we can be smarter about this.
13877       if (Op1.getValueType() != Offset.getValueType()) {
13878         OtherUses.clear();
13879         break;
13880       }
13881 
13882       OtherUses.push_back(Use.getUser());
13883     }
13884 
13885   if (Swapped)
13886     std::swap(BasePtr, Offset);
13887 
13888   // Now check for #3 and #4.
13889   bool RealUse = false;
13890 
13891   for (SDNode *Use : Ptr.getNode()->uses()) {
13892     if (Use == N)
13893       continue;
13894     if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
13895       return false;
13896 
13897     // If Ptr may be folded in addressing mode of other use, then it's
13898     // not profitable to do this transformation.
13899     if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
13900       RealUse = true;
13901   }
13902 
13903   if (!RealUse)
13904     return false;
13905 
13906   SDValue Result;
13907   if (!IsMasked) {
13908     if (IsLoad)
13909       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
13910     else
13911       Result =
13912           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
13913   } else {
13914     if (IsLoad)
13915       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
13916                                         Offset, AM);
13917     else
13918       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
13919                                          Offset, AM);
13920   }
13921   ++PreIndexedNodes;
13922   ++NodesCombined;
13923   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
13924              Result.getNode()->dump(&DAG); dbgs() << '\n');
13925   WorklistRemover DeadNodes(*this);
13926   if (IsLoad) {
13927     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
13928     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
13929   } else {
13930     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
13931   }
13932 
13933   // Finally, since the node is now dead, remove it from the graph.
13934   deleteAndRecombine(N);
13935 
13936   if (Swapped)
13937     std::swap(BasePtr, Offset);
13938 
13939   // Replace other uses of BasePtr that can be updated to use Ptr
13940   for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
13941     unsigned OffsetIdx = 1;
13942     if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
13943       OffsetIdx = 0;
13944     assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
13945            BasePtr.getNode() && "Expected BasePtr operand");
13946 
13947     // We need to replace ptr0 in the following expression:
13948     //   x0 * offset0 + y0 * ptr0 = t0
13949     // knowing that
13950     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
13951     //
13952     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
13953     // indexed load/store and the expression that needs to be re-written.
13954     //
13955     // Therefore, we have:
13956     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
13957 
13958     ConstantSDNode *CN =
13959       cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
13960     int X0, X1, Y0, Y1;
13961     const APInt &Offset0 = CN->getAPIntValue();
13962     APInt Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
13963 
13964     X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
13965     Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
13966     X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
13967     Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
13968 
13969     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
13970 
13971     APInt CNV = Offset0;
13972     if (X0 < 0) CNV = -CNV;
13973     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
13974     else CNV = CNV - Offset1;
13975 
13976     SDLoc DL(OtherUses[i]);
13977 
13978     // We can now generate the new expression.
13979     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
13980     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
13981 
13982     SDValue NewUse = DAG.getNode(Opcode,
13983                                  DL,
13984                                  OtherUses[i]->getValueType(0), NewOp1, NewOp2);
13985     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
13986     deleteAndRecombine(OtherUses[i]);
13987   }
13988 
13989   // Replace the uses of Ptr with uses of the updated base value.
13990   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
13991   deleteAndRecombine(Ptr.getNode());
13992   AddToWorklist(Result.getNode());
13993 
13994   return true;
13995 }
13996 
13997 /// Try to combine a load/store with a add/sub of the base pointer node into a
13998 /// post-indexed load/store. The transformation folded the add/subtract into the
13999 /// new indexed load/store effectively and all of its uses are redirected to the
14000 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)14001 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
14002   if (Level < AfterLegalizeDAG)
14003     return false;
14004 
14005   bool IsLoad = true;
14006   bool IsMasked = false;
14007   SDValue Ptr;
14008   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad, IsMasked,
14009                                 Ptr, TLI))
14010     return false;
14011 
14012   if (Ptr.getNode()->hasOneUse())
14013     return false;
14014 
14015   for (SDNode *Op : Ptr.getNode()->uses()) {
14016     if (Op == N ||
14017         (Op->getOpcode() != ISD::ADD && Op->getOpcode() != ISD::SUB))
14018       continue;
14019 
14020     SDValue BasePtr;
14021     SDValue Offset;
14022     ISD::MemIndexedMode AM = ISD::UNINDEXED;
14023     if (TLI.getPostIndexedAddressParts(N, Op, BasePtr, Offset, AM, DAG)) {
14024       // Don't create a indexed load / store with zero offset.
14025       if (isNullConstant(Offset))
14026         continue;
14027 
14028       // Try turning it into a post-indexed load / store except when
14029       // 1) All uses are load / store ops that use it as base ptr (and
14030       //    it may be folded as addressing mmode).
14031       // 2) Op must be independent of N, i.e. Op is neither a predecessor
14032       //    nor a successor of N. Otherwise, if Op is folded that would
14033       //    create a cycle.
14034 
14035       if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
14036         continue;
14037 
14038       // Check for #1.
14039       bool TryNext = false;
14040       for (SDNode *Use : BasePtr.getNode()->uses()) {
14041         if (Use == Ptr.getNode())
14042           continue;
14043 
14044         // If all the uses are load / store addresses, then don't do the
14045         // transformation.
14046         if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
14047           bool RealUse = false;
14048           for (SDNode *UseUse : Use->uses()) {
14049             if (!canFoldInAddressingMode(Use, UseUse, DAG, TLI))
14050               RealUse = true;
14051           }
14052 
14053           if (!RealUse) {
14054             TryNext = true;
14055             break;
14056           }
14057         }
14058       }
14059 
14060       if (TryNext)
14061         continue;
14062 
14063       // Check for #2.
14064       SmallPtrSet<const SDNode *, 32> Visited;
14065       SmallVector<const SDNode *, 8> Worklist;
14066       // Ptr is predecessor to both N and Op.
14067       Visited.insert(Ptr.getNode());
14068       Worklist.push_back(N);
14069       Worklist.push_back(Op);
14070       if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
14071           !SDNode::hasPredecessorHelper(Op, Visited, Worklist)) {
14072         SDValue Result;
14073         if (!IsMasked)
14074           Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
14075                                                Offset, AM)
14076                           : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
14077                                                 BasePtr, Offset, AM);
14078         else
14079           Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
14080                                                      BasePtr, Offset, AM)
14081                           : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
14082                                                       BasePtr, Offset, AM);
14083         ++PostIndexedNodes;
14084         ++NodesCombined;
14085         LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG);
14086                    dbgs() << "\nWith: "; Result.getNode()->dump(&DAG);
14087                    dbgs() << '\n');
14088         WorklistRemover DeadNodes(*this);
14089         if (IsLoad) {
14090           DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
14091           DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
14092         } else {
14093           DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
14094         }
14095 
14096         // Finally, since the node is now dead, remove it from the graph.
14097         deleteAndRecombine(N);
14098 
14099         // Replace the uses of Use with uses of the updated base value.
14100         DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
14101                                       Result.getValue(IsLoad ? 1 : 0));
14102         deleteAndRecombine(Op);
14103         return true;
14104       }
14105     }
14106   }
14107 
14108   return false;
14109 }
14110 
14111 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)14112 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
14113   ISD::MemIndexedMode AM = LD->getAddressingMode();
14114   assert(AM != ISD::UNINDEXED);
14115   SDValue BP = LD->getOperand(1);
14116   SDValue Inc = LD->getOperand(2);
14117 
14118   // Some backends use TargetConstants for load offsets, but don't expect
14119   // TargetConstants in general ADD nodes. We can convert these constants into
14120   // regular Constants (if the constant is not opaque).
14121   assert((Inc.getOpcode() != ISD::TargetConstant ||
14122           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
14123          "Cannot split out indexing using opaque target constants");
14124   if (Inc.getOpcode() == ISD::TargetConstant) {
14125     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
14126     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
14127                           ConstInc->getValueType(0));
14128   }
14129 
14130   unsigned Opc =
14131       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
14132   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
14133 }
14134 
numVectorEltsOrZero(EVT T)14135 static inline int numVectorEltsOrZero(EVT T) {
14136   return T.isVector() ? T.getVectorNumElements() : 0;
14137 }
14138 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)14139 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
14140   Val = ST->getValue();
14141   EVT STType = Val.getValueType();
14142   EVT STMemType = ST->getMemoryVT();
14143   if (STType == STMemType)
14144     return true;
14145   if (isTypeLegal(STMemType))
14146     return false; // fail.
14147   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
14148       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
14149     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
14150     return true;
14151   }
14152   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
14153       STType.isInteger() && STMemType.isInteger()) {
14154     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
14155     return true;
14156   }
14157   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
14158     Val = DAG.getBitcast(STMemType, Val);
14159     return true;
14160   }
14161   return false; // fail.
14162 }
14163 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)14164 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
14165   EVT LDMemType = LD->getMemoryVT();
14166   EVT LDType = LD->getValueType(0);
14167   assert(Val.getValueType() == LDMemType &&
14168          "Attempting to extend value of non-matching type");
14169   if (LDType == LDMemType)
14170     return true;
14171   if (LDMemType.isInteger() && LDType.isInteger()) {
14172     switch (LD->getExtensionType()) {
14173     case ISD::NON_EXTLOAD:
14174       Val = DAG.getBitcast(LDType, Val);
14175       return true;
14176     case ISD::EXTLOAD:
14177       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
14178       return true;
14179     case ISD::SEXTLOAD:
14180       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
14181       return true;
14182     case ISD::ZEXTLOAD:
14183       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
14184       return true;
14185     }
14186   }
14187   return false;
14188 }
14189 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)14190 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
14191   if (OptLevel == CodeGenOpt::None || !LD->isSimple())
14192     return SDValue();
14193   SDValue Chain = LD->getOperand(0);
14194   StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
14195   // TODO: Relax this restriction for unordered atomics (see D66309)
14196   if (!ST || !ST->isSimple())
14197     return SDValue();
14198 
14199   EVT LDType = LD->getValueType(0);
14200   EVT LDMemType = LD->getMemoryVT();
14201   EVT STMemType = ST->getMemoryVT();
14202   EVT STType = ST->getValue().getValueType();
14203 
14204   BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
14205   BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
14206   int64_t Offset;
14207   if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
14208     return SDValue();
14209 
14210   // Normalize for Endianness. After this Offset=0 will denote that the least
14211   // significant bit in the loaded value maps to the least significant bit in
14212   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
14213   // n:th least significant byte of the stored value.
14214   if (DAG.getDataLayout().isBigEndian())
14215     Offset = ((int64_t)STMemType.getStoreSizeInBits() -
14216               (int64_t)LDMemType.getStoreSizeInBits()) / 8 - Offset;
14217 
14218   // Check that the stored value cover all bits that are loaded.
14219   bool STCoversLD =
14220       (Offset >= 0) &&
14221       (Offset * 8 + LDMemType.getSizeInBits() <= STMemType.getSizeInBits());
14222 
14223   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
14224     if (LD->isIndexed()) {
14225       bool IsSub = (LD->getAddressingMode() == ISD::PRE_DEC ||
14226                     LD->getAddressingMode() == ISD::POST_DEC);
14227       unsigned Opc = IsSub ? ISD::SUB : ISD::ADD;
14228       SDValue Idx = DAG.getNode(Opc, SDLoc(LD), LD->getOperand(1).getValueType(),
14229                              LD->getOperand(1), LD->getOperand(2));
14230       SDValue Ops[] = {Val, Idx, Chain};
14231       return CombineTo(LD, Ops, 3);
14232     }
14233     return CombineTo(LD, Val, Chain);
14234   };
14235 
14236   if (!STCoversLD)
14237     return SDValue();
14238 
14239   // Memory as copy space (potentially masked).
14240   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
14241     // Simple case: Direct non-truncating forwarding
14242     if (LDType.getSizeInBits() == LDMemType.getSizeInBits())
14243       return ReplaceLd(LD, ST->getValue(), Chain);
14244     // Can we model the truncate and extension with an and mask?
14245     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
14246         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
14247       // Mask to size of LDMemType
14248       auto Mask =
14249           DAG.getConstant(APInt::getLowBitsSet(STType.getSizeInBits(),
14250                                                STMemType.getSizeInBits()),
14251                           SDLoc(ST), STType);
14252       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
14253       return ReplaceLd(LD, Val, Chain);
14254     }
14255   }
14256 
14257   // TODO: Deal with nonzero offset.
14258   if (LD->getBasePtr().isUndef() || Offset != 0)
14259     return SDValue();
14260   // Model necessary truncations / extenstions.
14261   SDValue Val;
14262   // Truncate Value To Stored Memory Size.
14263   do {
14264     if (!getTruncatedStoreValue(ST, Val))
14265       continue;
14266     if (!isTypeLegal(LDMemType))
14267       continue;
14268     if (STMemType != LDMemType) {
14269       // TODO: Support vectors? This requires extract_subvector/bitcast.
14270       if (!STMemType.isVector() && !LDMemType.isVector() &&
14271           STMemType.isInteger() && LDMemType.isInteger())
14272         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
14273       else
14274         continue;
14275     }
14276     if (!extendLoadedValueToExtension(LD, Val))
14277       continue;
14278     return ReplaceLd(LD, Val, Chain);
14279   } while (false);
14280 
14281   // On failure, cleanup dead nodes we may have created.
14282   if (Val->use_empty())
14283     deleteAndRecombine(Val.getNode());
14284   return SDValue();
14285 }
14286 
visitLOAD(SDNode * N)14287 SDValue DAGCombiner::visitLOAD(SDNode *N) {
14288   LoadSDNode *LD  = cast<LoadSDNode>(N);
14289   SDValue Chain = LD->getChain();
14290   SDValue Ptr   = LD->getBasePtr();
14291 
14292   // If load is not volatile and there are no uses of the loaded value (and
14293   // the updated indexed value in case of indexed loads), change uses of the
14294   // chain value into uses of the chain input (i.e. delete the dead load).
14295   // TODO: Allow this for unordered atomics (see D66309)
14296   if (LD->isSimple()) {
14297     if (N->getValueType(1) == MVT::Other) {
14298       // Unindexed loads.
14299       if (!N->hasAnyUseOfValue(0)) {
14300         // It's not safe to use the two value CombineTo variant here. e.g.
14301         // v1, chain2 = load chain1, loc
14302         // v2, chain3 = load chain2, loc
14303         // v3         = add v2, c
14304         // Now we replace use of chain2 with chain1.  This makes the second load
14305         // isomorphic to the one we are deleting, and thus makes this load live.
14306         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
14307                    dbgs() << "\nWith chain: "; Chain.getNode()->dump(&DAG);
14308                    dbgs() << "\n");
14309         WorklistRemover DeadNodes(*this);
14310         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
14311         AddUsersToWorklist(Chain.getNode());
14312         if (N->use_empty())
14313           deleteAndRecombine(N);
14314 
14315         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14316       }
14317     } else {
14318       // Indexed loads.
14319       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
14320 
14321       // If this load has an opaque TargetConstant offset, then we cannot split
14322       // the indexing into an add/sub directly (that TargetConstant may not be
14323       // valid for a different type of node, and we cannot convert an opaque
14324       // target constant into a regular constant).
14325       bool HasOTCInc = LD->getOperand(2).getOpcode() == ISD::TargetConstant &&
14326                        cast<ConstantSDNode>(LD->getOperand(2))->isOpaque();
14327 
14328       if (!N->hasAnyUseOfValue(0) &&
14329           ((MaySplitLoadIndex && !HasOTCInc) || !N->hasAnyUseOfValue(1))) {
14330         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
14331         SDValue Index;
14332         if (N->hasAnyUseOfValue(1) && MaySplitLoadIndex && !HasOTCInc) {
14333           Index = SplitIndexingFromLoad(LD);
14334           // Try to fold the base pointer arithmetic into subsequent loads and
14335           // stores.
14336           AddUsersToWorklist(N);
14337         } else
14338           Index = DAG.getUNDEF(N->getValueType(1));
14339         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
14340                    dbgs() << "\nWith: "; Undef.getNode()->dump(&DAG);
14341                    dbgs() << " and 2 other values\n");
14342         WorklistRemover DeadNodes(*this);
14343         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
14344         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
14345         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
14346         deleteAndRecombine(N);
14347         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14348       }
14349     }
14350   }
14351 
14352   // If this load is directly stored, replace the load value with the stored
14353   // value.
14354   if (auto V = ForwardStoreValueToDirectLoad(LD))
14355     return V;
14356 
14357   // Try to infer better alignment information than the load already has.
14358   if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) {
14359     if (unsigned Align = DAG.InferPtrAlignment(Ptr)) {
14360       if (Align > LD->getAlignment() && LD->getSrcValueOffset() % Align == 0) {
14361         SDValue NewLoad = DAG.getExtLoad(
14362             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
14363             LD->getPointerInfo(), LD->getMemoryVT(), Align,
14364             LD->getMemOperand()->getFlags(), LD->getAAInfo());
14365         // NewLoad will always be N as we are only refining the alignment
14366         assert(NewLoad.getNode() == N);
14367         (void)NewLoad;
14368       }
14369     }
14370   }
14371 
14372   if (LD->isUnindexed()) {
14373     // Walk up chain skipping non-aliasing memory nodes.
14374     SDValue BetterChain = FindBetterChain(LD, Chain);
14375 
14376     // If there is a better chain.
14377     if (Chain != BetterChain) {
14378       SDValue ReplLoad;
14379 
14380       // Replace the chain to void dependency.
14381       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
14382         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
14383                                BetterChain, Ptr, LD->getMemOperand());
14384       } else {
14385         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
14386                                   LD->getValueType(0),
14387                                   BetterChain, Ptr, LD->getMemoryVT(),
14388                                   LD->getMemOperand());
14389       }
14390 
14391       // Create token factor to keep old chain connected.
14392       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
14393                                   MVT::Other, Chain, ReplLoad.getValue(1));
14394 
14395       // Replace uses with load result and token factor
14396       return CombineTo(N, ReplLoad.getValue(0), Token);
14397     }
14398   }
14399 
14400   // Try transforming N to an indexed load.
14401   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
14402     return SDValue(N, 0);
14403 
14404   // Try to slice up N to more direct loads if the slices are mapped to
14405   // different register banks or pairing can take place.
14406   if (SliceUpLoad(N))
14407     return SDValue(N, 0);
14408 
14409   return SDValue();
14410 }
14411 
14412 namespace {
14413 
14414 /// Helper structure used to slice a load in smaller loads.
14415 /// Basically a slice is obtained from the following sequence:
14416 /// Origin = load Ty1, Base
14417 /// Shift = srl Ty1 Origin, CstTy Amount
14418 /// Inst = trunc Shift to Ty2
14419 ///
14420 /// Then, it will be rewritten into:
14421 /// Slice = load SliceTy, Base + SliceOffset
14422 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
14423 ///
14424 /// SliceTy is deduced from the number of bits that are actually used to
14425 /// build Inst.
14426 struct LoadedSlice {
14427   /// Helper structure used to compute the cost of a slice.
14428   struct Cost {
14429     /// Are we optimizing for code size.
14430     bool ForCodeSize = false;
14431 
14432     /// Various cost.
14433     unsigned Loads = 0;
14434     unsigned Truncates = 0;
14435     unsigned CrossRegisterBanksCopies = 0;
14436     unsigned ZExts = 0;
14437     unsigned Shift = 0;
14438 
Cost__anon9770a4812211::LoadedSlice::Cost14439     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
14440 
14441     /// Get the cost of one isolated slice.
Cost__anon9770a4812211::LoadedSlice::Cost14442     Cost(const LoadedSlice &LS, bool ForCodeSize)
14443         : ForCodeSize(ForCodeSize), Loads(1) {
14444       EVT TruncType = LS.Inst->getValueType(0);
14445       EVT LoadedType = LS.getLoadedType();
14446       if (TruncType != LoadedType &&
14447           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
14448         ZExts = 1;
14449     }
14450 
14451     /// Account for slicing gain in the current cost.
14452     /// Slicing provide a few gains like removing a shift or a
14453     /// truncate. This method allows to grow the cost of the original
14454     /// load with the gain from this slice.
addSliceGain__anon9770a4812211::LoadedSlice::Cost14455     void addSliceGain(const LoadedSlice &LS) {
14456       // Each slice saves a truncate.
14457       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
14458       if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(),
14459                               LS.Inst->getValueType(0)))
14460         ++Truncates;
14461       // If there is a shift amount, this slice gets rid of it.
14462       if (LS.Shift)
14463         ++Shift;
14464       // If this slice can merge a cross register bank copy, account for it.
14465       if (LS.canMergeExpensiveCrossRegisterBankCopy())
14466         ++CrossRegisterBanksCopies;
14467     }
14468 
operator +=__anon9770a4812211::LoadedSlice::Cost14469     Cost &operator+=(const Cost &RHS) {
14470       Loads += RHS.Loads;
14471       Truncates += RHS.Truncates;
14472       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
14473       ZExts += RHS.ZExts;
14474       Shift += RHS.Shift;
14475       return *this;
14476     }
14477 
operator ==__anon9770a4812211::LoadedSlice::Cost14478     bool operator==(const Cost &RHS) const {
14479       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
14480              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
14481              ZExts == RHS.ZExts && Shift == RHS.Shift;
14482     }
14483 
operator !=__anon9770a4812211::LoadedSlice::Cost14484     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
14485 
operator <__anon9770a4812211::LoadedSlice::Cost14486     bool operator<(const Cost &RHS) const {
14487       // Assume cross register banks copies are as expensive as loads.
14488       // FIXME: Do we want some more target hooks?
14489       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
14490       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
14491       // Unless we are optimizing for code size, consider the
14492       // expensive operation first.
14493       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
14494         return ExpensiveOpsLHS < ExpensiveOpsRHS;
14495       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
14496              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
14497     }
14498 
operator >__anon9770a4812211::LoadedSlice::Cost14499     bool operator>(const Cost &RHS) const { return RHS < *this; }
14500 
operator <=__anon9770a4812211::LoadedSlice::Cost14501     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
14502 
operator >=__anon9770a4812211::LoadedSlice::Cost14503     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
14504   };
14505 
14506   // The last instruction that represent the slice. This should be a
14507   // truncate instruction.
14508   SDNode *Inst;
14509 
14510   // The original load instruction.
14511   LoadSDNode *Origin;
14512 
14513   // The right shift amount in bits from the original load.
14514   unsigned Shift;
14515 
14516   // The DAG from which Origin came from.
14517   // This is used to get some contextual information about legal types, etc.
14518   SelectionDAG *DAG;
14519 
LoadedSlice__anon9770a4812211::LoadedSlice14520   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
14521               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
14522       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
14523 
14524   /// Get the bits used in a chunk of bits \p BitWidth large.
14525   /// \return Result is \p BitWidth and has used bits set to 1 and
14526   ///         not used bits set to 0.
getUsedBits__anon9770a4812211::LoadedSlice14527   APInt getUsedBits() const {
14528     // Reproduce the trunc(lshr) sequence:
14529     // - Start from the truncated value.
14530     // - Zero extend to the desired bit width.
14531     // - Shift left.
14532     assert(Origin && "No original load to compare against.");
14533     unsigned BitWidth = Origin->getValueSizeInBits(0);
14534     assert(Inst && "This slice is not bound to an instruction");
14535     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
14536            "Extracted slice is bigger than the whole type!");
14537     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
14538     UsedBits.setAllBits();
14539     UsedBits = UsedBits.zext(BitWidth);
14540     UsedBits <<= Shift;
14541     return UsedBits;
14542   }
14543 
14544   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon9770a4812211::LoadedSlice14545   unsigned getLoadedSize() const {
14546     unsigned SliceSize = getUsedBits().countPopulation();
14547     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
14548     return SliceSize / 8;
14549   }
14550 
14551   /// Get the type that will be loaded for this slice.
14552   /// Note: This may not be the final type for the slice.
getLoadedType__anon9770a4812211::LoadedSlice14553   EVT getLoadedType() const {
14554     assert(DAG && "Missing context");
14555     LLVMContext &Ctxt = *DAG->getContext();
14556     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
14557   }
14558 
14559   /// Get the alignment of the load used for this slice.
getAlignment__anon9770a4812211::LoadedSlice14560   unsigned getAlignment() const {
14561     unsigned Alignment = Origin->getAlignment();
14562     uint64_t Offset = getOffsetFromBase();
14563     if (Offset != 0)
14564       Alignment = MinAlign(Alignment, Alignment + Offset);
14565     return Alignment;
14566   }
14567 
14568   /// Check if this slice can be rewritten with legal operations.
isLegal__anon9770a4812211::LoadedSlice14569   bool isLegal() const {
14570     // An invalid slice is not legal.
14571     if (!Origin || !Inst || !DAG)
14572       return false;
14573 
14574     // Offsets are for indexed load only, we do not handle that.
14575     if (!Origin->getOffset().isUndef())
14576       return false;
14577 
14578     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
14579 
14580     // Check that the type is legal.
14581     EVT SliceType = getLoadedType();
14582     if (!TLI.isTypeLegal(SliceType))
14583       return false;
14584 
14585     // Check that the load is legal for this type.
14586     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
14587       return false;
14588 
14589     // Check that the offset can be computed.
14590     // 1. Check its type.
14591     EVT PtrType = Origin->getBasePtr().getValueType();
14592     if (PtrType == MVT::Untyped || PtrType.isExtended())
14593       return false;
14594 
14595     // 2. Check that it fits in the immediate.
14596     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
14597       return false;
14598 
14599     // 3. Check that the computation is legal.
14600     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
14601       return false;
14602 
14603     // Check that the zext is legal if it needs one.
14604     EVT TruncateType = Inst->getValueType(0);
14605     if (TruncateType != SliceType &&
14606         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
14607       return false;
14608 
14609     return true;
14610   }
14611 
14612   /// Get the offset in bytes of this slice in the original chunk of
14613   /// bits.
14614   /// \pre DAG != nullptr.
getOffsetFromBase__anon9770a4812211::LoadedSlice14615   uint64_t getOffsetFromBase() const {
14616     assert(DAG && "Missing context.");
14617     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
14618     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
14619     uint64_t Offset = Shift / 8;
14620     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
14621     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
14622            "The size of the original loaded type is not a multiple of a"
14623            " byte.");
14624     // If Offset is bigger than TySizeInBytes, it means we are loading all
14625     // zeros. This should have been optimized before in the process.
14626     assert(TySizeInBytes > Offset &&
14627            "Invalid shift amount for given loaded size");
14628     if (IsBigEndian)
14629       Offset = TySizeInBytes - Offset - getLoadedSize();
14630     return Offset;
14631   }
14632 
14633   /// Generate the sequence of instructions to load the slice
14634   /// represented by this object and redirect the uses of this slice to
14635   /// this new sequence of instructions.
14636   /// \pre this->Inst && this->Origin are valid Instructions and this
14637   /// object passed the legal check: LoadedSlice::isLegal returned true.
14638   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon9770a4812211::LoadedSlice14639   SDValue loadSlice() const {
14640     assert(Inst && Origin && "Unable to replace a non-existing slice.");
14641     const SDValue &OldBaseAddr = Origin->getBasePtr();
14642     SDValue BaseAddr = OldBaseAddr;
14643     // Get the offset in that chunk of bytes w.r.t. the endianness.
14644     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
14645     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
14646     if (Offset) {
14647       // BaseAddr = BaseAddr + Offset.
14648       EVT ArithType = BaseAddr.getValueType();
14649       SDLoc DL(Origin);
14650       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
14651                               DAG->getConstant(Offset, DL, ArithType));
14652     }
14653 
14654     // Create the type of the loaded slice according to its size.
14655     EVT SliceType = getLoadedType();
14656 
14657     // Create the load for the slice.
14658     SDValue LastInst =
14659         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
14660                      Origin->getPointerInfo().getWithOffset(Offset),
14661                      getAlignment(), Origin->getMemOperand()->getFlags());
14662     // If the final type is not the same as the loaded type, this means that
14663     // we have to pad with zero. Create a zero extend for that.
14664     EVT FinalType = Inst->getValueType(0);
14665     if (SliceType != FinalType)
14666       LastInst =
14667           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
14668     return LastInst;
14669   }
14670 
14671   /// Check if this slice can be merged with an expensive cross register
14672   /// bank copy. E.g.,
14673   /// i = load i32
14674   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon9770a4812211::LoadedSlice14675   bool canMergeExpensiveCrossRegisterBankCopy() const {
14676     if (!Inst || !Inst->hasOneUse())
14677       return false;
14678     SDNode *Use = *Inst->use_begin();
14679     if (Use->getOpcode() != ISD::BITCAST)
14680       return false;
14681     assert(DAG && "Missing context");
14682     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
14683     EVT ResVT = Use->getValueType(0);
14684     const TargetRegisterClass *ResRC =
14685         TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
14686     const TargetRegisterClass *ArgRC =
14687         TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
14688                            Use->getOperand(0)->isDivergent());
14689     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
14690       return false;
14691 
14692     // At this point, we know that we perform a cross-register-bank copy.
14693     // Check if it is expensive.
14694     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
14695     // Assume bitcasts are cheap, unless both register classes do not
14696     // explicitly share a common sub class.
14697     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
14698       return false;
14699 
14700     // Check if it will be merged with the load.
14701     // 1. Check the alignment constraint.
14702     unsigned RequiredAlignment = DAG->getDataLayout().getABITypeAlignment(
14703         ResVT.getTypeForEVT(*DAG->getContext()));
14704 
14705     if (RequiredAlignment > getAlignment())
14706       return false;
14707 
14708     // 2. Check that the load is a legal operation for that type.
14709     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
14710       return false;
14711 
14712     // 3. Check that we do not have a zext in the way.
14713     if (Inst->getValueType(0) != getLoadedType())
14714       return false;
14715 
14716     return true;
14717   }
14718 };
14719 
14720 } // end anonymous namespace
14721 
14722 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
14723 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)14724 static bool areUsedBitsDense(const APInt &UsedBits) {
14725   // If all the bits are one, this is dense!
14726   if (UsedBits.isAllOnesValue())
14727     return true;
14728 
14729   // Get rid of the unused bits on the right.
14730   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros());
14731   // Get rid of the unused bits on the left.
14732   if (NarrowedUsedBits.countLeadingZeros())
14733     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
14734   // Check that the chunk of bits is completely used.
14735   return NarrowedUsedBits.isAllOnesValue();
14736 }
14737 
14738 /// Check whether or not \p First and \p Second are next to each other
14739 /// in memory. This means that there is no hole between the bits loaded
14740 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)14741 static bool areSlicesNextToEachOther(const LoadedSlice &First,
14742                                      const LoadedSlice &Second) {
14743   assert(First.Origin == Second.Origin && First.Origin &&
14744          "Unable to match different memory origins.");
14745   APInt UsedBits = First.getUsedBits();
14746   assert((UsedBits & Second.getUsedBits()) == 0 &&
14747          "Slices are not supposed to overlap.");
14748   UsedBits |= Second.getUsedBits();
14749   return areUsedBitsDense(UsedBits);
14750 }
14751 
14752 /// Adjust the \p GlobalLSCost according to the target
14753 /// paring capabilities and the layout of the slices.
14754 /// \pre \p GlobalLSCost should account for at least as many loads as
14755 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)14756 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
14757                                  LoadedSlice::Cost &GlobalLSCost) {
14758   unsigned NumberOfSlices = LoadedSlices.size();
14759   // If there is less than 2 elements, no pairing is possible.
14760   if (NumberOfSlices < 2)
14761     return;
14762 
14763   // Sort the slices so that elements that are likely to be next to each
14764   // other in memory are next to each other in the list.
14765   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
14766     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
14767     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
14768   });
14769   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
14770   // First (resp. Second) is the first (resp. Second) potentially candidate
14771   // to be placed in a paired load.
14772   const LoadedSlice *First = nullptr;
14773   const LoadedSlice *Second = nullptr;
14774   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
14775                 // Set the beginning of the pair.
14776                                                            First = Second) {
14777     Second = &LoadedSlices[CurrSlice];
14778 
14779     // If First is NULL, it means we start a new pair.
14780     // Get to the next slice.
14781     if (!First)
14782       continue;
14783 
14784     EVT LoadedType = First->getLoadedType();
14785 
14786     // If the types of the slices are different, we cannot pair them.
14787     if (LoadedType != Second->getLoadedType())
14788       continue;
14789 
14790     // Check if the target supplies paired loads for this type.
14791     unsigned RequiredAlignment = 0;
14792     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
14793       // move to the next pair, this type is hopeless.
14794       Second = nullptr;
14795       continue;
14796     }
14797     // Check if we meet the alignment requirement.
14798     if (RequiredAlignment > First->getAlignment())
14799       continue;
14800 
14801     // Check that both loads are next to each other in memory.
14802     if (!areSlicesNextToEachOther(*First, *Second))
14803       continue;
14804 
14805     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
14806     --GlobalLSCost.Loads;
14807     // Move to the next pair.
14808     Second = nullptr;
14809   }
14810 }
14811 
14812 /// Check the profitability of all involved LoadedSlice.
14813 /// Currently, it is considered profitable if there is exactly two
14814 /// involved slices (1) which are (2) next to each other in memory, and
14815 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
14816 ///
14817 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
14818 /// the elements themselves.
14819 ///
14820 /// FIXME: When the cost model will be mature enough, we can relax
14821 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)14822 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
14823                                 const APInt &UsedBits, bool ForCodeSize) {
14824   unsigned NumberOfSlices = LoadedSlices.size();
14825   if (StressLoadSlicing)
14826     return NumberOfSlices > 1;
14827 
14828   // Check (1).
14829   if (NumberOfSlices != 2)
14830     return false;
14831 
14832   // Check (2).
14833   if (!areUsedBitsDense(UsedBits))
14834     return false;
14835 
14836   // Check (3).
14837   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
14838   // The original code has one big load.
14839   OrigCost.Loads = 1;
14840   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
14841     const LoadedSlice &LS = LoadedSlices[CurrSlice];
14842     // Accumulate the cost of all the slices.
14843     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
14844     GlobalSlicingCost += SliceCost;
14845 
14846     // Account as cost in the original configuration the gain obtained
14847     // with the current slices.
14848     OrigCost.addSliceGain(LS);
14849   }
14850 
14851   // If the target supports paired load, adjust the cost accordingly.
14852   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
14853   return OrigCost > GlobalSlicingCost;
14854 }
14855 
14856 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
14857 /// operations, split it in the various pieces being extracted.
14858 ///
14859 /// This sort of thing is introduced by SROA.
14860 /// This slicing takes care not to insert overlapping loads.
14861 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)14862 bool DAGCombiner::SliceUpLoad(SDNode *N) {
14863   if (Level < AfterLegalizeDAG)
14864     return false;
14865 
14866   LoadSDNode *LD = cast<LoadSDNode>(N);
14867   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
14868       !LD->getValueType(0).isInteger())
14869     return false;
14870 
14871   // Keep track of already used bits to detect overlapping values.
14872   // In that case, we will just abort the transformation.
14873   APInt UsedBits(LD->getValueSizeInBits(0), 0);
14874 
14875   SmallVector<LoadedSlice, 4> LoadedSlices;
14876 
14877   // Check if this load is used as several smaller chunks of bits.
14878   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
14879   // of computation for each trunc.
14880   for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
14881        UI != UIEnd; ++UI) {
14882     // Skip the uses of the chain.
14883     if (UI.getUse().getResNo() != 0)
14884       continue;
14885 
14886     SDNode *User = *UI;
14887     unsigned Shift = 0;
14888 
14889     // Check if this is a trunc(lshr).
14890     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
14891         isa<ConstantSDNode>(User->getOperand(1))) {
14892       Shift = User->getConstantOperandVal(1);
14893       User = *User->use_begin();
14894     }
14895 
14896     // At this point, User is a Truncate, iff we encountered, trunc or
14897     // trunc(lshr).
14898     if (User->getOpcode() != ISD::TRUNCATE)
14899       return false;
14900 
14901     // The width of the type must be a power of 2 and greater than 8-bits.
14902     // Otherwise the load cannot be represented in LLVM IR.
14903     // Moreover, if we shifted with a non-8-bits multiple, the slice
14904     // will be across several bytes. We do not support that.
14905     unsigned Width = User->getValueSizeInBits(0);
14906     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
14907       return false;
14908 
14909     // Build the slice for this chain of computations.
14910     LoadedSlice LS(User, LD, Shift, &DAG);
14911     APInt CurrentUsedBits = LS.getUsedBits();
14912 
14913     // Check if this slice overlaps with another.
14914     if ((CurrentUsedBits & UsedBits) != 0)
14915       return false;
14916     // Update the bits used globally.
14917     UsedBits |= CurrentUsedBits;
14918 
14919     // Check if the new slice would be legal.
14920     if (!LS.isLegal())
14921       return false;
14922 
14923     // Record the slice.
14924     LoadedSlices.push_back(LS);
14925   }
14926 
14927   // Abort slicing if it does not seem to be profitable.
14928   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
14929     return false;
14930 
14931   ++SlicedLoads;
14932 
14933   // Rewrite each chain to use an independent load.
14934   // By construction, each chain can be represented by a unique load.
14935 
14936   // Prepare the argument for the new token factor for all the slices.
14937   SmallVector<SDValue, 8> ArgChains;
14938   for (SmallVectorImpl<LoadedSlice>::const_iterator
14939            LSIt = LoadedSlices.begin(),
14940            LSItEnd = LoadedSlices.end();
14941        LSIt != LSItEnd; ++LSIt) {
14942     SDValue SliceInst = LSIt->loadSlice();
14943     CombineTo(LSIt->Inst, SliceInst, true);
14944     if (SliceInst.getOpcode() != ISD::LOAD)
14945       SliceInst = SliceInst.getOperand(0);
14946     assert(SliceInst->getOpcode() == ISD::LOAD &&
14947            "It takes more than a zext to get to the loaded slice!!");
14948     ArgChains.push_back(SliceInst.getValue(1));
14949   }
14950 
14951   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
14952                               ArgChains);
14953   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
14954   AddToWorklist(Chain.getNode());
14955   return true;
14956 }
14957 
14958 /// Check to see if V is (and load (ptr), imm), where the load is having
14959 /// specific bytes cleared out.  If so, return the byte size being masked out
14960 /// and the shift amount.
14961 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)14962 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
14963   std::pair<unsigned, unsigned> Result(0, 0);
14964 
14965   // Check for the structure we're looking for.
14966   if (V->getOpcode() != ISD::AND ||
14967       !isa<ConstantSDNode>(V->getOperand(1)) ||
14968       !ISD::isNormalLoad(V->getOperand(0).getNode()))
14969     return Result;
14970 
14971   // Check the chain and pointer.
14972   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
14973   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
14974 
14975   // This only handles simple types.
14976   if (V.getValueType() != MVT::i16 &&
14977       V.getValueType() != MVT::i32 &&
14978       V.getValueType() != MVT::i64)
14979     return Result;
14980 
14981   // Check the constant mask.  Invert it so that the bits being masked out are
14982   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
14983   // follow the sign bit for uniformity.
14984   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
14985   unsigned NotMaskLZ = countLeadingZeros(NotMask);
14986   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
14987   unsigned NotMaskTZ = countTrailingZeros(NotMask);
14988   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
14989   if (NotMaskLZ == 64) return Result;  // All zero mask.
14990 
14991   // See if we have a continuous run of bits.  If so, we have 0*1+0*
14992   if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
14993     return Result;
14994 
14995   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
14996   if (V.getValueType() != MVT::i64 && NotMaskLZ)
14997     NotMaskLZ -= 64-V.getValueSizeInBits();
14998 
14999   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
15000   switch (MaskedBytes) {
15001   case 1:
15002   case 2:
15003   case 4: break;
15004   default: return Result; // All one mask, or 5-byte mask.
15005   }
15006 
15007   // Verify that the first bit starts at a multiple of mask so that the access
15008   // is aligned the same as the access width.
15009   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
15010 
15011   // For narrowing to be valid, it must be the case that the load the
15012   // immediately preceding memory operation before the store.
15013   if (LD == Chain.getNode())
15014     ; // ok.
15015   else if (Chain->getOpcode() == ISD::TokenFactor &&
15016            SDValue(LD, 1).hasOneUse()) {
15017     // LD has only 1 chain use so they are no indirect dependencies.
15018     if (!LD->isOperandOf(Chain.getNode()))
15019       return Result;
15020   } else
15021     return Result; // Fail.
15022 
15023   Result.first = MaskedBytes;
15024   Result.second = NotMaskTZ/8;
15025   return Result;
15026 }
15027 
15028 /// Check to see if IVal is something that provides a value as specified by
15029 /// MaskInfo. If so, replace the specified store with a narrower store of
15030 /// truncated IVal.
15031 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)15032 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
15033                                 SDValue IVal, StoreSDNode *St,
15034                                 DAGCombiner *DC) {
15035   unsigned NumBytes = MaskInfo.first;
15036   unsigned ByteShift = MaskInfo.second;
15037   SelectionDAG &DAG = DC->getDAG();
15038 
15039   // Check to see if IVal is all zeros in the part being masked in by the 'or'
15040   // that uses this.  If not, this is not a replacement.
15041   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
15042                                   ByteShift*8, (ByteShift+NumBytes)*8);
15043   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
15044 
15045   // Check that it is legal on the target to do this.  It is legal if the new
15046   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
15047   // legalization (and the target doesn't explicitly think this is a bad idea).
15048   MVT VT = MVT::getIntegerVT(NumBytes * 8);
15049   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
15050   if (!DC->isTypeLegal(VT))
15051     return SDValue();
15052   if (St->getMemOperand() &&
15053       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
15054                               *St->getMemOperand()))
15055     return SDValue();
15056 
15057   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
15058   // shifted by ByteShift and truncated down to NumBytes.
15059   if (ByteShift) {
15060     SDLoc DL(IVal);
15061     IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
15062                        DAG.getConstant(ByteShift*8, DL,
15063                                     DC->getShiftAmountTy(IVal.getValueType())));
15064   }
15065 
15066   // Figure out the offset for the store and the alignment of the access.
15067   unsigned StOffset;
15068   unsigned NewAlign = St->getAlignment();
15069 
15070   if (DAG.getDataLayout().isLittleEndian())
15071     StOffset = ByteShift;
15072   else
15073     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
15074 
15075   SDValue Ptr = St->getBasePtr();
15076   if (StOffset) {
15077     SDLoc DL(IVal);
15078     Ptr = DAG.getMemBasePlusOffset(Ptr, StOffset, DL);
15079     NewAlign = MinAlign(NewAlign, StOffset);
15080   }
15081 
15082   // Truncate down to the new size.
15083   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
15084 
15085   ++OpsNarrowed;
15086   return DAG
15087       .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
15088                 St->getPointerInfo().getWithOffset(StOffset), NewAlign);
15089 }
15090 
15091 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
15092 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
15093 /// narrowing the load and store if it would end up being a win for performance
15094 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)15095 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
15096   StoreSDNode *ST  = cast<StoreSDNode>(N);
15097   if (!ST->isSimple())
15098     return SDValue();
15099 
15100   SDValue Chain = ST->getChain();
15101   SDValue Value = ST->getValue();
15102   SDValue Ptr   = ST->getBasePtr();
15103   EVT VT = Value.getValueType();
15104 
15105   if (ST->isTruncatingStore() || VT.isVector() || !Value.hasOneUse())
15106     return SDValue();
15107 
15108   unsigned Opc = Value.getOpcode();
15109 
15110   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
15111   // is a byte mask indicating a consecutive number of bytes, check to see if
15112   // Y is known to provide just those bytes.  If so, we try to replace the
15113   // load + replace + store sequence with a single (narrower) store, which makes
15114   // the load dead.
15115   if (Opc == ISD::OR) {
15116     std::pair<unsigned, unsigned> MaskedLoad;
15117     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
15118     if (MaskedLoad.first)
15119       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
15120                                                   Value.getOperand(1), ST,this))
15121         return NewST;
15122 
15123     // Or is commutative, so try swapping X and Y.
15124     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
15125     if (MaskedLoad.first)
15126       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
15127                                                   Value.getOperand(0), ST,this))
15128         return NewST;
15129   }
15130 
15131   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
15132       Value.getOperand(1).getOpcode() != ISD::Constant)
15133     return SDValue();
15134 
15135   SDValue N0 = Value.getOperand(0);
15136   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
15137       Chain == SDValue(N0.getNode(), 1)) {
15138     LoadSDNode *LD = cast<LoadSDNode>(N0);
15139     if (LD->getBasePtr() != Ptr ||
15140         LD->getPointerInfo().getAddrSpace() !=
15141         ST->getPointerInfo().getAddrSpace())
15142       return SDValue();
15143 
15144     // Find the type to narrow it the load / op / store to.
15145     SDValue N1 = Value.getOperand(1);
15146     unsigned BitWidth = N1.getValueSizeInBits();
15147     APInt Imm = cast<ConstantSDNode>(N1)->getAPIntValue();
15148     if (Opc == ISD::AND)
15149       Imm ^= APInt::getAllOnesValue(BitWidth);
15150     if (Imm == 0 || Imm.isAllOnesValue())
15151       return SDValue();
15152     unsigned ShAmt = Imm.countTrailingZeros();
15153     unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1;
15154     unsigned NewBW = NextPowerOf2(MSB - ShAmt);
15155     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
15156     // The narrowing should be profitable, the load/store operation should be
15157     // legal (or custom) and the store size should be equal to the NewVT width.
15158     while (NewBW < BitWidth &&
15159            (NewVT.getStoreSizeInBits() != NewBW ||
15160             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
15161             !TLI.isNarrowingProfitable(VT, NewVT))) {
15162       NewBW = NextPowerOf2(NewBW);
15163       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
15164     }
15165     if (NewBW >= BitWidth)
15166       return SDValue();
15167 
15168     // If the lsb changed does not start at the type bitwidth boundary,
15169     // start at the previous one.
15170     if (ShAmt % NewBW)
15171       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
15172     APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
15173                                    std::min(BitWidth, ShAmt + NewBW));
15174     if ((Imm & Mask) == Imm) {
15175       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
15176       if (Opc == ISD::AND)
15177         NewImm ^= APInt::getAllOnesValue(NewBW);
15178       uint64_t PtrOff = ShAmt / 8;
15179       // For big endian targets, we need to adjust the offset to the pointer to
15180       // load the correct bytes.
15181       if (DAG.getDataLayout().isBigEndian())
15182         PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
15183 
15184       unsigned NewAlign = MinAlign(LD->getAlignment(), PtrOff);
15185       Type *NewVTTy = NewVT.getTypeForEVT(*DAG.getContext());
15186       if (NewAlign < DAG.getDataLayout().getABITypeAlignment(NewVTTy))
15187         return SDValue();
15188 
15189       SDValue NewPtr = DAG.getMemBasePlusOffset(Ptr, PtrOff, SDLoc(LD));
15190       SDValue NewLD =
15191           DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
15192                       LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
15193                       LD->getMemOperand()->getFlags(), LD->getAAInfo());
15194       SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
15195                                    DAG.getConstant(NewImm, SDLoc(Value),
15196                                                    NewVT));
15197       SDValue NewST =
15198           DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
15199                        ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
15200 
15201       AddToWorklist(NewPtr.getNode());
15202       AddToWorklist(NewLD.getNode());
15203       AddToWorklist(NewVal.getNode());
15204       WorklistRemover DeadNodes(*this);
15205       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
15206       ++OpsNarrowed;
15207       return NewST;
15208     }
15209   }
15210 
15211   return SDValue();
15212 }
15213 
15214 /// For a given floating point load / store pair, if the load value isn't used
15215 /// by any other operations, then consider transforming the pair to integer
15216 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)15217 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
15218   StoreSDNode *ST  = cast<StoreSDNode>(N);
15219   SDValue Value = ST->getValue();
15220   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
15221       Value.hasOneUse()) {
15222     LoadSDNode *LD = cast<LoadSDNode>(Value);
15223     EVT VT = LD->getMemoryVT();
15224     if (!VT.isFloatingPoint() ||
15225         VT != ST->getMemoryVT() ||
15226         LD->isNonTemporal() ||
15227         ST->isNonTemporal() ||
15228         LD->getPointerInfo().getAddrSpace() != 0 ||
15229         ST->getPointerInfo().getAddrSpace() != 0)
15230       return SDValue();
15231 
15232     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
15233     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
15234         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
15235         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
15236         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT))
15237       return SDValue();
15238 
15239     unsigned LDAlign = LD->getAlignment();
15240     unsigned STAlign = ST->getAlignment();
15241     Type *IntVTTy = IntVT.getTypeForEVT(*DAG.getContext());
15242     unsigned ABIAlign = DAG.getDataLayout().getABITypeAlignment(IntVTTy);
15243     if (LDAlign < ABIAlign || STAlign < ABIAlign)
15244       return SDValue();
15245 
15246     SDValue NewLD =
15247         DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
15248                     LD->getPointerInfo(), LDAlign);
15249 
15250     SDValue NewST =
15251         DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
15252                      ST->getPointerInfo(), STAlign);
15253 
15254     AddToWorklist(NewLD.getNode());
15255     AddToWorklist(NewST.getNode());
15256     WorklistRemover DeadNodes(*this);
15257     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
15258     ++LdStFP2Int;
15259     return NewST;
15260   }
15261 
15262   return SDValue();
15263 }
15264 
15265 // This is a helper function for visitMUL to check the profitability
15266 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
15267 // MulNode is the original multiply, AddNode is (add x, c1),
15268 // and ConstNode is c2.
15269 //
15270 // If the (add x, c1) has multiple uses, we could increase
15271 // the number of adds if we make this transformation.
15272 // It would only be worth doing this if we can remove a
15273 // multiply in the process. Check for that here.
15274 // To illustrate:
15275 //     (A + c1) * c3
15276 //     (A + c2) * c3
15277 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue & AddNode,SDValue & ConstNode)15278 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
15279                                               SDValue &AddNode,
15280                                               SDValue &ConstNode) {
15281   APInt Val;
15282 
15283   // If the add only has one use, this would be OK to do.
15284   if (AddNode.getNode()->hasOneUse())
15285     return true;
15286 
15287   // Walk all the users of the constant with which we're multiplying.
15288   for (SDNode *Use : ConstNode->uses()) {
15289     if (Use == MulNode) // This use is the one we're on right now. Skip it.
15290       continue;
15291 
15292     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
15293       SDNode *OtherOp;
15294       SDNode *MulVar = AddNode.getOperand(0).getNode();
15295 
15296       // OtherOp is what we're multiplying against the constant.
15297       if (Use->getOperand(0) == ConstNode)
15298         OtherOp = Use->getOperand(1).getNode();
15299       else
15300         OtherOp = Use->getOperand(0).getNode();
15301 
15302       // Check to see if multiply is with the same operand of our "add".
15303       //
15304       //     ConstNode  = CONST
15305       //     Use = ConstNode * A  <-- visiting Use. OtherOp is A.
15306       //     ...
15307       //     AddNode  = (A + c1)  <-- MulVar is A.
15308       //         = AddNode * ConstNode   <-- current visiting instruction.
15309       //
15310       // If we make this transformation, we will have a common
15311       // multiply (ConstNode * A) that we can save.
15312       if (OtherOp == MulVar)
15313         return true;
15314 
15315       // Now check to see if a future expansion will give us a common
15316       // multiply.
15317       //
15318       //     ConstNode  = CONST
15319       //     AddNode    = (A + c1)
15320       //     ...   = AddNode * ConstNode <-- current visiting instruction.
15321       //     ...
15322       //     OtherOp = (A + c2)
15323       //     Use     = OtherOp * ConstNode <-- visiting Use.
15324       //
15325       // If we make this transformation, we will have a common
15326       // multiply (CONST * A) after we also do the same transformation
15327       // to the "t2" instruction.
15328       if (OtherOp->getOpcode() == ISD::ADD &&
15329           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
15330           OtherOp->getOperand(0).getNode() == MulVar)
15331         return true;
15332     }
15333   }
15334 
15335   // Didn't find a case where this would be profitable.
15336   return false;
15337 }
15338 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)15339 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
15340                                          unsigned NumStores) {
15341   SmallVector<SDValue, 8> Chains;
15342   SmallPtrSet<const SDNode *, 8> Visited;
15343   SDLoc StoreDL(StoreNodes[0].MemNode);
15344 
15345   for (unsigned i = 0; i < NumStores; ++i) {
15346     Visited.insert(StoreNodes[i].MemNode);
15347   }
15348 
15349   // don't include nodes that are children or repeated nodes.
15350   for (unsigned i = 0; i < NumStores; ++i) {
15351     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
15352       Chains.push_back(StoreNodes[i].MemNode->getChain());
15353   }
15354 
15355   assert(Chains.size() > 0 && "Chain should have generated a chain");
15356   return DAG.getTokenFactor(StoreDL, Chains);
15357 }
15358 
MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)15359 bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
15360     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
15361     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
15362   // Make sure we have something to merge.
15363   if (NumStores < 2)
15364     return false;
15365 
15366   // The latest Node in the DAG.
15367   SDLoc DL(StoreNodes[0].MemNode);
15368 
15369   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
15370   unsigned SizeInBits = NumStores * ElementSizeBits;
15371   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
15372 
15373   EVT StoreTy;
15374   if (UseVector) {
15375     unsigned Elts = NumStores * NumMemElts;
15376     // Get the type for the merged vector store.
15377     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
15378   } else
15379     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
15380 
15381   SDValue StoredVal;
15382   if (UseVector) {
15383     if (IsConstantSrc) {
15384       SmallVector<SDValue, 8> BuildVector;
15385       for (unsigned I = 0; I != NumStores; ++I) {
15386         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
15387         SDValue Val = St->getValue();
15388         // If constant is of the wrong type, convert it now.
15389         if (MemVT != Val.getValueType()) {
15390           Val = peekThroughBitcasts(Val);
15391           // Deal with constants of wrong size.
15392           if (ElementSizeBits != Val.getValueSizeInBits()) {
15393             EVT IntMemVT =
15394                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
15395             if (isa<ConstantFPSDNode>(Val)) {
15396               // Not clear how to truncate FP values.
15397               return false;
15398             } else if (auto *C = dyn_cast<ConstantSDNode>(Val))
15399               Val = DAG.getConstant(C->getAPIntValue()
15400                                         .zextOrTrunc(Val.getValueSizeInBits())
15401                                         .zextOrTrunc(ElementSizeBits),
15402                                     SDLoc(C), IntMemVT);
15403           }
15404           // Make sure correctly size type is the correct type.
15405           Val = DAG.getBitcast(MemVT, Val);
15406         }
15407         BuildVector.push_back(Val);
15408       }
15409       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
15410                                                : ISD::BUILD_VECTOR,
15411                               DL, StoreTy, BuildVector);
15412     } else {
15413       SmallVector<SDValue, 8> Ops;
15414       for (unsigned i = 0; i < NumStores; ++i) {
15415         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
15416         SDValue Val = peekThroughBitcasts(St->getValue());
15417         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
15418         // type MemVT. If the underlying value is not the correct
15419         // type, but it is an extraction of an appropriate vector we
15420         // can recast Val to be of the correct type. This may require
15421         // converting between EXTRACT_VECTOR_ELT and
15422         // EXTRACT_SUBVECTOR.
15423         if ((MemVT != Val.getValueType()) &&
15424             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
15425              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
15426           EVT MemVTScalarTy = MemVT.getScalarType();
15427           // We may need to add a bitcast here to get types to line up.
15428           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
15429             Val = DAG.getBitcast(MemVT, Val);
15430           } else {
15431             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
15432                                             : ISD::EXTRACT_VECTOR_ELT;
15433             SDValue Vec = Val.getOperand(0);
15434             SDValue Idx = Val.getOperand(1);
15435             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
15436           }
15437         }
15438         Ops.push_back(Val);
15439       }
15440 
15441       // Build the extracted vector elements back into a vector.
15442       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
15443                                                : ISD::BUILD_VECTOR,
15444                               DL, StoreTy, Ops);
15445     }
15446   } else {
15447     // We should always use a vector store when merging extracted vector
15448     // elements, so this path implies a store of constants.
15449     assert(IsConstantSrc && "Merged vector elements should use vector store");
15450 
15451     APInt StoreInt(SizeInBits, 0);
15452 
15453     // Construct a single integer constant which is made of the smaller
15454     // constant inputs.
15455     bool IsLE = DAG.getDataLayout().isLittleEndian();
15456     for (unsigned i = 0; i < NumStores; ++i) {
15457       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
15458       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
15459 
15460       SDValue Val = St->getValue();
15461       Val = peekThroughBitcasts(Val);
15462       StoreInt <<= ElementSizeBits;
15463       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
15464         StoreInt |= C->getAPIntValue()
15465                         .zextOrTrunc(ElementSizeBits)
15466                         .zextOrTrunc(SizeInBits);
15467       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
15468         StoreInt |= C->getValueAPF()
15469                         .bitcastToAPInt()
15470                         .zextOrTrunc(ElementSizeBits)
15471                         .zextOrTrunc(SizeInBits);
15472         // If fp truncation is necessary give up for now.
15473         if (MemVT.getSizeInBits() != ElementSizeBits)
15474           return false;
15475       } else {
15476         llvm_unreachable("Invalid constant element type");
15477       }
15478     }
15479 
15480     // Create the new Load and Store operations.
15481     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
15482   }
15483 
15484   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
15485   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
15486 
15487   // make sure we use trunc store if it's necessary to be legal.
15488   SDValue NewStore;
15489   if (!UseTrunc) {
15490     NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
15491                             FirstInChain->getPointerInfo(),
15492                             FirstInChain->getAlignment());
15493   } else { // Must be realized as a trunc store
15494     EVT LegalizedStoredValTy =
15495         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
15496     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
15497     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
15498     SDValue ExtendedStoreVal =
15499         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
15500                         LegalizedStoredValTy);
15501     NewStore = DAG.getTruncStore(
15502         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
15503         FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
15504         FirstInChain->getAlignment(),
15505         FirstInChain->getMemOperand()->getFlags());
15506   }
15507 
15508   // Replace all merged stores with the new store.
15509   for (unsigned i = 0; i < NumStores; ++i)
15510     CombineTo(StoreNodes[i].MemNode, NewStore);
15511 
15512   AddToWorklist(NewChain.getNode());
15513   return true;
15514 }
15515 
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)15516 void DAGCombiner::getStoreMergeCandidates(
15517     StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
15518     SDNode *&RootNode) {
15519   // This holds the base pointer, index, and the offset in bytes from the base
15520   // pointer.
15521   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
15522   EVT MemVT = St->getMemoryVT();
15523 
15524   SDValue Val = peekThroughBitcasts(St->getValue());
15525   // We must have a base and an offset.
15526   if (!BasePtr.getBase().getNode())
15527     return;
15528 
15529   // Do not handle stores to undef base pointers.
15530   if (BasePtr.getBase().isUndef())
15531     return;
15532 
15533   bool IsConstantSrc = isa<ConstantSDNode>(Val) || isa<ConstantFPSDNode>(Val);
15534   bool IsExtractVecSrc = (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
15535                           Val.getOpcode() == ISD::EXTRACT_SUBVECTOR);
15536   bool IsLoadSrc = isa<LoadSDNode>(Val);
15537   BaseIndexOffset LBasePtr;
15538   // Match on loadbaseptr if relevant.
15539   EVT LoadVT;
15540   if (IsLoadSrc) {
15541     auto *Ld = cast<LoadSDNode>(Val);
15542     LBasePtr = BaseIndexOffset::match(Ld, DAG);
15543     LoadVT = Ld->getMemoryVT();
15544     // Load and store should be the same type.
15545     if (MemVT != LoadVT)
15546       return;
15547     // Loads must only have one use.
15548     if (!Ld->hasNUsesOfValue(1, 0))
15549       return;
15550     // The memory operands must not be volatile/indexed/atomic.
15551     // TODO: May be able to relax for unordered atomics (see D66309)
15552     if (!Ld->isSimple() || Ld->isIndexed())
15553       return;
15554   }
15555   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
15556                             int64_t &Offset) -> bool {
15557     // The memory operands must not be volatile/indexed/atomic.
15558     // TODO: May be able to relax for unordered atomics (see D66309)
15559     if (!Other->isSimple() ||  Other->isIndexed())
15560       return false;
15561     // Don't mix temporal stores with non-temporal stores.
15562     if (St->isNonTemporal() != Other->isNonTemporal())
15563       return false;
15564     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
15565     // Allow merging constants of different types as integers.
15566     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
15567                                            : Other->getMemoryVT() != MemVT;
15568     if (IsLoadSrc) {
15569       if (NoTypeMatch)
15570         return false;
15571       // The Load's Base Ptr must also match
15572       if (LoadSDNode *OtherLd = dyn_cast<LoadSDNode>(OtherBC)) {
15573         BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
15574         if (LoadVT != OtherLd->getMemoryVT())
15575           return false;
15576         // Loads must only have one use.
15577         if (!OtherLd->hasNUsesOfValue(1, 0))
15578           return false;
15579         // The memory operands must not be volatile/indexed/atomic.
15580         // TODO: May be able to relax for unordered atomics (see D66309)
15581         if (!OtherLd->isSimple() ||
15582             OtherLd->isIndexed())
15583           return false;
15584         // Don't mix temporal loads with non-temporal loads.
15585         if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
15586           return false;
15587         if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
15588           return false;
15589       } else
15590         return false;
15591     }
15592     if (IsConstantSrc) {
15593       if (NoTypeMatch)
15594         return false;
15595       if (!(isa<ConstantSDNode>(OtherBC) || isa<ConstantFPSDNode>(OtherBC)))
15596         return false;
15597     }
15598     if (IsExtractVecSrc) {
15599       // Do not merge truncated stores here.
15600       if (Other->isTruncatingStore())
15601         return false;
15602       if (!MemVT.bitsEq(OtherBC.getValueType()))
15603         return false;
15604       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
15605           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
15606         return false;
15607     }
15608     Ptr = BaseIndexOffset::match(Other, DAG);
15609     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
15610   };
15611 
15612   // Check if the pair of StoreNode and the RootNode already bail out many
15613   // times which is over the limit in dependence check.
15614   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
15615                                         SDNode *RootNode) -> bool {
15616     auto RootCount = StoreRootCountMap.find(StoreNode);
15617     if (RootCount != StoreRootCountMap.end() &&
15618         RootCount->second.first == RootNode &&
15619         RootCount->second.second > StoreMergeDependenceLimit)
15620       return true;
15621     return false;
15622   };
15623 
15624   // We looking for a root node which is an ancestor to all mergable
15625   // stores. We search up through a load, to our root and then down
15626   // through all children. For instance we will find Store{1,2,3} if
15627   // St is Store1, Store2. or Store3 where the root is not a load
15628   // which always true for nonvolatile ops. TODO: Expand
15629   // the search to find all valid candidates through multiple layers of loads.
15630   //
15631   // Root
15632   // |-------|-------|
15633   // Load    Load    Store3
15634   // |       |
15635   // Store1   Store2
15636   //
15637   // FIXME: We should be able to climb and
15638   // descend TokenFactors to find candidates as well.
15639 
15640   RootNode = St->getChain().getNode();
15641 
15642   unsigned NumNodesExplored = 0;
15643   if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
15644     RootNode = Ldn->getChain().getNode();
15645     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
15646          I != E && NumNodesExplored < 1024; ++I, ++NumNodesExplored)
15647       if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) // walk down chain
15648         for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
15649           if (I2.getOperandNo() == 0)
15650             if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I2)) {
15651               BaseIndexOffset Ptr;
15652               int64_t PtrDiff;
15653               if (CandidateMatch(OtherST, Ptr, PtrDiff) &&
15654                   !OverLimitInDependenceCheck(OtherST, RootNode))
15655                 StoreNodes.push_back(MemOpLink(OtherST, PtrDiff));
15656             }
15657   } else
15658     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
15659          I != E && NumNodesExplored < 1024; ++I, ++NumNodesExplored)
15660       if (I.getOperandNo() == 0)
15661         if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I)) {
15662           BaseIndexOffset Ptr;
15663           int64_t PtrDiff;
15664           if (CandidateMatch(OtherST, Ptr, PtrDiff) &&
15665               !OverLimitInDependenceCheck(OtherST, RootNode))
15666             StoreNodes.push_back(MemOpLink(OtherST, PtrDiff));
15667         }
15668 }
15669 
15670 // We need to check that merging these stores does not cause a loop in
15671 // the DAG. Any store candidate may depend on another candidate
15672 // indirectly through its operand (we already consider dependencies
15673 // through the chain). Check in parallel by searching up from
15674 // non-chain operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)15675 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
15676     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
15677     SDNode *RootNode) {
15678   // FIXME: We should be able to truncate a full search of
15679   // predecessors by doing a BFS and keeping tabs the originating
15680   // stores from which worklist nodes come from in a similar way to
15681   // TokenFactor simplfication.
15682 
15683   SmallPtrSet<const SDNode *, 32> Visited;
15684   SmallVector<const SDNode *, 8> Worklist;
15685 
15686   // RootNode is a predecessor to all candidates so we need not search
15687   // past it. Add RootNode (peeking through TokenFactors). Do not count
15688   // these towards size check.
15689 
15690   Worklist.push_back(RootNode);
15691   while (!Worklist.empty()) {
15692     auto N = Worklist.pop_back_val();
15693     if (!Visited.insert(N).second)
15694       continue; // Already present in Visited.
15695     if (N->getOpcode() == ISD::TokenFactor) {
15696       for (SDValue Op : N->ops())
15697         Worklist.push_back(Op.getNode());
15698     }
15699   }
15700 
15701   // Don't count pruning nodes towards max.
15702   unsigned int Max = 1024 + Visited.size();
15703   // Search Ops of store candidates.
15704   for (unsigned i = 0; i < NumStores; ++i) {
15705     SDNode *N = StoreNodes[i].MemNode;
15706     // Of the 4 Store Operands:
15707     //   * Chain (Op 0) -> We have already considered these
15708     //                    in candidate selection and can be
15709     //                    safely ignored
15710     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
15711     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
15712     //                       but aren't necessarily fromt the same base node, so
15713     //                       cycles possible (e.g. via indexed store).
15714     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
15715     //               non-indexed stores). Not constant on all targets (e.g. ARM)
15716     //               and so can participate in a cycle.
15717     for (unsigned j = 1; j < N->getNumOperands(); ++j)
15718       Worklist.push_back(N->getOperand(j).getNode());
15719   }
15720   // Search through DAG. We can stop early if we find a store node.
15721   for (unsigned i = 0; i < NumStores; ++i)
15722     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
15723                                      Max)) {
15724       // If the searching bail out, record the StoreNode and RootNode in the
15725       // StoreRootCountMap. If we have seen the pair many times over a limit,
15726       // we won't add the StoreNode into StoreNodes set again.
15727       if (Visited.size() >= Max) {
15728         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
15729         if (RootCount.first == RootNode)
15730           RootCount.second++;
15731         else
15732           RootCount = {RootNode, 1};
15733       }
15734       return false;
15735     }
15736   return true;
15737 }
15738 
MergeConsecutiveStores(StoreSDNode * St)15739 bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
15740   if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
15741     return false;
15742 
15743   EVT MemVT = St->getMemoryVT();
15744   int64_t ElementSizeBytes = MemVT.getStoreSize();
15745   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
15746 
15747   if (MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
15748     return false;
15749 
15750   bool NoVectors = DAG.getMachineFunction().getFunction().hasFnAttribute(
15751       Attribute::NoImplicitFloat);
15752 
15753   // This function cannot currently deal with non-byte-sized memory sizes.
15754   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
15755     return false;
15756 
15757   if (!MemVT.isSimple())
15758     return false;
15759 
15760   // Perform an early exit check. Do not bother looking at stored values that
15761   // are not constants, loads, or extracted vector elements.
15762   SDValue StoredVal = peekThroughBitcasts(St->getValue());
15763   bool IsLoadSrc = isa<LoadSDNode>(StoredVal);
15764   bool IsConstantSrc = isa<ConstantSDNode>(StoredVal) ||
15765                        isa<ConstantFPSDNode>(StoredVal);
15766   bool IsExtractVecSrc = (StoredVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
15767                           StoredVal.getOpcode() == ISD::EXTRACT_SUBVECTOR);
15768   bool IsNonTemporalStore = St->isNonTemporal();
15769   bool IsNonTemporalLoad =
15770       IsLoadSrc && cast<LoadSDNode>(StoredVal)->isNonTemporal();
15771 
15772   if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc)
15773     return false;
15774 
15775   SmallVector<MemOpLink, 8> StoreNodes;
15776   SDNode *RootNode;
15777   // Find potential store merge candidates by searching through chain sub-DAG
15778   getStoreMergeCandidates(St, StoreNodes, RootNode);
15779 
15780   // Check if there is anything to merge.
15781   if (StoreNodes.size() < 2)
15782     return false;
15783 
15784   // Sort the memory operands according to their distance from the
15785   // base pointer.
15786   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
15787     return LHS.OffsetFromBase < RHS.OffsetFromBase;
15788   });
15789 
15790   // Store Merge attempts to merge the lowest stores. This generally
15791   // works out as if successful, as the remaining stores are checked
15792   // after the first collection of stores is merged. However, in the
15793   // case that a non-mergeable store is found first, e.g., {p[-2],
15794   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
15795   // mergeable cases. To prevent this, we prune such stores from the
15796   // front of StoreNodes here.
15797 
15798   bool RV = false;
15799   while (StoreNodes.size() > 1) {
15800     size_t StartIdx = 0;
15801     while ((StartIdx + 1 < StoreNodes.size()) &&
15802            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
15803                StoreNodes[StartIdx + 1].OffsetFromBase)
15804       ++StartIdx;
15805 
15806     // Bail if we don't have enough candidates to merge.
15807     if (StartIdx + 1 >= StoreNodes.size())
15808       return RV;
15809 
15810     if (StartIdx)
15811       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
15812 
15813     // Scan the memory operations on the chain and find the first
15814     // non-consecutive store memory address.
15815     unsigned NumConsecutiveStores = 1;
15816     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
15817     // Check that the addresses are consecutive starting from the second
15818     // element in the list of stores.
15819     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
15820       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
15821       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
15822         break;
15823       NumConsecutiveStores = i + 1;
15824     }
15825 
15826     if (NumConsecutiveStores < 2) {
15827       StoreNodes.erase(StoreNodes.begin(),
15828                        StoreNodes.begin() + NumConsecutiveStores);
15829       continue;
15830     }
15831 
15832     // The node with the lowest store address.
15833     LLVMContext &Context = *DAG.getContext();
15834     const DataLayout &DL = DAG.getDataLayout();
15835 
15836     // Store the constants into memory as one consecutive store.
15837     if (IsConstantSrc) {
15838       while (NumConsecutiveStores >= 2) {
15839         LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
15840         unsigned FirstStoreAS = FirstInChain->getAddressSpace();
15841         unsigned FirstStoreAlign = FirstInChain->getAlignment();
15842         unsigned LastLegalType = 1;
15843         unsigned LastLegalVectorType = 1;
15844         bool LastIntegerTrunc = false;
15845         bool NonZero = false;
15846         unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
15847         for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
15848           StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
15849           SDValue StoredVal = ST->getValue();
15850           bool IsElementZero = false;
15851           if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
15852             IsElementZero = C->isNullValue();
15853           else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
15854             IsElementZero = C->getConstantFPValue()->isNullValue();
15855           if (IsElementZero) {
15856             if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
15857               FirstZeroAfterNonZero = i;
15858           }
15859           NonZero |= !IsElementZero;
15860 
15861           // Find a legal type for the constant store.
15862           unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
15863           EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
15864           bool IsFast = false;
15865 
15866           // Break early when size is too large to be legal.
15867           if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
15868             break;
15869 
15870           if (TLI.isTypeLegal(StoreTy) &&
15871               TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
15872               TLI.allowsMemoryAccess(Context, DL, StoreTy,
15873                                      *FirstInChain->getMemOperand(), &IsFast) &&
15874               IsFast) {
15875             LastIntegerTrunc = false;
15876             LastLegalType = i + 1;
15877             // Or check whether a truncstore is legal.
15878           } else if (TLI.getTypeAction(Context, StoreTy) ==
15879                      TargetLowering::TypePromoteInteger) {
15880             EVT LegalizedStoredValTy =
15881                 TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
15882             if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
15883                 TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
15884                 TLI.allowsMemoryAccess(Context, DL, StoreTy,
15885                                        *FirstInChain->getMemOperand(),
15886                                        &IsFast) &&
15887                 IsFast) {
15888               LastIntegerTrunc = true;
15889               LastLegalType = i + 1;
15890             }
15891           }
15892 
15893           // We only use vectors if the constant is known to be zero or the
15894           // target allows it and the function is not marked with the
15895           // noimplicitfloat attribute.
15896           if ((!NonZero ||
15897                TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
15898               !NoVectors) {
15899             // Find a legal type for the vector store.
15900             unsigned Elts = (i + 1) * NumMemElts;
15901             EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
15902             if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
15903                 TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
15904                 TLI.allowsMemoryAccess(
15905                     Context, DL, Ty, *FirstInChain->getMemOperand(), &IsFast) &&
15906                 IsFast)
15907               LastLegalVectorType = i + 1;
15908           }
15909         }
15910 
15911         bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors;
15912         unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
15913 
15914         // Check if we found a legal integer type that creates a meaningful
15915         // merge.
15916         if (NumElem < 2) {
15917           // We know that candidate stores are in order and of correct
15918           // shape. While there is no mergeable sequence from the
15919           // beginning one may start later in the sequence. The only
15920           // reason a merge of size N could have failed where another of
15921           // the same size would not have, is if the alignment has
15922           // improved or we've dropped a non-zero value. Drop as many
15923           // candidates as we can here.
15924           unsigned NumSkip = 1;
15925           while (
15926               (NumSkip < NumConsecutiveStores) &&
15927               (NumSkip < FirstZeroAfterNonZero) &&
15928               (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
15929             NumSkip++;
15930 
15931           StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
15932           NumConsecutiveStores -= NumSkip;
15933           continue;
15934         }
15935 
15936         // Check that we can merge these candidates without causing a cycle.
15937         if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
15938                                                       RootNode)) {
15939           StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
15940           NumConsecutiveStores -= NumElem;
15941           continue;
15942         }
15943 
15944         RV |= MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, true,
15945                                               UseVector, LastIntegerTrunc);
15946 
15947         // Remove merged stores for next iteration.
15948         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
15949         NumConsecutiveStores -= NumElem;
15950       }
15951       continue;
15952     }
15953 
15954     // When extracting multiple vector elements, try to store them
15955     // in one vector store rather than a sequence of scalar stores.
15956     if (IsExtractVecSrc) {
15957       // Loop on Consecutive Stores on success.
15958       while (NumConsecutiveStores >= 2) {
15959         LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
15960         unsigned FirstStoreAS = FirstInChain->getAddressSpace();
15961         unsigned FirstStoreAlign = FirstInChain->getAlignment();
15962         unsigned NumStoresToMerge = 1;
15963         for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
15964           // Find a legal type for the vector store.
15965           unsigned Elts = (i + 1) * NumMemElts;
15966           EVT Ty =
15967               EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
15968           bool IsFast;
15969 
15970           // Break early when size is too large to be legal.
15971           if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
15972             break;
15973 
15974           if (TLI.isTypeLegal(Ty) &&
15975               TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
15976               TLI.allowsMemoryAccess(Context, DL, Ty,
15977                                      *FirstInChain->getMemOperand(), &IsFast) &&
15978               IsFast)
15979             NumStoresToMerge = i + 1;
15980         }
15981 
15982         // Check if we found a legal integer type creating a meaningful
15983         // merge.
15984         if (NumStoresToMerge < 2) {
15985           // We know that candidate stores are in order and of correct
15986           // shape. While there is no mergeable sequence from the
15987           // beginning one may start later in the sequence. The only
15988           // reason a merge of size N could have failed where another of
15989           // the same size would not have, is if the alignment has
15990           // improved. Drop as many candidates as we can here.
15991           unsigned NumSkip = 1;
15992           while (
15993               (NumSkip < NumConsecutiveStores) &&
15994               (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
15995             NumSkip++;
15996 
15997           StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
15998           NumConsecutiveStores -= NumSkip;
15999           continue;
16000         }
16001 
16002         // Check that we can merge these candidates without causing a cycle.
16003         if (!checkMergeStoreCandidatesForDependencies(
16004                 StoreNodes, NumStoresToMerge, RootNode)) {
16005           StoreNodes.erase(StoreNodes.begin(),
16006                            StoreNodes.begin() + NumStoresToMerge);
16007           NumConsecutiveStores -= NumStoresToMerge;
16008           continue;
16009         }
16010 
16011         RV |= MergeStoresOfConstantsOrVecElts(
16012             StoreNodes, MemVT, NumStoresToMerge, false, true, false);
16013 
16014         StoreNodes.erase(StoreNodes.begin(),
16015                          StoreNodes.begin() + NumStoresToMerge);
16016         NumConsecutiveStores -= NumStoresToMerge;
16017       }
16018       continue;
16019     }
16020 
16021     // Below we handle the case of multiple consecutive stores that
16022     // come from multiple consecutive loads. We merge them into a single
16023     // wide load and a single wide store.
16024 
16025     // Look for load nodes which are used by the stored values.
16026     SmallVector<MemOpLink, 8> LoadNodes;
16027 
16028     // Find acceptable loads. Loads need to have the same chain (token factor),
16029     // must not be zext, volatile, indexed, and they must be consecutive.
16030     BaseIndexOffset LdBasePtr;
16031 
16032     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
16033       StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
16034       SDValue Val = peekThroughBitcasts(St->getValue());
16035       LoadSDNode *Ld = cast<LoadSDNode>(Val);
16036 
16037       BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
16038       // If this is not the first ptr that we check.
16039       int64_t LdOffset = 0;
16040       if (LdBasePtr.getBase().getNode()) {
16041         // The base ptr must be the same.
16042         if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
16043           break;
16044       } else {
16045         // Check that all other base pointers are the same as this one.
16046         LdBasePtr = LdPtr;
16047       }
16048 
16049       // We found a potential memory operand to merge.
16050       LoadNodes.push_back(MemOpLink(Ld, LdOffset));
16051     }
16052 
16053     while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
16054       // If we have load/store pair instructions and we only have two values,
16055       // don't bother merging.
16056       unsigned RequiredAlignment;
16057       if (LoadNodes.size() == 2 &&
16058           TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
16059           StoreNodes[0].MemNode->getAlignment() >= RequiredAlignment) {
16060         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
16061         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
16062         break;
16063       }
16064       LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
16065       unsigned FirstStoreAS = FirstInChain->getAddressSpace();
16066       unsigned FirstStoreAlign = FirstInChain->getAlignment();
16067       LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
16068       unsigned FirstLoadAlign = FirstLoad->getAlignment();
16069 
16070       // Scan the memory operations on the chain and find the first
16071       // non-consecutive load memory address. These variables hold the index in
16072       // the store node array.
16073 
16074       unsigned LastConsecutiveLoad = 1;
16075 
16076       // This variable refers to the size and not index in the array.
16077       unsigned LastLegalVectorType = 1;
16078       unsigned LastLegalIntegerType = 1;
16079       bool isDereferenceable = true;
16080       bool DoIntegerTruncate = false;
16081       StartAddress = LoadNodes[0].OffsetFromBase;
16082       SDValue FirstChain = FirstLoad->getChain();
16083       for (unsigned i = 1; i < LoadNodes.size(); ++i) {
16084         // All loads must share the same chain.
16085         if (LoadNodes[i].MemNode->getChain() != FirstChain)
16086           break;
16087 
16088         int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
16089         if (CurrAddress - StartAddress != (ElementSizeBytes * i))
16090           break;
16091         LastConsecutiveLoad = i;
16092 
16093         if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
16094           isDereferenceable = false;
16095 
16096         // Find a legal type for the vector store.
16097         unsigned Elts = (i + 1) * NumMemElts;
16098         EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
16099 
16100         // Break early when size is too large to be legal.
16101         if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
16102           break;
16103 
16104         bool IsFastSt, IsFastLd;
16105         if (TLI.isTypeLegal(StoreTy) &&
16106             TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
16107             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16108                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
16109             IsFastSt &&
16110             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16111                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
16112             IsFastLd) {
16113           LastLegalVectorType = i + 1;
16114         }
16115 
16116         // Find a legal type for the integer store.
16117         unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
16118         StoreTy = EVT::getIntegerVT(Context, SizeInBits);
16119         if (TLI.isTypeLegal(StoreTy) &&
16120             TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
16121             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16122                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
16123             IsFastSt &&
16124             TLI.allowsMemoryAccess(Context, DL, StoreTy,
16125                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
16126             IsFastLd) {
16127           LastLegalIntegerType = i + 1;
16128           DoIntegerTruncate = false;
16129           // Or check whether a truncstore and extload is legal.
16130         } else if (TLI.getTypeAction(Context, StoreTy) ==
16131                    TargetLowering::TypePromoteInteger) {
16132           EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
16133           if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
16134               TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
16135               TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy,
16136                                  StoreTy) &&
16137               TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy,
16138                                  StoreTy) &&
16139               TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
16140               TLI.allowsMemoryAccess(Context, DL, StoreTy,
16141                                      *FirstInChain->getMemOperand(),
16142                                      &IsFastSt) &&
16143               IsFastSt &&
16144               TLI.allowsMemoryAccess(Context, DL, StoreTy,
16145                                      *FirstLoad->getMemOperand(), &IsFastLd) &&
16146               IsFastLd) {
16147             LastLegalIntegerType = i + 1;
16148             DoIntegerTruncate = true;
16149           }
16150         }
16151       }
16152 
16153       // Only use vector types if the vector type is larger than the integer
16154       // type. If they are the same, use integers.
16155       bool UseVectorTy =
16156           LastLegalVectorType > LastLegalIntegerType && !NoVectors;
16157       unsigned LastLegalType =
16158           std::max(LastLegalVectorType, LastLegalIntegerType);
16159 
16160       // We add +1 here because the LastXXX variables refer to location while
16161       // the NumElem refers to array/index size.
16162       unsigned NumElem =
16163           std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
16164       NumElem = std::min(LastLegalType, NumElem);
16165 
16166       if (NumElem < 2) {
16167         // We know that candidate stores are in order and of correct
16168         // shape. While there is no mergeable sequence from the
16169         // beginning one may start later in the sequence. The only
16170         // reason a merge of size N could have failed where another of
16171         // the same size would not have is if the alignment or either
16172         // the load or store has improved. Drop as many candidates as we
16173         // can here.
16174         unsigned NumSkip = 1;
16175         while ((NumSkip < LoadNodes.size()) &&
16176                (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) &&
16177                (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
16178           NumSkip++;
16179         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
16180         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
16181         NumConsecutiveStores -= NumSkip;
16182         continue;
16183       }
16184 
16185       // Check that we can merge these candidates without causing a cycle.
16186       if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
16187                                                     RootNode)) {
16188         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
16189         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
16190         NumConsecutiveStores -= NumElem;
16191         continue;
16192       }
16193 
16194       // Find if it is better to use vectors or integers to load and store
16195       // to memory.
16196       EVT JointMemOpVT;
16197       if (UseVectorTy) {
16198         // Find a legal type for the vector store.
16199         unsigned Elts = NumElem * NumMemElts;
16200         JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
16201       } else {
16202         unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
16203         JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
16204       }
16205 
16206       SDLoc LoadDL(LoadNodes[0].MemNode);
16207       SDLoc StoreDL(StoreNodes[0].MemNode);
16208 
16209       // The merged loads are required to have the same incoming chain, so
16210       // using the first's chain is acceptable.
16211 
16212       SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
16213       AddToWorklist(NewStoreChain.getNode());
16214 
16215       MachineMemOperand::Flags LdMMOFlags =
16216           isDereferenceable ? MachineMemOperand::MODereferenceable
16217                             : MachineMemOperand::MONone;
16218       if (IsNonTemporalLoad)
16219         LdMMOFlags |= MachineMemOperand::MONonTemporal;
16220 
16221       MachineMemOperand::Flags StMMOFlags =
16222           IsNonTemporalStore ? MachineMemOperand::MONonTemporal
16223                              : MachineMemOperand::MONone;
16224 
16225       SDValue NewLoad, NewStore;
16226       if (UseVectorTy || !DoIntegerTruncate) {
16227         NewLoad =
16228             DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(),
16229                         FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(),
16230                         FirstLoadAlign, LdMMOFlags);
16231         NewStore = DAG.getStore(
16232             NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
16233             FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
16234       } else { // This must be the truncstore/extload case
16235         EVT ExtendedTy =
16236             TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
16237         NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
16238                                  FirstLoad->getChain(), FirstLoad->getBasePtr(),
16239                                  FirstLoad->getPointerInfo(), JointMemOpVT,
16240                                  FirstLoadAlign, LdMMOFlags);
16241         NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad,
16242                                      FirstInChain->getBasePtr(),
16243                                      FirstInChain->getPointerInfo(),
16244                                      JointMemOpVT, FirstInChain->getAlignment(),
16245                                      FirstInChain->getMemOperand()->getFlags());
16246       }
16247 
16248       // Transfer chain users from old loads to the new load.
16249       for (unsigned i = 0; i < NumElem; ++i) {
16250         LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
16251         DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
16252                                       SDValue(NewLoad.getNode(), 1));
16253       }
16254 
16255       // Replace the all stores with the new store. Recursively remove
16256       // corresponding value if its no longer used.
16257       for (unsigned i = 0; i < NumElem; ++i) {
16258         SDValue Val = StoreNodes[i].MemNode->getOperand(1);
16259         CombineTo(StoreNodes[i].MemNode, NewStore);
16260         if (Val.getNode()->use_empty())
16261           recursivelyDeleteUnusedNodes(Val.getNode());
16262       }
16263 
16264       RV = true;
16265       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
16266       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
16267       NumConsecutiveStores -= NumElem;
16268     }
16269   }
16270   return RV;
16271 }
16272 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)16273 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
16274   SDLoc SL(ST);
16275   SDValue ReplStore;
16276 
16277   // Replace the chain to avoid dependency.
16278   if (ST->isTruncatingStore()) {
16279     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
16280                                   ST->getBasePtr(), ST->getMemoryVT(),
16281                                   ST->getMemOperand());
16282   } else {
16283     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
16284                              ST->getMemOperand());
16285   }
16286 
16287   // Create token to keep both nodes around.
16288   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
16289                               MVT::Other, ST->getChain(), ReplStore);
16290 
16291   // Make sure the new and old chains are cleaned up.
16292   AddToWorklist(Token.getNode());
16293 
16294   // Don't add users to work list.
16295   return CombineTo(ST, Token, false);
16296 }
16297 
replaceStoreOfFPConstant(StoreSDNode * ST)16298 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
16299   SDValue Value = ST->getValue();
16300   if (Value.getOpcode() == ISD::TargetConstantFP)
16301     return SDValue();
16302 
16303   if (!ISD::isNormalStore(ST))
16304     return SDValue();
16305 
16306   SDLoc DL(ST);
16307 
16308   SDValue Chain = ST->getChain();
16309   SDValue Ptr = ST->getBasePtr();
16310 
16311   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
16312 
16313   // NOTE: If the original store is volatile, this transform must not increase
16314   // the number of stores.  For example, on x86-32 an f64 can be stored in one
16315   // processor operation but an i64 (which is not legal) requires two.  So the
16316   // transform should not be done in this case.
16317 
16318   SDValue Tmp;
16319   switch (CFP->getSimpleValueType(0).SimpleTy) {
16320   default:
16321     llvm_unreachable("Unknown FP type");
16322   case MVT::f16:    // We don't do this for these yet.
16323   case MVT::f80:
16324   case MVT::f128:
16325   case MVT::ppcf128:
16326     return SDValue();
16327   case MVT::f32:
16328     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
16329         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
16330       ;
16331       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
16332                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
16333                             MVT::i32);
16334       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
16335     }
16336 
16337     return SDValue();
16338   case MVT::f64:
16339     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
16340          ST->isSimple()) ||
16341         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
16342       ;
16343       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
16344                             getZExtValue(), SDLoc(CFP), MVT::i64);
16345       return DAG.getStore(Chain, DL, Tmp,
16346                           Ptr, ST->getMemOperand());
16347     }
16348 
16349     if (ST->isSimple() &&
16350         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
16351       // Many FP stores are not made apparent until after legalize, e.g. for
16352       // argument passing.  Since this is so common, custom legalize the
16353       // 64-bit integer store into two 32-bit stores.
16354       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
16355       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
16356       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
16357       if (DAG.getDataLayout().isBigEndian())
16358         std::swap(Lo, Hi);
16359 
16360       unsigned Alignment = ST->getAlignment();
16361       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
16362       AAMDNodes AAInfo = ST->getAAInfo();
16363 
16364       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
16365                                  ST->getAlignment(), MMOFlags, AAInfo);
16366       Ptr = DAG.getMemBasePlusOffset(Ptr, 4, DL);
16367       Alignment = MinAlign(Alignment, 4U);
16368       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
16369                                  ST->getPointerInfo().getWithOffset(4),
16370                                  Alignment, MMOFlags, AAInfo);
16371       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
16372                          St0, St1);
16373     }
16374 
16375     return SDValue();
16376   }
16377 }
16378 
visitSTORE(SDNode * N)16379 SDValue DAGCombiner::visitSTORE(SDNode *N) {
16380   StoreSDNode *ST  = cast<StoreSDNode>(N);
16381   SDValue Chain = ST->getChain();
16382   SDValue Value = ST->getValue();
16383   SDValue Ptr   = ST->getBasePtr();
16384 
16385   // If this is a store of a bit convert, store the input value if the
16386   // resultant store does not need a higher alignment than the original.
16387   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
16388       ST->isUnindexed()) {
16389     EVT SVT = Value.getOperand(0).getValueType();
16390     // If the store is volatile, we only want to change the store type if the
16391     // resulting store is legal. Otherwise we might increase the number of
16392     // memory accesses. We don't care if the original type was legal or not
16393     // as we assume software couldn't rely on the number of accesses of an
16394     // illegal type.
16395     // TODO: May be able to relax for unordered atomics (see D66309)
16396     if (((!LegalOperations && ST->isSimple()) ||
16397          TLI.isOperationLegal(ISD::STORE, SVT)) &&
16398         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
16399                                      DAG, *ST->getMemOperand())) {
16400       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
16401                           ST->getMemOperand());
16402     }
16403   }
16404 
16405   // Turn 'store undef, Ptr' -> nothing.
16406   if (Value.isUndef() && ST->isUnindexed())
16407     return Chain;
16408 
16409   // Try to infer better alignment information than the store already has.
16410   if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) {
16411     if (unsigned Align = DAG.InferPtrAlignment(Ptr)) {
16412       if (Align > ST->getAlignment() && ST->getSrcValueOffset() % Align == 0) {
16413         SDValue NewStore =
16414             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
16415                               ST->getMemoryVT(), Align,
16416                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
16417         // NewStore will always be N as we are only refining the alignment
16418         assert(NewStore.getNode() == N);
16419         (void)NewStore;
16420       }
16421     }
16422   }
16423 
16424   // Try transforming a pair floating point load / store ops to integer
16425   // load / store ops.
16426   if (SDValue NewST = TransformFPLoadStorePair(N))
16427     return NewST;
16428 
16429   // Try transforming several stores into STORE (BSWAP).
16430   if (SDValue Store = MatchStoreCombine(ST))
16431     return Store;
16432 
16433   if (ST->isUnindexed()) {
16434     // Walk up chain skipping non-aliasing memory nodes, on this store and any
16435     // adjacent stores.
16436     if (findBetterNeighborChains(ST)) {
16437       // replaceStoreChain uses CombineTo, which handled all of the worklist
16438       // manipulation. Return the original node to not do anything else.
16439       return SDValue(ST, 0);
16440     }
16441     Chain = ST->getChain();
16442   }
16443 
16444   // FIXME: is there such a thing as a truncating indexed store?
16445   if (ST->isTruncatingStore() && ST->isUnindexed() &&
16446       Value.getValueType().isInteger() &&
16447       (!isa<ConstantSDNode>(Value) ||
16448        !cast<ConstantSDNode>(Value)->isOpaque())) {
16449     APInt TruncDemandedBits =
16450         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
16451                              ST->getMemoryVT().getScalarSizeInBits());
16452 
16453     // See if we can simplify the input to this truncstore with knowledge that
16454     // only the low bits are being used.  For example:
16455     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
16456     AddToWorklist(Value.getNode());
16457     if (SDValue Shorter = DAG.GetDemandedBits(Value, TruncDemandedBits))
16458       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
16459                                ST->getMemOperand());
16460 
16461     // Otherwise, see if we can simplify the operation with
16462     // SimplifyDemandedBits, which only works if the value has a single use.
16463     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
16464       // Re-visit the store if anything changed and the store hasn't been merged
16465       // with another node (N is deleted) SimplifyDemandedBits will add Value's
16466       // node back to the worklist if necessary, but we also need to re-visit
16467       // the Store node itself.
16468       if (N->getOpcode() != ISD::DELETED_NODE)
16469         AddToWorklist(N);
16470       return SDValue(N, 0);
16471     }
16472   }
16473 
16474   // If this is a load followed by a store to the same location, then the store
16475   // is dead/noop.
16476   // TODO: Can relax for unordered atomics (see D66309)
16477   if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) {
16478     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
16479         ST->isUnindexed() && ST->isSimple() &&
16480         // There can't be any side effects between the load and store, such as
16481         // a call or store.
16482         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
16483       // The store is dead, remove it.
16484       return Chain;
16485     }
16486   }
16487 
16488   // TODO: Can relax for unordered atomics (see D66309)
16489   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
16490     if (ST->isUnindexed() && ST->isSimple() &&
16491         ST1->isUnindexed() && ST1->isSimple()) {
16492       if (ST1->getBasePtr() == Ptr && ST1->getValue() == Value &&
16493           ST->getMemoryVT() == ST1->getMemoryVT()) {
16494         // If this is a store followed by a store with the same value to the
16495         // same location, then the store is dead/noop.
16496         return Chain;
16497       }
16498 
16499       if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
16500           !ST1->getBasePtr().isUndef()) {
16501         const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
16502         const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
16503         unsigned STBitSize = ST->getMemoryVT().getSizeInBits();
16504         unsigned ChainBitSize = ST1->getMemoryVT().getSizeInBits();
16505         // If this is a store who's preceding store to a subset of the current
16506         // location and no one other node is chained to that store we can
16507         // effectively drop the store. Do not remove stores to undef as they may
16508         // be used as data sinks.
16509         if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) {
16510           CombineTo(ST1, ST1->getChain());
16511           return SDValue();
16512         }
16513       }
16514     }
16515   }
16516 
16517   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
16518   // truncating store.  We can do this even if this is already a truncstore.
16519   if ((Value.getOpcode() == ISD::FP_ROUND || Value.getOpcode() == ISD::TRUNCATE)
16520       && Value.getNode()->hasOneUse() && ST->isUnindexed() &&
16521       TLI.isTruncStoreLegal(Value.getOperand(0).getValueType(),
16522                             ST->getMemoryVT())) {
16523     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
16524                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
16525   }
16526 
16527   // Always perform this optimization before types are legal. If the target
16528   // prefers, also try this after legalization to catch stores that were created
16529   // by intrinsics or other nodes.
16530   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
16531     while (true) {
16532       // There can be multiple store sequences on the same chain.
16533       // Keep trying to merge store sequences until we are unable to do so
16534       // or until we merge the last store on the chain.
16535       bool Changed = MergeConsecutiveStores(ST);
16536       if (!Changed) break;
16537       // Return N as merge only uses CombineTo and no worklist clean
16538       // up is necessary.
16539       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
16540         return SDValue(N, 0);
16541     }
16542   }
16543 
16544   // Try transforming N to an indexed store.
16545   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
16546     return SDValue(N, 0);
16547 
16548   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
16549   //
16550   // Make sure to do this only after attempting to merge stores in order to
16551   //  avoid changing the types of some subset of stores due to visit order,
16552   //  preventing their merging.
16553   if (isa<ConstantFPSDNode>(ST->getValue())) {
16554     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
16555       return NewSt;
16556   }
16557 
16558   if (SDValue NewSt = splitMergedValStore(ST))
16559     return NewSt;
16560 
16561   return ReduceLoadOpStoreWidth(N);
16562 }
16563 
visitLIFETIME_END(SDNode * N)16564 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
16565   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
16566   if (!LifetimeEnd->hasOffset())
16567     return SDValue();
16568 
16569   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
16570                                         LifetimeEnd->getOffset(), false);
16571 
16572   // We walk up the chains to find stores.
16573   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
16574   while (!Chains.empty()) {
16575     SDValue Chain = Chains.back();
16576     Chains.pop_back();
16577     if (!Chain.hasOneUse())
16578       continue;
16579     switch (Chain.getOpcode()) {
16580     case ISD::TokenFactor:
16581       for (unsigned Nops = Chain.getNumOperands(); Nops;)
16582         Chains.push_back(Chain.getOperand(--Nops));
16583       break;
16584     case ISD::LIFETIME_START:
16585     case ISD::LIFETIME_END:
16586       // We can forward past any lifetime start/end that can be proven not to
16587       // alias the node.
16588       if (!isAlias(Chain.getNode(), N))
16589         Chains.push_back(Chain.getOperand(0));
16590       break;
16591     case ISD::STORE: {
16592       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
16593       // TODO: Can relax for unordered atomics (see D66309)
16594       if (!ST->isSimple() || ST->isIndexed())
16595         continue;
16596       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
16597       // If we store purely within object bounds just before its lifetime ends,
16598       // we can remove the store.
16599       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
16600                                    ST->getMemoryVT().getStoreSizeInBits())) {
16601         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
16602                    dbgs() << "\nwithin LIFETIME_END of : ";
16603                    LifetimeEndBase.dump(); dbgs() << "\n");
16604         CombineTo(ST, ST->getChain());
16605         return SDValue(N, 0);
16606       }
16607     }
16608     }
16609   }
16610   return SDValue();
16611 }
16612 
16613 /// For the instruction sequence of store below, F and I values
16614 /// are bundled together as an i64 value before being stored into memory.
16615 /// Sometimes it is more efficent to generate separate stores for F and I,
16616 /// which can remove the bitwise instructions or sink them to colder places.
16617 ///
16618 ///   (store (or (zext (bitcast F to i32) to i64),
16619 ///              (shl (zext I to i64), 32)), addr)  -->
16620 ///   (store F, addr) and (store I, addr+4)
16621 ///
16622 /// Similarly, splitting for other merged store can also be beneficial, like:
16623 /// For pair of {i32, i32}, i64 store --> two i32 stores.
16624 /// For pair of {i32, i16}, i64 store --> two i32 stores.
16625 /// For pair of {i16, i16}, i32 store --> two i16 stores.
16626 /// For pair of {i16, i8},  i32 store --> two i16 stores.
16627 /// For pair of {i8, i8},   i16 store --> two i8 stores.
16628 ///
16629 /// We allow each target to determine specifically which kind of splitting is
16630 /// supported.
16631 ///
16632 /// The store patterns are commonly seen from the simple code snippet below
16633 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
16634 ///   void goo(const std::pair<int, float> &);
16635 ///   hoo() {
16636 ///     ...
16637 ///     goo(std::make_pair(tmp, ftmp));
16638 ///     ...
16639 ///   }
16640 ///
splitMergedValStore(StoreSDNode * ST)16641 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
16642   if (OptLevel == CodeGenOpt::None)
16643     return SDValue();
16644 
16645   // Can't change the number of memory accesses for a volatile store or break
16646   // atomicity for an atomic one.
16647   if (!ST->isSimple())
16648     return SDValue();
16649 
16650   SDValue Val = ST->getValue();
16651   SDLoc DL(ST);
16652 
16653   // Match OR operand.
16654   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
16655     return SDValue();
16656 
16657   // Match SHL operand and get Lower and Higher parts of Val.
16658   SDValue Op1 = Val.getOperand(0);
16659   SDValue Op2 = Val.getOperand(1);
16660   SDValue Lo, Hi;
16661   if (Op1.getOpcode() != ISD::SHL) {
16662     std::swap(Op1, Op2);
16663     if (Op1.getOpcode() != ISD::SHL)
16664       return SDValue();
16665   }
16666   Lo = Op2;
16667   Hi = Op1.getOperand(0);
16668   if (!Op1.hasOneUse())
16669     return SDValue();
16670 
16671   // Match shift amount to HalfValBitSize.
16672   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
16673   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
16674   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
16675     return SDValue();
16676 
16677   // Lo and Hi are zero-extended from int with size less equal than 32
16678   // to i64.
16679   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
16680       !Lo.getOperand(0).getValueType().isScalarInteger() ||
16681       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
16682       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
16683       !Hi.getOperand(0).getValueType().isScalarInteger() ||
16684       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
16685     return SDValue();
16686 
16687   // Use the EVT of low and high parts before bitcast as the input
16688   // of target query.
16689   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
16690                   ? Lo.getOperand(0).getValueType()
16691                   : Lo.getValueType();
16692   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
16693                    ? Hi.getOperand(0).getValueType()
16694                    : Hi.getValueType();
16695   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
16696     return SDValue();
16697 
16698   // Start to split store.
16699   unsigned Alignment = ST->getAlignment();
16700   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
16701   AAMDNodes AAInfo = ST->getAAInfo();
16702 
16703   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
16704   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
16705   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
16706   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
16707 
16708   SDValue Chain = ST->getChain();
16709   SDValue Ptr = ST->getBasePtr();
16710   // Lower value store.
16711   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
16712                              ST->getAlignment(), MMOFlags, AAInfo);
16713   Ptr = DAG.getMemBasePlusOffset(Ptr, HalfValBitSize / 8, DL);
16714   // Higher value store.
16715   SDValue St1 =
16716       DAG.getStore(St0, DL, Hi, Ptr,
16717                    ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
16718                    Alignment / 2, MMOFlags, AAInfo);
16719   return St1;
16720 }
16721 
16722 /// Convert a disguised subvector insertion into a shuffle:
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)16723 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
16724   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
16725          "Expected extract_vector_elt");
16726   SDValue InsertVal = N->getOperand(1);
16727   SDValue Vec = N->getOperand(0);
16728 
16729   // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
16730   // InsIndex)
16731   //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
16732   //   CONCAT_VECTORS.
16733   if (Vec.getOpcode() == ISD::VECTOR_SHUFFLE && Vec.hasOneUse() &&
16734       InsertVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
16735       isa<ConstantSDNode>(InsertVal.getOperand(1))) {
16736     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Vec.getNode());
16737     ArrayRef<int> Mask = SVN->getMask();
16738 
16739     SDValue X = Vec.getOperand(0);
16740     SDValue Y = Vec.getOperand(1);
16741 
16742     // Vec's operand 0 is using indices from 0 to N-1 and
16743     // operand 1 from N to 2N - 1, where N is the number of
16744     // elements in the vectors.
16745     SDValue InsertVal0 = InsertVal.getOperand(0);
16746     int ElementOffset = -1;
16747 
16748     // We explore the inputs of the shuffle in order to see if we find the
16749     // source of the extract_vector_elt. If so, we can use it to modify the
16750     // shuffle rather than perform an insert_vector_elt.
16751     SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
16752     ArgWorkList.emplace_back(Mask.size(), Y);
16753     ArgWorkList.emplace_back(0, X);
16754 
16755     while (!ArgWorkList.empty()) {
16756       int ArgOffset;
16757       SDValue ArgVal;
16758       std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
16759 
16760       if (ArgVal == InsertVal0) {
16761         ElementOffset = ArgOffset;
16762         break;
16763       }
16764 
16765       // Peek through concat_vector.
16766       if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
16767         int CurrentArgOffset =
16768             ArgOffset + ArgVal.getValueType().getVectorNumElements();
16769         int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
16770         for (SDValue Op : reverse(ArgVal->ops())) {
16771           CurrentArgOffset -= Step;
16772           ArgWorkList.emplace_back(CurrentArgOffset, Op);
16773         }
16774 
16775         // Make sure we went through all the elements and did not screw up index
16776         // computation.
16777         assert(CurrentArgOffset == ArgOffset);
16778       }
16779     }
16780 
16781     if (ElementOffset != -1) {
16782       SmallVector<int, 16> NewMask(Mask.begin(), Mask.end());
16783 
16784       auto *ExtrIndex = cast<ConstantSDNode>(InsertVal.getOperand(1));
16785       NewMask[InsIndex] = ElementOffset + ExtrIndex->getZExtValue();
16786       assert(NewMask[InsIndex] <
16787                  (int)(2 * Vec.getValueType().getVectorNumElements()) &&
16788              NewMask[InsIndex] >= 0 && "NewMask[InsIndex] is out of bound");
16789 
16790       SDValue LegalShuffle =
16791               TLI.buildLegalVectorShuffle(Vec.getValueType(), SDLoc(N), X,
16792                                           Y, NewMask, DAG);
16793       if (LegalShuffle)
16794         return LegalShuffle;
16795     }
16796   }
16797 
16798   // insert_vector_elt V, (bitcast X from vector type), IdxC -->
16799   // bitcast(shuffle (bitcast V), (extended X), Mask)
16800   // Note: We do not use an insert_subvector node because that requires a
16801   // legal subvector type.
16802   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
16803       !InsertVal.getOperand(0).getValueType().isVector())
16804     return SDValue();
16805 
16806   SDValue SubVec = InsertVal.getOperand(0);
16807   SDValue DestVec = N->getOperand(0);
16808   EVT SubVecVT = SubVec.getValueType();
16809   EVT VT = DestVec.getValueType();
16810   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
16811   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
16812   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
16813 
16814   // Step 1: Create a shuffle mask that implements this insert operation. The
16815   // vector that we are inserting into will be operand 0 of the shuffle, so
16816   // those elements are just 'i'. The inserted subvector is in the first
16817   // positions of operand 1 of the shuffle. Example:
16818   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
16819   SmallVector<int, 16> Mask(NumMaskVals);
16820   for (unsigned i = 0; i != NumMaskVals; ++i) {
16821     if (i / NumSrcElts == InsIndex)
16822       Mask[i] = (i % NumSrcElts) + NumMaskVals;
16823     else
16824       Mask[i] = i;
16825   }
16826 
16827   // Bail out if the target can not handle the shuffle we want to create.
16828   EVT SubVecEltVT = SubVecVT.getVectorElementType();
16829   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
16830   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
16831     return SDValue();
16832 
16833   // Step 2: Create a wide vector from the inserted source vector by appending
16834   // undefined elements. This is the same size as our destination vector.
16835   SDLoc DL(N);
16836   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
16837   ConcatOps[0] = SubVec;
16838   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
16839 
16840   // Step 3: Shuffle in the padded subvector.
16841   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
16842   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
16843   AddToWorklist(PaddedSubV.getNode());
16844   AddToWorklist(DestVecBC.getNode());
16845   AddToWorklist(Shuf.getNode());
16846   return DAG.getBitcast(VT, Shuf);
16847 }
16848 
visitINSERT_VECTOR_ELT(SDNode * N)16849 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
16850   SDValue InVec = N->getOperand(0);
16851   SDValue InVal = N->getOperand(1);
16852   SDValue EltNo = N->getOperand(2);
16853   SDLoc DL(N);
16854 
16855   EVT VT = InVec.getValueType();
16856   unsigned NumElts = VT.getVectorNumElements();
16857 
16858   // Insert into out-of-bounds element is undefined.
16859   if (auto *IndexC = dyn_cast<ConstantSDNode>(EltNo))
16860     if (IndexC->getZExtValue() >= VT.getVectorNumElements())
16861       return DAG.getUNDEF(VT);
16862 
16863   // Remove redundant insertions:
16864   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
16865   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
16866       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
16867     return InVec;
16868 
16869   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
16870   if (!IndexC) {
16871     // If this is variable insert to undef vector, it might be better to splat:
16872     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
16873     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
16874       SmallVector<SDValue, 8> Ops(NumElts, InVal);
16875       return DAG.getBuildVector(VT, DL, Ops);
16876     }
16877     return SDValue();
16878   }
16879 
16880   // We must know which element is being inserted for folds below here.
16881   unsigned Elt = IndexC->getZExtValue();
16882   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
16883     return Shuf;
16884 
16885   // Canonicalize insert_vector_elt dag nodes.
16886   // Example:
16887   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
16888   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
16889   //
16890   // Do this only if the child insert_vector node has one use; also
16891   // do this only if indices are both constants and Idx1 < Idx0.
16892   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
16893       && isa<ConstantSDNode>(InVec.getOperand(2))) {
16894     unsigned OtherElt = InVec.getConstantOperandVal(2);
16895     if (Elt < OtherElt) {
16896       // Swap nodes.
16897       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
16898                                   InVec.getOperand(0), InVal, EltNo);
16899       AddToWorklist(NewOp.getNode());
16900       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
16901                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
16902     }
16903   }
16904 
16905   // If we can't generate a legal BUILD_VECTOR, exit
16906   if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
16907     return SDValue();
16908 
16909   // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially
16910   // be converted to a BUILD_VECTOR).  Fill in the Ops vector with the
16911   // vector elements.
16912   SmallVector<SDValue, 8> Ops;
16913   // Do not combine these two vectors if the output vector will not replace
16914   // the input vector.
16915   if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) {
16916     Ops.append(InVec.getNode()->op_begin(),
16917                InVec.getNode()->op_end());
16918   } else if (InVec.isUndef()) {
16919     Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType()));
16920   } else {
16921     return SDValue();
16922   }
16923   assert(Ops.size() == NumElts && "Unexpected vector size");
16924 
16925   // Insert the element
16926   if (Elt < Ops.size()) {
16927     // All the operands of BUILD_VECTOR must have the same type;
16928     // we enforce that here.
16929     EVT OpVT = Ops[0].getValueType();
16930     Ops[Elt] = OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal;
16931   }
16932 
16933   // Return the new vector
16934   return DAG.getBuildVector(VT, DL, Ops);
16935 }
16936 
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)16937 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
16938                                                   SDValue EltNo,
16939                                                   LoadSDNode *OriginalLoad) {
16940   assert(OriginalLoad->isSimple());
16941 
16942   EVT ResultVT = EVE->getValueType(0);
16943   EVT VecEltVT = InVecVT.getVectorElementType();
16944   unsigned Align = OriginalLoad->getAlignment();
16945   unsigned NewAlign = DAG.getDataLayout().getABITypeAlignment(
16946       VecEltVT.getTypeForEVT(*DAG.getContext()));
16947 
16948   if (NewAlign > Align || !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT))
16949     return SDValue();
16950 
16951   ISD::LoadExtType ExtTy = ResultVT.bitsGT(VecEltVT) ?
16952     ISD::NON_EXTLOAD : ISD::EXTLOAD;
16953   if (!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
16954     return SDValue();
16955 
16956   Align = NewAlign;
16957 
16958   SDValue NewPtr = OriginalLoad->getBasePtr();
16959   SDValue Offset;
16960   EVT PtrType = NewPtr.getValueType();
16961   MachinePointerInfo MPI;
16962   SDLoc DL(EVE);
16963   if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
16964     int Elt = ConstEltNo->getZExtValue();
16965     unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
16966     Offset = DAG.getConstant(PtrOff, DL, PtrType);
16967     MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
16968   } else {
16969     Offset = DAG.getZExtOrTrunc(EltNo, DL, PtrType);
16970     Offset = DAG.getNode(
16971         ISD::MUL, DL, PtrType, Offset,
16972         DAG.getConstant(VecEltVT.getStoreSize(), DL, PtrType));
16973     // Discard the pointer info except the address space because the memory
16974     // operand can't represent this new access since the offset is variable.
16975     MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
16976   }
16977   NewPtr = DAG.getMemBasePlusOffset(NewPtr, Offset, DL);
16978 
16979   // The replacement we need to do here is a little tricky: we need to
16980   // replace an extractelement of a load with a load.
16981   // Use ReplaceAllUsesOfValuesWith to do the replacement.
16982   // Note that this replacement assumes that the extractvalue is the only
16983   // use of the load; that's okay because we don't want to perform this
16984   // transformation in other cases anyway.
16985   SDValue Load;
16986   SDValue Chain;
16987   if (ResultVT.bitsGT(VecEltVT)) {
16988     // If the result type of vextract is wider than the load, then issue an
16989     // extending load instead.
16990     ISD::LoadExtType ExtType = TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT,
16991                                                   VecEltVT)
16992                                    ? ISD::ZEXTLOAD
16993                                    : ISD::EXTLOAD;
16994     Load = DAG.getExtLoad(ExtType, SDLoc(EVE), ResultVT,
16995                           OriginalLoad->getChain(), NewPtr, MPI, VecEltVT,
16996                           Align, OriginalLoad->getMemOperand()->getFlags(),
16997                           OriginalLoad->getAAInfo());
16998     Chain = Load.getValue(1);
16999   } else {
17000     Load = DAG.getLoad(VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr,
17001                        MPI, Align, OriginalLoad->getMemOperand()->getFlags(),
17002                        OriginalLoad->getAAInfo());
17003     Chain = Load.getValue(1);
17004     if (ResultVT.bitsLT(VecEltVT))
17005       Load = DAG.getNode(ISD::TRUNCATE, SDLoc(EVE), ResultVT, Load);
17006     else
17007       Load = DAG.getBitcast(ResultVT, Load);
17008   }
17009   WorklistRemover DeadNodes(*this);
17010   SDValue From[] = { SDValue(EVE, 0), SDValue(OriginalLoad, 1) };
17011   SDValue To[] = { Load, Chain };
17012   DAG.ReplaceAllUsesOfValuesWith(From, To, 2);
17013   // Make sure to revisit this node to clean it up; it will usually be dead.
17014   AddToWorklist(EVE);
17015   // Since we're explicitly calling ReplaceAllUses, add the new node to the
17016   // worklist explicitly as well.
17017   AddToWorklistWithUsers(Load.getNode());
17018   ++OpsNarrowed;
17019   return SDValue(EVE, 0);
17020 }
17021 
17022 /// Transform a vector binary operation into a scalar binary operation by moving
17023 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)17024 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
17025                                        bool LegalOperations) {
17026   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17027   SDValue Vec = ExtElt->getOperand(0);
17028   SDValue Index = ExtElt->getOperand(1);
17029   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
17030   if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
17031       Vec.getNode()->getNumValues() != 1)
17032     return SDValue();
17033 
17034   // Targets may want to avoid this to prevent an expensive register transfer.
17035   if (!TLI.shouldScalarizeBinop(Vec))
17036     return SDValue();
17037 
17038   // Extracting an element of a vector constant is constant-folded, so this
17039   // transform is just replacing a vector op with a scalar op while moving the
17040   // extract.
17041   SDValue Op0 = Vec.getOperand(0);
17042   SDValue Op1 = Vec.getOperand(1);
17043   if (isAnyConstantBuildVector(Op0, true) ||
17044       isAnyConstantBuildVector(Op1, true)) {
17045     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
17046     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
17047     SDLoc DL(ExtElt);
17048     EVT VT = ExtElt->getValueType(0);
17049     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
17050     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
17051     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
17052   }
17053 
17054   return SDValue();
17055 }
17056 
visitEXTRACT_VECTOR_ELT(SDNode * N)17057 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
17058   SDValue VecOp = N->getOperand(0);
17059   SDValue Index = N->getOperand(1);
17060   EVT ScalarVT = N->getValueType(0);
17061   EVT VecVT = VecOp.getValueType();
17062   if (VecOp.isUndef())
17063     return DAG.getUNDEF(ScalarVT);
17064 
17065   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
17066   //
17067   // This only really matters if the index is non-constant since other combines
17068   // on the constant elements already work.
17069   SDLoc DL(N);
17070   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
17071       Index == VecOp.getOperand(2)) {
17072     SDValue Elt = VecOp.getOperand(1);
17073     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
17074   }
17075 
17076   // (vextract (scalar_to_vector val, 0) -> val
17077   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
17078     // Check if the result type doesn't match the inserted element type. A
17079     // SCALAR_TO_VECTOR may truncate the inserted element and the
17080     // EXTRACT_VECTOR_ELT may widen the extracted vector.
17081     SDValue InOp = VecOp.getOperand(0);
17082     if (InOp.getValueType() != ScalarVT) {
17083       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
17084       return DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
17085     }
17086     return InOp;
17087   }
17088 
17089   // extract_vector_elt of out-of-bounds element -> UNDEF
17090   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
17091   unsigned NumElts = VecVT.getVectorNumElements();
17092   if (IndexC && IndexC->getAPIntValue().uge(NumElts))
17093     return DAG.getUNDEF(ScalarVT);
17094 
17095   // extract_vector_elt (build_vector x, y), 1 -> y
17096   if (IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR &&
17097       TLI.isTypeLegal(VecVT) &&
17098       (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
17099     SDValue Elt = VecOp.getOperand(IndexC->getZExtValue());
17100     EVT InEltVT = Elt.getValueType();
17101 
17102     // Sometimes build_vector's scalar input types do not match result type.
17103     if (ScalarVT == InEltVT)
17104       return Elt;
17105 
17106     // TODO: It may be useful to truncate if free if the build_vector implicitly
17107     // converts.
17108   }
17109 
17110   // TODO: These transforms should not require the 'hasOneUse' restriction, but
17111   // there are regressions on multiple targets without it. We can end up with a
17112   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
17113   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
17114       VecOp.hasOneUse()) {
17115     // The vector index of the LSBs of the source depend on the endian-ness.
17116     bool IsLE = DAG.getDataLayout().isLittleEndian();
17117     unsigned ExtractIndex = IndexC->getZExtValue();
17118     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
17119     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
17120     SDValue BCSrc = VecOp.getOperand(0);
17121     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
17122       return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc);
17123 
17124     if (LegalTypes && BCSrc.getValueType().isInteger() &&
17125         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
17126       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
17127       // trunc i64 X to i32
17128       SDValue X = BCSrc.getOperand(0);
17129       assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
17130              "Extract element and scalar to vector can't change element type "
17131              "from FP to integer.");
17132       unsigned XBitWidth = X.getValueSizeInBits();
17133       unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
17134       BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
17135 
17136       // An extract element return value type can be wider than its vector
17137       // operand element type. In that case, the high bits are undefined, so
17138       // it's possible that we may need to extend rather than truncate.
17139       if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
17140         assert(XBitWidth % VecEltBitWidth == 0 &&
17141                "Scalar bitwidth must be a multiple of vector element bitwidth");
17142         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
17143       }
17144     }
17145   }
17146 
17147   if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
17148     return BO;
17149 
17150   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
17151   // We only perform this optimization before the op legalization phase because
17152   // we may introduce new vector instructions which are not backed by TD
17153   // patterns. For example on AVX, extracting elements from a wide vector
17154   // without using extract_subvector. However, if we can find an underlying
17155   // scalar value, then we can always use that.
17156   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
17157     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
17158     // Find the new index to extract from.
17159     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
17160 
17161     // Extracting an undef index is undef.
17162     if (OrigElt == -1)
17163       return DAG.getUNDEF(ScalarVT);
17164 
17165     // Select the right vector half to extract from.
17166     SDValue SVInVec;
17167     if (OrigElt < (int)NumElts) {
17168       SVInVec = VecOp.getOperand(0);
17169     } else {
17170       SVInVec = VecOp.getOperand(1);
17171       OrigElt -= NumElts;
17172     }
17173 
17174     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
17175       SDValue InOp = SVInVec.getOperand(OrigElt);
17176       if (InOp.getValueType() != ScalarVT) {
17177         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
17178         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
17179       }
17180 
17181       return InOp;
17182     }
17183 
17184     // FIXME: We should handle recursing on other vector shuffles and
17185     // scalar_to_vector here as well.
17186 
17187     if (!LegalOperations ||
17188         // FIXME: Should really be just isOperationLegalOrCustom.
17189         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
17190         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
17191       EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout());
17192       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
17193                          DAG.getConstant(OrigElt, DL, IndexTy));
17194     }
17195   }
17196 
17197   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
17198   // simplify it based on the (valid) extraction indices.
17199   if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
17200         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
17201                Use->getOperand(0) == VecOp &&
17202                isa<ConstantSDNode>(Use->getOperand(1));
17203       })) {
17204     APInt DemandedElts = APInt::getNullValue(NumElts);
17205     for (SDNode *Use : VecOp->uses()) {
17206       auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
17207       if (CstElt->getAPIntValue().ult(NumElts))
17208         DemandedElts.setBit(CstElt->getZExtValue());
17209     }
17210     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
17211       // We simplified the vector operand of this extract element. If this
17212       // extract is not dead, visit it again so it is folded properly.
17213       if (N->getOpcode() != ISD::DELETED_NODE)
17214         AddToWorklist(N);
17215       return SDValue(N, 0);
17216     }
17217   }
17218 
17219   // Everything under here is trying to match an extract of a loaded value.
17220   // If the result of load has to be truncated, then it's not necessarily
17221   // profitable.
17222   bool BCNumEltsChanged = false;
17223   EVT ExtVT = VecVT.getVectorElementType();
17224   EVT LVT = ExtVT;
17225   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
17226     return SDValue();
17227 
17228   if (VecOp.getOpcode() == ISD::BITCAST) {
17229     // Don't duplicate a load with other uses.
17230     if (!VecOp.hasOneUse())
17231       return SDValue();
17232 
17233     EVT BCVT = VecOp.getOperand(0).getValueType();
17234     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
17235       return SDValue();
17236     if (NumElts != BCVT.getVectorNumElements())
17237       BCNumEltsChanged = true;
17238     VecOp = VecOp.getOperand(0);
17239     ExtVT = BCVT.getVectorElementType();
17240   }
17241 
17242   // extract (vector load $addr), i --> load $addr + i * size
17243   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
17244       ISD::isNormalLoad(VecOp.getNode()) &&
17245       !Index->hasPredecessor(VecOp.getNode())) {
17246     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
17247     if (VecLoad && VecLoad->isSimple())
17248       return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
17249   }
17250 
17251   // Perform only after legalization to ensure build_vector / vector_shuffle
17252   // optimizations have already been done.
17253   if (!LegalOperations || !IndexC)
17254     return SDValue();
17255 
17256   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
17257   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
17258   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
17259   int Elt = IndexC->getZExtValue();
17260   LoadSDNode *LN0 = nullptr;
17261   if (ISD::isNormalLoad(VecOp.getNode())) {
17262     LN0 = cast<LoadSDNode>(VecOp);
17263   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
17264              VecOp.getOperand(0).getValueType() == ExtVT &&
17265              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
17266     // Don't duplicate a load with other uses.
17267     if (!VecOp.hasOneUse())
17268       return SDValue();
17269 
17270     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
17271   }
17272   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
17273     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
17274     // =>
17275     // (load $addr+1*size)
17276 
17277     // Don't duplicate a load with other uses.
17278     if (!VecOp.hasOneUse())
17279       return SDValue();
17280 
17281     // If the bit convert changed the number of elements, it is unsafe
17282     // to examine the mask.
17283     if (BCNumEltsChanged)
17284       return SDValue();
17285 
17286     // Select the input vector, guarding against out of range extract vector.
17287     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
17288     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
17289 
17290     if (VecOp.getOpcode() == ISD::BITCAST) {
17291       // Don't duplicate a load with other uses.
17292       if (!VecOp.hasOneUse())
17293         return SDValue();
17294 
17295       VecOp = VecOp.getOperand(0);
17296     }
17297     if (ISD::isNormalLoad(VecOp.getNode())) {
17298       LN0 = cast<LoadSDNode>(VecOp);
17299       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
17300       Index = DAG.getConstant(Elt, DL, Index.getValueType());
17301     }
17302   }
17303 
17304   // Make sure we found a non-volatile load and the extractelement is
17305   // the only use.
17306   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
17307     return SDValue();
17308 
17309   // If Idx was -1 above, Elt is going to be -1, so just return undef.
17310   if (Elt == -1)
17311     return DAG.getUNDEF(LVT);
17312 
17313   return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
17314 }
17315 
17316 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)17317 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
17318   // We perform this optimization post type-legalization because
17319   // the type-legalizer often scalarizes integer-promoted vectors.
17320   // Performing this optimization before may create bit-casts which
17321   // will be type-legalized to complex code sequences.
17322   // We perform this optimization only before the operation legalizer because we
17323   // may introduce illegal operations.
17324   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
17325     return SDValue();
17326 
17327   unsigned NumInScalars = N->getNumOperands();
17328   SDLoc DL(N);
17329   EVT VT = N->getValueType(0);
17330 
17331   // Check to see if this is a BUILD_VECTOR of a bunch of values
17332   // which come from any_extend or zero_extend nodes. If so, we can create
17333   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
17334   // optimizations. We do not handle sign-extend because we can't fill the sign
17335   // using shuffles.
17336   EVT SourceType = MVT::Other;
17337   bool AllAnyExt = true;
17338 
17339   for (unsigned i = 0; i != NumInScalars; ++i) {
17340     SDValue In = N->getOperand(i);
17341     // Ignore undef inputs.
17342     if (In.isUndef()) continue;
17343 
17344     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
17345     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
17346 
17347     // Abort if the element is not an extension.
17348     if (!ZeroExt && !AnyExt) {
17349       SourceType = MVT::Other;
17350       break;
17351     }
17352 
17353     // The input is a ZeroExt or AnyExt. Check the original type.
17354     EVT InTy = In.getOperand(0).getValueType();
17355 
17356     // Check that all of the widened source types are the same.
17357     if (SourceType == MVT::Other)
17358       // First time.
17359       SourceType = InTy;
17360     else if (InTy != SourceType) {
17361       // Multiple income types. Abort.
17362       SourceType = MVT::Other;
17363       break;
17364     }
17365 
17366     // Check if all of the extends are ANY_EXTENDs.
17367     AllAnyExt &= AnyExt;
17368   }
17369 
17370   // In order to have valid types, all of the inputs must be extended from the
17371   // same source type and all of the inputs must be any or zero extend.
17372   // Scalar sizes must be a power of two.
17373   EVT OutScalarTy = VT.getScalarType();
17374   bool ValidTypes = SourceType != MVT::Other &&
17375                  isPowerOf2_32(OutScalarTy.getSizeInBits()) &&
17376                  isPowerOf2_32(SourceType.getSizeInBits());
17377 
17378   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
17379   // turn into a single shuffle instruction.
17380   if (!ValidTypes)
17381     return SDValue();
17382 
17383   bool isLE = DAG.getDataLayout().isLittleEndian();
17384   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
17385   assert(ElemRatio > 1 && "Invalid element size ratio");
17386   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
17387                                DAG.getConstant(0, DL, SourceType);
17388 
17389   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
17390   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
17391 
17392   // Populate the new build_vector
17393   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
17394     SDValue Cast = N->getOperand(i);
17395     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
17396             Cast.getOpcode() == ISD::ZERO_EXTEND ||
17397             Cast.isUndef()) && "Invalid cast opcode");
17398     SDValue In;
17399     if (Cast.isUndef())
17400       In = DAG.getUNDEF(SourceType);
17401     else
17402       In = Cast->getOperand(0);
17403     unsigned Index = isLE ? (i * ElemRatio) :
17404                             (i * ElemRatio + (ElemRatio - 1));
17405 
17406     assert(Index < Ops.size() && "Invalid index");
17407     Ops[Index] = In;
17408   }
17409 
17410   // The type of the new BUILD_VECTOR node.
17411   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
17412   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
17413          "Invalid vector size");
17414   // Check if the new vector type is legal.
17415   if (!isTypeLegal(VecVT) ||
17416       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
17417        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
17418     return SDValue();
17419 
17420   // Make the new BUILD_VECTOR.
17421   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
17422 
17423   // The new BUILD_VECTOR node has the potential to be further optimized.
17424   AddToWorklist(BV.getNode());
17425   // Bitcast to the desired type.
17426   return DAG.getBitcast(VT, BV);
17427 }
17428 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)17429 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
17430                                            ArrayRef<int> VectorMask,
17431                                            SDValue VecIn1, SDValue VecIn2,
17432                                            unsigned LeftIdx, bool DidSplitVec) {
17433   MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
17434   SDValue ZeroIdx = DAG.getConstant(0, DL, IdxTy);
17435 
17436   EVT VT = N->getValueType(0);
17437   EVT InVT1 = VecIn1.getValueType();
17438   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
17439 
17440   unsigned NumElems = VT.getVectorNumElements();
17441   unsigned ShuffleNumElems = NumElems;
17442 
17443   // If we artificially split a vector in two already, then the offsets in the
17444   // operands will all be based off of VecIn1, even those in VecIn2.
17445   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
17446 
17447   // We can't generate a shuffle node with mismatched input and output types.
17448   // Try to make the types match the type of the output.
17449   if (InVT1 != VT || InVT2 != VT) {
17450     if ((VT.getSizeInBits() % InVT1.getSizeInBits() == 0) && InVT1 == InVT2) {
17451       // If the output vector length is a multiple of both input lengths,
17452       // we can concatenate them and pad the rest with undefs.
17453       unsigned NumConcats = VT.getSizeInBits() / InVT1.getSizeInBits();
17454       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
17455       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
17456       ConcatOps[0] = VecIn1;
17457       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
17458       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
17459       VecIn2 = SDValue();
17460     } else if (InVT1.getSizeInBits() == VT.getSizeInBits() * 2) {
17461       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
17462         return SDValue();
17463 
17464       if (!VecIn2.getNode()) {
17465         // If we only have one input vector, and it's twice the size of the
17466         // output, split it in two.
17467         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
17468                              DAG.getConstant(NumElems, DL, IdxTy));
17469         VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
17470         // Since we now have shorter input vectors, adjust the offset of the
17471         // second vector's start.
17472         Vec2Offset = NumElems;
17473       } else if (InVT2.getSizeInBits() <= InVT1.getSizeInBits()) {
17474         // VecIn1 is wider than the output, and we have another, possibly
17475         // smaller input. Pad the smaller input with undefs, shuffle at the
17476         // input vector width, and extract the output.
17477         // The shuffle type is different than VT, so check legality again.
17478         if (LegalOperations &&
17479             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
17480           return SDValue();
17481 
17482         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
17483         // lower it back into a BUILD_VECTOR. So if the inserted type is
17484         // illegal, don't even try.
17485         if (InVT1 != InVT2) {
17486           if (!TLI.isTypeLegal(InVT2))
17487             return SDValue();
17488           VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
17489                                DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
17490         }
17491         ShuffleNumElems = NumElems * 2;
17492       } else {
17493         // Both VecIn1 and VecIn2 are wider than the output, and VecIn2 is wider
17494         // than VecIn1. We can't handle this for now - this case will disappear
17495         // when we start sorting the vectors by type.
17496         return SDValue();
17497       }
17498     } else if (InVT2.getSizeInBits() * 2 == VT.getSizeInBits() &&
17499                InVT1.getSizeInBits() == VT.getSizeInBits()) {
17500       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
17501       ConcatOps[0] = VecIn2;
17502       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
17503     } else {
17504       // TODO: Support cases where the length mismatch isn't exactly by a
17505       // factor of 2.
17506       // TODO: Move this check upwards, so that if we have bad type
17507       // mismatches, we don't create any DAG nodes.
17508       return SDValue();
17509     }
17510   }
17511 
17512   // Initialize mask to undef.
17513   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
17514 
17515   // Only need to run up to the number of elements actually used, not the
17516   // total number of elements in the shuffle - if we are shuffling a wider
17517   // vector, the high lanes should be set to undef.
17518   for (unsigned i = 0; i != NumElems; ++i) {
17519     if (VectorMask[i] <= 0)
17520       continue;
17521 
17522     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
17523     if (VectorMask[i] == (int)LeftIdx) {
17524       Mask[i] = ExtIndex;
17525     } else if (VectorMask[i] == (int)LeftIdx + 1) {
17526       Mask[i] = Vec2Offset + ExtIndex;
17527     }
17528   }
17529 
17530   // The type the input vectors may have changed above.
17531   InVT1 = VecIn1.getValueType();
17532 
17533   // If we already have a VecIn2, it should have the same type as VecIn1.
17534   // If we don't, get an undef/zero vector of the appropriate type.
17535   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
17536   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
17537 
17538   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
17539   if (ShuffleNumElems > NumElems)
17540     Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
17541 
17542   return Shuffle;
17543 }
17544 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)17545 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
17546   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
17547 
17548   // First, determine where the build vector is not undef.
17549   // TODO: We could extend this to handle zero elements as well as undefs.
17550   int NumBVOps = BV->getNumOperands();
17551   int ZextElt = -1;
17552   for (int i = 0; i != NumBVOps; ++i) {
17553     SDValue Op = BV->getOperand(i);
17554     if (Op.isUndef())
17555       continue;
17556     if (ZextElt == -1)
17557       ZextElt = i;
17558     else
17559       return SDValue();
17560   }
17561   // Bail out if there's no non-undef element.
17562   if (ZextElt == -1)
17563     return SDValue();
17564 
17565   // The build vector contains some number of undef elements and exactly
17566   // one other element. That other element must be a zero-extended scalar
17567   // extracted from a vector at a constant index to turn this into a shuffle.
17568   // Also, require that the build vector does not implicitly truncate/extend
17569   // its elements.
17570   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
17571   EVT VT = BV->getValueType(0);
17572   SDValue Zext = BV->getOperand(ZextElt);
17573   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
17574       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
17575       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
17576       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
17577     return SDValue();
17578 
17579   // The zero-extend must be a multiple of the source size, and we must be
17580   // building a vector of the same size as the source of the extract element.
17581   SDValue Extract = Zext.getOperand(0);
17582   unsigned DestSize = Zext.getValueSizeInBits();
17583   unsigned SrcSize = Extract.getValueSizeInBits();
17584   if (DestSize % SrcSize != 0 ||
17585       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
17586     return SDValue();
17587 
17588   // Create a shuffle mask that will combine the extracted element with zeros
17589   // and undefs.
17590   int ZextRatio = DestSize / SrcSize;
17591   int NumMaskElts = NumBVOps * ZextRatio;
17592   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
17593   for (int i = 0; i != NumMaskElts; ++i) {
17594     if (i / ZextRatio == ZextElt) {
17595       // The low bits of the (potentially translated) extracted element map to
17596       // the source vector. The high bits map to zero. We will use a zero vector
17597       // as the 2nd source operand of the shuffle, so use the 1st element of
17598       // that vector (mask value is number-of-elements) for the high bits.
17599       if (i % ZextRatio == 0)
17600         ShufMask[i] = Extract.getConstantOperandVal(1);
17601       else
17602         ShufMask[i] = NumMaskElts;
17603     }
17604 
17605     // Undef elements of the build vector remain undef because we initialize
17606     // the shuffle mask with -1.
17607   }
17608 
17609   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
17610   // bitcast (shuffle V, ZeroVec, VectorMask)
17611   SDLoc DL(BV);
17612   EVT VecVT = Extract.getOperand(0).getValueType();
17613   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
17614   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17615   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
17616                                              ZeroVec, ShufMask, DAG);
17617   if (!Shuf)
17618     return SDValue();
17619   return DAG.getBitcast(VT, Shuf);
17620 }
17621 
17622 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
17623 // operations. If the types of the vectors we're extracting from allow it,
17624 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)17625 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
17626   SDLoc DL(N);
17627   EVT VT = N->getValueType(0);
17628 
17629   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
17630   if (!isTypeLegal(VT))
17631     return SDValue();
17632 
17633   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
17634     return V;
17635 
17636   // May only combine to shuffle after legalize if shuffle is legal.
17637   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
17638     return SDValue();
17639 
17640   bool UsesZeroVector = false;
17641   unsigned NumElems = N->getNumOperands();
17642 
17643   // Record, for each element of the newly built vector, which input vector
17644   // that element comes from. -1 stands for undef, 0 for the zero vector,
17645   // and positive values for the input vectors.
17646   // VectorMask maps each element to its vector number, and VecIn maps vector
17647   // numbers to their initial SDValues.
17648 
17649   SmallVector<int, 8> VectorMask(NumElems, -1);
17650   SmallVector<SDValue, 8> VecIn;
17651   VecIn.push_back(SDValue());
17652 
17653   for (unsigned i = 0; i != NumElems; ++i) {
17654     SDValue Op = N->getOperand(i);
17655 
17656     if (Op.isUndef())
17657       continue;
17658 
17659     // See if we can use a blend with a zero vector.
17660     // TODO: Should we generalize this to a blend with an arbitrary constant
17661     // vector?
17662     if (isNullConstant(Op) || isNullFPConstant(Op)) {
17663       UsesZeroVector = true;
17664       VectorMask[i] = 0;
17665       continue;
17666     }
17667 
17668     // Not an undef or zero. If the input is something other than an
17669     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
17670     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
17671         !isa<ConstantSDNode>(Op.getOperand(1)))
17672       return SDValue();
17673     SDValue ExtractedFromVec = Op.getOperand(0);
17674 
17675     const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
17676     if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
17677       return SDValue();
17678 
17679     // All inputs must have the same element type as the output.
17680     if (VT.getVectorElementType() !=
17681         ExtractedFromVec.getValueType().getVectorElementType())
17682       return SDValue();
17683 
17684     // Have we seen this input vector before?
17685     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
17686     // a map back from SDValues to numbers isn't worth it.
17687     unsigned Idx = std::distance(
17688         VecIn.begin(), std::find(VecIn.begin(), VecIn.end(), ExtractedFromVec));
17689     if (Idx == VecIn.size())
17690       VecIn.push_back(ExtractedFromVec);
17691 
17692     VectorMask[i] = Idx;
17693   }
17694 
17695   // If we didn't find at least one input vector, bail out.
17696   if (VecIn.size() < 2)
17697     return SDValue();
17698 
17699   // If all the Operands of BUILD_VECTOR extract from same
17700   // vector, then split the vector efficiently based on the maximum
17701   // vector access index and adjust the VectorMask and
17702   // VecIn accordingly.
17703   bool DidSplitVec = false;
17704   if (VecIn.size() == 2) {
17705     unsigned MaxIndex = 0;
17706     unsigned NearestPow2 = 0;
17707     SDValue Vec = VecIn.back();
17708     EVT InVT = Vec.getValueType();
17709     MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
17710     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
17711 
17712     for (unsigned i = 0; i < NumElems; i++) {
17713       if (VectorMask[i] <= 0)
17714         continue;
17715       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
17716       IndexVec[i] = Index;
17717       MaxIndex = std::max(MaxIndex, Index);
17718     }
17719 
17720     NearestPow2 = PowerOf2Ceil(MaxIndex);
17721     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
17722         NumElems * 2 < NearestPow2) {
17723       unsigned SplitSize = NearestPow2 / 2;
17724       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
17725                                      InVT.getVectorElementType(), SplitSize);
17726       if (TLI.isTypeLegal(SplitVT)) {
17727         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
17728                                      DAG.getConstant(SplitSize, DL, IdxTy));
17729         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
17730                                      DAG.getConstant(0, DL, IdxTy));
17731         VecIn.pop_back();
17732         VecIn.push_back(VecIn1);
17733         VecIn.push_back(VecIn2);
17734         DidSplitVec = true;
17735 
17736         for (unsigned i = 0; i < NumElems; i++) {
17737           if (VectorMask[i] <= 0)
17738             continue;
17739           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
17740         }
17741       }
17742     }
17743   }
17744 
17745   // TODO: We want to sort the vectors by descending length, so that adjacent
17746   // pairs have similar length, and the longer vector is always first in the
17747   // pair.
17748 
17749   // TODO: Should this fire if some of the input vectors has illegal type (like
17750   // it does now), or should we let legalization run its course first?
17751 
17752   // Shuffle phase:
17753   // Take pairs of vectors, and shuffle them so that the result has elements
17754   // from these vectors in the correct places.
17755   // For example, given:
17756   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
17757   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
17758   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
17759   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
17760   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
17761   // We will generate:
17762   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
17763   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
17764   SmallVector<SDValue, 4> Shuffles;
17765   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
17766     unsigned LeftIdx = 2 * In + 1;
17767     SDValue VecLeft = VecIn[LeftIdx];
17768     SDValue VecRight =
17769         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
17770 
17771     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
17772                                                 VecRight, LeftIdx, DidSplitVec))
17773       Shuffles.push_back(Shuffle);
17774     else
17775       return SDValue();
17776   }
17777 
17778   // If we need the zero vector as an "ingredient" in the blend tree, add it
17779   // to the list of shuffles.
17780   if (UsesZeroVector)
17781     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
17782                                       : DAG.getConstantFP(0.0, DL, VT));
17783 
17784   // If we only have one shuffle, we're done.
17785   if (Shuffles.size() == 1)
17786     return Shuffles[0];
17787 
17788   // Update the vector mask to point to the post-shuffle vectors.
17789   for (int &Vec : VectorMask)
17790     if (Vec == 0)
17791       Vec = Shuffles.size() - 1;
17792     else
17793       Vec = (Vec - 1) / 2;
17794 
17795   // More than one shuffle. Generate a binary tree of blends, e.g. if from
17796   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
17797   // generate:
17798   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
17799   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
17800   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
17801   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
17802   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
17803   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
17804   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
17805 
17806   // Make sure the initial size of the shuffle list is even.
17807   if (Shuffles.size() % 2)
17808     Shuffles.push_back(DAG.getUNDEF(VT));
17809 
17810   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
17811     if (CurSize % 2) {
17812       Shuffles[CurSize] = DAG.getUNDEF(VT);
17813       CurSize++;
17814     }
17815     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
17816       int Left = 2 * In;
17817       int Right = 2 * In + 1;
17818       SmallVector<int, 8> Mask(NumElems, -1);
17819       for (unsigned i = 0; i != NumElems; ++i) {
17820         if (VectorMask[i] == Left) {
17821           Mask[i] = i;
17822           VectorMask[i] = In;
17823         } else if (VectorMask[i] == Right) {
17824           Mask[i] = i + NumElems;
17825           VectorMask[i] = In;
17826         }
17827       }
17828 
17829       Shuffles[In] =
17830           DAG.getVectorShuffle(VT, DL, Shuffles[Left], Shuffles[Right], Mask);
17831     }
17832   }
17833   return Shuffles[0];
17834 }
17835 
17836 // Try to turn a build vector of zero extends of extract vector elts into a
17837 // a vector zero extend and possibly an extract subvector.
17838 // TODO: Support sign extend?
17839 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)17840 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
17841   if (LegalOperations)
17842     return SDValue();
17843 
17844   EVT VT = N->getValueType(0);
17845 
17846   bool FoundZeroExtend = false;
17847   SDValue Op0 = N->getOperand(0);
17848   auto checkElem = [&](SDValue Op) -> int64_t {
17849     unsigned Opc = Op.getOpcode();
17850     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
17851     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
17852         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
17853         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
17854       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
17855         return C->getZExtValue();
17856     return -1;
17857   };
17858 
17859   // Make sure the first element matches
17860   // (zext (extract_vector_elt X, C))
17861   int64_t Offset = checkElem(Op0);
17862   if (Offset < 0)
17863     return SDValue();
17864 
17865   unsigned NumElems = N->getNumOperands();
17866   SDValue In = Op0.getOperand(0).getOperand(0);
17867   EVT InSVT = In.getValueType().getScalarType();
17868   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
17869 
17870   // Don't create an illegal input type after type legalization.
17871   if (LegalTypes && !TLI.isTypeLegal(InVT))
17872     return SDValue();
17873 
17874   // Ensure all the elements come from the same vector and are adjacent.
17875   for (unsigned i = 1; i != NumElems; ++i) {
17876     if ((Offset + i) != checkElem(N->getOperand(i)))
17877       return SDValue();
17878   }
17879 
17880   SDLoc DL(N);
17881   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
17882                    Op0.getOperand(0).getOperand(1));
17883   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
17884                      VT, In);
17885 }
17886 
visitBUILD_VECTOR(SDNode * N)17887 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
17888   EVT VT = N->getValueType(0);
17889 
17890   // A vector built entirely of undefs is undef.
17891   if (ISD::allOperandsUndef(N))
17892     return DAG.getUNDEF(VT);
17893 
17894   // If this is a splat of a bitcast from another vector, change to a
17895   // concat_vector.
17896   // For example:
17897   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
17898   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
17899   //
17900   // If X is a build_vector itself, the concat can become a larger build_vector.
17901   // TODO: Maybe this is useful for non-splat too?
17902   if (!LegalOperations) {
17903     if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
17904       Splat = peekThroughBitcasts(Splat);
17905       EVT SrcVT = Splat.getValueType();
17906       if (SrcVT.isVector()) {
17907         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
17908         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
17909                                      SrcVT.getVectorElementType(), NumElts);
17910         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
17911           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
17912           SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
17913                                        NewVT, Ops);
17914           return DAG.getBitcast(VT, Concat);
17915         }
17916       }
17917     }
17918   }
17919 
17920   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
17921   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
17922     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
17923       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
17924       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
17925     }
17926 
17927   // Check if we can express BUILD VECTOR via subvector extract.
17928   if (!LegalTypes && (N->getNumOperands() > 1)) {
17929     SDValue Op0 = N->getOperand(0);
17930     auto checkElem = [&](SDValue Op) -> uint64_t {
17931       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
17932           (Op0.getOperand(0) == Op.getOperand(0)))
17933         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
17934           return CNode->getZExtValue();
17935       return -1;
17936     };
17937 
17938     int Offset = checkElem(Op0);
17939     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
17940       if (Offset + i != checkElem(N->getOperand(i))) {
17941         Offset = -1;
17942         break;
17943       }
17944     }
17945 
17946     if ((Offset == 0) &&
17947         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
17948       return Op0.getOperand(0);
17949     if ((Offset != -1) &&
17950         ((Offset % N->getValueType(0).getVectorNumElements()) ==
17951          0)) // IDX must be multiple of output size.
17952       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
17953                          Op0.getOperand(0), Op0.getOperand(1));
17954   }
17955 
17956   if (SDValue V = convertBuildVecZextToZext(N))
17957     return V;
17958 
17959   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
17960     return V;
17961 
17962   if (SDValue V = reduceBuildVecToShuffle(N))
17963     return V;
17964 
17965   return SDValue();
17966 }
17967 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)17968 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
17969   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17970   EVT OpVT = N->getOperand(0).getValueType();
17971 
17972   // If the operands are legal vectors, leave them alone.
17973   if (TLI.isTypeLegal(OpVT))
17974     return SDValue();
17975 
17976   SDLoc DL(N);
17977   EVT VT = N->getValueType(0);
17978   SmallVector<SDValue, 8> Ops;
17979 
17980   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
17981   SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
17982 
17983   // Keep track of what we encounter.
17984   bool AnyInteger = false;
17985   bool AnyFP = false;
17986   for (const SDValue &Op : N->ops()) {
17987     if (ISD::BITCAST == Op.getOpcode() &&
17988         !Op.getOperand(0).getValueType().isVector())
17989       Ops.push_back(Op.getOperand(0));
17990     else if (ISD::UNDEF == Op.getOpcode())
17991       Ops.push_back(ScalarUndef);
17992     else
17993       return SDValue();
17994 
17995     // Note whether we encounter an integer or floating point scalar.
17996     // If it's neither, bail out, it could be something weird like x86mmx.
17997     EVT LastOpVT = Ops.back().getValueType();
17998     if (LastOpVT.isFloatingPoint())
17999       AnyFP = true;
18000     else if (LastOpVT.isInteger())
18001       AnyInteger = true;
18002     else
18003       return SDValue();
18004   }
18005 
18006   // If any of the operands is a floating point scalar bitcast to a vector,
18007   // use floating point types throughout, and bitcast everything.
18008   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
18009   if (AnyFP) {
18010     SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
18011     ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
18012     if (AnyInteger) {
18013       for (SDValue &Op : Ops) {
18014         if (Op.getValueType() == SVT)
18015           continue;
18016         if (Op.isUndef())
18017           Op = ScalarUndef;
18018         else
18019           Op = DAG.getBitcast(SVT, Op);
18020       }
18021     }
18022   }
18023 
18024   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
18025                                VT.getSizeInBits() / SVT.getSizeInBits());
18026   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
18027 }
18028 
18029 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
18030 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
18031 // most two distinct vectors the same size as the result, attempt to turn this
18032 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)18033 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
18034   EVT VT = N->getValueType(0);
18035   EVT OpVT = N->getOperand(0).getValueType();
18036   int NumElts = VT.getVectorNumElements();
18037   int NumOpElts = OpVT.getVectorNumElements();
18038 
18039   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
18040   SmallVector<int, 8> Mask;
18041 
18042   for (SDValue Op : N->ops()) {
18043     Op = peekThroughBitcasts(Op);
18044 
18045     // UNDEF nodes convert to UNDEF shuffle mask values.
18046     if (Op.isUndef()) {
18047       Mask.append((unsigned)NumOpElts, -1);
18048       continue;
18049     }
18050 
18051     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
18052       return SDValue();
18053 
18054     // What vector are we extracting the subvector from and at what index?
18055     SDValue ExtVec = Op.getOperand(0);
18056 
18057     // We want the EVT of the original extraction to correctly scale the
18058     // extraction index.
18059     EVT ExtVT = ExtVec.getValueType();
18060     ExtVec = peekThroughBitcasts(ExtVec);
18061 
18062     // UNDEF nodes convert to UNDEF shuffle mask values.
18063     if (ExtVec.isUndef()) {
18064       Mask.append((unsigned)NumOpElts, -1);
18065       continue;
18066     }
18067 
18068     if (!isa<ConstantSDNode>(Op.getOperand(1)))
18069       return SDValue();
18070     int ExtIdx = Op.getConstantOperandVal(1);
18071 
18072     // Ensure that we are extracting a subvector from a vector the same
18073     // size as the result.
18074     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
18075       return SDValue();
18076 
18077     // Scale the subvector index to account for any bitcast.
18078     int NumExtElts = ExtVT.getVectorNumElements();
18079     if (0 == (NumExtElts % NumElts))
18080       ExtIdx /= (NumExtElts / NumElts);
18081     else if (0 == (NumElts % NumExtElts))
18082       ExtIdx *= (NumElts / NumExtElts);
18083     else
18084       return SDValue();
18085 
18086     // At most we can reference 2 inputs in the final shuffle.
18087     if (SV0.isUndef() || SV0 == ExtVec) {
18088       SV0 = ExtVec;
18089       for (int i = 0; i != NumOpElts; ++i)
18090         Mask.push_back(i + ExtIdx);
18091     } else if (SV1.isUndef() || SV1 == ExtVec) {
18092       SV1 = ExtVec;
18093       for (int i = 0; i != NumOpElts; ++i)
18094         Mask.push_back(i + ExtIdx + NumElts);
18095     } else {
18096       return SDValue();
18097     }
18098   }
18099 
18100   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18101   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
18102                                      DAG.getBitcast(VT, SV1), Mask, DAG);
18103 }
18104 
visitCONCAT_VECTORS(SDNode * N)18105 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
18106   // If we only have one input vector, we don't need to do any concatenation.
18107   if (N->getNumOperands() == 1)
18108     return N->getOperand(0);
18109 
18110   // Check if all of the operands are undefs.
18111   EVT VT = N->getValueType(0);
18112   if (ISD::allOperandsUndef(N))
18113     return DAG.getUNDEF(VT);
18114 
18115   // Optimize concat_vectors where all but the first of the vectors are undef.
18116   if (std::all_of(std::next(N->op_begin()), N->op_end(), [](const SDValue &Op) {
18117         return Op.isUndef();
18118       })) {
18119     SDValue In = N->getOperand(0);
18120     assert(In.getValueType().isVector() && "Must concat vectors");
18121 
18122     // If the input is a concat_vectors, just make a larger concat by padding
18123     // with smaller undefs.
18124     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) {
18125       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
18126       SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
18127       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
18128       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
18129     }
18130 
18131     SDValue Scalar = peekThroughOneUseBitcasts(In);
18132 
18133     // concat_vectors(scalar_to_vector(scalar), undef) ->
18134     //     scalar_to_vector(scalar)
18135     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
18136          Scalar.hasOneUse()) {
18137       EVT SVT = Scalar.getValueType().getVectorElementType();
18138       if (SVT == Scalar.getOperand(0).getValueType())
18139         Scalar = Scalar.getOperand(0);
18140     }
18141 
18142     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
18143     if (!Scalar.getValueType().isVector()) {
18144       // If the bitcast type isn't legal, it might be a trunc of a legal type;
18145       // look through the trunc so we can still do the transform:
18146       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
18147       if (Scalar->getOpcode() == ISD::TRUNCATE &&
18148           !TLI.isTypeLegal(Scalar.getValueType()) &&
18149           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
18150         Scalar = Scalar->getOperand(0);
18151 
18152       EVT SclTy = Scalar.getValueType();
18153 
18154       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
18155         return SDValue();
18156 
18157       // Bail out if the vector size is not a multiple of the scalar size.
18158       if (VT.getSizeInBits() % SclTy.getSizeInBits())
18159         return SDValue();
18160 
18161       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
18162       if (VNTNumElms < 2)
18163         return SDValue();
18164 
18165       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
18166       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
18167         return SDValue();
18168 
18169       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
18170       return DAG.getBitcast(VT, Res);
18171     }
18172   }
18173 
18174   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
18175   // We have already tested above for an UNDEF only concatenation.
18176   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
18177   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
18178   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
18179     return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
18180   };
18181   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
18182     SmallVector<SDValue, 8> Opnds;
18183     EVT SVT = VT.getScalarType();
18184 
18185     EVT MinVT = SVT;
18186     if (!SVT.isFloatingPoint()) {
18187       // If BUILD_VECTOR are from built from integer, they may have different
18188       // operand types. Get the smallest type and truncate all operands to it.
18189       bool FoundMinVT = false;
18190       for (const SDValue &Op : N->ops())
18191         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
18192           EVT OpSVT = Op.getOperand(0).getValueType();
18193           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
18194           FoundMinVT = true;
18195         }
18196       assert(FoundMinVT && "Concat vector type mismatch");
18197     }
18198 
18199     for (const SDValue &Op : N->ops()) {
18200       EVT OpVT = Op.getValueType();
18201       unsigned NumElts = OpVT.getVectorNumElements();
18202 
18203       if (ISD::UNDEF == Op.getOpcode())
18204         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
18205 
18206       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
18207         if (SVT.isFloatingPoint()) {
18208           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
18209           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
18210         } else {
18211           for (unsigned i = 0; i != NumElts; ++i)
18212             Opnds.push_back(
18213                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
18214         }
18215       }
18216     }
18217 
18218     assert(VT.getVectorNumElements() == Opnds.size() &&
18219            "Concat vector type mismatch");
18220     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
18221   }
18222 
18223   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
18224   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
18225     return V;
18226 
18227   // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
18228   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT))
18229     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
18230       return V;
18231 
18232   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
18233   // nodes often generate nop CONCAT_VECTOR nodes.
18234   // Scan the CONCAT_VECTOR operands and look for a CONCAT operations that
18235   // place the incoming vectors at the exact same location.
18236   SDValue SingleSource = SDValue();
18237   unsigned PartNumElem = N->getOperand(0).getValueType().getVectorNumElements();
18238 
18239   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
18240     SDValue Op = N->getOperand(i);
18241 
18242     if (Op.isUndef())
18243       continue;
18244 
18245     // Check if this is the identity extract:
18246     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
18247       return SDValue();
18248 
18249     // Find the single incoming vector for the extract_subvector.
18250     if (SingleSource.getNode()) {
18251       if (Op.getOperand(0) != SingleSource)
18252         return SDValue();
18253     } else {
18254       SingleSource = Op.getOperand(0);
18255 
18256       // Check the source type is the same as the type of the result.
18257       // If not, this concat may extend the vector, so we can not
18258       // optimize it away.
18259       if (SingleSource.getValueType() != N->getValueType(0))
18260         return SDValue();
18261     }
18262 
18263     auto *CS = dyn_cast<ConstantSDNode>(Op.getOperand(1));
18264     // The extract index must be constant.
18265     if (!CS)
18266       return SDValue();
18267 
18268     // Check that we are reading from the identity index.
18269     unsigned IdentityIndex = i * PartNumElem;
18270     if (CS->getAPIntValue() != IdentityIndex)
18271       return SDValue();
18272   }
18273 
18274   if (SingleSource.getNode())
18275     return SingleSource;
18276 
18277   return SDValue();
18278 }
18279 
18280 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
18281 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)18282 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
18283   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
18284       V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
18285     return V.getOperand(1);
18286   }
18287   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
18288   if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
18289       V.getOperand(0).getValueType() == SubVT &&
18290       (IndexC->getZExtValue() % SubVT.getVectorNumElements()) == 0) {
18291     uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorNumElements();
18292     return V.getOperand(SubIdx);
18293   }
18294   return SDValue();
18295 }
18296 
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG)18297 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
18298                                               SelectionDAG &DAG) {
18299   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18300   SDValue BinOp = Extract->getOperand(0);
18301   unsigned BinOpcode = BinOp.getOpcode();
18302   if (!TLI.isBinOp(BinOpcode) || BinOp.getNode()->getNumValues() != 1)
18303     return SDValue();
18304 
18305   EVT VecVT = BinOp.getValueType();
18306   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
18307   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
18308     return SDValue();
18309 
18310   SDValue Index = Extract->getOperand(1);
18311   EVT SubVT = Extract->getValueType(0);
18312   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT))
18313     return SDValue();
18314 
18315   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
18316   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
18317 
18318   // TODO: We could handle the case where only 1 operand is being inserted by
18319   //       creating an extract of the other operand, but that requires checking
18320   //       number of uses and/or costs.
18321   if (!Sub0 || !Sub1)
18322     return SDValue();
18323 
18324   // We are inserting both operands of the wide binop only to extract back
18325   // to the narrow vector size. Eliminate all of the insert/extract:
18326   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
18327   return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
18328                      BinOp->getFlags());
18329 }
18330 
18331 /// If we are extracting a subvector produced by a wide binary operator try
18332 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG)18333 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) {
18334   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
18335   // some of these bailouts with other transforms.
18336 
18337   if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG))
18338     return V;
18339 
18340   // The extract index must be a constant, so we can map it to a concat operand.
18341   auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
18342   if (!ExtractIndexC)
18343     return SDValue();
18344 
18345   // We are looking for an optionally bitcasted wide vector binary operator
18346   // feeding an extract subvector.
18347   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18348   SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
18349   unsigned BOpcode = BinOp.getOpcode();
18350   if (!TLI.isBinOp(BOpcode) || BinOp.getNode()->getNumValues() != 1)
18351     return SDValue();
18352 
18353   // The binop must be a vector type, so we can extract some fraction of it.
18354   EVT WideBVT = BinOp.getValueType();
18355   if (!WideBVT.isVector())
18356     return SDValue();
18357 
18358   EVT VT = Extract->getValueType(0);
18359   unsigned ExtractIndex = ExtractIndexC->getZExtValue();
18360   assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
18361          "Extract index is not a multiple of the vector length.");
18362 
18363   // Bail out if this is not a proper multiple width extraction.
18364   unsigned WideWidth = WideBVT.getSizeInBits();
18365   unsigned NarrowWidth = VT.getSizeInBits();
18366   if (WideWidth % NarrowWidth != 0)
18367     return SDValue();
18368 
18369   // Bail out if we are extracting a fraction of a single operation. This can
18370   // occur because we potentially looked through a bitcast of the binop.
18371   unsigned NarrowingRatio = WideWidth / NarrowWidth;
18372   unsigned WideNumElts = WideBVT.getVectorNumElements();
18373   if (WideNumElts % NarrowingRatio != 0)
18374     return SDValue();
18375 
18376   // Bail out if the target does not support a narrower version of the binop.
18377   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
18378                                    WideNumElts / NarrowingRatio);
18379   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
18380     return SDValue();
18381 
18382   // If extraction is cheap, we don't need to look at the binop operands
18383   // for concat ops. The narrow binop alone makes this transform profitable.
18384   // We can't just reuse the original extract index operand because we may have
18385   // bitcasted.
18386   unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
18387   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
18388   EVT ExtBOIdxVT = Extract->getOperand(1).getValueType();
18389   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
18390       BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
18391     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
18392     SDLoc DL(Extract);
18393     SDValue NewExtIndex = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT);
18394     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
18395                             BinOp.getOperand(0), NewExtIndex);
18396     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
18397                             BinOp.getOperand(1), NewExtIndex);
18398     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y,
18399                                       BinOp.getNode()->getFlags());
18400     return DAG.getBitcast(VT, NarrowBinOp);
18401   }
18402 
18403   // Only handle the case where we are doubling and then halving. A larger ratio
18404   // may require more than two narrow binops to replace the wide binop.
18405   if (NarrowingRatio != 2)
18406     return SDValue();
18407 
18408   // TODO: The motivating case for this transform is an x86 AVX1 target. That
18409   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
18410   // flavors, but no other 256-bit integer support. This could be extended to
18411   // handle any binop, but that may require fixing/adding other folds to avoid
18412   // codegen regressions.
18413   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
18414     return SDValue();
18415 
18416   // We need at least one concatenation operation of a binop operand to make
18417   // this transform worthwhile. The concat must double the input vector sizes.
18418   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
18419     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
18420       return V.getOperand(ConcatOpNum);
18421     return SDValue();
18422   };
18423   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
18424   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
18425 
18426   if (SubVecL || SubVecR) {
18427     // If a binop operand was not the result of a concat, we must extract a
18428     // half-sized operand for our new narrow binop:
18429     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
18430     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
18431     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
18432     SDLoc DL(Extract);
18433     SDValue IndexC = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT);
18434     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
18435                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
18436                                       BinOp.getOperand(0), IndexC);
18437 
18438     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
18439                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
18440                                       BinOp.getOperand(1), IndexC);
18441 
18442     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
18443     return DAG.getBitcast(VT, NarrowBinOp);
18444   }
18445 
18446   return SDValue();
18447 }
18448 
18449 /// If we are extracting a subvector from a wide vector load, convert to a
18450 /// narrow load to eliminate the extraction:
18451 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)18452 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
18453   // TODO: Add support for big-endian. The offset calculation must be adjusted.
18454   if (DAG.getDataLayout().isBigEndian())
18455     return SDValue();
18456 
18457   auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
18458   auto *ExtIdx = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
18459   if (!Ld || Ld->getExtensionType() || !Ld->isSimple() ||
18460       !ExtIdx)
18461     return SDValue();
18462 
18463   // Allow targets to opt-out.
18464   EVT VT = Extract->getValueType(0);
18465   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18466   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
18467     return SDValue();
18468 
18469   // The narrow load will be offset from the base address of the old load if
18470   // we are extracting from something besides index 0 (little-endian).
18471   SDLoc DL(Extract);
18472   SDValue BaseAddr = Ld->getOperand(1);
18473   unsigned Offset = ExtIdx->getZExtValue() * VT.getScalarType().getStoreSize();
18474 
18475   // TODO: Use "BaseIndexOffset" to make this more effective.
18476   SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL);
18477   MachineFunction &MF = DAG.getMachineFunction();
18478   MachineMemOperand *MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset,
18479                                                    VT.getStoreSize());
18480   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
18481   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
18482   return NewLd;
18483 }
18484 
visitEXTRACT_SUBVECTOR(SDNode * N)18485 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
18486   EVT NVT = N->getValueType(0);
18487   SDValue V = N->getOperand(0);
18488 
18489   // Extract from UNDEF is UNDEF.
18490   if (V.isUndef())
18491     return DAG.getUNDEF(NVT);
18492 
18493   if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
18494     if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
18495       return NarrowLoad;
18496 
18497   // Combine an extract of an extract into a single extract_subvector.
18498   // ext (ext X, C), 0 --> ext X, C
18499   SDValue Index = N->getOperand(1);
18500   if (isNullConstant(Index) && V.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
18501       V.hasOneUse() && isa<ConstantSDNode>(V.getOperand(1))) {
18502     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
18503                                     V.getConstantOperandVal(1)) &&
18504         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
18505       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
18506                          V.getOperand(1));
18507     }
18508   }
18509 
18510   // Try to move vector bitcast after extract_subv by scaling extraction index:
18511   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
18512   if (isa<ConstantSDNode>(Index) && V.getOpcode() == ISD::BITCAST &&
18513       V.getOperand(0).getValueType().isVector()) {
18514     SDValue SrcOp = V.getOperand(0);
18515     EVT SrcVT = SrcOp.getValueType();
18516     unsigned SrcNumElts = SrcVT.getVectorNumElements();
18517     unsigned DestNumElts = V.getValueType().getVectorNumElements();
18518     if ((SrcNumElts % DestNumElts) == 0) {
18519       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
18520       unsigned NewExtNumElts = NVT.getVectorNumElements() * SrcDestRatio;
18521       EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
18522                                       NewExtNumElts);
18523       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
18524         unsigned IndexValScaled = N->getConstantOperandVal(1) * SrcDestRatio;
18525         SDLoc DL(N);
18526         SDValue NewIndex = DAG.getIntPtrConstant(IndexValScaled, DL);
18527         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
18528                                          V.getOperand(0), NewIndex);
18529         return DAG.getBitcast(NVT, NewExtract);
18530       }
18531     }
18532     if ((DestNumElts % SrcNumElts) == 0) {
18533       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
18534       if ((NVT.getVectorNumElements() % DestSrcRatio) == 0) {
18535         unsigned NewExtNumElts = NVT.getVectorNumElements() / DestSrcRatio;
18536         EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(),
18537                                         SrcVT.getScalarType(), NewExtNumElts);
18538         if ((N->getConstantOperandVal(1) % DestSrcRatio) == 0 &&
18539             TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
18540           unsigned IndexValScaled = N->getConstantOperandVal(1) / DestSrcRatio;
18541           SDLoc DL(N);
18542           SDValue NewIndex = DAG.getIntPtrConstant(IndexValScaled, DL);
18543           SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
18544                                            V.getOperand(0), NewIndex);
18545           return DAG.getBitcast(NVT, NewExtract);
18546         }
18547       }
18548     }
18549   }
18550 
18551   if (V.getOpcode() == ISD::CONCAT_VECTORS && isa<ConstantSDNode>(Index)) {
18552     EVT ConcatSrcVT = V.getOperand(0).getValueType();
18553     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
18554            "Concat and extract subvector do not change element type");
18555 
18556     unsigned ExtIdx = N->getConstantOperandVal(1);
18557     unsigned ExtNumElts = NVT.getVectorNumElements();
18558     assert(ExtIdx % ExtNumElts == 0 &&
18559            "Extract index is not a multiple of the input vector length.");
18560 
18561     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorNumElements();
18562     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
18563 
18564     // If the concatenated source types match this extract, it's a direct
18565     // simplification:
18566     // extract_subvec (concat V1, V2, ...), i --> Vi
18567     if (ConcatSrcNumElts == ExtNumElts)
18568       return V.getOperand(ConcatOpIdx);
18569 
18570     // If the concatenated source vectors are a multiple length of this extract,
18571     // then extract a fraction of one of those source vectors directly from a
18572     // concat operand. Example:
18573     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
18574     //   v2i8 extract_subvec v8i8 Y, 6
18575     if (ConcatSrcNumElts % ExtNumElts == 0) {
18576       SDLoc DL(N);
18577       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
18578       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
18579              "Trying to extract from >1 concat operand?");
18580       assert(NewExtIdx % ExtNumElts == 0 &&
18581              "Extract index is not a multiple of the input vector length.");
18582       MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
18583       SDValue NewIndexC = DAG.getConstant(NewExtIdx, DL, IdxTy);
18584       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
18585                          V.getOperand(ConcatOpIdx), NewIndexC);
18586     }
18587   }
18588 
18589   V = peekThroughBitcasts(V);
18590 
18591   // If the input is a build vector. Try to make a smaller build vector.
18592   if (V.getOpcode() == ISD::BUILD_VECTOR) {
18593     if (auto *IdxC = dyn_cast<ConstantSDNode>(Index)) {
18594       EVT InVT = V.getValueType();
18595       unsigned ExtractSize = NVT.getSizeInBits();
18596       unsigned EltSize = InVT.getScalarSizeInBits();
18597       // Only do this if we won't split any elements.
18598       if (ExtractSize % EltSize == 0) {
18599         unsigned NumElems = ExtractSize / EltSize;
18600         EVT EltVT = InVT.getVectorElementType();
18601         EVT ExtractVT = NumElems == 1 ? EltVT
18602                                       : EVT::getVectorVT(*DAG.getContext(),
18603                                                          EltVT, NumElems);
18604         if ((Level < AfterLegalizeDAG ||
18605              (NumElems == 1 ||
18606               TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
18607             (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
18608           unsigned IdxVal = IdxC->getZExtValue();
18609           IdxVal *= NVT.getScalarSizeInBits();
18610           IdxVal /= EltSize;
18611 
18612           if (NumElems == 1) {
18613             SDValue Src = V->getOperand(IdxVal);
18614             if (EltVT != Src.getValueType())
18615               Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
18616             return DAG.getBitcast(NVT, Src);
18617           }
18618 
18619           // Extract the pieces from the original build_vector.
18620           SDValue BuildVec = DAG.getBuildVector(
18621               ExtractVT, SDLoc(N), V->ops().slice(IdxVal, NumElems));
18622           return DAG.getBitcast(NVT, BuildVec);
18623         }
18624       }
18625     }
18626   }
18627 
18628   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
18629     // Handle only simple case where vector being inserted and vector
18630     // being extracted are of same size.
18631     EVT SmallVT = V.getOperand(1).getValueType();
18632     if (!NVT.bitsEq(SmallVT))
18633       return SDValue();
18634 
18635     // Only handle cases where both indexes are constants.
18636     auto *ExtIdx = dyn_cast<ConstantSDNode>(Index);
18637     auto *InsIdx = dyn_cast<ConstantSDNode>(V.getOperand(2));
18638     if (InsIdx && ExtIdx) {
18639       // Combine:
18640       //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
18641       // Into:
18642       //    indices are equal or bit offsets are equal => V1
18643       //    otherwise => (extract_subvec V1, ExtIdx)
18644       if (InsIdx->getZExtValue() * SmallVT.getScalarSizeInBits() ==
18645           ExtIdx->getZExtValue() * NVT.getScalarSizeInBits())
18646         return DAG.getBitcast(NVT, V.getOperand(1));
18647       return DAG.getNode(
18648           ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
18649           DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
18650           Index);
18651     }
18652   }
18653 
18654   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG))
18655     return NarrowBOp;
18656 
18657   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
18658     return SDValue(N, 0);
18659 
18660   return SDValue();
18661 }
18662 
18663 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
18664 /// followed by concatenation. Narrow vector ops may have better performance
18665 /// than wide ops, and this can unlock further narrowing of other vector ops.
18666 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)18667 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
18668                                          SelectionDAG &DAG) {
18669   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
18670   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
18671       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
18672       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
18673     return SDValue();
18674 
18675   // Split the wide shuffle mask into halves. Any mask element that is accessing
18676   // operand 1 is offset down to account for narrowing of the vectors.
18677   ArrayRef<int> Mask = Shuf->getMask();
18678   EVT VT = Shuf->getValueType(0);
18679   unsigned NumElts = VT.getVectorNumElements();
18680   unsigned HalfNumElts = NumElts / 2;
18681   SmallVector<int, 16> Mask0(HalfNumElts, -1);
18682   SmallVector<int, 16> Mask1(HalfNumElts, -1);
18683   for (unsigned i = 0; i != NumElts; ++i) {
18684     if (Mask[i] == -1)
18685       continue;
18686     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
18687     if (i < HalfNumElts)
18688       Mask0[i] = M;
18689     else
18690       Mask1[i - HalfNumElts] = M;
18691   }
18692 
18693   // Ask the target if this is a valid transform.
18694   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18695   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
18696                                 HalfNumElts);
18697   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
18698       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
18699     return SDValue();
18700 
18701   // shuffle (concat X, undef), (concat Y, undef), Mask -->
18702   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
18703   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
18704   SDLoc DL(Shuf);
18705   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
18706   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
18707   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
18708 }
18709 
18710 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
18711 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)18712 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
18713   EVT VT = N->getValueType(0);
18714   unsigned NumElts = VT.getVectorNumElements();
18715 
18716   SDValue N0 = N->getOperand(0);
18717   SDValue N1 = N->getOperand(1);
18718   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
18719   ArrayRef<int> Mask = SVN->getMask();
18720 
18721   SmallVector<SDValue, 4> Ops;
18722   EVT ConcatVT = N0.getOperand(0).getValueType();
18723   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
18724   unsigned NumConcats = NumElts / NumElemsPerConcat;
18725 
18726   auto IsUndefMaskElt = [](int i) { return i == -1; };
18727 
18728   // Special case: shuffle(concat(A,B)) can be more efficiently represented
18729   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
18730   // half vector elements.
18731   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
18732       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
18733                    IsUndefMaskElt)) {
18734     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
18735                               N0.getOperand(1),
18736                               Mask.slice(0, NumElemsPerConcat));
18737     N1 = DAG.getUNDEF(ConcatVT);
18738     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
18739   }
18740 
18741   // Look at every vector that's inserted. We're looking for exact
18742   // subvector-sized copies from a concatenated vector
18743   for (unsigned I = 0; I != NumConcats; ++I) {
18744     unsigned Begin = I * NumElemsPerConcat;
18745     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
18746 
18747     // Make sure we're dealing with a copy.
18748     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
18749       Ops.push_back(DAG.getUNDEF(ConcatVT));
18750       continue;
18751     }
18752 
18753     int OpIdx = -1;
18754     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
18755       if (IsUndefMaskElt(SubMask[i]))
18756         continue;
18757       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
18758         return SDValue();
18759       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
18760       if (0 <= OpIdx && EltOpIdx != OpIdx)
18761         return SDValue();
18762       OpIdx = EltOpIdx;
18763     }
18764     assert(0 <= OpIdx && "Unknown concat_vectors op");
18765 
18766     if (OpIdx < (int)N0.getNumOperands())
18767       Ops.push_back(N0.getOperand(OpIdx));
18768     else
18769       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
18770   }
18771 
18772   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
18773 }
18774 
18775 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
18776 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
18777 //
18778 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
18779 // a simplification in some sense, but it isn't appropriate in general: some
18780 // BUILD_VECTORs are substantially cheaper than others. The general case
18781 // of a BUILD_VECTOR requires inserting each element individually (or
18782 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
18783 // all constants is a single constant pool load.  A BUILD_VECTOR where each
18784 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
18785 // are undef lowers to a small number of element insertions.
18786 //
18787 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
18788 // We don't fold shuffles where one side is a non-zero constant, and we don't
18789 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
18790 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)18791 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
18792                                        SelectionDAG &DAG,
18793                                        const TargetLowering &TLI) {
18794   EVT VT = SVN->getValueType(0);
18795   unsigned NumElts = VT.getVectorNumElements();
18796   SDValue N0 = SVN->getOperand(0);
18797   SDValue N1 = SVN->getOperand(1);
18798 
18799   if (!N0->hasOneUse())
18800     return SDValue();
18801 
18802   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
18803   // discussed above.
18804   if (!N1.isUndef()) {
18805     if (!N1->hasOneUse())
18806       return SDValue();
18807 
18808     bool N0AnyConst = isAnyConstantBuildVector(N0);
18809     bool N1AnyConst = isAnyConstantBuildVector(N1);
18810     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
18811       return SDValue();
18812     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
18813       return SDValue();
18814   }
18815 
18816   // If both inputs are splats of the same value then we can safely merge this
18817   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
18818   bool IsSplat = false;
18819   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
18820   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
18821   if (BV0 && BV1)
18822     if (SDValue Splat0 = BV0->getSplatValue())
18823       IsSplat = (Splat0 == BV1->getSplatValue());
18824 
18825   SmallVector<SDValue, 8> Ops;
18826   SmallSet<SDValue, 16> DuplicateOps;
18827   for (int M : SVN->getMask()) {
18828     SDValue Op = DAG.getUNDEF(VT.getScalarType());
18829     if (M >= 0) {
18830       int Idx = M < (int)NumElts ? M : M - NumElts;
18831       SDValue &S = (M < (int)NumElts ? N0 : N1);
18832       if (S.getOpcode() == ISD::BUILD_VECTOR) {
18833         Op = S.getOperand(Idx);
18834       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
18835         SDValue Op0 = S.getOperand(0);
18836         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
18837       } else {
18838         // Operand can't be combined - bail out.
18839         return SDValue();
18840       }
18841     }
18842 
18843     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
18844     // generating a splat; semantically, this is fine, but it's likely to
18845     // generate low-quality code if the target can't reconstruct an appropriate
18846     // shuffle.
18847     if (!Op.isUndef() && !isa<ConstantSDNode>(Op) && !isa<ConstantFPSDNode>(Op))
18848       if (!IsSplat && !DuplicateOps.insert(Op).second)
18849         return SDValue();
18850 
18851     Ops.push_back(Op);
18852   }
18853 
18854   // BUILD_VECTOR requires all inputs to be of the same type, find the
18855   // maximum type and extend them all.
18856   EVT SVT = VT.getScalarType();
18857   if (SVT.isInteger())
18858     for (SDValue &Op : Ops)
18859       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
18860   if (SVT != VT.getScalarType())
18861     for (SDValue &Op : Ops)
18862       Op = TLI.isZExtFree(Op.getValueType(), SVT)
18863                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
18864                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT);
18865   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
18866 }
18867 
18868 // Match shuffles that can be converted to any_vector_extend_in_reg.
18869 // This is often generated during legalization.
18870 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
18871 // TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case.
combineShuffleToVectorExtend(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)18872 static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN,
18873                                             SelectionDAG &DAG,
18874                                             const TargetLowering &TLI,
18875                                             bool LegalOperations) {
18876   EVT VT = SVN->getValueType(0);
18877   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
18878 
18879   // TODO Add support for big-endian when we have a test case.
18880   if (!VT.isInteger() || IsBigEndian)
18881     return SDValue();
18882 
18883   unsigned NumElts = VT.getVectorNumElements();
18884   unsigned EltSizeInBits = VT.getScalarSizeInBits();
18885   ArrayRef<int> Mask = SVN->getMask();
18886   SDValue N0 = SVN->getOperand(0);
18887 
18888   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
18889   auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) {
18890     for (unsigned i = 0; i != NumElts; ++i) {
18891       if (Mask[i] < 0)
18892         continue;
18893       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
18894         continue;
18895       return false;
18896     }
18897     return true;
18898   };
18899 
18900   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
18901   // power-of-2 extensions as they are the most likely.
18902   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
18903     // Check for non power of 2 vector sizes
18904     if (NumElts % Scale != 0)
18905       continue;
18906     if (!isAnyExtend(Scale))
18907       continue;
18908 
18909     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
18910     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
18911     // Never create an illegal type. Only create unsupported operations if we
18912     // are pre-legalization.
18913     if (TLI.isTypeLegal(OutVT))
18914       if (!LegalOperations ||
18915           TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT))
18916         return DAG.getBitcast(VT,
18917                               DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG,
18918                                           SDLoc(SVN), OutVT, N0));
18919   }
18920 
18921   return SDValue();
18922 }
18923 
18924 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
18925 // each source element of a large type into the lowest elements of a smaller
18926 // destination type. This is often generated during legalization.
18927 // If the source node itself was a '*_extend_vector_inreg' node then we should
18928 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)18929 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
18930                                         SelectionDAG &DAG) {
18931   EVT VT = SVN->getValueType(0);
18932   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
18933 
18934   // TODO Add support for big-endian when we have a test case.
18935   if (!VT.isInteger() || IsBigEndian)
18936     return SDValue();
18937 
18938   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
18939 
18940   unsigned Opcode = N0.getOpcode();
18941   if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
18942       Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
18943       Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
18944     return SDValue();
18945 
18946   SDValue N00 = N0.getOperand(0);
18947   ArrayRef<int> Mask = SVN->getMask();
18948   unsigned NumElts = VT.getVectorNumElements();
18949   unsigned EltSizeInBits = VT.getScalarSizeInBits();
18950   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
18951   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
18952 
18953   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
18954     return SDValue();
18955   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
18956 
18957   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
18958   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
18959   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
18960   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
18961     for (unsigned i = 0; i != NumElts; ++i) {
18962       if (Mask[i] < 0)
18963         continue;
18964       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
18965         continue;
18966       return false;
18967     }
18968     return true;
18969   };
18970 
18971   // At the moment we just handle the case where we've truncated back to the
18972   // same size as before the extension.
18973   // TODO: handle more extension/truncation cases as cases arise.
18974   if (EltSizeInBits != ExtSrcSizeInBits)
18975     return SDValue();
18976 
18977   // We can remove *extend_vector_inreg only if the truncation happens at
18978   // the same scale as the extension.
18979   if (isTruncate(ExtScale))
18980     return DAG.getBitcast(VT, N00);
18981 
18982   return SDValue();
18983 }
18984 
18985 // Combine shuffles of splat-shuffles of the form:
18986 // shuffle (shuffle V, undef, splat-mask), undef, M
18987 // If splat-mask contains undef elements, we need to be careful about
18988 // introducing undef's in the folded mask which are not the result of composing
18989 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)18990 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
18991                                         SelectionDAG &DAG) {
18992   if (!Shuf->getOperand(1).isUndef())
18993     return SDValue();
18994   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
18995   if (!Splat || !Splat->isSplat())
18996     return SDValue();
18997 
18998   ArrayRef<int> ShufMask = Shuf->getMask();
18999   ArrayRef<int> SplatMask = Splat->getMask();
19000   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
19001 
19002   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
19003   // every undef mask element in the splat-shuffle has a corresponding undef
19004   // element in the user-shuffle's mask or if the composition of mask elements
19005   // would result in undef.
19006   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
19007   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
19008   //   In this case it is not legal to simplify to the splat-shuffle because we
19009   //   may be exposing the users of the shuffle an undef element at index 1
19010   //   which was not there before the combine.
19011   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
19012   //   In this case the composition of masks yields SplatMask, so it's ok to
19013   //   simplify to the splat-shuffle.
19014   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
19015   //   In this case the composed mask includes all undef elements of SplatMask
19016   //   and in addition sets element zero to undef. It is safe to simplify to
19017   //   the splat-shuffle.
19018   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
19019                                        ArrayRef<int> SplatMask) {
19020     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
19021       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
19022           SplatMask[UserMask[i]] != -1)
19023         return false;
19024     return true;
19025   };
19026   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
19027     return Shuf->getOperand(0);
19028 
19029   // Create a new shuffle with a mask that is composed of the two shuffles'
19030   // masks.
19031   SmallVector<int, 32> NewMask;
19032   for (int Idx : ShufMask)
19033     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
19034 
19035   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
19036                               Splat->getOperand(0), Splat->getOperand(1),
19037                               NewMask);
19038 }
19039 
19040 /// If the shuffle mask is taking exactly one element from the first vector
19041 /// operand and passing through all other elements from the second vector
19042 /// operand, return the index of the mask element that is choosing an element
19043 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)19044 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
19045   int MaskSize = Mask.size();
19046   int EltFromOp0 = -1;
19047   // TODO: This does not match if there are undef elements in the shuffle mask.
19048   // Should we ignore undefs in the shuffle mask instead? The trade-off is
19049   // removing an instruction (a shuffle), but losing the knowledge that some
19050   // vector lanes are not needed.
19051   for (int i = 0; i != MaskSize; ++i) {
19052     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
19053       // We're looking for a shuffle of exactly one element from operand 0.
19054       if (EltFromOp0 != -1)
19055         return -1;
19056       EltFromOp0 = i;
19057     } else if (Mask[i] != i + MaskSize) {
19058       // Nothing from operand 1 can change lanes.
19059       return -1;
19060     }
19061   }
19062   return EltFromOp0;
19063 }
19064 
19065 /// If a shuffle inserts exactly one element from a source vector operand into
19066 /// another vector operand and we can access the specified element as a scalar,
19067 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)19068 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
19069                                       SelectionDAG &DAG) {
19070   // First, check if we are taking one element of a vector and shuffling that
19071   // element into another vector.
19072   ArrayRef<int> Mask = Shuf->getMask();
19073   SmallVector<int, 16> CommutedMask(Mask.begin(), Mask.end());
19074   SDValue Op0 = Shuf->getOperand(0);
19075   SDValue Op1 = Shuf->getOperand(1);
19076   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
19077   if (ShufOp0Index == -1) {
19078     // Commute mask and check again.
19079     ShuffleVectorSDNode::commuteMask(CommutedMask);
19080     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
19081     if (ShufOp0Index == -1)
19082       return SDValue();
19083     // Commute operands to match the commuted shuffle mask.
19084     std::swap(Op0, Op1);
19085     Mask = CommutedMask;
19086   }
19087 
19088   // The shuffle inserts exactly one element from operand 0 into operand 1.
19089   // Now see if we can access that element as a scalar via a real insert element
19090   // instruction.
19091   // TODO: We can try harder to locate the element as a scalar. Examples: it
19092   // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
19093   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
19094          "Shuffle mask value must be from operand 0");
19095   if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
19096     return SDValue();
19097 
19098   auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
19099   if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
19100     return SDValue();
19101 
19102   // There's an existing insertelement with constant insertion index, so we
19103   // don't need to check the legality/profitability of a replacement operation
19104   // that differs at most in the constant value. The target should be able to
19105   // lower any of those in a similar way. If not, legalization will expand this
19106   // to a scalar-to-vector plus shuffle.
19107   //
19108   // Note that the shuffle may move the scalar from the position that the insert
19109   // element used. Therefore, our new insert element occurs at the shuffle's
19110   // mask index value, not the insert's index value.
19111   // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
19112   SDValue NewInsIndex = DAG.getConstant(ShufOp0Index, SDLoc(Shuf),
19113                                         Op0.getOperand(2).getValueType());
19114   return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
19115                      Op1, Op0.getOperand(1), NewInsIndex);
19116 }
19117 
19118 /// If we have a unary shuffle of a shuffle, see if it can be folded away
19119 /// completely. This has the potential to lose undef knowledge because the first
19120 /// shuffle may not have an undef mask element where the second one does. So
19121 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)19122 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
19123   // shuf (shuf0 X, Y, Mask0), undef, Mask
19124   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
19125   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
19126     return SDValue();
19127 
19128   ArrayRef<int> Mask = Shuf->getMask();
19129   ArrayRef<int> Mask0 = Shuf0->getMask();
19130   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
19131     // Ignore undef elements.
19132     if (Mask[i] == -1)
19133       continue;
19134     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
19135 
19136     // Is the element of the shuffle operand chosen by this shuffle the same as
19137     // the element chosen by the shuffle operand itself?
19138     if (Mask0[Mask[i]] != Mask0[i])
19139       return SDValue();
19140   }
19141   // Every element of this shuffle is identical to the result of the previous
19142   // shuffle, so we can replace this value.
19143   return Shuf->getOperand(0);
19144 }
19145 
visitVECTOR_SHUFFLE(SDNode * N)19146 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
19147   EVT VT = N->getValueType(0);
19148   unsigned NumElts = VT.getVectorNumElements();
19149 
19150   SDValue N0 = N->getOperand(0);
19151   SDValue N1 = N->getOperand(1);
19152 
19153   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
19154 
19155   // Canonicalize shuffle undef, undef -> undef
19156   if (N0.isUndef() && N1.isUndef())
19157     return DAG.getUNDEF(VT);
19158 
19159   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
19160 
19161   // Canonicalize shuffle v, v -> v, undef
19162   if (N0 == N1) {
19163     SmallVector<int, 8> NewMask;
19164     for (unsigned i = 0; i != NumElts; ++i) {
19165       int Idx = SVN->getMaskElt(i);
19166       if (Idx >= (int)NumElts) Idx -= NumElts;
19167       NewMask.push_back(Idx);
19168     }
19169     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT), NewMask);
19170   }
19171 
19172   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
19173   if (N0.isUndef())
19174     return DAG.getCommutedVectorShuffle(*SVN);
19175 
19176   // Remove references to rhs if it is undef
19177   if (N1.isUndef()) {
19178     bool Changed = false;
19179     SmallVector<int, 8> NewMask;
19180     for (unsigned i = 0; i != NumElts; ++i) {
19181       int Idx = SVN->getMaskElt(i);
19182       if (Idx >= (int)NumElts) {
19183         Idx = -1;
19184         Changed = true;
19185       }
19186       NewMask.push_back(Idx);
19187     }
19188     if (Changed)
19189       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
19190   }
19191 
19192   if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
19193     return InsElt;
19194 
19195   // A shuffle of a single vector that is a splatted value can always be folded.
19196   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
19197     return V;
19198 
19199   // If it is a splat, check if the argument vector is another splat or a
19200   // build_vector.
19201   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
19202     int SplatIndex = SVN->getSplatIndex();
19203     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
19204         TLI.isBinOp(N0.getOpcode()) && N0.getNode()->getNumValues() == 1) {
19205       // splat (vector_bo L, R), Index -->
19206       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
19207       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
19208       SDLoc DL(N);
19209       EVT EltVT = VT.getScalarType();
19210       SDValue Index = DAG.getIntPtrConstant(SplatIndex, DL);
19211       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
19212       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
19213       SDValue NewBO = DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR,
19214                                   N0.getNode()->getFlags());
19215       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
19216       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
19217       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
19218     }
19219 
19220     // If this is a bit convert that changes the element type of the vector but
19221     // not the number of vector elements, look through it.  Be careful not to
19222     // look though conversions that change things like v4f32 to v2f64.
19223     SDNode *V = N0.getNode();
19224     if (V->getOpcode() == ISD::BITCAST) {
19225       SDValue ConvInput = V->getOperand(0);
19226       if (ConvInput.getValueType().isVector() &&
19227           ConvInput.getValueType().getVectorNumElements() == NumElts)
19228         V = ConvInput.getNode();
19229     }
19230 
19231     if (V->getOpcode() == ISD::BUILD_VECTOR) {
19232       assert(V->getNumOperands() == NumElts &&
19233              "BUILD_VECTOR has wrong number of operands");
19234       SDValue Base;
19235       bool AllSame = true;
19236       for (unsigned i = 0; i != NumElts; ++i) {
19237         if (!V->getOperand(i).isUndef()) {
19238           Base = V->getOperand(i);
19239           break;
19240         }
19241       }
19242       // Splat of <u, u, u, u>, return <u, u, u, u>
19243       if (!Base.getNode())
19244         return N0;
19245       for (unsigned i = 0; i != NumElts; ++i) {
19246         if (V->getOperand(i) != Base) {
19247           AllSame = false;
19248           break;
19249         }
19250       }
19251       // Splat of <x, x, x, x>, return <x, x, x, x>
19252       if (AllSame)
19253         return N0;
19254 
19255       // Canonicalize any other splat as a build_vector.
19256       SDValue Splatted = V->getOperand(SplatIndex);
19257       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
19258       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
19259 
19260       // We may have jumped through bitcasts, so the type of the
19261       // BUILD_VECTOR may not match the type of the shuffle.
19262       if (V->getValueType(0) != VT)
19263         NewBV = DAG.getBitcast(VT, NewBV);
19264       return NewBV;
19265     }
19266   }
19267 
19268   // Simplify source operands based on shuffle mask.
19269   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
19270     return SDValue(N, 0);
19271 
19272   // This is intentionally placed after demanded elements simplification because
19273   // it could eliminate knowledge of undef elements created by this shuffle.
19274   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
19275     return ShufOp;
19276 
19277   // Match shuffles that can be converted to any_vector_extend_in_reg.
19278   if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations))
19279     return V;
19280 
19281   // Combine "truncate_vector_in_reg" style shuffles.
19282   if (SDValue V = combineTruncationShuffle(SVN, DAG))
19283     return V;
19284 
19285   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
19286       Level < AfterLegalizeVectorOps &&
19287       (N1.isUndef() ||
19288       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
19289        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
19290     if (SDValue V = partitionShuffleOfConcats(N, DAG))
19291       return V;
19292   }
19293 
19294   // A shuffle of a concat of the same narrow vector can be reduced to use
19295   // only low-half elements of a concat with undef:
19296   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
19297   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
19298       N0.getNumOperands() == 2 &&
19299       N0.getOperand(0) == N0.getOperand(1)) {
19300     int HalfNumElts = (int)NumElts / 2;
19301     SmallVector<int, 8> NewMask;
19302     for (unsigned i = 0; i != NumElts; ++i) {
19303       int Idx = SVN->getMaskElt(i);
19304       if (Idx >= HalfNumElts) {
19305         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
19306         Idx -= HalfNumElts;
19307       }
19308       NewMask.push_back(Idx);
19309     }
19310     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
19311       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
19312       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
19313                                    N0.getOperand(0), UndefVec);
19314       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
19315     }
19316   }
19317 
19318   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
19319   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
19320   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
19321     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
19322       return Res;
19323 
19324   // If this shuffle only has a single input that is a bitcasted shuffle,
19325   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
19326   // back to their original types.
19327   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
19328       N1.isUndef() && Level < AfterLegalizeVectorOps &&
19329       TLI.isTypeLegal(VT)) {
19330     auto ScaleShuffleMask = [](ArrayRef<int> Mask, int Scale) {
19331       if (Scale == 1)
19332         return SmallVector<int, 8>(Mask.begin(), Mask.end());
19333 
19334       SmallVector<int, 8> NewMask;
19335       for (int M : Mask)
19336         for (int s = 0; s != Scale; ++s)
19337           NewMask.push_back(M < 0 ? -1 : Scale * M + s);
19338       return NewMask;
19339     };
19340 
19341     SDValue BC0 = peekThroughOneUseBitcasts(N0);
19342     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
19343       EVT SVT = VT.getScalarType();
19344       EVT InnerVT = BC0->getValueType(0);
19345       EVT InnerSVT = InnerVT.getScalarType();
19346 
19347       // Determine which shuffle works with the smaller scalar type.
19348       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
19349       EVT ScaleSVT = ScaleVT.getScalarType();
19350 
19351       if (TLI.isTypeLegal(ScaleVT) &&
19352           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
19353           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
19354         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
19355         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
19356 
19357         // Scale the shuffle masks to the smaller scalar type.
19358         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
19359         SmallVector<int, 8> InnerMask =
19360             ScaleShuffleMask(InnerSVN->getMask(), InnerScale);
19361         SmallVector<int, 8> OuterMask =
19362             ScaleShuffleMask(SVN->getMask(), OuterScale);
19363 
19364         // Merge the shuffle masks.
19365         SmallVector<int, 8> NewMask;
19366         for (int M : OuterMask)
19367           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
19368 
19369         // Test for shuffle mask legality over both commutations.
19370         SDValue SV0 = BC0->getOperand(0);
19371         SDValue SV1 = BC0->getOperand(1);
19372         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
19373         if (!LegalMask) {
19374           std::swap(SV0, SV1);
19375           ShuffleVectorSDNode::commuteMask(NewMask);
19376           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
19377         }
19378 
19379         if (LegalMask) {
19380           SV0 = DAG.getBitcast(ScaleVT, SV0);
19381           SV1 = DAG.getBitcast(ScaleVT, SV1);
19382           return DAG.getBitcast(
19383               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
19384         }
19385       }
19386     }
19387   }
19388 
19389   // Canonicalize shuffles according to rules:
19390   //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
19391   //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
19392   //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
19393   if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
19394       N0.getOpcode() != ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG &&
19395       TLI.isTypeLegal(VT)) {
19396     // The incoming shuffle must be of the same type as the result of the
19397     // current shuffle.
19398     assert(N1->getOperand(0).getValueType() == VT &&
19399            "Shuffle types don't match");
19400 
19401     SDValue SV0 = N1->getOperand(0);
19402     SDValue SV1 = N1->getOperand(1);
19403     bool HasSameOp0 = N0 == SV0;
19404     bool IsSV1Undef = SV1.isUndef();
19405     if (HasSameOp0 || IsSV1Undef || N0 == SV1)
19406       // Commute the operands of this shuffle so that next rule
19407       // will trigger.
19408       return DAG.getCommutedVectorShuffle(*SVN);
19409   }
19410 
19411   // Try to fold according to rules:
19412   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
19413   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
19414   //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
19415   // Don't try to fold shuffles with illegal type.
19416   // Only fold if this shuffle is the only user of the other shuffle.
19417   if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) &&
19418       Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
19419     ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0);
19420 
19421     // Don't try to fold splats; they're likely to simplify somehow, or they
19422     // might be free.
19423     if (OtherSV->isSplat())
19424       return SDValue();
19425 
19426     // The incoming shuffle must be of the same type as the result of the
19427     // current shuffle.
19428     assert(OtherSV->getOperand(0).getValueType() == VT &&
19429            "Shuffle types don't match");
19430 
19431     SDValue SV0, SV1;
19432     SmallVector<int, 4> Mask;
19433     // Compute the combined shuffle mask for a shuffle with SV0 as the first
19434     // operand, and SV1 as the second operand.
19435     for (unsigned i = 0; i != NumElts; ++i) {
19436       int Idx = SVN->getMaskElt(i);
19437       if (Idx < 0) {
19438         // Propagate Undef.
19439         Mask.push_back(Idx);
19440         continue;
19441       }
19442 
19443       SDValue CurrentVec;
19444       if (Idx < (int)NumElts) {
19445         // This shuffle index refers to the inner shuffle N0. Lookup the inner
19446         // shuffle mask to identify which vector is actually referenced.
19447         Idx = OtherSV->getMaskElt(Idx);
19448         if (Idx < 0) {
19449           // Propagate Undef.
19450           Mask.push_back(Idx);
19451           continue;
19452         }
19453 
19454         CurrentVec = (Idx < (int) NumElts) ? OtherSV->getOperand(0)
19455                                            : OtherSV->getOperand(1);
19456       } else {
19457         // This shuffle index references an element within N1.
19458         CurrentVec = N1;
19459       }
19460 
19461       // Simple case where 'CurrentVec' is UNDEF.
19462       if (CurrentVec.isUndef()) {
19463         Mask.push_back(-1);
19464         continue;
19465       }
19466 
19467       // Canonicalize the shuffle index. We don't know yet if CurrentVec
19468       // will be the first or second operand of the combined shuffle.
19469       Idx = Idx % NumElts;
19470       if (!SV0.getNode() || SV0 == CurrentVec) {
19471         // Ok. CurrentVec is the left hand side.
19472         // Update the mask accordingly.
19473         SV0 = CurrentVec;
19474         Mask.push_back(Idx);
19475         continue;
19476       }
19477 
19478       // Bail out if we cannot convert the shuffle pair into a single shuffle.
19479       if (SV1.getNode() && SV1 != CurrentVec)
19480         return SDValue();
19481 
19482       // Ok. CurrentVec is the right hand side.
19483       // Update the mask accordingly.
19484       SV1 = CurrentVec;
19485       Mask.push_back(Idx + NumElts);
19486     }
19487 
19488     // Check if all indices in Mask are Undef. In case, propagate Undef.
19489     bool isUndefMask = true;
19490     for (unsigned i = 0; i != NumElts && isUndefMask; ++i)
19491       isUndefMask &= Mask[i] < 0;
19492 
19493     if (isUndefMask)
19494       return DAG.getUNDEF(VT);
19495 
19496     if (!SV0.getNode())
19497       SV0 = DAG.getUNDEF(VT);
19498     if (!SV1.getNode())
19499       SV1 = DAG.getUNDEF(VT);
19500 
19501     // Avoid introducing shuffles with illegal mask.
19502     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
19503     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
19504     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
19505     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
19506     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
19507     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
19508     return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG);
19509   }
19510 
19511   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
19512     return V;
19513 
19514   return SDValue();
19515 }
19516 
visitSCALAR_TO_VECTOR(SDNode * N)19517 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
19518   SDValue InVal = N->getOperand(0);
19519   EVT VT = N->getValueType(0);
19520 
19521   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
19522   // with a VECTOR_SHUFFLE and possible truncate.
19523   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
19524     SDValue InVec = InVal->getOperand(0);
19525     SDValue EltNo = InVal->getOperand(1);
19526     auto InVecT = InVec.getValueType();
19527     if (ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo)) {
19528       SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1);
19529       int Elt = C0->getZExtValue();
19530       NewMask[0] = Elt;
19531       // If we have an implict truncate do truncate here as long as it's legal.
19532       // if it's not legal, this should
19533       if (VT.getScalarType() != InVal.getValueType() &&
19534           InVal.getValueType().isScalarInteger() &&
19535           isTypeLegal(VT.getScalarType())) {
19536         SDValue Val =
19537             DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal);
19538         return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
19539       }
19540       if (VT.getScalarType() == InVecT.getScalarType() &&
19541           VT.getVectorNumElements() <= InVecT.getVectorNumElements()) {
19542         SDValue LegalShuffle =
19543           TLI.buildLegalVectorShuffle(InVecT, SDLoc(N), InVec,
19544                                       DAG.getUNDEF(InVecT), NewMask, DAG);
19545         if (LegalShuffle) {
19546           // If the initial vector is the correct size this shuffle is a
19547           // valid result.
19548           if (VT == InVecT)
19549             return LegalShuffle;
19550           // If not we must truncate the vector.
19551           if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
19552             MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
19553             SDValue ZeroIdx = DAG.getConstant(0, SDLoc(N), IdxTy);
19554             EVT SubVT =
19555                 EVT::getVectorVT(*DAG.getContext(), InVecT.getVectorElementType(),
19556                                  VT.getVectorNumElements());
19557             return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT,
19558                                LegalShuffle, ZeroIdx);
19559           }
19560         }
19561       }
19562     }
19563   }
19564 
19565   return SDValue();
19566 }
19567 
visitINSERT_SUBVECTOR(SDNode * N)19568 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
19569   EVT VT = N->getValueType(0);
19570   SDValue N0 = N->getOperand(0);
19571   SDValue N1 = N->getOperand(1);
19572   SDValue N2 = N->getOperand(2);
19573 
19574   // If inserting an UNDEF, just return the original vector.
19575   if (N1.isUndef())
19576     return N0;
19577 
19578   // If this is an insert of an extracted vector into an undef vector, we can
19579   // just use the input to the extract.
19580   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
19581       N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
19582     return N1.getOperand(0);
19583 
19584   // If we are inserting a bitcast value into an undef, with the same
19585   // number of elements, just use the bitcast input of the extract.
19586   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
19587   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
19588   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
19589       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
19590       N1.getOperand(0).getOperand(1) == N2 &&
19591       N1.getOperand(0).getOperand(0).getValueType().getVectorNumElements() ==
19592           VT.getVectorNumElements() &&
19593       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
19594           VT.getSizeInBits()) {
19595     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
19596   }
19597 
19598   // If both N1 and N2 are bitcast values on which insert_subvector
19599   // would makes sense, pull the bitcast through.
19600   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
19601   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
19602   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
19603     SDValue CN0 = N0.getOperand(0);
19604     SDValue CN1 = N1.getOperand(0);
19605     EVT CN0VT = CN0.getValueType();
19606     EVT CN1VT = CN1.getValueType();
19607     if (CN0VT.isVector() && CN1VT.isVector() &&
19608         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
19609         CN0VT.getVectorNumElements() == VT.getVectorNumElements()) {
19610       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
19611                                       CN0.getValueType(), CN0, CN1, N2);
19612       return DAG.getBitcast(VT, NewINSERT);
19613     }
19614   }
19615 
19616   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
19617   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
19618   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
19619   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
19620       N0.getOperand(1).getValueType() == N1.getValueType() &&
19621       N0.getOperand(2) == N2)
19622     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
19623                        N1, N2);
19624 
19625   // Eliminate an intermediate insert into an undef vector:
19626   // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
19627   // insert_subvector undef, X, N2
19628   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
19629       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
19630     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
19631                        N1.getOperand(1), N2);
19632 
19633   if (!isa<ConstantSDNode>(N2))
19634     return SDValue();
19635 
19636   uint64_t InsIdx = cast<ConstantSDNode>(N2)->getZExtValue();
19637 
19638   // Push subvector bitcasts to the output, adjusting the index as we go.
19639   // insert_subvector(bitcast(v), bitcast(s), c1)
19640   // -> bitcast(insert_subvector(v, s, c2))
19641   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
19642       N1.getOpcode() == ISD::BITCAST) {
19643     SDValue N0Src = peekThroughBitcasts(N0);
19644     SDValue N1Src = peekThroughBitcasts(N1);
19645     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
19646     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
19647     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
19648         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
19649       EVT NewVT;
19650       SDLoc DL(N);
19651       SDValue NewIdx;
19652       MVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
19653       LLVMContext &Ctx = *DAG.getContext();
19654       unsigned NumElts = VT.getVectorNumElements();
19655       unsigned EltSizeInBits = VT.getScalarSizeInBits();
19656       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
19657         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
19658         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
19659         NewIdx = DAG.getConstant(InsIdx * Scale, DL, IdxVT);
19660       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
19661         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
19662         if ((NumElts % Scale) == 0 && (InsIdx % Scale) == 0) {
19663           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts / Scale);
19664           NewIdx = DAG.getConstant(InsIdx / Scale, DL, IdxVT);
19665         }
19666       }
19667       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
19668         SDValue Res = DAG.getBitcast(NewVT, N0Src);
19669         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
19670         return DAG.getBitcast(VT, Res);
19671       }
19672     }
19673   }
19674 
19675   // Canonicalize insert_subvector dag nodes.
19676   // Example:
19677   // (insert_subvector (insert_subvector A, Idx0), Idx1)
19678   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
19679   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
19680       N1.getValueType() == N0.getOperand(1).getValueType() &&
19681       isa<ConstantSDNode>(N0.getOperand(2))) {
19682     unsigned OtherIdx = N0.getConstantOperandVal(2);
19683     if (InsIdx < OtherIdx) {
19684       // Swap nodes.
19685       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
19686                                   N0.getOperand(0), N1, N2);
19687       AddToWorklist(NewOp.getNode());
19688       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
19689                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
19690     }
19691   }
19692 
19693   // If the input vector is a concatenation, and the insert replaces
19694   // one of the pieces, we can optimize into a single concat_vectors.
19695   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
19696       N0.getOperand(0).getValueType() == N1.getValueType()) {
19697     unsigned Factor = N1.getValueType().getVectorNumElements();
19698 
19699     SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
19700     Ops[cast<ConstantSDNode>(N2)->getZExtValue() / Factor] = N1;
19701 
19702     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
19703   }
19704 
19705   // Simplify source operands based on insertion.
19706   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
19707     return SDValue(N, 0);
19708 
19709   return SDValue();
19710 }
19711 
visitFP_TO_FP16(SDNode * N)19712 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
19713   SDValue N0 = N->getOperand(0);
19714 
19715   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
19716   if (N0->getOpcode() == ISD::FP16_TO_FP)
19717     return N0->getOperand(0);
19718 
19719   return SDValue();
19720 }
19721 
visitFP16_TO_FP(SDNode * N)19722 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
19723   SDValue N0 = N->getOperand(0);
19724 
19725   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
19726   if (N0->getOpcode() == ISD::AND) {
19727     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
19728     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
19729       return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
19730                          N0.getOperand(0));
19731     }
19732   }
19733 
19734   return SDValue();
19735 }
19736 
visitVECREDUCE(SDNode * N)19737 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
19738   SDValue N0 = N->getOperand(0);
19739   EVT VT = N0.getValueType();
19740   unsigned Opcode = N->getOpcode();
19741 
19742   // VECREDUCE over 1-element vector is just an extract.
19743   if (VT.getVectorNumElements() == 1) {
19744     SDLoc dl(N);
19745     SDValue Res = DAG.getNode(
19746         ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
19747         DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout())));
19748     if (Res.getValueType() != N->getValueType(0))
19749       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
19750     return Res;
19751   }
19752 
19753   // On an boolean vector an and/or reduction is the same as a umin/umax
19754   // reduction. Convert them if the latter is legal while the former isn't.
19755   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
19756     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
19757         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
19758     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
19759         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
19760         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
19761       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
19762   }
19763 
19764   return SDValue();
19765 }
19766 
19767 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
19768 /// with the destination vector and a zero vector.
19769 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
19770 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)19771 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
19772   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
19773 
19774   EVT VT = N->getValueType(0);
19775   SDValue LHS = N->getOperand(0);
19776   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
19777   SDLoc DL(N);
19778 
19779   // Make sure we're not running after operation legalization where it
19780   // may have custom lowered the vector shuffles.
19781   if (LegalOperations)
19782     return SDValue();
19783 
19784   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
19785     return SDValue();
19786 
19787   EVT RVT = RHS.getValueType();
19788   unsigned NumElts = RHS.getNumOperands();
19789 
19790   // Attempt to create a valid clear mask, splitting the mask into
19791   // sub elements and checking to see if each is
19792   // all zeros or all ones - suitable for shuffle masking.
19793   auto BuildClearMask = [&](int Split) {
19794     int NumSubElts = NumElts * Split;
19795     int NumSubBits = RVT.getScalarSizeInBits() / Split;
19796 
19797     SmallVector<int, 8> Indices;
19798     for (int i = 0; i != NumSubElts; ++i) {
19799       int EltIdx = i / Split;
19800       int SubIdx = i % Split;
19801       SDValue Elt = RHS.getOperand(EltIdx);
19802       // X & undef --> 0 (not undef). So this lane must be converted to choose
19803       // from the zero constant vector (same as if the element had all 0-bits).
19804       if (Elt.isUndef()) {
19805         Indices.push_back(i + NumSubElts);
19806         continue;
19807       }
19808 
19809       APInt Bits;
19810       if (isa<ConstantSDNode>(Elt))
19811         Bits = cast<ConstantSDNode>(Elt)->getAPIntValue();
19812       else if (isa<ConstantFPSDNode>(Elt))
19813         Bits = cast<ConstantFPSDNode>(Elt)->getValueAPF().bitcastToAPInt();
19814       else
19815         return SDValue();
19816 
19817       // Extract the sub element from the constant bit mask.
19818       if (DAG.getDataLayout().isBigEndian())
19819         Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
19820       else
19821         Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
19822 
19823       if (Bits.isAllOnesValue())
19824         Indices.push_back(i);
19825       else if (Bits == 0)
19826         Indices.push_back(i + NumSubElts);
19827       else
19828         return SDValue();
19829     }
19830 
19831     // Let's see if the target supports this vector_shuffle.
19832     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
19833     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
19834     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
19835       return SDValue();
19836 
19837     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
19838     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
19839                                                    DAG.getBitcast(ClearVT, LHS),
19840                                                    Zero, Indices));
19841   };
19842 
19843   // Determine maximum split level (byte level masking).
19844   int MaxSplit = 1;
19845   if (RVT.getScalarSizeInBits() % 8 == 0)
19846     MaxSplit = RVT.getScalarSizeInBits() / 8;
19847 
19848   for (int Split = 1; Split <= MaxSplit; ++Split)
19849     if (RVT.getScalarSizeInBits() % Split == 0)
19850       if (SDValue S = BuildClearMask(Split))
19851         return S;
19852 
19853   return SDValue();
19854 }
19855 
19856 /// If a vector binop is performed on splat values, it may be profitable to
19857 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG)19858 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG) {
19859   SDValue N0 = N->getOperand(0);
19860   SDValue N1 = N->getOperand(1);
19861   unsigned Opcode = N->getOpcode();
19862   EVT VT = N->getValueType(0);
19863   EVT EltVT = VT.getVectorElementType();
19864   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19865 
19866   // TODO: Remove/replace the extract cost check? If the elements are available
19867   //       as scalars, then there may be no extract cost. Should we ask if
19868   //       inserting a scalar back into a vector is cheap instead?
19869   int Index0, Index1;
19870   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
19871   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
19872   if (!Src0 || !Src1 || Index0 != Index1 ||
19873       Src0.getValueType().getVectorElementType() != EltVT ||
19874       Src1.getValueType().getVectorElementType() != EltVT ||
19875       !TLI.isExtractVecEltCheap(VT, Index0) ||
19876       !TLI.isOperationLegalOrCustom(Opcode, EltVT))
19877     return SDValue();
19878 
19879   SDLoc DL(N);
19880   SDValue IndexC =
19881       DAG.getConstant(Index0, DL, TLI.getVectorIdxTy(DAG.getDataLayout()));
19882   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, N0, IndexC);
19883   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, N1, IndexC);
19884   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
19885 
19886   // If all lanes but 1 are undefined, no need to splat the scalar result.
19887   // TODO: Keep track of undefs and use that info in the general case.
19888   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
19889       count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
19890       count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
19891     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
19892     // build_vec ..undef, (bo X, Y), undef...
19893     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
19894     Ops[Index0] = ScalarBO;
19895     return DAG.getBuildVector(VT, DL, Ops);
19896   }
19897 
19898   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
19899   SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
19900   return DAG.getBuildVector(VT, DL, Ops);
19901 }
19902 
19903 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N)19904 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
19905   assert(N->getValueType(0).isVector() &&
19906          "SimplifyVBinOp only works on vectors!");
19907 
19908   SDValue LHS = N->getOperand(0);
19909   SDValue RHS = N->getOperand(1);
19910   SDValue Ops[] = {LHS, RHS};
19911   EVT VT = N->getValueType(0);
19912   unsigned Opcode = N->getOpcode();
19913 
19914   // See if we can constant fold the vector operation.
19915   if (SDValue Fold = DAG.FoldConstantVectorArithmetic(
19916           Opcode, SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags()))
19917     return Fold;
19918 
19919   // Move unary shuffles with identical masks after a vector binop:
19920   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
19921   //   --> shuffle (VBinOp A, B), Undef, Mask
19922   // This does not require type legality checks because we are creating the
19923   // same types of operations that are in the original sequence. We do have to
19924   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
19925   // though. This code is adapted from the identical transform in instcombine.
19926   if (Opcode != ISD::UDIV && Opcode != ISD::SDIV &&
19927       Opcode != ISD::UREM && Opcode != ISD::SREM &&
19928       Opcode != ISD::UDIVREM && Opcode != ISD::SDIVREM) {
19929     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
19930     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
19931     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
19932         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
19933         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
19934       SDLoc DL(N);
19935       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
19936                                      RHS.getOperand(0), N->getFlags());
19937       SDValue UndefV = LHS.getOperand(1);
19938       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
19939     }
19940   }
19941 
19942   // The following pattern is likely to emerge with vector reduction ops. Moving
19943   // the binary operation ahead of insertion may allow using a narrower vector
19944   // instruction that has better performance than the wide version of the op:
19945   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
19946   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
19947       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
19948       LHS.getOperand(2) == RHS.getOperand(2) &&
19949       (LHS.hasOneUse() || RHS.hasOneUse())) {
19950     SDValue X = LHS.getOperand(1);
19951     SDValue Y = RHS.getOperand(1);
19952     SDValue Z = LHS.getOperand(2);
19953     EVT NarrowVT = X.getValueType();
19954     if (NarrowVT == Y.getValueType() &&
19955         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
19956       // (binop undef, undef) may not return undef, so compute that result.
19957       SDLoc DL(N);
19958       SDValue VecC =
19959           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
19960       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
19961       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
19962     }
19963   }
19964 
19965   // Make sure all but the first op are undef or constant.
19966   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
19967     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
19968            std::all_of(std::next(Concat->op_begin()), Concat->op_end(),
19969                      [](const SDValue &Op) {
19970                        return Op.isUndef() ||
19971                               ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
19972                      });
19973   };
19974 
19975   // The following pattern is likely to emerge with vector reduction ops. Moving
19976   // the binary operation ahead of the concat may allow using a narrower vector
19977   // instruction that has better performance than the wide version of the op:
19978   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
19979   //   concat (VBinOp X, Y), VecC
19980   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
19981       (LHS.hasOneUse() || RHS.hasOneUse())) {
19982     EVT NarrowVT = LHS.getOperand(0).getValueType();
19983     if (NarrowVT == RHS.getOperand(0).getValueType() &&
19984         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
19985       SDLoc DL(N);
19986       unsigned NumOperands = LHS.getNumOperands();
19987       SmallVector<SDValue, 4> ConcatOps;
19988       for (unsigned i = 0; i != NumOperands; ++i) {
19989         // This constant fold for operands 1 and up.
19990         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
19991                                         RHS.getOperand(i)));
19992       }
19993 
19994       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
19995     }
19996   }
19997 
19998   if (SDValue V = scalarizeBinOpOfSplats(N, DAG))
19999     return V;
20000 
20001   return SDValue();
20002 }
20003 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)20004 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
20005                                     SDValue N2) {
20006   assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
20007 
20008   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
20009                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
20010 
20011   // If we got a simplified select_cc node back from SimplifySelectCC, then
20012   // break it down into a new SETCC node, and a new SELECT node, and then return
20013   // the SELECT node, since we were called with a SELECT node.
20014   if (SCC.getNode()) {
20015     // Check to see if we got a select_cc back (to turn into setcc/select).
20016     // Otherwise, just return whatever node we got back, like fabs.
20017     if (SCC.getOpcode() == ISD::SELECT_CC) {
20018       const SDNodeFlags Flags = N0.getNode()->getFlags();
20019       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
20020                                   N0.getValueType(),
20021                                   SCC.getOperand(0), SCC.getOperand(1),
20022                                   SCC.getOperand(4), Flags);
20023       AddToWorklist(SETCC.getNode());
20024       SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
20025                                          SCC.getOperand(2), SCC.getOperand(3));
20026       SelectNode->setFlags(Flags);
20027       return SelectNode;
20028     }
20029 
20030     return SCC;
20031   }
20032   return SDValue();
20033 }
20034 
20035 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
20036 /// being selected between, see if we can simplify the select.  Callers of this
20037 /// should assume that TheSelect is deleted if this returns true.  As such, they
20038 /// should return the appropriate thing (e.g. the node) back to the top-level of
20039 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)20040 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
20041                                     SDValue RHS) {
20042   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
20043   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
20044   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
20045     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
20046       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
20047       SDValue Sqrt = RHS;
20048       ISD::CondCode CC;
20049       SDValue CmpLHS;
20050       const ConstantFPSDNode *Zero = nullptr;
20051 
20052       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
20053         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
20054         CmpLHS = TheSelect->getOperand(0);
20055         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
20056       } else {
20057         // SELECT or VSELECT
20058         SDValue Cmp = TheSelect->getOperand(0);
20059         if (Cmp.getOpcode() == ISD::SETCC) {
20060           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
20061           CmpLHS = Cmp.getOperand(0);
20062           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
20063         }
20064       }
20065       if (Zero && Zero->isZero() &&
20066           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
20067           CC == ISD::SETULT || CC == ISD::SETLT)) {
20068         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
20069         CombineTo(TheSelect, Sqrt);
20070         return true;
20071       }
20072     }
20073   }
20074   // Cannot simplify select with vector condition
20075   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
20076 
20077   // If this is a select from two identical things, try to pull the operation
20078   // through the select.
20079   if (LHS.getOpcode() != RHS.getOpcode() ||
20080       !LHS.hasOneUse() || !RHS.hasOneUse())
20081     return false;
20082 
20083   // If this is a load and the token chain is identical, replace the select
20084   // of two loads with a load through a select of the address to load from.
20085   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
20086   // constants have been dropped into the constant pool.
20087   if (LHS.getOpcode() == ISD::LOAD) {
20088     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
20089     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
20090 
20091     // Token chains must be identical.
20092     if (LHS.getOperand(0) != RHS.getOperand(0) ||
20093         // Do not let this transformation reduce the number of volatile loads.
20094         // Be conservative for atomics for the moment
20095         // TODO: This does appear to be legal for unordered atomics (see D66309)
20096         !LLD->isSimple() || !RLD->isSimple() ||
20097         // FIXME: If either is a pre/post inc/dec load,
20098         // we'd need to split out the address adjustment.
20099         LLD->isIndexed() || RLD->isIndexed() ||
20100         // If this is an EXTLOAD, the VT's must match.
20101         LLD->getMemoryVT() != RLD->getMemoryVT() ||
20102         // If this is an EXTLOAD, the kind of extension must match.
20103         (LLD->getExtensionType() != RLD->getExtensionType() &&
20104          // The only exception is if one of the extensions is anyext.
20105          LLD->getExtensionType() != ISD::EXTLOAD &&
20106          RLD->getExtensionType() != ISD::EXTLOAD) ||
20107         // FIXME: this discards src value information.  This is
20108         // over-conservative. It would be beneficial to be able to remember
20109         // both potential memory locations.  Since we are discarding
20110         // src value info, don't do the transformation if the memory
20111         // locations are not in the default address space.
20112         LLD->getPointerInfo().getAddrSpace() != 0 ||
20113         RLD->getPointerInfo().getAddrSpace() != 0 ||
20114         // We can't produce a CMOV of a TargetFrameIndex since we won't
20115         // generate the address generation required.
20116         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
20117         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
20118         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
20119                                       LLD->getBasePtr().getValueType()))
20120       return false;
20121 
20122     // The loads must not depend on one another.
20123     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
20124       return false;
20125 
20126     // Check that the select condition doesn't reach either load.  If so,
20127     // folding this will induce a cycle into the DAG.  If not, this is safe to
20128     // xform, so create a select of the addresses.
20129 
20130     SmallPtrSet<const SDNode *, 32> Visited;
20131     SmallVector<const SDNode *, 16> Worklist;
20132 
20133     // Always fail if LLD and RLD are not independent. TheSelect is a
20134     // predecessor to all Nodes in question so we need not search past it.
20135 
20136     Visited.insert(TheSelect);
20137     Worklist.push_back(LLD);
20138     Worklist.push_back(RLD);
20139 
20140     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
20141         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
20142       return false;
20143 
20144     SDValue Addr;
20145     if (TheSelect->getOpcode() == ISD::SELECT) {
20146       // We cannot do this optimization if any pair of {RLD, LLD} is a
20147       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
20148       // Loads, we only need to check if CondNode is a successor to one of the
20149       // loads. We can further avoid this if there's no use of their chain
20150       // value.
20151       SDNode *CondNode = TheSelect->getOperand(0).getNode();
20152       Worklist.push_back(CondNode);
20153 
20154       if ((LLD->hasAnyUseOfValue(1) &&
20155            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
20156           (RLD->hasAnyUseOfValue(1) &&
20157            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
20158         return false;
20159 
20160       Addr = DAG.getSelect(SDLoc(TheSelect),
20161                            LLD->getBasePtr().getValueType(),
20162                            TheSelect->getOperand(0), LLD->getBasePtr(),
20163                            RLD->getBasePtr());
20164     } else {  // Otherwise SELECT_CC
20165       // We cannot do this optimization if any pair of {RLD, LLD} is a
20166       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
20167       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
20168       // one of the loads. We can further avoid this if there's no use of their
20169       // chain value.
20170 
20171       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
20172       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
20173       Worklist.push_back(CondLHS);
20174       Worklist.push_back(CondRHS);
20175 
20176       if ((LLD->hasAnyUseOfValue(1) &&
20177            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
20178           (RLD->hasAnyUseOfValue(1) &&
20179            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
20180         return false;
20181 
20182       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
20183                          LLD->getBasePtr().getValueType(),
20184                          TheSelect->getOperand(0),
20185                          TheSelect->getOperand(1),
20186                          LLD->getBasePtr(), RLD->getBasePtr(),
20187                          TheSelect->getOperand(4));
20188     }
20189 
20190     SDValue Load;
20191     // It is safe to replace the two loads if they have different alignments,
20192     // but the new load must be the minimum (most restrictive) alignment of the
20193     // inputs.
20194     unsigned Alignment = std::min(LLD->getAlignment(), RLD->getAlignment());
20195     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
20196     if (!RLD->isInvariant())
20197       MMOFlags &= ~MachineMemOperand::MOInvariant;
20198     if (!RLD->isDereferenceable())
20199       MMOFlags &= ~MachineMemOperand::MODereferenceable;
20200     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
20201       // FIXME: Discards pointer and AA info.
20202       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
20203                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
20204                          MMOFlags);
20205     } else {
20206       // FIXME: Discards pointer and AA info.
20207       Load = DAG.getExtLoad(
20208           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
20209                                                   : LLD->getExtensionType(),
20210           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
20211           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
20212     }
20213 
20214     // Users of the select now use the result of the load.
20215     CombineTo(TheSelect, Load);
20216 
20217     // Users of the old loads now use the new load's chain.  We know the
20218     // old-load value is dead now.
20219     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
20220     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
20221     return true;
20222   }
20223 
20224   return false;
20225 }
20226 
20227 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
20228 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)20229 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
20230                                             SDValue N1, SDValue N2, SDValue N3,
20231                                             ISD::CondCode CC) {
20232   // If this is a select where the false operand is zero and the compare is a
20233   // check of the sign bit, see if we can perform the "gzip trick":
20234   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
20235   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
20236   EVT XType = N0.getValueType();
20237   EVT AType = N2.getValueType();
20238   if (!isNullConstant(N3) || !XType.bitsGE(AType))
20239     return SDValue();
20240 
20241   // If the comparison is testing for a positive value, we have to invert
20242   // the sign bit mask, so only do that transform if the target has a bitwise
20243   // 'and not' instruction (the invert is free).
20244   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
20245     // (X > -1) ? A : 0
20246     // (X >  0) ? X : 0 <-- This is canonical signed max.
20247     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
20248       return SDValue();
20249   } else if (CC == ISD::SETLT) {
20250     // (X <  0) ? A : 0
20251     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
20252     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
20253       return SDValue();
20254   } else {
20255     return SDValue();
20256   }
20257 
20258   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
20259   // constant.
20260   EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
20261   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
20262   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
20263     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
20264     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
20265       SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
20266       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
20267       AddToWorklist(Shift.getNode());
20268 
20269       if (XType.bitsGT(AType)) {
20270         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
20271         AddToWorklist(Shift.getNode());
20272       }
20273 
20274       if (CC == ISD::SETGT)
20275         Shift = DAG.getNOT(DL, Shift, AType);
20276 
20277       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
20278     }
20279   }
20280 
20281   unsigned ShCt = XType.getSizeInBits() - 1;
20282   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
20283     return SDValue();
20284 
20285   SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
20286   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
20287   AddToWorklist(Shift.getNode());
20288 
20289   if (XType.bitsGT(AType)) {
20290     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
20291     AddToWorklist(Shift.getNode());
20292   }
20293 
20294   if (CC == ISD::SETGT)
20295     Shift = DAG.getNOT(DL, Shift, AType);
20296 
20297   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
20298 }
20299 
20300 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
20301 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
20302 /// in it. This may be a win when the constant is not otherwise available
20303 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)20304 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
20305     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
20306     ISD::CondCode CC) {
20307   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
20308     return SDValue();
20309 
20310   // If we are before legalize types, we want the other legalization to happen
20311   // first (for example, to avoid messing with soft float).
20312   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
20313   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
20314   EVT VT = N2.getValueType();
20315   if (!TV || !FV || !TLI.isTypeLegal(VT))
20316     return SDValue();
20317 
20318   // If a constant can be materialized without loads, this does not make sense.
20319   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
20320       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
20321       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
20322     return SDValue();
20323 
20324   // If both constants have multiple uses, then we won't need to do an extra
20325   // load. The values are likely around in registers for other users.
20326   if (!TV->hasOneUse() && !FV->hasOneUse())
20327     return SDValue();
20328 
20329   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
20330                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
20331   Type *FPTy = Elts[0]->getType();
20332   const DataLayout &TD = DAG.getDataLayout();
20333 
20334   // Create a ConstantArray of the two constants.
20335   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
20336   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
20337                                       TD.getPrefTypeAlignment(FPTy));
20338   unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment();
20339 
20340   // Get offsets to the 0 and 1 elements of the array, so we can select between
20341   // them.
20342   SDValue Zero = DAG.getIntPtrConstant(0, DL);
20343   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
20344   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
20345   SDValue Cond =
20346       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
20347   AddToWorklist(Cond.getNode());
20348   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
20349   AddToWorklist(CstOffset.getNode());
20350   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
20351   AddToWorklist(CPIdx.getNode());
20352   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
20353                      MachinePointerInfo::getConstantPool(
20354                          DAG.getMachineFunction()), Alignment);
20355 }
20356 
20357 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
20358 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)20359 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
20360                                       SDValue N2, SDValue N3, ISD::CondCode CC,
20361                                       bool NotExtCompare) {
20362   // (x ? y : y) -> y.
20363   if (N2 == N3) return N2;
20364 
20365   EVT CmpOpVT = N0.getValueType();
20366   EVT CmpResVT = getSetCCResultType(CmpOpVT);
20367   EVT VT = N2.getValueType();
20368   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
20369   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
20370   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
20371 
20372   // Determine if the condition we're dealing with is constant.
20373   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
20374     AddToWorklist(SCC.getNode());
20375     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
20376       // fold select_cc true, x, y -> x
20377       // fold select_cc false, x, y -> y
20378       return !(SCCC->isNullValue()) ? N2 : N3;
20379     }
20380   }
20381 
20382   if (SDValue V =
20383           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
20384     return V;
20385 
20386   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
20387     return V;
20388 
20389   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (shr (shl x)) A)
20390   // where y is has a single bit set.
20391   // A plaintext description would be, we can turn the SELECT_CC into an AND
20392   // when the condition can be materialized as an all-ones register.  Any
20393   // single bit-test can be materialized as an all-ones register with
20394   // shift-left and shift-right-arith.
20395   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
20396       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
20397     SDValue AndLHS = N0->getOperand(0);
20398     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
20399     if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
20400       // Shift the tested bit over the sign bit.
20401       const APInt &AndMask = ConstAndRHS->getAPIntValue();
20402       unsigned ShCt = AndMask.getBitWidth() - 1;
20403       if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
20404         SDValue ShlAmt =
20405           DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
20406                           getShiftAmountTy(AndLHS.getValueType()));
20407         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
20408 
20409         // Now arithmetic right shift it all the way over, so the result is
20410         // either all-ones, or zero.
20411         SDValue ShrAmt =
20412           DAG.getConstant(ShCt, SDLoc(Shl),
20413                           getShiftAmountTy(Shl.getValueType()));
20414         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
20415 
20416         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
20417       }
20418     }
20419   }
20420 
20421   // fold select C, 16, 0 -> shl C, 4
20422   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
20423   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
20424 
20425   if ((Fold || Swap) &&
20426       TLI.getBooleanContents(CmpOpVT) ==
20427           TargetLowering::ZeroOrOneBooleanContent &&
20428       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
20429 
20430     if (Swap) {
20431       CC = ISD::getSetCCInverse(CC, CmpOpVT);
20432       std::swap(N2C, N3C);
20433     }
20434 
20435     // If the caller doesn't want us to simplify this into a zext of a compare,
20436     // don't do it.
20437     if (NotExtCompare && N2C->isOne())
20438       return SDValue();
20439 
20440     SDValue Temp, SCC;
20441     // zext (setcc n0, n1)
20442     if (LegalTypes) {
20443       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
20444       if (VT.bitsLT(SCC.getValueType()))
20445         Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
20446       else
20447         Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
20448     } else {
20449       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
20450       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
20451     }
20452 
20453     AddToWorklist(SCC.getNode());
20454     AddToWorklist(Temp.getNode());
20455 
20456     if (N2C->isOne())
20457       return Temp;
20458 
20459     unsigned ShCt = N2C->getAPIntValue().logBase2();
20460     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
20461       return SDValue();
20462 
20463     // shl setcc result by log2 n2c
20464     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
20465                        DAG.getConstant(ShCt, SDLoc(Temp),
20466                                        getShiftAmountTy(Temp.getValueType())));
20467   }
20468 
20469   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
20470   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
20471   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
20472   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
20473   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
20474   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
20475   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
20476   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
20477   if (N1C && N1C->isNullValue() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
20478     SDValue ValueOnZero = N2;
20479     SDValue Count = N3;
20480     // If the condition is NE instead of E, swap the operands.
20481     if (CC == ISD::SETNE)
20482       std::swap(ValueOnZero, Count);
20483     // Check if the value on zero is a constant equal to the bits in the type.
20484     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
20485       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
20486         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
20487         // legal, combine to just cttz.
20488         if ((Count.getOpcode() == ISD::CTTZ ||
20489              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
20490             N0 == Count.getOperand(0) &&
20491             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
20492           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
20493         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
20494         // legal, combine to just ctlz.
20495         if ((Count.getOpcode() == ISD::CTLZ ||
20496              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
20497             N0 == Count.getOperand(0) &&
20498             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
20499           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
20500       }
20501     }
20502   }
20503 
20504   return SDValue();
20505 }
20506 
20507 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)20508 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
20509                                    ISD::CondCode Cond, const SDLoc &DL,
20510                                    bool foldBooleans) {
20511   TargetLowering::DAGCombinerInfo
20512     DagCombineInfo(DAG, Level, false, this);
20513   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
20514 }
20515 
20516 /// Given an ISD::SDIV node expressing a divide by constant, return
20517 /// a DAG expression to select that will generate the same value by multiplying
20518 /// by a magic number.
20519 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)20520 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
20521   // when optimising for minimum size, we don't want to expand a div to a mul
20522   // and a shift.
20523   if (DAG.getMachineFunction().getFunction().hasMinSize())
20524     return SDValue();
20525 
20526   SmallVector<SDNode *, 8> Built;
20527   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
20528     for (SDNode *N : Built)
20529       AddToWorklist(N);
20530     return S;
20531   }
20532 
20533   return SDValue();
20534 }
20535 
20536 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
20537 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)20538 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
20539   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
20540   if (!C)
20541     return SDValue();
20542 
20543   // Avoid division by zero.
20544   if (C->isNullValue())
20545     return SDValue();
20546 
20547   SmallVector<SDNode *, 8> Built;
20548   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
20549     for (SDNode *N : Built)
20550       AddToWorklist(N);
20551     return S;
20552   }
20553 
20554   return SDValue();
20555 }
20556 
20557 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
20558 /// expression that will generate the same value by multiplying by a magic
20559 /// number.
20560 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)20561 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
20562   // when optimising for minimum size, we don't want to expand a div to a mul
20563   // and a shift.
20564   if (DAG.getMachineFunction().getFunction().hasMinSize())
20565     return SDValue();
20566 
20567   SmallVector<SDNode *, 8> Built;
20568   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
20569     for (SDNode *N : Built)
20570       AddToWorklist(N);
20571     return S;
20572   }
20573 
20574   return SDValue();
20575 }
20576 
20577 /// Determines the LogBase2 value for a non-null input value using the
20578 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL)20579 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
20580   EVT VT = V.getValueType();
20581   unsigned EltBits = VT.getScalarSizeInBits();
20582   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
20583   SDValue Base = DAG.getConstant(EltBits - 1, DL, VT);
20584   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
20585   return LogBase2;
20586 }
20587 
20588 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
20589 /// For the reciprocal, we need to find the zero of the function:
20590 ///   F(X) = A X - 1 [which has a zero at X = 1/A]
20591 ///     =>
20592 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
20593 ///     does not require additional intermediate precision]
20594 /// For the last iteration, put numerator N into it to gain more precision:
20595 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)20596 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
20597                                       SDNodeFlags Flags) {
20598   if (LegalDAG)
20599     return SDValue();
20600 
20601   // TODO: Handle half and/or extended types?
20602   EVT VT = Op.getValueType();
20603   if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
20604     return SDValue();
20605 
20606   // If estimates are explicitly disabled for this function, we're done.
20607   MachineFunction &MF = DAG.getMachineFunction();
20608   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
20609   if (Enabled == TLI.ReciprocalEstimate::Disabled)
20610     return SDValue();
20611 
20612   // Estimates may be explicitly enabled for this type with a custom number of
20613   // refinement steps.
20614   int Iterations = TLI.getDivRefinementSteps(VT, MF);
20615   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
20616     AddToWorklist(Est.getNode());
20617 
20618     SDLoc DL(Op);
20619     if (Iterations) {
20620       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
20621 
20622       // Newton iterations: Est = Est + Est (N - Arg * Est)
20623       // If this is the last iteration, also multiply by the numerator.
20624       for (int i = 0; i < Iterations; ++i) {
20625         SDValue MulEst = Est;
20626 
20627         if (i == Iterations - 1) {
20628           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
20629           AddToWorklist(MulEst.getNode());
20630         }
20631 
20632         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
20633         AddToWorklist(NewEst.getNode());
20634 
20635         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
20636                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
20637         AddToWorklist(NewEst.getNode());
20638 
20639         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
20640         AddToWorklist(NewEst.getNode());
20641 
20642         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
20643         AddToWorklist(Est.getNode());
20644       }
20645     } else {
20646       // If no iterations are available, multiply with N.
20647       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
20648       AddToWorklist(Est.getNode());
20649     }
20650 
20651     return Est;
20652   }
20653 
20654   return SDValue();
20655 }
20656 
20657 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
20658 /// For the reciprocal sqrt, we need to find the zero of the function:
20659 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
20660 ///     =>
20661 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
20662 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)20663 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
20664                                          unsigned Iterations,
20665                                          SDNodeFlags Flags, bool Reciprocal) {
20666   EVT VT = Arg.getValueType();
20667   SDLoc DL(Arg);
20668   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
20669 
20670   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
20671   // this entire sequence requires only one FP constant.
20672   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
20673   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
20674 
20675   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
20676   for (unsigned i = 0; i < Iterations; ++i) {
20677     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
20678     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
20679     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
20680     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
20681   }
20682 
20683   // If non-reciprocal square root is requested, multiply the result by Arg.
20684   if (!Reciprocal)
20685     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
20686 
20687   return Est;
20688 }
20689 
20690 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
20691 /// For the reciprocal sqrt, we need to find the zero of the function:
20692 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
20693 ///     =>
20694 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)20695 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
20696                                          unsigned Iterations,
20697                                          SDNodeFlags Flags, bool Reciprocal) {
20698   EVT VT = Arg.getValueType();
20699   SDLoc DL(Arg);
20700   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
20701   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
20702 
20703   // This routine must enter the loop below to work correctly
20704   // when (Reciprocal == false).
20705   assert(Iterations > 0);
20706 
20707   // Newton iterations for reciprocal square root:
20708   // E = (E * -0.5) * ((A * E) * E + -3.0)
20709   for (unsigned i = 0; i < Iterations; ++i) {
20710     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
20711     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
20712     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
20713 
20714     // When calculating a square root at the last iteration build:
20715     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
20716     // (notice a common subexpression)
20717     SDValue LHS;
20718     if (Reciprocal || (i + 1) < Iterations) {
20719       // RSQRT: LHS = (E * -0.5)
20720       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
20721     } else {
20722       // SQRT: LHS = (A * E) * -0.5
20723       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
20724     }
20725 
20726     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
20727   }
20728 
20729   return Est;
20730 }
20731 
20732 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
20733 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
20734 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)20735 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
20736                                            bool Reciprocal) {
20737   if (LegalDAG)
20738     return SDValue();
20739 
20740   // TODO: Handle half and/or extended types?
20741   EVT VT = Op.getValueType();
20742   if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
20743     return SDValue();
20744 
20745   // If estimates are explicitly disabled for this function, we're done.
20746   MachineFunction &MF = DAG.getMachineFunction();
20747   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
20748   if (Enabled == TLI.ReciprocalEstimate::Disabled)
20749     return SDValue();
20750 
20751   // Estimates may be explicitly enabled for this type with a custom number of
20752   // refinement steps.
20753   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
20754 
20755   bool UseOneConstNR = false;
20756   if (SDValue Est =
20757       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
20758                           Reciprocal)) {
20759     AddToWorklist(Est.getNode());
20760 
20761     if (Iterations) {
20762       Est = UseOneConstNR
20763             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
20764             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
20765 
20766       if (!Reciprocal) {
20767         // The estimate is now completely wrong if the input was exactly 0.0 or
20768         // possibly a denormal. Force the answer to 0.0 for those cases.
20769         SDLoc DL(Op);
20770         EVT CCVT = getSetCCResultType(VT);
20771         ISD::NodeType SelOpcode = VT.isVector() ? ISD::VSELECT : ISD::SELECT;
20772         DenormalMode DenormMode = DAG.getDenormalMode(VT);
20773         if (DenormMode == DenormalMode::IEEE) {
20774           // fabs(X) < SmallestNormal ? 0.0 : Est
20775           const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
20776           APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
20777           SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
20778           SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
20779           SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
20780           SDValue IsDenorm = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
20781           Est = DAG.getNode(SelOpcode, DL, VT, IsDenorm, FPZero, Est);
20782         } else {
20783           // X == 0.0 ? 0.0 : Est
20784           SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
20785           SDValue IsZero = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
20786           Est = DAG.getNode(SelOpcode, DL, VT, IsZero, FPZero, Est);
20787         }
20788       }
20789     }
20790     return Est;
20791   }
20792 
20793   return SDValue();
20794 }
20795 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)20796 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
20797   return buildSqrtEstimateImpl(Op, Flags, true);
20798 }
20799 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)20800 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
20801   return buildSqrtEstimateImpl(Op, Flags, false);
20802 }
20803 
20804 /// Return true if there is any possibility that the two addresses overlap.
isAlias(SDNode * Op0,SDNode * Op1) const20805 bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const {
20806 
20807   struct MemUseCharacteristics {
20808     bool IsVolatile;
20809     bool IsAtomic;
20810     SDValue BasePtr;
20811     int64_t Offset;
20812     Optional<int64_t> NumBytes;
20813     MachineMemOperand *MMO;
20814   };
20815 
20816   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
20817     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
20818       int64_t Offset = 0;
20819       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
20820         Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
20821                      ? C->getSExtValue()
20822                      : (LSN->getAddressingMode() == ISD::PRE_DEC)
20823                            ? -1 * C->getSExtValue()
20824                            : 0;
20825       return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(),
20826               Offset /*base offset*/,
20827               Optional<int64_t>(LSN->getMemoryVT().getStoreSize()),
20828               LSN->getMemOperand()};
20829     }
20830     if (const auto *LN = cast<LifetimeSDNode>(N))
20831       return {false /*isVolatile*/, /*isAtomic*/ false, LN->getOperand(1),
20832               (LN->hasOffset()) ? LN->getOffset() : 0,
20833               (LN->hasOffset()) ? Optional<int64_t>(LN->getSize())
20834                                 : Optional<int64_t>(),
20835               (MachineMemOperand *)nullptr};
20836     // Default.
20837     return {false /*isvolatile*/, /*isAtomic*/ false, SDValue(),
20838             (int64_t)0 /*offset*/,
20839             Optional<int64_t>() /*size*/, (MachineMemOperand *)nullptr};
20840   };
20841 
20842   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
20843                         MUC1 = getCharacteristics(Op1);
20844 
20845   // If they are to the same address, then they must be aliases.
20846   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
20847       MUC0.Offset == MUC1.Offset)
20848     return true;
20849 
20850   // If they are both volatile then they cannot be reordered.
20851   if (MUC0.IsVolatile && MUC1.IsVolatile)
20852     return true;
20853 
20854   // Be conservative about atomics for the moment
20855   // TODO: This is way overconservative for unordered atomics (see D66309)
20856   if (MUC0.IsAtomic && MUC1.IsAtomic)
20857     return true;
20858 
20859   if (MUC0.MMO && MUC1.MMO) {
20860     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
20861         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
20862       return false;
20863   }
20864 
20865   // Try to prove that there is aliasing, or that there is no aliasing. Either
20866   // way, we can return now. If nothing can be proved, proceed with more tests.
20867   bool IsAlias;
20868   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
20869                                        DAG, IsAlias))
20870     return IsAlias;
20871 
20872   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
20873   // either are not known.
20874   if (!MUC0.MMO || !MUC1.MMO)
20875     return true;
20876 
20877   // If one operation reads from invariant memory, and the other may store, they
20878   // cannot alias. These should really be checking the equivalent of mayWrite,
20879   // but it only matters for memory nodes other than load /store.
20880   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
20881       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
20882     return false;
20883 
20884   // If we know required SrcValue1 and SrcValue2 have relatively large
20885   // alignment compared to the size and offset of the access, we may be able
20886   // to prove they do not alias. This check is conservative for now to catch
20887   // cases created by splitting vector types.
20888   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
20889   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
20890   unsigned OrigAlignment0 = MUC0.MMO->getBaseAlignment();
20891   unsigned OrigAlignment1 = MUC1.MMO->getBaseAlignment();
20892   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
20893       MUC0.NumBytes.hasValue() && MUC1.NumBytes.hasValue() &&
20894       *MUC0.NumBytes == *MUC1.NumBytes && OrigAlignment0 > *MUC0.NumBytes) {
20895     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0;
20896     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1;
20897 
20898     // There is no overlap between these relatively aligned accesses of
20899     // similar size. Return no alias.
20900     if ((OffAlign0 + *MUC0.NumBytes) <= OffAlign1 ||
20901         (OffAlign1 + *MUC1.NumBytes) <= OffAlign0)
20902       return false;
20903   }
20904 
20905   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
20906                    ? CombinerGlobalAA
20907                    : DAG.getSubtarget().useAA();
20908 #ifndef NDEBUG
20909   if (CombinerAAOnlyFunc.getNumOccurrences() &&
20910       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
20911     UseAA = false;
20912 #endif
20913 
20914   if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue()) {
20915     // Use alias analysis information.
20916     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
20917     int64_t Overlap0 = *MUC0.NumBytes + SrcValOffset0 - MinOffset;
20918     int64_t Overlap1 = *MUC1.NumBytes + SrcValOffset1 - MinOffset;
20919     AliasResult AAResult = AA->alias(
20920         MemoryLocation(MUC0.MMO->getValue(), Overlap0,
20921                        UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
20922         MemoryLocation(MUC1.MMO->getValue(), Overlap1,
20923                        UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes()));
20924     if (AAResult == NoAlias)
20925       return false;
20926   }
20927 
20928   // Otherwise we have to assume they alias.
20929   return true;
20930 }
20931 
20932 /// Walk up chain skipping non-aliasing memory nodes,
20933 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)20934 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
20935                                    SmallVectorImpl<SDValue> &Aliases) {
20936   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
20937   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
20938 
20939   // Get alias information for node.
20940   // TODO: relax aliasing for unordered atomics (see D66309)
20941   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
20942 
20943   // Starting off.
20944   Chains.push_back(OriginalChain);
20945   unsigned Depth = 0;
20946 
20947   // Attempt to improve chain by a single step
20948   std::function<bool(SDValue &)> ImproveChain = [&](SDValue &C) -> bool {
20949     switch (C.getOpcode()) {
20950     case ISD::EntryToken:
20951       // No need to mark EntryToken.
20952       C = SDValue();
20953       return true;
20954     case ISD::LOAD:
20955     case ISD::STORE: {
20956       // Get alias information for C.
20957       // TODO: Relax aliasing for unordered atomics (see D66309)
20958       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
20959                       cast<LSBaseSDNode>(C.getNode())->isSimple();
20960       if ((IsLoad && IsOpLoad) || !isAlias(N, C.getNode())) {
20961         // Look further up the chain.
20962         C = C.getOperand(0);
20963         return true;
20964       }
20965       // Alias, so stop here.
20966       return false;
20967     }
20968 
20969     case ISD::CopyFromReg:
20970       // Always forward past past CopyFromReg.
20971       C = C.getOperand(0);
20972       return true;
20973 
20974     case ISD::LIFETIME_START:
20975     case ISD::LIFETIME_END: {
20976       // We can forward past any lifetime start/end that can be proven not to
20977       // alias the memory access.
20978       if (!isAlias(N, C.getNode())) {
20979         // Look further up the chain.
20980         C = C.getOperand(0);
20981         return true;
20982       }
20983       return false;
20984     }
20985     default:
20986       return false;
20987     }
20988   };
20989 
20990   // Look at each chain and determine if it is an alias.  If so, add it to the
20991   // aliases list.  If not, then continue up the chain looking for the next
20992   // candidate.
20993   while (!Chains.empty()) {
20994     SDValue Chain = Chains.pop_back_val();
20995 
20996     // Don't bother if we've seen Chain before.
20997     if (!Visited.insert(Chain.getNode()).second)
20998       continue;
20999 
21000     // For TokenFactor nodes, look at each operand and only continue up the
21001     // chain until we reach the depth limit.
21002     //
21003     // FIXME: The depth check could be made to return the last non-aliasing
21004     // chain we found before we hit a tokenfactor rather than the original
21005     // chain.
21006     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
21007       Aliases.clear();
21008       Aliases.push_back(OriginalChain);
21009       return;
21010     }
21011 
21012     if (Chain.getOpcode() == ISD::TokenFactor) {
21013       // We have to check each of the operands of the token factor for "small"
21014       // token factors, so we queue them up.  Adding the operands to the queue
21015       // (stack) in reverse order maintains the original order and increases the
21016       // likelihood that getNode will find a matching token factor (CSE.)
21017       if (Chain.getNumOperands() > 16) {
21018         Aliases.push_back(Chain);
21019         continue;
21020       }
21021       for (unsigned n = Chain.getNumOperands(); n;)
21022         Chains.push_back(Chain.getOperand(--n));
21023       ++Depth;
21024       continue;
21025     }
21026     // Everything else
21027     if (ImproveChain(Chain)) {
21028       // Updated Chain Found, Consider new chain if one exists.
21029       if (Chain.getNode())
21030         Chains.push_back(Chain);
21031       ++Depth;
21032       continue;
21033     }
21034     // No Improved Chain Possible, treat as Alias.
21035     Aliases.push_back(Chain);
21036   }
21037 }
21038 
21039 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
21040 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)21041 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
21042   if (OptLevel == CodeGenOpt::None)
21043     return OldChain;
21044 
21045   // Ops for replacing token factor.
21046   SmallVector<SDValue, 8> Aliases;
21047 
21048   // Accumulate all the aliases to this node.
21049   GatherAllAliases(N, OldChain, Aliases);
21050 
21051   // If no operands then chain to entry token.
21052   if (Aliases.size() == 0)
21053     return DAG.getEntryNode();
21054 
21055   // If a single operand then chain to it.  We don't need to revisit it.
21056   if (Aliases.size() == 1)
21057     return Aliases[0];
21058 
21059   // Construct a custom tailored token factor.
21060   return DAG.getTokenFactor(SDLoc(N), Aliases);
21061 }
21062 
21063 namespace {
21064 // TODO: Replace with with std::monostate when we move to C++17.
21065 struct UnitT { } Unit;
operator ==(const UnitT &,const UnitT &)21066 bool operator==(const UnitT &, const UnitT &) { return true; }
operator !=(const UnitT &,const UnitT &)21067 bool operator!=(const UnitT &, const UnitT &) { return false; }
21068 } // namespace
21069 
21070 // This function tries to collect a bunch of potentially interesting
21071 // nodes to improve the chains of, all at once. This might seem
21072 // redundant, as this function gets called when visiting every store
21073 // node, so why not let the work be done on each store as it's visited?
21074 //
21075 // I believe this is mainly important because MergeConsecutiveStores
21076 // is unable to deal with merging stores of different sizes, so unless
21077 // we improve the chains of all the potential candidates up-front
21078 // before running MergeConsecutiveStores, it might only see some of
21079 // the nodes that will eventually be candidates, and then not be able
21080 // to go from a partially-merged state to the desired final
21081 // fully-merged state.
21082 
parallelizeChainedStores(StoreSDNode * St)21083 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
21084   SmallVector<StoreSDNode *, 8> ChainedStores;
21085   StoreSDNode *STChain = St;
21086   // Intervals records which offsets from BaseIndex have been covered. In
21087   // the common case, every store writes to the immediately previous address
21088   // space and thus merged with the previous interval at insertion time.
21089 
21090   using IMap =
21091       llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>;
21092   IMap::Allocator A;
21093   IMap Intervals(A);
21094 
21095   // This holds the base pointer, index, and the offset in bytes from the base
21096   // pointer.
21097   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
21098 
21099   // We must have a base and an offset.
21100   if (!BasePtr.getBase().getNode())
21101     return false;
21102 
21103   // Do not handle stores to undef base pointers.
21104   if (BasePtr.getBase().isUndef())
21105     return false;
21106 
21107   // Add ST's interval.
21108   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit);
21109 
21110   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
21111     // If the chain has more than one use, then we can't reorder the mem ops.
21112     if (!SDValue(Chain, 0)->hasOneUse())
21113       break;
21114     // TODO: Relax for unordered atomics (see D66309)
21115     if (!Chain->isSimple() || Chain->isIndexed())
21116       break;
21117 
21118     // Find the base pointer and offset for this memory node.
21119     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
21120     // Check that the base pointer is the same as the original one.
21121     int64_t Offset;
21122     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
21123       break;
21124     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
21125     // Make sure we don't overlap with other intervals by checking the ones to
21126     // the left or right before inserting.
21127     auto I = Intervals.find(Offset);
21128     // If there's a next interval, we should end before it.
21129     if (I != Intervals.end() && I.start() < (Offset + Length))
21130       break;
21131     // If there's a previous interval, we should start after it.
21132     if (I != Intervals.begin() && (--I).stop() <= Offset)
21133       break;
21134     Intervals.insert(Offset, Offset + Length, Unit);
21135 
21136     ChainedStores.push_back(Chain);
21137     STChain = Chain;
21138   }
21139 
21140   // If we didn't find a chained store, exit.
21141   if (ChainedStores.size() == 0)
21142     return false;
21143 
21144   // Improve all chained stores (St and ChainedStores members) starting from
21145   // where the store chain ended and return single TokenFactor.
21146   SDValue NewChain = STChain->getChain();
21147   SmallVector<SDValue, 8> TFOps;
21148   for (unsigned I = ChainedStores.size(); I;) {
21149     StoreSDNode *S = ChainedStores[--I];
21150     SDValue BetterChain = FindBetterChain(S, NewChain);
21151     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
21152         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
21153     TFOps.push_back(SDValue(S, 0));
21154     ChainedStores[I] = S;
21155   }
21156 
21157   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
21158   SDValue BetterChain = FindBetterChain(St, NewChain);
21159   SDValue NewST;
21160   if (St->isTruncatingStore())
21161     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
21162                               St->getBasePtr(), St->getMemoryVT(),
21163                               St->getMemOperand());
21164   else
21165     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
21166                          St->getBasePtr(), St->getMemOperand());
21167 
21168   TFOps.push_back(NewST);
21169 
21170   // If we improved every element of TFOps, then we've lost the dependence on
21171   // NewChain to successors of St and we need to add it back to TFOps. Do so at
21172   // the beginning to keep relative order consistent with FindBetterChains.
21173   auto hasImprovedChain = [&](SDValue ST) -> bool {
21174     return ST->getOperand(0) != NewChain;
21175   };
21176   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
21177   if (AddNewChain)
21178     TFOps.insert(TFOps.begin(), NewChain);
21179 
21180   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
21181   CombineTo(St, TF);
21182 
21183   // Add TF and its operands to the worklist.
21184   AddToWorklist(TF.getNode());
21185   for (const SDValue &Op : TF->ops())
21186     AddToWorklist(Op.getNode());
21187   AddToWorklist(STChain);
21188   return true;
21189 }
21190 
findBetterNeighborChains(StoreSDNode * St)21191 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
21192   if (OptLevel == CodeGenOpt::None)
21193     return false;
21194 
21195   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
21196 
21197   // We must have a base and an offset.
21198   if (!BasePtr.getBase().getNode())
21199     return false;
21200 
21201   // Do not handle stores to undef base pointers.
21202   if (BasePtr.getBase().isUndef())
21203     return false;
21204 
21205   // Directly improve a chain of disjoint stores starting at St.
21206   if (parallelizeChainedStores(St))
21207     return true;
21208 
21209   // Improve St's Chain..
21210   SDValue BetterChain = FindBetterChain(St, St->getChain());
21211   if (St->getChain() != BetterChain) {
21212     replaceStoreChain(St, BetterChain);
21213     return true;
21214   }
21215   return false;
21216 }
21217 
21218 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOpt::Level OptLevel)21219 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
21220                            CodeGenOpt::Level OptLevel) {
21221   /// This is the main entry point to this class.
21222   DAGCombiner(*this, AA, OptLevel).Run(Level);
21223 }
21224