1 //===-- RISCVISelLowering.cpp - RISC-V DAG Lowering Implementation  -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that RISC-V uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "RISCVISelLowering.h"
15 #include "MCTargetDesc/RISCVMatInt.h"
16 #include "RISCV.h"
17 #include "RISCVMachineFunctionInfo.h"
18 #include "RISCVRegisterInfo.h"
19 #include "RISCVSubtarget.h"
20 #include "RISCVTargetMachine.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/Analysis/MemoryLocation.h"
24 #include "llvm/Analysis/VectorUtils.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunction.h"
27 #include "llvm/CodeGen/MachineInstrBuilder.h"
28 #include "llvm/CodeGen/MachineJumpTableInfo.h"
29 #include "llvm/CodeGen/MachineRegisterInfo.h"
30 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
31 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
32 #include "llvm/CodeGen/ValueTypes.h"
33 #include "llvm/IR/DiagnosticInfo.h"
34 #include "llvm/IR/DiagnosticPrinter.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/IntrinsicsRISCV.h"
38 #include "llvm/IR/PatternMatch.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/Support/InstructionCost.h"
43 #include "llvm/Support/KnownBits.h"
44 #include "llvm/Support/MathExtras.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include <optional>
47 
48 using namespace llvm;
49 
50 #define DEBUG_TYPE "riscv-lower"
51 
52 STATISTIC(NumTailCalls, "Number of tail calls");
53 
54 static cl::opt<unsigned> ExtensionMaxWebSize(
55     DEBUG_TYPE "-ext-max-web-size", cl::Hidden,
56     cl::desc("Give the maximum size (in number of nodes) of the web of "
57              "instructions that we will consider for VW expansion"),
58     cl::init(18));
59 
60 static cl::opt<bool>
61     AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden,
62                      cl::desc("Allow the formation of VW_W operations (e.g., "
63                               "VWADD_W) with splat constants"),
64                      cl::init(false));
65 
66 static cl::opt<unsigned> NumRepeatedDivisors(
67     DEBUG_TYPE "-fp-repeated-divisors", cl::Hidden,
68     cl::desc("Set the minimum number of repetitions of a divisor to allow "
69              "transformation to multiplications by the reciprocal"),
70     cl::init(2));
71 
72 static cl::opt<int>
73     FPImmCost(DEBUG_TYPE "-fpimm-cost", cl::Hidden,
74               cl::desc("Give the maximum number of instructions that we will "
75                        "use for creating a floating-point immediate value"),
76               cl::init(2));
77 
78 static cl::opt<bool>
79     RV64LegalI32("riscv-experimental-rv64-legal-i32", cl::ReallyHidden,
80                  cl::desc("Make i32 a legal type for SelectionDAG on RV64."));
81 
RISCVTargetLowering(const TargetMachine & TM,const RISCVSubtarget & STI)82 RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
83                                          const RISCVSubtarget &STI)
84     : TargetLowering(TM), Subtarget(STI) {
85 
86   RISCVABI::ABI ABI = Subtarget.getTargetABI();
87   assert(ABI != RISCVABI::ABI_Unknown && "Improperly initialised target ABI");
88 
89   if ((ABI == RISCVABI::ABI_ILP32F || ABI == RISCVABI::ABI_LP64F) &&
90       !Subtarget.hasStdExtF()) {
91     errs() << "Hard-float 'f' ABI can't be used for a target that "
92                 "doesn't support the F instruction set extension (ignoring "
93                           "target-abi)\n";
94     ABI = Subtarget.is64Bit() ? RISCVABI::ABI_LP64 : RISCVABI::ABI_ILP32;
95   } else if ((ABI == RISCVABI::ABI_ILP32D || ABI == RISCVABI::ABI_LP64D) &&
96              !Subtarget.hasStdExtD()) {
97     errs() << "Hard-float 'd' ABI can't be used for a target that "
98               "doesn't support the D instruction set extension (ignoring "
99               "target-abi)\n";
100     ABI = Subtarget.is64Bit() ? RISCVABI::ABI_LP64 : RISCVABI::ABI_ILP32;
101   }
102 
103   switch (ABI) {
104   default:
105     report_fatal_error("Don't know how to lower this ABI");
106   case RISCVABI::ABI_ILP32:
107   case RISCVABI::ABI_ILP32E:
108   case RISCVABI::ABI_LP64E:
109   case RISCVABI::ABI_ILP32F:
110   case RISCVABI::ABI_ILP32D:
111   case RISCVABI::ABI_LP64:
112   case RISCVABI::ABI_LP64F:
113   case RISCVABI::ABI_LP64D:
114     break;
115   }
116 
117   MVT XLenVT = Subtarget.getXLenVT();
118 
119   // Set up the register classes.
120   addRegisterClass(XLenVT, &RISCV::GPRRegClass);
121   if (Subtarget.is64Bit() && RV64LegalI32)
122     addRegisterClass(MVT::i32, &RISCV::GPRRegClass);
123 
124   if (Subtarget.hasStdExtZfhmin())
125     addRegisterClass(MVT::f16, &RISCV::FPR16RegClass);
126   if (Subtarget.hasStdExtZfbfmin())
127     addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass);
128   if (Subtarget.hasStdExtF())
129     addRegisterClass(MVT::f32, &RISCV::FPR32RegClass);
130   if (Subtarget.hasStdExtD())
131     addRegisterClass(MVT::f64, &RISCV::FPR64RegClass);
132   if (Subtarget.hasStdExtZhinxmin())
133     addRegisterClass(MVT::f16, &RISCV::GPRF16RegClass);
134   if (Subtarget.hasStdExtZfinx())
135     addRegisterClass(MVT::f32, &RISCV::GPRF32RegClass);
136   if (Subtarget.hasStdExtZdinx()) {
137     if (Subtarget.is64Bit())
138       addRegisterClass(MVT::f64, &RISCV::GPRRegClass);
139     else
140       addRegisterClass(MVT::f64, &RISCV::GPRPairRegClass);
141   }
142 
143   static const MVT::SimpleValueType BoolVecVTs[] = {
144       MVT::nxv1i1,  MVT::nxv2i1,  MVT::nxv4i1, MVT::nxv8i1,
145       MVT::nxv16i1, MVT::nxv32i1, MVT::nxv64i1};
146   static const MVT::SimpleValueType IntVecVTs[] = {
147       MVT::nxv1i8,  MVT::nxv2i8,   MVT::nxv4i8,   MVT::nxv8i8,  MVT::nxv16i8,
148       MVT::nxv32i8, MVT::nxv64i8,  MVT::nxv1i16,  MVT::nxv2i16, MVT::nxv4i16,
149       MVT::nxv8i16, MVT::nxv16i16, MVT::nxv32i16, MVT::nxv1i32, MVT::nxv2i32,
150       MVT::nxv4i32, MVT::nxv8i32,  MVT::nxv16i32, MVT::nxv1i64, MVT::nxv2i64,
151       MVT::nxv4i64, MVT::nxv8i64};
152   static const MVT::SimpleValueType F16VecVTs[] = {
153       MVT::nxv1f16, MVT::nxv2f16,  MVT::nxv4f16,
154       MVT::nxv8f16, MVT::nxv16f16, MVT::nxv32f16};
155   static const MVT::SimpleValueType BF16VecVTs[] = {
156       MVT::nxv1bf16, MVT::nxv2bf16,  MVT::nxv4bf16,
157       MVT::nxv8bf16, MVT::nxv16bf16, MVT::nxv32bf16};
158   static const MVT::SimpleValueType F32VecVTs[] = {
159       MVT::nxv1f32, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv8f32, MVT::nxv16f32};
160   static const MVT::SimpleValueType F64VecVTs[] = {
161       MVT::nxv1f64, MVT::nxv2f64, MVT::nxv4f64, MVT::nxv8f64};
162 
163   if (Subtarget.hasVInstructions()) {
164     auto addRegClassForRVV = [this](MVT VT) {
165       // Disable the smallest fractional LMUL types if ELEN is less than
166       // RVVBitsPerBlock.
167       unsigned MinElts = RISCV::RVVBitsPerBlock / Subtarget.getELen();
168       if (VT.getVectorMinNumElements() < MinElts)
169         return;
170 
171       unsigned Size = VT.getSizeInBits().getKnownMinValue();
172       const TargetRegisterClass *RC;
173       if (Size <= RISCV::RVVBitsPerBlock)
174         RC = &RISCV::VRRegClass;
175       else if (Size == 2 * RISCV::RVVBitsPerBlock)
176         RC = &RISCV::VRM2RegClass;
177       else if (Size == 4 * RISCV::RVVBitsPerBlock)
178         RC = &RISCV::VRM4RegClass;
179       else if (Size == 8 * RISCV::RVVBitsPerBlock)
180         RC = &RISCV::VRM8RegClass;
181       else
182         llvm_unreachable("Unexpected size");
183 
184       addRegisterClass(VT, RC);
185     };
186 
187     for (MVT VT : BoolVecVTs)
188       addRegClassForRVV(VT);
189     for (MVT VT : IntVecVTs) {
190       if (VT.getVectorElementType() == MVT::i64 &&
191           !Subtarget.hasVInstructionsI64())
192         continue;
193       addRegClassForRVV(VT);
194     }
195 
196     if (Subtarget.hasVInstructionsF16Minimal())
197       for (MVT VT : F16VecVTs)
198         addRegClassForRVV(VT);
199 
200     if (Subtarget.hasVInstructionsBF16())
201       for (MVT VT : BF16VecVTs)
202         addRegClassForRVV(VT);
203 
204     if (Subtarget.hasVInstructionsF32())
205       for (MVT VT : F32VecVTs)
206         addRegClassForRVV(VT);
207 
208     if (Subtarget.hasVInstructionsF64())
209       for (MVT VT : F64VecVTs)
210         addRegClassForRVV(VT);
211 
212     if (Subtarget.useRVVForFixedLengthVectors()) {
213       auto addRegClassForFixedVectors = [this](MVT VT) {
214         MVT ContainerVT = getContainerForFixedLengthVector(VT);
215         unsigned RCID = getRegClassIDForVecVT(ContainerVT);
216         const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo();
217         addRegisterClass(VT, TRI.getRegClass(RCID));
218       };
219       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
220         if (useRVVForFixedLengthVectorVT(VT))
221           addRegClassForFixedVectors(VT);
222 
223       for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
224         if (useRVVForFixedLengthVectorVT(VT))
225           addRegClassForFixedVectors(VT);
226     }
227   }
228 
229   // Compute derived properties from the register classes.
230   computeRegisterProperties(STI.getRegisterInfo());
231 
232   setStackPointerRegisterToSaveRestore(RISCV::X2);
233 
234   setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, XLenVT,
235                    MVT::i1, Promote);
236   // DAGCombiner can call isLoadExtLegal for types that aren't legal.
237   setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i32,
238                    MVT::i1, Promote);
239 
240   // TODO: add all necessary setOperationAction calls.
241   setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Expand);
242 
243   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
244   setOperationAction(ISD::BR_CC, XLenVT, Expand);
245   if (RV64LegalI32 && Subtarget.is64Bit())
246     setOperationAction(ISD::BR_CC, MVT::i32, Expand);
247   setOperationAction(ISD::BRCOND, MVT::Other, Custom);
248   setOperationAction(ISD::SELECT_CC, XLenVT, Expand);
249   if (RV64LegalI32 && Subtarget.is64Bit())
250     setOperationAction(ISD::SELECT_CC, MVT::i32, Expand);
251 
252   setCondCodeAction(ISD::SETLE, XLenVT, Expand);
253   setCondCodeAction(ISD::SETGT, XLenVT, Custom);
254   setCondCodeAction(ISD::SETGE, XLenVT, Expand);
255   setCondCodeAction(ISD::SETULE, XLenVT, Expand);
256   setCondCodeAction(ISD::SETUGT, XLenVT, Custom);
257   setCondCodeAction(ISD::SETUGE, XLenVT, Expand);
258 
259   if (RV64LegalI32 && Subtarget.is64Bit())
260     setOperationAction(ISD::SETCC, MVT::i32, Promote);
261 
262   setOperationAction({ISD::STACKSAVE, ISD::STACKRESTORE}, MVT::Other, Expand);
263 
264   setOperationAction(ISD::VASTART, MVT::Other, Custom);
265   setOperationAction({ISD::VAARG, ISD::VACOPY, ISD::VAEND}, MVT::Other, Expand);
266 
267   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
268 
269   setOperationAction(ISD::EH_DWARF_CFA, MVT::i32, Custom);
270 
271   if (!Subtarget.hasStdExtZbb() && !Subtarget.hasVendorXTHeadBb())
272     setOperationAction(ISD::SIGN_EXTEND_INREG, {MVT::i8, MVT::i16}, Expand);
273 
274   if (Subtarget.is64Bit()) {
275     setOperationAction(ISD::EH_DWARF_CFA, MVT::i64, Custom);
276 
277     if (!RV64LegalI32) {
278       setOperationAction(ISD::LOAD, MVT::i32, Custom);
279       setOperationAction({ISD::ADD, ISD::SUB, ISD::SHL, ISD::SRA, ISD::SRL},
280                          MVT::i32, Custom);
281       setOperationAction(ISD::SADDO, MVT::i32, Custom);
282       setOperationAction({ISD::UADDO, ISD::USUBO, ISD::UADDSAT, ISD::USUBSAT},
283                          MVT::i32, Custom);
284     }
285   } else {
286     setLibcallName(
287         {RTLIB::SHL_I128, RTLIB::SRL_I128, RTLIB::SRA_I128, RTLIB::MUL_I128},
288         nullptr);
289     setLibcallName(RTLIB::MULO_I64, nullptr);
290   }
291 
292   if (!Subtarget.hasStdExtM() && !Subtarget.hasStdExtZmmul()) {
293     setOperationAction({ISD::MUL, ISD::MULHS, ISD::MULHU}, XLenVT, Expand);
294     if (RV64LegalI32 && Subtarget.is64Bit())
295       setOperationAction(ISD::MUL, MVT::i32, Promote);
296   } else if (Subtarget.is64Bit()) {
297     setOperationAction(ISD::MUL, MVT::i128, Custom);
298     if (!RV64LegalI32)
299       setOperationAction(ISD::MUL, MVT::i32, Custom);
300   } else {
301     setOperationAction(ISD::MUL, MVT::i64, Custom);
302   }
303 
304   if (!Subtarget.hasStdExtM()) {
305     setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM},
306                        XLenVT, Expand);
307     if (RV64LegalI32 && Subtarget.is64Bit())
308       setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM}, MVT::i32,
309                          Promote);
310   } else if (Subtarget.is64Bit()) {
311     if (!RV64LegalI32)
312       setOperationAction({ISD::SDIV, ISD::UDIV, ISD::UREM},
313                          {MVT::i8, MVT::i16, MVT::i32}, Custom);
314   }
315 
316   if (RV64LegalI32 && Subtarget.is64Bit()) {
317     setOperationAction({ISD::MULHS, ISD::MULHU}, MVT::i32, Expand);
318     setOperationAction(
319         {ISD::SDIVREM, ISD::UDIVREM, ISD::SMUL_LOHI, ISD::UMUL_LOHI}, MVT::i32,
320         Expand);
321   }
322 
323   setOperationAction(
324       {ISD::SDIVREM, ISD::UDIVREM, ISD::SMUL_LOHI, ISD::UMUL_LOHI}, XLenVT,
325       Expand);
326 
327   setOperationAction({ISD::SHL_PARTS, ISD::SRL_PARTS, ISD::SRA_PARTS}, XLenVT,
328                      Custom);
329 
330   if (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()) {
331     if (!RV64LegalI32 && Subtarget.is64Bit())
332       setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Custom);
333   } else if (Subtarget.hasVendorXTHeadBb()) {
334     if (Subtarget.is64Bit())
335       setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Custom);
336     setOperationAction({ISD::ROTL, ISD::ROTR}, XLenVT, Custom);
337   } else if (Subtarget.hasVendorXCVbitmanip()) {
338     setOperationAction(ISD::ROTL, XLenVT, Expand);
339   } else {
340     setOperationAction({ISD::ROTL, ISD::ROTR}, XLenVT, Expand);
341     if (RV64LegalI32 && Subtarget.is64Bit())
342       setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Expand);
343   }
344 
345   // With Zbb we have an XLen rev8 instruction, but not GREVI. So we'll
346   // pattern match it directly in isel.
347   setOperationAction(ISD::BSWAP, XLenVT,
348                      (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb() ||
349                       Subtarget.hasVendorXTHeadBb())
350                          ? Legal
351                          : Expand);
352   if (RV64LegalI32 && Subtarget.is64Bit())
353     setOperationAction(ISD::BSWAP, MVT::i32,
354                        (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb() ||
355                         Subtarget.hasVendorXTHeadBb())
356                            ? Promote
357                            : Expand);
358 
359 
360   if (Subtarget.hasVendorXCVbitmanip()) {
361     setOperationAction(ISD::BITREVERSE, XLenVT, Legal);
362   } else {
363     // Zbkb can use rev8+brev8 to implement bitreverse.
364     setOperationAction(ISD::BITREVERSE, XLenVT,
365                        Subtarget.hasStdExtZbkb() ? Custom : Expand);
366   }
367 
368   if (Subtarget.hasStdExtZbb()) {
369     setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, XLenVT,
370                        Legal);
371     if (RV64LegalI32 && Subtarget.is64Bit())
372       setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, MVT::i32,
373                          Promote);
374 
375     if (Subtarget.is64Bit()) {
376       if (RV64LegalI32)
377         setOperationAction(ISD::CTTZ, MVT::i32, Legal);
378       else
379         setOperationAction({ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF}, MVT::i32, Custom);
380     }
381   } else if (!Subtarget.hasVendorXCVbitmanip()) {
382     setOperationAction({ISD::CTTZ, ISD::CTPOP}, XLenVT, Expand);
383     if (RV64LegalI32 && Subtarget.is64Bit())
384       setOperationAction({ISD::CTTZ, ISD::CTPOP}, MVT::i32, Expand);
385   }
386 
387   if (Subtarget.hasStdExtZbb() || Subtarget.hasVendorXTHeadBb() ||
388       Subtarget.hasVendorXCVbitmanip()) {
389     // We need the custom lowering to make sure that the resulting sequence
390     // for the 32bit case is efficient on 64bit targets.
391     if (Subtarget.is64Bit()) {
392       if (RV64LegalI32) {
393         setOperationAction(ISD::CTLZ, MVT::i32,
394                            Subtarget.hasStdExtZbb() ? Legal : Promote);
395         if (!Subtarget.hasStdExtZbb())
396           setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Promote);
397       } else
398         setOperationAction({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, MVT::i32, Custom);
399     }
400   } else {
401     setOperationAction(ISD::CTLZ, XLenVT, Expand);
402     if (RV64LegalI32 && Subtarget.is64Bit())
403       setOperationAction(ISD::CTLZ, MVT::i32, Expand);
404   }
405 
406   if (!RV64LegalI32 && Subtarget.is64Bit() &&
407       !Subtarget.hasShortForwardBranchOpt())
408     setOperationAction(ISD::ABS, MVT::i32, Custom);
409 
410   // We can use PseudoCCSUB to implement ABS.
411   if (Subtarget.hasShortForwardBranchOpt())
412     setOperationAction(ISD::ABS, XLenVT, Legal);
413 
414   if (!Subtarget.hasVendorXTHeadCondMov())
415     setOperationAction(ISD::SELECT, XLenVT, Custom);
416 
417   if (RV64LegalI32 && Subtarget.is64Bit())
418     setOperationAction(ISD::SELECT, MVT::i32, Promote);
419 
420   static const unsigned FPLegalNodeTypes[] = {
421       ISD::FMINNUM,        ISD::FMAXNUM,       ISD::LRINT,
422       ISD::LLRINT,         ISD::LROUND,        ISD::LLROUND,
423       ISD::STRICT_LRINT,   ISD::STRICT_LLRINT, ISD::STRICT_LROUND,
424       ISD::STRICT_LLROUND, ISD::STRICT_FMA,    ISD::STRICT_FADD,
425       ISD::STRICT_FSUB,    ISD::STRICT_FMUL,   ISD::STRICT_FDIV,
426       ISD::STRICT_FSQRT,   ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS};
427 
428   static const ISD::CondCode FPCCToExpand[] = {
429       ISD::SETOGT, ISD::SETOGE, ISD::SETONE, ISD::SETUEQ, ISD::SETUGT,
430       ISD::SETUGE, ISD::SETULT, ISD::SETULE, ISD::SETUNE, ISD::SETGT,
431       ISD::SETGE,  ISD::SETNE,  ISD::SETO,   ISD::SETUO};
432 
433   static const unsigned FPOpToExpand[] = {
434       ISD::FSIN, ISD::FCOS,       ISD::FSINCOS,   ISD::FPOW,
435       ISD::FREM};
436 
437   static const unsigned FPRndMode[] = {
438       ISD::FCEIL, ISD::FFLOOR, ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
439       ISD::FROUNDEVEN};
440 
441   if (Subtarget.hasStdExtZfhminOrZhinxmin())
442     setOperationAction(ISD::BITCAST, MVT::i16, Custom);
443 
444   static const unsigned ZfhminZfbfminPromoteOps[] = {
445       ISD::FMINNUM,      ISD::FMAXNUM,       ISD::FADD,
446       ISD::FSUB,         ISD::FMUL,          ISD::FMA,
447       ISD::FDIV,         ISD::FSQRT,         ISD::FABS,
448       ISD::FNEG,         ISD::STRICT_FMA,    ISD::STRICT_FADD,
449       ISD::STRICT_FSUB,  ISD::STRICT_FMUL,   ISD::STRICT_FDIV,
450       ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS,
451       ISD::SETCC,        ISD::FCEIL,         ISD::FFLOOR,
452       ISD::FTRUNC,       ISD::FRINT,         ISD::FROUND,
453       ISD::FROUNDEVEN,   ISD::SELECT};
454 
455   if (Subtarget.hasStdExtZfbfmin()) {
456     setOperationAction(ISD::BITCAST, MVT::i16, Custom);
457     setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
458     setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
459     setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
460     setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
461     setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
462     setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand);
463     setOperationAction(ISD::BR_CC, MVT::bf16, Expand);
464     setOperationAction(ZfhminZfbfminPromoteOps, MVT::bf16, Promote);
465     setOperationAction(ISD::FREM, MVT::bf16, Promote);
466     // FIXME: Need to promote bf16 FCOPYSIGN to f32, but the
467     // DAGCombiner::visitFP_ROUND probably needs improvements first.
468     setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
469   }
470 
471   if (Subtarget.hasStdExtZfhminOrZhinxmin()) {
472     if (Subtarget.hasStdExtZfhOrZhinx()) {
473       setOperationAction(FPLegalNodeTypes, MVT::f16, Legal);
474       setOperationAction(FPRndMode, MVT::f16,
475                          Subtarget.hasStdExtZfa() ? Legal : Custom);
476       setOperationAction(ISD::SELECT, MVT::f16, Custom);
477       setOperationAction(ISD::IS_FPCLASS, MVT::f16, Custom);
478     } else {
479       setOperationAction(ZfhminZfbfminPromoteOps, MVT::f16, Promote);
480       setOperationAction({ISD::STRICT_LRINT, ISD::STRICT_LLRINT,
481                           ISD::STRICT_LROUND, ISD::STRICT_LLROUND},
482                          MVT::f16, Legal);
483       // FIXME: Need to promote f16 FCOPYSIGN to f32, but the
484       // DAGCombiner::visitFP_ROUND probably needs improvements first.
485       setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
486     }
487 
488     setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Legal);
489     setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal);
490     setCondCodeAction(FPCCToExpand, MVT::f16, Expand);
491     setOperationAction(ISD::SELECT_CC, MVT::f16, Expand);
492     setOperationAction(ISD::BR_CC, MVT::f16, Expand);
493 
494     setOperationAction(ISD::FNEARBYINT, MVT::f16,
495                        Subtarget.hasStdExtZfa() ? Legal : Promote);
496     setOperationAction({ISD::FREM, ISD::FPOW, ISD::FPOWI,
497                         ISD::FCOS, ISD::FSIN, ISD::FSINCOS, ISD::FEXP,
498                         ISD::FEXP2, ISD::FEXP10, ISD::FLOG, ISD::FLOG2,
499                         ISD::FLOG10},
500                        MVT::f16, Promote);
501 
502     // FIXME: Need to promote f16 STRICT_* to f32 libcalls, but we don't have
503     // complete support for all operations in LegalizeDAG.
504     setOperationAction({ISD::STRICT_FCEIL, ISD::STRICT_FFLOOR,
505                         ISD::STRICT_FNEARBYINT, ISD::STRICT_FRINT,
506                         ISD::STRICT_FROUND, ISD::STRICT_FROUNDEVEN,
507                         ISD::STRICT_FTRUNC},
508                        MVT::f16, Promote);
509 
510     // We need to custom promote this.
511     if (Subtarget.is64Bit())
512       setOperationAction(ISD::FPOWI, MVT::i32, Custom);
513 
514     if (!Subtarget.hasStdExtZfa())
515       setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f16, Custom);
516   }
517 
518   if (Subtarget.hasStdExtFOrZfinx()) {
519     setOperationAction(FPLegalNodeTypes, MVT::f32, Legal);
520     setOperationAction(FPRndMode, MVT::f32,
521                        Subtarget.hasStdExtZfa() ? Legal : Custom);
522     setCondCodeAction(FPCCToExpand, MVT::f32, Expand);
523     setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
524     setOperationAction(ISD::SELECT, MVT::f32, Custom);
525     setOperationAction(ISD::BR_CC, MVT::f32, Expand);
526     setOperationAction(FPOpToExpand, MVT::f32, Expand);
527     setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
528     setTruncStoreAction(MVT::f32, MVT::f16, Expand);
529     setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
530     setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
531     setOperationAction(ISD::IS_FPCLASS, MVT::f32, Custom);
532     setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom);
533     setOperationAction(ISD::FP_TO_BF16, MVT::f32,
534                        Subtarget.isSoftFPABI() ? LibCall : Custom);
535     setOperationAction(ISD::FP_TO_FP16, MVT::f32, Custom);
536     setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
537 
538     if (Subtarget.hasStdExtZfa())
539       setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
540     else
541       setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f32, Custom);
542   }
543 
544   if (Subtarget.hasStdExtFOrZfinx() && Subtarget.is64Bit())
545     setOperationAction(ISD::BITCAST, MVT::i32, Custom);
546 
547   if (Subtarget.hasStdExtDOrZdinx()) {
548     setOperationAction(FPLegalNodeTypes, MVT::f64, Legal);
549 
550     if (Subtarget.hasStdExtZfa()) {
551       setOperationAction(FPRndMode, MVT::f64, Legal);
552       setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal);
553       setOperationAction(ISD::BITCAST, MVT::i64, Custom);
554       setOperationAction(ISD::BITCAST, MVT::f64, Custom);
555     } else {
556       if (Subtarget.is64Bit())
557         setOperationAction(FPRndMode, MVT::f64, Custom);
558 
559       setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f64, Custom);
560     }
561 
562     setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Legal);
563     setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Legal);
564     setCondCodeAction(FPCCToExpand, MVT::f64, Expand);
565     setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
566     setOperationAction(ISD::SELECT, MVT::f64, Custom);
567     setOperationAction(ISD::BR_CC, MVT::f64, Expand);
568     setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
569     setTruncStoreAction(MVT::f64, MVT::f32, Expand);
570     setOperationAction(FPOpToExpand, MVT::f64, Expand);
571     setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
572     setTruncStoreAction(MVT::f64, MVT::f16, Expand);
573     setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
574     setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
575     setOperationAction(ISD::IS_FPCLASS, MVT::f64, Custom);
576     setOperationAction(ISD::BF16_TO_FP, MVT::f64, Custom);
577     setOperationAction(ISD::FP_TO_BF16, MVT::f64,
578                        Subtarget.isSoftFPABI() ? LibCall : Custom);
579     setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);
580     setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
581   }
582 
583   if (Subtarget.is64Bit()) {
584     setOperationAction({ISD::FP_TO_UINT, ISD::FP_TO_SINT,
585                         ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT},
586                        MVT::i32, Custom);
587     setOperationAction(ISD::LROUND, MVT::i32, Custom);
588   }
589 
590   if (Subtarget.hasStdExtFOrZfinx()) {
591     setOperationAction({ISD::FP_TO_UINT_SAT, ISD::FP_TO_SINT_SAT}, XLenVT,
592                        Custom);
593 
594     setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT,
595                         ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
596                        XLenVT, Legal);
597 
598     if (RV64LegalI32 && Subtarget.is64Bit())
599       setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT,
600                           ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
601                          MVT::i32, Legal);
602 
603     setOperationAction(ISD::GET_ROUNDING, XLenVT, Custom);
604     setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
605   }
606 
607   setOperationAction({ISD::GlobalAddress, ISD::BlockAddress, ISD::ConstantPool,
608                       ISD::JumpTable},
609                      XLenVT, Custom);
610 
611   setOperationAction(ISD::GlobalTLSAddress, XLenVT, Custom);
612 
613   if (Subtarget.is64Bit())
614     setOperationAction(ISD::Constant, MVT::i64, Custom);
615 
616   // TODO: On M-mode only targets, the cycle[h] CSR may not be present.
617   // Unfortunately this can't be determined just from the ISA naming string.
618   setOperationAction(ISD::READCYCLECOUNTER, MVT::i64,
619                      Subtarget.is64Bit() ? Legal : Custom);
620 
621   setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal);
622   setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
623   if (Subtarget.is64Bit())
624     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i32, Custom);
625 
626   if (Subtarget.hasStdExtZicbop()) {
627     setOperationAction(ISD::PREFETCH, MVT::Other, Legal);
628   }
629 
630   if (Subtarget.hasStdExtA()) {
631     setMaxAtomicSizeInBitsSupported(Subtarget.getXLen());
632     setMinCmpXchgSizeInBits(32);
633   } else if (Subtarget.hasForcedAtomics()) {
634     setMaxAtomicSizeInBitsSupported(Subtarget.getXLen());
635   } else {
636     setMaxAtomicSizeInBitsSupported(0);
637   }
638 
639   setOperationAction(ISD::ATOMIC_FENCE, MVT::Other, Custom);
640 
641   setBooleanContents(ZeroOrOneBooleanContent);
642 
643   if (Subtarget.hasVInstructions()) {
644     setBooleanVectorContents(ZeroOrOneBooleanContent);
645 
646     setOperationAction(ISD::VSCALE, XLenVT, Custom);
647     if (RV64LegalI32 && Subtarget.is64Bit())
648       setOperationAction(ISD::VSCALE, MVT::i32, Custom);
649 
650     // RVV intrinsics may have illegal operands.
651     // We also need to custom legalize vmv.x.s.
652     setOperationAction({ISD::INTRINSIC_WO_CHAIN, ISD::INTRINSIC_W_CHAIN,
653                         ISD::INTRINSIC_VOID},
654                        {MVT::i8, MVT::i16}, Custom);
655     if (Subtarget.is64Bit())
656       setOperationAction({ISD::INTRINSIC_W_CHAIN, ISD::INTRINSIC_VOID},
657                          MVT::i32, Custom);
658     else
659       setOperationAction({ISD::INTRINSIC_WO_CHAIN, ISD::INTRINSIC_W_CHAIN},
660                          MVT::i64, Custom);
661 
662     setOperationAction({ISD::INTRINSIC_W_CHAIN, ISD::INTRINSIC_VOID},
663                        MVT::Other, Custom);
664 
665     static const unsigned IntegerVPOps[] = {
666         ISD::VP_ADD,         ISD::VP_SUB,         ISD::VP_MUL,
667         ISD::VP_SDIV,        ISD::VP_UDIV,        ISD::VP_SREM,
668         ISD::VP_UREM,        ISD::VP_AND,         ISD::VP_OR,
669         ISD::VP_XOR,         ISD::VP_ASHR,        ISD::VP_LSHR,
670         ISD::VP_SHL,         ISD::VP_REDUCE_ADD,  ISD::VP_REDUCE_AND,
671         ISD::VP_REDUCE_OR,   ISD::VP_REDUCE_XOR,  ISD::VP_REDUCE_SMAX,
672         ISD::VP_REDUCE_SMIN, ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN,
673         ISD::VP_MERGE,       ISD::VP_SELECT,      ISD::VP_FP_TO_SINT,
674         ISD::VP_FP_TO_UINT,  ISD::VP_SETCC,       ISD::VP_SIGN_EXTEND,
675         ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE,    ISD::VP_SMIN,
676         ISD::VP_SMAX,        ISD::VP_UMIN,        ISD::VP_UMAX,
677         ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE};
678 
679     static const unsigned FloatingPointVPOps[] = {
680         ISD::VP_FADD,        ISD::VP_FSUB,        ISD::VP_FMUL,
681         ISD::VP_FDIV,        ISD::VP_FNEG,        ISD::VP_FABS,
682         ISD::VP_FMA,         ISD::VP_REDUCE_FADD, ISD::VP_REDUCE_SEQ_FADD,
683         ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_MERGE,
684         ISD::VP_SELECT,      ISD::VP_SINT_TO_FP,  ISD::VP_UINT_TO_FP,
685         ISD::VP_SETCC,       ISD::VP_FP_ROUND,    ISD::VP_FP_EXTEND,
686         ISD::VP_SQRT,        ISD::VP_FMINNUM,     ISD::VP_FMAXNUM,
687         ISD::VP_FCEIL,       ISD::VP_FFLOOR,      ISD::VP_FROUND,
688         ISD::VP_FROUNDEVEN,  ISD::VP_FCOPYSIGN,   ISD::VP_FROUNDTOZERO,
689         ISD::VP_FRINT,       ISD::VP_FNEARBYINT,  ISD::VP_IS_FPCLASS,
690         ISD::VP_FMINIMUM,    ISD::VP_FMAXIMUM,    ISD::EXPERIMENTAL_VP_REVERSE,
691         ISD::EXPERIMENTAL_VP_SPLICE};
692 
693     static const unsigned IntegerVecReduceOps[] = {
694         ISD::VECREDUCE_ADD,  ISD::VECREDUCE_AND,  ISD::VECREDUCE_OR,
695         ISD::VECREDUCE_XOR,  ISD::VECREDUCE_SMAX, ISD::VECREDUCE_SMIN,
696         ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN};
697 
698     static const unsigned FloatingPointVecReduceOps[] = {
699         ISD::VECREDUCE_FADD, ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_FMIN,
700         ISD::VECREDUCE_FMAX};
701 
702     if (!Subtarget.is64Bit()) {
703       // We must custom-lower certain vXi64 operations on RV32 due to the vector
704       // element type being illegal.
705       setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT},
706                          MVT::i64, Custom);
707 
708       setOperationAction(IntegerVecReduceOps, MVT::i64, Custom);
709 
710       setOperationAction({ISD::VP_REDUCE_ADD, ISD::VP_REDUCE_AND,
711                           ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR,
712                           ISD::VP_REDUCE_SMAX, ISD::VP_REDUCE_SMIN,
713                           ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN},
714                          MVT::i64, Custom);
715     }
716 
717     for (MVT VT : BoolVecVTs) {
718       if (!isTypeLegal(VT))
719         continue;
720 
721       setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
722 
723       // Mask VTs are custom-expanded into a series of standard nodes
724       setOperationAction({ISD::TRUNCATE, ISD::CONCAT_VECTORS,
725                           ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR,
726                           ISD::SCALAR_TO_VECTOR},
727                          VT, Custom);
728 
729       setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
730                          Custom);
731 
732       setOperationAction(ISD::SELECT, VT, Custom);
733       setOperationAction(
734           {ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
735           Expand);
736 
737       setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR}, VT, Custom);
738 
739       setOperationAction(
740           {ISD::VECREDUCE_AND, ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR}, VT,
741           Custom);
742 
743       setOperationAction(
744           {ISD::VP_REDUCE_AND, ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR}, VT,
745           Custom);
746 
747       // RVV has native int->float & float->int conversions where the
748       // element type sizes are within one power-of-two of each other. Any
749       // wider distances between type sizes have to be lowered as sequences
750       // which progressively narrow the gap in stages.
751       setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT,
752                           ISD::FP_TO_UINT, ISD::STRICT_SINT_TO_FP,
753                           ISD::STRICT_UINT_TO_FP, ISD::STRICT_FP_TO_SINT,
754                           ISD::STRICT_FP_TO_UINT},
755                          VT, Custom);
756       setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
757                          Custom);
758 
759       // Expand all extending loads to types larger than this, and truncating
760       // stores from types larger than this.
761       for (MVT OtherVT : MVT::integer_scalable_vector_valuetypes()) {
762         setTruncStoreAction(VT, OtherVT, Expand);
763         setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, VT,
764                          OtherVT, Expand);
765       }
766 
767       setOperationAction({ISD::VP_FP_TO_SINT, ISD::VP_FP_TO_UINT,
768                           ISD::VP_TRUNCATE, ISD::VP_SETCC},
769                          VT, Custom);
770 
771       setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
772       setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
773 
774       setOperationAction(ISD::VECTOR_REVERSE, VT, Custom);
775 
776       setOperationAction(ISD::EXPERIMENTAL_VP_SPLICE, VT, Custom);
777       setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom);
778 
779       setOperationPromotedToType(
780           ISD::VECTOR_SPLICE, VT,
781           MVT::getVectorVT(MVT::i8, VT.getVectorElementCount()));
782     }
783 
784     for (MVT VT : IntVecVTs) {
785       if (!isTypeLegal(VT))
786         continue;
787 
788       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
789       setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom);
790 
791       // Vectors implement MULHS/MULHU.
792       setOperationAction({ISD::SMUL_LOHI, ISD::UMUL_LOHI}, VT, Expand);
793 
794       // nxvXi64 MULHS/MULHU requires the V extension instead of Zve64*.
795       if (VT.getVectorElementType() == MVT::i64 && !Subtarget.hasStdExtV())
796         setOperationAction({ISD::MULHU, ISD::MULHS}, VT, Expand);
797 
798       setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, VT,
799                          Legal);
800 
801       // Custom-lower extensions and truncations from/to mask types.
802       setOperationAction({ISD::ANY_EXTEND, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND},
803                          VT, Custom);
804 
805       // RVV has native int->float & float->int conversions where the
806       // element type sizes are within one power-of-two of each other. Any
807       // wider distances between type sizes have to be lowered as sequences
808       // which progressively narrow the gap in stages.
809       setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT,
810                           ISD::FP_TO_UINT, ISD::STRICT_SINT_TO_FP,
811                           ISD::STRICT_UINT_TO_FP, ISD::STRICT_FP_TO_SINT,
812                           ISD::STRICT_FP_TO_UINT},
813                          VT, Custom);
814       setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
815                          Custom);
816       setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
817       setOperationAction({ISD::AVGFLOORU, ISD::AVGCEILU, ISD::SADDSAT,
818                           ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT},
819                          VT, Legal);
820 
821       // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
822       // nodes which truncate by one power of two at a time.
823       setOperationAction(ISD::TRUNCATE, VT, Custom);
824 
825       // Custom-lower insert/extract operations to simplify patterns.
826       setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
827                          Custom);
828 
829       // Custom-lower reduction operations to set up the corresponding custom
830       // nodes' operands.
831       setOperationAction(IntegerVecReduceOps, VT, Custom);
832 
833       setOperationAction(IntegerVPOps, VT, Custom);
834 
835       setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
836 
837       setOperationAction({ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER},
838                          VT, Custom);
839 
840       setOperationAction(
841           {ISD::VP_LOAD, ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
842            ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER, ISD::VP_SCATTER},
843           VT, Custom);
844 
845       setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR,
846                           ISD::EXTRACT_SUBVECTOR, ISD::SCALAR_TO_VECTOR},
847                          VT, Custom);
848 
849       setOperationAction(ISD::SELECT, VT, Custom);
850       setOperationAction(ISD::SELECT_CC, VT, Expand);
851 
852       setOperationAction({ISD::STEP_VECTOR, ISD::VECTOR_REVERSE}, VT, Custom);
853 
854       for (MVT OtherVT : MVT::integer_scalable_vector_valuetypes()) {
855         setTruncStoreAction(VT, OtherVT, Expand);
856         setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, VT,
857                          OtherVT, Expand);
858       }
859 
860       setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
861       setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
862 
863       // Splice
864       setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
865 
866       if (Subtarget.hasStdExtZvkb()) {
867         setOperationAction(ISD::BSWAP, VT, Legal);
868         setOperationAction(ISD::VP_BSWAP, VT, Custom);
869       } else {
870         setOperationAction({ISD::BSWAP, ISD::VP_BSWAP}, VT, Expand);
871         setOperationAction({ISD::ROTL, ISD::ROTR}, VT, Expand);
872       }
873 
874       if (Subtarget.hasStdExtZvbb()) {
875         setOperationAction(ISD::BITREVERSE, VT, Legal);
876         setOperationAction(ISD::VP_BITREVERSE, VT, Custom);
877         setOperationAction({ISD::VP_CTLZ, ISD::VP_CTLZ_ZERO_UNDEF, ISD::VP_CTTZ,
878                             ISD::VP_CTTZ_ZERO_UNDEF, ISD::VP_CTPOP},
879                            VT, Custom);
880       } else {
881         setOperationAction({ISD::BITREVERSE, ISD::VP_BITREVERSE}, VT, Expand);
882         setOperationAction({ISD::CTLZ, ISD::CTTZ, ISD::CTPOP}, VT, Expand);
883         setOperationAction({ISD::VP_CTLZ, ISD::VP_CTLZ_ZERO_UNDEF, ISD::VP_CTTZ,
884                             ISD::VP_CTTZ_ZERO_UNDEF, ISD::VP_CTPOP},
885                            VT, Expand);
886 
887         // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if element of VT in the
888         // range of f32.
889         EVT FloatVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
890         if (isTypeLegal(FloatVT)) {
891           setOperationAction({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF,
892                               ISD::CTTZ_ZERO_UNDEF, ISD::VP_CTLZ,
893                               ISD::VP_CTLZ_ZERO_UNDEF, ISD::VP_CTTZ_ZERO_UNDEF},
894                              VT, Custom);
895         }
896       }
897     }
898 
899     // Expand various CCs to best match the RVV ISA, which natively supports UNE
900     // but no other unordered comparisons, and supports all ordered comparisons
901     // except ONE. Additionally, we expand GT,OGT,GE,OGE for optimization
902     // purposes; they are expanded to their swapped-operand CCs (LT,OLT,LE,OLE),
903     // and we pattern-match those back to the "original", swapping operands once
904     // more. This way we catch both operations and both "vf" and "fv" forms with
905     // fewer patterns.
906     static const ISD::CondCode VFPCCToExpand[] = {
907         ISD::SETO,   ISD::SETONE, ISD::SETUEQ, ISD::SETUGT,
908         ISD::SETUGE, ISD::SETULT, ISD::SETULE, ISD::SETUO,
909         ISD::SETGT,  ISD::SETOGT, ISD::SETGE,  ISD::SETOGE,
910     };
911 
912     // TODO: support more ops.
913     static const unsigned ZvfhminPromoteOps[] = {
914         ISD::FMINNUM,     ISD::FMAXNUM,      ISD::FADD,        ISD::FSUB,
915         ISD::FMUL,        ISD::FMA,          ISD::FDIV,        ISD::FSQRT,
916         ISD::FABS,        ISD::FNEG,         ISD::FCOPYSIGN,   ISD::FCEIL,
917         ISD::FFLOOR,      ISD::FROUND,       ISD::FROUNDEVEN,  ISD::FRINT,
918         ISD::FNEARBYINT,  ISD::IS_FPCLASS,   ISD::SETCC,       ISD::FMAXIMUM,
919         ISD::FMINIMUM,    ISD::STRICT_FADD,  ISD::STRICT_FSUB, ISD::STRICT_FMUL,
920         ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA};
921 
922     // TODO: support more vp ops.
923     static const unsigned ZvfhminPromoteVPOps[] = {
924         ISD::VP_FADD,        ISD::VP_FSUB,         ISD::VP_FMUL,
925         ISD::VP_FDIV,        ISD::VP_FNEG,         ISD::VP_FABS,
926         ISD::VP_FMA,         ISD::VP_REDUCE_FADD,  ISD::VP_REDUCE_SEQ_FADD,
927         ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX,  ISD::VP_SQRT,
928         ISD::VP_FMINNUM,     ISD::VP_FMAXNUM,      ISD::VP_FCEIL,
929         ISD::VP_FFLOOR,      ISD::VP_FROUND,       ISD::VP_FROUNDEVEN,
930         ISD::VP_FCOPYSIGN,   ISD::VP_FROUNDTOZERO, ISD::VP_FRINT,
931         ISD::VP_FNEARBYINT,  ISD::VP_SETCC,        ISD::VP_FMINIMUM,
932         ISD::VP_FMAXIMUM};
933 
934     // Sets common operation actions on RVV floating-point vector types.
935     const auto SetCommonVFPActions = [&](MVT VT) {
936       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
937       // RVV has native FP_ROUND & FP_EXTEND conversions where the element type
938       // sizes are within one power-of-two of each other. Therefore conversions
939       // between vXf16 and vXf64 must be lowered as sequences which convert via
940       // vXf32.
941       setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
942       // Custom-lower insert/extract operations to simplify patterns.
943       setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
944                          Custom);
945       // Expand various condition codes (explained above).
946       setCondCodeAction(VFPCCToExpand, VT, Expand);
947 
948       setOperationAction({ISD::FMINNUM, ISD::FMAXNUM}, VT, Legal);
949       setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, VT, Custom);
950 
951       setOperationAction({ISD::FTRUNC, ISD::FCEIL, ISD::FFLOOR, ISD::FROUND,
952                           ISD::FROUNDEVEN, ISD::FRINT, ISD::FNEARBYINT,
953                           ISD::IS_FPCLASS},
954                          VT, Custom);
955 
956       setOperationAction(FloatingPointVecReduceOps, VT, Custom);
957 
958       // Expand FP operations that need libcalls.
959       setOperationAction(ISD::FREM, VT, Expand);
960       setOperationAction(ISD::FPOW, VT, Expand);
961       setOperationAction(ISD::FCOS, VT, Expand);
962       setOperationAction(ISD::FSIN, VT, Expand);
963       setOperationAction(ISD::FSINCOS, VT, Expand);
964       setOperationAction(ISD::FEXP, VT, Expand);
965       setOperationAction(ISD::FEXP2, VT, Expand);
966       setOperationAction(ISD::FEXP10, VT, Expand);
967       setOperationAction(ISD::FLOG, VT, Expand);
968       setOperationAction(ISD::FLOG2, VT, Expand);
969       setOperationAction(ISD::FLOG10, VT, Expand);
970 
971       setOperationAction(ISD::FCOPYSIGN, VT, Legal);
972 
973       setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
974 
975       setOperationAction({ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER},
976                          VT, Custom);
977 
978       setOperationAction(
979           {ISD::VP_LOAD, ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
980            ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER, ISD::VP_SCATTER},
981           VT, Custom);
982 
983       setOperationAction(ISD::SELECT, VT, Custom);
984       setOperationAction(ISD::SELECT_CC, VT, Expand);
985 
986       setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR,
987                           ISD::EXTRACT_SUBVECTOR, ISD::SCALAR_TO_VECTOR},
988                          VT, Custom);
989 
990       setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
991       setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
992 
993       setOperationAction({ISD::VECTOR_REVERSE, ISD::VECTOR_SPLICE}, VT, Custom);
994 
995       setOperationAction(FloatingPointVPOps, VT, Custom);
996 
997       setOperationAction({ISD::STRICT_FP_EXTEND, ISD::STRICT_FP_ROUND}, VT,
998                          Custom);
999       setOperationAction({ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
1000                           ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA},
1001                          VT, Legal);
1002       setOperationAction({ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS,
1003                           ISD::STRICT_FTRUNC, ISD::STRICT_FCEIL,
1004                           ISD::STRICT_FFLOOR, ISD::STRICT_FROUND,
1005                           ISD::STRICT_FROUNDEVEN, ISD::STRICT_FNEARBYINT},
1006                          VT, Custom);
1007     };
1008 
1009     // Sets common extload/truncstore actions on RVV floating-point vector
1010     // types.
1011     const auto SetCommonVFPExtLoadTruncStoreActions =
1012         [&](MVT VT, ArrayRef<MVT::SimpleValueType> SmallerVTs) {
1013           for (auto SmallVT : SmallerVTs) {
1014             setTruncStoreAction(VT, SmallVT, Expand);
1015             setLoadExtAction(ISD::EXTLOAD, VT, SmallVT, Expand);
1016           }
1017         };
1018 
1019     if (Subtarget.hasVInstructionsF16()) {
1020       for (MVT VT : F16VecVTs) {
1021         if (!isTypeLegal(VT))
1022           continue;
1023         SetCommonVFPActions(VT);
1024       }
1025     } else if (Subtarget.hasVInstructionsF16Minimal()) {
1026       for (MVT VT : F16VecVTs) {
1027         if (!isTypeLegal(VT))
1028           continue;
1029         setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
1030         setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT,
1031                            Custom);
1032         setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
1033         setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT,
1034                            Custom);
1035         setOperationAction(ISD::SELECT_CC, VT, Expand);
1036         setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP,
1037                             ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP},
1038                            VT, Custom);
1039         setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR,
1040                             ISD::EXTRACT_SUBVECTOR, ISD::SCALAR_TO_VECTOR},
1041                            VT, Custom);
1042         setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1043         // load/store
1044         setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
1045 
1046         // Custom split nxv32f16 since nxv32f32 if not legal.
1047         if (VT == MVT::nxv32f16) {
1048           setOperationAction(ZvfhminPromoteOps, VT, Custom);
1049           setOperationAction(ZvfhminPromoteVPOps, VT, Custom);
1050           continue;
1051         }
1052         // Add more promote ops.
1053         MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
1054         setOperationPromotedToType(ZvfhminPromoteOps, VT, F32VecVT);
1055         setOperationPromotedToType(ZvfhminPromoteVPOps, VT, F32VecVT);
1056       }
1057     }
1058 
1059     if (Subtarget.hasVInstructionsF32()) {
1060       for (MVT VT : F32VecVTs) {
1061         if (!isTypeLegal(VT))
1062           continue;
1063         SetCommonVFPActions(VT);
1064         SetCommonVFPExtLoadTruncStoreActions(VT, F16VecVTs);
1065       }
1066     }
1067 
1068     if (Subtarget.hasVInstructionsF64()) {
1069       for (MVT VT : F64VecVTs) {
1070         if (!isTypeLegal(VT))
1071           continue;
1072         SetCommonVFPActions(VT);
1073         SetCommonVFPExtLoadTruncStoreActions(VT, F16VecVTs);
1074         SetCommonVFPExtLoadTruncStoreActions(VT, F32VecVTs);
1075       }
1076     }
1077 
1078     if (Subtarget.useRVVForFixedLengthVectors()) {
1079       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
1080         if (!useRVVForFixedLengthVectorVT(VT))
1081           continue;
1082 
1083         // By default everything must be expanded.
1084         for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
1085           setOperationAction(Op, VT, Expand);
1086         for (MVT OtherVT : MVT::integer_fixedlen_vector_valuetypes()) {
1087           setTruncStoreAction(VT, OtherVT, Expand);
1088           setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, VT,
1089                            OtherVT, Expand);
1090         }
1091 
1092         // Custom lower fixed vector undefs to scalable vector undefs to avoid
1093         // expansion to a build_vector of 0s.
1094         setOperationAction(ISD::UNDEF, VT, Custom);
1095 
1096         // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
1097         setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT,
1098                            Custom);
1099 
1100         setOperationAction({ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS}, VT,
1101                            Custom);
1102 
1103         setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT},
1104                            VT, Custom);
1105 
1106         setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
1107 
1108         setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
1109 
1110         setOperationAction(ISD::SETCC, VT, Custom);
1111 
1112         setOperationAction(ISD::SELECT, VT, Custom);
1113 
1114         setOperationAction(ISD::TRUNCATE, VT, Custom);
1115 
1116         setOperationAction(ISD::BITCAST, VT, Custom);
1117 
1118         setOperationAction(
1119             {ISD::VECREDUCE_AND, ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR}, VT,
1120             Custom);
1121 
1122         setOperationAction(
1123             {ISD::VP_REDUCE_AND, ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR}, VT,
1124             Custom);
1125 
1126         setOperationAction(
1127             {
1128                 ISD::SINT_TO_FP,
1129                 ISD::UINT_TO_FP,
1130                 ISD::FP_TO_SINT,
1131                 ISD::FP_TO_UINT,
1132                 ISD::STRICT_SINT_TO_FP,
1133                 ISD::STRICT_UINT_TO_FP,
1134                 ISD::STRICT_FP_TO_SINT,
1135                 ISD::STRICT_FP_TO_UINT,
1136             },
1137             VT, Custom);
1138         setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
1139                            Custom);
1140 
1141         setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
1142 
1143         // Operations below are different for between masks and other vectors.
1144         if (VT.getVectorElementType() == MVT::i1) {
1145           setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR, ISD::AND,
1146                               ISD::OR, ISD::XOR},
1147                              VT, Custom);
1148 
1149           setOperationAction({ISD::VP_FP_TO_SINT, ISD::VP_FP_TO_UINT,
1150                               ISD::VP_SETCC, ISD::VP_TRUNCATE},
1151                              VT, Custom);
1152 
1153           setOperationAction(ISD::EXPERIMENTAL_VP_SPLICE, VT, Custom);
1154           setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom);
1155           continue;
1156         }
1157 
1158         // Make SPLAT_VECTOR Legal so DAGCombine will convert splat vectors to
1159         // it before type legalization for i64 vectors on RV32. It will then be
1160         // type legalized to SPLAT_VECTOR_PARTS which we need to Custom handle.
1161         // FIXME: Use SPLAT_VECTOR for all types? DAGCombine probably needs
1162         // improvements first.
1163         if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64) {
1164           setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1165           setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom);
1166         }
1167 
1168         setOperationAction(
1169             {ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER}, VT, Custom);
1170 
1171         setOperationAction({ISD::VP_LOAD, ISD::VP_STORE,
1172                             ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
1173                             ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER,
1174                             ISD::VP_SCATTER},
1175                            VT, Custom);
1176 
1177         setOperationAction({ISD::ADD, ISD::MUL, ISD::SUB, ISD::AND, ISD::OR,
1178                             ISD::XOR, ISD::SDIV, ISD::SREM, ISD::UDIV,
1179                             ISD::UREM, ISD::SHL, ISD::SRA, ISD::SRL},
1180                            VT, Custom);
1181 
1182         setOperationAction(
1183             {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX, ISD::ABS}, VT, Custom);
1184 
1185         // vXi64 MULHS/MULHU requires the V extension instead of Zve64*.
1186         if (VT.getVectorElementType() != MVT::i64 || Subtarget.hasStdExtV())
1187           setOperationAction({ISD::MULHS, ISD::MULHU}, VT, Custom);
1188 
1189         setOperationAction({ISD::AVGFLOORU, ISD::AVGCEILU, ISD::SADDSAT,
1190                             ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT},
1191                            VT, Custom);
1192 
1193         setOperationAction(ISD::VSELECT, VT, Custom);
1194         setOperationAction(ISD::SELECT_CC, VT, Expand);
1195 
1196         setOperationAction(
1197             {ISD::ANY_EXTEND, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND}, VT, Custom);
1198 
1199         // Custom-lower reduction operations to set up the corresponding custom
1200         // nodes' operands.
1201         setOperationAction({ISD::VECREDUCE_ADD, ISD::VECREDUCE_SMAX,
1202                             ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX,
1203                             ISD::VECREDUCE_UMIN},
1204                            VT, Custom);
1205 
1206         setOperationAction(IntegerVPOps, VT, Custom);
1207 
1208         if (Subtarget.hasStdExtZvkb())
1209           setOperationAction({ISD::BSWAP, ISD::ROTL, ISD::ROTR}, VT, Custom);
1210 
1211         if (Subtarget.hasStdExtZvbb()) {
1212           setOperationAction({ISD::BITREVERSE, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF,
1213                               ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF, ISD::CTPOP},
1214                              VT, Custom);
1215         } else {
1216           // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if element of VT in the
1217           // range of f32.
1218           EVT FloatVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
1219           if (isTypeLegal(FloatVT))
1220             setOperationAction(
1221                 {ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
1222                 Custom);
1223         }
1224       }
1225 
1226       for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) {
1227         // There are no extending loads or truncating stores.
1228         for (MVT InnerVT : MVT::fp_fixedlen_vector_valuetypes()) {
1229           setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1230           setTruncStoreAction(VT, InnerVT, Expand);
1231         }
1232 
1233         if (!useRVVForFixedLengthVectorVT(VT))
1234           continue;
1235 
1236         // By default everything must be expanded.
1237         for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
1238           setOperationAction(Op, VT, Expand);
1239 
1240         // Custom lower fixed vector undefs to scalable vector undefs to avoid
1241         // expansion to a build_vector of 0s.
1242         setOperationAction(ISD::UNDEF, VT, Custom);
1243 
1244         if (VT.getVectorElementType() == MVT::f16 &&
1245             !Subtarget.hasVInstructionsF16()) {
1246           setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
1247           setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT,
1248                              Custom);
1249           setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
1250           setOperationAction(
1251               {ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT,
1252               Custom);
1253           setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP,
1254                               ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP},
1255                              VT, Custom);
1256           setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR,
1257                               ISD::EXTRACT_SUBVECTOR, ISD::SCALAR_TO_VECTOR},
1258                              VT, Custom);
1259           setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
1260           setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1261           MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
1262           // Don't promote f16 vector operations to f32 if f32 vector type is
1263           // not legal.
1264           // TODO: could split the f16 vector into two vectors and do promotion.
1265           if (!isTypeLegal(F32VecVT))
1266             continue;
1267           setOperationPromotedToType(ZvfhminPromoteOps, VT, F32VecVT);
1268           setOperationPromotedToType(ZvfhminPromoteVPOps, VT, F32VecVT);
1269           continue;
1270         }
1271 
1272         // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
1273         setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT,
1274                            Custom);
1275 
1276         setOperationAction({ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
1277                             ISD::VECTOR_SHUFFLE, ISD::INSERT_VECTOR_ELT,
1278                             ISD::EXTRACT_VECTOR_ELT},
1279                            VT, Custom);
1280 
1281         setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE,
1282                             ISD::MGATHER, ISD::MSCATTER},
1283                            VT, Custom);
1284 
1285         setOperationAction({ISD::VP_LOAD, ISD::VP_STORE,
1286                             ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
1287                             ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER,
1288                             ISD::VP_SCATTER},
1289                            VT, Custom);
1290 
1291         setOperationAction({ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV,
1292                             ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN, ISD::FSQRT,
1293                             ISD::FMA, ISD::FMINNUM, ISD::FMAXNUM,
1294                             ISD::IS_FPCLASS, ISD::FMAXIMUM, ISD::FMINIMUM},
1295                            VT, Custom);
1296 
1297         setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
1298 
1299         setOperationAction({ISD::FTRUNC, ISD::FCEIL, ISD::FFLOOR, ISD::FROUND,
1300                             ISD::FROUNDEVEN, ISD::FRINT, ISD::FNEARBYINT},
1301                            VT, Custom);
1302 
1303         setCondCodeAction(VFPCCToExpand, VT, Expand);
1304 
1305         setOperationAction(ISD::SETCC, VT, Custom);
1306         setOperationAction({ISD::VSELECT, ISD::SELECT}, VT, Custom);
1307         setOperationAction(ISD::SELECT_CC, VT, Expand);
1308 
1309         setOperationAction(ISD::BITCAST, VT, Custom);
1310 
1311         setOperationAction(FloatingPointVecReduceOps, VT, Custom);
1312 
1313         setOperationAction(FloatingPointVPOps, VT, Custom);
1314 
1315         setOperationAction({ISD::STRICT_FP_EXTEND, ISD::STRICT_FP_ROUND}, VT,
1316                            Custom);
1317         setOperationAction(
1318             {ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
1319              ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA,
1320              ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS, ISD::STRICT_FTRUNC,
1321              ISD::STRICT_FCEIL, ISD::STRICT_FFLOOR, ISD::STRICT_FROUND,
1322              ISD::STRICT_FROUNDEVEN, ISD::STRICT_FNEARBYINT},
1323             VT, Custom);
1324       }
1325 
1326       // Custom-legalize bitcasts from fixed-length vectors to scalar types.
1327       setOperationAction(ISD::BITCAST, {MVT::i8, MVT::i16, MVT::i32, MVT::i64},
1328                          Custom);
1329       if (Subtarget.hasStdExtZfhminOrZhinxmin())
1330         setOperationAction(ISD::BITCAST, MVT::f16, Custom);
1331       if (Subtarget.hasStdExtFOrZfinx())
1332         setOperationAction(ISD::BITCAST, MVT::f32, Custom);
1333       if (Subtarget.hasStdExtDOrZdinx())
1334         setOperationAction(ISD::BITCAST, MVT::f64, Custom);
1335     }
1336   }
1337 
1338   if (Subtarget.hasStdExtA()) {
1339     setOperationAction(ISD::ATOMIC_LOAD_SUB, XLenVT, Expand);
1340     if (RV64LegalI32 && Subtarget.is64Bit())
1341       setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, Expand);
1342   }
1343 
1344   if (Subtarget.hasForcedAtomics()) {
1345     // Force __sync libcalls to be emitted for atomic rmw/cas operations.
1346     setOperationAction(
1347         {ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP, ISD::ATOMIC_LOAD_ADD,
1348          ISD::ATOMIC_LOAD_SUB, ISD::ATOMIC_LOAD_AND, ISD::ATOMIC_LOAD_OR,
1349          ISD::ATOMIC_LOAD_XOR, ISD::ATOMIC_LOAD_NAND, ISD::ATOMIC_LOAD_MIN,
1350          ISD::ATOMIC_LOAD_MAX, ISD::ATOMIC_LOAD_UMIN, ISD::ATOMIC_LOAD_UMAX},
1351         XLenVT, LibCall);
1352   }
1353 
1354   if (Subtarget.hasVendorXTHeadMemIdx()) {
1355     for (unsigned im : {ISD::PRE_INC, ISD::POST_INC}) {
1356       setIndexedLoadAction(im, MVT::i8, Legal);
1357       setIndexedStoreAction(im, MVT::i8, Legal);
1358       setIndexedLoadAction(im, MVT::i16, Legal);
1359       setIndexedStoreAction(im, MVT::i16, Legal);
1360       setIndexedLoadAction(im, MVT::i32, Legal);
1361       setIndexedStoreAction(im, MVT::i32, Legal);
1362 
1363       if (Subtarget.is64Bit()) {
1364         setIndexedLoadAction(im, MVT::i64, Legal);
1365         setIndexedStoreAction(im, MVT::i64, Legal);
1366       }
1367     }
1368   }
1369 
1370   // Function alignments.
1371   const Align FunctionAlignment(Subtarget.hasStdExtCOrZca() ? 2 : 4);
1372   setMinFunctionAlignment(FunctionAlignment);
1373   // Set preferred alignments.
1374   setPrefFunctionAlignment(Subtarget.getPrefFunctionAlignment());
1375   setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
1376 
1377   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
1378                        ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
1379                        ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
1380   if (Subtarget.is64Bit())
1381     setTargetDAGCombine(ISD::SRA);
1382 
1383   if (Subtarget.hasStdExtFOrZfinx())
1384     setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM});
1385 
1386   if (Subtarget.hasStdExtZbb())
1387     setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
1388 
1389   if (Subtarget.hasStdExtZbs() && Subtarget.is64Bit())
1390     setTargetDAGCombine(ISD::TRUNCATE);
1391 
1392   if (Subtarget.hasStdExtZbkb())
1393     setTargetDAGCombine(ISD::BITREVERSE);
1394   if (Subtarget.hasStdExtZfhminOrZhinxmin())
1395     setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
1396   if (Subtarget.hasStdExtFOrZfinx())
1397     setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
1398                          ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
1399   if (Subtarget.hasVInstructions())
1400     setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
1401                          ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
1402                          ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
1403                          ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
1404                          ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
1405                          ISD::INSERT_VECTOR_ELT});
1406   if (Subtarget.hasVendorXTHeadMemPair())
1407     setTargetDAGCombine({ISD::LOAD, ISD::STORE});
1408   if (Subtarget.useRVVForFixedLengthVectors())
1409     setTargetDAGCombine(ISD::BITCAST);
1410 
1411   setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
1412   setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2");
1413 
1414   // Disable strict node mutation.
1415   IsStrictFPEnabled = true;
1416 }
1417 
getSetCCResultType(const DataLayout & DL,LLVMContext & Context,EVT VT) const1418 EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL,
1419                                             LLVMContext &Context,
1420                                             EVT VT) const {
1421   if (!VT.isVector())
1422     return getPointerTy(DL);
1423   if (Subtarget.hasVInstructions() &&
1424       (VT.isScalableVector() || Subtarget.useRVVForFixedLengthVectors()))
1425     return EVT::getVectorVT(Context, MVT::i1, VT.getVectorElementCount());
1426   return VT.changeVectorElementTypeToInteger();
1427 }
1428 
getVPExplicitVectorLengthTy() const1429 MVT RISCVTargetLowering::getVPExplicitVectorLengthTy() const {
1430   return Subtarget.getXLenVT();
1431 }
1432 
1433 // Return false if we can lower get_vector_length to a vsetvli intrinsic.
shouldExpandGetVectorLength(EVT TripCountVT,unsigned VF,bool IsScalable) const1434 bool RISCVTargetLowering::shouldExpandGetVectorLength(EVT TripCountVT,
1435                                                       unsigned VF,
1436                                                       bool IsScalable) const {
1437   if (!Subtarget.hasVInstructions())
1438     return true;
1439 
1440   if (!IsScalable)
1441     return true;
1442 
1443   if (TripCountVT != MVT::i32 && TripCountVT != Subtarget.getXLenVT())
1444     return true;
1445 
1446   // Don't allow VF=1 if those types are't legal.
1447   if (VF < RISCV::RVVBitsPerBlock / Subtarget.getELen())
1448     return true;
1449 
1450   // VLEN=32 support is incomplete.
1451   if (Subtarget.getRealMinVLen() < RISCV::RVVBitsPerBlock)
1452     return true;
1453 
1454   // The maximum VF is for the smallest element width with LMUL=8.
1455   // VF must be a power of 2.
1456   unsigned MaxVF = (RISCV::RVVBitsPerBlock / 8) * 8;
1457   return VF > MaxVF || !isPowerOf2_32(VF);
1458 }
1459 
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const1460 bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
1461                                              const CallInst &I,
1462                                              MachineFunction &MF,
1463                                              unsigned Intrinsic) const {
1464   auto &DL = I.getModule()->getDataLayout();
1465 
1466   auto SetRVVLoadStoreInfo = [&](unsigned PtrOp, bool IsStore,
1467                                  bool IsUnitStrided) {
1468     Info.opc = IsStore ? ISD::INTRINSIC_VOID : ISD::INTRINSIC_W_CHAIN;
1469     Info.ptrVal = I.getArgOperand(PtrOp);
1470     Type *MemTy;
1471     if (IsStore) {
1472       // Store value is the first operand.
1473       MemTy = I.getArgOperand(0)->getType();
1474     } else {
1475       // Use return type. If it's segment load, return type is a struct.
1476       MemTy = I.getType();
1477       if (MemTy->isStructTy())
1478         MemTy = MemTy->getStructElementType(0);
1479     }
1480     if (!IsUnitStrided)
1481       MemTy = MemTy->getScalarType();
1482 
1483     Info.memVT = getValueType(DL, MemTy);
1484     Info.align = Align(DL.getTypeSizeInBits(MemTy->getScalarType()) / 8);
1485     Info.size = MemoryLocation::UnknownSize;
1486     Info.flags |=
1487         IsStore ? MachineMemOperand::MOStore : MachineMemOperand::MOLoad;
1488     return true;
1489   };
1490 
1491   if (I.getMetadata(LLVMContext::MD_nontemporal) != nullptr)
1492     Info.flags |= MachineMemOperand::MONonTemporal;
1493 
1494   Info.flags |= RISCVTargetLowering::getTargetMMOFlags(I);
1495   switch (Intrinsic) {
1496   default:
1497     return false;
1498   case Intrinsic::riscv_masked_atomicrmw_xchg_i32:
1499   case Intrinsic::riscv_masked_atomicrmw_add_i32:
1500   case Intrinsic::riscv_masked_atomicrmw_sub_i32:
1501   case Intrinsic::riscv_masked_atomicrmw_nand_i32:
1502   case Intrinsic::riscv_masked_atomicrmw_max_i32:
1503   case Intrinsic::riscv_masked_atomicrmw_min_i32:
1504   case Intrinsic::riscv_masked_atomicrmw_umax_i32:
1505   case Intrinsic::riscv_masked_atomicrmw_umin_i32:
1506   case Intrinsic::riscv_masked_cmpxchg_i32:
1507     Info.opc = ISD::INTRINSIC_W_CHAIN;
1508     Info.memVT = MVT::i32;
1509     Info.ptrVal = I.getArgOperand(0);
1510     Info.offset = 0;
1511     Info.align = Align(4);
1512     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore |
1513                  MachineMemOperand::MOVolatile;
1514     return true;
1515   case Intrinsic::riscv_masked_strided_load:
1516     return SetRVVLoadStoreInfo(/*PtrOp*/ 1, /*IsStore*/ false,
1517                                /*IsUnitStrided*/ false);
1518   case Intrinsic::riscv_masked_strided_store:
1519     return SetRVVLoadStoreInfo(/*PtrOp*/ 1, /*IsStore*/ true,
1520                                /*IsUnitStrided*/ false);
1521   case Intrinsic::riscv_seg2_load:
1522   case Intrinsic::riscv_seg3_load:
1523   case Intrinsic::riscv_seg4_load:
1524   case Intrinsic::riscv_seg5_load:
1525   case Intrinsic::riscv_seg6_load:
1526   case Intrinsic::riscv_seg7_load:
1527   case Intrinsic::riscv_seg8_load:
1528     return SetRVVLoadStoreInfo(/*PtrOp*/ 0, /*IsStore*/ false,
1529                                /*IsUnitStrided*/ false);
1530   case Intrinsic::riscv_seg2_store:
1531   case Intrinsic::riscv_seg3_store:
1532   case Intrinsic::riscv_seg4_store:
1533   case Intrinsic::riscv_seg5_store:
1534   case Intrinsic::riscv_seg6_store:
1535   case Intrinsic::riscv_seg7_store:
1536   case Intrinsic::riscv_seg8_store:
1537     // Operands are (vec, ..., vec, ptr, vl)
1538     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 2,
1539                                /*IsStore*/ true,
1540                                /*IsUnitStrided*/ false);
1541   case Intrinsic::riscv_vle:
1542   case Intrinsic::riscv_vle_mask:
1543   case Intrinsic::riscv_vleff:
1544   case Intrinsic::riscv_vleff_mask:
1545     return SetRVVLoadStoreInfo(/*PtrOp*/ 1,
1546                                /*IsStore*/ false,
1547                                /*IsUnitStrided*/ true);
1548   case Intrinsic::riscv_vse:
1549   case Intrinsic::riscv_vse_mask:
1550     return SetRVVLoadStoreInfo(/*PtrOp*/ 1,
1551                                /*IsStore*/ true,
1552                                /*IsUnitStrided*/ true);
1553   case Intrinsic::riscv_vlse:
1554   case Intrinsic::riscv_vlse_mask:
1555   case Intrinsic::riscv_vloxei:
1556   case Intrinsic::riscv_vloxei_mask:
1557   case Intrinsic::riscv_vluxei:
1558   case Intrinsic::riscv_vluxei_mask:
1559     return SetRVVLoadStoreInfo(/*PtrOp*/ 1,
1560                                /*IsStore*/ false,
1561                                /*IsUnitStrided*/ false);
1562   case Intrinsic::riscv_vsse:
1563   case Intrinsic::riscv_vsse_mask:
1564   case Intrinsic::riscv_vsoxei:
1565   case Intrinsic::riscv_vsoxei_mask:
1566   case Intrinsic::riscv_vsuxei:
1567   case Intrinsic::riscv_vsuxei_mask:
1568     return SetRVVLoadStoreInfo(/*PtrOp*/ 1,
1569                                /*IsStore*/ true,
1570                                /*IsUnitStrided*/ false);
1571   case Intrinsic::riscv_vlseg2:
1572   case Intrinsic::riscv_vlseg3:
1573   case Intrinsic::riscv_vlseg4:
1574   case Intrinsic::riscv_vlseg5:
1575   case Intrinsic::riscv_vlseg6:
1576   case Intrinsic::riscv_vlseg7:
1577   case Intrinsic::riscv_vlseg8:
1578   case Intrinsic::riscv_vlseg2ff:
1579   case Intrinsic::riscv_vlseg3ff:
1580   case Intrinsic::riscv_vlseg4ff:
1581   case Intrinsic::riscv_vlseg5ff:
1582   case Intrinsic::riscv_vlseg6ff:
1583   case Intrinsic::riscv_vlseg7ff:
1584   case Intrinsic::riscv_vlseg8ff:
1585     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 2,
1586                                /*IsStore*/ false,
1587                                /*IsUnitStrided*/ false);
1588   case Intrinsic::riscv_vlseg2_mask:
1589   case Intrinsic::riscv_vlseg3_mask:
1590   case Intrinsic::riscv_vlseg4_mask:
1591   case Intrinsic::riscv_vlseg5_mask:
1592   case Intrinsic::riscv_vlseg6_mask:
1593   case Intrinsic::riscv_vlseg7_mask:
1594   case Intrinsic::riscv_vlseg8_mask:
1595   case Intrinsic::riscv_vlseg2ff_mask:
1596   case Intrinsic::riscv_vlseg3ff_mask:
1597   case Intrinsic::riscv_vlseg4ff_mask:
1598   case Intrinsic::riscv_vlseg5ff_mask:
1599   case Intrinsic::riscv_vlseg6ff_mask:
1600   case Intrinsic::riscv_vlseg7ff_mask:
1601   case Intrinsic::riscv_vlseg8ff_mask:
1602     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 4,
1603                                /*IsStore*/ false,
1604                                /*IsUnitStrided*/ false);
1605   case Intrinsic::riscv_vlsseg2:
1606   case Intrinsic::riscv_vlsseg3:
1607   case Intrinsic::riscv_vlsseg4:
1608   case Intrinsic::riscv_vlsseg5:
1609   case Intrinsic::riscv_vlsseg6:
1610   case Intrinsic::riscv_vlsseg7:
1611   case Intrinsic::riscv_vlsseg8:
1612   case Intrinsic::riscv_vloxseg2:
1613   case Intrinsic::riscv_vloxseg3:
1614   case Intrinsic::riscv_vloxseg4:
1615   case Intrinsic::riscv_vloxseg5:
1616   case Intrinsic::riscv_vloxseg6:
1617   case Intrinsic::riscv_vloxseg7:
1618   case Intrinsic::riscv_vloxseg8:
1619   case Intrinsic::riscv_vluxseg2:
1620   case Intrinsic::riscv_vluxseg3:
1621   case Intrinsic::riscv_vluxseg4:
1622   case Intrinsic::riscv_vluxseg5:
1623   case Intrinsic::riscv_vluxseg6:
1624   case Intrinsic::riscv_vluxseg7:
1625   case Intrinsic::riscv_vluxseg8:
1626     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 3,
1627                                /*IsStore*/ false,
1628                                /*IsUnitStrided*/ false);
1629   case Intrinsic::riscv_vlsseg2_mask:
1630   case Intrinsic::riscv_vlsseg3_mask:
1631   case Intrinsic::riscv_vlsseg4_mask:
1632   case Intrinsic::riscv_vlsseg5_mask:
1633   case Intrinsic::riscv_vlsseg6_mask:
1634   case Intrinsic::riscv_vlsseg7_mask:
1635   case Intrinsic::riscv_vlsseg8_mask:
1636   case Intrinsic::riscv_vloxseg2_mask:
1637   case Intrinsic::riscv_vloxseg3_mask:
1638   case Intrinsic::riscv_vloxseg4_mask:
1639   case Intrinsic::riscv_vloxseg5_mask:
1640   case Intrinsic::riscv_vloxseg6_mask:
1641   case Intrinsic::riscv_vloxseg7_mask:
1642   case Intrinsic::riscv_vloxseg8_mask:
1643   case Intrinsic::riscv_vluxseg2_mask:
1644   case Intrinsic::riscv_vluxseg3_mask:
1645   case Intrinsic::riscv_vluxseg4_mask:
1646   case Intrinsic::riscv_vluxseg5_mask:
1647   case Intrinsic::riscv_vluxseg6_mask:
1648   case Intrinsic::riscv_vluxseg7_mask:
1649   case Intrinsic::riscv_vluxseg8_mask:
1650     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 5,
1651                                /*IsStore*/ false,
1652                                /*IsUnitStrided*/ false);
1653   case Intrinsic::riscv_vsseg2:
1654   case Intrinsic::riscv_vsseg3:
1655   case Intrinsic::riscv_vsseg4:
1656   case Intrinsic::riscv_vsseg5:
1657   case Intrinsic::riscv_vsseg6:
1658   case Intrinsic::riscv_vsseg7:
1659   case Intrinsic::riscv_vsseg8:
1660     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 2,
1661                                /*IsStore*/ true,
1662                                /*IsUnitStrided*/ false);
1663   case Intrinsic::riscv_vsseg2_mask:
1664   case Intrinsic::riscv_vsseg3_mask:
1665   case Intrinsic::riscv_vsseg4_mask:
1666   case Intrinsic::riscv_vsseg5_mask:
1667   case Intrinsic::riscv_vsseg6_mask:
1668   case Intrinsic::riscv_vsseg7_mask:
1669   case Intrinsic::riscv_vsseg8_mask:
1670     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 3,
1671                                /*IsStore*/ true,
1672                                /*IsUnitStrided*/ false);
1673   case Intrinsic::riscv_vssseg2:
1674   case Intrinsic::riscv_vssseg3:
1675   case Intrinsic::riscv_vssseg4:
1676   case Intrinsic::riscv_vssseg5:
1677   case Intrinsic::riscv_vssseg6:
1678   case Intrinsic::riscv_vssseg7:
1679   case Intrinsic::riscv_vssseg8:
1680   case Intrinsic::riscv_vsoxseg2:
1681   case Intrinsic::riscv_vsoxseg3:
1682   case Intrinsic::riscv_vsoxseg4:
1683   case Intrinsic::riscv_vsoxseg5:
1684   case Intrinsic::riscv_vsoxseg6:
1685   case Intrinsic::riscv_vsoxseg7:
1686   case Intrinsic::riscv_vsoxseg8:
1687   case Intrinsic::riscv_vsuxseg2:
1688   case Intrinsic::riscv_vsuxseg3:
1689   case Intrinsic::riscv_vsuxseg4:
1690   case Intrinsic::riscv_vsuxseg5:
1691   case Intrinsic::riscv_vsuxseg6:
1692   case Intrinsic::riscv_vsuxseg7:
1693   case Intrinsic::riscv_vsuxseg8:
1694     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 3,
1695                                /*IsStore*/ true,
1696                                /*IsUnitStrided*/ false);
1697   case Intrinsic::riscv_vssseg2_mask:
1698   case Intrinsic::riscv_vssseg3_mask:
1699   case Intrinsic::riscv_vssseg4_mask:
1700   case Intrinsic::riscv_vssseg5_mask:
1701   case Intrinsic::riscv_vssseg6_mask:
1702   case Intrinsic::riscv_vssseg7_mask:
1703   case Intrinsic::riscv_vssseg8_mask:
1704   case Intrinsic::riscv_vsoxseg2_mask:
1705   case Intrinsic::riscv_vsoxseg3_mask:
1706   case Intrinsic::riscv_vsoxseg4_mask:
1707   case Intrinsic::riscv_vsoxseg5_mask:
1708   case Intrinsic::riscv_vsoxseg6_mask:
1709   case Intrinsic::riscv_vsoxseg7_mask:
1710   case Intrinsic::riscv_vsoxseg8_mask:
1711   case Intrinsic::riscv_vsuxseg2_mask:
1712   case Intrinsic::riscv_vsuxseg3_mask:
1713   case Intrinsic::riscv_vsuxseg4_mask:
1714   case Intrinsic::riscv_vsuxseg5_mask:
1715   case Intrinsic::riscv_vsuxseg6_mask:
1716   case Intrinsic::riscv_vsuxseg7_mask:
1717   case Intrinsic::riscv_vsuxseg8_mask:
1718     return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 4,
1719                                /*IsStore*/ true,
1720                                /*IsUnitStrided*/ false);
1721   }
1722 }
1723 
isLegalAddressingMode(const DataLayout & DL,const AddrMode & AM,Type * Ty,unsigned AS,Instruction * I) const1724 bool RISCVTargetLowering::isLegalAddressingMode(const DataLayout &DL,
1725                                                 const AddrMode &AM, Type *Ty,
1726                                                 unsigned AS,
1727                                                 Instruction *I) const {
1728   // No global is ever allowed as a base.
1729   if (AM.BaseGV)
1730     return false;
1731 
1732   // RVV instructions only support register addressing.
1733   if (Subtarget.hasVInstructions() && isa<VectorType>(Ty))
1734     return AM.HasBaseReg && AM.Scale == 0 && !AM.BaseOffs;
1735 
1736   // Require a 12-bit signed offset.
1737   if (!isInt<12>(AM.BaseOffs))
1738     return false;
1739 
1740   switch (AM.Scale) {
1741   case 0: // "r+i" or just "i", depending on HasBaseReg.
1742     break;
1743   case 1:
1744     if (!AM.HasBaseReg) // allow "r+i".
1745       break;
1746     return false; // disallow "r+r" or "r+r+i".
1747   default:
1748     return false;
1749   }
1750 
1751   return true;
1752 }
1753 
isLegalICmpImmediate(int64_t Imm) const1754 bool RISCVTargetLowering::isLegalICmpImmediate(int64_t Imm) const {
1755   return isInt<12>(Imm);
1756 }
1757 
isLegalAddImmediate(int64_t Imm) const1758 bool RISCVTargetLowering::isLegalAddImmediate(int64_t Imm) const {
1759   return isInt<12>(Imm);
1760 }
1761 
1762 // On RV32, 64-bit integers are split into their high and low parts and held
1763 // in two different registers, so the trunc is free since the low register can
1764 // just be used.
1765 // FIXME: Should we consider i64->i32 free on RV64 to match the EVT version of
1766 // isTruncateFree?
isTruncateFree(Type * SrcTy,Type * DstTy) const1767 bool RISCVTargetLowering::isTruncateFree(Type *SrcTy, Type *DstTy) const {
1768   if (Subtarget.is64Bit() || !SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
1769     return false;
1770   unsigned SrcBits = SrcTy->getPrimitiveSizeInBits();
1771   unsigned DestBits = DstTy->getPrimitiveSizeInBits();
1772   return (SrcBits == 64 && DestBits == 32);
1773 }
1774 
isTruncateFree(EVT SrcVT,EVT DstVT) const1775 bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
1776   // We consider i64->i32 free on RV64 since we have good selection of W
1777   // instructions that make promoting operations back to i64 free in many cases.
1778   if (SrcVT.isVector() || DstVT.isVector() || !SrcVT.isInteger() ||
1779       !DstVT.isInteger())
1780     return false;
1781   unsigned SrcBits = SrcVT.getSizeInBits();
1782   unsigned DestBits = DstVT.getSizeInBits();
1783   return (SrcBits == 64 && DestBits == 32);
1784 }
1785 
isZExtFree(SDValue Val,EVT VT2) const1786 bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
1787   // Zexts are free if they can be combined with a load.
1788   // Don't advertise i32->i64 zextload as being free for RV64. It interacts
1789   // poorly with type legalization of compares preferring sext.
1790   if (auto *LD = dyn_cast<LoadSDNode>(Val)) {
1791     EVT MemVT = LD->getMemoryVT();
1792     if ((MemVT == MVT::i8 || MemVT == MVT::i16) &&
1793         (LD->getExtensionType() == ISD::NON_EXTLOAD ||
1794          LD->getExtensionType() == ISD::ZEXTLOAD))
1795       return true;
1796   }
1797 
1798   return TargetLowering::isZExtFree(Val, VT2);
1799 }
1800 
isSExtCheaperThanZExt(EVT SrcVT,EVT DstVT) const1801 bool RISCVTargetLowering::isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const {
1802   return Subtarget.is64Bit() && SrcVT == MVT::i32 && DstVT == MVT::i64;
1803 }
1804 
signExtendConstant(const ConstantInt * CI) const1805 bool RISCVTargetLowering::signExtendConstant(const ConstantInt *CI) const {
1806   return Subtarget.is64Bit() && CI->getType()->isIntegerTy(32);
1807 }
1808 
isCheapToSpeculateCttz(Type * Ty) const1809 bool RISCVTargetLowering::isCheapToSpeculateCttz(Type *Ty) const {
1810   return Subtarget.hasStdExtZbb() || Subtarget.hasVendorXCVbitmanip();
1811 }
1812 
isCheapToSpeculateCtlz(Type * Ty) const1813 bool RISCVTargetLowering::isCheapToSpeculateCtlz(Type *Ty) const {
1814   return Subtarget.hasStdExtZbb() || Subtarget.hasVendorXTHeadBb() ||
1815          Subtarget.hasVendorXCVbitmanip();
1816 }
1817 
isMaskAndCmp0FoldingBeneficial(const Instruction & AndI) const1818 bool RISCVTargetLowering::isMaskAndCmp0FoldingBeneficial(
1819     const Instruction &AndI) const {
1820   // We expect to be able to match a bit extraction instruction if the Zbs
1821   // extension is supported and the mask is a power of two. However, we
1822   // conservatively return false if the mask would fit in an ANDI instruction,
1823   // on the basis that it's possible the sinking+duplication of the AND in
1824   // CodeGenPrepare triggered by this hook wouldn't decrease the instruction
1825   // count and would increase code size (e.g. ANDI+BNEZ => BEXTI+BNEZ).
1826   if (!Subtarget.hasStdExtZbs() && !Subtarget.hasVendorXTHeadBs())
1827     return false;
1828   ConstantInt *Mask = dyn_cast<ConstantInt>(AndI.getOperand(1));
1829   if (!Mask)
1830     return false;
1831   return !Mask->getValue().isSignedIntN(12) && Mask->getValue().isPowerOf2();
1832 }
1833 
hasAndNotCompare(SDValue Y) const1834 bool RISCVTargetLowering::hasAndNotCompare(SDValue Y) const {
1835   EVT VT = Y.getValueType();
1836 
1837   // FIXME: Support vectors once we have tests.
1838   if (VT.isVector())
1839     return false;
1840 
1841   return (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()) &&
1842          !isa<ConstantSDNode>(Y);
1843 }
1844 
hasBitTest(SDValue X,SDValue Y) const1845 bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const {
1846   // Zbs provides BEXT[_I], which can be used with SEQZ/SNEZ as a bit test.
1847   if (Subtarget.hasStdExtZbs())
1848     return X.getValueType().isScalarInteger();
1849   auto *C = dyn_cast<ConstantSDNode>(Y);
1850   // XTheadBs provides th.tst (similar to bexti), if Y is a constant
1851   if (Subtarget.hasVendorXTHeadBs())
1852     return C != nullptr;
1853   // We can use ANDI+SEQZ/SNEZ as a bit test. Y contains the bit position.
1854   return C && C->getAPIntValue().ule(10);
1855 }
1856 
shouldFoldSelectWithIdentityConstant(unsigned Opcode,EVT VT) const1857 bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
1858                                                                EVT VT) const {
1859   // Only enable for rvv.
1860   if (!VT.isVector() || !Subtarget.hasVInstructions())
1861     return false;
1862 
1863   if (VT.isFixedLengthVector() && !isTypeLegal(VT))
1864     return false;
1865 
1866   return true;
1867 }
1868 
shouldConvertConstantLoadToIntImm(const APInt & Imm,Type * Ty) const1869 bool RISCVTargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
1870                                                             Type *Ty) const {
1871   assert(Ty->isIntegerTy());
1872 
1873   unsigned BitSize = Ty->getIntegerBitWidth();
1874   if (BitSize > Subtarget.getXLen())
1875     return false;
1876 
1877   // Fast path, assume 32-bit immediates are cheap.
1878   int64_t Val = Imm.getSExtValue();
1879   if (isInt<32>(Val))
1880     return true;
1881 
1882   // A constant pool entry may be more aligned thant he load we're trying to
1883   // replace. If we don't support unaligned scalar mem, prefer the constant
1884   // pool.
1885   // TODO: Can the caller pass down the alignment?
1886   if (!Subtarget.hasFastUnalignedAccess() &&
1887       !Subtarget.enableUnalignedScalarMem())
1888     return true;
1889 
1890   // Prefer to keep the load if it would require many instructions.
1891   // This uses the same threshold we use for constant pools but doesn't
1892   // check useConstantPoolForLargeInts.
1893   // TODO: Should we keep the load only when we're definitely going to emit a
1894   // constant pool?
1895 
1896   RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Val, Subtarget);
1897   return Seq.size() <= Subtarget.getMaxBuildIntsCost();
1898 }
1899 
1900 bool RISCVTargetLowering::
shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(SDValue X,ConstantSDNode * XC,ConstantSDNode * CC,SDValue Y,unsigned OldShiftOpcode,unsigned NewShiftOpcode,SelectionDAG & DAG) const1901     shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
1902         SDValue X, ConstantSDNode *XC, ConstantSDNode *CC, SDValue Y,
1903         unsigned OldShiftOpcode, unsigned NewShiftOpcode,
1904         SelectionDAG &DAG) const {
1905   // One interesting pattern that we'd want to form is 'bit extract':
1906   //   ((1 >> Y) & 1) ==/!= 0
1907   // But we also need to be careful not to try to reverse that fold.
1908 
1909   // Is this '((1 >> Y) & 1)'?
1910   if (XC && OldShiftOpcode == ISD::SRL && XC->isOne())
1911     return false; // Keep the 'bit extract' pattern.
1912 
1913   // Will this be '((1 >> Y) & 1)' after the transform?
1914   if (NewShiftOpcode == ISD::SRL && CC->isOne())
1915     return true; // Do form the 'bit extract' pattern.
1916 
1917   // If 'X' is a constant, and we transform, then we will immediately
1918   // try to undo the fold, thus causing endless combine loop.
1919   // So only do the transform if X is not a constant. This matches the default
1920   // implementation of this function.
1921   return !XC;
1922 }
1923 
canSplatOperand(unsigned Opcode,int Operand) const1924 bool RISCVTargetLowering::canSplatOperand(unsigned Opcode, int Operand) const {
1925   switch (Opcode) {
1926   case Instruction::Add:
1927   case Instruction::Sub:
1928   case Instruction::Mul:
1929   case Instruction::And:
1930   case Instruction::Or:
1931   case Instruction::Xor:
1932   case Instruction::FAdd:
1933   case Instruction::FSub:
1934   case Instruction::FMul:
1935   case Instruction::FDiv:
1936   case Instruction::ICmp:
1937   case Instruction::FCmp:
1938     return true;
1939   case Instruction::Shl:
1940   case Instruction::LShr:
1941   case Instruction::AShr:
1942   case Instruction::UDiv:
1943   case Instruction::SDiv:
1944   case Instruction::URem:
1945   case Instruction::SRem:
1946     return Operand == 1;
1947   default:
1948     return false;
1949   }
1950 }
1951 
1952 
canSplatOperand(Instruction * I,int Operand) const1953 bool RISCVTargetLowering::canSplatOperand(Instruction *I, int Operand) const {
1954   if (!I->getType()->isVectorTy() || !Subtarget.hasVInstructions())
1955     return false;
1956 
1957   if (canSplatOperand(I->getOpcode(), Operand))
1958     return true;
1959 
1960   auto *II = dyn_cast<IntrinsicInst>(I);
1961   if (!II)
1962     return false;
1963 
1964   switch (II->getIntrinsicID()) {
1965   case Intrinsic::fma:
1966   case Intrinsic::vp_fma:
1967     return Operand == 0 || Operand == 1;
1968   case Intrinsic::vp_shl:
1969   case Intrinsic::vp_lshr:
1970   case Intrinsic::vp_ashr:
1971   case Intrinsic::vp_udiv:
1972   case Intrinsic::vp_sdiv:
1973   case Intrinsic::vp_urem:
1974   case Intrinsic::vp_srem:
1975     return Operand == 1;
1976     // These intrinsics are commutative.
1977   case Intrinsic::vp_add:
1978   case Intrinsic::vp_mul:
1979   case Intrinsic::vp_and:
1980   case Intrinsic::vp_or:
1981   case Intrinsic::vp_xor:
1982   case Intrinsic::vp_fadd:
1983   case Intrinsic::vp_fmul:
1984   case Intrinsic::vp_icmp:
1985   case Intrinsic::vp_fcmp:
1986     // These intrinsics have 'vr' versions.
1987   case Intrinsic::vp_sub:
1988   case Intrinsic::vp_fsub:
1989   case Intrinsic::vp_fdiv:
1990     return Operand == 0 || Operand == 1;
1991   default:
1992     return false;
1993   }
1994 }
1995 
1996 /// Check if sinking \p I's operands to I's basic block is profitable, because
1997 /// the operands can be folded into a target instruction, e.g.
1998 /// splats of scalars can fold into vector instructions.
shouldSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const1999 bool RISCVTargetLowering::shouldSinkOperands(
2000     Instruction *I, SmallVectorImpl<Use *> &Ops) const {
2001   using namespace llvm::PatternMatch;
2002 
2003   if (!I->getType()->isVectorTy() || !Subtarget.hasVInstructions())
2004     return false;
2005 
2006   for (auto OpIdx : enumerate(I->operands())) {
2007     if (!canSplatOperand(I, OpIdx.index()))
2008       continue;
2009 
2010     Instruction *Op = dyn_cast<Instruction>(OpIdx.value().get());
2011     // Make sure we are not already sinking this operand
2012     if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; }))
2013       continue;
2014 
2015     // We are looking for a splat that can be sunk.
2016     if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()),
2017                              m_Undef(), m_ZeroMask())))
2018       continue;
2019 
2020     // Don't sink i1 splats.
2021     if (cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(1))
2022       continue;
2023 
2024     // All uses of the shuffle should be sunk to avoid duplicating it across gpr
2025     // and vector registers
2026     for (Use &U : Op->uses()) {
2027       Instruction *Insn = cast<Instruction>(U.getUser());
2028       if (!canSplatOperand(Insn, U.getOperandNo()))
2029         return false;
2030     }
2031 
2032     Ops.push_back(&Op->getOperandUse(0));
2033     Ops.push_back(&OpIdx.value());
2034   }
2035   return true;
2036 }
2037 
shouldScalarizeBinop(SDValue VecOp) const2038 bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
2039   unsigned Opc = VecOp.getOpcode();
2040 
2041   // Assume target opcodes can't be scalarized.
2042   // TODO - do we have any exceptions?
2043   if (Opc >= ISD::BUILTIN_OP_END)
2044     return false;
2045 
2046   // If the vector op is not supported, try to convert to scalar.
2047   EVT VecVT = VecOp.getValueType();
2048   if (!isOperationLegalOrCustomOrPromote(Opc, VecVT))
2049     return true;
2050 
2051   // If the vector op is supported, but the scalar op is not, the transform may
2052   // not be worthwhile.
2053   // Permit a vector binary operation can be converted to scalar binary
2054   // operation which is custom lowered with illegal type.
2055   EVT ScalarVT = VecVT.getScalarType();
2056   return isOperationLegalOrCustomOrPromote(Opc, ScalarVT) ||
2057          isOperationCustom(Opc, ScalarVT);
2058 }
2059 
isOffsetFoldingLegal(const GlobalAddressSDNode * GA) const2060 bool RISCVTargetLowering::isOffsetFoldingLegal(
2061     const GlobalAddressSDNode *GA) const {
2062   // In order to maximise the opportunity for common subexpression elimination,
2063   // keep a separate ADD node for the global address offset instead of folding
2064   // it in the global address node. Later peephole optimisations may choose to
2065   // fold it back in when profitable.
2066   return false;
2067 }
2068 
2069 // Return one of the followings:
2070 // (1) `{0-31 value, false}` if FLI is available for Imm's type and FP value.
2071 // (2) `{0-31 value, true}` if Imm is negative and FLI is available for its
2072 // positive counterpart, which will be materialized from the first returned
2073 // element. The second returned element indicated that there should be a FNEG
2074 // followed.
2075 // (3) `{-1, _}` if there is no way FLI can be used to materialize Imm.
getLegalZfaFPImm(const APFloat & Imm,EVT VT) const2076 std::pair<int, bool> RISCVTargetLowering::getLegalZfaFPImm(const APFloat &Imm,
2077                                                            EVT VT) const {
2078   if (!Subtarget.hasStdExtZfa())
2079     return std::make_pair(-1, false);
2080 
2081   bool IsSupportedVT = false;
2082   if (VT == MVT::f16) {
2083     IsSupportedVT = Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZvfh();
2084   } else if (VT == MVT::f32) {
2085     IsSupportedVT = true;
2086   } else if (VT == MVT::f64) {
2087     assert(Subtarget.hasStdExtD() && "Expect D extension");
2088     IsSupportedVT = true;
2089   }
2090 
2091   if (!IsSupportedVT)
2092     return std::make_pair(-1, false);
2093 
2094   int Index = RISCVLoadFPImm::getLoadFPImm(Imm);
2095   if (Index < 0 && Imm.isNegative())
2096     // Try the combination of its positive counterpart + FNEG.
2097     return std::make_pair(RISCVLoadFPImm::getLoadFPImm(-Imm), true);
2098   else
2099     return std::make_pair(Index, false);
2100 }
2101 
isFPImmLegal(const APFloat & Imm,EVT VT,bool ForCodeSize) const2102 bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
2103                                        bool ForCodeSize) const {
2104   bool IsLegalVT = false;
2105   if (VT == MVT::f16)
2106     IsLegalVT = Subtarget.hasStdExtZfhminOrZhinxmin();
2107   else if (VT == MVT::f32)
2108     IsLegalVT = Subtarget.hasStdExtFOrZfinx();
2109   else if (VT == MVT::f64)
2110     IsLegalVT = Subtarget.hasStdExtDOrZdinx();
2111   else if (VT == MVT::bf16)
2112     IsLegalVT = Subtarget.hasStdExtZfbfmin();
2113 
2114   if (!IsLegalVT)
2115     return false;
2116 
2117   if (getLegalZfaFPImm(Imm, VT).first >= 0)
2118     return true;
2119 
2120   // Cannot create a 64 bit floating-point immediate value for rv32.
2121   if (Subtarget.getXLen() < VT.getScalarSizeInBits()) {
2122     // td can handle +0.0 or -0.0 already.
2123     // -0.0 can be created by fmv + fneg.
2124     return Imm.isZero();
2125   }
2126 
2127   // Special case: fmv + fneg
2128   if (Imm.isNegZero())
2129     return true;
2130 
2131   // Building an integer and then converting requires a fmv at the end of
2132   // the integer sequence.
2133   const int Cost =
2134       1 + RISCVMatInt::getIntMatCost(Imm.bitcastToAPInt(), Subtarget.getXLen(),
2135                                      Subtarget);
2136   return Cost <= FPImmCost;
2137 }
2138 
2139 // TODO: This is very conservative.
isExtractSubvectorCheap(EVT ResVT,EVT SrcVT,unsigned Index) const2140 bool RISCVTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
2141                                                   unsigned Index) const {
2142   if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
2143     return false;
2144 
2145   // Only support extracting a fixed from a fixed vector for now.
2146   if (ResVT.isScalableVector() || SrcVT.isScalableVector())
2147     return false;
2148 
2149   unsigned ResElts = ResVT.getVectorNumElements();
2150   unsigned SrcElts = SrcVT.getVectorNumElements();
2151 
2152   // Convervatively only handle extracting half of a vector.
2153   // TODO: Relax this.
2154   if ((ResElts * 2) != SrcElts)
2155     return false;
2156 
2157   // The smallest type we can slide is i8.
2158   // TODO: We can extract index 0 from a mask vector without a slide.
2159   if (ResVT.getVectorElementType() == MVT::i1)
2160     return false;
2161 
2162   // Slide can support arbitrary index, but we only treat vslidedown.vi as
2163   // cheap.
2164   if (Index >= 32)
2165     return false;
2166 
2167   // TODO: We can do arbitrary slidedowns, but for now only support extracting
2168   // the upper half of a vector until we have more test coverage.
2169   return Index == 0 || Index == ResElts;
2170 }
2171 
getRegisterTypeForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const2172 MVT RISCVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
2173                                                       CallingConv::ID CC,
2174                                                       EVT VT) const {
2175   // Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled.
2176   // We might still end up using a GPR but that will be decided based on ABI.
2177   if (VT == MVT::f16 && Subtarget.hasStdExtFOrZfinx() &&
2178       !Subtarget.hasStdExtZfhminOrZhinxmin())
2179     return MVT::f32;
2180 
2181   MVT PartVT = TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
2182 
2183   if (RV64LegalI32 && Subtarget.is64Bit() && PartVT == MVT::i32)
2184     return MVT::i64;
2185 
2186   return PartVT;
2187 }
2188 
getNumRegistersForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const2189 unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
2190                                                            CallingConv::ID CC,
2191                                                            EVT VT) const {
2192   // Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled.
2193   // We might still end up using a GPR but that will be decided based on ABI.
2194   if (VT == MVT::f16 && Subtarget.hasStdExtFOrZfinx() &&
2195       !Subtarget.hasStdExtZfhminOrZhinxmin())
2196     return 1;
2197 
2198   return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
2199 }
2200 
getVectorTypeBreakdownForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT,EVT & IntermediateVT,unsigned & NumIntermediates,MVT & RegisterVT) const2201 unsigned RISCVTargetLowering::getVectorTypeBreakdownForCallingConv(
2202     LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT,
2203     unsigned &NumIntermediates, MVT &RegisterVT) const {
2204   unsigned NumRegs = TargetLowering::getVectorTypeBreakdownForCallingConv(
2205       Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
2206 
2207   if (RV64LegalI32 && Subtarget.is64Bit() && IntermediateVT == MVT::i32)
2208     IntermediateVT = MVT::i64;
2209 
2210   if (RV64LegalI32 && Subtarget.is64Bit() && RegisterVT == MVT::i32)
2211     RegisterVT = MVT::i64;
2212 
2213   return NumRegs;
2214 }
2215 
2216 // Changes the condition code and swaps operands if necessary, so the SetCC
2217 // operation matches one of the comparisons supported directly by branches
2218 // in the RISC-V ISA. May adjust compares to favor compare with 0 over compare
2219 // with 1/-1.
translateSetCCForBranch(const SDLoc & DL,SDValue & LHS,SDValue & RHS,ISD::CondCode & CC,SelectionDAG & DAG)2220 static void translateSetCCForBranch(const SDLoc &DL, SDValue &LHS, SDValue &RHS,
2221                                     ISD::CondCode &CC, SelectionDAG &DAG) {
2222   // If this is a single bit test that can't be handled by ANDI, shift the
2223   // bit to be tested to the MSB and perform a signed compare with 0.
2224   if (isIntEqualitySetCC(CC) && isNullConstant(RHS) &&
2225       LHS.getOpcode() == ISD::AND && LHS.hasOneUse() &&
2226       isa<ConstantSDNode>(LHS.getOperand(1))) {
2227     uint64_t Mask = LHS.getConstantOperandVal(1);
2228     if ((isPowerOf2_64(Mask) || isMask_64(Mask)) && !isInt<12>(Mask)) {
2229       unsigned ShAmt = 0;
2230       if (isPowerOf2_64(Mask)) {
2231         CC = CC == ISD::SETEQ ? ISD::SETGE : ISD::SETLT;
2232         ShAmt = LHS.getValueSizeInBits() - 1 - Log2_64(Mask);
2233       } else {
2234         ShAmt = LHS.getValueSizeInBits() - llvm::bit_width(Mask);
2235       }
2236 
2237       LHS = LHS.getOperand(0);
2238       if (ShAmt != 0)
2239         LHS = DAG.getNode(ISD::SHL, DL, LHS.getValueType(), LHS,
2240                           DAG.getConstant(ShAmt, DL, LHS.getValueType()));
2241       return;
2242     }
2243   }
2244 
2245   if (auto *RHSC = dyn_cast<ConstantSDNode>(RHS)) {
2246     int64_t C = RHSC->getSExtValue();
2247     switch (CC) {
2248     default: break;
2249     case ISD::SETGT:
2250       // Convert X > -1 to X >= 0.
2251       if (C == -1) {
2252         RHS = DAG.getConstant(0, DL, RHS.getValueType());
2253         CC = ISD::SETGE;
2254         return;
2255       }
2256       break;
2257     case ISD::SETLT:
2258       // Convert X < 1 to 0 >= X.
2259       if (C == 1) {
2260         RHS = LHS;
2261         LHS = DAG.getConstant(0, DL, RHS.getValueType());
2262         CC = ISD::SETGE;
2263         return;
2264       }
2265       break;
2266     }
2267   }
2268 
2269   switch (CC) {
2270   default:
2271     break;
2272   case ISD::SETGT:
2273   case ISD::SETLE:
2274   case ISD::SETUGT:
2275   case ISD::SETULE:
2276     CC = ISD::getSetCCSwappedOperands(CC);
2277     std::swap(LHS, RHS);
2278     break;
2279   }
2280 }
2281 
getLMUL(MVT VT)2282 RISCVII::VLMUL RISCVTargetLowering::getLMUL(MVT VT) {
2283   assert(VT.isScalableVector() && "Expecting a scalable vector type");
2284   unsigned KnownSize = VT.getSizeInBits().getKnownMinValue();
2285   if (VT.getVectorElementType() == MVT::i1)
2286     KnownSize *= 8;
2287 
2288   switch (KnownSize) {
2289   default:
2290     llvm_unreachable("Invalid LMUL.");
2291   case 8:
2292     return RISCVII::VLMUL::LMUL_F8;
2293   case 16:
2294     return RISCVII::VLMUL::LMUL_F4;
2295   case 32:
2296     return RISCVII::VLMUL::LMUL_F2;
2297   case 64:
2298     return RISCVII::VLMUL::LMUL_1;
2299   case 128:
2300     return RISCVII::VLMUL::LMUL_2;
2301   case 256:
2302     return RISCVII::VLMUL::LMUL_4;
2303   case 512:
2304     return RISCVII::VLMUL::LMUL_8;
2305   }
2306 }
2307 
getRegClassIDForLMUL(RISCVII::VLMUL LMul)2308 unsigned RISCVTargetLowering::getRegClassIDForLMUL(RISCVII::VLMUL LMul) {
2309   switch (LMul) {
2310   default:
2311     llvm_unreachable("Invalid LMUL.");
2312   case RISCVII::VLMUL::LMUL_F8:
2313   case RISCVII::VLMUL::LMUL_F4:
2314   case RISCVII::VLMUL::LMUL_F2:
2315   case RISCVII::VLMUL::LMUL_1:
2316     return RISCV::VRRegClassID;
2317   case RISCVII::VLMUL::LMUL_2:
2318     return RISCV::VRM2RegClassID;
2319   case RISCVII::VLMUL::LMUL_4:
2320     return RISCV::VRM4RegClassID;
2321   case RISCVII::VLMUL::LMUL_8:
2322     return RISCV::VRM8RegClassID;
2323   }
2324 }
2325 
getSubregIndexByMVT(MVT VT,unsigned Index)2326 unsigned RISCVTargetLowering::getSubregIndexByMVT(MVT VT, unsigned Index) {
2327   RISCVII::VLMUL LMUL = getLMUL(VT);
2328   if (LMUL == RISCVII::VLMUL::LMUL_F8 ||
2329       LMUL == RISCVII::VLMUL::LMUL_F4 ||
2330       LMUL == RISCVII::VLMUL::LMUL_F2 ||
2331       LMUL == RISCVII::VLMUL::LMUL_1) {
2332     static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7,
2333                   "Unexpected subreg numbering");
2334     return RISCV::sub_vrm1_0 + Index;
2335   }
2336   if (LMUL == RISCVII::VLMUL::LMUL_2) {
2337     static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3,
2338                   "Unexpected subreg numbering");
2339     return RISCV::sub_vrm2_0 + Index;
2340   }
2341   if (LMUL == RISCVII::VLMUL::LMUL_4) {
2342     static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1,
2343                   "Unexpected subreg numbering");
2344     return RISCV::sub_vrm4_0 + Index;
2345   }
2346   llvm_unreachable("Invalid vector type.");
2347 }
2348 
getRegClassIDForVecVT(MVT VT)2349 unsigned RISCVTargetLowering::getRegClassIDForVecVT(MVT VT) {
2350   if (VT.getVectorElementType() == MVT::i1)
2351     return RISCV::VRRegClassID;
2352   return getRegClassIDForLMUL(getLMUL(VT));
2353 }
2354 
2355 // Attempt to decompose a subvector insert/extract between VecVT and
2356 // SubVecVT via subregister indices. Returns the subregister index that
2357 // can perform the subvector insert/extract with the given element index, as
2358 // well as the index corresponding to any leftover subvectors that must be
2359 // further inserted/extracted within the register class for SubVecVT.
2360 std::pair<unsigned, unsigned>
decomposeSubvectorInsertExtractToSubRegs(MVT VecVT,MVT SubVecVT,unsigned InsertExtractIdx,const RISCVRegisterInfo * TRI)2361 RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
2362     MVT VecVT, MVT SubVecVT, unsigned InsertExtractIdx,
2363     const RISCVRegisterInfo *TRI) {
2364   static_assert((RISCV::VRM8RegClassID > RISCV::VRM4RegClassID &&
2365                  RISCV::VRM4RegClassID > RISCV::VRM2RegClassID &&
2366                  RISCV::VRM2RegClassID > RISCV::VRRegClassID),
2367                 "Register classes not ordered");
2368   unsigned VecRegClassID = getRegClassIDForVecVT(VecVT);
2369   unsigned SubRegClassID = getRegClassIDForVecVT(SubVecVT);
2370   // Try to compose a subregister index that takes us from the incoming
2371   // LMUL>1 register class down to the outgoing one. At each step we half
2372   // the LMUL:
2373   //   nxv16i32@12 -> nxv2i32: sub_vrm4_1_then_sub_vrm2_1_then_sub_vrm1_0
2374   // Note that this is not guaranteed to find a subregister index, such as
2375   // when we are extracting from one VR type to another.
2376   unsigned SubRegIdx = RISCV::NoSubRegister;
2377   for (const unsigned RCID :
2378        {RISCV::VRM4RegClassID, RISCV::VRM2RegClassID, RISCV::VRRegClassID})
2379     if (VecRegClassID > RCID && SubRegClassID <= RCID) {
2380       VecVT = VecVT.getHalfNumVectorElementsVT();
2381       bool IsHi =
2382           InsertExtractIdx >= VecVT.getVectorElementCount().getKnownMinValue();
2383       SubRegIdx = TRI->composeSubRegIndices(SubRegIdx,
2384                                             getSubregIndexByMVT(VecVT, IsHi));
2385       if (IsHi)
2386         InsertExtractIdx -= VecVT.getVectorElementCount().getKnownMinValue();
2387     }
2388   return {SubRegIdx, InsertExtractIdx};
2389 }
2390 
2391 // Permit combining of mask vectors as BUILD_VECTOR never expands to scalar
2392 // stores for those types.
mergeStoresAfterLegalization(EVT VT) const2393 bool RISCVTargetLowering::mergeStoresAfterLegalization(EVT VT) const {
2394   return !Subtarget.useRVVForFixedLengthVectors() ||
2395          (VT.isFixedLengthVector() && VT.getVectorElementType() == MVT::i1);
2396 }
2397 
isLegalElementTypeForRVV(EVT ScalarTy) const2398 bool RISCVTargetLowering::isLegalElementTypeForRVV(EVT ScalarTy) const {
2399   if (!ScalarTy.isSimple())
2400     return false;
2401   switch (ScalarTy.getSimpleVT().SimpleTy) {
2402   case MVT::iPTR:
2403     return Subtarget.is64Bit() ? Subtarget.hasVInstructionsI64() : true;
2404   case MVT::i8:
2405   case MVT::i16:
2406   case MVT::i32:
2407     return true;
2408   case MVT::i64:
2409     return Subtarget.hasVInstructionsI64();
2410   case MVT::f16:
2411     return Subtarget.hasVInstructionsF16();
2412   case MVT::f32:
2413     return Subtarget.hasVInstructionsF32();
2414   case MVT::f64:
2415     return Subtarget.hasVInstructionsF64();
2416   default:
2417     return false;
2418   }
2419 }
2420 
2421 
combineRepeatedFPDivisors() const2422 unsigned RISCVTargetLowering::combineRepeatedFPDivisors() const {
2423   return NumRepeatedDivisors;
2424 }
2425 
getVLOperand(SDValue Op)2426 static SDValue getVLOperand(SDValue Op) {
2427   assert((Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
2428           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN) &&
2429          "Unexpected opcode");
2430   bool HasChain = Op.getOpcode() == ISD::INTRINSIC_W_CHAIN;
2431   unsigned IntNo = Op.getConstantOperandVal(HasChain ? 1 : 0);
2432   const RISCVVIntrinsicsTable::RISCVVIntrinsicInfo *II =
2433       RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IntNo);
2434   if (!II)
2435     return SDValue();
2436   return Op.getOperand(II->VLOperand + 1 + HasChain);
2437 }
2438 
useRVVForFixedLengthVectorVT(MVT VT,const RISCVSubtarget & Subtarget)2439 static bool useRVVForFixedLengthVectorVT(MVT VT,
2440                                          const RISCVSubtarget &Subtarget) {
2441   assert(VT.isFixedLengthVector() && "Expected a fixed length vector type!");
2442   if (!Subtarget.useRVVForFixedLengthVectors())
2443     return false;
2444 
2445   // We only support a set of vector types with a consistent maximum fixed size
2446   // across all supported vector element types to avoid legalization issues.
2447   // Therefore -- since the largest is v1024i8/v512i16/etc -- the largest
2448   // fixed-length vector type we support is 1024 bytes.
2449   if (VT.getFixedSizeInBits() > 1024 * 8)
2450     return false;
2451 
2452   unsigned MinVLen = Subtarget.getRealMinVLen();
2453 
2454   MVT EltVT = VT.getVectorElementType();
2455 
2456   // Don't use RVV for vectors we cannot scalarize if required.
2457   switch (EltVT.SimpleTy) {
2458   // i1 is supported but has different rules.
2459   default:
2460     return false;
2461   case MVT::i1:
2462     // Masks can only use a single register.
2463     if (VT.getVectorNumElements() > MinVLen)
2464       return false;
2465     MinVLen /= 8;
2466     break;
2467   case MVT::i8:
2468   case MVT::i16:
2469   case MVT::i32:
2470     break;
2471   case MVT::i64:
2472     if (!Subtarget.hasVInstructionsI64())
2473       return false;
2474     break;
2475   case MVT::f16:
2476     if (!Subtarget.hasVInstructionsF16Minimal())
2477       return false;
2478     break;
2479   case MVT::f32:
2480     if (!Subtarget.hasVInstructionsF32())
2481       return false;
2482     break;
2483   case MVT::f64:
2484     if (!Subtarget.hasVInstructionsF64())
2485       return false;
2486     break;
2487   }
2488 
2489   // Reject elements larger than ELEN.
2490   if (EltVT.getSizeInBits() > Subtarget.getELen())
2491     return false;
2492 
2493   unsigned LMul = divideCeil(VT.getSizeInBits(), MinVLen);
2494   // Don't use RVV for types that don't fit.
2495   if (LMul > Subtarget.getMaxLMULForFixedLengthVectors())
2496     return false;
2497 
2498   // TODO: Perhaps an artificial restriction, but worth having whilst getting
2499   // the base fixed length RVV support in place.
2500   if (!VT.isPow2VectorType())
2501     return false;
2502 
2503   return true;
2504 }
2505 
useRVVForFixedLengthVectorVT(MVT VT) const2506 bool RISCVTargetLowering::useRVVForFixedLengthVectorVT(MVT VT) const {
2507   return ::useRVVForFixedLengthVectorVT(VT, Subtarget);
2508 }
2509 
2510 // Return the largest legal scalable vector type that matches VT's element type.
getContainerForFixedLengthVector(const TargetLowering & TLI,MVT VT,const RISCVSubtarget & Subtarget)2511 static MVT getContainerForFixedLengthVector(const TargetLowering &TLI, MVT VT,
2512                                             const RISCVSubtarget &Subtarget) {
2513   // This may be called before legal types are setup.
2514   assert(((VT.isFixedLengthVector() && TLI.isTypeLegal(VT)) ||
2515           useRVVForFixedLengthVectorVT(VT, Subtarget)) &&
2516          "Expected legal fixed length vector!");
2517 
2518   unsigned MinVLen = Subtarget.getRealMinVLen();
2519   unsigned MaxELen = Subtarget.getELen();
2520 
2521   MVT EltVT = VT.getVectorElementType();
2522   switch (EltVT.SimpleTy) {
2523   default:
2524     llvm_unreachable("unexpected element type for RVV container");
2525   case MVT::i1:
2526   case MVT::i8:
2527   case MVT::i16:
2528   case MVT::i32:
2529   case MVT::i64:
2530   case MVT::f16:
2531   case MVT::f32:
2532   case MVT::f64: {
2533     // We prefer to use LMUL=1 for VLEN sized types. Use fractional lmuls for
2534     // narrower types. The smallest fractional LMUL we support is 8/ELEN. Within
2535     // each fractional LMUL we support SEW between 8 and LMUL*ELEN.
2536     unsigned NumElts =
2537         (VT.getVectorNumElements() * RISCV::RVVBitsPerBlock) / MinVLen;
2538     NumElts = std::max(NumElts, RISCV::RVVBitsPerBlock / MaxELen);
2539     assert(isPowerOf2_32(NumElts) && "Expected power of 2 NumElts");
2540     return MVT::getScalableVectorVT(EltVT, NumElts);
2541   }
2542   }
2543 }
2544 
getContainerForFixedLengthVector(SelectionDAG & DAG,MVT VT,const RISCVSubtarget & Subtarget)2545 static MVT getContainerForFixedLengthVector(SelectionDAG &DAG, MVT VT,
2546                                             const RISCVSubtarget &Subtarget) {
2547   return getContainerForFixedLengthVector(DAG.getTargetLoweringInfo(), VT,
2548                                           Subtarget);
2549 }
2550 
getContainerForFixedLengthVector(MVT VT) const2551 MVT RISCVTargetLowering::getContainerForFixedLengthVector(MVT VT) const {
2552   return ::getContainerForFixedLengthVector(*this, VT, getSubtarget());
2553 }
2554 
2555 // Grow V to consume an entire RVV register.
convertToScalableVector(EVT VT,SDValue V,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2556 static SDValue convertToScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
2557                                        const RISCVSubtarget &Subtarget) {
2558   assert(VT.isScalableVector() &&
2559          "Expected to convert into a scalable vector!");
2560   assert(V.getValueType().isFixedLengthVector() &&
2561          "Expected a fixed length vector operand!");
2562   SDLoc DL(V);
2563   SDValue Zero = DAG.getConstant(0, DL, Subtarget.getXLenVT());
2564   return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), V, Zero);
2565 }
2566 
2567 // Shrink V so it's just big enough to maintain a VT's worth of data.
convertFromScalableVector(EVT VT,SDValue V,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2568 static SDValue convertFromScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
2569                                          const RISCVSubtarget &Subtarget) {
2570   assert(VT.isFixedLengthVector() &&
2571          "Expected to convert into a fixed length vector!");
2572   assert(V.getValueType().isScalableVector() &&
2573          "Expected a scalable vector operand!");
2574   SDLoc DL(V);
2575   SDValue Zero = DAG.getConstant(0, DL, Subtarget.getXLenVT());
2576   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
2577 }
2578 
2579 /// Return the type of the mask type suitable for masking the provided
2580 /// vector type.  This is simply an i1 element type vector of the same
2581 /// (possibly scalable) length.
getMaskTypeFor(MVT VecVT)2582 static MVT getMaskTypeFor(MVT VecVT) {
2583   assert(VecVT.isVector());
2584   ElementCount EC = VecVT.getVectorElementCount();
2585   return MVT::getVectorVT(MVT::i1, EC);
2586 }
2587 
2588 /// Creates an all ones mask suitable for masking a vector of type VecTy with
2589 /// vector length VL.  .
getAllOnesMask(MVT VecVT,SDValue VL,const SDLoc & DL,SelectionDAG & DAG)2590 static SDValue getAllOnesMask(MVT VecVT, SDValue VL, const SDLoc &DL,
2591                               SelectionDAG &DAG) {
2592   MVT MaskVT = getMaskTypeFor(VecVT);
2593   return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
2594 }
2595 
getVLOp(uint64_t NumElts,MVT ContainerVT,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2596 static SDValue getVLOp(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
2597                        SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
2598   // If we know the exact VLEN, and our VL is exactly equal to VLMAX,
2599   // canonicalize the representation.  InsertVSETVLI will pick the immediate
2600   // encoding later if profitable.
2601   const auto [MinVLMAX, MaxVLMAX] =
2602       RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
2603   if (MinVLMAX == MaxVLMAX && NumElts == MinVLMAX)
2604     return DAG.getRegister(RISCV::X0, Subtarget.getXLenVT());
2605 
2606   return DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
2607 }
2608 
2609 static std::pair<SDValue, SDValue>
getDefaultScalableVLOps(MVT VecVT,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2610 getDefaultScalableVLOps(MVT VecVT, const SDLoc &DL, SelectionDAG &DAG,
2611                         const RISCVSubtarget &Subtarget) {
2612   assert(VecVT.isScalableVector() && "Expecting a scalable vector");
2613   SDValue VL = DAG.getRegister(RISCV::X0, Subtarget.getXLenVT());
2614   SDValue Mask = getAllOnesMask(VecVT, VL, DL, DAG);
2615   return {Mask, VL};
2616 }
2617 
2618 static std::pair<SDValue, SDValue>
getDefaultVLOps(uint64_t NumElts,MVT ContainerVT,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2619 getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
2620                 SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
2621   assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
2622   SDValue VL = getVLOp(NumElts, ContainerVT, DL, DAG, Subtarget);
2623   SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
2624   return {Mask, VL};
2625 }
2626 
2627 // Gets the two common "VL" operands: an all-ones mask and the vector length.
2628 // VecVT is a vector type, either fixed-length or scalable, and ContainerVT is
2629 // the vector type that the fixed-length vector is contained in. Otherwise if
2630 // VecVT is scalable, then ContainerVT should be the same as VecVT.
2631 static std::pair<SDValue, SDValue>
getDefaultVLOps(MVT VecVT,MVT ContainerVT,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2632 getDefaultVLOps(MVT VecVT, MVT ContainerVT, const SDLoc &DL, SelectionDAG &DAG,
2633                 const RISCVSubtarget &Subtarget) {
2634   if (VecVT.isFixedLengthVector())
2635     return getDefaultVLOps(VecVT.getVectorNumElements(), ContainerVT, DL, DAG,
2636                            Subtarget);
2637   assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
2638   return getDefaultScalableVLOps(ContainerVT, DL, DAG, Subtarget);
2639 }
2640 
computeVLMax(MVT VecVT,const SDLoc & DL,SelectionDAG & DAG) const2641 SDValue RISCVTargetLowering::computeVLMax(MVT VecVT, const SDLoc &DL,
2642                                           SelectionDAG &DAG) const {
2643   assert(VecVT.isScalableVector() && "Expected scalable vector");
2644   return DAG.getElementCount(DL, Subtarget.getXLenVT(),
2645                              VecVT.getVectorElementCount());
2646 }
2647 
2648 std::pair<unsigned, unsigned>
computeVLMAXBounds(MVT VecVT,const RISCVSubtarget & Subtarget)2649 RISCVTargetLowering::computeVLMAXBounds(MVT VecVT,
2650                                         const RISCVSubtarget &Subtarget) {
2651   assert(VecVT.isScalableVector() && "Expected scalable vector");
2652 
2653   unsigned EltSize = VecVT.getScalarSizeInBits();
2654   unsigned MinSize = VecVT.getSizeInBits().getKnownMinValue();
2655 
2656   unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
2657   unsigned MaxVLMAX =
2658       RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize);
2659 
2660   unsigned VectorBitsMin = Subtarget.getRealMinVLen();
2661   unsigned MinVLMAX =
2662       RISCVTargetLowering::computeVLMAX(VectorBitsMin, EltSize, MinSize);
2663 
2664   return std::make_pair(MinVLMAX, MaxVLMAX);
2665 }
2666 
2667 // The state of RVV BUILD_VECTOR and VECTOR_SHUFFLE lowering is that very few
2668 // of either is (currently) supported. This can get us into an infinite loop
2669 // where we try to lower a BUILD_VECTOR as a VECTOR_SHUFFLE as a BUILD_VECTOR
2670 // as a ..., etc.
2671 // Until either (or both) of these can reliably lower any node, reporting that
2672 // we don't want to expand BUILD_VECTORs via VECTOR_SHUFFLEs at least breaks
2673 // the infinite loop. Note that this lowers BUILD_VECTOR through the stack,
2674 // which is not desirable.
shouldExpandBuildVectorWithShuffles(EVT VT,unsigned DefinedValues) const2675 bool RISCVTargetLowering::shouldExpandBuildVectorWithShuffles(
2676     EVT VT, unsigned DefinedValues) const {
2677   return false;
2678 }
2679 
getLMULCost(MVT VT) const2680 InstructionCost RISCVTargetLowering::getLMULCost(MVT VT) const {
2681   // TODO: Here assume reciprocal throughput is 1 for LMUL_1, it is
2682   // implementation-defined.
2683   if (!VT.isVector())
2684     return InstructionCost::getInvalid();
2685   unsigned DLenFactor = Subtarget.getDLenFactor();
2686   unsigned Cost;
2687   if (VT.isScalableVector()) {
2688     unsigned LMul;
2689     bool Fractional;
2690     std::tie(LMul, Fractional) =
2691         RISCVVType::decodeVLMUL(RISCVTargetLowering::getLMUL(VT));
2692     if (Fractional)
2693       Cost = LMul <= DLenFactor ? (DLenFactor / LMul) : 1;
2694     else
2695       Cost = (LMul * DLenFactor);
2696   } else {
2697     Cost = divideCeil(VT.getSizeInBits(), Subtarget.getRealMinVLen() / DLenFactor);
2698   }
2699   return Cost;
2700 }
2701 
2702 
2703 /// Return the cost of a vrgather.vv instruction for the type VT.  vrgather.vv
2704 /// is generally quadratic in the number of vreg implied by LMUL.  Note that
2705 /// operand (index and possibly mask) are handled separately.
getVRGatherVVCost(MVT VT) const2706 InstructionCost RISCVTargetLowering::getVRGatherVVCost(MVT VT) const {
2707   return getLMULCost(VT) * getLMULCost(VT);
2708 }
2709 
2710 /// Return the cost of a vrgather.vi (or vx) instruction for the type VT.
2711 /// vrgather.vi/vx may be linear in the number of vregs implied by LMUL,
2712 /// or may track the vrgather.vv cost. It is implementation-dependent.
getVRGatherVICost(MVT VT) const2713 InstructionCost RISCVTargetLowering::getVRGatherVICost(MVT VT) const {
2714   return getLMULCost(VT);
2715 }
2716 
2717 /// Return the cost of a vslidedown.vx or vslideup.vx instruction
2718 /// for the type VT.  (This does not cover the vslide1up or vslide1down
2719 /// variants.)  Slides may be linear in the number of vregs implied by LMUL,
2720 /// or may track the vrgather.vv cost. It is implementation-dependent.
getVSlideVXCost(MVT VT) const2721 InstructionCost RISCVTargetLowering::getVSlideVXCost(MVT VT) const {
2722   return getLMULCost(VT);
2723 }
2724 
2725 /// Return the cost of a vslidedown.vi or vslideup.vi instruction
2726 /// for the type VT.  (This does not cover the vslide1up or vslide1down
2727 /// variants.)  Slides may be linear in the number of vregs implied by LMUL,
2728 /// or may track the vrgather.vv cost. It is implementation-dependent.
getVSlideVICost(MVT VT) const2729 InstructionCost RISCVTargetLowering::getVSlideVICost(MVT VT) const {
2730   return getLMULCost(VT);
2731 }
2732 
lowerFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2733 static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
2734                                   const RISCVSubtarget &Subtarget) {
2735   // RISC-V FP-to-int conversions saturate to the destination register size, but
2736   // don't produce 0 for nan. We can use a conversion instruction and fix the
2737   // nan case with a compare and a select.
2738   SDValue Src = Op.getOperand(0);
2739 
2740   MVT DstVT = Op.getSimpleValueType();
2741   EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2742 
2743   bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT_SAT;
2744 
2745   if (!DstVT.isVector()) {
2746     // For bf16 or for f16 in absense of Zfh, promote to f32, then saturate
2747     // the result.
2748     if ((Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfhOrZhinx()) ||
2749         Src.getValueType() == MVT::bf16) {
2750       Src = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Src);
2751     }
2752 
2753     unsigned Opc;
2754     if (SatVT == DstVT)
2755       Opc = IsSigned ? RISCVISD::FCVT_X : RISCVISD::FCVT_XU;
2756     else if (DstVT == MVT::i64 && SatVT == MVT::i32)
2757       Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
2758     else
2759       return SDValue();
2760     // FIXME: Support other SatVTs by clamping before or after the conversion.
2761 
2762     SDLoc DL(Op);
2763     SDValue FpToInt = DAG.getNode(
2764         Opc, DL, DstVT, Src,
2765         DAG.getTargetConstant(RISCVFPRndMode::RTZ, DL, Subtarget.getXLenVT()));
2766 
2767     if (Opc == RISCVISD::FCVT_WU_RV64)
2768       FpToInt = DAG.getZeroExtendInReg(FpToInt, DL, MVT::i32);
2769 
2770     SDValue ZeroInt = DAG.getConstant(0, DL, DstVT);
2771     return DAG.getSelectCC(DL, Src, Src, ZeroInt, FpToInt,
2772                            ISD::CondCode::SETUO);
2773   }
2774 
2775   // Vectors.
2776 
2777   MVT DstEltVT = DstVT.getVectorElementType();
2778   MVT SrcVT = Src.getSimpleValueType();
2779   MVT SrcEltVT = SrcVT.getVectorElementType();
2780   unsigned SrcEltSize = SrcEltVT.getSizeInBits();
2781   unsigned DstEltSize = DstEltVT.getSizeInBits();
2782 
2783   // Only handle saturating to the destination type.
2784   if (SatVT != DstEltVT)
2785     return SDValue();
2786 
2787   // FIXME: Don't support narrowing by more than 1 steps for now.
2788   if (SrcEltSize > (2 * DstEltSize))
2789     return SDValue();
2790 
2791   MVT DstContainerVT = DstVT;
2792   MVT SrcContainerVT = SrcVT;
2793   if (DstVT.isFixedLengthVector()) {
2794     DstContainerVT = getContainerForFixedLengthVector(DAG, DstVT, Subtarget);
2795     SrcContainerVT = getContainerForFixedLengthVector(DAG, SrcVT, Subtarget);
2796     assert(DstContainerVT.getVectorElementCount() ==
2797                SrcContainerVT.getVectorElementCount() &&
2798            "Expected same element count");
2799     Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
2800   }
2801 
2802   SDLoc DL(Op);
2803 
2804   auto [Mask, VL] = getDefaultVLOps(DstVT, DstContainerVT, DL, DAG, Subtarget);
2805 
2806   SDValue IsNan = DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(),
2807                               {Src, Src, DAG.getCondCode(ISD::SETNE),
2808                                DAG.getUNDEF(Mask.getValueType()), Mask, VL});
2809 
2810   // Need to widen by more than 1 step, promote the FP type, then do a widening
2811   // convert.
2812   if (DstEltSize > (2 * SrcEltSize)) {
2813     assert(SrcContainerVT.getVectorElementType() == MVT::f16 && "Unexpected VT!");
2814     MVT InterVT = SrcContainerVT.changeVectorElementType(MVT::f32);
2815     Src = DAG.getNode(RISCVISD::FP_EXTEND_VL, DL, InterVT, Src, Mask, VL);
2816   }
2817 
2818   unsigned RVVOpc =
2819       IsSigned ? RISCVISD::VFCVT_RTZ_X_F_VL : RISCVISD::VFCVT_RTZ_XU_F_VL;
2820   SDValue Res = DAG.getNode(RVVOpc, DL, DstContainerVT, Src, Mask, VL);
2821 
2822   SDValue SplatZero = DAG.getNode(
2823       RISCVISD::VMV_V_X_VL, DL, DstContainerVT, DAG.getUNDEF(DstContainerVT),
2824       DAG.getConstant(0, DL, Subtarget.getXLenVT()), VL);
2825   Res = DAG.getNode(RISCVISD::VMERGE_VL, DL, DstContainerVT, IsNan, SplatZero,
2826                     Res, DAG.getUNDEF(DstContainerVT), VL);
2827 
2828   if (DstVT.isFixedLengthVector())
2829     Res = convertFromScalableVector(DstVT, Res, DAG, Subtarget);
2830 
2831   return Res;
2832 }
2833 
matchRoundingOp(unsigned Opc)2834 static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
2835   switch (Opc) {
2836   case ISD::FROUNDEVEN:
2837   case ISD::STRICT_FROUNDEVEN:
2838   case ISD::VP_FROUNDEVEN:
2839     return RISCVFPRndMode::RNE;
2840   case ISD::FTRUNC:
2841   case ISD::STRICT_FTRUNC:
2842   case ISD::VP_FROUNDTOZERO:
2843     return RISCVFPRndMode::RTZ;
2844   case ISD::FFLOOR:
2845   case ISD::STRICT_FFLOOR:
2846   case ISD::VP_FFLOOR:
2847     return RISCVFPRndMode::RDN;
2848   case ISD::FCEIL:
2849   case ISD::STRICT_FCEIL:
2850   case ISD::VP_FCEIL:
2851     return RISCVFPRndMode::RUP;
2852   case ISD::FROUND:
2853   case ISD::STRICT_FROUND:
2854   case ISD::VP_FROUND:
2855     return RISCVFPRndMode::RMM;
2856   case ISD::FRINT:
2857     return RISCVFPRndMode::DYN;
2858   }
2859 
2860   return RISCVFPRndMode::Invalid;
2861 }
2862 
2863 // Expand vector FTRUNC, FCEIL, FFLOOR, FROUND, VP_FCEIL, VP_FFLOOR, VP_FROUND
2864 // VP_FROUNDEVEN, VP_FROUNDTOZERO, VP_FRINT and VP_FNEARBYINT by converting to
2865 // the integer domain and back. Taking care to avoid converting values that are
2866 // nan or already correct.
2867 static SDValue
lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2868 lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
2869                                       const RISCVSubtarget &Subtarget) {
2870   MVT VT = Op.getSimpleValueType();
2871   assert(VT.isVector() && "Unexpected type");
2872 
2873   SDLoc DL(Op);
2874 
2875   SDValue Src = Op.getOperand(0);
2876 
2877   MVT ContainerVT = VT;
2878   if (VT.isFixedLengthVector()) {
2879     ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
2880     Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
2881   }
2882 
2883   SDValue Mask, VL;
2884   if (Op->isVPOpcode()) {
2885     Mask = Op.getOperand(1);
2886     if (VT.isFixedLengthVector())
2887       Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
2888                                      Subtarget);
2889     VL = Op.getOperand(2);
2890   } else {
2891     std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
2892   }
2893 
2894   // Freeze the source since we are increasing the number of uses.
2895   Src = DAG.getFreeze(Src);
2896 
2897   // We do the conversion on the absolute value and fix the sign at the end.
2898   SDValue Abs = DAG.getNode(RISCVISD::FABS_VL, DL, ContainerVT, Src, Mask, VL);
2899 
2900   // Determine the largest integer that can be represented exactly. This and
2901   // values larger than it don't have any fractional bits so don't need to
2902   // be converted.
2903   const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(ContainerVT);
2904   unsigned Precision = APFloat::semanticsPrecision(FltSem);
2905   APFloat MaxVal = APFloat(FltSem);
2906   MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
2907                           /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
2908   SDValue MaxValNode =
2909       DAG.getConstantFP(MaxVal, DL, ContainerVT.getVectorElementType());
2910   SDValue MaxValSplat = DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
2911                                     DAG.getUNDEF(ContainerVT), MaxValNode, VL);
2912 
2913   // If abs(Src) was larger than MaxVal or nan, keep it.
2914   MVT SetccVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
2915   Mask =
2916       DAG.getNode(RISCVISD::SETCC_VL, DL, SetccVT,
2917                   {Abs, MaxValSplat, DAG.getCondCode(ISD::SETOLT),
2918                    Mask, Mask, VL});
2919 
2920   // Truncate to integer and convert back to FP.
2921   MVT IntVT = ContainerVT.changeVectorElementTypeToInteger();
2922   MVT XLenVT = Subtarget.getXLenVT();
2923   SDValue Truncated;
2924 
2925   switch (Op.getOpcode()) {
2926   default:
2927     llvm_unreachable("Unexpected opcode");
2928   case ISD::FCEIL:
2929   case ISD::VP_FCEIL:
2930   case ISD::FFLOOR:
2931   case ISD::VP_FFLOOR:
2932   case ISD::FROUND:
2933   case ISD::FROUNDEVEN:
2934   case ISD::VP_FROUND:
2935   case ISD::VP_FROUNDEVEN:
2936   case ISD::VP_FROUNDTOZERO: {
2937     RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Op.getOpcode());
2938     assert(FRM != RISCVFPRndMode::Invalid);
2939     Truncated = DAG.getNode(RISCVISD::VFCVT_RM_X_F_VL, DL, IntVT, Src, Mask,
2940                             DAG.getTargetConstant(FRM, DL, XLenVT), VL);
2941     break;
2942   }
2943   case ISD::FTRUNC:
2944     Truncated = DAG.getNode(RISCVISD::VFCVT_RTZ_X_F_VL, DL, IntVT, Src,
2945                             Mask, VL);
2946     break;
2947   case ISD::FRINT:
2948   case ISD::VP_FRINT:
2949     Truncated = DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, IntVT, Src, Mask, VL);
2950     break;
2951   case ISD::FNEARBYINT:
2952   case ISD::VP_FNEARBYINT:
2953     Truncated = DAG.getNode(RISCVISD::VFROUND_NOEXCEPT_VL, DL, ContainerVT, Src,
2954                             Mask, VL);
2955     break;
2956   }
2957 
2958   // VFROUND_NOEXCEPT_VL includes SINT_TO_FP_VL.
2959   if (Truncated.getOpcode() != RISCVISD::VFROUND_NOEXCEPT_VL)
2960     Truncated = DAG.getNode(RISCVISD::SINT_TO_FP_VL, DL, ContainerVT, Truncated,
2961                             Mask, VL);
2962 
2963   // Restore the original sign so that -0.0 is preserved.
2964   Truncated = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Truncated,
2965                           Src, Src, Mask, VL);
2966 
2967   if (!VT.isFixedLengthVector())
2968     return Truncated;
2969 
2970   return convertFromScalableVector(VT, Truncated, DAG, Subtarget);
2971 }
2972 
2973 // Expand vector STRICT_FTRUNC, STRICT_FCEIL, STRICT_FFLOOR, STRICT_FROUND
2974 // STRICT_FROUNDEVEN and STRICT_FNEARBYINT by converting sNan of the source to
2975 // qNan and coverting the new source to integer and back to FP.
2976 static SDValue
lowerVectorStrictFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2977 lowerVectorStrictFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
2978                                             const RISCVSubtarget &Subtarget) {
2979   SDLoc DL(Op);
2980   MVT VT = Op.getSimpleValueType();
2981   SDValue Chain = Op.getOperand(0);
2982   SDValue Src = Op.getOperand(1);
2983 
2984   MVT ContainerVT = VT;
2985   if (VT.isFixedLengthVector()) {
2986     ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
2987     Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
2988   }
2989 
2990   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
2991 
2992   // Freeze the source since we are increasing the number of uses.
2993   Src = DAG.getFreeze(Src);
2994 
2995   // Covert sNan to qNan by executing x + x for all unordered elemenet x in Src.
2996   MVT MaskVT = Mask.getSimpleValueType();
2997   SDValue Unorder = DAG.getNode(RISCVISD::STRICT_FSETCC_VL, DL,
2998                                 DAG.getVTList(MaskVT, MVT::Other),
2999                                 {Chain, Src, Src, DAG.getCondCode(ISD::SETUNE),
3000                                  DAG.getUNDEF(MaskVT), Mask, VL});
3001   Chain = Unorder.getValue(1);
3002   Src = DAG.getNode(RISCVISD::STRICT_FADD_VL, DL,
3003                     DAG.getVTList(ContainerVT, MVT::Other),
3004                     {Chain, Src, Src, DAG.getUNDEF(ContainerVT), Unorder, VL});
3005   Chain = Src.getValue(1);
3006 
3007   // We do the conversion on the absolute value and fix the sign at the end.
3008   SDValue Abs = DAG.getNode(RISCVISD::FABS_VL, DL, ContainerVT, Src, Mask, VL);
3009 
3010   // Determine the largest integer that can be represented exactly. This and
3011   // values larger than it don't have any fractional bits so don't need to
3012   // be converted.
3013   const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(ContainerVT);
3014   unsigned Precision = APFloat::semanticsPrecision(FltSem);
3015   APFloat MaxVal = APFloat(FltSem);
3016   MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
3017                           /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
3018   SDValue MaxValNode =
3019       DAG.getConstantFP(MaxVal, DL, ContainerVT.getVectorElementType());
3020   SDValue MaxValSplat = DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
3021                                     DAG.getUNDEF(ContainerVT), MaxValNode, VL);
3022 
3023   // If abs(Src) was larger than MaxVal or nan, keep it.
3024   Mask = DAG.getNode(
3025       RISCVISD::SETCC_VL, DL, MaskVT,
3026       {Abs, MaxValSplat, DAG.getCondCode(ISD::SETOLT), Mask, Mask, VL});
3027 
3028   // Truncate to integer and convert back to FP.
3029   MVT IntVT = ContainerVT.changeVectorElementTypeToInteger();
3030   MVT XLenVT = Subtarget.getXLenVT();
3031   SDValue Truncated;
3032 
3033   switch (Op.getOpcode()) {
3034   default:
3035     llvm_unreachable("Unexpected opcode");
3036   case ISD::STRICT_FCEIL:
3037   case ISD::STRICT_FFLOOR:
3038   case ISD::STRICT_FROUND:
3039   case ISD::STRICT_FROUNDEVEN: {
3040     RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Op.getOpcode());
3041     assert(FRM != RISCVFPRndMode::Invalid);
3042     Truncated = DAG.getNode(
3043         RISCVISD::STRICT_VFCVT_RM_X_F_VL, DL, DAG.getVTList(IntVT, MVT::Other),
3044         {Chain, Src, Mask, DAG.getTargetConstant(FRM, DL, XLenVT), VL});
3045     break;
3046   }
3047   case ISD::STRICT_FTRUNC:
3048     Truncated =
3049         DAG.getNode(RISCVISD::STRICT_VFCVT_RTZ_X_F_VL, DL,
3050                     DAG.getVTList(IntVT, MVT::Other), Chain, Src, Mask, VL);
3051     break;
3052   case ISD::STRICT_FNEARBYINT:
3053     Truncated = DAG.getNode(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL, DL,
3054                             DAG.getVTList(ContainerVT, MVT::Other), Chain, Src,
3055                             Mask, VL);
3056     break;
3057   }
3058   Chain = Truncated.getValue(1);
3059 
3060   // VFROUND_NOEXCEPT_VL includes SINT_TO_FP_VL.
3061   if (Op.getOpcode() != ISD::STRICT_FNEARBYINT) {
3062     Truncated = DAG.getNode(RISCVISD::STRICT_SINT_TO_FP_VL, DL,
3063                             DAG.getVTList(ContainerVT, MVT::Other), Chain,
3064                             Truncated, Mask, VL);
3065     Chain = Truncated.getValue(1);
3066   }
3067 
3068   // Restore the original sign so that -0.0 is preserved.
3069   Truncated = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Truncated,
3070                           Src, Src, Mask, VL);
3071 
3072   if (VT.isFixedLengthVector())
3073     Truncated = convertFromScalableVector(VT, Truncated, DAG, Subtarget);
3074   return DAG.getMergeValues({Truncated, Chain}, DL);
3075 }
3076 
3077 static SDValue
lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3078 lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
3079                                 const RISCVSubtarget &Subtarget) {
3080   MVT VT = Op.getSimpleValueType();
3081   if (VT.isVector())
3082     return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
3083 
3084   if (DAG.shouldOptForSize())
3085     return SDValue();
3086 
3087   SDLoc DL(Op);
3088   SDValue Src = Op.getOperand(0);
3089 
3090   // Create an integer the size of the mantissa with the MSB set. This and all
3091   // values larger than it don't have any fractional bits so don't need to be
3092   // converted.
3093   const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
3094   unsigned Precision = APFloat::semanticsPrecision(FltSem);
3095   APFloat MaxVal = APFloat(FltSem);
3096   MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
3097                           /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
3098   SDValue MaxValNode = DAG.getConstantFP(MaxVal, DL, VT);
3099 
3100   RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Op.getOpcode());
3101   return DAG.getNode(RISCVISD::FROUND, DL, VT, Src, MaxValNode,
3102                      DAG.getTargetConstant(FRM, DL, Subtarget.getXLenVT()));
3103 }
3104 
3105 // Expand vector LRINT and LLRINT by converting to the integer domain.
lowerVectorXRINT(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3106 static SDValue lowerVectorXRINT(SDValue Op, SelectionDAG &DAG,
3107                                 const RISCVSubtarget &Subtarget) {
3108   MVT VT = Op.getSimpleValueType();
3109   assert(VT.isVector() && "Unexpected type");
3110 
3111   SDLoc DL(Op);
3112   SDValue Src = Op.getOperand(0);
3113   MVT ContainerVT = VT;
3114 
3115   if (VT.isFixedLengthVector()) {
3116     ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
3117     Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
3118   }
3119 
3120   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3121   SDValue Truncated =
3122       DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, ContainerVT, Src, Mask, VL);
3123 
3124   if (!VT.isFixedLengthVector())
3125     return Truncated;
3126 
3127   return convertFromScalableVector(VT, Truncated, DAG, Subtarget);
3128 }
3129 
3130 static SDValue
getVSlidedown(SelectionDAG & DAG,const RISCVSubtarget & Subtarget,const SDLoc & DL,EVT VT,SDValue Merge,SDValue Op,SDValue Offset,SDValue Mask,SDValue VL,unsigned Policy=RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)3131 getVSlidedown(SelectionDAG &DAG, const RISCVSubtarget &Subtarget,
3132               const SDLoc &DL, EVT VT, SDValue Merge, SDValue Op,
3133               SDValue Offset, SDValue Mask, SDValue VL,
3134               unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED) {
3135   if (Merge.isUndef())
3136     Policy = RISCVII::TAIL_AGNOSTIC | RISCVII::MASK_AGNOSTIC;
3137   SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
3138   SDValue Ops[] = {Merge, Op, Offset, Mask, VL, PolicyOp};
3139   return DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VT, Ops);
3140 }
3141 
3142 static SDValue
getVSlideup(SelectionDAG & DAG,const RISCVSubtarget & Subtarget,const SDLoc & DL,EVT VT,SDValue Merge,SDValue Op,SDValue Offset,SDValue Mask,SDValue VL,unsigned Policy=RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)3143 getVSlideup(SelectionDAG &DAG, const RISCVSubtarget &Subtarget, const SDLoc &DL,
3144             EVT VT, SDValue Merge, SDValue Op, SDValue Offset, SDValue Mask,
3145             SDValue VL,
3146             unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED) {
3147   if (Merge.isUndef())
3148     Policy = RISCVII::TAIL_AGNOSTIC | RISCVII::MASK_AGNOSTIC;
3149   SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
3150   SDValue Ops[] = {Merge, Op, Offset, Mask, VL, PolicyOp};
3151   return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VT, Ops);
3152 }
3153 
getLMUL1VT(MVT VT)3154 static MVT getLMUL1VT(MVT VT) {
3155   assert(VT.getVectorElementType().getSizeInBits() <= 64 &&
3156          "Unexpected vector MVT");
3157   return MVT::getScalableVectorVT(
3158       VT.getVectorElementType(),
3159       RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits());
3160 }
3161 
3162 struct VIDSequence {
3163   int64_t StepNumerator;
3164   unsigned StepDenominator;
3165   int64_t Addend;
3166 };
3167 
getExactInteger(const APFloat & APF,uint32_t BitWidth)3168 static std::optional<uint64_t> getExactInteger(const APFloat &APF,
3169                                                uint32_t BitWidth) {
3170   APSInt ValInt(BitWidth, !APF.isNegative());
3171   // We use an arbitrary rounding mode here. If a floating-point is an exact
3172   // integer (e.g., 1.0), the rounding mode does not affect the output value. If
3173   // the rounding mode changes the output value, then it is not an exact
3174   // integer.
3175   RoundingMode ArbitraryRM = RoundingMode::TowardZero;
3176   bool IsExact;
3177   // If it is out of signed integer range, it will return an invalid operation.
3178   // If it is not an exact integer, IsExact is false.
3179   if ((APF.convertToInteger(ValInt, ArbitraryRM, &IsExact) ==
3180        APFloatBase::opInvalidOp) ||
3181       !IsExact)
3182     return std::nullopt;
3183   return ValInt.extractBitsAsZExtValue(BitWidth, 0);
3184 }
3185 
3186 // Try to match an arithmetic-sequence BUILD_VECTOR [X,X+S,X+2*S,...,X+(N-1)*S]
3187 // to the (non-zero) step S and start value X. This can be then lowered as the
3188 // RVV sequence (VID * S) + X, for example.
3189 // The step S is represented as an integer numerator divided by a positive
3190 // denominator. Note that the implementation currently only identifies
3191 // sequences in which either the numerator is +/- 1 or the denominator is 1. It
3192 // cannot detect 2/3, for example.
3193 // Note that this method will also match potentially unappealing index
3194 // sequences, like <i32 0, i32 50939494>, however it is left to the caller to
3195 // determine whether this is worth generating code for.
isSimpleVIDSequence(SDValue Op,unsigned EltSizeInBits)3196 static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
3197                                                       unsigned EltSizeInBits) {
3198   unsigned NumElts = Op.getNumOperands();
3199   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
3200   bool IsInteger = Op.getValueType().isInteger();
3201 
3202   std::optional<unsigned> SeqStepDenom;
3203   std::optional<int64_t> SeqStepNum, SeqAddend;
3204   std::optional<std::pair<uint64_t, unsigned>> PrevElt;
3205   assert(EltSizeInBits >= Op.getValueType().getScalarSizeInBits());
3206   for (unsigned Idx = 0; Idx < NumElts; Idx++) {
3207     // Assume undef elements match the sequence; we just have to be careful
3208     // when interpolating across them.
3209     if (Op.getOperand(Idx).isUndef())
3210       continue;
3211 
3212     uint64_t Val;
3213     if (IsInteger) {
3214       // The BUILD_VECTOR must be all constants.
3215       if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
3216         return std::nullopt;
3217       Val = Op.getConstantOperandVal(Idx) &
3218             maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
3219     } else {
3220       // The BUILD_VECTOR must be all constants.
3221       if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
3222         return std::nullopt;
3223       if (auto ExactInteger = getExactInteger(
3224               cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
3225               Op.getScalarValueSizeInBits()))
3226         Val = *ExactInteger;
3227       else
3228         return std::nullopt;
3229     }
3230 
3231     if (PrevElt) {
3232       // Calculate the step since the last non-undef element, and ensure
3233       // it's consistent across the entire sequence.
3234       unsigned IdxDiff = Idx - PrevElt->second;
3235       int64_t ValDiff = SignExtend64(Val - PrevElt->first, EltSizeInBits);
3236 
3237       // A zero-value value difference means that we're somewhere in the middle
3238       // of a fractional step, e.g. <0,0,0*,0,1,1,1,1>. Wait until we notice a
3239       // step change before evaluating the sequence.
3240       if (ValDiff == 0)
3241         continue;
3242 
3243       int64_t Remainder = ValDiff % IdxDiff;
3244       // Normalize the step if it's greater than 1.
3245       if (Remainder != ValDiff) {
3246         // The difference must cleanly divide the element span.
3247         if (Remainder != 0)
3248           return std::nullopt;
3249         ValDiff /= IdxDiff;
3250         IdxDiff = 1;
3251       }
3252 
3253       if (!SeqStepNum)
3254         SeqStepNum = ValDiff;
3255       else if (ValDiff != SeqStepNum)
3256         return std::nullopt;
3257 
3258       if (!SeqStepDenom)
3259         SeqStepDenom = IdxDiff;
3260       else if (IdxDiff != *SeqStepDenom)
3261         return std::nullopt;
3262     }
3263 
3264     // Record this non-undef element for later.
3265     if (!PrevElt || PrevElt->first != Val)
3266       PrevElt = std::make_pair(Val, Idx);
3267   }
3268 
3269   // We need to have logged a step for this to count as a legal index sequence.
3270   if (!SeqStepNum || !SeqStepDenom)
3271     return std::nullopt;
3272 
3273   // Loop back through the sequence and validate elements we might have skipped
3274   // while waiting for a valid step. While doing this, log any sequence addend.
3275   for (unsigned Idx = 0; Idx < NumElts; Idx++) {
3276     if (Op.getOperand(Idx).isUndef())
3277       continue;
3278     uint64_t Val;
3279     if (IsInteger) {
3280       Val = Op.getConstantOperandVal(Idx) &
3281             maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
3282     } else {
3283       Val = *getExactInteger(
3284           cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
3285           Op.getScalarValueSizeInBits());
3286     }
3287     uint64_t ExpectedVal =
3288         (int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
3289     int64_t Addend = SignExtend64(Val - ExpectedVal, EltSizeInBits);
3290     if (!SeqAddend)
3291       SeqAddend = Addend;
3292     else if (Addend != SeqAddend)
3293       return std::nullopt;
3294   }
3295 
3296   assert(SeqAddend && "Must have an addend if we have a step");
3297 
3298   return VIDSequence{*SeqStepNum, *SeqStepDenom, *SeqAddend};
3299 }
3300 
3301 // Match a splatted value (SPLAT_VECTOR/BUILD_VECTOR) of an EXTRACT_VECTOR_ELT
3302 // and lower it as a VRGATHER_VX_VL from the source vector.
matchSplatAsGather(SDValue SplatVal,MVT VT,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3303 static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
3304                                   SelectionDAG &DAG,
3305                                   const RISCVSubtarget &Subtarget) {
3306   if (SplatVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
3307     return SDValue();
3308   SDValue Vec = SplatVal.getOperand(0);
3309   // Only perform this optimization on vectors of the same size for simplicity.
3310   // Don't perform this optimization for i1 vectors.
3311   // FIXME: Support i1 vectors, maybe by promoting to i8?
3312   if (Vec.getValueType() != VT || VT.getVectorElementType() == MVT::i1)
3313     return SDValue();
3314   SDValue Idx = SplatVal.getOperand(1);
3315   // The index must be a legal type.
3316   if (Idx.getValueType() != Subtarget.getXLenVT())
3317     return SDValue();
3318 
3319   MVT ContainerVT = VT;
3320   if (VT.isFixedLengthVector()) {
3321     ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
3322     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
3323   }
3324 
3325   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3326 
3327   SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec,
3328                                Idx, DAG.getUNDEF(ContainerVT), Mask, VL);
3329 
3330   if (!VT.isFixedLengthVector())
3331     return Gather;
3332 
3333   return convertFromScalableVector(VT, Gather, DAG, Subtarget);
3334 }
3335 
3336 
3337 /// Try and optimize BUILD_VECTORs with "dominant values" - these are values
3338 /// which constitute a large proportion of the elements. In such cases we can
3339 /// splat a vector with the dominant element and make up the shortfall with
3340 /// INSERT_VECTOR_ELTs.  Returns SDValue if not profitable.
3341 /// Note that this includes vectors of 2 elements by association. The
3342 /// upper-most element is the "dominant" one, allowing us to use a splat to
3343 /// "insert" the upper element, and an insert of the lower element at position
3344 /// 0, which improves codegen.
lowerBuildVectorViaDominantValues(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3345 static SDValue lowerBuildVectorViaDominantValues(SDValue Op, SelectionDAG &DAG,
3346                                                  const RISCVSubtarget &Subtarget) {
3347   MVT VT = Op.getSimpleValueType();
3348   assert(VT.isFixedLengthVector() && "Unexpected vector!");
3349 
3350   MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
3351 
3352   SDLoc DL(Op);
3353   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3354 
3355   MVT XLenVT = Subtarget.getXLenVT();
3356   unsigned NumElts = Op.getNumOperands();
3357 
3358   SDValue DominantValue;
3359   unsigned MostCommonCount = 0;
3360   DenseMap<SDValue, unsigned> ValueCounts;
3361   unsigned NumUndefElts =
3362       count_if(Op->op_values(), [](const SDValue &V) { return V.isUndef(); });
3363 
3364   // Track the number of scalar loads we know we'd be inserting, estimated as
3365   // any non-zero floating-point constant. Other kinds of element are either
3366   // already in registers or are materialized on demand. The threshold at which
3367   // a vector load is more desirable than several scalar materializion and
3368   // vector-insertion instructions is not known.
3369   unsigned NumScalarLoads = 0;
3370 
3371   for (SDValue V : Op->op_values()) {
3372     if (V.isUndef())
3373       continue;
3374 
3375     ValueCounts.insert(std::make_pair(V, 0));
3376     unsigned &Count = ValueCounts[V];
3377     if (0 == Count)
3378       if (auto *CFP = dyn_cast<ConstantFPSDNode>(V))
3379         NumScalarLoads += !CFP->isExactlyValue(+0.0);
3380 
3381     // Is this value dominant? In case of a tie, prefer the highest element as
3382     // it's cheaper to insert near the beginning of a vector than it is at the
3383     // end.
3384     if (++Count >= MostCommonCount) {
3385       DominantValue = V;
3386       MostCommonCount = Count;
3387     }
3388   }
3389 
3390   assert(DominantValue && "Not expecting an all-undef BUILD_VECTOR");
3391   unsigned NumDefElts = NumElts - NumUndefElts;
3392   unsigned DominantValueCountThreshold = NumDefElts <= 2 ? 0 : NumDefElts - 2;
3393 
3394   // Don't perform this optimization when optimizing for size, since
3395   // materializing elements and inserting them tends to cause code bloat.
3396   if (!DAG.shouldOptForSize() && NumScalarLoads < NumElts &&
3397       (NumElts != 2 || ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) &&
3398       ((MostCommonCount > DominantValueCountThreshold) ||
3399        (ValueCounts.size() <= Log2_32(NumDefElts)))) {
3400     // Start by splatting the most common element.
3401     SDValue Vec = DAG.getSplatBuildVector(VT, DL, DominantValue);
3402 
3403     DenseSet<SDValue> Processed{DominantValue};
3404 
3405     // We can handle an insert into the last element (of a splat) via
3406     // v(f)slide1down.  This is slightly better than the vslideup insert
3407     // lowering as it avoids the need for a vector group temporary.  It
3408     // is also better than using vmerge.vx as it avoids the need to
3409     // materialize the mask in a vector register.
3410     if (SDValue LastOp = Op->getOperand(Op->getNumOperands() - 1);
3411         !LastOp.isUndef() && ValueCounts[LastOp] == 1 &&
3412         LastOp != DominantValue) {
3413       Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
3414       auto OpCode =
3415         VT.isFloatingPoint() ? RISCVISD::VFSLIDE1DOWN_VL : RISCVISD::VSLIDE1DOWN_VL;
3416       if (!VT.isFloatingPoint())
3417         LastOp = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, LastOp);
3418       Vec = DAG.getNode(OpCode, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Vec,
3419                         LastOp, Mask, VL);
3420       Vec = convertFromScalableVector(VT, Vec, DAG, Subtarget);
3421       Processed.insert(LastOp);
3422     }
3423 
3424     MVT SelMaskTy = VT.changeVectorElementType(MVT::i1);
3425     for (const auto &OpIdx : enumerate(Op->ops())) {
3426       const SDValue &V = OpIdx.value();
3427       if (V.isUndef() || !Processed.insert(V).second)
3428         continue;
3429       if (ValueCounts[V] == 1) {
3430         Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V,
3431                           DAG.getConstant(OpIdx.index(), DL, XLenVT));
3432       } else {
3433         // Blend in all instances of this value using a VSELECT, using a
3434         // mask where each bit signals whether that element is the one
3435         // we're after.
3436         SmallVector<SDValue> Ops;
3437         transform(Op->op_values(), std::back_inserter(Ops), [&](SDValue V1) {
3438           return DAG.getConstant(V == V1, DL, XLenVT);
3439         });
3440         Vec = DAG.getNode(ISD::VSELECT, DL, VT,
3441                           DAG.getBuildVector(SelMaskTy, DL, Ops),
3442                           DAG.getSplatBuildVector(VT, DL, V), Vec);
3443       }
3444     }
3445 
3446     return Vec;
3447   }
3448 
3449   return SDValue();
3450 }
3451 
lowerBuildVectorOfConstants(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3452 static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
3453                                            const RISCVSubtarget &Subtarget) {
3454   MVT VT = Op.getSimpleValueType();
3455   assert(VT.isFixedLengthVector() && "Unexpected vector!");
3456 
3457   MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
3458 
3459   SDLoc DL(Op);
3460   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3461 
3462   MVT XLenVT = Subtarget.getXLenVT();
3463   unsigned NumElts = Op.getNumOperands();
3464 
3465   if (VT.getVectorElementType() == MVT::i1) {
3466     if (ISD::isBuildVectorAllZeros(Op.getNode())) {
3467       SDValue VMClr = DAG.getNode(RISCVISD::VMCLR_VL, DL, ContainerVT, VL);
3468       return convertFromScalableVector(VT, VMClr, DAG, Subtarget);
3469     }
3470 
3471     if (ISD::isBuildVectorAllOnes(Op.getNode())) {
3472       SDValue VMSet = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL);
3473       return convertFromScalableVector(VT, VMSet, DAG, Subtarget);
3474     }
3475 
3476     // Lower constant mask BUILD_VECTORs via an integer vector type, in
3477     // scalar integer chunks whose bit-width depends on the number of mask
3478     // bits and XLEN.
3479     // First, determine the most appropriate scalar integer type to use. This
3480     // is at most XLenVT, but may be shrunk to a smaller vector element type
3481     // according to the size of the final vector - use i8 chunks rather than
3482     // XLenVT if we're producing a v8i1. This results in more consistent
3483     // codegen across RV32 and RV64.
3484     unsigned NumViaIntegerBits = std::clamp(NumElts, 8u, Subtarget.getXLen());
3485     NumViaIntegerBits = std::min(NumViaIntegerBits, Subtarget.getELen());
3486     // If we have to use more than one INSERT_VECTOR_ELT then this
3487     // optimization is likely to increase code size; avoid peforming it in
3488     // such a case. We can use a load from a constant pool in this case.
3489     if (DAG.shouldOptForSize() && NumElts > NumViaIntegerBits)
3490       return SDValue();
3491     // Now we can create our integer vector type. Note that it may be larger
3492     // than the resulting mask type: v4i1 would use v1i8 as its integer type.
3493     unsigned IntegerViaVecElts = divideCeil(NumElts, NumViaIntegerBits);
3494     MVT IntegerViaVecVT =
3495       MVT::getVectorVT(MVT::getIntegerVT(NumViaIntegerBits),
3496                        IntegerViaVecElts);
3497 
3498     uint64_t Bits = 0;
3499     unsigned BitPos = 0, IntegerEltIdx = 0;
3500     SmallVector<SDValue, 8> Elts(IntegerViaVecElts);
3501 
3502     for (unsigned I = 0; I < NumElts;) {
3503       SDValue V = Op.getOperand(I);
3504       bool BitValue = !V.isUndef() && V->getAsZExtVal();
3505       Bits |= ((uint64_t)BitValue << BitPos);
3506       ++BitPos;
3507       ++I;
3508 
3509       // Once we accumulate enough bits to fill our scalar type or process the
3510       // last element, insert into our vector and clear our accumulated data.
3511       if (I % NumViaIntegerBits == 0 || I == NumElts) {
3512         if (NumViaIntegerBits <= 32)
3513           Bits = SignExtend64<32>(Bits);
3514         SDValue Elt = DAG.getConstant(Bits, DL, XLenVT);
3515         Elts[IntegerEltIdx] = Elt;
3516         Bits = 0;
3517         BitPos = 0;
3518         IntegerEltIdx++;
3519       }
3520     }
3521 
3522     SDValue Vec = DAG.getBuildVector(IntegerViaVecVT, DL, Elts);
3523 
3524     if (NumElts < NumViaIntegerBits) {
3525       // If we're producing a smaller vector than our minimum legal integer
3526       // type, bitcast to the equivalent (known-legal) mask type, and extract
3527       // our final mask.
3528       assert(IntegerViaVecVT == MVT::v1i8 && "Unexpected mask vector type");
3529       Vec = DAG.getBitcast(MVT::v8i1, Vec);
3530       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Vec,
3531                         DAG.getConstant(0, DL, XLenVT));
3532     } else {
3533       // Else we must have produced an integer type with the same size as the
3534       // mask type; bitcast for the final result.
3535       assert(VT.getSizeInBits() == IntegerViaVecVT.getSizeInBits());
3536       Vec = DAG.getBitcast(VT, Vec);
3537     }
3538 
3539     return Vec;
3540   }
3541 
3542   if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
3543     unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
3544                                         : RISCVISD::VMV_V_X_VL;
3545     if (!VT.isFloatingPoint())
3546       Splat = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Splat);
3547     Splat =
3548         DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Splat, VL);
3549     return convertFromScalableVector(VT, Splat, DAG, Subtarget);
3550   }
3551 
3552   // Try and match index sequences, which we can lower to the vid instruction
3553   // with optional modifications. An all-undef vector is matched by
3554   // getSplatValue, above.
3555   if (auto SimpleVID = isSimpleVIDSequence(Op, Op.getScalarValueSizeInBits())) {
3556     int64_t StepNumerator = SimpleVID->StepNumerator;
3557     unsigned StepDenominator = SimpleVID->StepDenominator;
3558     int64_t Addend = SimpleVID->Addend;
3559 
3560     assert(StepNumerator != 0 && "Invalid step");
3561     bool Negate = false;
3562     int64_t SplatStepVal = StepNumerator;
3563     unsigned StepOpcode = ISD::MUL;
3564     // Exclude INT64_MIN to avoid passing it to std::abs. We won't optimize it
3565     // anyway as the shift of 63 won't fit in uimm5.
3566     if (StepNumerator != 1 && StepNumerator != INT64_MIN &&
3567         isPowerOf2_64(std::abs(StepNumerator))) {
3568       Negate = StepNumerator < 0;
3569       StepOpcode = ISD::SHL;
3570       SplatStepVal = Log2_64(std::abs(StepNumerator));
3571     }
3572 
3573     // Only emit VIDs with suitably-small steps/addends. We use imm5 is a
3574     // threshold since it's the immediate value many RVV instructions accept.
3575     // There is no vmul.vi instruction so ensure multiply constant can fit in
3576     // a single addi instruction.
3577     if (((StepOpcode == ISD::MUL && isInt<12>(SplatStepVal)) ||
3578          (StepOpcode == ISD::SHL && isUInt<5>(SplatStepVal))) &&
3579         isPowerOf2_32(StepDenominator) &&
3580         (SplatStepVal >= 0 || StepDenominator == 1) && isInt<5>(Addend)) {
3581       MVT VIDVT =
3582           VT.isFloatingPoint() ? VT.changeVectorElementTypeToInteger() : VT;
3583       MVT VIDContainerVT =
3584           getContainerForFixedLengthVector(DAG, VIDVT, Subtarget);
3585       SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VIDContainerVT, Mask, VL);
3586       // Convert right out of the scalable type so we can use standard ISD
3587       // nodes for the rest of the computation. If we used scalable types with
3588       // these, we'd lose the fixed-length vector info and generate worse
3589       // vsetvli code.
3590       VID = convertFromScalableVector(VIDVT, VID, DAG, Subtarget);
3591       if ((StepOpcode == ISD::MUL && SplatStepVal != 1) ||
3592           (StepOpcode == ISD::SHL && SplatStepVal != 0)) {
3593         SDValue SplatStep = DAG.getConstant(SplatStepVal, DL, VIDVT);
3594         VID = DAG.getNode(StepOpcode, DL, VIDVT, VID, SplatStep);
3595       }
3596       if (StepDenominator != 1) {
3597         SDValue SplatStep =
3598             DAG.getConstant(Log2_64(StepDenominator), DL, VIDVT);
3599         VID = DAG.getNode(ISD::SRL, DL, VIDVT, VID, SplatStep);
3600       }
3601       if (Addend != 0 || Negate) {
3602         SDValue SplatAddend = DAG.getConstant(Addend, DL, VIDVT);
3603         VID = DAG.getNode(Negate ? ISD::SUB : ISD::ADD, DL, VIDVT, SplatAddend,
3604                           VID);
3605       }
3606       if (VT.isFloatingPoint()) {
3607         // TODO: Use vfwcvt to reduce register pressure.
3608         VID = DAG.getNode(ISD::SINT_TO_FP, DL, VT, VID);
3609       }
3610       return VID;
3611     }
3612   }
3613 
3614   // For very small build_vectors, use a single scalar insert of a constant.
3615   // TODO: Base this on constant rematerialization cost, not size.
3616   const unsigned EltBitSize = VT.getScalarSizeInBits();
3617   if (VT.getSizeInBits() <= 32 &&
3618       ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) {
3619     MVT ViaIntVT = MVT::getIntegerVT(VT.getSizeInBits());
3620     assert((ViaIntVT == MVT::i16 || ViaIntVT == MVT::i32) &&
3621            "Unexpected sequence type");
3622     // If we can use the original VL with the modified element type, this
3623     // means we only have a VTYPE toggle, not a VL toggle.  TODO: Should this
3624     // be moved into InsertVSETVLI?
3625     unsigned ViaVecLen =
3626       (Subtarget.getRealMinVLen() >= VT.getSizeInBits() * NumElts) ? NumElts : 1;
3627     MVT ViaVecVT = MVT::getVectorVT(ViaIntVT, ViaVecLen);
3628 
3629     uint64_t EltMask = maskTrailingOnes<uint64_t>(EltBitSize);
3630     uint64_t SplatValue = 0;
3631     // Construct the amalgamated value at this larger vector type.
3632     for (const auto &OpIdx : enumerate(Op->op_values())) {
3633       const auto &SeqV = OpIdx.value();
3634       if (!SeqV.isUndef())
3635         SplatValue |=
3636             ((SeqV->getAsZExtVal() & EltMask) << (OpIdx.index() * EltBitSize));
3637     }
3638 
3639     // On RV64, sign-extend from 32 to 64 bits where possible in order to
3640     // achieve better constant materializion.
3641     if (Subtarget.is64Bit() && ViaIntVT == MVT::i32)
3642       SplatValue = SignExtend64<32>(SplatValue);
3643 
3644     SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ViaVecVT,
3645                               DAG.getUNDEF(ViaVecVT),
3646                               DAG.getConstant(SplatValue, DL, XLenVT),
3647                               DAG.getConstant(0, DL, XLenVT));
3648     if (ViaVecLen != 1)
3649       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL,
3650                         MVT::getVectorVT(ViaIntVT, 1), Vec,
3651                         DAG.getConstant(0, DL, XLenVT));
3652     return DAG.getBitcast(VT, Vec);
3653   }
3654 
3655 
3656   // Attempt to detect "hidden" splats, which only reveal themselves as splats
3657   // when re-interpreted as a vector with a larger element type. For example,
3658   //   v4i16 = build_vector i16 0, i16 1, i16 0, i16 1
3659   // could be instead splat as
3660   //   v2i32 = build_vector i32 0x00010000, i32 0x00010000
3661   // TODO: This optimization could also work on non-constant splats, but it
3662   // would require bit-manipulation instructions to construct the splat value.
3663   SmallVector<SDValue> Sequence;
3664   const auto *BV = cast<BuildVectorSDNode>(Op);
3665   if (VT.isInteger() && EltBitSize < Subtarget.getELen() &&
3666       ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) &&
3667       BV->getRepeatedSequence(Sequence) &&
3668       (Sequence.size() * EltBitSize) <= Subtarget.getELen()) {
3669     unsigned SeqLen = Sequence.size();
3670     MVT ViaIntVT = MVT::getIntegerVT(EltBitSize * SeqLen);
3671     assert((ViaIntVT == MVT::i16 || ViaIntVT == MVT::i32 ||
3672             ViaIntVT == MVT::i64) &&
3673            "Unexpected sequence type");
3674 
3675     // If we can use the original VL with the modified element type, this
3676     // means we only have a VTYPE toggle, not a VL toggle.  TODO: Should this
3677     // be moved into InsertVSETVLI?
3678     const unsigned RequiredVL = NumElts / SeqLen;
3679     const unsigned ViaVecLen =
3680       (Subtarget.getRealMinVLen() >= ViaIntVT.getSizeInBits() * NumElts) ?
3681       NumElts : RequiredVL;
3682     MVT ViaVecVT = MVT::getVectorVT(ViaIntVT, ViaVecLen);
3683 
3684     unsigned EltIdx = 0;
3685     uint64_t EltMask = maskTrailingOnes<uint64_t>(EltBitSize);
3686     uint64_t SplatValue = 0;
3687     // Construct the amalgamated value which can be splatted as this larger
3688     // vector type.
3689     for (const auto &SeqV : Sequence) {
3690       if (!SeqV.isUndef())
3691         SplatValue |=
3692             ((SeqV->getAsZExtVal() & EltMask) << (EltIdx * EltBitSize));
3693       EltIdx++;
3694     }
3695 
3696     // On RV64, sign-extend from 32 to 64 bits where possible in order to
3697     // achieve better constant materializion.
3698     if (Subtarget.is64Bit() && ViaIntVT == MVT::i32)
3699       SplatValue = SignExtend64<32>(SplatValue);
3700 
3701     // Since we can't introduce illegal i64 types at this stage, we can only
3702     // perform an i64 splat on RV32 if it is its own sign-extended value. That
3703     // way we can use RVV instructions to splat.
3704     assert((ViaIntVT.bitsLE(XLenVT) ||
3705             (!Subtarget.is64Bit() && ViaIntVT == MVT::i64)) &&
3706            "Unexpected bitcast sequence");
3707     if (ViaIntVT.bitsLE(XLenVT) || isInt<32>(SplatValue)) {
3708       SDValue ViaVL =
3709           DAG.getConstant(ViaVecVT.getVectorNumElements(), DL, XLenVT);
3710       MVT ViaContainerVT =
3711           getContainerForFixedLengthVector(DAG, ViaVecVT, Subtarget);
3712       SDValue Splat =
3713           DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ViaContainerVT,
3714                       DAG.getUNDEF(ViaContainerVT),
3715                       DAG.getConstant(SplatValue, DL, XLenVT), ViaVL);
3716       Splat = convertFromScalableVector(ViaVecVT, Splat, DAG, Subtarget);
3717       if (ViaVecLen != RequiredVL)
3718         Splat = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL,
3719                             MVT::getVectorVT(ViaIntVT, RequiredVL), Splat,
3720                             DAG.getConstant(0, DL, XLenVT));
3721       return DAG.getBitcast(VT, Splat);
3722     }
3723   }
3724 
3725   // If the number of signbits allows, see if we can lower as a <N x i8>.
3726   // Our main goal here is to reduce LMUL (and thus work) required to
3727   // build the constant, but we will also narrow if the resulting
3728   // narrow vector is known to materialize cheaply.
3729   // TODO: We really should be costing the smaller vector.  There are
3730   // profitable cases this misses.
3731   if (EltBitSize > 8 && VT.isInteger() &&
3732       (NumElts <= 4 || VT.getSizeInBits() > Subtarget.getRealMinVLen())) {
3733     unsigned SignBits = DAG.ComputeNumSignBits(Op);
3734     if (EltBitSize - SignBits < 8) {
3735       SDValue Source = DAG.getBuildVector(VT.changeVectorElementType(MVT::i8),
3736                                           DL, Op->ops());
3737       Source = convertToScalableVector(ContainerVT.changeVectorElementType(MVT::i8),
3738                                        Source, DAG, Subtarget);
3739       SDValue Res = DAG.getNode(RISCVISD::VSEXT_VL, DL, ContainerVT, Source, Mask, VL);
3740       return convertFromScalableVector(VT, Res, DAG, Subtarget);
3741     }
3742   }
3743 
3744   if (SDValue Res = lowerBuildVectorViaDominantValues(Op, DAG, Subtarget))
3745     return Res;
3746 
3747   // For constant vectors, use generic constant pool lowering.  Otherwise,
3748   // we'd have to materialize constants in GPRs just to move them into the
3749   // vector.
3750   return SDValue();
3751 }
3752 
lowerBUILD_VECTOR(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3753 static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
3754                                  const RISCVSubtarget &Subtarget) {
3755   MVT VT = Op.getSimpleValueType();
3756   assert(VT.isFixedLengthVector() && "Unexpected vector!");
3757 
3758   if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
3759       ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
3760     return lowerBuildVectorOfConstants(Op, DAG, Subtarget);
3761 
3762   MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
3763 
3764   SDLoc DL(Op);
3765   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3766 
3767   MVT XLenVT = Subtarget.getXLenVT();
3768 
3769   if (VT.getVectorElementType() == MVT::i1) {
3770     // A BUILD_VECTOR can be lowered as a SETCC. For each fixed-length mask
3771     // vector type, we have a legal equivalently-sized i8 type, so we can use
3772     // that.
3773     MVT WideVecVT = VT.changeVectorElementType(MVT::i8);
3774     SDValue VecZero = DAG.getConstant(0, DL, WideVecVT);
3775 
3776     SDValue WideVec;
3777     if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
3778       // For a splat, perform a scalar truncate before creating the wider
3779       // vector.
3780       Splat = DAG.getNode(ISD::AND, DL, Splat.getValueType(), Splat,
3781                           DAG.getConstant(1, DL, Splat.getValueType()));
3782       WideVec = DAG.getSplatBuildVector(WideVecVT, DL, Splat);
3783     } else {
3784       SmallVector<SDValue, 8> Ops(Op->op_values());
3785       WideVec = DAG.getBuildVector(WideVecVT, DL, Ops);
3786       SDValue VecOne = DAG.getConstant(1, DL, WideVecVT);
3787       WideVec = DAG.getNode(ISD::AND, DL, WideVecVT, WideVec, VecOne);
3788     }
3789 
3790     return DAG.getSetCC(DL, VT, WideVec, VecZero, ISD::SETNE);
3791   }
3792 
3793   if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
3794     if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget))
3795       return Gather;
3796     unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
3797                                         : RISCVISD::VMV_V_X_VL;
3798     if (!VT.isFloatingPoint())
3799       Splat = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Splat);
3800     Splat =
3801         DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Splat, VL);
3802     return convertFromScalableVector(VT, Splat, DAG, Subtarget);
3803   }
3804 
3805   if (SDValue Res = lowerBuildVectorViaDominantValues(Op, DAG, Subtarget))
3806     return Res;
3807 
3808   // If we're compiling for an exact VLEN value, we can split our work per
3809   // register in the register group.
3810   const unsigned MinVLen = Subtarget.getRealMinVLen();
3811   const unsigned MaxVLen = Subtarget.getRealMaxVLen();
3812   if (MinVLen == MaxVLen && VT.getSizeInBits().getKnownMinValue() > MinVLen) {
3813     MVT ElemVT = VT.getVectorElementType();
3814     unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
3815     EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
3816     MVT OneRegVT = MVT::getVectorVT(ElemVT, ElemsPerVReg);
3817     MVT M1VT = getContainerForFixedLengthVector(DAG, OneRegVT, Subtarget);
3818     assert(M1VT == getLMUL1VT(M1VT));
3819 
3820     // The following semantically builds up a fixed length concat_vector
3821     // of the component build_vectors.  We eagerly lower to scalable and
3822     // insert_subvector here to avoid DAG combining it back to a large
3823     // build_vector.
3824     SmallVector<SDValue> BuildVectorOps(Op->op_begin(), Op->op_end());
3825     unsigned NumOpElts = M1VT.getVectorMinNumElements();
3826     SDValue Vec = DAG.getUNDEF(ContainerVT);
3827     for (unsigned i = 0; i < VT.getVectorNumElements(); i += ElemsPerVReg) {
3828       auto OneVRegOfOps = ArrayRef(BuildVectorOps).slice(i, ElemsPerVReg);
3829       SDValue SubBV =
3830           DAG.getNode(ISD::BUILD_VECTOR, DL, OneRegVT, OneVRegOfOps);
3831       SubBV = convertToScalableVector(M1VT, SubBV, DAG, Subtarget);
3832       unsigned InsertIdx = (i / ElemsPerVReg) * NumOpElts;
3833       Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Vec, SubBV,
3834                         DAG.getVectorIdxConstant(InsertIdx, DL));
3835     }
3836     return convertFromScalableVector(VT, Vec, DAG, Subtarget);
3837   }
3838 
3839   // Cap the cost at a value linear to the number of elements in the vector.
3840   // The default lowering is to use the stack.  The vector store + scalar loads
3841   // is linear in VL.  However, at high lmuls vslide1down and vslidedown end up
3842   // being (at least) linear in LMUL.  As a result, using the vslidedown
3843   // lowering for every element ends up being VL*LMUL..
3844   // TODO: Should we be directly costing the stack alternative?  Doing so might
3845   // give us a more accurate upper bound.
3846   InstructionCost LinearBudget = VT.getVectorNumElements() * 2;
3847 
3848   // TODO: unify with TTI getSlideCost.
3849   InstructionCost PerSlideCost = 1;
3850   switch (RISCVTargetLowering::getLMUL(ContainerVT)) {
3851   default: break;
3852   case RISCVII::VLMUL::LMUL_2:
3853     PerSlideCost = 2;
3854     break;
3855   case RISCVII::VLMUL::LMUL_4:
3856     PerSlideCost = 4;
3857     break;
3858   case RISCVII::VLMUL::LMUL_8:
3859     PerSlideCost = 8;
3860     break;
3861   }
3862 
3863   // TODO: Should we be using the build instseq then cost + evaluate scheme
3864   // we use for integer constants here?
3865   unsigned UndefCount = 0;
3866   for (const SDValue &V : Op->ops()) {
3867     if (V.isUndef()) {
3868       UndefCount++;
3869       continue;
3870     }
3871     if (UndefCount) {
3872       LinearBudget -= PerSlideCost;
3873       UndefCount = 0;
3874     }
3875     LinearBudget -= PerSlideCost;
3876   }
3877   if (UndefCount) {
3878     LinearBudget -= PerSlideCost;
3879   }
3880 
3881   if (LinearBudget < 0)
3882     return SDValue();
3883 
3884   assert((!VT.isFloatingPoint() ||
3885           VT.getVectorElementType().getSizeInBits() <= Subtarget.getFLen()) &&
3886          "Illegal type which will result in reserved encoding");
3887 
3888   const unsigned Policy = RISCVII::TAIL_AGNOSTIC | RISCVII::MASK_AGNOSTIC;
3889 
3890   SDValue Vec;
3891   UndefCount = 0;
3892   for (SDValue V : Op->ops()) {
3893     if (V.isUndef()) {
3894       UndefCount++;
3895       continue;
3896     }
3897 
3898     // Start our sequence with a TA splat in the hopes that hardware is able to
3899     // recognize there's no dependency on the prior value of our temporary
3900     // register.
3901     if (!Vec) {
3902       Vec = DAG.getSplatVector(VT, DL, V);
3903       Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
3904       UndefCount = 0;
3905       continue;
3906     }
3907 
3908     if (UndefCount) {
3909       const SDValue Offset = DAG.getConstant(UndefCount, DL, Subtarget.getXLenVT());
3910       Vec = getVSlidedown(DAG, Subtarget, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
3911                           Vec, Offset, Mask, VL, Policy);
3912       UndefCount = 0;
3913     }
3914     auto OpCode =
3915       VT.isFloatingPoint() ? RISCVISD::VFSLIDE1DOWN_VL : RISCVISD::VSLIDE1DOWN_VL;
3916     if (!VT.isFloatingPoint())
3917       V = DAG.getNode(ISD::ANY_EXTEND, DL, Subtarget.getXLenVT(), V);
3918     Vec = DAG.getNode(OpCode, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Vec,
3919                       V, Mask, VL);
3920   }
3921   if (UndefCount) {
3922     const SDValue Offset = DAG.getConstant(UndefCount, DL, Subtarget.getXLenVT());
3923     Vec = getVSlidedown(DAG, Subtarget, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
3924                         Vec, Offset, Mask, VL, Policy);
3925   }
3926   return convertFromScalableVector(VT, Vec, DAG, Subtarget);
3927 }
3928 
splatPartsI64WithVL(const SDLoc & DL,MVT VT,SDValue Passthru,SDValue Lo,SDValue Hi,SDValue VL,SelectionDAG & DAG)3929 static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
3930                                    SDValue Lo, SDValue Hi, SDValue VL,
3931                                    SelectionDAG &DAG) {
3932   if (!Passthru)
3933     Passthru = DAG.getUNDEF(VT);
3934   if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
3935     int32_t LoC = cast<ConstantSDNode>(Lo)->getSExtValue();
3936     int32_t HiC = cast<ConstantSDNode>(Hi)->getSExtValue();
3937     // If Hi constant is all the same sign bit as Lo, lower this as a custom
3938     // node in order to try and match RVV vector/scalar instructions.
3939     if ((LoC >> 31) == HiC)
3940       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL);
3941 
3942     // If vl is equal to VLMAX or fits in 4 bits and Hi constant is equal to Lo,
3943     // we could use vmv.v.x whose EEW = 32 to lower it. This allows us to use
3944     // vlmax vsetvli or vsetivli to change the VL.
3945     // FIXME: Support larger constants?
3946     // FIXME: Support non-constant VLs by saturating?
3947     if (LoC == HiC) {
3948       SDValue NewVL;
3949       if (isAllOnesConstant(VL) ||
3950           (isa<RegisterSDNode>(VL) &&
3951            cast<RegisterSDNode>(VL)->getReg() == RISCV::X0))
3952         NewVL = DAG.getRegister(RISCV::X0, MVT::i32);
3953       else if (isa<ConstantSDNode>(VL) && isUInt<4>(VL->getAsZExtVal()))
3954         NewVL = DAG.getNode(ISD::ADD, DL, VL.getValueType(), VL, VL);
3955 
3956       if (NewVL) {
3957         MVT InterVT =
3958             MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2);
3959         auto InterVec = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, InterVT,
3960                                     DAG.getUNDEF(InterVT), Lo,
3961                                     DAG.getRegister(RISCV::X0, MVT::i32));
3962         return DAG.getNode(ISD::BITCAST, DL, VT, InterVec);
3963       }
3964     }
3965   }
3966 
3967   // Detect cases where Hi is (SRA Lo, 31) which means Hi is Lo sign extended.
3968   if (Hi.getOpcode() == ISD::SRA && Hi.getOperand(0) == Lo &&
3969       isa<ConstantSDNode>(Hi.getOperand(1)) &&
3970       Hi.getConstantOperandVal(1) == 31)
3971     return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL);
3972 
3973   // If the hi bits of the splat are undefined, then it's fine to just splat Lo
3974   // even if it might be sign extended.
3975   if (Hi.isUndef())
3976     return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL);
3977 
3978   // Fall back to a stack store and stride x0 vector load.
3979   return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VT, Passthru, Lo,
3980                      Hi, VL);
3981 }
3982 
3983 // Called by type legalization to handle splat of i64 on RV32.
3984 // FIXME: We can optimize this when the type has sign or zero bits in one
3985 // of the halves.
splatSplitI64WithVL(const SDLoc & DL,MVT VT,SDValue Passthru,SDValue Scalar,SDValue VL,SelectionDAG & DAG)3986 static SDValue splatSplitI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
3987                                    SDValue Scalar, SDValue VL,
3988                                    SelectionDAG &DAG) {
3989   assert(Scalar.getValueType() == MVT::i64 && "Unexpected VT!");
3990   SDValue Lo, Hi;
3991   std::tie(Lo, Hi) = DAG.SplitScalar(Scalar, DL, MVT::i32, MVT::i32);
3992   return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG);
3993 }
3994 
3995 // This function lowers a splat of a scalar operand Splat with the vector
3996 // length VL. It ensures the final sequence is type legal, which is useful when
3997 // lowering a splat after type legalization.
lowerScalarSplat(SDValue Passthru,SDValue Scalar,SDValue VL,MVT VT,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3998 static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
3999                                 MVT VT, const SDLoc &DL, SelectionDAG &DAG,
4000                                 const RISCVSubtarget &Subtarget) {
4001   bool HasPassthru = Passthru && !Passthru.isUndef();
4002   if (!HasPassthru && !Passthru)
4003     Passthru = DAG.getUNDEF(VT);
4004   if (VT.isFloatingPoint())
4005     return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL);
4006 
4007   MVT XLenVT = Subtarget.getXLenVT();
4008 
4009   // Simplest case is that the operand needs to be promoted to XLenVT.
4010   if (Scalar.getValueType().bitsLE(XLenVT)) {
4011     // If the operand is a constant, sign extend to increase our chances
4012     // of being able to use a .vi instruction. ANY_EXTEND would become a
4013     // a zero extend and the simm5 check in isel would fail.
4014     // FIXME: Should we ignore the upper bits in isel instead?
4015     unsigned ExtOpc =
4016         isa<ConstantSDNode>(Scalar) ? ISD::SIGN_EXTEND : ISD::ANY_EXTEND;
4017     Scalar = DAG.getNode(ExtOpc, DL, XLenVT, Scalar);
4018     return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
4019   }
4020 
4021   assert(XLenVT == MVT::i32 && Scalar.getValueType() == MVT::i64 &&
4022          "Unexpected scalar for splat lowering!");
4023 
4024   if (isOneConstant(VL) && isNullConstant(Scalar))
4025     return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru,
4026                        DAG.getConstant(0, DL, XLenVT), VL);
4027 
4028   // Otherwise use the more complicated splatting algorithm.
4029   return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG);
4030 }
4031 
4032 // This function lowers an insert of a scalar operand Scalar into lane
4033 // 0 of the vector regardless of the value of VL.  The contents of the
4034 // remaining lanes of the result vector are unspecified.  VL is assumed
4035 // to be non-zero.
lowerScalarInsert(SDValue Scalar,SDValue VL,MVT VT,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)4036 static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL, MVT VT,
4037                                  const SDLoc &DL, SelectionDAG &DAG,
4038                                  const RISCVSubtarget &Subtarget) {
4039   assert(VT.isScalableVector() && "Expect VT is scalable vector type.");
4040 
4041   const MVT XLenVT = Subtarget.getXLenVT();
4042   SDValue Passthru = DAG.getUNDEF(VT);
4043 
4044   if (Scalar.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
4045       isNullConstant(Scalar.getOperand(1))) {
4046     SDValue ExtractedVal = Scalar.getOperand(0);
4047     // The element types must be the same.
4048     if (ExtractedVal.getValueType().getVectorElementType() ==
4049         VT.getVectorElementType()) {
4050       MVT ExtractedVT = ExtractedVal.getSimpleValueType();
4051       MVT ExtractedContainerVT = ExtractedVT;
4052       if (ExtractedContainerVT.isFixedLengthVector()) {
4053         ExtractedContainerVT = getContainerForFixedLengthVector(
4054             DAG, ExtractedContainerVT, Subtarget);
4055         ExtractedVal = convertToScalableVector(ExtractedContainerVT,
4056                                                ExtractedVal, DAG, Subtarget);
4057       }
4058       if (ExtractedContainerVT.bitsLE(VT))
4059         return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru,
4060                            ExtractedVal, DAG.getConstant(0, DL, XLenVT));
4061       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal,
4062                          DAG.getConstant(0, DL, XLenVT));
4063     }
4064   }
4065 
4066 
4067   if (VT.isFloatingPoint())
4068     return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, VT,
4069                        DAG.getUNDEF(VT), Scalar, VL);
4070 
4071   // Avoid the tricky legalization cases by falling back to using the
4072   // splat code which already handles it gracefully.
4073   if (!Scalar.getValueType().bitsLE(XLenVT))
4074     return lowerScalarSplat(DAG.getUNDEF(VT), Scalar,
4075                             DAG.getConstant(1, DL, XLenVT),
4076                             VT, DL, DAG, Subtarget);
4077 
4078   // If the operand is a constant, sign extend to increase our chances
4079   // of being able to use a .vi instruction. ANY_EXTEND would become a
4080   // a zero extend and the simm5 check in isel would fail.
4081   // FIXME: Should we ignore the upper bits in isel instead?
4082   unsigned ExtOpc =
4083     isa<ConstantSDNode>(Scalar) ? ISD::SIGN_EXTEND : ISD::ANY_EXTEND;
4084   Scalar = DAG.getNode(ExtOpc, DL, XLenVT, Scalar);
4085   return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT,
4086                      DAG.getUNDEF(VT), Scalar, VL);
4087 }
4088 
4089 // Is this a shuffle extracts either the even or odd elements of a vector?
4090 // That is, specifically, either (a) or (b) below.
4091 // t34: v8i8 = extract_subvector t11, Constant:i64<0>
4092 // t33: v8i8 = extract_subvector t11, Constant:i64<8>
4093 // a) t35: v8i8 = vector_shuffle<0,2,4,6,8,10,12,14> t34, t33
4094 // b) t35: v8i8 = vector_shuffle<1,3,5,7,9,11,13,15> t34, t33
4095 // Returns {Src Vector, Even Elements} om success
isDeinterleaveShuffle(MVT VT,MVT ContainerVT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const RISCVSubtarget & Subtarget)4096 static bool isDeinterleaveShuffle(MVT VT, MVT ContainerVT, SDValue V1,
4097                                   SDValue V2, ArrayRef<int> Mask,
4098                                   const RISCVSubtarget &Subtarget) {
4099   // Need to be able to widen the vector.
4100   if (VT.getScalarSizeInBits() >= Subtarget.getELen())
4101     return false;
4102 
4103   // Both input must be extracts.
4104   if (V1.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
4105       V2.getOpcode() != ISD::EXTRACT_SUBVECTOR)
4106     return false;
4107 
4108   // Extracting from the same source.
4109   SDValue Src = V1.getOperand(0);
4110   if (Src != V2.getOperand(0))
4111     return false;
4112 
4113   // Src needs to have twice the number of elements.
4114   if (Src.getValueType().getVectorNumElements() != (Mask.size() * 2))
4115     return false;
4116 
4117   // The extracts must extract the two halves of the source.
4118   if (V1.getConstantOperandVal(1) != 0 ||
4119       V2.getConstantOperandVal(1) != Mask.size())
4120     return false;
4121 
4122   // First index must be the first even or odd element from V1.
4123   if (Mask[0] != 0 && Mask[0] != 1)
4124     return false;
4125 
4126   // The others must increase by 2 each time.
4127   // TODO: Support undef elements?
4128   for (unsigned i = 1; i != Mask.size(); ++i)
4129     if (Mask[i] != Mask[i - 1] + 2)
4130       return false;
4131 
4132   return true;
4133 }
4134 
4135 /// Is this shuffle interleaving contiguous elements from one vector into the
4136 /// even elements and contiguous elements from another vector into the odd
4137 /// elements. \p EvenSrc will contain the element that should be in the first
4138 /// even element. \p OddSrc will contain the element that should be in the first
4139 /// odd element. These can be the first element in a source or the element half
4140 /// way through the source.
isInterleaveShuffle(ArrayRef<int> Mask,MVT VT,int & EvenSrc,int & OddSrc,const RISCVSubtarget & Subtarget)4141 static bool isInterleaveShuffle(ArrayRef<int> Mask, MVT VT, int &EvenSrc,
4142                                 int &OddSrc, const RISCVSubtarget &Subtarget) {
4143   // We need to be able to widen elements to the next larger integer type.
4144   if (VT.getScalarSizeInBits() >= Subtarget.getELen())
4145     return false;
4146 
4147   int Size = Mask.size();
4148   int NumElts = VT.getVectorNumElements();
4149   assert(Size == (int)NumElts && "Unexpected mask size");
4150 
4151   SmallVector<unsigned, 2> StartIndexes;
4152   if (!ShuffleVectorInst::isInterleaveMask(Mask, 2, Size * 2, StartIndexes))
4153     return false;
4154 
4155   EvenSrc = StartIndexes[0];
4156   OddSrc = StartIndexes[1];
4157 
4158   // One source should be low half of first vector.
4159   if (EvenSrc != 0 && OddSrc != 0)
4160     return false;
4161 
4162   // Subvectors will be subtracted from either at the start of the two input
4163   // vectors, or at the start and middle of the first vector if it's an unary
4164   // interleave.
4165   // In both cases, HalfNumElts will be extracted.
4166   // We need to ensure that the extract indices are 0 or HalfNumElts otherwise
4167   // we'll create an illegal extract_subvector.
4168   // FIXME: We could support other values using a slidedown first.
4169   int HalfNumElts = NumElts / 2;
4170   return ((EvenSrc % HalfNumElts) == 0) && ((OddSrc % HalfNumElts) == 0);
4171 }
4172 
4173 /// Match shuffles that concatenate two vectors, rotate the concatenation,
4174 /// and then extract the original number of elements from the rotated result.
4175 /// This is equivalent to vector.splice or X86's PALIGNR instruction. The
4176 /// returned rotation amount is for a rotate right, where elements move from
4177 /// higher elements to lower elements. \p LoSrc indicates the first source
4178 /// vector of the rotate or -1 for undef. \p HiSrc indicates the second vector
4179 /// of the rotate or -1 for undef. At least one of \p LoSrc and \p HiSrc will be
4180 /// 0 or 1 if a rotation is found.
4181 ///
4182 /// NOTE: We talk about rotate to the right which matches how bit shift and
4183 /// rotate instructions are described where LSBs are on the right, but LLVM IR
4184 /// and the table below write vectors with the lowest elements on the left.
isElementRotate(int & LoSrc,int & HiSrc,ArrayRef<int> Mask)4185 static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
4186   int Size = Mask.size();
4187 
4188   // We need to detect various ways of spelling a rotation:
4189   //   [11, 12, 13, 14, 15,  0,  1,  2]
4190   //   [-1, 12, 13, 14, -1, -1,  1, -1]
4191   //   [-1, -1, -1, -1, -1, -1,  1,  2]
4192   //   [ 3,  4,  5,  6,  7,  8,  9, 10]
4193   //   [-1,  4,  5,  6, -1, -1,  9, -1]
4194   //   [-1,  4,  5,  6, -1, -1, -1, -1]
4195   int Rotation = 0;
4196   LoSrc = -1;
4197   HiSrc = -1;
4198   for (int i = 0; i != Size; ++i) {
4199     int M = Mask[i];
4200     if (M < 0)
4201       continue;
4202 
4203     // Determine where a rotate vector would have started.
4204     int StartIdx = i - (M % Size);
4205     // The identity rotation isn't interesting, stop.
4206     if (StartIdx == 0)
4207       return -1;
4208 
4209     // If we found the tail of a vector the rotation must be the missing
4210     // front. If we found the head of a vector, it must be how much of the
4211     // head.
4212     int CandidateRotation = StartIdx < 0 ? -StartIdx : Size - StartIdx;
4213 
4214     if (Rotation == 0)
4215       Rotation = CandidateRotation;
4216     else if (Rotation != CandidateRotation)
4217       // The rotations don't match, so we can't match this mask.
4218       return -1;
4219 
4220     // Compute which value this mask is pointing at.
4221     int MaskSrc = M < Size ? 0 : 1;
4222 
4223     // Compute which of the two target values this index should be assigned to.
4224     // This reflects whether the high elements are remaining or the low elemnts
4225     // are remaining.
4226     int &TargetSrc = StartIdx < 0 ? HiSrc : LoSrc;
4227 
4228     // Either set up this value if we've not encountered it before, or check
4229     // that it remains consistent.
4230     if (TargetSrc < 0)
4231       TargetSrc = MaskSrc;
4232     else if (TargetSrc != MaskSrc)
4233       // This may be a rotation, but it pulls from the inputs in some
4234       // unsupported interleaving.
4235       return -1;
4236   }
4237 
4238   // Check that we successfully analyzed the mask, and normalize the results.
4239   assert(Rotation != 0 && "Failed to locate a viable rotation!");
4240   assert((LoSrc >= 0 || HiSrc >= 0) &&
4241          "Failed to find a rotated input vector!");
4242 
4243   return Rotation;
4244 }
4245 
4246 // Lower a deinterleave shuffle to vnsrl.
4247 // [a, p, b, q, c, r, d, s] -> [a, b, c, d] (EvenElts == true)
4248 //                          -> [p, q, r, s] (EvenElts == false)
4249 // VT is the type of the vector to return, <[vscale x ]n x ty>
4250 // Src is the vector to deinterleave of type <[vscale x ]n*2 x ty>
getDeinterleaveViaVNSRL(const SDLoc & DL,MVT VT,SDValue Src,bool EvenElts,const RISCVSubtarget & Subtarget,SelectionDAG & DAG)4251 static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src,
4252                                        bool EvenElts,
4253                                        const RISCVSubtarget &Subtarget,
4254                                        SelectionDAG &DAG) {
4255   // The result is a vector of type <m x n x ty>
4256   MVT ContainerVT = VT;
4257   // Convert fixed vectors to scalable if needed
4258   if (ContainerVT.isFixedLengthVector()) {
4259     assert(Src.getSimpleValueType().isFixedLengthVector());
4260     ContainerVT = getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget);
4261 
4262     // The source is a vector of type <m x n*2 x ty>
4263     MVT SrcContainerVT =
4264         MVT::getVectorVT(ContainerVT.getVectorElementType(),
4265                          ContainerVT.getVectorElementCount() * 2);
4266     Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
4267   }
4268 
4269   auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
4270 
4271   // Bitcast the source vector from <m x n*2 x ty> -> <m x n x ty*2>
4272   // This also converts FP to int.
4273   unsigned EltBits = ContainerVT.getScalarSizeInBits();
4274   MVT WideSrcContainerVT = MVT::getVectorVT(
4275       MVT::getIntegerVT(EltBits * 2), ContainerVT.getVectorElementCount());
4276   Src = DAG.getBitcast(WideSrcContainerVT, Src);
4277 
4278   // The integer version of the container type.
4279   MVT IntContainerVT = ContainerVT.changeVectorElementTypeToInteger();
4280 
4281   // If we want even elements, then the shift amount is 0. Otherwise, shift by
4282   // the original element size.
4283   unsigned Shift = EvenElts ? 0 : EltBits;
4284   SDValue SplatShift = DAG.getNode(
4285       RISCVISD::VMV_V_X_VL, DL, IntContainerVT, DAG.getUNDEF(ContainerVT),
4286       DAG.getConstant(Shift, DL, Subtarget.getXLenVT()), VL);
4287   SDValue Res =
4288       DAG.getNode(RISCVISD::VNSRL_VL, DL, IntContainerVT, Src, SplatShift,
4289                   DAG.getUNDEF(IntContainerVT), TrueMask, VL);
4290   // Cast back to FP if needed.
4291   Res = DAG.getBitcast(ContainerVT, Res);
4292 
4293   if (VT.isFixedLengthVector())
4294     Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
4295   return Res;
4296 }
4297 
4298 // Lower the following shuffle to vslidedown.
4299 // a)
4300 // t49: v8i8 = extract_subvector t13, Constant:i64<0>
4301 // t109: v8i8 = extract_subvector t13, Constant:i64<8>
4302 // t108: v8i8 = vector_shuffle<1,2,3,4,5,6,7,8> t49, t106
4303 // b)
4304 // t69: v16i16 = extract_subvector t68, Constant:i64<0>
4305 // t23: v8i16 = extract_subvector t69, Constant:i64<0>
4306 // t29: v4i16 = extract_subvector t23, Constant:i64<4>
4307 // t26: v8i16 = extract_subvector t69, Constant:i64<8>
4308 // t30: v4i16 = extract_subvector t26, Constant:i64<0>
4309 // t54: v4i16 = vector_shuffle<1,2,3,4> t29, t30
lowerVECTOR_SHUFFLEAsVSlidedown(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const RISCVSubtarget & Subtarget,SelectionDAG & DAG)4310 static SDValue lowerVECTOR_SHUFFLEAsVSlidedown(const SDLoc &DL, MVT VT,
4311                                                SDValue V1, SDValue V2,
4312                                                ArrayRef<int> Mask,
4313                                                const RISCVSubtarget &Subtarget,
4314                                                SelectionDAG &DAG) {
4315   auto findNonEXTRACT_SUBVECTORParent =
4316       [](SDValue Parent) -> std::pair<SDValue, uint64_t> {
4317     uint64_t Offset = 0;
4318     while (Parent.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
4319            // EXTRACT_SUBVECTOR can be used to extract a fixed-width vector from
4320            // a scalable vector. But we don't want to match the case.
4321            Parent.getOperand(0).getSimpleValueType().isFixedLengthVector()) {
4322       Offset += Parent.getConstantOperandVal(1);
4323       Parent = Parent.getOperand(0);
4324     }
4325     return std::make_pair(Parent, Offset);
4326   };
4327 
4328   auto [V1Src, V1IndexOffset] = findNonEXTRACT_SUBVECTORParent(V1);
4329   auto [V2Src, V2IndexOffset] = findNonEXTRACT_SUBVECTORParent(V2);
4330 
4331   // Extracting from the same source.
4332   SDValue Src = V1Src;
4333   if (Src != V2Src)
4334     return SDValue();
4335 
4336   // Rebuild mask because Src may be from multiple EXTRACT_SUBVECTORs.
4337   SmallVector<int, 16> NewMask(Mask);
4338   for (size_t i = 0; i != NewMask.size(); ++i) {
4339     if (NewMask[i] == -1)
4340       continue;
4341 
4342     if (static_cast<size_t>(NewMask[i]) < NewMask.size()) {
4343       NewMask[i] = NewMask[i] + V1IndexOffset;
4344     } else {
4345       // Minus NewMask.size() is needed. Otherwise, the b case would be
4346       // <5,6,7,12> instead of <5,6,7,8>.
4347       NewMask[i] = NewMask[i] - NewMask.size() + V2IndexOffset;
4348     }
4349   }
4350 
4351   // First index must be known and non-zero. It will be used as the slidedown
4352   // amount.
4353   if (NewMask[0] <= 0)
4354     return SDValue();
4355 
4356   // NewMask is also continuous.
4357   for (unsigned i = 1; i != NewMask.size(); ++i)
4358     if (NewMask[i - 1] + 1 != NewMask[i])
4359       return SDValue();
4360 
4361   MVT XLenVT = Subtarget.getXLenVT();
4362   MVT SrcVT = Src.getSimpleValueType();
4363   MVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT, Subtarget);
4364   auto [TrueMask, VL] = getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
4365   SDValue Slidedown =
4366       getVSlidedown(DAG, Subtarget, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
4367                     convertToScalableVector(ContainerVT, Src, DAG, Subtarget),
4368                     DAG.getConstant(NewMask[0], DL, XLenVT), TrueMask, VL);
4369   return DAG.getNode(
4370       ISD::EXTRACT_SUBVECTOR, DL, VT,
4371       convertFromScalableVector(SrcVT, Slidedown, DAG, Subtarget),
4372       DAG.getConstant(0, DL, XLenVT));
4373 }
4374 
4375 // Because vslideup leaves the destination elements at the start intact, we can
4376 // use it to perform shuffles that insert subvectors:
4377 //
4378 // vector_shuffle v8:v8i8, v9:v8i8, <0, 1, 2, 3, 8, 9, 10, 11>
4379 // ->
4380 // vsetvli zero, 8, e8, mf2, ta, ma
4381 // vslideup.vi v8, v9, 4
4382 //
4383 // vector_shuffle v8:v8i8, v9:v8i8 <0, 1, 8, 9, 10, 5, 6, 7>
4384 // ->
4385 // vsetvli zero, 5, e8, mf2, tu, ma
4386 // vslideup.v1 v8, v9, 2
lowerVECTOR_SHUFFLEAsVSlideup(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const RISCVSubtarget & Subtarget,SelectionDAG & DAG)4387 static SDValue lowerVECTOR_SHUFFLEAsVSlideup(const SDLoc &DL, MVT VT,
4388                                              SDValue V1, SDValue V2,
4389                                              ArrayRef<int> Mask,
4390                                              const RISCVSubtarget &Subtarget,
4391                                              SelectionDAG &DAG) {
4392   unsigned NumElts = VT.getVectorNumElements();
4393   int NumSubElts, Index;
4394   if (!ShuffleVectorInst::isInsertSubvectorMask(Mask, NumElts, NumSubElts,
4395                                                 Index))
4396     return SDValue();
4397 
4398   bool OpsSwapped = Mask[Index] < (int)NumElts;
4399   SDValue InPlace = OpsSwapped ? V2 : V1;
4400   SDValue ToInsert = OpsSwapped ? V1 : V2;
4401 
4402   MVT XLenVT = Subtarget.getXLenVT();
4403   MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
4404   auto TrueMask = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).first;
4405   // We slide up by the index that the subvector is being inserted at, and set
4406   // VL to the index + the number of elements being inserted.
4407   unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED | RISCVII::MASK_AGNOSTIC;
4408   // If the we're adding a suffix to the in place vector, i.e. inserting right
4409   // up to the very end of it, then we don't actually care about the tail.
4410   if (NumSubElts + Index >= (int)NumElts)
4411     Policy |= RISCVII::TAIL_AGNOSTIC;
4412 
4413   InPlace = convertToScalableVector(ContainerVT, InPlace, DAG, Subtarget);
4414   ToInsert = convertToScalableVector(ContainerVT, ToInsert, DAG, Subtarget);
4415   SDValue VL = DAG.getConstant(NumSubElts + Index, DL, XLenVT);
4416 
4417   SDValue Res;
4418   // If we're inserting into the lowest elements, use a tail undisturbed
4419   // vmv.v.v.
4420   if (Index == 0)
4421     Res = DAG.getNode(RISCVISD::VMV_V_V_VL, DL, ContainerVT, InPlace, ToInsert,
4422                       VL);
4423   else
4424     Res = getVSlideup(DAG, Subtarget, DL, ContainerVT, InPlace, ToInsert,
4425                       DAG.getConstant(Index, DL, XLenVT), TrueMask, VL, Policy);
4426   return convertFromScalableVector(VT, Res, DAG, Subtarget);
4427 }
4428 
4429 /// Match v(f)slide1up/down idioms.  These operations involve sliding
4430 /// N-1 elements to make room for an inserted scalar at one end.
lowerVECTOR_SHUFFLEAsVSlide1(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const RISCVSubtarget & Subtarget,SelectionDAG & DAG)4431 static SDValue lowerVECTOR_SHUFFLEAsVSlide1(const SDLoc &DL, MVT VT,
4432                                             SDValue V1, SDValue V2,
4433                                             ArrayRef<int> Mask,
4434                                             const RISCVSubtarget &Subtarget,
4435                                             SelectionDAG &DAG) {
4436   bool OpsSwapped = false;
4437   if (!isa<BuildVectorSDNode>(V1)) {
4438     if (!isa<BuildVectorSDNode>(V2))
4439       return SDValue();
4440     std::swap(V1, V2);
4441     OpsSwapped = true;
4442   }
4443   SDValue Splat = cast<BuildVectorSDNode>(V1)->getSplatValue();
4444   if (!Splat)
4445     return SDValue();
4446 
4447   // Return true if the mask could describe a slide of Mask.size() - 1
4448   // elements from concat_vector(V1, V2)[Base:] to [Offset:].
4449   auto isSlideMask = [](ArrayRef<int> Mask, unsigned Base, int Offset) {
4450     const unsigned S = (Offset > 0) ? 0 : -Offset;
4451     const unsigned E = Mask.size() - ((Offset > 0) ? Offset : 0);
4452     for (unsigned i = S; i != E; ++i)
4453       if (Mask[i] >= 0 && (unsigned)Mask[i] != Base + i + Offset)
4454         return false;
4455     return true;
4456   };
4457 
4458   const unsigned NumElts = VT.getVectorNumElements();
4459   bool IsVSlidedown = isSlideMask(Mask, OpsSwapped ? 0 : NumElts, 1);
4460   if (!IsVSlidedown && !isSlideMask(Mask, OpsSwapped ? 0 : NumElts, -1))
4461     return SDValue();
4462 
4463   const int InsertIdx = Mask[IsVSlidedown ? (NumElts - 1) : 0];
4464   // Inserted lane must come from splat, undef scalar is legal but not profitable.
4465   if (InsertIdx < 0 || InsertIdx / NumElts != (unsigned)OpsSwapped)
4466     return SDValue();
4467 
4468   MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
4469   auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
4470   auto OpCode = IsVSlidedown ?
4471     (VT.isFloatingPoint() ? RISCVISD::VFSLIDE1DOWN_VL : RISCVISD::VSLIDE1DOWN_VL) :
4472     (VT.isFloatingPoint() ? RISCVISD::VFSLIDE1UP_VL : RISCVISD::VSLIDE1UP_VL);
4473   if (!VT.isFloatingPoint())
4474     Splat = DAG.getNode(ISD::ANY_EXTEND, DL, Subtarget.getXLenVT(), Splat);
4475   auto Vec = DAG.getNode(OpCode, DL, ContainerVT,
4476                          DAG.getUNDEF(ContainerVT),
4477                          convertToScalableVector(ContainerVT, V2, DAG, Subtarget),
4478                          Splat, TrueMask, VL);
4479   return convertFromScalableVector(VT, Vec, DAG, Subtarget);
4480 }
4481 
4482 // Given two input vectors of <[vscale x ]n x ty>, use vwaddu.vv and vwmaccu.vx
4483 // to create an interleaved vector of <[vscale x] n*2 x ty>.
4484 // This requires that the size of ty is less than the subtarget's maximum ELEN.
getWideningInterleave(SDValue EvenV,SDValue OddV,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)4485 static SDValue getWideningInterleave(SDValue EvenV, SDValue OddV,
4486                                      const SDLoc &DL, SelectionDAG &DAG,
4487                                      const RISCVSubtarget &Subtarget) {
4488   MVT VecVT = EvenV.getSimpleValueType();
4489   MVT VecContainerVT = VecVT; // <vscale x n x ty>
4490   // Convert fixed vectors to scalable if needed
4491   if (VecContainerVT.isFixedLengthVector()) {
4492     VecContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget);
4493     EvenV = convertToScalableVector(VecContainerVT, EvenV, DAG, Subtarget);
4494     OddV = convertToScalableVector(VecContainerVT, OddV, DAG, Subtarget);
4495   }
4496 
4497   assert(VecVT.getScalarSizeInBits() < Subtarget.getELen());
4498 
4499   // We're working with a vector of the same size as the resulting
4500   // interleaved vector, but with half the number of elements and
4501   // twice the SEW (Hence the restriction on not using the maximum
4502   // ELEN)
4503   MVT WideVT =
4504       MVT::getVectorVT(MVT::getIntegerVT(VecVT.getScalarSizeInBits() * 2),
4505                        VecVT.getVectorElementCount());
4506   MVT WideContainerVT = WideVT; // <vscale x n x ty*2>
4507   if (WideContainerVT.isFixedLengthVector())
4508     WideContainerVT = getContainerForFixedLengthVector(DAG, WideVT, Subtarget);
4509 
4510   // Bitcast the input vectors to integers in case they are FP
4511   VecContainerVT = VecContainerVT.changeTypeToInteger();
4512   EvenV = DAG.getBitcast(VecContainerVT, EvenV);
4513   OddV = DAG.getBitcast(VecContainerVT, OddV);
4514 
4515   auto [Mask, VL] = getDefaultVLOps(VecVT, VecContainerVT, DL, DAG, Subtarget);
4516   SDValue Passthru = DAG.getUNDEF(WideContainerVT);
4517 
4518   SDValue Interleaved;
4519   if (Subtarget.hasStdExtZvbb()) {
4520     // Interleaved = (OddV << VecVT.getScalarSizeInBits()) + EvenV.
4521     SDValue OffsetVec =
4522         DAG.getSplatVector(VecContainerVT, DL,
4523                            DAG.getConstant(VecVT.getScalarSizeInBits(), DL,
4524                                            Subtarget.getXLenVT()));
4525     Interleaved = DAG.getNode(RISCVISD::VWSLL_VL, DL, WideContainerVT, OddV,
4526                               OffsetVec, Passthru, Mask, VL);
4527     Interleaved = DAG.getNode(RISCVISD::VWADDU_W_VL, DL, WideContainerVT,
4528                               Interleaved, EvenV, Passthru, Mask, VL);
4529   } else {
4530     // Widen EvenV and OddV with 0s and add one copy of OddV to EvenV with
4531     // vwaddu.vv
4532     Interleaved = DAG.getNode(RISCVISD::VWADDU_VL, DL, WideContainerVT, EvenV,
4533                               OddV, Passthru, Mask, VL);
4534 
4535     // Then get OddV * by 2^(VecVT.getScalarSizeInBits() - 1)
4536     SDValue AllOnesVec = DAG.getSplatVector(
4537         VecContainerVT, DL, DAG.getAllOnesConstant(DL, Subtarget.getXLenVT()));
4538     SDValue OddsMul = DAG.getNode(RISCVISD::VWMULU_VL, DL, WideContainerVT,
4539                                   OddV, AllOnesVec, Passthru, Mask, VL);
4540 
4541     // Add the two together so we get
4542     //   (OddV * 0xff...ff) + (OddV + EvenV)
4543     // = (OddV * 0x100...00) + EvenV
4544     // = (OddV << VecVT.getScalarSizeInBits()) + EvenV
4545     // Note the ADD_VL and VLMULU_VL should get selected as vwmaccu.vx
4546     Interleaved = DAG.getNode(RISCVISD::ADD_VL, DL, WideContainerVT,
4547                               Interleaved, OddsMul, Passthru, Mask, VL);
4548   }
4549 
4550   // Bitcast from <vscale x n * ty*2> to <vscale x 2*n x ty>
4551   MVT ResultContainerVT = MVT::getVectorVT(
4552       VecVT.getVectorElementType(), // Make sure to use original type
4553       VecContainerVT.getVectorElementCount().multiplyCoefficientBy(2));
4554   Interleaved = DAG.getBitcast(ResultContainerVT, Interleaved);
4555 
4556   // Convert back to a fixed vector if needed
4557   MVT ResultVT =
4558       MVT::getVectorVT(VecVT.getVectorElementType(),
4559                        VecVT.getVectorElementCount().multiplyCoefficientBy(2));
4560   if (ResultVT.isFixedLengthVector())
4561     Interleaved =
4562         convertFromScalableVector(ResultVT, Interleaved, DAG, Subtarget);
4563 
4564   return Interleaved;
4565 }
4566 
4567 // If we have a vector of bits that we want to reverse, we can use a vbrev on a
4568 // larger element type, e.g. v32i1 can be reversed with a v1i32 bitreverse.
lowerBitreverseShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)4569 static SDValue lowerBitreverseShuffle(ShuffleVectorSDNode *SVN,
4570                                       SelectionDAG &DAG,
4571                                       const RISCVSubtarget &Subtarget) {
4572   SDLoc DL(SVN);
4573   MVT VT = SVN->getSimpleValueType(0);
4574   SDValue V = SVN->getOperand(0);
4575   unsigned NumElts = VT.getVectorNumElements();
4576 
4577   assert(VT.getVectorElementType() == MVT::i1);
4578 
4579   if (!ShuffleVectorInst::isReverseMask(SVN->getMask(),
4580                                         SVN->getMask().size()) ||
4581       !SVN->getOperand(1).isUndef())
4582     return SDValue();
4583 
4584   unsigned ViaEltSize = std::max((uint64_t)8, PowerOf2Ceil(NumElts));
4585   EVT ViaVT = EVT::getVectorVT(
4586       *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), ViaEltSize), 1);
4587   EVT ViaBitVT =
4588       EVT::getVectorVT(*DAG.getContext(), MVT::i1, ViaVT.getScalarSizeInBits());
4589 
4590   // If we don't have zvbb or the larger element type > ELEN, the operation will
4591   // be illegal.
4592   if (!Subtarget.getTargetLowering()->isOperationLegalOrCustom(ISD::BITREVERSE,
4593                                                                ViaVT) ||
4594       !Subtarget.getTargetLowering()->isTypeLegal(ViaBitVT))
4595     return SDValue();
4596 
4597   // If the bit vector doesn't fit exactly into the larger element type, we need
4598   // to insert it into the larger vector and then shift up the reversed bits
4599   // afterwards to get rid of the gap introduced.
4600   if (ViaEltSize > NumElts)
4601     V = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ViaBitVT, DAG.getUNDEF(ViaBitVT),
4602                     V, DAG.getVectorIdxConstant(0, DL));
4603 
4604   SDValue Res =
4605       DAG.getNode(ISD::BITREVERSE, DL, ViaVT, DAG.getBitcast(ViaVT, V));
4606 
4607   // Shift up the reversed bits if the vector didn't exactly fit into the larger
4608   // element type.
4609   if (ViaEltSize > NumElts)
4610     Res = DAG.getNode(ISD::SRL, DL, ViaVT, Res,
4611                       DAG.getConstant(ViaEltSize - NumElts, DL, ViaVT));
4612 
4613   Res = DAG.getBitcast(ViaBitVT, Res);
4614 
4615   if (ViaEltSize > NumElts)
4616     Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
4617                       DAG.getVectorIdxConstant(0, DL));
4618   return Res;
4619 }
4620 
4621 // Given a shuffle mask like <3, 0, 1, 2, 7, 4, 5, 6> for v8i8, we can
4622 // reinterpret it as a v2i32 and rotate it right by 8 instead. We can lower this
4623 // as a vror.vi if we have Zvkb, or otherwise as a vsll, vsrl and vor.
lowerVECTOR_SHUFFLEAsRotate(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)4624 static SDValue lowerVECTOR_SHUFFLEAsRotate(ShuffleVectorSDNode *SVN,
4625                                            SelectionDAG &DAG,
4626                                            const RISCVSubtarget &Subtarget) {
4627   SDLoc DL(SVN);
4628 
4629   EVT VT = SVN->getValueType(0);
4630   unsigned NumElts = VT.getVectorNumElements();
4631   unsigned EltSizeInBits = VT.getScalarSizeInBits();
4632   unsigned NumSubElts, RotateAmt;
4633   if (!ShuffleVectorInst::isBitRotateMask(SVN->getMask(), EltSizeInBits, 2,
4634                                           NumElts, NumSubElts, RotateAmt))
4635     return SDValue();
4636   MVT RotateVT = MVT::getVectorVT(MVT::getIntegerVT(EltSizeInBits * NumSubElts),
4637                                   NumElts / NumSubElts);
4638 
4639   // We might have a RotateVT that isn't legal, e.g. v4i64 on zve32x.
4640   if (!Subtarget.getTargetLowering()->isTypeLegal(RotateVT))
4641     return SDValue();
4642 
4643   SDValue Op = DAG.getBitcast(RotateVT, SVN->getOperand(0));
4644 
4645   SDValue Rotate;
4646   // A rotate of an i16 by 8 bits either direction is equivalent to a byteswap,
4647   // so canonicalize to vrev8.
4648   if (RotateVT.getScalarType() == MVT::i16 && RotateAmt == 8)
4649     Rotate = DAG.getNode(ISD::BSWAP, DL, RotateVT, Op);
4650   else
4651     Rotate = DAG.getNode(ISD::ROTL, DL, RotateVT, Op,
4652                          DAG.getConstant(RotateAmt, DL, RotateVT));
4653 
4654   return DAG.getBitcast(VT, Rotate);
4655 }
4656 
4657 // If compiling with an exactly known VLEN, see if we can split a
4658 // shuffle on m2 or larger into a small number of m1 sized shuffles
4659 // which write each destination registers exactly once.
lowerShuffleViaVRegSplitting(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)4660 static SDValue lowerShuffleViaVRegSplitting(ShuffleVectorSDNode *SVN,
4661                                             SelectionDAG &DAG,
4662                                             const RISCVSubtarget &Subtarget) {
4663   SDLoc DL(SVN);
4664   MVT VT = SVN->getSimpleValueType(0);
4665   SDValue V1 = SVN->getOperand(0);
4666   SDValue V2 = SVN->getOperand(1);
4667   ArrayRef<int> Mask = SVN->getMask();
4668   unsigned NumElts = VT.getVectorNumElements();
4669 
4670   // If we don't know exact data layout, not much we can do.  If this
4671   // is already m1 or smaller, no point in splitting further.
4672   const unsigned MinVLen = Subtarget.getRealMinVLen();
4673   const unsigned MaxVLen = Subtarget.getRealMaxVLen();
4674   if (MinVLen != MaxVLen || VT.getSizeInBits().getFixedValue() <= MinVLen)
4675     return SDValue();
4676 
4677   MVT ElemVT = VT.getVectorElementType();
4678   unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
4679   unsigned VRegsPerSrc = NumElts / ElemsPerVReg;
4680 
4681   SmallVector<std::pair<int, SmallVector<int>>>
4682     OutMasks(VRegsPerSrc, {-1, {}});
4683 
4684   // Check if our mask can be done as a 1-to-1 mapping from source
4685   // to destination registers in the group without needing to
4686   // write each destination more than once.
4687   for (unsigned DstIdx = 0; DstIdx < Mask.size(); DstIdx++) {
4688     int DstVecIdx = DstIdx / ElemsPerVReg;
4689     int DstSubIdx = DstIdx % ElemsPerVReg;
4690     int SrcIdx = Mask[DstIdx];
4691     if (SrcIdx < 0 || (unsigned)SrcIdx >= 2 * NumElts)
4692       continue;
4693     int SrcVecIdx = SrcIdx / ElemsPerVReg;
4694     int SrcSubIdx = SrcIdx % ElemsPerVReg;
4695     if (OutMasks[DstVecIdx].first == -1)
4696       OutMasks[DstVecIdx].first = SrcVecIdx;
4697     if (OutMasks[DstVecIdx].first != SrcVecIdx)
4698       // Note: This case could easily be handled by keeping track of a chain
4699       // of source values and generating two element shuffles below.  This is
4700       // less an implementation question, and more a profitability one.
4701       return SDValue();
4702 
4703     OutMasks[DstVecIdx].second.resize(ElemsPerVReg, -1);
4704     OutMasks[DstVecIdx].second[DstSubIdx] = SrcSubIdx;
4705   }
4706 
4707   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
4708   MVT OneRegVT = MVT::getVectorVT(ElemVT, ElemsPerVReg);
4709   MVT M1VT = getContainerForFixedLengthVector(DAG, OneRegVT, Subtarget);
4710   assert(M1VT == getLMUL1VT(M1VT));
4711   unsigned NumOpElts = M1VT.getVectorMinNumElements();
4712   SDValue Vec = DAG.getUNDEF(ContainerVT);
4713   // The following semantically builds up a fixed length concat_vector
4714   // of the component shuffle_vectors.  We eagerly lower to scalable here
4715   // to avoid DAG combining it back to a large shuffle_vector again.
4716   V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
4717   V2 = convertToScalableVector(ContainerVT, V2, DAG, Subtarget);
4718   for (unsigned DstVecIdx = 0 ; DstVecIdx < OutMasks.size(); DstVecIdx++) {
4719     auto &[SrcVecIdx, SrcSubMask] = OutMasks[DstVecIdx];
4720     if (SrcVecIdx == -1)
4721       continue;
4722     unsigned ExtractIdx = (SrcVecIdx % VRegsPerSrc) * NumOpElts;
4723     SDValue SrcVec = (unsigned)SrcVecIdx >= VRegsPerSrc ? V2 : V1;
4724     SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, SrcVec,
4725                                  DAG.getVectorIdxConstant(ExtractIdx, DL));
4726     SubVec = convertFromScalableVector(OneRegVT, SubVec, DAG, Subtarget);
4727     SubVec = DAG.getVectorShuffle(OneRegVT, DL, SubVec, SubVec, SrcSubMask);
4728     SubVec = convertToScalableVector(M1VT, SubVec, DAG, Subtarget);
4729     unsigned InsertIdx = DstVecIdx * NumOpElts;
4730     Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Vec, SubVec,
4731                       DAG.getVectorIdxConstant(InsertIdx, DL));
4732   }
4733   return convertFromScalableVector(VT, Vec, DAG, Subtarget);
4734 }
4735 
lowerVECTOR_SHUFFLE(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)4736 static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
4737                                    const RISCVSubtarget &Subtarget) {
4738   SDValue V1 = Op.getOperand(0);
4739   SDValue V2 = Op.getOperand(1);
4740   SDLoc DL(Op);
4741   MVT XLenVT = Subtarget.getXLenVT();
4742   MVT VT = Op.getSimpleValueType();
4743   unsigned NumElts = VT.getVectorNumElements();
4744   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
4745 
4746   if (VT.getVectorElementType() == MVT::i1) {
4747     // Lower to a vror.vi of a larger element type if possible before we promote
4748     // i1s to i8s.
4749     if (SDValue V = lowerVECTOR_SHUFFLEAsRotate(SVN, DAG, Subtarget))
4750       return V;
4751     if (SDValue V = lowerBitreverseShuffle(SVN, DAG, Subtarget))
4752       return V;
4753 
4754     // Promote i1 shuffle to i8 shuffle.
4755     MVT WidenVT = MVT::getVectorVT(MVT::i8, VT.getVectorElementCount());
4756     V1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WidenVT, V1);
4757     V2 = V2.isUndef() ? DAG.getUNDEF(WidenVT)
4758                       : DAG.getNode(ISD::ZERO_EXTEND, DL, WidenVT, V2);
4759     SDValue Shuffled = DAG.getVectorShuffle(WidenVT, DL, V1, V2, SVN->getMask());
4760     return DAG.getSetCC(DL, VT, Shuffled, DAG.getConstant(0, DL, WidenVT),
4761                         ISD::SETNE);
4762   }
4763 
4764   MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
4765 
4766   auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
4767 
4768   if (SVN->isSplat()) {
4769     const int Lane = SVN->getSplatIndex();
4770     if (Lane >= 0) {
4771       MVT SVT = VT.getVectorElementType();
4772 
4773       // Turn splatted vector load into a strided load with an X0 stride.
4774       SDValue V = V1;
4775       // Peek through CONCAT_VECTORS as VectorCombine can concat a vector
4776       // with undef.
4777       // FIXME: Peek through INSERT_SUBVECTOR, EXTRACT_SUBVECTOR, bitcasts?
4778       int Offset = Lane;
4779       if (V.getOpcode() == ISD::CONCAT_VECTORS) {
4780         int OpElements =
4781             V.getOperand(0).getSimpleValueType().getVectorNumElements();
4782         V = V.getOperand(Offset / OpElements);
4783         Offset %= OpElements;
4784       }
4785 
4786       // We need to ensure the load isn't atomic or volatile.
4787       if (ISD::isNormalLoad(V.getNode()) && cast<LoadSDNode>(V)->isSimple()) {
4788         auto *Ld = cast<LoadSDNode>(V);
4789         Offset *= SVT.getStoreSize();
4790         SDValue NewAddr = DAG.getMemBasePlusOffset(
4791             Ld->getBasePtr(), TypeSize::getFixed(Offset), DL);
4792 
4793         // If this is SEW=64 on RV32, use a strided load with a stride of x0.
4794         if (SVT.isInteger() && SVT.bitsGT(XLenVT)) {
4795           SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
4796           SDValue IntID =
4797               DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, XLenVT);
4798           SDValue Ops[] = {Ld->getChain(),
4799                            IntID,
4800                            DAG.getUNDEF(ContainerVT),
4801                            NewAddr,
4802                            DAG.getRegister(RISCV::X0, XLenVT),
4803                            VL};
4804           SDValue NewLoad = DAG.getMemIntrinsicNode(
4805               ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, SVT,
4806               DAG.getMachineFunction().getMachineMemOperand(
4807                   Ld->getMemOperand(), Offset, SVT.getStoreSize()));
4808           DAG.makeEquivalentMemoryOrdering(Ld, NewLoad);
4809           return convertFromScalableVector(VT, NewLoad, DAG, Subtarget);
4810         }
4811 
4812         // Otherwise use a scalar load and splat. This will give the best
4813         // opportunity to fold a splat into the operation. ISel can turn it into
4814         // the x0 strided load if we aren't able to fold away the select.
4815         if (SVT.isFloatingPoint())
4816           V = DAG.getLoad(SVT, DL, Ld->getChain(), NewAddr,
4817                           Ld->getPointerInfo().getWithOffset(Offset),
4818                           Ld->getOriginalAlign(),
4819                           Ld->getMemOperand()->getFlags());
4820         else
4821           V = DAG.getExtLoad(ISD::SEXTLOAD, DL, XLenVT, Ld->getChain(), NewAddr,
4822                              Ld->getPointerInfo().getWithOffset(Offset), SVT,
4823                              Ld->getOriginalAlign(),
4824                              Ld->getMemOperand()->getFlags());
4825         DAG.makeEquivalentMemoryOrdering(Ld, V);
4826 
4827         unsigned Opc =
4828             VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL : RISCVISD::VMV_V_X_VL;
4829         SDValue Splat =
4830             DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), V, VL);
4831         return convertFromScalableVector(VT, Splat, DAG, Subtarget);
4832       }
4833 
4834       V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
4835       assert(Lane < (int)NumElts && "Unexpected lane!");
4836       SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT,
4837                                    V1, DAG.getConstant(Lane, DL, XLenVT),
4838                                    DAG.getUNDEF(ContainerVT), TrueMask, VL);
4839       return convertFromScalableVector(VT, Gather, DAG, Subtarget);
4840     }
4841   }
4842 
4843   // For exact VLEN m2 or greater, try to split to m1 operations if we
4844   // can split cleanly.
4845   if (SDValue V = lowerShuffleViaVRegSplitting(SVN, DAG, Subtarget))
4846     return V;
4847 
4848   ArrayRef<int> Mask = SVN->getMask();
4849 
4850   if (SDValue V =
4851           lowerVECTOR_SHUFFLEAsVSlide1(DL, VT, V1, V2, Mask, Subtarget, DAG))
4852     return V;
4853 
4854   if (SDValue V =
4855           lowerVECTOR_SHUFFLEAsVSlidedown(DL, VT, V1, V2, Mask, Subtarget, DAG))
4856     return V;
4857 
4858   // A bitrotate will be one instruction on Zvkb, so try to lower to it first if
4859   // available.
4860   if (Subtarget.hasStdExtZvkb())
4861     if (SDValue V = lowerVECTOR_SHUFFLEAsRotate(SVN, DAG, Subtarget))
4862       return V;
4863 
4864   // Lower rotations to a SLIDEDOWN and a SLIDEUP. One of the source vectors may
4865   // be undef which can be handled with a single SLIDEDOWN/UP.
4866   int LoSrc, HiSrc;
4867   int Rotation = isElementRotate(LoSrc, HiSrc, Mask);
4868   if (Rotation > 0) {
4869     SDValue LoV, HiV;
4870     if (LoSrc >= 0) {
4871       LoV = LoSrc == 0 ? V1 : V2;
4872       LoV = convertToScalableVector(ContainerVT, LoV, DAG, Subtarget);
4873     }
4874     if (HiSrc >= 0) {
4875       HiV = HiSrc == 0 ? V1 : V2;
4876       HiV = convertToScalableVector(ContainerVT, HiV, DAG, Subtarget);
4877     }
4878 
4879     // We found a rotation. We need to slide HiV down by Rotation. Then we need
4880     // to slide LoV up by (NumElts - Rotation).
4881     unsigned InvRotate = NumElts - Rotation;
4882 
4883     SDValue Res = DAG.getUNDEF(ContainerVT);
4884     if (HiV) {
4885       // Even though we could use a smaller VL, don't to avoid a vsetivli
4886       // toggle.
4887       Res = getVSlidedown(DAG, Subtarget, DL, ContainerVT, Res, HiV,
4888                           DAG.getConstant(Rotation, DL, XLenVT), TrueMask, VL);
4889     }
4890     if (LoV)
4891       Res = getVSlideup(DAG, Subtarget, DL, ContainerVT, Res, LoV,
4892                         DAG.getConstant(InvRotate, DL, XLenVT), TrueMask, VL,
4893                         RISCVII::TAIL_AGNOSTIC);
4894 
4895     return convertFromScalableVector(VT, Res, DAG, Subtarget);
4896   }
4897 
4898   // If this is a deinterleave and we can widen the vector, then we can use
4899   // vnsrl to deinterleave.
4900   if (isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget)) {
4901     return getDeinterleaveViaVNSRL(DL, VT, V1.getOperand(0), Mask[0] == 0,
4902                                    Subtarget, DAG);
4903   }
4904 
4905   if (SDValue V =
4906           lowerVECTOR_SHUFFLEAsVSlideup(DL, VT, V1, V2, Mask, Subtarget, DAG))
4907     return V;
4908 
4909   // Detect an interleave shuffle and lower to
4910   // (vmaccu.vx (vwaddu.vx lohalf(V1), lohalf(V2)), lohalf(V2), (2^eltbits - 1))
4911   int EvenSrc, OddSrc;
4912   if (isInterleaveShuffle(Mask, VT, EvenSrc, OddSrc, Subtarget)) {
4913     // Extract the halves of the vectors.
4914     MVT HalfVT = VT.getHalfNumVectorElementsVT();
4915 
4916     int Size = Mask.size();
4917     SDValue EvenV, OddV;
4918     assert(EvenSrc >= 0 && "Undef source?");
4919     EvenV = (EvenSrc / Size) == 0 ? V1 : V2;
4920     EvenV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, EvenV,
4921                         DAG.getConstant(EvenSrc % Size, DL, XLenVT));
4922 
4923     assert(OddSrc >= 0 && "Undef source?");
4924     OddV = (OddSrc / Size) == 0 ? V1 : V2;
4925     OddV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, OddV,
4926                        DAG.getConstant(OddSrc % Size, DL, XLenVT));
4927 
4928     return getWideningInterleave(EvenV, OddV, DL, DAG, Subtarget);
4929   }
4930 
4931   // Detect shuffles which can be re-expressed as vector selects; these are
4932   // shuffles in which each element in the destination is taken from an element
4933   // at the corresponding index in either source vectors.
4934   bool IsSelect = all_of(enumerate(Mask), [&](const auto &MaskIdx) {
4935     int MaskIndex = MaskIdx.value();
4936     return MaskIndex < 0 || MaskIdx.index() == (unsigned)MaskIndex % NumElts;
4937   });
4938 
4939   assert(!V1.isUndef() && "Unexpected shuffle canonicalization");
4940 
4941   // By default we preserve the original operand order, and use a mask to
4942   // select LHS as true and RHS as false. However, since RVV vector selects may
4943   // feature splats but only on the LHS, we may choose to invert our mask and
4944   // instead select between RHS and LHS.
4945   bool SwapOps = DAG.isSplatValue(V2) && !DAG.isSplatValue(V1);
4946 
4947   if (IsSelect) {
4948     // Now construct the mask that will be used by the vselect operation.
4949     SmallVector<SDValue> MaskVals;
4950     for (int MaskIndex : Mask) {
4951       bool SelectMaskVal = (MaskIndex < (int)NumElts) ^ SwapOps;
4952       MaskVals.push_back(DAG.getConstant(SelectMaskVal, DL, XLenVT));
4953     }
4954 
4955     if (SwapOps)
4956       std::swap(V1, V2);
4957 
4958     assert(MaskVals.size() == NumElts && "Unexpected select-like shuffle");
4959     MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
4960     SDValue SelectMask = DAG.getBuildVector(MaskVT, DL, MaskVals);
4961     return DAG.getNode(ISD::VSELECT, DL, VT, SelectMask, V1, V2);
4962   }
4963 
4964   // We might be able to express the shuffle as a bitrotate. But even if we
4965   // don't have Zvkb and have to expand, the expanded sequence of approx. 2
4966   // shifts and a vor will have a higher throughput than a vrgather.
4967   if (SDValue V = lowerVECTOR_SHUFFLEAsRotate(SVN, DAG, Subtarget))
4968     return V;
4969 
4970   if (VT.getScalarSizeInBits() == 8 && VT.getVectorNumElements() > 256) {
4971     // On such a large vector we're unable to use i8 as the index type.
4972     // FIXME: We could promote the index to i16 and use vrgatherei16, but that
4973     // may involve vector splitting if we're already at LMUL=8, or our
4974     // user-supplied maximum fixed-length LMUL.
4975     return SDValue();
4976   }
4977 
4978   // As a backup, shuffles can be lowered via a vrgather instruction, possibly
4979   // merged with a second vrgather.
4980   SmallVector<SDValue> GatherIndicesLHS, GatherIndicesRHS;
4981 
4982   // Keep a track of which non-undef indices are used by each LHS/RHS shuffle
4983   // half.
4984   DenseMap<int, unsigned> LHSIndexCounts, RHSIndexCounts;
4985 
4986   SmallVector<SDValue> MaskVals;
4987 
4988   // Now construct the mask that will be used by the blended vrgather operation.
4989   // Cconstruct the appropriate indices into each vector.
4990   for (int MaskIndex : Mask) {
4991     bool SelectMaskVal = (MaskIndex < (int)NumElts) ^ !SwapOps;
4992     MaskVals.push_back(DAG.getConstant(SelectMaskVal, DL, XLenVT));
4993     bool IsLHSOrUndefIndex = MaskIndex < (int)NumElts;
4994     GatherIndicesLHS.push_back(IsLHSOrUndefIndex && MaskIndex >= 0
4995                                ? DAG.getConstant(MaskIndex, DL, XLenVT)
4996                                : DAG.getUNDEF(XLenVT));
4997     GatherIndicesRHS.push_back(
4998                                IsLHSOrUndefIndex ? DAG.getUNDEF(XLenVT)
4999                                : DAG.getConstant(MaskIndex - NumElts, DL, XLenVT));
5000     if (IsLHSOrUndefIndex && MaskIndex >= 0)
5001       ++LHSIndexCounts[MaskIndex];
5002     if (!IsLHSOrUndefIndex)
5003       ++RHSIndexCounts[MaskIndex - NumElts];
5004   }
5005 
5006   if (SwapOps) {
5007     std::swap(V1, V2);
5008     std::swap(GatherIndicesLHS, GatherIndicesRHS);
5009   }
5010 
5011   assert(MaskVals.size() == NumElts && "Unexpected select-like shuffle");
5012   MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
5013   SDValue SelectMask = DAG.getBuildVector(MaskVT, DL, MaskVals);
5014 
5015   unsigned GatherVXOpc = RISCVISD::VRGATHER_VX_VL;
5016   unsigned GatherVVOpc = RISCVISD::VRGATHER_VV_VL;
5017   MVT IndexVT = VT.changeTypeToInteger();
5018   // Since we can't introduce illegal index types at this stage, use i16 and
5019   // vrgatherei16 if the corresponding index type for plain vrgather is greater
5020   // than XLenVT.
5021   if (IndexVT.getScalarType().bitsGT(XLenVT)) {
5022     GatherVVOpc = RISCVISD::VRGATHEREI16_VV_VL;
5023     IndexVT = IndexVT.changeVectorElementType(MVT::i16);
5024   }
5025 
5026   // If the mask allows, we can do all the index computation in 16 bits.  This
5027   // requires less work and less register pressure at high LMUL, and creates
5028   // smaller constants which may be cheaper to materialize.
5029   if (IndexVT.getScalarType().bitsGT(MVT::i16) && isUInt<16>(NumElts - 1) &&
5030       (IndexVT.getSizeInBits() / Subtarget.getRealMinVLen()) > 1) {
5031     GatherVVOpc = RISCVISD::VRGATHEREI16_VV_VL;
5032     IndexVT = IndexVT.changeVectorElementType(MVT::i16);
5033   }
5034 
5035   MVT IndexContainerVT =
5036       ContainerVT.changeVectorElementType(IndexVT.getScalarType());
5037 
5038   SDValue Gather;
5039   // TODO: This doesn't trigger for i64 vectors on RV32, since there we
5040   // encounter a bitcasted BUILD_VECTOR with low/high i32 values.
5041   if (SDValue SplatValue = DAG.getSplatValue(V1, /*LegalTypes*/ true)) {
5042     Gather = lowerScalarSplat(SDValue(), SplatValue, VL, ContainerVT, DL, DAG,
5043                               Subtarget);
5044   } else {
5045     V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
5046     // If only one index is used, we can use a "splat" vrgather.
5047     // TODO: We can splat the most-common index and fix-up any stragglers, if
5048     // that's beneficial.
5049     if (LHSIndexCounts.size() == 1) {
5050       int SplatIndex = LHSIndexCounts.begin()->getFirst();
5051       Gather = DAG.getNode(GatherVXOpc, DL, ContainerVT, V1,
5052                            DAG.getConstant(SplatIndex, DL, XLenVT),
5053                            DAG.getUNDEF(ContainerVT), TrueMask, VL);
5054     } else {
5055       SDValue LHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesLHS);
5056       LHSIndices =
5057           convertToScalableVector(IndexContainerVT, LHSIndices, DAG, Subtarget);
5058 
5059       Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices,
5060                            DAG.getUNDEF(ContainerVT), TrueMask, VL);
5061     }
5062   }
5063 
5064   // If a second vector operand is used by this shuffle, blend it in with an
5065   // additional vrgather.
5066   if (!V2.isUndef()) {
5067     V2 = convertToScalableVector(ContainerVT, V2, DAG, Subtarget);
5068 
5069     MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1);
5070     SelectMask =
5071         convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget);
5072 
5073     // If only one index is used, we can use a "splat" vrgather.
5074     // TODO: We can splat the most-common index and fix-up any stragglers, if
5075     // that's beneficial.
5076     if (RHSIndexCounts.size() == 1) {
5077       int SplatIndex = RHSIndexCounts.begin()->getFirst();
5078       Gather = DAG.getNode(GatherVXOpc, DL, ContainerVT, V2,
5079                            DAG.getConstant(SplatIndex, DL, XLenVT), Gather,
5080                            SelectMask, VL);
5081     } else {
5082       SDValue RHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesRHS);
5083       RHSIndices =
5084           convertToScalableVector(IndexContainerVT, RHSIndices, DAG, Subtarget);
5085       Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V2, RHSIndices, Gather,
5086                            SelectMask, VL);
5087     }
5088   }
5089 
5090   return convertFromScalableVector(VT, Gather, DAG, Subtarget);
5091 }
5092 
isShuffleMaskLegal(ArrayRef<int> M,EVT VT) const5093 bool RISCVTargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
5094   // Support splats for any type. These should type legalize well.
5095   if (ShuffleVectorSDNode::isSplatMask(M.data(), VT))
5096     return true;
5097 
5098   // Only support legal VTs for other shuffles for now.
5099   if (!isTypeLegal(VT))
5100     return false;
5101 
5102   MVT SVT = VT.getSimpleVT();
5103 
5104   // Not for i1 vectors.
5105   if (SVT.getScalarType() == MVT::i1)
5106     return false;
5107 
5108   int Dummy1, Dummy2;
5109   return (isElementRotate(Dummy1, Dummy2, M) > 0) ||
5110          isInterleaveShuffle(M, SVT, Dummy1, Dummy2, Subtarget);
5111 }
5112 
5113 // Lower CTLZ_ZERO_UNDEF or CTTZ_ZERO_UNDEF by converting to FP and extracting
5114 // the exponent.
5115 SDValue
lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,SelectionDAG & DAG) const5116 RISCVTargetLowering::lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,
5117                                                SelectionDAG &DAG) const {
5118   MVT VT = Op.getSimpleValueType();
5119   unsigned EltSize = VT.getScalarSizeInBits();
5120   SDValue Src = Op.getOperand(0);
5121   SDLoc DL(Op);
5122   MVT ContainerVT = VT;
5123 
5124   SDValue Mask, VL;
5125   if (Op->isVPOpcode()) {
5126     Mask = Op.getOperand(1);
5127     if (VT.isFixedLengthVector())
5128       Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
5129                                      Subtarget);
5130     VL = Op.getOperand(2);
5131   }
5132 
5133   // We choose FP type that can represent the value if possible. Otherwise, we
5134   // use rounding to zero conversion for correct exponent of the result.
5135   // TODO: Use f16 for i8 when possible?
5136   MVT FloatEltVT = (EltSize >= 32) ? MVT::f64 : MVT::f32;
5137   if (!isTypeLegal(MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount())))
5138     FloatEltVT = MVT::f32;
5139   MVT FloatVT = MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount());
5140 
5141   // Legal types should have been checked in the RISCVTargetLowering
5142   // constructor.
5143   // TODO: Splitting may make sense in some cases.
5144   assert(DAG.getTargetLoweringInfo().isTypeLegal(FloatVT) &&
5145          "Expected legal float type!");
5146 
5147   // For CTTZ_ZERO_UNDEF, we need to extract the lowest set bit using X & -X.
5148   // The trailing zero count is equal to log2 of this single bit value.
5149   if (Op.getOpcode() == ISD::CTTZ_ZERO_UNDEF) {
5150     SDValue Neg = DAG.getNegative(Src, DL, VT);
5151     Src = DAG.getNode(ISD::AND, DL, VT, Src, Neg);
5152   } else if (Op.getOpcode() == ISD::VP_CTTZ_ZERO_UNDEF) {
5153     SDValue Neg = DAG.getNode(ISD::VP_SUB, DL, VT, DAG.getConstant(0, DL, VT),
5154                               Src, Mask, VL);
5155     Src = DAG.getNode(ISD::VP_AND, DL, VT, Src, Neg, Mask, VL);
5156   }
5157 
5158   // We have a legal FP type, convert to it.
5159   SDValue FloatVal;
5160   if (FloatVT.bitsGT(VT)) {
5161     if (Op->isVPOpcode())
5162       FloatVal = DAG.getNode(ISD::VP_UINT_TO_FP, DL, FloatVT, Src, Mask, VL);
5163     else
5164       FloatVal = DAG.getNode(ISD::UINT_TO_FP, DL, FloatVT, Src);
5165   } else {
5166     // Use RTZ to avoid rounding influencing exponent of FloatVal.
5167     if (VT.isFixedLengthVector()) {
5168       ContainerVT = getContainerForFixedLengthVector(VT);
5169       Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
5170     }
5171     if (!Op->isVPOpcode())
5172       std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
5173     SDValue RTZRM =
5174         DAG.getTargetConstant(RISCVFPRndMode::RTZ, DL, Subtarget.getXLenVT());
5175     MVT ContainerFloatVT =
5176         MVT::getVectorVT(FloatEltVT, ContainerVT.getVectorElementCount());
5177     FloatVal = DAG.getNode(RISCVISD::VFCVT_RM_F_XU_VL, DL, ContainerFloatVT,
5178                            Src, Mask, RTZRM, VL);
5179     if (VT.isFixedLengthVector())
5180       FloatVal = convertFromScalableVector(FloatVT, FloatVal, DAG, Subtarget);
5181   }
5182   // Bitcast to integer and shift the exponent to the LSB.
5183   EVT IntVT = FloatVT.changeVectorElementTypeToInteger();
5184   SDValue Bitcast = DAG.getBitcast(IntVT, FloatVal);
5185   unsigned ShiftAmt = FloatEltVT == MVT::f64 ? 52 : 23;
5186 
5187   SDValue Exp;
5188   // Restore back to original type. Truncation after SRL is to generate vnsrl.
5189   if (Op->isVPOpcode()) {
5190     Exp = DAG.getNode(ISD::VP_LSHR, DL, IntVT, Bitcast,
5191                       DAG.getConstant(ShiftAmt, DL, IntVT), Mask, VL);
5192     Exp = DAG.getVPZExtOrTrunc(DL, VT, Exp, Mask, VL);
5193   } else {
5194     Exp = DAG.getNode(ISD::SRL, DL, IntVT, Bitcast,
5195                       DAG.getConstant(ShiftAmt, DL, IntVT));
5196     if (IntVT.bitsLT(VT))
5197       Exp = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Exp);
5198     else if (IntVT.bitsGT(VT))
5199       Exp = DAG.getNode(ISD::TRUNCATE, DL, VT, Exp);
5200   }
5201 
5202   // The exponent contains log2 of the value in biased form.
5203   unsigned ExponentBias = FloatEltVT == MVT::f64 ? 1023 : 127;
5204   // For trailing zeros, we just need to subtract the bias.
5205   if (Op.getOpcode() == ISD::CTTZ_ZERO_UNDEF)
5206     return DAG.getNode(ISD::SUB, DL, VT, Exp,
5207                        DAG.getConstant(ExponentBias, DL, VT));
5208   if (Op.getOpcode() == ISD::VP_CTTZ_ZERO_UNDEF)
5209     return DAG.getNode(ISD::VP_SUB, DL, VT, Exp,
5210                        DAG.getConstant(ExponentBias, DL, VT), Mask, VL);
5211 
5212   // For leading zeros, we need to remove the bias and convert from log2 to
5213   // leading zeros. We can do this by subtracting from (Bias + (EltSize - 1)).
5214   unsigned Adjust = ExponentBias + (EltSize - 1);
5215   SDValue Res;
5216   if (Op->isVPOpcode())
5217     Res = DAG.getNode(ISD::VP_SUB, DL, VT, DAG.getConstant(Adjust, DL, VT), Exp,
5218                       Mask, VL);
5219   else
5220     Res = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(Adjust, DL, VT), Exp);
5221 
5222   // The above result with zero input equals to Adjust which is greater than
5223   // EltSize. Hence, we can do min(Res, EltSize) for CTLZ.
5224   if (Op.getOpcode() == ISD::CTLZ)
5225     Res = DAG.getNode(ISD::UMIN, DL, VT, Res, DAG.getConstant(EltSize, DL, VT));
5226   else if (Op.getOpcode() == ISD::VP_CTLZ)
5227     Res = DAG.getNode(ISD::VP_UMIN, DL, VT, Res,
5228                       DAG.getConstant(EltSize, DL, VT), Mask, VL);
5229   return Res;
5230 }
5231 
5232 // While RVV has alignment restrictions, we should always be able to load as a
5233 // legal equivalently-sized byte-typed vector instead. This method is
5234 // responsible for re-expressing a ISD::LOAD via a correctly-aligned type. If
5235 // the load is already correctly-aligned, it returns SDValue().
expandUnalignedRVVLoad(SDValue Op,SelectionDAG & DAG) const5236 SDValue RISCVTargetLowering::expandUnalignedRVVLoad(SDValue Op,
5237                                                     SelectionDAG &DAG) const {
5238   auto *Load = cast<LoadSDNode>(Op);
5239   assert(Load && Load->getMemoryVT().isVector() && "Expected vector load");
5240 
5241   if (allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
5242                                      Load->getMemoryVT(),
5243                                      *Load->getMemOperand()))
5244     return SDValue();
5245 
5246   SDLoc DL(Op);
5247   MVT VT = Op.getSimpleValueType();
5248   unsigned EltSizeBits = VT.getScalarSizeInBits();
5249   assert((EltSizeBits == 16 || EltSizeBits == 32 || EltSizeBits == 64) &&
5250          "Unexpected unaligned RVV load type");
5251   MVT NewVT =
5252       MVT::getVectorVT(MVT::i8, VT.getVectorElementCount() * (EltSizeBits / 8));
5253   assert(NewVT.isValid() &&
5254          "Expecting equally-sized RVV vector types to be legal");
5255   SDValue L = DAG.getLoad(NewVT, DL, Load->getChain(), Load->getBasePtr(),
5256                           Load->getPointerInfo(), Load->getOriginalAlign(),
5257                           Load->getMemOperand()->getFlags());
5258   return DAG.getMergeValues({DAG.getBitcast(VT, L), L.getValue(1)}, DL);
5259 }
5260 
5261 // While RVV has alignment restrictions, we should always be able to store as a
5262 // legal equivalently-sized byte-typed vector instead. This method is
5263 // responsible for re-expressing a ISD::STORE via a correctly-aligned type. It
5264 // returns SDValue() if the store is already correctly aligned.
expandUnalignedRVVStore(SDValue Op,SelectionDAG & DAG) const5265 SDValue RISCVTargetLowering::expandUnalignedRVVStore(SDValue Op,
5266                                                      SelectionDAG &DAG) const {
5267   auto *Store = cast<StoreSDNode>(Op);
5268   assert(Store && Store->getValue().getValueType().isVector() &&
5269          "Expected vector store");
5270 
5271   if (allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
5272                                      Store->getMemoryVT(),
5273                                      *Store->getMemOperand()))
5274     return SDValue();
5275 
5276   SDLoc DL(Op);
5277   SDValue StoredVal = Store->getValue();
5278   MVT VT = StoredVal.getSimpleValueType();
5279   unsigned EltSizeBits = VT.getScalarSizeInBits();
5280   assert((EltSizeBits == 16 || EltSizeBits == 32 || EltSizeBits == 64) &&
5281          "Unexpected unaligned RVV store type");
5282   MVT NewVT =
5283       MVT::getVectorVT(MVT::i8, VT.getVectorElementCount() * (EltSizeBits / 8));
5284   assert(NewVT.isValid() &&
5285          "Expecting equally-sized RVV vector types to be legal");
5286   StoredVal = DAG.getBitcast(NewVT, StoredVal);
5287   return DAG.getStore(Store->getChain(), DL, StoredVal, Store->getBasePtr(),
5288                       Store->getPointerInfo(), Store->getOriginalAlign(),
5289                       Store->getMemOperand()->getFlags());
5290 }
5291 
lowerConstant(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)5292 static SDValue lowerConstant(SDValue Op, SelectionDAG &DAG,
5293                              const RISCVSubtarget &Subtarget) {
5294   assert(Op.getValueType() == MVT::i64 && "Unexpected VT");
5295 
5296   int64_t Imm = cast<ConstantSDNode>(Op)->getSExtValue();
5297 
5298   // All simm32 constants should be handled by isel.
5299   // NOTE: The getMaxBuildIntsCost call below should return a value >= 2 making
5300   // this check redundant, but small immediates are common so this check
5301   // should have better compile time.
5302   if (isInt<32>(Imm))
5303     return Op;
5304 
5305   // We only need to cost the immediate, if constant pool lowering is enabled.
5306   if (!Subtarget.useConstantPoolForLargeInts())
5307     return Op;
5308 
5309   RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Imm, Subtarget);
5310   if (Seq.size() <= Subtarget.getMaxBuildIntsCost())
5311     return Op;
5312 
5313   // Optimizations below are disabled for opt size. If we're optimizing for
5314   // size, use a constant pool.
5315   if (DAG.shouldOptForSize())
5316     return SDValue();
5317 
5318   // Special case. See if we can build the constant as (ADD (SLLI X, C), X) do
5319   // that if it will avoid a constant pool.
5320   // It will require an extra temporary register though.
5321   // If we have Zba we can use (ADD_UW X, (SLLI X, 32)) to handle cases where
5322   // low and high 32 bits are the same and bit 31 and 63 are set.
5323   unsigned ShiftAmt, AddOpc;
5324   RISCVMatInt::InstSeq SeqLo =
5325       RISCVMatInt::generateTwoRegInstSeq(Imm, Subtarget, ShiftAmt, AddOpc);
5326   if (!SeqLo.empty() && (SeqLo.size() + 2) <= Subtarget.getMaxBuildIntsCost())
5327     return Op;
5328 
5329   return SDValue();
5330 }
5331 
LowerATOMIC_FENCE(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)5332 static SDValue LowerATOMIC_FENCE(SDValue Op, SelectionDAG &DAG,
5333                                  const RISCVSubtarget &Subtarget) {
5334   SDLoc dl(Op);
5335   AtomicOrdering FenceOrdering =
5336       static_cast<AtomicOrdering>(Op.getConstantOperandVal(1));
5337   SyncScope::ID FenceSSID =
5338       static_cast<SyncScope::ID>(Op.getConstantOperandVal(2));
5339 
5340   if (Subtarget.hasStdExtZtso()) {
5341     // The only fence that needs an instruction is a sequentially-consistent
5342     // cross-thread fence.
5343     if (FenceOrdering == AtomicOrdering::SequentiallyConsistent &&
5344         FenceSSID == SyncScope::System)
5345       return Op;
5346 
5347     // MEMBARRIER is a compiler barrier; it codegens to a no-op.
5348     return DAG.getNode(ISD::MEMBARRIER, dl, MVT::Other, Op.getOperand(0));
5349   }
5350 
5351   // singlethread fences only synchronize with signal handlers on the same
5352   // thread and thus only need to preserve instruction order, not actually
5353   // enforce memory ordering.
5354   if (FenceSSID == SyncScope::SingleThread)
5355     // MEMBARRIER is a compiler barrier; it codegens to a no-op.
5356     return DAG.getNode(ISD::MEMBARRIER, dl, MVT::Other, Op.getOperand(0));
5357 
5358   return Op;
5359 }
5360 
LowerIS_FPCLASS(SDValue Op,SelectionDAG & DAG) const5361 SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op,
5362                                              SelectionDAG &DAG) const {
5363   SDLoc DL(Op);
5364   MVT VT = Op.getSimpleValueType();
5365   MVT XLenVT = Subtarget.getXLenVT();
5366   unsigned Check = Op.getConstantOperandVal(1);
5367   unsigned TDCMask = 0;
5368   if (Check & fcSNan)
5369     TDCMask |= RISCV::FPMASK_Signaling_NaN;
5370   if (Check & fcQNan)
5371     TDCMask |= RISCV::FPMASK_Quiet_NaN;
5372   if (Check & fcPosInf)
5373     TDCMask |= RISCV::FPMASK_Positive_Infinity;
5374   if (Check & fcNegInf)
5375     TDCMask |= RISCV::FPMASK_Negative_Infinity;
5376   if (Check & fcPosNormal)
5377     TDCMask |= RISCV::FPMASK_Positive_Normal;
5378   if (Check & fcNegNormal)
5379     TDCMask |= RISCV::FPMASK_Negative_Normal;
5380   if (Check & fcPosSubnormal)
5381     TDCMask |= RISCV::FPMASK_Positive_Subnormal;
5382   if (Check & fcNegSubnormal)
5383     TDCMask |= RISCV::FPMASK_Negative_Subnormal;
5384   if (Check & fcPosZero)
5385     TDCMask |= RISCV::FPMASK_Positive_Zero;
5386   if (Check & fcNegZero)
5387     TDCMask |= RISCV::FPMASK_Negative_Zero;
5388 
5389   bool IsOneBitMask = isPowerOf2_32(TDCMask);
5390 
5391   SDValue TDCMaskV = DAG.getConstant(TDCMask, DL, XLenVT);
5392 
5393   if (VT.isVector()) {
5394     SDValue Op0 = Op.getOperand(0);
5395     MVT VT0 = Op.getOperand(0).getSimpleValueType();
5396 
5397     if (VT.isScalableVector()) {
5398       MVT DstVT = VT0.changeVectorElementTypeToInteger();
5399       auto [Mask, VL] = getDefaultScalableVLOps(VT0, DL, DAG, Subtarget);
5400       if (Op.getOpcode() == ISD::VP_IS_FPCLASS) {
5401         Mask = Op.getOperand(2);
5402         VL = Op.getOperand(3);
5403       }
5404       SDValue FPCLASS = DAG.getNode(RISCVISD::FCLASS_VL, DL, DstVT, Op0, Mask,
5405                                     VL, Op->getFlags());
5406       if (IsOneBitMask)
5407         return DAG.getSetCC(DL, VT, FPCLASS,
5408                             DAG.getConstant(TDCMask, DL, DstVT),
5409                             ISD::CondCode::SETEQ);
5410       SDValue AND = DAG.getNode(ISD::AND, DL, DstVT, FPCLASS,
5411                                 DAG.getConstant(TDCMask, DL, DstVT));
5412       return DAG.getSetCC(DL, VT, AND, DAG.getConstant(0, DL, DstVT),
5413                           ISD::SETNE);
5414     }
5415 
5416     MVT ContainerVT0 = getContainerForFixedLengthVector(VT0);
5417     MVT ContainerVT = getContainerForFixedLengthVector(VT);
5418     MVT ContainerDstVT = ContainerVT0.changeVectorElementTypeToInteger();
5419     auto [Mask, VL] = getDefaultVLOps(VT0, ContainerVT0, DL, DAG, Subtarget);
5420     if (Op.getOpcode() == ISD::VP_IS_FPCLASS) {
5421       Mask = Op.getOperand(2);
5422       MVT MaskContainerVT =
5423           getContainerForFixedLengthVector(Mask.getSimpleValueType());
5424       Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
5425       VL = Op.getOperand(3);
5426     }
5427     Op0 = convertToScalableVector(ContainerVT0, Op0, DAG, Subtarget);
5428 
5429     SDValue FPCLASS = DAG.getNode(RISCVISD::FCLASS_VL, DL, ContainerDstVT, Op0,
5430                                   Mask, VL, Op->getFlags());
5431 
5432     TDCMaskV = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerDstVT,
5433                            DAG.getUNDEF(ContainerDstVT), TDCMaskV, VL);
5434     if (IsOneBitMask) {
5435       SDValue VMSEQ =
5436           DAG.getNode(RISCVISD::SETCC_VL, DL, ContainerVT,
5437                       {FPCLASS, TDCMaskV, DAG.getCondCode(ISD::SETEQ),
5438                        DAG.getUNDEF(ContainerVT), Mask, VL});
5439       return convertFromScalableVector(VT, VMSEQ, DAG, Subtarget);
5440     }
5441     SDValue AND = DAG.getNode(RISCVISD::AND_VL, DL, ContainerDstVT, FPCLASS,
5442                               TDCMaskV, DAG.getUNDEF(ContainerDstVT), Mask, VL);
5443 
5444     SDValue SplatZero = DAG.getConstant(0, DL, XLenVT);
5445     SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerDstVT,
5446                             DAG.getUNDEF(ContainerDstVT), SplatZero, VL);
5447 
5448     SDValue VMSNE = DAG.getNode(RISCVISD::SETCC_VL, DL, ContainerVT,
5449                                 {AND, SplatZero, DAG.getCondCode(ISD::SETNE),
5450                                  DAG.getUNDEF(ContainerVT), Mask, VL});
5451     return convertFromScalableVector(VT, VMSNE, DAG, Subtarget);
5452   }
5453 
5454   SDValue FCLASS = DAG.getNode(RISCVISD::FCLASS, DL, XLenVT, Op.getOperand(0));
5455   SDValue AND = DAG.getNode(ISD::AND, DL, XLenVT, FCLASS, TDCMaskV);
5456   SDValue Res = DAG.getSetCC(DL, XLenVT, AND, DAG.getConstant(0, DL, XLenVT),
5457                              ISD::CondCode::SETNE);
5458   return DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
5459 }
5460 
5461 // Lower fmaximum and fminimum. Unlike our fmax and fmin instructions, these
5462 // operations propagate nans.
lowerFMAXIMUM_FMINIMUM(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)5463 static SDValue lowerFMAXIMUM_FMINIMUM(SDValue Op, SelectionDAG &DAG,
5464                                       const RISCVSubtarget &Subtarget) {
5465   SDLoc DL(Op);
5466   MVT VT = Op.getSimpleValueType();
5467 
5468   SDValue X = Op.getOperand(0);
5469   SDValue Y = Op.getOperand(1);
5470 
5471   if (!VT.isVector()) {
5472     MVT XLenVT = Subtarget.getXLenVT();
5473 
5474     // If X is a nan, replace Y with X. If Y is a nan, replace X with Y. This
5475     // ensures that when one input is a nan, the other will also be a nan
5476     // allowing the nan to propagate. If both inputs are nan, this will swap the
5477     // inputs which is harmless.
5478 
5479     SDValue NewY = Y;
5480     if (!Op->getFlags().hasNoNaNs() && !DAG.isKnownNeverNaN(X)) {
5481       SDValue XIsNonNan = DAG.getSetCC(DL, XLenVT, X, X, ISD::SETOEQ);
5482       NewY = DAG.getSelect(DL, VT, XIsNonNan, Y, X);
5483     }
5484 
5485     SDValue NewX = X;
5486     if (!Op->getFlags().hasNoNaNs() && !DAG.isKnownNeverNaN(Y)) {
5487       SDValue YIsNonNan = DAG.getSetCC(DL, XLenVT, Y, Y, ISD::SETOEQ);
5488       NewX = DAG.getSelect(DL, VT, YIsNonNan, X, Y);
5489     }
5490 
5491     unsigned Opc =
5492         Op.getOpcode() == ISD::FMAXIMUM ? RISCVISD::FMAX : RISCVISD::FMIN;
5493     return DAG.getNode(Opc, DL, VT, NewX, NewY);
5494   }
5495 
5496   // Check no NaNs before converting to fixed vector scalable.
5497   bool XIsNeverNan = Op->getFlags().hasNoNaNs() || DAG.isKnownNeverNaN(X);
5498   bool YIsNeverNan = Op->getFlags().hasNoNaNs() || DAG.isKnownNeverNaN(Y);
5499 
5500   MVT ContainerVT = VT;
5501   if (VT.isFixedLengthVector()) {
5502     ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
5503     X = convertToScalableVector(ContainerVT, X, DAG, Subtarget);
5504     Y = convertToScalableVector(ContainerVT, Y, DAG, Subtarget);
5505   }
5506 
5507   SDValue Mask, VL;
5508   if (Op->isVPOpcode()) {
5509     Mask = Op.getOperand(2);
5510     if (VT.isFixedLengthVector())
5511       Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
5512                                      Subtarget);
5513     VL = Op.getOperand(3);
5514   } else {
5515     std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
5516   }
5517 
5518   SDValue NewY = Y;
5519   if (!XIsNeverNan) {
5520     SDValue XIsNonNan = DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(),
5521                                     {X, X, DAG.getCondCode(ISD::SETOEQ),
5522                                      DAG.getUNDEF(ContainerVT), Mask, VL});
5523     NewY = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, XIsNonNan, Y, X,
5524                        DAG.getUNDEF(ContainerVT), VL);
5525   }
5526 
5527   SDValue NewX = X;
5528   if (!YIsNeverNan) {
5529     SDValue YIsNonNan = DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(),
5530                                     {Y, Y, DAG.getCondCode(ISD::SETOEQ),
5531                                      DAG.getUNDEF(ContainerVT), Mask, VL});
5532     NewX = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, YIsNonNan, X, Y,
5533                        DAG.getUNDEF(ContainerVT), VL);
5534   }
5535 
5536   unsigned Opc =
5537       Op.getOpcode() == ISD::FMAXIMUM || Op->getOpcode() == ISD::VP_FMAXIMUM
5538           ? RISCVISD::VFMAX_VL
5539           : RISCVISD::VFMIN_VL;
5540   SDValue Res = DAG.getNode(Opc, DL, ContainerVT, NewX, NewY,
5541                             DAG.getUNDEF(ContainerVT), Mask, VL);
5542   if (VT.isFixedLengthVector())
5543     Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
5544   return Res;
5545 }
5546 
5547 /// Get a RISC-V target specified VL op for a given SDNode.
getRISCVVLOp(SDValue Op)5548 static unsigned getRISCVVLOp(SDValue Op) {
5549 #define OP_CASE(NODE)                                                          \
5550   case ISD::NODE:                                                              \
5551     return RISCVISD::NODE##_VL;
5552 #define VP_CASE(NODE)                                                          \
5553   case ISD::VP_##NODE:                                                         \
5554     return RISCVISD::NODE##_VL;
5555   // clang-format off
5556   switch (Op.getOpcode()) {
5557   default:
5558     llvm_unreachable("don't have RISC-V specified VL op for this SDNode");
5559   OP_CASE(ADD)
5560   OP_CASE(SUB)
5561   OP_CASE(MUL)
5562   OP_CASE(MULHS)
5563   OP_CASE(MULHU)
5564   OP_CASE(SDIV)
5565   OP_CASE(SREM)
5566   OP_CASE(UDIV)
5567   OP_CASE(UREM)
5568   OP_CASE(SHL)
5569   OP_CASE(SRA)
5570   OP_CASE(SRL)
5571   OP_CASE(ROTL)
5572   OP_CASE(ROTR)
5573   OP_CASE(BSWAP)
5574   OP_CASE(CTTZ)
5575   OP_CASE(CTLZ)
5576   OP_CASE(CTPOP)
5577   OP_CASE(BITREVERSE)
5578   OP_CASE(SADDSAT)
5579   OP_CASE(UADDSAT)
5580   OP_CASE(SSUBSAT)
5581   OP_CASE(USUBSAT)
5582   OP_CASE(AVGFLOORU)
5583   OP_CASE(AVGCEILU)
5584   OP_CASE(FADD)
5585   OP_CASE(FSUB)
5586   OP_CASE(FMUL)
5587   OP_CASE(FDIV)
5588   OP_CASE(FNEG)
5589   OP_CASE(FABS)
5590   OP_CASE(FSQRT)
5591   OP_CASE(SMIN)
5592   OP_CASE(SMAX)
5593   OP_CASE(UMIN)
5594   OP_CASE(UMAX)
5595   OP_CASE(STRICT_FADD)
5596   OP_CASE(STRICT_FSUB)
5597   OP_CASE(STRICT_FMUL)
5598   OP_CASE(STRICT_FDIV)
5599   OP_CASE(STRICT_FSQRT)
5600   VP_CASE(ADD)        // VP_ADD
5601   VP_CASE(SUB)        // VP_SUB
5602   VP_CASE(MUL)        // VP_MUL
5603   VP_CASE(SDIV)       // VP_SDIV
5604   VP_CASE(SREM)       // VP_SREM
5605   VP_CASE(UDIV)       // VP_UDIV
5606   VP_CASE(UREM)       // VP_UREM
5607   VP_CASE(SHL)        // VP_SHL
5608   VP_CASE(FADD)       // VP_FADD
5609   VP_CASE(FSUB)       // VP_FSUB
5610   VP_CASE(FMUL)       // VP_FMUL
5611   VP_CASE(FDIV)       // VP_FDIV
5612   VP_CASE(FNEG)       // VP_FNEG
5613   VP_CASE(FABS)       // VP_FABS
5614   VP_CASE(SMIN)       // VP_SMIN
5615   VP_CASE(SMAX)       // VP_SMAX
5616   VP_CASE(UMIN)       // VP_UMIN
5617   VP_CASE(UMAX)       // VP_UMAX
5618   VP_CASE(FCOPYSIGN)  // VP_FCOPYSIGN
5619   VP_CASE(SETCC)      // VP_SETCC
5620   VP_CASE(SINT_TO_FP) // VP_SINT_TO_FP
5621   VP_CASE(UINT_TO_FP) // VP_UINT_TO_FP
5622   VP_CASE(BITREVERSE) // VP_BITREVERSE
5623   VP_CASE(BSWAP)      // VP_BSWAP
5624   VP_CASE(CTLZ)       // VP_CTLZ
5625   VP_CASE(CTTZ)       // VP_CTTZ
5626   VP_CASE(CTPOP)      // VP_CTPOP
5627   case ISD::CTLZ_ZERO_UNDEF:
5628   case ISD::VP_CTLZ_ZERO_UNDEF:
5629     return RISCVISD::CTLZ_VL;
5630   case ISD::CTTZ_ZERO_UNDEF:
5631   case ISD::VP_CTTZ_ZERO_UNDEF:
5632     return RISCVISD::CTTZ_VL;
5633   case ISD::FMA:
5634   case ISD::VP_FMA:
5635     return RISCVISD::VFMADD_VL;
5636   case ISD::STRICT_FMA:
5637     return RISCVISD::STRICT_VFMADD_VL;
5638   case ISD::AND:
5639   case ISD::VP_AND:
5640     if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
5641       return RISCVISD::VMAND_VL;
5642     return RISCVISD::AND_VL;
5643   case ISD::OR:
5644   case ISD::VP_OR:
5645     if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
5646       return RISCVISD::VMOR_VL;
5647     return RISCVISD::OR_VL;
5648   case ISD::XOR:
5649   case ISD::VP_XOR:
5650     if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
5651       return RISCVISD::VMXOR_VL;
5652     return RISCVISD::XOR_VL;
5653   case ISD::VP_SELECT:
5654   case ISD::VP_MERGE:
5655     return RISCVISD::VMERGE_VL;
5656   case ISD::VP_ASHR:
5657     return RISCVISD::SRA_VL;
5658   case ISD::VP_LSHR:
5659     return RISCVISD::SRL_VL;
5660   case ISD::VP_SQRT:
5661     return RISCVISD::FSQRT_VL;
5662   case ISD::VP_SIGN_EXTEND:
5663     return RISCVISD::VSEXT_VL;
5664   case ISD::VP_ZERO_EXTEND:
5665     return RISCVISD::VZEXT_VL;
5666   case ISD::VP_FP_TO_SINT:
5667     return RISCVISD::VFCVT_RTZ_X_F_VL;
5668   case ISD::VP_FP_TO_UINT:
5669     return RISCVISD::VFCVT_RTZ_XU_F_VL;
5670   case ISD::FMINNUM:
5671   case ISD::VP_FMINNUM:
5672     return RISCVISD::VFMIN_VL;
5673   case ISD::FMAXNUM:
5674   case ISD::VP_FMAXNUM:
5675     return RISCVISD::VFMAX_VL;
5676   }
5677   // clang-format on
5678 #undef OP_CASE
5679 #undef VP_CASE
5680 }
5681 
5682 /// Return true if a RISC-V target specified op has a merge operand.
hasMergeOp(unsigned Opcode)5683 static bool hasMergeOp(unsigned Opcode) {
5684   assert(Opcode > RISCVISD::FIRST_NUMBER &&
5685          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
5686          "not a RISC-V target specific op");
5687   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5688                     126 &&
5689                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
5690                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
5691                     21 &&
5692                 "adding target specific op should update this function");
5693   if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
5694     return true;
5695   if (Opcode == RISCVISD::FCOPYSIGN_VL)
5696     return true;
5697   if (Opcode >= RISCVISD::VWMUL_VL && Opcode <= RISCVISD::VFWSUB_W_VL)
5698     return true;
5699   if (Opcode == RISCVISD::SETCC_VL)
5700     return true;
5701   if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
5702     return true;
5703   if (Opcode == RISCVISD::VMERGE_VL)
5704     return true;
5705   return false;
5706 }
5707 
5708 /// Return true if a RISC-V target specified op has a mask operand.
hasMaskOp(unsigned Opcode)5709 static bool hasMaskOp(unsigned Opcode) {
5710   assert(Opcode > RISCVISD::FIRST_NUMBER &&
5711          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
5712          "not a RISC-V target specific op");
5713   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5714                     126 &&
5715                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
5716                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
5717                     21 &&
5718                 "adding target specific op should update this function");
5719   if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
5720     return true;
5721   if (Opcode >= RISCVISD::VRGATHER_VX_VL && Opcode <= RISCVISD::VFIRST_VL)
5722     return true;
5723   if (Opcode >= RISCVISD::STRICT_FADD_VL &&
5724       Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL)
5725     return true;
5726   return false;
5727 }
5728 
SplitVectorOp(SDValue Op,SelectionDAG & DAG)5729 static SDValue SplitVectorOp(SDValue Op, SelectionDAG &DAG) {
5730   auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op.getValueType());
5731   SDLoc DL(Op);
5732 
5733   SmallVector<SDValue, 4> LoOperands(Op.getNumOperands());
5734   SmallVector<SDValue, 4> HiOperands(Op.getNumOperands());
5735 
5736   for (unsigned j = 0; j != Op.getNumOperands(); ++j) {
5737     if (!Op.getOperand(j).getValueType().isVector()) {
5738       LoOperands[j] = Op.getOperand(j);
5739       HiOperands[j] = Op.getOperand(j);
5740       continue;
5741     }
5742     std::tie(LoOperands[j], HiOperands[j]) =
5743         DAG.SplitVector(Op.getOperand(j), DL);
5744   }
5745 
5746   SDValue LoRes =
5747       DAG.getNode(Op.getOpcode(), DL, LoVT, LoOperands, Op->getFlags());
5748   SDValue HiRes =
5749       DAG.getNode(Op.getOpcode(), DL, HiVT, HiOperands, Op->getFlags());
5750 
5751   return DAG.getNode(ISD::CONCAT_VECTORS, DL, Op.getValueType(), LoRes, HiRes);
5752 }
5753 
SplitVPOp(SDValue Op,SelectionDAG & DAG)5754 static SDValue SplitVPOp(SDValue Op, SelectionDAG &DAG) {
5755   assert(ISD::isVPOpcode(Op.getOpcode()) && "Not a VP op");
5756   auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op.getValueType());
5757   SDLoc DL(Op);
5758 
5759   SmallVector<SDValue, 4> LoOperands(Op.getNumOperands());
5760   SmallVector<SDValue, 4> HiOperands(Op.getNumOperands());
5761 
5762   for (unsigned j = 0; j != Op.getNumOperands(); ++j) {
5763     if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) == j) {
5764       std::tie(LoOperands[j], HiOperands[j]) =
5765           DAG.SplitEVL(Op.getOperand(j), Op.getValueType(), DL);
5766       continue;
5767     }
5768     if (!Op.getOperand(j).getValueType().isVector()) {
5769       LoOperands[j] = Op.getOperand(j);
5770       HiOperands[j] = Op.getOperand(j);
5771       continue;
5772     }
5773     std::tie(LoOperands[j], HiOperands[j]) =
5774         DAG.SplitVector(Op.getOperand(j), DL);
5775   }
5776 
5777   SDValue LoRes =
5778       DAG.getNode(Op.getOpcode(), DL, LoVT, LoOperands, Op->getFlags());
5779   SDValue HiRes =
5780       DAG.getNode(Op.getOpcode(), DL, HiVT, HiOperands, Op->getFlags());
5781 
5782   return DAG.getNode(ISD::CONCAT_VECTORS, DL, Op.getValueType(), LoRes, HiRes);
5783 }
5784 
SplitVectorReductionOp(SDValue Op,SelectionDAG & DAG)5785 static SDValue SplitVectorReductionOp(SDValue Op, SelectionDAG &DAG) {
5786   SDLoc DL(Op);
5787 
5788   auto [Lo, Hi] = DAG.SplitVector(Op.getOperand(1), DL);
5789   auto [MaskLo, MaskHi] = DAG.SplitVector(Op.getOperand(2), DL);
5790   auto [EVLLo, EVLHi] =
5791       DAG.SplitEVL(Op.getOperand(3), Op.getOperand(1).getValueType(), DL);
5792 
5793   SDValue ResLo =
5794       DAG.getNode(Op.getOpcode(), DL, Op.getValueType(),
5795                   {Op.getOperand(0), Lo, MaskLo, EVLLo}, Op->getFlags());
5796   return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(),
5797                      {ResLo, Hi, MaskHi, EVLHi}, Op->getFlags());
5798 }
5799 
SplitStrictFPVectorOp(SDValue Op,SelectionDAG & DAG)5800 static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
5801 
5802   assert(Op->isStrictFPOpcode());
5803 
5804   auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op->getValueType(0));
5805 
5806   SDVTList LoVTs = DAG.getVTList(LoVT, Op->getValueType(1));
5807   SDVTList HiVTs = DAG.getVTList(HiVT, Op->getValueType(1));
5808 
5809   SDLoc DL(Op);
5810 
5811   SmallVector<SDValue, 4> LoOperands(Op.getNumOperands());
5812   SmallVector<SDValue, 4> HiOperands(Op.getNumOperands());
5813 
5814   for (unsigned j = 0; j != Op.getNumOperands(); ++j) {
5815     if (!Op.getOperand(j).getValueType().isVector()) {
5816       LoOperands[j] = Op.getOperand(j);
5817       HiOperands[j] = Op.getOperand(j);
5818       continue;
5819     }
5820     std::tie(LoOperands[j], HiOperands[j]) =
5821         DAG.SplitVector(Op.getOperand(j), DL);
5822   }
5823 
5824   SDValue LoRes =
5825       DAG.getNode(Op.getOpcode(), DL, LoVTs, LoOperands, Op->getFlags());
5826   HiOperands[0] = LoRes.getValue(1);
5827   SDValue HiRes =
5828       DAG.getNode(Op.getOpcode(), DL, HiVTs, HiOperands, Op->getFlags());
5829 
5830   SDValue V = DAG.getNode(ISD::CONCAT_VECTORS, DL, Op->getValueType(0),
5831                           LoRes.getValue(0), HiRes.getValue(0));
5832   return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
5833 }
5834 
LowerOperation(SDValue Op,SelectionDAG & DAG) const5835 SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
5836                                             SelectionDAG &DAG) const {
5837   switch (Op.getOpcode()) {
5838   default:
5839     report_fatal_error("unimplemented operand");
5840   case ISD::ATOMIC_FENCE:
5841     return LowerATOMIC_FENCE(Op, DAG, Subtarget);
5842   case ISD::GlobalAddress:
5843     return lowerGlobalAddress(Op, DAG);
5844   case ISD::BlockAddress:
5845     return lowerBlockAddress(Op, DAG);
5846   case ISD::ConstantPool:
5847     return lowerConstantPool(Op, DAG);
5848   case ISD::JumpTable:
5849     return lowerJumpTable(Op, DAG);
5850   case ISD::GlobalTLSAddress:
5851     return lowerGlobalTLSAddress(Op, DAG);
5852   case ISD::Constant:
5853     return lowerConstant(Op, DAG, Subtarget);
5854   case ISD::SELECT:
5855     return lowerSELECT(Op, DAG);
5856   case ISD::BRCOND:
5857     return lowerBRCOND(Op, DAG);
5858   case ISD::VASTART:
5859     return lowerVASTART(Op, DAG);
5860   case ISD::FRAMEADDR:
5861     return lowerFRAMEADDR(Op, DAG);
5862   case ISD::RETURNADDR:
5863     return lowerRETURNADDR(Op, DAG);
5864   case ISD::SHL_PARTS:
5865     return lowerShiftLeftParts(Op, DAG);
5866   case ISD::SRA_PARTS:
5867     return lowerShiftRightParts(Op, DAG, true);
5868   case ISD::SRL_PARTS:
5869     return lowerShiftRightParts(Op, DAG, false);
5870   case ISD::ROTL:
5871   case ISD::ROTR:
5872     if (Op.getValueType().isFixedLengthVector()) {
5873       assert(Subtarget.hasStdExtZvkb());
5874       return lowerToScalableOp(Op, DAG);
5875     }
5876     assert(Subtarget.hasVendorXTHeadBb() &&
5877            !(Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()) &&
5878            "Unexpected custom legalization");
5879     // XTHeadBb only supports rotate by constant.
5880     if (!isa<ConstantSDNode>(Op.getOperand(1)))
5881       return SDValue();
5882     return Op;
5883   case ISD::BITCAST: {
5884     SDLoc DL(Op);
5885     EVT VT = Op.getValueType();
5886     SDValue Op0 = Op.getOperand(0);
5887     EVT Op0VT = Op0.getValueType();
5888     MVT XLenVT = Subtarget.getXLenVT();
5889     if (VT == MVT::f16 && Op0VT == MVT::i16 &&
5890         Subtarget.hasStdExtZfhminOrZhinxmin()) {
5891       SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0);
5892       SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0);
5893       return FPConv;
5894     }
5895     if (VT == MVT::bf16 && Op0VT == MVT::i16 &&
5896         Subtarget.hasStdExtZfbfmin()) {
5897       SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0);
5898       SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::bf16, NewOp0);
5899       return FPConv;
5900     }
5901     if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() &&
5902         Subtarget.hasStdExtFOrZfinx()) {
5903       SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0);
5904       SDValue FPConv =
5905           DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, NewOp0);
5906       return FPConv;
5907     }
5908     if (VT == MVT::f64 && Op0VT == MVT::i64 && XLenVT == MVT::i32 &&
5909         Subtarget.hasStdExtZfa()) {
5910       SDValue Lo, Hi;
5911       std::tie(Lo, Hi) = DAG.SplitScalar(Op0, DL, MVT::i32, MVT::i32);
5912       SDValue RetReg =
5913           DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi);
5914       return RetReg;
5915     }
5916 
5917     // Consider other scalar<->scalar casts as legal if the types are legal.
5918     // Otherwise expand them.
5919     if (!VT.isVector() && !Op0VT.isVector()) {
5920       if (isTypeLegal(VT) && isTypeLegal(Op0VT))
5921         return Op;
5922       return SDValue();
5923     }
5924 
5925     assert(!VT.isScalableVector() && !Op0VT.isScalableVector() &&
5926            "Unexpected types");
5927 
5928     if (VT.isFixedLengthVector()) {
5929       // We can handle fixed length vector bitcasts with a simple replacement
5930       // in isel.
5931       if (Op0VT.isFixedLengthVector())
5932         return Op;
5933       // When bitcasting from scalar to fixed-length vector, insert the scalar
5934       // into a one-element vector of the result type, and perform a vector
5935       // bitcast.
5936       if (!Op0VT.isVector()) {
5937         EVT BVT = EVT::getVectorVT(*DAG.getContext(), Op0VT, 1);
5938         if (!isTypeLegal(BVT))
5939           return SDValue();
5940         return DAG.getBitcast(VT, DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, BVT,
5941                                               DAG.getUNDEF(BVT), Op0,
5942                                               DAG.getConstant(0, DL, XLenVT)));
5943       }
5944       return SDValue();
5945     }
5946     // Custom-legalize bitcasts from fixed-length vector types to scalar types
5947     // thus: bitcast the vector to a one-element vector type whose element type
5948     // is the same as the result type, and extract the first element.
5949     if (!VT.isVector() && Op0VT.isFixedLengthVector()) {
5950       EVT BVT = EVT::getVectorVT(*DAG.getContext(), VT, 1);
5951       if (!isTypeLegal(BVT))
5952         return SDValue();
5953       SDValue BVec = DAG.getBitcast(BVT, Op0);
5954       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
5955                          DAG.getConstant(0, DL, XLenVT));
5956     }
5957     return SDValue();
5958   }
5959   case ISD::INTRINSIC_WO_CHAIN:
5960     return LowerINTRINSIC_WO_CHAIN(Op, DAG);
5961   case ISD::INTRINSIC_W_CHAIN:
5962     return LowerINTRINSIC_W_CHAIN(Op, DAG);
5963   case ISD::INTRINSIC_VOID:
5964     return LowerINTRINSIC_VOID(Op, DAG);
5965   case ISD::IS_FPCLASS:
5966     return LowerIS_FPCLASS(Op, DAG);
5967   case ISD::BITREVERSE: {
5968     MVT VT = Op.getSimpleValueType();
5969     if (VT.isFixedLengthVector()) {
5970       assert(Subtarget.hasStdExtZvbb());
5971       return lowerToScalableOp(Op, DAG);
5972     }
5973     SDLoc DL(Op);
5974     assert(Subtarget.hasStdExtZbkb() && "Unexpected custom legalization");
5975     assert(Op.getOpcode() == ISD::BITREVERSE && "Unexpected opcode");
5976     // Expand bitreverse to a bswap(rev8) followed by brev8.
5977     SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Op.getOperand(0));
5978     return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
5979   }
5980   case ISD::TRUNCATE:
5981     // Only custom-lower vector truncates
5982     if (!Op.getSimpleValueType().isVector())
5983       return Op;
5984     return lowerVectorTruncLike(Op, DAG);
5985   case ISD::ANY_EXTEND:
5986   case ISD::ZERO_EXTEND:
5987     if (Op.getOperand(0).getValueType().isVector() &&
5988         Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
5989       return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ 1);
5990     return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VZEXT_VL);
5991   case ISD::SIGN_EXTEND:
5992     if (Op.getOperand(0).getValueType().isVector() &&
5993         Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
5994       return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
5995     return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
5996   case ISD::SPLAT_VECTOR_PARTS:
5997     return lowerSPLAT_VECTOR_PARTS(Op, DAG);
5998   case ISD::INSERT_VECTOR_ELT:
5999     return lowerINSERT_VECTOR_ELT(Op, DAG);
6000   case ISD::EXTRACT_VECTOR_ELT:
6001     return lowerEXTRACT_VECTOR_ELT(Op, DAG);
6002   case ISD::SCALAR_TO_VECTOR: {
6003     MVT VT = Op.getSimpleValueType();
6004     SDLoc DL(Op);
6005     SDValue Scalar = Op.getOperand(0);
6006     if (VT.getVectorElementType() == MVT::i1) {
6007       MVT WideVT = VT.changeVectorElementType(MVT::i8);
6008       SDValue V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, WideVT, Scalar);
6009       return DAG.getNode(ISD::TRUNCATE, DL, VT, V);
6010     }
6011     MVT ContainerVT = VT;
6012     if (VT.isFixedLengthVector())
6013       ContainerVT = getContainerForFixedLengthVector(VT);
6014     SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
6015     Scalar = DAG.getNode(ISD::ANY_EXTEND, DL, Subtarget.getXLenVT(), Scalar);
6016     SDValue V = DAG.getNode(RISCVISD::VMV_S_X_VL, DL, ContainerVT,
6017                             DAG.getUNDEF(ContainerVT), Scalar, VL);
6018     if (VT.isFixedLengthVector())
6019       V = convertFromScalableVector(VT, V, DAG, Subtarget);
6020     return V;
6021   }
6022   case ISD::VSCALE: {
6023     MVT XLenVT = Subtarget.getXLenVT();
6024     MVT VT = Op.getSimpleValueType();
6025     SDLoc DL(Op);
6026     SDValue Res = DAG.getNode(RISCVISD::READ_VLENB, DL, XLenVT);
6027     // We define our scalable vector types for lmul=1 to use a 64 bit known
6028     // minimum size. e.g. <vscale x 2 x i32>. VLENB is in bytes so we calculate
6029     // vscale as VLENB / 8.
6030     static_assert(RISCV::RVVBitsPerBlock == 64, "Unexpected bits per block!");
6031     if (Subtarget.getRealMinVLen() < RISCV::RVVBitsPerBlock)
6032       report_fatal_error("Support for VLEN==32 is incomplete.");
6033     // We assume VLENB is a multiple of 8. We manually choose the best shift
6034     // here because SimplifyDemandedBits isn't always able to simplify it.
6035     uint64_t Val = Op.getConstantOperandVal(0);
6036     if (isPowerOf2_64(Val)) {
6037       uint64_t Log2 = Log2_64(Val);
6038       if (Log2 < 3)
6039         Res = DAG.getNode(ISD::SRL, DL, XLenVT, Res,
6040                           DAG.getConstant(3 - Log2, DL, VT));
6041       else if (Log2 > 3)
6042         Res = DAG.getNode(ISD::SHL, DL, XLenVT, Res,
6043                           DAG.getConstant(Log2 - 3, DL, XLenVT));
6044     } else if ((Val % 8) == 0) {
6045       // If the multiplier is a multiple of 8, scale it down to avoid needing
6046       // to shift the VLENB value.
6047       Res = DAG.getNode(ISD::MUL, DL, XLenVT, Res,
6048                         DAG.getConstant(Val / 8, DL, XLenVT));
6049     } else {
6050       SDValue VScale = DAG.getNode(ISD::SRL, DL, XLenVT, Res,
6051                                    DAG.getConstant(3, DL, XLenVT));
6052       Res = DAG.getNode(ISD::MUL, DL, XLenVT, VScale,
6053                         DAG.getConstant(Val, DL, XLenVT));
6054     }
6055     return DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
6056   }
6057   case ISD::FPOWI: {
6058     // Custom promote f16 powi with illegal i32 integer type on RV64. Once
6059     // promoted this will be legalized into a libcall by LegalizeIntegerTypes.
6060     if (Op.getValueType() == MVT::f16 && Subtarget.is64Bit() &&
6061         Op.getOperand(1).getValueType() == MVT::i32) {
6062       SDLoc DL(Op);
6063       SDValue Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
6064       SDValue Powi =
6065           DAG.getNode(ISD::FPOWI, DL, MVT::f32, Op0, Op.getOperand(1));
6066       return DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, Powi,
6067                          DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
6068     }
6069     return SDValue();
6070   }
6071   case ISD::FMAXIMUM:
6072   case ISD::FMINIMUM:
6073     if (Op.getValueType() == MVT::nxv32f16 &&
6074         (Subtarget.hasVInstructionsF16Minimal() &&
6075          !Subtarget.hasVInstructionsF16()))
6076       return SplitVectorOp(Op, DAG);
6077     return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
6078   case ISD::FP_EXTEND: {
6079     SDLoc DL(Op);
6080     EVT VT = Op.getValueType();
6081     SDValue Op0 = Op.getOperand(0);
6082     EVT Op0VT = Op0.getValueType();
6083     if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin())
6084       return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
6085     if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) {
6086       SDValue FloatVal =
6087           DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
6088       return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal);
6089     }
6090 
6091     if (!Op.getValueType().isVector())
6092       return Op;
6093     return lowerVectorFPExtendOrRoundLike(Op, DAG);
6094   }
6095   case ISD::FP_ROUND: {
6096     SDLoc DL(Op);
6097     EVT VT = Op.getValueType();
6098     SDValue Op0 = Op.getOperand(0);
6099     EVT Op0VT = Op0.getValueType();
6100     if (VT == MVT::bf16 && Op0VT == MVT::f32 && Subtarget.hasStdExtZfbfmin())
6101       return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, Op0);
6102     if (VT == MVT::bf16 && Op0VT == MVT::f64 && Subtarget.hasStdExtZfbfmin() &&
6103         Subtarget.hasStdExtDOrZdinx()) {
6104       SDValue FloatVal =
6105           DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op0,
6106                       DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
6107       return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, FloatVal);
6108     }
6109 
6110     if (!Op.getValueType().isVector())
6111       return Op;
6112     return lowerVectorFPExtendOrRoundLike(Op, DAG);
6113   }
6114   case ISD::STRICT_FP_ROUND:
6115   case ISD::STRICT_FP_EXTEND:
6116     return lowerStrictFPExtendOrRoundLike(Op, DAG);
6117   case ISD::SINT_TO_FP:
6118   case ISD::UINT_TO_FP:
6119     if (Op.getValueType().isVector() &&
6120         Op.getValueType().getScalarType() == MVT::f16 &&
6121         (Subtarget.hasVInstructionsF16Minimal() &&
6122          !Subtarget.hasVInstructionsF16())) {
6123       if (Op.getValueType() == MVT::nxv32f16)
6124         return SplitVectorOp(Op, DAG);
6125       // int -> f32
6126       SDLoc DL(Op);
6127       MVT NVT =
6128           MVT::getVectorVT(MVT::f32, Op.getValueType().getVectorElementCount());
6129       SDValue NC = DAG.getNode(Op.getOpcode(), DL, NVT, Op->ops());
6130       // f32 -> f16
6131       return DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), NC,
6132                          DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
6133     }
6134     [[fallthrough]];
6135   case ISD::FP_TO_SINT:
6136   case ISD::FP_TO_UINT:
6137     if (SDValue Op1 = Op.getOperand(0);
6138         Op1.getValueType().isVector() &&
6139         Op1.getValueType().getScalarType() == MVT::f16 &&
6140         (Subtarget.hasVInstructionsF16Minimal() &&
6141          !Subtarget.hasVInstructionsF16())) {
6142       if (Op1.getValueType() == MVT::nxv32f16)
6143         return SplitVectorOp(Op, DAG);
6144       // f16 -> f32
6145       SDLoc DL(Op);
6146       MVT NVT = MVT::getVectorVT(MVT::f32,
6147                                  Op1.getValueType().getVectorElementCount());
6148       SDValue WidenVec = DAG.getNode(ISD::FP_EXTEND, DL, NVT, Op1);
6149       // f32 -> int
6150       return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), WidenVec);
6151     }
6152     [[fallthrough]];
6153   case ISD::STRICT_FP_TO_SINT:
6154   case ISD::STRICT_FP_TO_UINT:
6155   case ISD::STRICT_SINT_TO_FP:
6156   case ISD::STRICT_UINT_TO_FP: {
6157     // RVV can only do fp<->int conversions to types half/double the size as
6158     // the source. We custom-lower any conversions that do two hops into
6159     // sequences.
6160     MVT VT = Op.getSimpleValueType();
6161     if (!VT.isVector())
6162       return Op;
6163     SDLoc DL(Op);
6164     bool IsStrict = Op->isStrictFPOpcode();
6165     SDValue Src = Op.getOperand(0 + IsStrict);
6166     MVT EltVT = VT.getVectorElementType();
6167     MVT SrcVT = Src.getSimpleValueType();
6168     MVT SrcEltVT = SrcVT.getVectorElementType();
6169     unsigned EltSize = EltVT.getSizeInBits();
6170     unsigned SrcEltSize = SrcEltVT.getSizeInBits();
6171     assert(isPowerOf2_32(EltSize) && isPowerOf2_32(SrcEltSize) &&
6172            "Unexpected vector element types");
6173 
6174     bool IsInt2FP = SrcEltVT.isInteger();
6175     // Widening conversions
6176     if (EltSize > (2 * SrcEltSize)) {
6177       if (IsInt2FP) {
6178         // Do a regular integer sign/zero extension then convert to float.
6179         MVT IVecVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize / 2),
6180                                       VT.getVectorElementCount());
6181         unsigned ExtOpcode = (Op.getOpcode() == ISD::UINT_TO_FP ||
6182                               Op.getOpcode() == ISD::STRICT_UINT_TO_FP)
6183                                  ? ISD::ZERO_EXTEND
6184                                  : ISD::SIGN_EXTEND;
6185         SDValue Ext = DAG.getNode(ExtOpcode, DL, IVecVT, Src);
6186         if (IsStrict)
6187           return DAG.getNode(Op.getOpcode(), DL, Op->getVTList(),
6188                              Op.getOperand(0), Ext);
6189         return DAG.getNode(Op.getOpcode(), DL, VT, Ext);
6190       }
6191       // FP2Int
6192       assert(SrcEltVT == MVT::f16 && "Unexpected FP_TO_[US]INT lowering");
6193       // Do one doubling fp_extend then complete the operation by converting
6194       // to int.
6195       MVT InterimFVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
6196       if (IsStrict) {
6197         auto [FExt, Chain] =
6198             DAG.getStrictFPExtendOrRound(Src, Op.getOperand(0), DL, InterimFVT);
6199         return DAG.getNode(Op.getOpcode(), DL, Op->getVTList(), Chain, FExt);
6200       }
6201       SDValue FExt = DAG.getFPExtendOrRound(Src, DL, InterimFVT);
6202       return DAG.getNode(Op.getOpcode(), DL, VT, FExt);
6203     }
6204 
6205     // Narrowing conversions
6206     if (SrcEltSize > (2 * EltSize)) {
6207       if (IsInt2FP) {
6208         // One narrowing int_to_fp, then an fp_round.
6209         assert(EltVT == MVT::f16 && "Unexpected [US]_TO_FP lowering");
6210         MVT InterimFVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
6211         if (IsStrict) {
6212           SDValue Int2FP = DAG.getNode(Op.getOpcode(), DL,
6213                                        DAG.getVTList(InterimFVT, MVT::Other),
6214                                        Op.getOperand(0), Src);
6215           SDValue Chain = Int2FP.getValue(1);
6216           return DAG.getStrictFPExtendOrRound(Int2FP, Chain, DL, VT).first;
6217         }
6218         SDValue Int2FP = DAG.getNode(Op.getOpcode(), DL, InterimFVT, Src);
6219         return DAG.getFPExtendOrRound(Int2FP, DL, VT);
6220       }
6221       // FP2Int
6222       // One narrowing fp_to_int, then truncate the integer. If the float isn't
6223       // representable by the integer, the result is poison.
6224       MVT IVecVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2),
6225                                     VT.getVectorElementCount());
6226       if (IsStrict) {
6227         SDValue FP2Int =
6228             DAG.getNode(Op.getOpcode(), DL, DAG.getVTList(IVecVT, MVT::Other),
6229                         Op.getOperand(0), Src);
6230         SDValue Res = DAG.getNode(ISD::TRUNCATE, DL, VT, FP2Int);
6231         return DAG.getMergeValues({Res, FP2Int.getValue(1)}, DL);
6232       }
6233       SDValue FP2Int = DAG.getNode(Op.getOpcode(), DL, IVecVT, Src);
6234       return DAG.getNode(ISD::TRUNCATE, DL, VT, FP2Int);
6235     }
6236 
6237     // Scalable vectors can exit here. Patterns will handle equally-sized
6238     // conversions halving/doubling ones.
6239     if (!VT.isFixedLengthVector())
6240       return Op;
6241 
6242     // For fixed-length vectors we lower to a custom "VL" node.
6243     unsigned RVVOpc = 0;
6244     switch (Op.getOpcode()) {
6245     default:
6246       llvm_unreachable("Impossible opcode");
6247     case ISD::FP_TO_SINT:
6248       RVVOpc = RISCVISD::VFCVT_RTZ_X_F_VL;
6249       break;
6250     case ISD::FP_TO_UINT:
6251       RVVOpc = RISCVISD::VFCVT_RTZ_XU_F_VL;
6252       break;
6253     case ISD::SINT_TO_FP:
6254       RVVOpc = RISCVISD::SINT_TO_FP_VL;
6255       break;
6256     case ISD::UINT_TO_FP:
6257       RVVOpc = RISCVISD::UINT_TO_FP_VL;
6258       break;
6259     case ISD::STRICT_FP_TO_SINT:
6260       RVVOpc = RISCVISD::STRICT_VFCVT_RTZ_X_F_VL;
6261       break;
6262     case ISD::STRICT_FP_TO_UINT:
6263       RVVOpc = RISCVISD::STRICT_VFCVT_RTZ_XU_F_VL;
6264       break;
6265     case ISD::STRICT_SINT_TO_FP:
6266       RVVOpc = RISCVISD::STRICT_SINT_TO_FP_VL;
6267       break;
6268     case ISD::STRICT_UINT_TO_FP:
6269       RVVOpc = RISCVISD::STRICT_UINT_TO_FP_VL;
6270       break;
6271     }
6272 
6273     MVT ContainerVT = getContainerForFixedLengthVector(VT);
6274     MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
6275     assert(ContainerVT.getVectorElementCount() == SrcContainerVT.getVectorElementCount() &&
6276            "Expected same element count");
6277 
6278     auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
6279 
6280     Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
6281     if (IsStrict) {
6282       Src = DAG.getNode(RVVOpc, DL, DAG.getVTList(ContainerVT, MVT::Other),
6283                         Op.getOperand(0), Src, Mask, VL);
6284       SDValue SubVec = convertFromScalableVector(VT, Src, DAG, Subtarget);
6285       return DAG.getMergeValues({SubVec, Src.getValue(1)}, DL);
6286     }
6287     Src = DAG.getNode(RVVOpc, DL, ContainerVT, Src, Mask, VL);
6288     return convertFromScalableVector(VT, Src, DAG, Subtarget);
6289   }
6290   case ISD::FP_TO_SINT_SAT:
6291   case ISD::FP_TO_UINT_SAT:
6292     return lowerFP_TO_INT_SAT(Op, DAG, Subtarget);
6293   case ISD::FP_TO_BF16: {
6294     // Custom lower to ensure the libcall return is passed in an FPR on hard
6295     // float ABIs.
6296     assert(!Subtarget.isSoftFPABI() && "Unexpected custom legalization");
6297     SDLoc DL(Op);
6298     MakeLibCallOptions CallOptions;
6299     RTLIB::Libcall LC =
6300         RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
6301     SDValue Res =
6302         makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
6303     if (Subtarget.is64Bit() && !RV64LegalI32)
6304       return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res);
6305     return DAG.getBitcast(MVT::i32, Res);
6306   }
6307   case ISD::BF16_TO_FP: {
6308     assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalization");
6309     MVT VT = Op.getSimpleValueType();
6310     SDLoc DL(Op);
6311     Op = DAG.getNode(
6312         ISD::SHL, DL, Op.getOperand(0).getValueType(), Op.getOperand(0),
6313         DAG.getShiftAmountConstant(16, Op.getOperand(0).getValueType(), DL));
6314     SDValue Res = Subtarget.is64Bit()
6315                       ? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Op)
6316                       : DAG.getBitcast(MVT::f32, Op);
6317     // fp_extend if the target VT is bigger than f32.
6318     if (VT != MVT::f32)
6319       return DAG.getNode(ISD::FP_EXTEND, DL, VT, Res);
6320     return Res;
6321   }
6322   case ISD::FP_TO_FP16: {
6323     // Custom lower to ensure the libcall return is passed in an FPR on hard
6324     // float ABIs.
6325     assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalisation");
6326     SDLoc DL(Op);
6327     MakeLibCallOptions CallOptions;
6328     RTLIB::Libcall LC =
6329         RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::f16);
6330     SDValue Res =
6331         makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
6332     if (Subtarget.is64Bit() && !RV64LegalI32)
6333       return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res);
6334     return DAG.getBitcast(MVT::i32, Res);
6335   }
6336   case ISD::FP16_TO_FP: {
6337     // Custom lower to ensure the libcall argument is passed in an FPR on hard
6338     // float ABIs.
6339     assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalisation");
6340     SDLoc DL(Op);
6341     MakeLibCallOptions CallOptions;
6342     SDValue Arg = Subtarget.is64Bit()
6343                       ? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32,
6344                                     Op.getOperand(0))
6345                       : DAG.getBitcast(MVT::f32, Op.getOperand(0));
6346     SDValue Res =
6347         makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg, CallOptions, DL)
6348             .first;
6349     return Res;
6350   }
6351   case ISD::FTRUNC:
6352   case ISD::FCEIL:
6353   case ISD::FFLOOR:
6354   case ISD::FNEARBYINT:
6355   case ISD::FRINT:
6356   case ISD::FROUND:
6357   case ISD::FROUNDEVEN:
6358     return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
6359   case ISD::LRINT:
6360   case ISD::LLRINT:
6361     return lowerVectorXRINT(Op, DAG, Subtarget);
6362   case ISD::VECREDUCE_ADD:
6363   case ISD::VECREDUCE_UMAX:
6364   case ISD::VECREDUCE_SMAX:
6365   case ISD::VECREDUCE_UMIN:
6366   case ISD::VECREDUCE_SMIN:
6367     return lowerVECREDUCE(Op, DAG);
6368   case ISD::VECREDUCE_AND:
6369   case ISD::VECREDUCE_OR:
6370   case ISD::VECREDUCE_XOR:
6371     if (Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
6372       return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ false);
6373     return lowerVECREDUCE(Op, DAG);
6374   case ISD::VECREDUCE_FADD:
6375   case ISD::VECREDUCE_SEQ_FADD:
6376   case ISD::VECREDUCE_FMIN:
6377   case ISD::VECREDUCE_FMAX:
6378     return lowerFPVECREDUCE(Op, DAG);
6379   case ISD::VP_REDUCE_ADD:
6380   case ISD::VP_REDUCE_UMAX:
6381   case ISD::VP_REDUCE_SMAX:
6382   case ISD::VP_REDUCE_UMIN:
6383   case ISD::VP_REDUCE_SMIN:
6384   case ISD::VP_REDUCE_FADD:
6385   case ISD::VP_REDUCE_SEQ_FADD:
6386   case ISD::VP_REDUCE_FMIN:
6387   case ISD::VP_REDUCE_FMAX:
6388     if (Op.getOperand(1).getValueType() == MVT::nxv32f16 &&
6389         (Subtarget.hasVInstructionsF16Minimal() &&
6390          !Subtarget.hasVInstructionsF16()))
6391       return SplitVectorReductionOp(Op, DAG);
6392     return lowerVPREDUCE(Op, DAG);
6393   case ISD::VP_REDUCE_AND:
6394   case ISD::VP_REDUCE_OR:
6395   case ISD::VP_REDUCE_XOR:
6396     if (Op.getOperand(1).getValueType().getVectorElementType() == MVT::i1)
6397       return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ true);
6398     return lowerVPREDUCE(Op, DAG);
6399   case ISD::UNDEF: {
6400     MVT ContainerVT = getContainerForFixedLengthVector(Op.getSimpleValueType());
6401     return convertFromScalableVector(Op.getSimpleValueType(),
6402                                      DAG.getUNDEF(ContainerVT), DAG, Subtarget);
6403   }
6404   case ISD::INSERT_SUBVECTOR:
6405     return lowerINSERT_SUBVECTOR(Op, DAG);
6406   case ISD::EXTRACT_SUBVECTOR:
6407     return lowerEXTRACT_SUBVECTOR(Op, DAG);
6408   case ISD::VECTOR_DEINTERLEAVE:
6409     return lowerVECTOR_DEINTERLEAVE(Op, DAG);
6410   case ISD::VECTOR_INTERLEAVE:
6411     return lowerVECTOR_INTERLEAVE(Op, DAG);
6412   case ISD::STEP_VECTOR:
6413     return lowerSTEP_VECTOR(Op, DAG);
6414   case ISD::VECTOR_REVERSE:
6415     return lowerVECTOR_REVERSE(Op, DAG);
6416   case ISD::VECTOR_SPLICE:
6417     return lowerVECTOR_SPLICE(Op, DAG);
6418   case ISD::BUILD_VECTOR:
6419     return lowerBUILD_VECTOR(Op, DAG, Subtarget);
6420   case ISD::SPLAT_VECTOR:
6421     if (Op.getValueType().getScalarType() == MVT::f16 &&
6422         (Subtarget.hasVInstructionsF16Minimal() &&
6423          !Subtarget.hasVInstructionsF16())) {
6424       if (Op.getValueType() == MVT::nxv32f16)
6425         return SplitVectorOp(Op, DAG);
6426       SDLoc DL(Op);
6427       SDValue NewScalar =
6428           DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
6429       SDValue NewSplat = DAG.getNode(
6430           ISD::SPLAT_VECTOR, DL,
6431           MVT::getVectorVT(MVT::f32, Op.getValueType().getVectorElementCount()),
6432           NewScalar);
6433       return DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), NewSplat,
6434                          DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
6435     }
6436     if (Op.getValueType().getVectorElementType() == MVT::i1)
6437       return lowerVectorMaskSplat(Op, DAG);
6438     return SDValue();
6439   case ISD::VECTOR_SHUFFLE:
6440     return lowerVECTOR_SHUFFLE(Op, DAG, Subtarget);
6441   case ISD::CONCAT_VECTORS: {
6442     // Split CONCAT_VECTORS into a series of INSERT_SUBVECTOR nodes. This is
6443     // better than going through the stack, as the default expansion does.
6444     SDLoc DL(Op);
6445     MVT VT = Op.getSimpleValueType();
6446     unsigned NumOpElts =
6447         Op.getOperand(0).getSimpleValueType().getVectorMinNumElements();
6448     SDValue Vec = DAG.getUNDEF(VT);
6449     for (const auto &OpIdx : enumerate(Op->ops())) {
6450       SDValue SubVec = OpIdx.value();
6451       // Don't insert undef subvectors.
6452       if (SubVec.isUndef())
6453         continue;
6454       Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Vec, SubVec,
6455                         DAG.getIntPtrConstant(OpIdx.index() * NumOpElts, DL));
6456     }
6457     return Vec;
6458   }
6459   case ISD::LOAD:
6460     if (auto V = expandUnalignedRVVLoad(Op, DAG))
6461       return V;
6462     if (Op.getValueType().isFixedLengthVector())
6463       return lowerFixedLengthVectorLoadToRVV(Op, DAG);
6464     return Op;
6465   case ISD::STORE:
6466     if (auto V = expandUnalignedRVVStore(Op, DAG))
6467       return V;
6468     if (Op.getOperand(1).getValueType().isFixedLengthVector())
6469       return lowerFixedLengthVectorStoreToRVV(Op, DAG);
6470     return Op;
6471   case ISD::MLOAD:
6472   case ISD::VP_LOAD:
6473     return lowerMaskedLoad(Op, DAG);
6474   case ISD::MSTORE:
6475   case ISD::VP_STORE:
6476     return lowerMaskedStore(Op, DAG);
6477   case ISD::SELECT_CC: {
6478     // This occurs because we custom legalize SETGT and SETUGT for setcc. That
6479     // causes LegalizeDAG to think we need to custom legalize select_cc. Expand
6480     // into separate SETCC+SELECT just like LegalizeDAG.
6481     SDValue Tmp1 = Op.getOperand(0);
6482     SDValue Tmp2 = Op.getOperand(1);
6483     SDValue True = Op.getOperand(2);
6484     SDValue False = Op.getOperand(3);
6485     EVT VT = Op.getValueType();
6486     SDValue CC = Op.getOperand(4);
6487     EVT CmpVT = Tmp1.getValueType();
6488     EVT CCVT =
6489         getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), CmpVT);
6490     SDLoc DL(Op);
6491     SDValue Cond =
6492         DAG.getNode(ISD::SETCC, DL, CCVT, Tmp1, Tmp2, CC, Op->getFlags());
6493     return DAG.getSelect(DL, VT, Cond, True, False);
6494   }
6495   case ISD::SETCC: {
6496     MVT OpVT = Op.getOperand(0).getSimpleValueType();
6497     if (OpVT.isScalarInteger()) {
6498       MVT VT = Op.getSimpleValueType();
6499       SDValue LHS = Op.getOperand(0);
6500       SDValue RHS = Op.getOperand(1);
6501       ISD::CondCode CCVal = cast<CondCodeSDNode>(Op.getOperand(2))->get();
6502       assert((CCVal == ISD::SETGT || CCVal == ISD::SETUGT) &&
6503              "Unexpected CondCode");
6504 
6505       SDLoc DL(Op);
6506 
6507       // If the RHS is a constant in the range [-2049, 0) or (0, 2046], we can
6508       // convert this to the equivalent of (set(u)ge X, C+1) by using
6509       // (xori (slti(u) X, C+1), 1). This avoids materializing a small constant
6510       // in a register.
6511       if (isa<ConstantSDNode>(RHS)) {
6512         int64_t Imm = cast<ConstantSDNode>(RHS)->getSExtValue();
6513         if (Imm != 0 && isInt<12>((uint64_t)Imm + 1)) {
6514           // If this is an unsigned compare and the constant is -1, incrementing
6515           // the constant would change behavior. The result should be false.
6516           if (CCVal == ISD::SETUGT && Imm == -1)
6517             return DAG.getConstant(0, DL, VT);
6518           // Using getSetCCSwappedOperands will convert SET(U)GT->SET(U)LT.
6519           CCVal = ISD::getSetCCSwappedOperands(CCVal);
6520           SDValue SetCC = DAG.getSetCC(
6521               DL, VT, LHS, DAG.getConstant(Imm + 1, DL, OpVT), CCVal);
6522           return DAG.getLogicalNOT(DL, SetCC, VT);
6523         }
6524       }
6525 
6526       // Not a constant we could handle, swap the operands and condition code to
6527       // SETLT/SETULT.
6528       CCVal = ISD::getSetCCSwappedOperands(CCVal);
6529       return DAG.getSetCC(DL, VT, RHS, LHS, CCVal);
6530     }
6531 
6532     if (Op.getOperand(0).getSimpleValueType() == MVT::nxv32f16 &&
6533         (Subtarget.hasVInstructionsF16Minimal() &&
6534          !Subtarget.hasVInstructionsF16()))
6535       return SplitVectorOp(Op, DAG);
6536 
6537     return lowerFixedLengthVectorSetccToRVV(Op, DAG);
6538   }
6539   case ISD::ADD:
6540   case ISD::SUB:
6541   case ISD::MUL:
6542   case ISD::MULHS:
6543   case ISD::MULHU:
6544   case ISD::AND:
6545   case ISD::OR:
6546   case ISD::XOR:
6547   case ISD::SDIV:
6548   case ISD::SREM:
6549   case ISD::UDIV:
6550   case ISD::UREM:
6551   case ISD::BSWAP:
6552   case ISD::CTPOP:
6553     return lowerToScalableOp(Op, DAG);
6554   case ISD::SHL:
6555   case ISD::SRA:
6556   case ISD::SRL:
6557     if (Op.getSimpleValueType().isFixedLengthVector())
6558       return lowerToScalableOp(Op, DAG);
6559     // This can be called for an i32 shift amount that needs to be promoted.
6560     assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
6561            "Unexpected custom legalisation");
6562     return SDValue();
6563   case ISD::FADD:
6564   case ISD::FSUB:
6565   case ISD::FMUL:
6566   case ISD::FDIV:
6567   case ISD::FNEG:
6568   case ISD::FABS:
6569   case ISD::FSQRT:
6570   case ISD::FMA:
6571   case ISD::FMINNUM:
6572   case ISD::FMAXNUM:
6573     if (Op.getValueType() == MVT::nxv32f16 &&
6574         (Subtarget.hasVInstructionsF16Minimal() &&
6575          !Subtarget.hasVInstructionsF16()))
6576       return SplitVectorOp(Op, DAG);
6577     [[fallthrough]];
6578   case ISD::AVGFLOORU:
6579   case ISD::AVGCEILU:
6580   case ISD::SADDSAT:
6581   case ISD::UADDSAT:
6582   case ISD::SSUBSAT:
6583   case ISD::USUBSAT:
6584   case ISD::SMIN:
6585   case ISD::SMAX:
6586   case ISD::UMIN:
6587   case ISD::UMAX:
6588     return lowerToScalableOp(Op, DAG);
6589   case ISD::ABS:
6590   case ISD::VP_ABS:
6591     return lowerABS(Op, DAG);
6592   case ISD::CTLZ:
6593   case ISD::CTLZ_ZERO_UNDEF:
6594   case ISD::CTTZ:
6595   case ISD::CTTZ_ZERO_UNDEF:
6596     if (Subtarget.hasStdExtZvbb())
6597       return lowerToScalableOp(Op, DAG);
6598     assert(Op.getOpcode() != ISD::CTTZ);
6599     return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
6600   case ISD::VSELECT:
6601     return lowerFixedLengthVectorSelectToRVV(Op, DAG);
6602   case ISD::FCOPYSIGN:
6603     if (Op.getValueType() == MVT::nxv32f16 &&
6604         (Subtarget.hasVInstructionsF16Minimal() &&
6605          !Subtarget.hasVInstructionsF16()))
6606       return SplitVectorOp(Op, DAG);
6607     return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
6608   case ISD::STRICT_FADD:
6609   case ISD::STRICT_FSUB:
6610   case ISD::STRICT_FMUL:
6611   case ISD::STRICT_FDIV:
6612   case ISD::STRICT_FSQRT:
6613   case ISD::STRICT_FMA:
6614     if (Op.getValueType() == MVT::nxv32f16 &&
6615         (Subtarget.hasVInstructionsF16Minimal() &&
6616          !Subtarget.hasVInstructionsF16()))
6617       return SplitStrictFPVectorOp(Op, DAG);
6618     return lowerToScalableOp(Op, DAG);
6619   case ISD::STRICT_FSETCC:
6620   case ISD::STRICT_FSETCCS:
6621     return lowerVectorStrictFSetcc(Op, DAG);
6622   case ISD::STRICT_FCEIL:
6623   case ISD::STRICT_FRINT:
6624   case ISD::STRICT_FFLOOR:
6625   case ISD::STRICT_FTRUNC:
6626   case ISD::STRICT_FNEARBYINT:
6627   case ISD::STRICT_FROUND:
6628   case ISD::STRICT_FROUNDEVEN:
6629     return lowerVectorStrictFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
6630   case ISD::MGATHER:
6631   case ISD::VP_GATHER:
6632     return lowerMaskedGather(Op, DAG);
6633   case ISD::MSCATTER:
6634   case ISD::VP_SCATTER:
6635     return lowerMaskedScatter(Op, DAG);
6636   case ISD::GET_ROUNDING:
6637     return lowerGET_ROUNDING(Op, DAG);
6638   case ISD::SET_ROUNDING:
6639     return lowerSET_ROUNDING(Op, DAG);
6640   case ISD::EH_DWARF_CFA:
6641     return lowerEH_DWARF_CFA(Op, DAG);
6642   case ISD::VP_SELECT:
6643   case ISD::VP_MERGE:
6644   case ISD::VP_ADD:
6645   case ISD::VP_SUB:
6646   case ISD::VP_MUL:
6647   case ISD::VP_SDIV:
6648   case ISD::VP_UDIV:
6649   case ISD::VP_SREM:
6650   case ISD::VP_UREM:
6651     return lowerVPOp(Op, DAG);
6652   case ISD::VP_AND:
6653   case ISD::VP_OR:
6654   case ISD::VP_XOR:
6655     return lowerLogicVPOp(Op, DAG);
6656   case ISD::VP_FADD:
6657   case ISD::VP_FSUB:
6658   case ISD::VP_FMUL:
6659   case ISD::VP_FDIV:
6660   case ISD::VP_FNEG:
6661   case ISD::VP_FABS:
6662   case ISD::VP_SQRT:
6663   case ISD::VP_FMA:
6664   case ISD::VP_FMINNUM:
6665   case ISD::VP_FMAXNUM:
6666   case ISD::VP_FCOPYSIGN:
6667     if (Op.getValueType() == MVT::nxv32f16 &&
6668         (Subtarget.hasVInstructionsF16Minimal() &&
6669          !Subtarget.hasVInstructionsF16()))
6670       return SplitVPOp(Op, DAG);
6671     [[fallthrough]];
6672   case ISD::VP_ASHR:
6673   case ISD::VP_LSHR:
6674   case ISD::VP_SHL:
6675     return lowerVPOp(Op, DAG);
6676   case ISD::VP_IS_FPCLASS:
6677     return LowerIS_FPCLASS(Op, DAG);
6678   case ISD::VP_SIGN_EXTEND:
6679   case ISD::VP_ZERO_EXTEND:
6680     if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
6681       return lowerVPExtMaskOp(Op, DAG);
6682     return lowerVPOp(Op, DAG);
6683   case ISD::VP_TRUNCATE:
6684     return lowerVectorTruncLike(Op, DAG);
6685   case ISD::VP_FP_EXTEND:
6686   case ISD::VP_FP_ROUND:
6687     return lowerVectorFPExtendOrRoundLike(Op, DAG);
6688   case ISD::VP_SINT_TO_FP:
6689   case ISD::VP_UINT_TO_FP:
6690     if (Op.getValueType().isVector() &&
6691         Op.getValueType().getScalarType() == MVT::f16 &&
6692         (Subtarget.hasVInstructionsF16Minimal() &&
6693          !Subtarget.hasVInstructionsF16())) {
6694       if (Op.getValueType() == MVT::nxv32f16)
6695         return SplitVPOp(Op, DAG);
6696       // int -> f32
6697       SDLoc DL(Op);
6698       MVT NVT =
6699           MVT::getVectorVT(MVT::f32, Op.getValueType().getVectorElementCount());
6700       auto NC = DAG.getNode(Op.getOpcode(), DL, NVT, Op->ops());
6701       // f32 -> f16
6702       return DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), NC,
6703                          DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
6704     }
6705     [[fallthrough]];
6706   case ISD::VP_FP_TO_SINT:
6707   case ISD::VP_FP_TO_UINT:
6708     if (SDValue Op1 = Op.getOperand(0);
6709         Op1.getValueType().isVector() &&
6710         Op1.getValueType().getScalarType() == MVT::f16 &&
6711         (Subtarget.hasVInstructionsF16Minimal() &&
6712          !Subtarget.hasVInstructionsF16())) {
6713       if (Op1.getValueType() == MVT::nxv32f16)
6714         return SplitVPOp(Op, DAG);
6715       // f16 -> f32
6716       SDLoc DL(Op);
6717       MVT NVT = MVT::getVectorVT(MVT::f32,
6718                                  Op1.getValueType().getVectorElementCount());
6719       SDValue WidenVec = DAG.getNode(ISD::FP_EXTEND, DL, NVT, Op1);
6720       // f32 -> int
6721       return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(),
6722                          {WidenVec, Op.getOperand(1), Op.getOperand(2)});
6723     }
6724     return lowerVPFPIntConvOp(Op, DAG);
6725   case ISD::VP_SETCC:
6726     if (Op.getOperand(0).getSimpleValueType() == MVT::nxv32f16 &&
6727         (Subtarget.hasVInstructionsF16Minimal() &&
6728          !Subtarget.hasVInstructionsF16()))
6729       return SplitVPOp(Op, DAG);
6730     if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
6731       return lowerVPSetCCMaskOp(Op, DAG);
6732     [[fallthrough]];
6733   case ISD::VP_SMIN:
6734   case ISD::VP_SMAX:
6735   case ISD::VP_UMIN:
6736   case ISD::VP_UMAX:
6737   case ISD::VP_BITREVERSE:
6738   case ISD::VP_BSWAP:
6739     return lowerVPOp(Op, DAG);
6740   case ISD::VP_CTLZ:
6741   case ISD::VP_CTLZ_ZERO_UNDEF:
6742     if (Subtarget.hasStdExtZvbb())
6743       return lowerVPOp(Op, DAG);
6744     return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
6745   case ISD::VP_CTTZ:
6746   case ISD::VP_CTTZ_ZERO_UNDEF:
6747     if (Subtarget.hasStdExtZvbb())
6748       return lowerVPOp(Op, DAG);
6749     return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
6750   case ISD::VP_CTPOP:
6751     return lowerVPOp(Op, DAG);
6752   case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
6753     return lowerVPStridedLoad(Op, DAG);
6754   case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
6755     return lowerVPStridedStore(Op, DAG);
6756   case ISD::VP_FCEIL:
6757   case ISD::VP_FFLOOR:
6758   case ISD::VP_FRINT:
6759   case ISD::VP_FNEARBYINT:
6760   case ISD::VP_FROUND:
6761   case ISD::VP_FROUNDEVEN:
6762   case ISD::VP_FROUNDTOZERO:
6763     if (Op.getValueType() == MVT::nxv32f16 &&
6764         (Subtarget.hasVInstructionsF16Minimal() &&
6765          !Subtarget.hasVInstructionsF16()))
6766       return SplitVPOp(Op, DAG);
6767     return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
6768   case ISD::VP_FMAXIMUM:
6769   case ISD::VP_FMINIMUM:
6770     if (Op.getValueType() == MVT::nxv32f16 &&
6771         (Subtarget.hasVInstructionsF16Minimal() &&
6772          !Subtarget.hasVInstructionsF16()))
6773       return SplitVPOp(Op, DAG);
6774     return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
6775   case ISD::EXPERIMENTAL_VP_SPLICE:
6776     return lowerVPSpliceExperimental(Op, DAG);
6777   case ISD::EXPERIMENTAL_VP_REVERSE:
6778     return lowerVPReverseExperimental(Op, DAG);
6779   }
6780 }
6781 
getTargetNode(GlobalAddressSDNode * N,const SDLoc & DL,EVT Ty,SelectionDAG & DAG,unsigned Flags)6782 static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
6783                              SelectionDAG &DAG, unsigned Flags) {
6784   return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);
6785 }
6786 
getTargetNode(BlockAddressSDNode * N,const SDLoc & DL,EVT Ty,SelectionDAG & DAG,unsigned Flags)6787 static SDValue getTargetNode(BlockAddressSDNode *N, const SDLoc &DL, EVT Ty,
6788                              SelectionDAG &DAG, unsigned Flags) {
6789   return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(),
6790                                    Flags);
6791 }
6792 
getTargetNode(ConstantPoolSDNode * N,const SDLoc & DL,EVT Ty,SelectionDAG & DAG,unsigned Flags)6793 static SDValue getTargetNode(ConstantPoolSDNode *N, const SDLoc &DL, EVT Ty,
6794                              SelectionDAG &DAG, unsigned Flags) {
6795   return DAG.getTargetConstantPool(N->getConstVal(), Ty, N->getAlign(),
6796                                    N->getOffset(), Flags);
6797 }
6798 
getTargetNode(JumpTableSDNode * N,const SDLoc & DL,EVT Ty,SelectionDAG & DAG,unsigned Flags)6799 static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty,
6800                              SelectionDAG &DAG, unsigned Flags) {
6801   return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags);
6802 }
6803 
6804 template <class NodeTy>
getAddr(NodeTy * N,SelectionDAG & DAG,bool IsLocal,bool IsExternWeak) const6805 SDValue RISCVTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
6806                                      bool IsLocal, bool IsExternWeak) const {
6807   SDLoc DL(N);
6808   EVT Ty = getPointerTy(DAG.getDataLayout());
6809 
6810   // When HWASAN is used and tagging of global variables is enabled
6811   // they should be accessed via the GOT, since the tagged address of a global
6812   // is incompatible with existing code models. This also applies to non-pic
6813   // mode.
6814   if (isPositionIndependent() || Subtarget.allowTaggedGlobals()) {
6815     SDValue Addr = getTargetNode(N, DL, Ty, DAG, 0);
6816     if (IsLocal && !Subtarget.allowTaggedGlobals())
6817       // Use PC-relative addressing to access the symbol. This generates the
6818       // pattern (PseudoLLA sym), which expands to (addi (auipc %pcrel_hi(sym))
6819       // %pcrel_lo(auipc)).
6820       return DAG.getNode(RISCVISD::LLA, DL, Ty, Addr);
6821 
6822     // Use PC-relative addressing to access the GOT for this symbol, then load
6823     // the address from the GOT. This generates the pattern (PseudoLGA sym),
6824     // which expands to (ld (addi (auipc %got_pcrel_hi(sym)) %pcrel_lo(auipc))).
6825     SDValue Load =
6826         SDValue(DAG.getMachineNode(RISCV::PseudoLGA, DL, Ty, Addr), 0);
6827     MachineFunction &MF = DAG.getMachineFunction();
6828     MachineMemOperand *MemOp = MF.getMachineMemOperand(
6829         MachinePointerInfo::getGOT(MF),
6830         MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable |
6831             MachineMemOperand::MOInvariant,
6832         LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8));
6833     DAG.setNodeMemRefs(cast<MachineSDNode>(Load.getNode()), {MemOp});
6834     return Load;
6835   }
6836 
6837   switch (getTargetMachine().getCodeModel()) {
6838   default:
6839     report_fatal_error("Unsupported code model for lowering");
6840   case CodeModel::Small: {
6841     // Generate a sequence for accessing addresses within the first 2 GiB of
6842     // address space. This generates the pattern (addi (lui %hi(sym)) %lo(sym)).
6843     SDValue AddrHi = getTargetNode(N, DL, Ty, DAG, RISCVII::MO_HI);
6844     SDValue AddrLo = getTargetNode(N, DL, Ty, DAG, RISCVII::MO_LO);
6845     SDValue MNHi = DAG.getNode(RISCVISD::HI, DL, Ty, AddrHi);
6846     return DAG.getNode(RISCVISD::ADD_LO, DL, Ty, MNHi, AddrLo);
6847   }
6848   case CodeModel::Medium: {
6849     SDValue Addr = getTargetNode(N, DL, Ty, DAG, 0);
6850     if (IsExternWeak) {
6851       // An extern weak symbol may be undefined, i.e. have value 0, which may
6852       // not be within 2GiB of PC, so use GOT-indirect addressing to access the
6853       // symbol. This generates the pattern (PseudoLGA sym), which expands to
6854       // (ld (addi (auipc %got_pcrel_hi(sym)) %pcrel_lo(auipc))).
6855       SDValue Load =
6856           SDValue(DAG.getMachineNode(RISCV::PseudoLGA, DL, Ty, Addr), 0);
6857       MachineFunction &MF = DAG.getMachineFunction();
6858       MachineMemOperand *MemOp = MF.getMachineMemOperand(
6859           MachinePointerInfo::getGOT(MF),
6860           MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable |
6861               MachineMemOperand::MOInvariant,
6862           LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8));
6863       DAG.setNodeMemRefs(cast<MachineSDNode>(Load.getNode()), {MemOp});
6864       return Load;
6865     }
6866 
6867     // Generate a sequence for accessing addresses within any 2GiB range within
6868     // the address space. This generates the pattern (PseudoLLA sym), which
6869     // expands to (addi (auipc %pcrel_hi(sym)) %pcrel_lo(auipc)).
6870     return DAG.getNode(RISCVISD::LLA, DL, Ty, Addr);
6871   }
6872   }
6873 }
6874 
lowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const6875 SDValue RISCVTargetLowering::lowerGlobalAddress(SDValue Op,
6876                                                 SelectionDAG &DAG) const {
6877   GlobalAddressSDNode *N = cast<GlobalAddressSDNode>(Op);
6878   assert(N->getOffset() == 0 && "unexpected offset in global node");
6879   const GlobalValue *GV = N->getGlobal();
6880   return getAddr(N, DAG, GV->isDSOLocal(), GV->hasExternalWeakLinkage());
6881 }
6882 
lowerBlockAddress(SDValue Op,SelectionDAG & DAG) const6883 SDValue RISCVTargetLowering::lowerBlockAddress(SDValue Op,
6884                                                SelectionDAG &DAG) const {
6885   BlockAddressSDNode *N = cast<BlockAddressSDNode>(Op);
6886 
6887   return getAddr(N, DAG);
6888 }
6889 
lowerConstantPool(SDValue Op,SelectionDAG & DAG) const6890 SDValue RISCVTargetLowering::lowerConstantPool(SDValue Op,
6891                                                SelectionDAG &DAG) const {
6892   ConstantPoolSDNode *N = cast<ConstantPoolSDNode>(Op);
6893 
6894   return getAddr(N, DAG);
6895 }
6896 
lowerJumpTable(SDValue Op,SelectionDAG & DAG) const6897 SDValue RISCVTargetLowering::lowerJumpTable(SDValue Op,
6898                                             SelectionDAG &DAG) const {
6899   JumpTableSDNode *N = cast<JumpTableSDNode>(Op);
6900 
6901   return getAddr(N, DAG);
6902 }
6903 
getStaticTLSAddr(GlobalAddressSDNode * N,SelectionDAG & DAG,bool UseGOT) const6904 SDValue RISCVTargetLowering::getStaticTLSAddr(GlobalAddressSDNode *N,
6905                                               SelectionDAG &DAG,
6906                                               bool UseGOT) const {
6907   SDLoc DL(N);
6908   EVT Ty = getPointerTy(DAG.getDataLayout());
6909   const GlobalValue *GV = N->getGlobal();
6910   MVT XLenVT = Subtarget.getXLenVT();
6911 
6912   if (UseGOT) {
6913     // Use PC-relative addressing to access the GOT for this TLS symbol, then
6914     // load the address from the GOT and add the thread pointer. This generates
6915     // the pattern (PseudoLA_TLS_IE sym), which expands to
6916     // (ld (auipc %tls_ie_pcrel_hi(sym)) %pcrel_lo(auipc)).
6917     SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0);
6918     SDValue Load =
6919         SDValue(DAG.getMachineNode(RISCV::PseudoLA_TLS_IE, DL, Ty, Addr), 0);
6920     MachineFunction &MF = DAG.getMachineFunction();
6921     MachineMemOperand *MemOp = MF.getMachineMemOperand(
6922         MachinePointerInfo::getGOT(MF),
6923         MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable |
6924             MachineMemOperand::MOInvariant,
6925         LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8));
6926     DAG.setNodeMemRefs(cast<MachineSDNode>(Load.getNode()), {MemOp});
6927 
6928     // Add the thread pointer.
6929     SDValue TPReg = DAG.getRegister(RISCV::X4, XLenVT);
6930     return DAG.getNode(ISD::ADD, DL, Ty, Load, TPReg);
6931   }
6932 
6933   // Generate a sequence for accessing the address relative to the thread
6934   // pointer, with the appropriate adjustment for the thread pointer offset.
6935   // This generates the pattern
6936   // (add (add_tprel (lui %tprel_hi(sym)) tp %tprel_add(sym)) %tprel_lo(sym))
6937   SDValue AddrHi =
6938       DAG.getTargetGlobalAddress(GV, DL, Ty, 0, RISCVII::MO_TPREL_HI);
6939   SDValue AddrAdd =
6940       DAG.getTargetGlobalAddress(GV, DL, Ty, 0, RISCVII::MO_TPREL_ADD);
6941   SDValue AddrLo =
6942       DAG.getTargetGlobalAddress(GV, DL, Ty, 0, RISCVII::MO_TPREL_LO);
6943 
6944   SDValue MNHi = DAG.getNode(RISCVISD::HI, DL, Ty, AddrHi);
6945   SDValue TPReg = DAG.getRegister(RISCV::X4, XLenVT);
6946   SDValue MNAdd =
6947       DAG.getNode(RISCVISD::ADD_TPREL, DL, Ty, MNHi, TPReg, AddrAdd);
6948   return DAG.getNode(RISCVISD::ADD_LO, DL, Ty, MNAdd, AddrLo);
6949 }
6950 
getDynamicTLSAddr(GlobalAddressSDNode * N,SelectionDAG & DAG) const6951 SDValue RISCVTargetLowering::getDynamicTLSAddr(GlobalAddressSDNode *N,
6952                                                SelectionDAG &DAG) const {
6953   SDLoc DL(N);
6954   EVT Ty = getPointerTy(DAG.getDataLayout());
6955   IntegerType *CallTy = Type::getIntNTy(*DAG.getContext(), Ty.getSizeInBits());
6956   const GlobalValue *GV = N->getGlobal();
6957 
6958   // Use a PC-relative addressing mode to access the global dynamic GOT address.
6959   // This generates the pattern (PseudoLA_TLS_GD sym), which expands to
6960   // (addi (auipc %tls_gd_pcrel_hi(sym)) %pcrel_lo(auipc)).
6961   SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0);
6962   SDValue Load =
6963       SDValue(DAG.getMachineNode(RISCV::PseudoLA_TLS_GD, DL, Ty, Addr), 0);
6964 
6965   // Prepare argument list to generate call.
6966   ArgListTy Args;
6967   ArgListEntry Entry;
6968   Entry.Node = Load;
6969   Entry.Ty = CallTy;
6970   Args.push_back(Entry);
6971 
6972   // Setup call to __tls_get_addr.
6973   TargetLowering::CallLoweringInfo CLI(DAG);
6974   CLI.setDebugLoc(DL)
6975       .setChain(DAG.getEntryNode())
6976       .setLibCallee(CallingConv::C, CallTy,
6977                     DAG.getExternalSymbol("__tls_get_addr", Ty),
6978                     std::move(Args));
6979 
6980   return LowerCallTo(CLI).first;
6981 }
6982 
getTLSDescAddr(GlobalAddressSDNode * N,SelectionDAG & DAG) const6983 SDValue RISCVTargetLowering::getTLSDescAddr(GlobalAddressSDNode *N,
6984                                             SelectionDAG &DAG) const {
6985   SDLoc DL(N);
6986   EVT Ty = getPointerTy(DAG.getDataLayout());
6987   const GlobalValue *GV = N->getGlobal();
6988 
6989   // Use a PC-relative addressing mode to access the global dynamic GOT address.
6990   // This generates the pattern (PseudoLA_TLSDESC sym), which expands to
6991   //
6992   // auipc tX, %tlsdesc_hi(symbol)         // R_RISCV_TLSDESC_HI20(symbol)
6993   // lw    tY, tX, %tlsdesc_lo_load(label) // R_RISCV_TLSDESC_LOAD_LO12_I(label)
6994   // addi  a0, tX, %tlsdesc_lo_add(label)  // R_RISCV_TLSDESC_ADD_LO12_I(label)
6995   // jalr  t0, tY                          // R_RISCV_TLSDESC_CALL(label)
6996   SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0);
6997   return SDValue(DAG.getMachineNode(RISCV::PseudoLA_TLSDESC, DL, Ty, Addr), 0);
6998 }
6999 
lowerGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const7000 SDValue RISCVTargetLowering::lowerGlobalTLSAddress(SDValue Op,
7001                                                    SelectionDAG &DAG) const {
7002   GlobalAddressSDNode *N = cast<GlobalAddressSDNode>(Op);
7003   assert(N->getOffset() == 0 && "unexpected offset in global node");
7004 
7005   if (DAG.getTarget().useEmulatedTLS())
7006     return LowerToTLSEmulatedModel(N, DAG);
7007 
7008   TLSModel::Model Model = getTargetMachine().getTLSModel(N->getGlobal());
7009 
7010   if (DAG.getMachineFunction().getFunction().getCallingConv() ==
7011       CallingConv::GHC)
7012     report_fatal_error("In GHC calling convention TLS is not supported");
7013 
7014   SDValue Addr;
7015   switch (Model) {
7016   case TLSModel::LocalExec:
7017     Addr = getStaticTLSAddr(N, DAG, /*UseGOT=*/false);
7018     break;
7019   case TLSModel::InitialExec:
7020     Addr = getStaticTLSAddr(N, DAG, /*UseGOT=*/true);
7021     break;
7022   case TLSModel::LocalDynamic:
7023   case TLSModel::GeneralDynamic:
7024     Addr = DAG.getTarget().useTLSDESC() ? getTLSDescAddr(N, DAG)
7025                                         : getDynamicTLSAddr(N, DAG);
7026     break;
7027   }
7028 
7029   return Addr;
7030 }
7031 
7032 // Return true if Val is equal to (setcc LHS, RHS, CC).
7033 // Return false if Val is the inverse of (setcc LHS, RHS, CC).
7034 // Otherwise, return std::nullopt.
matchSetCC(SDValue LHS,SDValue RHS,ISD::CondCode CC,SDValue Val)7035 static std::optional<bool> matchSetCC(SDValue LHS, SDValue RHS,
7036                                       ISD::CondCode CC, SDValue Val) {
7037   assert(Val->getOpcode() == ISD::SETCC);
7038   SDValue LHS2 = Val.getOperand(0);
7039   SDValue RHS2 = Val.getOperand(1);
7040   ISD::CondCode CC2 = cast<CondCodeSDNode>(Val.getOperand(2))->get();
7041 
7042   if (LHS == LHS2 && RHS == RHS2) {
7043     if (CC == CC2)
7044       return true;
7045     if (CC == ISD::getSetCCInverse(CC2, LHS2.getValueType()))
7046       return false;
7047   } else if (LHS == RHS2 && RHS == LHS2) {
7048     CC2 = ISD::getSetCCSwappedOperands(CC2);
7049     if (CC == CC2)
7050       return true;
7051     if (CC == ISD::getSetCCInverse(CC2, LHS2.getValueType()))
7052       return false;
7053   }
7054 
7055   return std::nullopt;
7056 }
7057 
combineSelectToBinOp(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)7058 static SDValue combineSelectToBinOp(SDNode *N, SelectionDAG &DAG,
7059                                     const RISCVSubtarget &Subtarget) {
7060   SDValue CondV = N->getOperand(0);
7061   SDValue TrueV = N->getOperand(1);
7062   SDValue FalseV = N->getOperand(2);
7063   MVT VT = N->getSimpleValueType(0);
7064   SDLoc DL(N);
7065 
7066   if (!Subtarget.hasConditionalMoveFusion()) {
7067     // (select c, -1, y) -> -c | y
7068     if (isAllOnesConstant(TrueV)) {
7069       SDValue Neg = DAG.getNegative(CondV, DL, VT);
7070       return DAG.getNode(ISD::OR, DL, VT, Neg, FalseV);
7071     }
7072     // (select c, y, -1) -> (c-1) | y
7073     if (isAllOnesConstant(FalseV)) {
7074       SDValue Neg = DAG.getNode(ISD::ADD, DL, VT, CondV,
7075                                 DAG.getAllOnesConstant(DL, VT));
7076       return DAG.getNode(ISD::OR, DL, VT, Neg, TrueV);
7077     }
7078 
7079     // (select c, 0, y) -> (c-1) & y
7080     if (isNullConstant(TrueV)) {
7081       SDValue Neg = DAG.getNode(ISD::ADD, DL, VT, CondV,
7082                                 DAG.getAllOnesConstant(DL, VT));
7083       return DAG.getNode(ISD::AND, DL, VT, Neg, FalseV);
7084     }
7085     // (select c, y, 0) -> -c & y
7086     if (isNullConstant(FalseV)) {
7087       SDValue Neg = DAG.getNegative(CondV, DL, VT);
7088       return DAG.getNode(ISD::AND, DL, VT, Neg, TrueV);
7089     }
7090   }
7091 
7092   // Try to fold (select (setcc lhs, rhs, cc), truev, falsev) into bitwise ops
7093   // when both truev and falsev are also setcc.
7094   if (CondV.getOpcode() == ISD::SETCC && TrueV.getOpcode() == ISD::SETCC &&
7095       FalseV.getOpcode() == ISD::SETCC) {
7096     SDValue LHS = CondV.getOperand(0);
7097     SDValue RHS = CondV.getOperand(1);
7098     ISD::CondCode CC = cast<CondCodeSDNode>(CondV.getOperand(2))->get();
7099 
7100     // (select x, x, y) -> x | y
7101     // (select !x, x, y) -> x & y
7102     if (std::optional<bool> MatchResult = matchSetCC(LHS, RHS, CC, TrueV)) {
7103       return DAG.getNode(*MatchResult ? ISD::OR : ISD::AND, DL, VT, TrueV,
7104                          FalseV);
7105     }
7106     // (select x, y, x) -> x & y
7107     // (select !x, y, x) -> x | y
7108     if (std::optional<bool> MatchResult = matchSetCC(LHS, RHS, CC, FalseV)) {
7109       return DAG.getNode(*MatchResult ? ISD::AND : ISD::OR, DL, VT, TrueV,
7110                          FalseV);
7111     }
7112   }
7113 
7114   return SDValue();
7115 }
7116 
7117 // Transform `binOp (select cond, x, c0), c1` where `c0` and `c1` are constants
7118 // into `select cond, binOp(x, c1), binOp(c0, c1)` if profitable.
7119 // For now we only consider transformation profitable if `binOp(c0, c1)` ends up
7120 // being `0` or `-1`. In such cases we can replace `select` with `and`.
7121 // TODO: Should we also do this if `binOp(c0, c1)` is cheaper to materialize
7122 // than `c0`?
7123 static SDValue
foldBinOpIntoSelectIfProfitable(SDNode * BO,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)7124 foldBinOpIntoSelectIfProfitable(SDNode *BO, SelectionDAG &DAG,
7125                                 const RISCVSubtarget &Subtarget) {
7126   if (Subtarget.hasShortForwardBranchOpt())
7127     return SDValue();
7128 
7129   unsigned SelOpNo = 0;
7130   SDValue Sel = BO->getOperand(0);
7131   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
7132     SelOpNo = 1;
7133     Sel = BO->getOperand(1);
7134   }
7135 
7136   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
7137     return SDValue();
7138 
7139   unsigned ConstSelOpNo = 1;
7140   unsigned OtherSelOpNo = 2;
7141   if (!dyn_cast<ConstantSDNode>(Sel->getOperand(ConstSelOpNo))) {
7142     ConstSelOpNo = 2;
7143     OtherSelOpNo = 1;
7144   }
7145   SDValue ConstSelOp = Sel->getOperand(ConstSelOpNo);
7146   ConstantSDNode *ConstSelOpNode = dyn_cast<ConstantSDNode>(ConstSelOp);
7147   if (!ConstSelOpNode || ConstSelOpNode->isOpaque())
7148     return SDValue();
7149 
7150   SDValue ConstBinOp = BO->getOperand(SelOpNo ^ 1);
7151   ConstantSDNode *ConstBinOpNode = dyn_cast<ConstantSDNode>(ConstBinOp);
7152   if (!ConstBinOpNode || ConstBinOpNode->isOpaque())
7153     return SDValue();
7154 
7155   SDLoc DL(Sel);
7156   EVT VT = BO->getValueType(0);
7157 
7158   SDValue NewConstOps[2] = {ConstSelOp, ConstBinOp};
7159   if (SelOpNo == 1)
7160     std::swap(NewConstOps[0], NewConstOps[1]);
7161 
7162   SDValue NewConstOp =
7163       DAG.FoldConstantArithmetic(BO->getOpcode(), DL, VT, NewConstOps);
7164   if (!NewConstOp)
7165     return SDValue();
7166 
7167   const APInt &NewConstAPInt = NewConstOp->getAsAPIntVal();
7168   if (!NewConstAPInt.isZero() && !NewConstAPInt.isAllOnes())
7169     return SDValue();
7170 
7171   SDValue OtherSelOp = Sel->getOperand(OtherSelOpNo);
7172   SDValue NewNonConstOps[2] = {OtherSelOp, ConstBinOp};
7173   if (SelOpNo == 1)
7174     std::swap(NewNonConstOps[0], NewNonConstOps[1]);
7175   SDValue NewNonConstOp = DAG.getNode(BO->getOpcode(), DL, VT, NewNonConstOps);
7176 
7177   SDValue NewT = (ConstSelOpNo == 1) ? NewConstOp : NewNonConstOp;
7178   SDValue NewF = (ConstSelOpNo == 1) ? NewNonConstOp : NewConstOp;
7179   return DAG.getSelect(DL, VT, Sel.getOperand(0), NewT, NewF);
7180 }
7181 
lowerSELECT(SDValue Op,SelectionDAG & DAG) const7182 SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
7183   SDValue CondV = Op.getOperand(0);
7184   SDValue TrueV = Op.getOperand(1);
7185   SDValue FalseV = Op.getOperand(2);
7186   SDLoc DL(Op);
7187   MVT VT = Op.getSimpleValueType();
7188   MVT XLenVT = Subtarget.getXLenVT();
7189 
7190   // Lower vector SELECTs to VSELECTs by splatting the condition.
7191   if (VT.isVector()) {
7192     MVT SplatCondVT = VT.changeVectorElementType(MVT::i1);
7193     SDValue CondSplat = DAG.getSplat(SplatCondVT, DL, CondV);
7194     return DAG.getNode(ISD::VSELECT, DL, VT, CondSplat, TrueV, FalseV);
7195   }
7196 
7197   // When Zicond or XVentanaCondOps is present, emit CZERO_EQZ and CZERO_NEZ
7198   // nodes to implement the SELECT. Performing the lowering here allows for
7199   // greater control over when CZERO_{EQZ/NEZ} are used vs another branchless
7200   // sequence or RISCVISD::SELECT_CC node (branch-based select).
7201   if ((Subtarget.hasStdExtZicond() || Subtarget.hasVendorXVentanaCondOps()) &&
7202       VT.isScalarInteger()) {
7203     // (select c, t, 0) -> (czero_eqz t, c)
7204     if (isNullConstant(FalseV))
7205       return DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, CondV);
7206     // (select c, 0, f) -> (czero_nez f, c)
7207     if (isNullConstant(TrueV))
7208       return DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV, CondV);
7209 
7210     // (select c, (and f, x), f) -> (or (and f, x), (czero_nez f, c))
7211     if (TrueV.getOpcode() == ISD::AND &&
7212         (TrueV.getOperand(0) == FalseV || TrueV.getOperand(1) == FalseV))
7213       return DAG.getNode(
7214           ISD::OR, DL, VT, TrueV,
7215           DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV, CondV));
7216     // (select c, t, (and t, x)) -> (or (czero_eqz t, c), (and t, x))
7217     if (FalseV.getOpcode() == ISD::AND &&
7218         (FalseV.getOperand(0) == TrueV || FalseV.getOperand(1) == TrueV))
7219       return DAG.getNode(
7220           ISD::OR, DL, VT, FalseV,
7221           DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, CondV));
7222 
7223     // Try some other optimizations before falling back to generic lowering.
7224     if (SDValue V = combineSelectToBinOp(Op.getNode(), DAG, Subtarget))
7225       return V;
7226 
7227     // (select c, t, f) -> (or (czero_eqz t, c), (czero_nez f, c))
7228     // Unless we have the short forward branch optimization.
7229     if (!Subtarget.hasConditionalMoveFusion())
7230       return DAG.getNode(
7231           ISD::OR, DL, VT,
7232           DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, CondV),
7233           DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV, CondV));
7234   }
7235 
7236   if (SDValue V = combineSelectToBinOp(Op.getNode(), DAG, Subtarget))
7237     return V;
7238 
7239   if (Op.hasOneUse()) {
7240     unsigned UseOpc = Op->use_begin()->getOpcode();
7241     if (isBinOp(UseOpc) && DAG.isSafeToSpeculativelyExecute(UseOpc)) {
7242       SDNode *BinOp = *Op->use_begin();
7243       if (SDValue NewSel = foldBinOpIntoSelectIfProfitable(*Op->use_begin(),
7244                                                            DAG, Subtarget)) {
7245         DAG.ReplaceAllUsesWith(BinOp, &NewSel);
7246         return lowerSELECT(NewSel, DAG);
7247       }
7248     }
7249   }
7250 
7251   // (select cc, 1.0, 0.0) -> (sint_to_fp (zext cc))
7252   // (select cc, 0.0, 1.0) -> (sint_to_fp (zext (xor cc, 1)))
7253   const ConstantFPSDNode *FPTV = dyn_cast<ConstantFPSDNode>(TrueV);
7254   const ConstantFPSDNode *FPFV = dyn_cast<ConstantFPSDNode>(FalseV);
7255   if (FPTV && FPFV) {
7256     if (FPTV->isExactlyValue(1.0) && FPFV->isExactlyValue(0.0))
7257       return DAG.getNode(ISD::SINT_TO_FP, DL, VT, CondV);
7258     if (FPTV->isExactlyValue(0.0) && FPFV->isExactlyValue(1.0)) {
7259       SDValue XOR = DAG.getNode(ISD::XOR, DL, XLenVT, CondV,
7260                                 DAG.getConstant(1, DL, XLenVT));
7261       return DAG.getNode(ISD::SINT_TO_FP, DL, VT, XOR);
7262     }
7263   }
7264 
7265   // If the condition is not an integer SETCC which operates on XLenVT, we need
7266   // to emit a RISCVISD::SELECT_CC comparing the condition to zero. i.e.:
7267   // (select condv, truev, falsev)
7268   // -> (riscvisd::select_cc condv, zero, setne, truev, falsev)
7269   if (CondV.getOpcode() != ISD::SETCC ||
7270       CondV.getOperand(0).getSimpleValueType() != XLenVT) {
7271     SDValue Zero = DAG.getConstant(0, DL, XLenVT);
7272     SDValue SetNE = DAG.getCondCode(ISD::SETNE);
7273 
7274     SDValue Ops[] = {CondV, Zero, SetNE, TrueV, FalseV};
7275 
7276     return DAG.getNode(RISCVISD::SELECT_CC, DL, VT, Ops);
7277   }
7278 
7279   // If the CondV is the output of a SETCC node which operates on XLenVT inputs,
7280   // then merge the SETCC node into the lowered RISCVISD::SELECT_CC to take
7281   // advantage of the integer compare+branch instructions. i.e.:
7282   // (select (setcc lhs, rhs, cc), truev, falsev)
7283   // -> (riscvisd::select_cc lhs, rhs, cc, truev, falsev)
7284   SDValue LHS = CondV.getOperand(0);
7285   SDValue RHS = CondV.getOperand(1);
7286   ISD::CondCode CCVal = cast<CondCodeSDNode>(CondV.getOperand(2))->get();
7287 
7288   // Special case for a select of 2 constants that have a diffence of 1.
7289   // Normally this is done by DAGCombine, but if the select is introduced by
7290   // type legalization or op legalization, we miss it. Restricting to SETLT
7291   // case for now because that is what signed saturating add/sub need.
7292   // FIXME: We don't need the condition to be SETLT or even a SETCC,
7293   // but we would probably want to swap the true/false values if the condition
7294   // is SETGE/SETLE to avoid an XORI.
7295   if (isa<ConstantSDNode>(TrueV) && isa<ConstantSDNode>(FalseV) &&
7296       CCVal == ISD::SETLT) {
7297     const APInt &TrueVal = TrueV->getAsAPIntVal();
7298     const APInt &FalseVal = FalseV->getAsAPIntVal();
7299     if (TrueVal - 1 == FalseVal)
7300       return DAG.getNode(ISD::ADD, DL, VT, CondV, FalseV);
7301     if (TrueVal + 1 == FalseVal)
7302       return DAG.getNode(ISD::SUB, DL, VT, FalseV, CondV);
7303   }
7304 
7305   translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
7306   // 1 < x ? x : 1 -> 0 < x ? x : 1
7307   if (isOneConstant(LHS) && (CCVal == ISD::SETLT || CCVal == ISD::SETULT) &&
7308       RHS == TrueV && LHS == FalseV) {
7309     LHS = DAG.getConstant(0, DL, VT);
7310     // 0 <u x is the same as x != 0.
7311     if (CCVal == ISD::SETULT) {
7312       std::swap(LHS, RHS);
7313       CCVal = ISD::SETNE;
7314     }
7315   }
7316 
7317   // x <s -1 ? x : -1 -> x <s 0 ? x : -1
7318   if (isAllOnesConstant(RHS) && CCVal == ISD::SETLT && LHS == TrueV &&
7319       RHS == FalseV) {
7320     RHS = DAG.getConstant(0, DL, VT);
7321   }
7322 
7323   SDValue TargetCC = DAG.getCondCode(CCVal);
7324 
7325   if (isa<ConstantSDNode>(TrueV) && !isa<ConstantSDNode>(FalseV)) {
7326     // (select (setcc lhs, rhs, CC), constant, falsev)
7327     // -> (select (setcc lhs, rhs, InverseCC), falsev, constant)
7328     std::swap(TrueV, FalseV);
7329     TargetCC = DAG.getCondCode(ISD::getSetCCInverse(CCVal, LHS.getValueType()));
7330   }
7331 
7332   SDValue Ops[] = {LHS, RHS, TargetCC, TrueV, FalseV};
7333   return DAG.getNode(RISCVISD::SELECT_CC, DL, VT, Ops);
7334 }
7335 
lowerBRCOND(SDValue Op,SelectionDAG & DAG) const7336 SDValue RISCVTargetLowering::lowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
7337   SDValue CondV = Op.getOperand(1);
7338   SDLoc DL(Op);
7339   MVT XLenVT = Subtarget.getXLenVT();
7340 
7341   if (CondV.getOpcode() == ISD::SETCC &&
7342       CondV.getOperand(0).getValueType() == XLenVT) {
7343     SDValue LHS = CondV.getOperand(0);
7344     SDValue RHS = CondV.getOperand(1);
7345     ISD::CondCode CCVal = cast<CondCodeSDNode>(CondV.getOperand(2))->get();
7346 
7347     translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
7348 
7349     SDValue TargetCC = DAG.getCondCode(CCVal);
7350     return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0),
7351                        LHS, RHS, TargetCC, Op.getOperand(2));
7352   }
7353 
7354   return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0),
7355                      CondV, DAG.getConstant(0, DL, XLenVT),
7356                      DAG.getCondCode(ISD::SETNE), Op.getOperand(2));
7357 }
7358 
lowerVASTART(SDValue Op,SelectionDAG & DAG) const7359 SDValue RISCVTargetLowering::lowerVASTART(SDValue Op, SelectionDAG &DAG) const {
7360   MachineFunction &MF = DAG.getMachineFunction();
7361   RISCVMachineFunctionInfo *FuncInfo = MF.getInfo<RISCVMachineFunctionInfo>();
7362 
7363   SDLoc DL(Op);
7364   SDValue FI = DAG.getFrameIndex(FuncInfo->getVarArgsFrameIndex(),
7365                                  getPointerTy(MF.getDataLayout()));
7366 
7367   // vastart just stores the address of the VarArgsFrameIndex slot into the
7368   // memory location argument.
7369   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
7370   return DAG.getStore(Op.getOperand(0), DL, FI, Op.getOperand(1),
7371                       MachinePointerInfo(SV));
7372 }
7373 
lowerFRAMEADDR(SDValue Op,SelectionDAG & DAG) const7374 SDValue RISCVTargetLowering::lowerFRAMEADDR(SDValue Op,
7375                                             SelectionDAG &DAG) const {
7376   const RISCVRegisterInfo &RI = *Subtarget.getRegisterInfo();
7377   MachineFunction &MF = DAG.getMachineFunction();
7378   MachineFrameInfo &MFI = MF.getFrameInfo();
7379   MFI.setFrameAddressIsTaken(true);
7380   Register FrameReg = RI.getFrameRegister(MF);
7381   int XLenInBytes = Subtarget.getXLen() / 8;
7382 
7383   EVT VT = Op.getValueType();
7384   SDLoc DL(Op);
7385   SDValue FrameAddr = DAG.getCopyFromReg(DAG.getEntryNode(), DL, FrameReg, VT);
7386   unsigned Depth = Op.getConstantOperandVal(0);
7387   while (Depth--) {
7388     int Offset = -(XLenInBytes * 2);
7389     SDValue Ptr = DAG.getNode(ISD::ADD, DL, VT, FrameAddr,
7390                               DAG.getIntPtrConstant(Offset, DL));
7391     FrameAddr =
7392         DAG.getLoad(VT, DL, DAG.getEntryNode(), Ptr, MachinePointerInfo());
7393   }
7394   return FrameAddr;
7395 }
7396 
lowerRETURNADDR(SDValue Op,SelectionDAG & DAG) const7397 SDValue RISCVTargetLowering::lowerRETURNADDR(SDValue Op,
7398                                              SelectionDAG &DAG) const {
7399   const RISCVRegisterInfo &RI = *Subtarget.getRegisterInfo();
7400   MachineFunction &MF = DAG.getMachineFunction();
7401   MachineFrameInfo &MFI = MF.getFrameInfo();
7402   MFI.setReturnAddressIsTaken(true);
7403   MVT XLenVT = Subtarget.getXLenVT();
7404   int XLenInBytes = Subtarget.getXLen() / 8;
7405 
7406   if (verifyReturnAddressArgumentIsConstant(Op, DAG))
7407     return SDValue();
7408 
7409   EVT VT = Op.getValueType();
7410   SDLoc DL(Op);
7411   unsigned Depth = Op.getConstantOperandVal(0);
7412   if (Depth) {
7413     int Off = -XLenInBytes;
7414     SDValue FrameAddr = lowerFRAMEADDR(Op, DAG);
7415     SDValue Offset = DAG.getConstant(Off, DL, VT);
7416     return DAG.getLoad(VT, DL, DAG.getEntryNode(),
7417                        DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset),
7418                        MachinePointerInfo());
7419   }
7420 
7421   // Return the value of the return address register, marking it an implicit
7422   // live-in.
7423   Register Reg = MF.addLiveIn(RI.getRARegister(), getRegClassFor(XLenVT));
7424   return DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, XLenVT);
7425 }
7426 
lowerShiftLeftParts(SDValue Op,SelectionDAG & DAG) const7427 SDValue RISCVTargetLowering::lowerShiftLeftParts(SDValue Op,
7428                                                  SelectionDAG &DAG) const {
7429   SDLoc DL(Op);
7430   SDValue Lo = Op.getOperand(0);
7431   SDValue Hi = Op.getOperand(1);
7432   SDValue Shamt = Op.getOperand(2);
7433   EVT VT = Lo.getValueType();
7434 
7435   // if Shamt-XLEN < 0: // Shamt < XLEN
7436   //   Lo = Lo << Shamt
7437   //   Hi = (Hi << Shamt) | ((Lo >>u 1) >>u (XLEN-1 - Shamt))
7438   // else:
7439   //   Lo = 0
7440   //   Hi = Lo << (Shamt-XLEN)
7441 
7442   SDValue Zero = DAG.getConstant(0, DL, VT);
7443   SDValue One = DAG.getConstant(1, DL, VT);
7444   SDValue MinusXLen = DAG.getConstant(-(int)Subtarget.getXLen(), DL, VT);
7445   SDValue XLenMinus1 = DAG.getConstant(Subtarget.getXLen() - 1, DL, VT);
7446   SDValue ShamtMinusXLen = DAG.getNode(ISD::ADD, DL, VT, Shamt, MinusXLen);
7447   SDValue XLenMinus1Shamt = DAG.getNode(ISD::SUB, DL, VT, XLenMinus1, Shamt);
7448 
7449   SDValue LoTrue = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt);
7450   SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo, One);
7451   SDValue ShiftRightLo =
7452       DAG.getNode(ISD::SRL, DL, VT, ShiftRight1Lo, XLenMinus1Shamt);
7453   SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, Shamt);
7454   SDValue HiTrue = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
7455   SDValue HiFalse = DAG.getNode(ISD::SHL, DL, VT, Lo, ShamtMinusXLen);
7456 
7457   SDValue CC = DAG.getSetCC(DL, VT, ShamtMinusXLen, Zero, ISD::SETLT);
7458 
7459   Lo = DAG.getNode(ISD::SELECT, DL, VT, CC, LoTrue, Zero);
7460   Hi = DAG.getNode(ISD::SELECT, DL, VT, CC, HiTrue, HiFalse);
7461 
7462   SDValue Parts[2] = {Lo, Hi};
7463   return DAG.getMergeValues(Parts, DL);
7464 }
7465 
lowerShiftRightParts(SDValue Op,SelectionDAG & DAG,bool IsSRA) const7466 SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
7467                                                   bool IsSRA) const {
7468   SDLoc DL(Op);
7469   SDValue Lo = Op.getOperand(0);
7470   SDValue Hi = Op.getOperand(1);
7471   SDValue Shamt = Op.getOperand(2);
7472   EVT VT = Lo.getValueType();
7473 
7474   // SRA expansion:
7475   //   if Shamt-XLEN < 0: // Shamt < XLEN
7476   //     Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - ShAmt))
7477   //     Hi = Hi >>s Shamt
7478   //   else:
7479   //     Lo = Hi >>s (Shamt-XLEN);
7480   //     Hi = Hi >>s (XLEN-1)
7481   //
7482   // SRL expansion:
7483   //   if Shamt-XLEN < 0: // Shamt < XLEN
7484   //     Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - ShAmt))
7485   //     Hi = Hi >>u Shamt
7486   //   else:
7487   //     Lo = Hi >>u (Shamt-XLEN);
7488   //     Hi = 0;
7489 
7490   unsigned ShiftRightOp = IsSRA ? ISD::SRA : ISD::SRL;
7491 
7492   SDValue Zero = DAG.getConstant(0, DL, VT);
7493   SDValue One = DAG.getConstant(1, DL, VT);
7494   SDValue MinusXLen = DAG.getConstant(-(int)Subtarget.getXLen(), DL, VT);
7495   SDValue XLenMinus1 = DAG.getConstant(Subtarget.getXLen() - 1, DL, VT);
7496   SDValue ShamtMinusXLen = DAG.getNode(ISD::ADD, DL, VT, Shamt, MinusXLen);
7497   SDValue XLenMinus1Shamt = DAG.getNode(ISD::SUB, DL, VT, XLenMinus1, Shamt);
7498 
7499   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt);
7500   SDValue ShiftLeftHi1 = DAG.getNode(ISD::SHL, DL, VT, Hi, One);
7501   SDValue ShiftLeftHi =
7502       DAG.getNode(ISD::SHL, DL, VT, ShiftLeftHi1, XLenMinus1Shamt);
7503   SDValue LoTrue = DAG.getNode(ISD::OR, DL, VT, ShiftRightLo, ShiftLeftHi);
7504   SDValue HiTrue = DAG.getNode(ShiftRightOp, DL, VT, Hi, Shamt);
7505   SDValue LoFalse = DAG.getNode(ShiftRightOp, DL, VT, Hi, ShamtMinusXLen);
7506   SDValue HiFalse =
7507       IsSRA ? DAG.getNode(ISD::SRA, DL, VT, Hi, XLenMinus1) : Zero;
7508 
7509   SDValue CC = DAG.getSetCC(DL, VT, ShamtMinusXLen, Zero, ISD::SETLT);
7510 
7511   Lo = DAG.getNode(ISD::SELECT, DL, VT, CC, LoTrue, LoFalse);
7512   Hi = DAG.getNode(ISD::SELECT, DL, VT, CC, HiTrue, HiFalse);
7513 
7514   SDValue Parts[2] = {Lo, Hi};
7515   return DAG.getMergeValues(Parts, DL);
7516 }
7517 
7518 // Lower splats of i1 types to SETCC. For each mask vector type, we have a
7519 // legal equivalently-sized i8 type, so we can use that as a go-between.
lowerVectorMaskSplat(SDValue Op,SelectionDAG & DAG) const7520 SDValue RISCVTargetLowering::lowerVectorMaskSplat(SDValue Op,
7521                                                   SelectionDAG &DAG) const {
7522   SDLoc DL(Op);
7523   MVT VT = Op.getSimpleValueType();
7524   SDValue SplatVal = Op.getOperand(0);
7525   // All-zeros or all-ones splats are handled specially.
7526   if (ISD::isConstantSplatVectorAllOnes(Op.getNode())) {
7527     SDValue VL = getDefaultScalableVLOps(VT, DL, DAG, Subtarget).second;
7528     return DAG.getNode(RISCVISD::VMSET_VL, DL, VT, VL);
7529   }
7530   if (ISD::isConstantSplatVectorAllZeros(Op.getNode())) {
7531     SDValue VL = getDefaultScalableVLOps(VT, DL, DAG, Subtarget).second;
7532     return DAG.getNode(RISCVISD::VMCLR_VL, DL, VT, VL);
7533   }
7534   MVT InterVT = VT.changeVectorElementType(MVT::i8);
7535   SplatVal = DAG.getNode(ISD::AND, DL, SplatVal.getValueType(), SplatVal,
7536                          DAG.getConstant(1, DL, SplatVal.getValueType()));
7537   SDValue LHS = DAG.getSplatVector(InterVT, DL, SplatVal);
7538   SDValue Zero = DAG.getConstant(0, DL, InterVT);
7539   return DAG.getSetCC(DL, VT, LHS, Zero, ISD::SETNE);
7540 }
7541 
7542 // Custom-lower a SPLAT_VECTOR_PARTS where XLEN<SEW, as the SEW element type is
7543 // illegal (currently only vXi64 RV32).
7544 // FIXME: We could also catch non-constant sign-extended i32 values and lower
7545 // them to VMV_V_X_VL.
lowerSPLAT_VECTOR_PARTS(SDValue Op,SelectionDAG & DAG) const7546 SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
7547                                                      SelectionDAG &DAG) const {
7548   SDLoc DL(Op);
7549   MVT VecVT = Op.getSimpleValueType();
7550   assert(!Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64 &&
7551          "Unexpected SPLAT_VECTOR_PARTS lowering");
7552 
7553   assert(Op.getNumOperands() == 2 && "Unexpected number of operands!");
7554   SDValue Lo = Op.getOperand(0);
7555   SDValue Hi = Op.getOperand(1);
7556 
7557   MVT ContainerVT = VecVT;
7558   if (VecVT.isFixedLengthVector())
7559     ContainerVT = getContainerForFixedLengthVector(VecVT);
7560 
7561   auto VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second;
7562 
7563   SDValue Res =
7564       splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG);
7565 
7566   if (VecVT.isFixedLengthVector())
7567     Res = convertFromScalableVector(VecVT, Res, DAG, Subtarget);
7568 
7569   return Res;
7570 }
7571 
7572 // Custom-lower extensions from mask vectors by using a vselect either with 1
7573 // for zero/any-extension or -1 for sign-extension:
7574 //   (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
7575 // Note that any-extension is lowered identically to zero-extension.
lowerVectorMaskExt(SDValue Op,SelectionDAG & DAG,int64_t ExtTrueVal) const7576 SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
7577                                                 int64_t ExtTrueVal) const {
7578   SDLoc DL(Op);
7579   MVT VecVT = Op.getSimpleValueType();
7580   SDValue Src = Op.getOperand(0);
7581   // Only custom-lower extensions from mask types
7582   assert(Src.getValueType().isVector() &&
7583          Src.getValueType().getVectorElementType() == MVT::i1);
7584 
7585   if (VecVT.isScalableVector()) {
7586     SDValue SplatZero = DAG.getConstant(0, DL, VecVT);
7587     SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, VecVT);
7588     return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero);
7589   }
7590 
7591   MVT ContainerVT = getContainerForFixedLengthVector(VecVT);
7592   MVT I1ContainerVT =
7593       MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
7594 
7595   SDValue CC = convertToScalableVector(I1ContainerVT, Src, DAG, Subtarget);
7596 
7597   SDValue VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second;
7598 
7599   MVT XLenVT = Subtarget.getXLenVT();
7600   SDValue SplatZero = DAG.getConstant(0, DL, XLenVT);
7601   SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, XLenVT);
7602 
7603   SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
7604                           DAG.getUNDEF(ContainerVT), SplatZero, VL);
7605   SplatTrueVal = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
7606                              DAG.getUNDEF(ContainerVT), SplatTrueVal, VL);
7607   SDValue Select =
7608       DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, CC, SplatTrueVal,
7609                   SplatZero, DAG.getUNDEF(ContainerVT), VL);
7610 
7611   return convertFromScalableVector(VecVT, Select, DAG, Subtarget);
7612 }
7613 
lowerFixedLengthVectorExtendToRVV(SDValue Op,SelectionDAG & DAG,unsigned ExtendOpc) const7614 SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV(
7615     SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const {
7616   MVT ExtVT = Op.getSimpleValueType();
7617   // Only custom-lower extensions from fixed-length vector types.
7618   if (!ExtVT.isFixedLengthVector())
7619     return Op;
7620   MVT VT = Op.getOperand(0).getSimpleValueType();
7621   // Grab the canonical container type for the extended type. Infer the smaller
7622   // type from that to ensure the same number of vector elements, as we know
7623   // the LMUL will be sufficient to hold the smaller type.
7624   MVT ContainerExtVT = getContainerForFixedLengthVector(ExtVT);
7625   // Get the extended container type manually to ensure the same number of
7626   // vector elements between source and dest.
7627   MVT ContainerVT = MVT::getVectorVT(VT.getVectorElementType(),
7628                                      ContainerExtVT.getVectorElementCount());
7629 
7630   SDValue Op1 =
7631       convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
7632 
7633   SDLoc DL(Op);
7634   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
7635 
7636   SDValue Ext = DAG.getNode(ExtendOpc, DL, ContainerExtVT, Op1, Mask, VL);
7637 
7638   return convertFromScalableVector(ExtVT, Ext, DAG, Subtarget);
7639 }
7640 
7641 // Custom-lower truncations from vectors to mask vectors by using a mask and a
7642 // setcc operation:
7643 //   (vXi1 = trunc vXiN vec) -> (vXi1 = setcc (and vec, 1), 0, ne)
lowerVectorMaskTruncLike(SDValue Op,SelectionDAG & DAG) const7644 SDValue RISCVTargetLowering::lowerVectorMaskTruncLike(SDValue Op,
7645                                                       SelectionDAG &DAG) const {
7646   bool IsVPTrunc = Op.getOpcode() == ISD::VP_TRUNCATE;
7647   SDLoc DL(Op);
7648   EVT MaskVT = Op.getValueType();
7649   // Only expect to custom-lower truncations to mask types
7650   assert(MaskVT.isVector() && MaskVT.getVectorElementType() == MVT::i1 &&
7651          "Unexpected type for vector mask lowering");
7652   SDValue Src = Op.getOperand(0);
7653   MVT VecVT = Src.getSimpleValueType();
7654   SDValue Mask, VL;
7655   if (IsVPTrunc) {
7656     Mask = Op.getOperand(1);
7657     VL = Op.getOperand(2);
7658   }
7659   // If this is a fixed vector, we need to convert it to a scalable vector.
7660   MVT ContainerVT = VecVT;
7661 
7662   if (VecVT.isFixedLengthVector()) {
7663     ContainerVT = getContainerForFixedLengthVector(VecVT);
7664     Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
7665     if (IsVPTrunc) {
7666       MVT MaskContainerVT =
7667           getContainerForFixedLengthVector(Mask.getSimpleValueType());
7668       Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
7669     }
7670   }
7671 
7672   if (!IsVPTrunc) {
7673     std::tie(Mask, VL) =
7674         getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
7675   }
7676 
7677   SDValue SplatOne = DAG.getConstant(1, DL, Subtarget.getXLenVT());
7678   SDValue SplatZero = DAG.getConstant(0, DL, Subtarget.getXLenVT());
7679 
7680   SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
7681                          DAG.getUNDEF(ContainerVT), SplatOne, VL);
7682   SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
7683                           DAG.getUNDEF(ContainerVT), SplatZero, VL);
7684 
7685   MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1);
7686   SDValue Trunc = DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT, Src, SplatOne,
7687                               DAG.getUNDEF(ContainerVT), Mask, VL);
7688   Trunc = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskContainerVT,
7689                       {Trunc, SplatZero, DAG.getCondCode(ISD::SETNE),
7690                        DAG.getUNDEF(MaskContainerVT), Mask, VL});
7691   if (MaskVT.isFixedLengthVector())
7692     Trunc = convertFromScalableVector(MaskVT, Trunc, DAG, Subtarget);
7693   return Trunc;
7694 }
7695 
lowerVectorTruncLike(SDValue Op,SelectionDAG & DAG) const7696 SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
7697                                                   SelectionDAG &DAG) const {
7698   bool IsVPTrunc = Op.getOpcode() == ISD::VP_TRUNCATE;
7699   SDLoc DL(Op);
7700 
7701   MVT VT = Op.getSimpleValueType();
7702   // Only custom-lower vector truncates
7703   assert(VT.isVector() && "Unexpected type for vector truncate lowering");
7704 
7705   // Truncates to mask types are handled differently
7706   if (VT.getVectorElementType() == MVT::i1)
7707     return lowerVectorMaskTruncLike(Op, DAG);
7708 
7709   // RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary
7710   // truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which
7711   // truncate by one power of two at a time.
7712   MVT DstEltVT = VT.getVectorElementType();
7713 
7714   SDValue Src = Op.getOperand(0);
7715   MVT SrcVT = Src.getSimpleValueType();
7716   MVT SrcEltVT = SrcVT.getVectorElementType();
7717 
7718   assert(DstEltVT.bitsLT(SrcEltVT) && isPowerOf2_64(DstEltVT.getSizeInBits()) &&
7719          isPowerOf2_64(SrcEltVT.getSizeInBits()) &&
7720          "Unexpected vector truncate lowering");
7721 
7722   MVT ContainerVT = SrcVT;
7723   SDValue Mask, VL;
7724   if (IsVPTrunc) {
7725     Mask = Op.getOperand(1);
7726     VL = Op.getOperand(2);
7727   }
7728   if (SrcVT.isFixedLengthVector()) {
7729     ContainerVT = getContainerForFixedLengthVector(SrcVT);
7730     Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
7731     if (IsVPTrunc) {
7732       MVT MaskVT = getMaskTypeFor(ContainerVT);
7733       Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
7734     }
7735   }
7736 
7737   SDValue Result = Src;
7738   if (!IsVPTrunc) {
7739     std::tie(Mask, VL) =
7740         getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
7741   }
7742 
7743   LLVMContext &Context = *DAG.getContext();
7744   const ElementCount Count = ContainerVT.getVectorElementCount();
7745   do {
7746     SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
7747     EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
7748     Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
7749                          Mask, VL);
7750   } while (SrcEltVT != DstEltVT);
7751 
7752   if (SrcVT.isFixedLengthVector())
7753     Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
7754 
7755   return Result;
7756 }
7757 
7758 SDValue
lowerStrictFPExtendOrRoundLike(SDValue Op,SelectionDAG & DAG) const7759 RISCVTargetLowering::lowerStrictFPExtendOrRoundLike(SDValue Op,
7760                                                     SelectionDAG &DAG) const {
7761   SDLoc DL(Op);
7762   SDValue Chain = Op.getOperand(0);
7763   SDValue Src = Op.getOperand(1);
7764   MVT VT = Op.getSimpleValueType();
7765   MVT SrcVT = Src.getSimpleValueType();
7766   MVT ContainerVT = VT;
7767   if (VT.isFixedLengthVector()) {
7768     MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
7769     ContainerVT =
7770         SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
7771     Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
7772   }
7773 
7774   auto [Mask, VL] = getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
7775 
7776   // RVV can only widen/truncate fp to types double/half the size as the source.
7777   if ((VT.getVectorElementType() == MVT::f64 &&
7778        SrcVT.getVectorElementType() == MVT::f16) ||
7779       (VT.getVectorElementType() == MVT::f16 &&
7780        SrcVT.getVectorElementType() == MVT::f64)) {
7781     // For double rounding, the intermediate rounding should be round-to-odd.
7782     unsigned InterConvOpc = Op.getOpcode() == ISD::STRICT_FP_EXTEND
7783                                 ? RISCVISD::STRICT_FP_EXTEND_VL
7784                                 : RISCVISD::STRICT_VFNCVT_ROD_VL;
7785     MVT InterVT = ContainerVT.changeVectorElementType(MVT::f32);
7786     Src = DAG.getNode(InterConvOpc, DL, DAG.getVTList(InterVT, MVT::Other),
7787                       Chain, Src, Mask, VL);
7788     Chain = Src.getValue(1);
7789   }
7790 
7791   unsigned ConvOpc = Op.getOpcode() == ISD::STRICT_FP_EXTEND
7792                          ? RISCVISD::STRICT_FP_EXTEND_VL
7793                          : RISCVISD::STRICT_FP_ROUND_VL;
7794   SDValue Res = DAG.getNode(ConvOpc, DL, DAG.getVTList(ContainerVT, MVT::Other),
7795                             Chain, Src, Mask, VL);
7796   if (VT.isFixedLengthVector()) {
7797     // StrictFP operations have two result values. Their lowered result should
7798     // have same result count.
7799     SDValue SubVec = convertFromScalableVector(VT, Res, DAG, Subtarget);
7800     Res = DAG.getMergeValues({SubVec, Res.getValue(1)}, DL);
7801   }
7802   return Res;
7803 }
7804 
7805 SDValue
lowerVectorFPExtendOrRoundLike(SDValue Op,SelectionDAG & DAG) const7806 RISCVTargetLowering::lowerVectorFPExtendOrRoundLike(SDValue Op,
7807                                                     SelectionDAG &DAG) const {
7808   bool IsVP =
7809       Op.getOpcode() == ISD::VP_FP_ROUND || Op.getOpcode() == ISD::VP_FP_EXTEND;
7810   bool IsExtend =
7811       Op.getOpcode() == ISD::VP_FP_EXTEND || Op.getOpcode() == ISD::FP_EXTEND;
7812   // RVV can only do truncate fp to types half the size as the source. We
7813   // custom-lower f64->f16 rounds via RVV's round-to-odd float
7814   // conversion instruction.
7815   SDLoc DL(Op);
7816   MVT VT = Op.getSimpleValueType();
7817 
7818   assert(VT.isVector() && "Unexpected type for vector truncate lowering");
7819 
7820   SDValue Src = Op.getOperand(0);
7821   MVT SrcVT = Src.getSimpleValueType();
7822 
7823   bool IsDirectExtend = IsExtend && (VT.getVectorElementType() != MVT::f64 ||
7824                                      SrcVT.getVectorElementType() != MVT::f16);
7825   bool IsDirectTrunc = !IsExtend && (VT.getVectorElementType() != MVT::f16 ||
7826                                      SrcVT.getVectorElementType() != MVT::f64);
7827 
7828   bool IsDirectConv = IsDirectExtend || IsDirectTrunc;
7829 
7830   // Prepare any fixed-length vector operands.
7831   MVT ContainerVT = VT;
7832   SDValue Mask, VL;
7833   if (IsVP) {
7834     Mask = Op.getOperand(1);
7835     VL = Op.getOperand(2);
7836   }
7837   if (VT.isFixedLengthVector()) {
7838     MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
7839     ContainerVT =
7840         SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
7841     Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
7842     if (IsVP) {
7843       MVT MaskVT = getMaskTypeFor(ContainerVT);
7844       Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
7845     }
7846   }
7847 
7848   if (!IsVP)
7849     std::tie(Mask, VL) =
7850         getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
7851 
7852   unsigned ConvOpc = IsExtend ? RISCVISD::FP_EXTEND_VL : RISCVISD::FP_ROUND_VL;
7853 
7854   if (IsDirectConv) {
7855     Src = DAG.getNode(ConvOpc, DL, ContainerVT, Src, Mask, VL);
7856     if (VT.isFixedLengthVector())
7857       Src = convertFromScalableVector(VT, Src, DAG, Subtarget);
7858     return Src;
7859   }
7860 
7861   unsigned InterConvOpc =
7862       IsExtend ? RISCVISD::FP_EXTEND_VL : RISCVISD::VFNCVT_ROD_VL;
7863 
7864   MVT InterVT = ContainerVT.changeVectorElementType(MVT::f32);
7865   SDValue IntermediateConv =
7866       DAG.getNode(InterConvOpc, DL, InterVT, Src, Mask, VL);
7867   SDValue Result =
7868       DAG.getNode(ConvOpc, DL, ContainerVT, IntermediateConv, Mask, VL);
7869   if (VT.isFixedLengthVector())
7870     return convertFromScalableVector(VT, Result, DAG, Subtarget);
7871   return Result;
7872 }
7873 
7874 // Given a scalable vector type and an index into it, returns the type for the
7875 // smallest subvector that the index fits in. This can be used to reduce LMUL
7876 // for operations like vslidedown.
7877 //
7878 // E.g. With Zvl128b, index 3 in a nxv4i32 fits within the first nxv2i32.
7879 static std::optional<MVT>
getSmallestVTForIndex(MVT VecVT,unsigned MaxIdx,SDLoc DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)7880 getSmallestVTForIndex(MVT VecVT, unsigned MaxIdx, SDLoc DL, SelectionDAG &DAG,
7881                       const RISCVSubtarget &Subtarget) {
7882   assert(VecVT.isScalableVector());
7883   const unsigned EltSize = VecVT.getScalarSizeInBits();
7884   const unsigned VectorBitsMin = Subtarget.getRealMinVLen();
7885   const unsigned MinVLMAX = VectorBitsMin / EltSize;
7886   MVT SmallerVT;
7887   if (MaxIdx < MinVLMAX)
7888     SmallerVT = getLMUL1VT(VecVT);
7889   else if (MaxIdx < MinVLMAX * 2)
7890     SmallerVT = getLMUL1VT(VecVT).getDoubleNumVectorElementsVT();
7891   else if (MaxIdx < MinVLMAX * 4)
7892     SmallerVT = getLMUL1VT(VecVT)
7893                     .getDoubleNumVectorElementsVT()
7894                     .getDoubleNumVectorElementsVT();
7895   if (!SmallerVT.isValid() || !VecVT.bitsGT(SmallerVT))
7896     return std::nullopt;
7897   return SmallerVT;
7898 }
7899 
7900 // Custom-legalize INSERT_VECTOR_ELT so that the value is inserted into the
7901 // first position of a vector, and that vector is slid up to the insert index.
7902 // By limiting the active vector length to index+1 and merging with the
7903 // original vector (with an undisturbed tail policy for elements >= VL), we
7904 // achieve the desired result of leaving all elements untouched except the one
7905 // at VL-1, which is replaced with the desired value.
lowerINSERT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const7906 SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
7907                                                     SelectionDAG &DAG) const {
7908   SDLoc DL(Op);
7909   MVT VecVT = Op.getSimpleValueType();
7910   SDValue Vec = Op.getOperand(0);
7911   SDValue Val = Op.getOperand(1);
7912   SDValue Idx = Op.getOperand(2);
7913 
7914   if (VecVT.getVectorElementType() == MVT::i1) {
7915     // FIXME: For now we just promote to an i8 vector and insert into that,
7916     // but this is probably not optimal.
7917     MVT WideVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorElementCount());
7918     Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, Vec);
7919     Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, WideVT, Vec, Val, Idx);
7920     return DAG.getNode(ISD::TRUNCATE, DL, VecVT, Vec);
7921   }
7922 
7923   MVT ContainerVT = VecVT;
7924   // If the operand is a fixed-length vector, convert to a scalable one.
7925   if (VecVT.isFixedLengthVector()) {
7926     ContainerVT = getContainerForFixedLengthVector(VecVT);
7927     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
7928   }
7929 
7930   // If we know the index we're going to insert at, we can shrink Vec so that
7931   // we're performing the scalar inserts and slideup on a smaller LMUL.
7932   MVT OrigContainerVT = ContainerVT;
7933   SDValue OrigVec = Vec;
7934   SDValue AlignedIdx;
7935   if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx)) {
7936     const unsigned OrigIdx = IdxC->getZExtValue();
7937     // Do we know an upper bound on LMUL?
7938     if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, OrigIdx,
7939                                               DL, DAG, Subtarget)) {
7940       ContainerVT = *ShrunkVT;
7941       AlignedIdx = DAG.getVectorIdxConstant(0, DL);
7942     }
7943 
7944     // If we're compiling for an exact VLEN value, we can always perform
7945     // the insert in m1 as we can determine the register corresponding to
7946     // the index in the register group.
7947     const unsigned MinVLen = Subtarget.getRealMinVLen();
7948     const unsigned MaxVLen = Subtarget.getRealMaxVLen();
7949     const MVT M1VT = getLMUL1VT(ContainerVT);
7950     if (MinVLen == MaxVLen && ContainerVT.bitsGT(M1VT)) {
7951       EVT ElemVT = VecVT.getVectorElementType();
7952       unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
7953       unsigned RemIdx = OrigIdx % ElemsPerVReg;
7954       unsigned SubRegIdx = OrigIdx / ElemsPerVReg;
7955       unsigned ExtractIdx =
7956           SubRegIdx * M1VT.getVectorElementCount().getKnownMinValue();
7957       AlignedIdx = DAG.getVectorIdxConstant(ExtractIdx, DL);
7958       Idx = DAG.getVectorIdxConstant(RemIdx, DL);
7959       ContainerVT = M1VT;
7960     }
7961 
7962     if (AlignedIdx)
7963       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
7964                         AlignedIdx);
7965   }
7966 
7967   MVT XLenVT = Subtarget.getXLenVT();
7968 
7969   bool IsLegalInsert = Subtarget.is64Bit() || Val.getValueType() != MVT::i64;
7970   // Even i64-element vectors on RV32 can be lowered without scalar
7971   // legalization if the most-significant 32 bits of the value are not affected
7972   // by the sign-extension of the lower 32 bits.
7973   // TODO: We could also catch sign extensions of a 32-bit value.
7974   if (!IsLegalInsert && isa<ConstantSDNode>(Val)) {
7975     const auto *CVal = cast<ConstantSDNode>(Val);
7976     if (isInt<32>(CVal->getSExtValue())) {
7977       IsLegalInsert = true;
7978       Val = DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32);
7979     }
7980   }
7981 
7982   auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
7983 
7984   SDValue ValInVec;
7985 
7986   if (IsLegalInsert) {
7987     unsigned Opc =
7988         VecVT.isFloatingPoint() ? RISCVISD::VFMV_S_F_VL : RISCVISD::VMV_S_X_VL;
7989     if (isNullConstant(Idx)) {
7990       if (!VecVT.isFloatingPoint())
7991         Val = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Val);
7992       Vec = DAG.getNode(Opc, DL, ContainerVT, Vec, Val, VL);
7993 
7994       if (AlignedIdx)
7995         Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
7996                           Vec, AlignedIdx);
7997       if (!VecVT.isFixedLengthVector())
7998         return Vec;
7999       return convertFromScalableVector(VecVT, Vec, DAG, Subtarget);
8000     }
8001     ValInVec = lowerScalarInsert(Val, VL, ContainerVT, DL, DAG, Subtarget);
8002   } else {
8003     // On RV32, i64-element vectors must be specially handled to place the
8004     // value at element 0, by using two vslide1down instructions in sequence on
8005     // the i32 split lo/hi value. Use an equivalently-sized i32 vector for
8006     // this.
8007     SDValue ValLo, ValHi;
8008     std::tie(ValLo, ValHi) = DAG.SplitScalar(Val, DL, MVT::i32, MVT::i32);
8009     MVT I32ContainerVT =
8010         MVT::getVectorVT(MVT::i32, ContainerVT.getVectorElementCount() * 2);
8011     SDValue I32Mask =
8012         getDefaultScalableVLOps(I32ContainerVT, DL, DAG, Subtarget).first;
8013     // Limit the active VL to two.
8014     SDValue InsertI64VL = DAG.getConstant(2, DL, XLenVT);
8015     // If the Idx is 0 we can insert directly into the vector.
8016     if (isNullConstant(Idx)) {
8017       // First slide in the lo value, then the hi in above it. We use slide1down
8018       // to avoid the register group overlap constraint of vslide1up.
8019       ValInVec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32ContainerVT,
8020                              Vec, Vec, ValLo, I32Mask, InsertI64VL);
8021       // If the source vector is undef don't pass along the tail elements from
8022       // the previous slide1down.
8023       SDValue Tail = Vec.isUndef() ? Vec : ValInVec;
8024       ValInVec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32ContainerVT,
8025                              Tail, ValInVec, ValHi, I32Mask, InsertI64VL);
8026       // Bitcast back to the right container type.
8027       ValInVec = DAG.getBitcast(ContainerVT, ValInVec);
8028 
8029       if (AlignedIdx)
8030         ValInVec =
8031             DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
8032                         ValInVec, AlignedIdx);
8033       if (!VecVT.isFixedLengthVector())
8034         return ValInVec;
8035       return convertFromScalableVector(VecVT, ValInVec, DAG, Subtarget);
8036     }
8037 
8038     // First slide in the lo value, then the hi in above it. We use slide1down
8039     // to avoid the register group overlap constraint of vslide1up.
8040     ValInVec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32ContainerVT,
8041                            DAG.getUNDEF(I32ContainerVT),
8042                            DAG.getUNDEF(I32ContainerVT), ValLo,
8043                            I32Mask, InsertI64VL);
8044     ValInVec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32ContainerVT,
8045                            DAG.getUNDEF(I32ContainerVT), ValInVec, ValHi,
8046                            I32Mask, InsertI64VL);
8047     // Bitcast back to the right container type.
8048     ValInVec = DAG.getBitcast(ContainerVT, ValInVec);
8049   }
8050 
8051   // Now that the value is in a vector, slide it into position.
8052   SDValue InsertVL =
8053       DAG.getNode(ISD::ADD, DL, XLenVT, Idx, DAG.getConstant(1, DL, XLenVT));
8054 
8055   // Use tail agnostic policy if Idx is the last index of Vec.
8056   unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
8057   if (VecVT.isFixedLengthVector() && isa<ConstantSDNode>(Idx) &&
8058       Idx->getAsZExtVal() + 1 == VecVT.getVectorNumElements())
8059     Policy = RISCVII::TAIL_AGNOSTIC;
8060   SDValue Slideup = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, ValInVec,
8061                                 Idx, Mask, InsertVL, Policy);
8062 
8063   if (AlignedIdx)
8064     Slideup = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
8065                           Slideup, AlignedIdx);
8066   if (!VecVT.isFixedLengthVector())
8067     return Slideup;
8068   return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget);
8069 }
8070 
8071 // Custom-lower EXTRACT_VECTOR_ELT operations to slide the vector down, then
8072 // extract the first element: (extractelt (slidedown vec, idx), 0). For integer
8073 // types this is done using VMV_X_S to allow us to glean information about the
8074 // sign bits of the result.
lowerEXTRACT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const8075 SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
8076                                                      SelectionDAG &DAG) const {
8077   SDLoc DL(Op);
8078   SDValue Idx = Op.getOperand(1);
8079   SDValue Vec = Op.getOperand(0);
8080   EVT EltVT = Op.getValueType();
8081   MVT VecVT = Vec.getSimpleValueType();
8082   MVT XLenVT = Subtarget.getXLenVT();
8083 
8084   if (VecVT.getVectorElementType() == MVT::i1) {
8085     // Use vfirst.m to extract the first bit.
8086     if (isNullConstant(Idx)) {
8087       MVT ContainerVT = VecVT;
8088       if (VecVT.isFixedLengthVector()) {
8089         ContainerVT = getContainerForFixedLengthVector(VecVT);
8090         Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
8091       }
8092       auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
8093       SDValue Vfirst =
8094           DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Vec, Mask, VL);
8095       SDValue Res = DAG.getSetCC(DL, XLenVT, Vfirst,
8096                                  DAG.getConstant(0, DL, XLenVT), ISD::SETEQ);
8097       return DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
8098     }
8099     if (VecVT.isFixedLengthVector()) {
8100       unsigned NumElts = VecVT.getVectorNumElements();
8101       if (NumElts >= 8) {
8102         MVT WideEltVT;
8103         unsigned WidenVecLen;
8104         SDValue ExtractElementIdx;
8105         SDValue ExtractBitIdx;
8106         unsigned MaxEEW = Subtarget.getELen();
8107         MVT LargestEltVT = MVT::getIntegerVT(
8108             std::min(MaxEEW, unsigned(XLenVT.getSizeInBits())));
8109         if (NumElts <= LargestEltVT.getSizeInBits()) {
8110           assert(isPowerOf2_32(NumElts) &&
8111                  "the number of elements should be power of 2");
8112           WideEltVT = MVT::getIntegerVT(NumElts);
8113           WidenVecLen = 1;
8114           ExtractElementIdx = DAG.getConstant(0, DL, XLenVT);
8115           ExtractBitIdx = Idx;
8116         } else {
8117           WideEltVT = LargestEltVT;
8118           WidenVecLen = NumElts / WideEltVT.getSizeInBits();
8119           // extract element index = index / element width
8120           ExtractElementIdx = DAG.getNode(
8121               ISD::SRL, DL, XLenVT, Idx,
8122               DAG.getConstant(Log2_64(WideEltVT.getSizeInBits()), DL, XLenVT));
8123           // mask bit index = index % element width
8124           ExtractBitIdx = DAG.getNode(
8125               ISD::AND, DL, XLenVT, Idx,
8126               DAG.getConstant(WideEltVT.getSizeInBits() - 1, DL, XLenVT));
8127         }
8128         MVT WideVT = MVT::getVectorVT(WideEltVT, WidenVecLen);
8129         Vec = DAG.getNode(ISD::BITCAST, DL, WideVT, Vec);
8130         SDValue ExtractElt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, XLenVT,
8131                                          Vec, ExtractElementIdx);
8132         // Extract the bit from GPR.
8133         SDValue ShiftRight =
8134             DAG.getNode(ISD::SRL, DL, XLenVT, ExtractElt, ExtractBitIdx);
8135         SDValue Res = DAG.getNode(ISD::AND, DL, XLenVT, ShiftRight,
8136                                   DAG.getConstant(1, DL, XLenVT));
8137         return DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
8138       }
8139     }
8140     // Otherwise, promote to an i8 vector and extract from that.
8141     MVT WideVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorElementCount());
8142     Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, Vec);
8143     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec, Idx);
8144   }
8145 
8146   // If this is a fixed vector, we need to convert it to a scalable vector.
8147   MVT ContainerVT = VecVT;
8148   if (VecVT.isFixedLengthVector()) {
8149     ContainerVT = getContainerForFixedLengthVector(VecVT);
8150     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
8151   }
8152 
8153   // If we're compiling for an exact VLEN value and we have a known
8154   // constant index, we can always perform the extract in m1 (or
8155   // smaller) as we can determine the register corresponding to
8156   // the index in the register group.
8157   const unsigned MinVLen = Subtarget.getRealMinVLen();
8158   const unsigned MaxVLen = Subtarget.getRealMaxVLen();
8159   if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx);
8160       IdxC && MinVLen == MaxVLen &&
8161       VecVT.getSizeInBits().getKnownMinValue() > MinVLen) {
8162     MVT M1VT = getLMUL1VT(ContainerVT);
8163     unsigned OrigIdx = IdxC->getZExtValue();
8164     EVT ElemVT = VecVT.getVectorElementType();
8165     unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
8166     unsigned RemIdx = OrigIdx % ElemsPerVReg;
8167     unsigned SubRegIdx = OrigIdx / ElemsPerVReg;
8168     unsigned ExtractIdx =
8169       SubRegIdx * M1VT.getVectorElementCount().getKnownMinValue();
8170     Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, Vec,
8171                       DAG.getVectorIdxConstant(ExtractIdx, DL));
8172     Idx = DAG.getVectorIdxConstant(RemIdx, DL);
8173     ContainerVT = M1VT;
8174   }
8175 
8176   // Reduce the LMUL of our slidedown and vmv.x.s to the smallest LMUL which
8177   // contains our index.
8178   std::optional<uint64_t> MaxIdx;
8179   if (VecVT.isFixedLengthVector())
8180     MaxIdx = VecVT.getVectorNumElements() - 1;
8181   if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx))
8182     MaxIdx = IdxC->getZExtValue();
8183   if (MaxIdx) {
8184     if (auto SmallerVT =
8185             getSmallestVTForIndex(ContainerVT, *MaxIdx, DL, DAG, Subtarget)) {
8186       ContainerVT = *SmallerVT;
8187       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
8188                         DAG.getConstant(0, DL, XLenVT));
8189     }
8190   }
8191 
8192   // If after narrowing, the required slide is still greater than LMUL2,
8193   // fallback to generic expansion and go through the stack.  This is done
8194   // for a subtle reason: extracting *all* elements out of a vector is
8195   // widely expected to be linear in vector size, but because vslidedown
8196   // is linear in LMUL, performing N extracts using vslidedown becomes
8197   // O(n^2) / (VLEN/ETYPE) work.  On the surface, going through the stack
8198   // seems to have the same problem (the store is linear in LMUL), but the
8199   // generic expansion *memoizes* the store, and thus for many extracts of
8200   // the same vector we end up with one store and a bunch of loads.
8201   // TODO: We don't have the same code for insert_vector_elt because we
8202   // have BUILD_VECTOR and handle the degenerate case there.  Should we
8203   // consider adding an inverse BUILD_VECTOR node?
8204   MVT LMUL2VT = getLMUL1VT(ContainerVT).getDoubleNumVectorElementsVT();
8205   if (ContainerVT.bitsGT(LMUL2VT) && VecVT.isFixedLengthVector())
8206     return SDValue();
8207 
8208   // If the index is 0, the vector is already in the right position.
8209   if (!isNullConstant(Idx)) {
8210     // Use a VL of 1 to avoid processing more elements than we need.
8211     auto [Mask, VL] = getDefaultVLOps(1, ContainerVT, DL, DAG, Subtarget);
8212     Vec = getVSlidedown(DAG, Subtarget, DL, ContainerVT,
8213                         DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
8214   }
8215 
8216   if (!EltVT.isInteger()) {
8217     // Floating-point extracts are handled in TableGen.
8218     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
8219                        DAG.getConstant(0, DL, XLenVT));
8220   }
8221 
8222   SDValue Elt0 = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
8223   return DAG.getNode(ISD::TRUNCATE, DL, EltVT, Elt0);
8224 }
8225 
8226 // Some RVV intrinsics may claim that they want an integer operand to be
8227 // promoted or expanded.
lowerVectorIntrinsicScalars(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8228 static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
8229                                            const RISCVSubtarget &Subtarget) {
8230   assert((Op.getOpcode() == ISD::INTRINSIC_VOID ||
8231           Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
8232           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN) &&
8233          "Unexpected opcode");
8234 
8235   if (!Subtarget.hasVInstructions())
8236     return SDValue();
8237 
8238   bool HasChain = Op.getOpcode() == ISD::INTRINSIC_VOID ||
8239                   Op.getOpcode() == ISD::INTRINSIC_W_CHAIN;
8240   unsigned IntNo = Op.getConstantOperandVal(HasChain ? 1 : 0);
8241 
8242   SDLoc DL(Op);
8243 
8244   const RISCVVIntrinsicsTable::RISCVVIntrinsicInfo *II =
8245       RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IntNo);
8246   if (!II || !II->hasScalarOperand())
8247     return SDValue();
8248 
8249   unsigned SplatOp = II->ScalarOperand + 1 + HasChain;
8250   assert(SplatOp < Op.getNumOperands());
8251 
8252   SmallVector<SDValue, 8> Operands(Op->op_begin(), Op->op_end());
8253   SDValue &ScalarOp = Operands[SplatOp];
8254   MVT OpVT = ScalarOp.getSimpleValueType();
8255   MVT XLenVT = Subtarget.getXLenVT();
8256 
8257   // If this isn't a scalar, or its type is XLenVT we're done.
8258   if (!OpVT.isScalarInteger() || OpVT == XLenVT)
8259     return SDValue();
8260 
8261   // Simplest case is that the operand needs to be promoted to XLenVT.
8262   if (OpVT.bitsLT(XLenVT)) {
8263     // If the operand is a constant, sign extend to increase our chances
8264     // of being able to use a .vi instruction. ANY_EXTEND would become a
8265     // a zero extend and the simm5 check in isel would fail.
8266     // FIXME: Should we ignore the upper bits in isel instead?
8267     unsigned ExtOpc =
8268         isa<ConstantSDNode>(ScalarOp) ? ISD::SIGN_EXTEND : ISD::ANY_EXTEND;
8269     ScalarOp = DAG.getNode(ExtOpc, DL, XLenVT, ScalarOp);
8270     return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
8271   }
8272 
8273   // Use the previous operand to get the vXi64 VT. The result might be a mask
8274   // VT for compares. Using the previous operand assumes that the previous
8275   // operand will never have a smaller element size than a scalar operand and
8276   // that a widening operation never uses SEW=64.
8277   // NOTE: If this fails the below assert, we can probably just find the
8278   // element count from any operand or result and use it to construct the VT.
8279   assert(II->ScalarOperand > 0 && "Unexpected splat operand!");
8280   MVT VT = Op.getOperand(SplatOp - 1).getSimpleValueType();
8281 
8282   // The more complex case is when the scalar is larger than XLenVT.
8283   assert(XLenVT == MVT::i32 && OpVT == MVT::i64 &&
8284          VT.getVectorElementType() == MVT::i64 && "Unexpected VTs!");
8285 
8286   // If this is a sign-extended 32-bit value, we can truncate it and rely on the
8287   // instruction to sign-extend since SEW>XLEN.
8288   if (DAG.ComputeNumSignBits(ScalarOp) > 32) {
8289     ScalarOp = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, ScalarOp);
8290     return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
8291   }
8292 
8293   switch (IntNo) {
8294   case Intrinsic::riscv_vslide1up:
8295   case Intrinsic::riscv_vslide1down:
8296   case Intrinsic::riscv_vslide1up_mask:
8297   case Intrinsic::riscv_vslide1down_mask: {
8298     // We need to special case these when the scalar is larger than XLen.
8299     unsigned NumOps = Op.getNumOperands();
8300     bool IsMasked = NumOps == 7;
8301 
8302     // Convert the vector source to the equivalent nxvXi32 vector.
8303     MVT I32VT = MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2);
8304     SDValue Vec = DAG.getBitcast(I32VT, Operands[2]);
8305     SDValue ScalarLo, ScalarHi;
8306     std::tie(ScalarLo, ScalarHi) =
8307         DAG.SplitScalar(ScalarOp, DL, MVT::i32, MVT::i32);
8308 
8309     // Double the VL since we halved SEW.
8310     SDValue AVL = getVLOperand(Op);
8311     SDValue I32VL;
8312 
8313     // Optimize for constant AVL
8314     if (isa<ConstantSDNode>(AVL)) {
8315       const auto [MinVLMAX, MaxVLMAX] =
8316           RISCVTargetLowering::computeVLMAXBounds(VT, Subtarget);
8317 
8318       uint64_t AVLInt = AVL->getAsZExtVal();
8319       if (AVLInt <= MinVLMAX) {
8320         I32VL = DAG.getConstant(2 * AVLInt, DL, XLenVT);
8321       } else if (AVLInt >= 2 * MaxVLMAX) {
8322         // Just set vl to VLMAX in this situation
8323         RISCVII::VLMUL Lmul = RISCVTargetLowering::getLMUL(I32VT);
8324         SDValue LMUL = DAG.getConstant(Lmul, DL, XLenVT);
8325         unsigned Sew = RISCVVType::encodeSEW(I32VT.getScalarSizeInBits());
8326         SDValue SEW = DAG.getConstant(Sew, DL, XLenVT);
8327         SDValue SETVLMAX = DAG.getTargetConstant(
8328             Intrinsic::riscv_vsetvlimax, DL, MVT::i32);
8329         I32VL = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, SETVLMAX, SEW,
8330                             LMUL);
8331       } else {
8332         // For AVL between (MinVLMAX, 2 * MaxVLMAX), the actual working vl
8333         // is related to the hardware implementation.
8334         // So let the following code handle
8335       }
8336     }
8337     if (!I32VL) {
8338       RISCVII::VLMUL Lmul = RISCVTargetLowering::getLMUL(VT);
8339       SDValue LMUL = DAG.getConstant(Lmul, DL, XLenVT);
8340       unsigned Sew = RISCVVType::encodeSEW(VT.getScalarSizeInBits());
8341       SDValue SEW = DAG.getConstant(Sew, DL, XLenVT);
8342       SDValue SETVL =
8343           DAG.getTargetConstant(Intrinsic::riscv_vsetvli, DL, MVT::i32);
8344       // Using vsetvli instruction to get actually used length which related to
8345       // the hardware implementation
8346       SDValue VL = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, SETVL, AVL,
8347                                SEW, LMUL);
8348       I32VL =
8349           DAG.getNode(ISD::SHL, DL, XLenVT, VL, DAG.getConstant(1, DL, XLenVT));
8350     }
8351 
8352     SDValue I32Mask = getAllOnesMask(I32VT, I32VL, DL, DAG);
8353 
8354     // Shift the two scalar parts in using SEW=32 slide1up/slide1down
8355     // instructions.
8356     SDValue Passthru;
8357     if (IsMasked)
8358       Passthru = DAG.getUNDEF(I32VT);
8359     else
8360       Passthru = DAG.getBitcast(I32VT, Operands[1]);
8361 
8362     if (IntNo == Intrinsic::riscv_vslide1up ||
8363         IntNo == Intrinsic::riscv_vslide1up_mask) {
8364       Vec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32VT, Passthru, Vec,
8365                         ScalarHi, I32Mask, I32VL);
8366       Vec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32VT, Passthru, Vec,
8367                         ScalarLo, I32Mask, I32VL);
8368     } else {
8369       Vec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32VT, Passthru, Vec,
8370                         ScalarLo, I32Mask, I32VL);
8371       Vec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32VT, Passthru, Vec,
8372                         ScalarHi, I32Mask, I32VL);
8373     }
8374 
8375     // Convert back to nxvXi64.
8376     Vec = DAG.getBitcast(VT, Vec);
8377 
8378     if (!IsMasked)
8379       return Vec;
8380     // Apply mask after the operation.
8381     SDValue Mask = Operands[NumOps - 3];
8382     SDValue MaskedOff = Operands[1];
8383     // Assume Policy operand is the last operand.
8384     uint64_t Policy = Operands[NumOps - 1]->getAsZExtVal();
8385     // We don't need to select maskedoff if it's undef.
8386     if (MaskedOff.isUndef())
8387       return Vec;
8388     // TAMU
8389     if (Policy == RISCVII::TAIL_AGNOSTIC)
8390       return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec, MaskedOff,
8391                          DAG.getUNDEF(VT), AVL);
8392     // TUMA or TUMU: Currently we always emit tumu policy regardless of tuma.
8393     // It's fine because vmerge does not care mask policy.
8394     return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec, MaskedOff,
8395                        MaskedOff, AVL);
8396   }
8397   }
8398 
8399   // We need to convert the scalar to a splat vector.
8400   SDValue VL = getVLOperand(Op);
8401   assert(VL.getValueType() == XLenVT);
8402   ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL, DAG);
8403   return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
8404 }
8405 
8406 // Lower the llvm.get.vector.length intrinsic to vsetvli. We only support
8407 // scalable vector llvm.get.vector.length for now.
8408 //
8409 // We need to convert from a scalable VF to a vsetvli with VLMax equal to
8410 // (vscale * VF). The vscale and VF are independent of element width. We use
8411 // SEW=8 for the vsetvli because it is the only element width that supports all
8412 // fractional LMULs. The LMUL is choosen so that with SEW=8 the VLMax is
8413 // (vscale * VF). Where vscale is defined as VLEN/RVVBitsPerBlock. The
8414 // InsertVSETVLI pass can fix up the vtype of the vsetvli if a different
8415 // SEW and LMUL are better for the surrounding vector instructions.
lowerGetVectorLength(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8416 static SDValue lowerGetVectorLength(SDNode *N, SelectionDAG &DAG,
8417                                     const RISCVSubtarget &Subtarget) {
8418   MVT XLenVT = Subtarget.getXLenVT();
8419 
8420   // The smallest LMUL is only valid for the smallest element width.
8421   const unsigned ElementWidth = 8;
8422 
8423   // Determine the VF that corresponds to LMUL 1 for ElementWidth.
8424   unsigned LMul1VF = RISCV::RVVBitsPerBlock / ElementWidth;
8425   // We don't support VF==1 with ELEN==32.
8426   unsigned MinVF = RISCV::RVVBitsPerBlock / Subtarget.getELen();
8427 
8428   unsigned VF = N->getConstantOperandVal(2);
8429   assert(VF >= MinVF && VF <= (LMul1VF * 8) && isPowerOf2_32(VF) &&
8430          "Unexpected VF");
8431   (void)MinVF;
8432 
8433   bool Fractional = VF < LMul1VF;
8434   unsigned LMulVal = Fractional ? LMul1VF / VF : VF / LMul1VF;
8435   unsigned VLMUL = (unsigned)RISCVVType::encodeLMUL(LMulVal, Fractional);
8436   unsigned VSEW = RISCVVType::encodeSEW(ElementWidth);
8437 
8438   SDLoc DL(N);
8439 
8440   SDValue LMul = DAG.getTargetConstant(VLMUL, DL, XLenVT);
8441   SDValue Sew = DAG.getTargetConstant(VSEW, DL, XLenVT);
8442 
8443   SDValue AVL = DAG.getNode(ISD::ZERO_EXTEND, DL, XLenVT, N->getOperand(1));
8444 
8445   SDValue ID = DAG.getTargetConstant(Intrinsic::riscv_vsetvli, DL, XLenVT);
8446   SDValue Res =
8447       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, ID, AVL, Sew, LMul);
8448   return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), Res);
8449 }
8450 
getVCIXOperands(SDValue & Op,SelectionDAG & DAG,SmallVector<SDValue> & Ops)8451 static void getVCIXOperands(SDValue &Op, SelectionDAG &DAG,
8452                             SmallVector<SDValue> &Ops) {
8453   SDLoc DL(Op);
8454 
8455   const RISCVSubtarget &Subtarget =
8456       DAG.getMachineFunction().getSubtarget<RISCVSubtarget>();
8457   for (const SDValue &V : Op->op_values()) {
8458     EVT ValType = V.getValueType();
8459     if (ValType.isScalableVector() && ValType.isFloatingPoint()) {
8460       MVT InterimIVT =
8461           MVT::getVectorVT(MVT::getIntegerVT(ValType.getScalarSizeInBits()),
8462                            ValType.getVectorElementCount());
8463       Ops.push_back(DAG.getBitcast(InterimIVT, V));
8464     } else if (ValType.isFixedLengthVector()) {
8465       MVT OpContainerVT = getContainerForFixedLengthVector(
8466           DAG, V.getSimpleValueType(), Subtarget);
8467       Ops.push_back(convertToScalableVector(OpContainerVT, V, DAG, Subtarget));
8468     } else
8469       Ops.push_back(V);
8470   }
8471 }
8472 
8473 // LMUL * VLEN should be greater than or equal to EGS * SEW
isValidEGW(int EGS,EVT VT,const RISCVSubtarget & Subtarget)8474 static inline bool isValidEGW(int EGS, EVT VT,
8475                               const RISCVSubtarget &Subtarget) {
8476   return (Subtarget.getRealMinVLen() *
8477              VT.getSizeInBits().getKnownMinValue()) / RISCV::RVVBitsPerBlock >=
8478          EGS * VT.getScalarSizeInBits();
8479 }
8480 
LowerINTRINSIC_WO_CHAIN(SDValue Op,SelectionDAG & DAG) const8481 SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
8482                                                      SelectionDAG &DAG) const {
8483   unsigned IntNo = Op.getConstantOperandVal(0);
8484   SDLoc DL(Op);
8485   MVT XLenVT = Subtarget.getXLenVT();
8486 
8487   switch (IntNo) {
8488   default:
8489     break; // Don't custom lower most intrinsics.
8490   case Intrinsic::thread_pointer: {
8491     EVT PtrVT = getPointerTy(DAG.getDataLayout());
8492     return DAG.getRegister(RISCV::X4, PtrVT);
8493   }
8494   case Intrinsic::riscv_orc_b:
8495   case Intrinsic::riscv_brev8:
8496   case Intrinsic::riscv_sha256sig0:
8497   case Intrinsic::riscv_sha256sig1:
8498   case Intrinsic::riscv_sha256sum0:
8499   case Intrinsic::riscv_sha256sum1:
8500   case Intrinsic::riscv_sm3p0:
8501   case Intrinsic::riscv_sm3p1: {
8502     unsigned Opc;
8503     switch (IntNo) {
8504     case Intrinsic::riscv_orc_b:      Opc = RISCVISD::ORC_B;      break;
8505     case Intrinsic::riscv_brev8:      Opc = RISCVISD::BREV8;      break;
8506     case Intrinsic::riscv_sha256sig0: Opc = RISCVISD::SHA256SIG0; break;
8507     case Intrinsic::riscv_sha256sig1: Opc = RISCVISD::SHA256SIG1; break;
8508     case Intrinsic::riscv_sha256sum0: Opc = RISCVISD::SHA256SUM0; break;
8509     case Intrinsic::riscv_sha256sum1: Opc = RISCVISD::SHA256SUM1; break;
8510     case Intrinsic::riscv_sm3p0:      Opc = RISCVISD::SM3P0;      break;
8511     case Intrinsic::riscv_sm3p1:      Opc = RISCVISD::SM3P1;      break;
8512     }
8513 
8514     if (RV64LegalI32 && Subtarget.is64Bit() && Op.getValueType() == MVT::i32) {
8515       SDValue NewOp =
8516           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(1));
8517       SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp);
8518       return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res);
8519     }
8520 
8521     return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1));
8522   }
8523   case Intrinsic::riscv_sm4ks:
8524   case Intrinsic::riscv_sm4ed: {
8525     unsigned Opc =
8526         IntNo == Intrinsic::riscv_sm4ks ? RISCVISD::SM4KS : RISCVISD::SM4ED;
8527 
8528     if (RV64LegalI32 && Subtarget.is64Bit() && Op.getValueType() == MVT::i32) {
8529       SDValue NewOp0 =
8530           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(1));
8531       SDValue NewOp1 =
8532           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(2));
8533       SDValue Res =
8534           DAG.getNode(Opc, DL, MVT::i64, NewOp0, NewOp1, Op.getOperand(3));
8535       return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res);
8536     }
8537 
8538     return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1), Op.getOperand(2),
8539                        Op.getOperand(3));
8540   }
8541   case Intrinsic::riscv_zip:
8542   case Intrinsic::riscv_unzip: {
8543     unsigned Opc =
8544         IntNo == Intrinsic::riscv_zip ? RISCVISD::ZIP : RISCVISD::UNZIP;
8545     return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1));
8546   }
8547   case Intrinsic::riscv_clmul:
8548     if (RV64LegalI32 && Subtarget.is64Bit() && Op.getValueType() == MVT::i32) {
8549       SDValue NewOp0 =
8550           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(1));
8551       SDValue NewOp1 =
8552           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(2));
8553       SDValue Res = DAG.getNode(RISCVISD::CLMUL, DL, MVT::i64, NewOp0, NewOp1);
8554       return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res);
8555     }
8556     return DAG.getNode(RISCVISD::CLMUL, DL, XLenVT, Op.getOperand(1),
8557                        Op.getOperand(2));
8558   case Intrinsic::riscv_clmulh:
8559   case Intrinsic::riscv_clmulr: {
8560     unsigned Opc =
8561         IntNo == Intrinsic::riscv_clmulh ? RISCVISD::CLMULH : RISCVISD::CLMULR;
8562     if (RV64LegalI32 && Subtarget.is64Bit() && Op.getValueType() == MVT::i32) {
8563       SDValue NewOp0 =
8564           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(1));
8565       SDValue NewOp1 =
8566           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(2));
8567       NewOp0 = DAG.getNode(ISD::SHL, DL, MVT::i64, NewOp0,
8568                            DAG.getConstant(32, DL, MVT::i64));
8569       NewOp1 = DAG.getNode(ISD::SHL, DL, MVT::i64, NewOp1,
8570                            DAG.getConstant(32, DL, MVT::i64));
8571       SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp0, NewOp1);
8572       Res = DAG.getNode(ISD::SRL, DL, MVT::i64, Res,
8573                         DAG.getConstant(32, DL, MVT::i64));
8574       return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res);
8575     }
8576 
8577     return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1), Op.getOperand(2));
8578   }
8579   case Intrinsic::experimental_get_vector_length:
8580     return lowerGetVectorLength(Op.getNode(), DAG, Subtarget);
8581   case Intrinsic::riscv_vmv_x_s: {
8582     SDValue Res = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Op.getOperand(1));
8583     return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
8584   }
8585   case Intrinsic::riscv_vfmv_f_s:
8586     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
8587                        Op.getOperand(1), DAG.getConstant(0, DL, XLenVT));
8588   case Intrinsic::riscv_vmv_v_x:
8589     return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
8590                             Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
8591                             Subtarget);
8592   case Intrinsic::riscv_vfmv_v_f:
8593     return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, Op.getValueType(),
8594                        Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
8595   case Intrinsic::riscv_vmv_s_x: {
8596     SDValue Scalar = Op.getOperand(2);
8597 
8598     if (Scalar.getValueType().bitsLE(XLenVT)) {
8599       Scalar = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Scalar);
8600       return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, Op.getValueType(),
8601                          Op.getOperand(1), Scalar, Op.getOperand(3));
8602     }
8603 
8604     assert(Scalar.getValueType() == MVT::i64 && "Unexpected scalar VT!");
8605 
8606     // This is an i64 value that lives in two scalar registers. We have to
8607     // insert this in a convoluted way. First we build vXi64 splat containing
8608     // the two values that we assemble using some bit math. Next we'll use
8609     // vid.v and vmseq to build a mask with bit 0 set. Then we'll use that mask
8610     // to merge element 0 from our splat into the source vector.
8611     // FIXME: This is probably not the best way to do this, but it is
8612     // consistent with INSERT_VECTOR_ELT lowering so it is a good starting
8613     // point.
8614     //   sw lo, (a0)
8615     //   sw hi, 4(a0)
8616     //   vlse vX, (a0)
8617     //
8618     //   vid.v      vVid
8619     //   vmseq.vx   mMask, vVid, 0
8620     //   vmerge.vvm vDest, vSrc, vVal, mMask
8621     MVT VT = Op.getSimpleValueType();
8622     SDValue Vec = Op.getOperand(1);
8623     SDValue VL = getVLOperand(Op);
8624 
8625     SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL, DAG);
8626     if (Op.getOperand(1).isUndef())
8627       return SplattedVal;
8628     SDValue SplattedIdx =
8629         DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
8630                     DAG.getConstant(0, DL, MVT::i32), VL);
8631 
8632     MVT MaskVT = getMaskTypeFor(VT);
8633     SDValue Mask = getAllOnesMask(VT, VL, DL, DAG);
8634     SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VT, Mask, VL);
8635     SDValue SelectCond =
8636         DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
8637                     {VID, SplattedIdx, DAG.getCondCode(ISD::SETEQ),
8638                      DAG.getUNDEF(MaskVT), Mask, VL});
8639     return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, SelectCond, SplattedVal,
8640                        Vec, DAG.getUNDEF(VT), VL);
8641   }
8642   case Intrinsic::riscv_vfmv_s_f:
8643     return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, Op.getSimpleValueType(),
8644                        Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
8645   // EGS * EEW >= 128 bits
8646   case Intrinsic::riscv_vaesdf_vv:
8647   case Intrinsic::riscv_vaesdf_vs:
8648   case Intrinsic::riscv_vaesdm_vv:
8649   case Intrinsic::riscv_vaesdm_vs:
8650   case Intrinsic::riscv_vaesef_vv:
8651   case Intrinsic::riscv_vaesef_vs:
8652   case Intrinsic::riscv_vaesem_vv:
8653   case Intrinsic::riscv_vaesem_vs:
8654   case Intrinsic::riscv_vaeskf1:
8655   case Intrinsic::riscv_vaeskf2:
8656   case Intrinsic::riscv_vaesz_vs:
8657   case Intrinsic::riscv_vsm4k:
8658   case Intrinsic::riscv_vsm4r_vv:
8659   case Intrinsic::riscv_vsm4r_vs: {
8660     if (!isValidEGW(4, Op.getSimpleValueType(), Subtarget) ||
8661         !isValidEGW(4, Op->getOperand(1).getSimpleValueType(), Subtarget) ||
8662         !isValidEGW(4, Op->getOperand(2).getSimpleValueType(), Subtarget))
8663       report_fatal_error("EGW should be greater than or equal to 4 * SEW.");
8664     return Op;
8665   }
8666   // EGS * EEW >= 256 bits
8667   case Intrinsic::riscv_vsm3c:
8668   case Intrinsic::riscv_vsm3me: {
8669     if (!isValidEGW(8, Op.getSimpleValueType(), Subtarget) ||
8670         !isValidEGW(8, Op->getOperand(1).getSimpleValueType(), Subtarget))
8671       report_fatal_error("EGW should be greater than or equal to 8 * SEW.");
8672     return Op;
8673   }
8674   // zvknha(SEW=32)/zvknhb(SEW=[32|64])
8675   case Intrinsic::riscv_vsha2ch:
8676   case Intrinsic::riscv_vsha2cl:
8677   case Intrinsic::riscv_vsha2ms: {
8678     if (Op->getSimpleValueType(0).getScalarSizeInBits() == 64 &&
8679         !Subtarget.hasStdExtZvknhb())
8680       report_fatal_error("SEW=64 needs Zvknhb to be enabled.");
8681     if (!isValidEGW(4, Op.getSimpleValueType(), Subtarget) ||
8682         !isValidEGW(4, Op->getOperand(1).getSimpleValueType(), Subtarget) ||
8683         !isValidEGW(4, Op->getOperand(2).getSimpleValueType(), Subtarget))
8684       report_fatal_error("EGW should be greater than or equal to 4 * SEW.");
8685     return Op;
8686   }
8687   case Intrinsic::riscv_sf_vc_v_x:
8688   case Intrinsic::riscv_sf_vc_v_i:
8689   case Intrinsic::riscv_sf_vc_v_xv:
8690   case Intrinsic::riscv_sf_vc_v_iv:
8691   case Intrinsic::riscv_sf_vc_v_vv:
8692   case Intrinsic::riscv_sf_vc_v_fv:
8693   case Intrinsic::riscv_sf_vc_v_xvv:
8694   case Intrinsic::riscv_sf_vc_v_ivv:
8695   case Intrinsic::riscv_sf_vc_v_vvv:
8696   case Intrinsic::riscv_sf_vc_v_fvv:
8697   case Intrinsic::riscv_sf_vc_v_xvw:
8698   case Intrinsic::riscv_sf_vc_v_ivw:
8699   case Intrinsic::riscv_sf_vc_v_vvw:
8700   case Intrinsic::riscv_sf_vc_v_fvw: {
8701     MVT VT = Op.getSimpleValueType();
8702 
8703     SmallVector<SDValue> Ops;
8704     getVCIXOperands(Op, DAG, Ops);
8705 
8706     MVT RetVT = VT;
8707     if (VT.isFixedLengthVector())
8708       RetVT = getContainerForFixedLengthVector(VT);
8709     else if (VT.isFloatingPoint())
8710       RetVT = MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits()),
8711                                VT.getVectorElementCount());
8712 
8713     SDValue NewNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, RetVT, Ops);
8714 
8715     if (VT.isFixedLengthVector())
8716       NewNode = convertFromScalableVector(VT, NewNode, DAG, Subtarget);
8717     else if (VT.isFloatingPoint())
8718       NewNode = DAG.getBitcast(VT, NewNode);
8719 
8720     if (Op == NewNode)
8721       break;
8722 
8723     return NewNode;
8724   }
8725   }
8726 
8727   return lowerVectorIntrinsicScalars(Op, DAG, Subtarget);
8728 }
8729 
LowerINTRINSIC_W_CHAIN(SDValue Op,SelectionDAG & DAG) const8730 SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
8731                                                     SelectionDAG &DAG) const {
8732   unsigned IntNo = Op.getConstantOperandVal(1);
8733   switch (IntNo) {
8734   default:
8735     break;
8736   case Intrinsic::riscv_masked_strided_load: {
8737     SDLoc DL(Op);
8738     MVT XLenVT = Subtarget.getXLenVT();
8739 
8740     // If the mask is known to be all ones, optimize to an unmasked intrinsic;
8741     // the selection of the masked intrinsics doesn't do this for us.
8742     SDValue Mask = Op.getOperand(5);
8743     bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
8744 
8745     MVT VT = Op->getSimpleValueType(0);
8746     MVT ContainerVT = VT;
8747     if (VT.isFixedLengthVector())
8748       ContainerVT = getContainerForFixedLengthVector(VT);
8749 
8750     SDValue PassThru = Op.getOperand(2);
8751     if (!IsUnmasked) {
8752       MVT MaskVT = getMaskTypeFor(ContainerVT);
8753       if (VT.isFixedLengthVector()) {
8754         Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
8755         PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
8756       }
8757     }
8758 
8759     auto *Load = cast<MemIntrinsicSDNode>(Op);
8760     SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
8761     SDValue Ptr = Op.getOperand(3);
8762     SDValue Stride = Op.getOperand(4);
8763     SDValue Result, Chain;
8764 
8765     // TODO: We restrict this to unmasked loads currently in consideration of
8766     // the complexity of hanlding all falses masks.
8767     if (IsUnmasked && isNullConstant(Stride)) {
8768       MVT ScalarVT = ContainerVT.getVectorElementType();
8769       SDValue ScalarLoad =
8770           DAG.getExtLoad(ISD::ZEXTLOAD, DL, XLenVT, Load->getChain(), Ptr,
8771                          ScalarVT, Load->getMemOperand());
8772       Chain = ScalarLoad.getValue(1);
8773       Result = lowerScalarSplat(SDValue(), ScalarLoad, VL, ContainerVT, DL, DAG,
8774                                 Subtarget);
8775     } else {
8776       SDValue IntID = DAG.getTargetConstant(
8777           IsUnmasked ? Intrinsic::riscv_vlse : Intrinsic::riscv_vlse_mask, DL,
8778           XLenVT);
8779 
8780       SmallVector<SDValue, 8> Ops{Load->getChain(), IntID};
8781       if (IsUnmasked)
8782         Ops.push_back(DAG.getUNDEF(ContainerVT));
8783       else
8784         Ops.push_back(PassThru);
8785       Ops.push_back(Ptr);
8786       Ops.push_back(Stride);
8787       if (!IsUnmasked)
8788         Ops.push_back(Mask);
8789       Ops.push_back(VL);
8790       if (!IsUnmasked) {
8791         SDValue Policy =
8792             DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT);
8793         Ops.push_back(Policy);
8794       }
8795 
8796       SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
8797       Result =
8798           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
8799                                   Load->getMemoryVT(), Load->getMemOperand());
8800       Chain = Result.getValue(1);
8801     }
8802     if (VT.isFixedLengthVector())
8803       Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
8804     return DAG.getMergeValues({Result, Chain}, DL);
8805   }
8806   case Intrinsic::riscv_seg2_load:
8807   case Intrinsic::riscv_seg3_load:
8808   case Intrinsic::riscv_seg4_load:
8809   case Intrinsic::riscv_seg5_load:
8810   case Intrinsic::riscv_seg6_load:
8811   case Intrinsic::riscv_seg7_load:
8812   case Intrinsic::riscv_seg8_load: {
8813     SDLoc DL(Op);
8814     static const Intrinsic::ID VlsegInts[7] = {
8815         Intrinsic::riscv_vlseg2, Intrinsic::riscv_vlseg3,
8816         Intrinsic::riscv_vlseg4, Intrinsic::riscv_vlseg5,
8817         Intrinsic::riscv_vlseg6, Intrinsic::riscv_vlseg7,
8818         Intrinsic::riscv_vlseg8};
8819     unsigned NF = Op->getNumValues() - 1;
8820     assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
8821     MVT XLenVT = Subtarget.getXLenVT();
8822     MVT VT = Op->getSimpleValueType(0);
8823     MVT ContainerVT = getContainerForFixedLengthVector(VT);
8824 
8825     SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
8826                          Subtarget);
8827     SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
8828     auto *Load = cast<MemIntrinsicSDNode>(Op);
8829     SmallVector<EVT, 9> ContainerVTs(NF, ContainerVT);
8830     ContainerVTs.push_back(MVT::Other);
8831     SDVTList VTs = DAG.getVTList(ContainerVTs);
8832     SmallVector<SDValue, 12> Ops = {Load->getChain(), IntID};
8833     Ops.insert(Ops.end(), NF, DAG.getUNDEF(ContainerVT));
8834     Ops.push_back(Op.getOperand(2));
8835     Ops.push_back(VL);
8836     SDValue Result =
8837         DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
8838                                 Load->getMemoryVT(), Load->getMemOperand());
8839     SmallVector<SDValue, 9> Results;
8840     for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++)
8841       Results.push_back(convertFromScalableVector(VT, Result.getValue(RetIdx),
8842                                                   DAG, Subtarget));
8843     Results.push_back(Result.getValue(NF));
8844     return DAG.getMergeValues(Results, DL);
8845   }
8846   case Intrinsic::riscv_sf_vc_v_x_se:
8847   case Intrinsic::riscv_sf_vc_v_i_se:
8848   case Intrinsic::riscv_sf_vc_v_xv_se:
8849   case Intrinsic::riscv_sf_vc_v_iv_se:
8850   case Intrinsic::riscv_sf_vc_v_vv_se:
8851   case Intrinsic::riscv_sf_vc_v_fv_se:
8852   case Intrinsic::riscv_sf_vc_v_xvv_se:
8853   case Intrinsic::riscv_sf_vc_v_ivv_se:
8854   case Intrinsic::riscv_sf_vc_v_vvv_se:
8855   case Intrinsic::riscv_sf_vc_v_fvv_se:
8856   case Intrinsic::riscv_sf_vc_v_xvw_se:
8857   case Intrinsic::riscv_sf_vc_v_ivw_se:
8858   case Intrinsic::riscv_sf_vc_v_vvw_se:
8859   case Intrinsic::riscv_sf_vc_v_fvw_se: {
8860     MVT VT = Op.getSimpleValueType();
8861     SDLoc DL(Op);
8862     SmallVector<SDValue> Ops;
8863     getVCIXOperands(Op, DAG, Ops);
8864 
8865     MVT RetVT = VT;
8866     if (VT.isFixedLengthVector())
8867       RetVT = getContainerForFixedLengthVector(VT);
8868     else if (VT.isFloatingPoint())
8869       RetVT = MVT::getVectorVT(MVT::getIntegerVT(RetVT.getScalarSizeInBits()),
8870                                RetVT.getVectorElementCount());
8871 
8872     SDVTList VTs = DAG.getVTList({RetVT, MVT::Other});
8873     SDValue NewNode = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops);
8874 
8875     if (VT.isFixedLengthVector()) {
8876       SDValue FixedVector =
8877           convertFromScalableVector(VT, NewNode, DAG, Subtarget);
8878       NewNode = DAG.getMergeValues({FixedVector, NewNode.getValue(1)}, DL);
8879     } else if (VT.isFloatingPoint()) {
8880       SDValue BitCast = DAG.getBitcast(VT, NewNode.getValue(0));
8881       NewNode = DAG.getMergeValues({BitCast, NewNode.getValue(1)}, DL);
8882     }
8883 
8884     if (Op == NewNode)
8885       break;
8886 
8887     return NewNode;
8888   }
8889   }
8890 
8891   return lowerVectorIntrinsicScalars(Op, DAG, Subtarget);
8892 }
8893 
LowerINTRINSIC_VOID(SDValue Op,SelectionDAG & DAG) const8894 SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
8895                                                  SelectionDAG &DAG) const {
8896   unsigned IntNo = Op.getConstantOperandVal(1);
8897   switch (IntNo) {
8898   default:
8899     break;
8900   case Intrinsic::riscv_masked_strided_store: {
8901     SDLoc DL(Op);
8902     MVT XLenVT = Subtarget.getXLenVT();
8903 
8904     // If the mask is known to be all ones, optimize to an unmasked intrinsic;
8905     // the selection of the masked intrinsics doesn't do this for us.
8906     SDValue Mask = Op.getOperand(5);
8907     bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
8908 
8909     SDValue Val = Op.getOperand(2);
8910     MVT VT = Val.getSimpleValueType();
8911     MVT ContainerVT = VT;
8912     if (VT.isFixedLengthVector()) {
8913       ContainerVT = getContainerForFixedLengthVector(VT);
8914       Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
8915     }
8916     if (!IsUnmasked) {
8917       MVT MaskVT = getMaskTypeFor(ContainerVT);
8918       if (VT.isFixedLengthVector())
8919         Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
8920     }
8921 
8922     SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
8923 
8924     SDValue IntID = DAG.getTargetConstant(
8925         IsUnmasked ? Intrinsic::riscv_vsse : Intrinsic::riscv_vsse_mask, DL,
8926         XLenVT);
8927 
8928     auto *Store = cast<MemIntrinsicSDNode>(Op);
8929     SmallVector<SDValue, 8> Ops{Store->getChain(), IntID};
8930     Ops.push_back(Val);
8931     Ops.push_back(Op.getOperand(3)); // Ptr
8932     Ops.push_back(Op.getOperand(4)); // Stride
8933     if (!IsUnmasked)
8934       Ops.push_back(Mask);
8935     Ops.push_back(VL);
8936 
8937     return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, Store->getVTList(),
8938                                    Ops, Store->getMemoryVT(),
8939                                    Store->getMemOperand());
8940   }
8941   case Intrinsic::riscv_seg2_store:
8942   case Intrinsic::riscv_seg3_store:
8943   case Intrinsic::riscv_seg4_store:
8944   case Intrinsic::riscv_seg5_store:
8945   case Intrinsic::riscv_seg6_store:
8946   case Intrinsic::riscv_seg7_store:
8947   case Intrinsic::riscv_seg8_store: {
8948     SDLoc DL(Op);
8949     static const Intrinsic::ID VssegInts[] = {
8950         Intrinsic::riscv_vsseg2, Intrinsic::riscv_vsseg3,
8951         Intrinsic::riscv_vsseg4, Intrinsic::riscv_vsseg5,
8952         Intrinsic::riscv_vsseg6, Intrinsic::riscv_vsseg7,
8953         Intrinsic::riscv_vsseg8};
8954     // Operands are (chain, int_id, vec*, ptr, vl)
8955     unsigned NF = Op->getNumOperands() - 4;
8956     assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
8957     MVT XLenVT = Subtarget.getXLenVT();
8958     MVT VT = Op->getOperand(2).getSimpleValueType();
8959     MVT ContainerVT = getContainerForFixedLengthVector(VT);
8960 
8961     SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
8962                          Subtarget);
8963     SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT);
8964     SDValue Ptr = Op->getOperand(NF + 2);
8965 
8966     auto *FixedIntrinsic = cast<MemIntrinsicSDNode>(Op);
8967     SmallVector<SDValue, 12> Ops = {FixedIntrinsic->getChain(), IntID};
8968     for (unsigned i = 0; i < NF; i++)
8969       Ops.push_back(convertToScalableVector(
8970           ContainerVT, FixedIntrinsic->getOperand(2 + i), DAG, Subtarget));
8971     Ops.append({Ptr, VL});
8972 
8973     return DAG.getMemIntrinsicNode(
8974         ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Ops,
8975         FixedIntrinsic->getMemoryVT(), FixedIntrinsic->getMemOperand());
8976   }
8977   case Intrinsic::riscv_sf_vc_x_se_e8mf8:
8978   case Intrinsic::riscv_sf_vc_x_se_e8mf4:
8979   case Intrinsic::riscv_sf_vc_x_se_e8mf2:
8980   case Intrinsic::riscv_sf_vc_x_se_e8m1:
8981   case Intrinsic::riscv_sf_vc_x_se_e8m2:
8982   case Intrinsic::riscv_sf_vc_x_se_e8m4:
8983   case Intrinsic::riscv_sf_vc_x_se_e8m8:
8984   case Intrinsic::riscv_sf_vc_x_se_e16mf4:
8985   case Intrinsic::riscv_sf_vc_x_se_e16mf2:
8986   case Intrinsic::riscv_sf_vc_x_se_e16m1:
8987   case Intrinsic::riscv_sf_vc_x_se_e16m2:
8988   case Intrinsic::riscv_sf_vc_x_se_e16m4:
8989   case Intrinsic::riscv_sf_vc_x_se_e16m8:
8990   case Intrinsic::riscv_sf_vc_x_se_e32mf2:
8991   case Intrinsic::riscv_sf_vc_x_se_e32m1:
8992   case Intrinsic::riscv_sf_vc_x_se_e32m2:
8993   case Intrinsic::riscv_sf_vc_x_se_e32m4:
8994   case Intrinsic::riscv_sf_vc_x_se_e32m8:
8995   case Intrinsic::riscv_sf_vc_x_se_e64m1:
8996   case Intrinsic::riscv_sf_vc_x_se_e64m2:
8997   case Intrinsic::riscv_sf_vc_x_se_e64m4:
8998   case Intrinsic::riscv_sf_vc_x_se_e64m8:
8999   case Intrinsic::riscv_sf_vc_i_se_e8mf8:
9000   case Intrinsic::riscv_sf_vc_i_se_e8mf4:
9001   case Intrinsic::riscv_sf_vc_i_se_e8mf2:
9002   case Intrinsic::riscv_sf_vc_i_se_e8m1:
9003   case Intrinsic::riscv_sf_vc_i_se_e8m2:
9004   case Intrinsic::riscv_sf_vc_i_se_e8m4:
9005   case Intrinsic::riscv_sf_vc_i_se_e8m8:
9006   case Intrinsic::riscv_sf_vc_i_se_e16mf4:
9007   case Intrinsic::riscv_sf_vc_i_se_e16mf2:
9008   case Intrinsic::riscv_sf_vc_i_se_e16m1:
9009   case Intrinsic::riscv_sf_vc_i_se_e16m2:
9010   case Intrinsic::riscv_sf_vc_i_se_e16m4:
9011   case Intrinsic::riscv_sf_vc_i_se_e16m8:
9012   case Intrinsic::riscv_sf_vc_i_se_e32mf2:
9013   case Intrinsic::riscv_sf_vc_i_se_e32m1:
9014   case Intrinsic::riscv_sf_vc_i_se_e32m2:
9015   case Intrinsic::riscv_sf_vc_i_se_e32m4:
9016   case Intrinsic::riscv_sf_vc_i_se_e32m8:
9017   case Intrinsic::riscv_sf_vc_i_se_e64m1:
9018   case Intrinsic::riscv_sf_vc_i_se_e64m2:
9019   case Intrinsic::riscv_sf_vc_i_se_e64m4:
9020   case Intrinsic::riscv_sf_vc_i_se_e64m8:
9021   case Intrinsic::riscv_sf_vc_xv_se:
9022   case Intrinsic::riscv_sf_vc_iv_se:
9023   case Intrinsic::riscv_sf_vc_vv_se:
9024   case Intrinsic::riscv_sf_vc_fv_se:
9025   case Intrinsic::riscv_sf_vc_xvv_se:
9026   case Intrinsic::riscv_sf_vc_ivv_se:
9027   case Intrinsic::riscv_sf_vc_vvv_se:
9028   case Intrinsic::riscv_sf_vc_fvv_se:
9029   case Intrinsic::riscv_sf_vc_xvw_se:
9030   case Intrinsic::riscv_sf_vc_ivw_se:
9031   case Intrinsic::riscv_sf_vc_vvw_se:
9032   case Intrinsic::riscv_sf_vc_fvw_se: {
9033     SmallVector<SDValue> Ops;
9034     getVCIXOperands(Op, DAG, Ops);
9035 
9036     SDValue NewNode =
9037         DAG.getNode(ISD::INTRINSIC_VOID, SDLoc(Op), Op->getVTList(), Ops);
9038 
9039     if (Op == NewNode)
9040       break;
9041 
9042     return NewNode;
9043   }
9044   }
9045 
9046   return lowerVectorIntrinsicScalars(Op, DAG, Subtarget);
9047 }
9048 
getRVVReductionOp(unsigned ISDOpcode)9049 static unsigned getRVVReductionOp(unsigned ISDOpcode) {
9050   switch (ISDOpcode) {
9051   default:
9052     llvm_unreachable("Unhandled reduction");
9053   case ISD::VP_REDUCE_ADD:
9054   case ISD::VECREDUCE_ADD:
9055     return RISCVISD::VECREDUCE_ADD_VL;
9056   case ISD::VP_REDUCE_UMAX:
9057   case ISD::VECREDUCE_UMAX:
9058     return RISCVISD::VECREDUCE_UMAX_VL;
9059   case ISD::VP_REDUCE_SMAX:
9060   case ISD::VECREDUCE_SMAX:
9061     return RISCVISD::VECREDUCE_SMAX_VL;
9062   case ISD::VP_REDUCE_UMIN:
9063   case ISD::VECREDUCE_UMIN:
9064     return RISCVISD::VECREDUCE_UMIN_VL;
9065   case ISD::VP_REDUCE_SMIN:
9066   case ISD::VECREDUCE_SMIN:
9067     return RISCVISD::VECREDUCE_SMIN_VL;
9068   case ISD::VP_REDUCE_AND:
9069   case ISD::VECREDUCE_AND:
9070     return RISCVISD::VECREDUCE_AND_VL;
9071   case ISD::VP_REDUCE_OR:
9072   case ISD::VECREDUCE_OR:
9073     return RISCVISD::VECREDUCE_OR_VL;
9074   case ISD::VP_REDUCE_XOR:
9075   case ISD::VECREDUCE_XOR:
9076     return RISCVISD::VECREDUCE_XOR_VL;
9077   case ISD::VP_REDUCE_FADD:
9078     return RISCVISD::VECREDUCE_FADD_VL;
9079   case ISD::VP_REDUCE_SEQ_FADD:
9080     return RISCVISD::VECREDUCE_SEQ_FADD_VL;
9081   case ISD::VP_REDUCE_FMAX:
9082     return RISCVISD::VECREDUCE_FMAX_VL;
9083   case ISD::VP_REDUCE_FMIN:
9084     return RISCVISD::VECREDUCE_FMIN_VL;
9085   }
9086 
9087 }
9088 
lowerVectorMaskVecReduction(SDValue Op,SelectionDAG & DAG,bool IsVP) const9089 SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op,
9090                                                          SelectionDAG &DAG,
9091                                                          bool IsVP) const {
9092   SDLoc DL(Op);
9093   SDValue Vec = Op.getOperand(IsVP ? 1 : 0);
9094   MVT VecVT = Vec.getSimpleValueType();
9095   assert((Op.getOpcode() == ISD::VECREDUCE_AND ||
9096           Op.getOpcode() == ISD::VECREDUCE_OR ||
9097           Op.getOpcode() == ISD::VECREDUCE_XOR ||
9098           Op.getOpcode() == ISD::VP_REDUCE_AND ||
9099           Op.getOpcode() == ISD::VP_REDUCE_OR ||
9100           Op.getOpcode() == ISD::VP_REDUCE_XOR) &&
9101          "Unexpected reduction lowering");
9102 
9103   MVT XLenVT = Subtarget.getXLenVT();
9104 
9105   MVT ContainerVT = VecVT;
9106   if (VecVT.isFixedLengthVector()) {
9107     ContainerVT = getContainerForFixedLengthVector(VecVT);
9108     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
9109   }
9110 
9111   SDValue Mask, VL;
9112   if (IsVP) {
9113     Mask = Op.getOperand(2);
9114     VL = Op.getOperand(3);
9115   } else {
9116     std::tie(Mask, VL) =
9117         getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
9118   }
9119 
9120   unsigned BaseOpc;
9121   ISD::CondCode CC;
9122   SDValue Zero = DAG.getConstant(0, DL, XLenVT);
9123 
9124   switch (Op.getOpcode()) {
9125   default:
9126     llvm_unreachable("Unhandled reduction");
9127   case ISD::VECREDUCE_AND:
9128   case ISD::VP_REDUCE_AND: {
9129     // vcpop ~x == 0
9130     SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL);
9131     Vec = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Vec, TrueMask, VL);
9132     Vec = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Vec, Mask, VL);
9133     CC = ISD::SETEQ;
9134     BaseOpc = ISD::AND;
9135     break;
9136   }
9137   case ISD::VECREDUCE_OR:
9138   case ISD::VP_REDUCE_OR:
9139     // vcpop x != 0
9140     Vec = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Vec, Mask, VL);
9141     CC = ISD::SETNE;
9142     BaseOpc = ISD::OR;
9143     break;
9144   case ISD::VECREDUCE_XOR:
9145   case ISD::VP_REDUCE_XOR: {
9146     // ((vcpop x) & 1) != 0
9147     SDValue One = DAG.getConstant(1, DL, XLenVT);
9148     Vec = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Vec, Mask, VL);
9149     Vec = DAG.getNode(ISD::AND, DL, XLenVT, Vec, One);
9150     CC = ISD::SETNE;
9151     BaseOpc = ISD::XOR;
9152     break;
9153   }
9154   }
9155 
9156   SDValue SetCC = DAG.getSetCC(DL, XLenVT, Vec, Zero, CC);
9157   SetCC = DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), SetCC);
9158 
9159   if (!IsVP)
9160     return SetCC;
9161 
9162   // Now include the start value in the operation.
9163   // Note that we must return the start value when no elements are operated
9164   // upon. The vcpop instructions we've emitted in each case above will return
9165   // 0 for an inactive vector, and so we've already received the neutral value:
9166   // AND gives us (0 == 0) -> 1 and OR/XOR give us (0 != 0) -> 0. Therefore we
9167   // can simply include the start value.
9168   return DAG.getNode(BaseOpc, DL, Op.getValueType(), SetCC, Op.getOperand(0));
9169 }
9170 
isNonZeroAVL(SDValue AVL)9171 static bool isNonZeroAVL(SDValue AVL) {
9172   auto *RegisterAVL = dyn_cast<RegisterSDNode>(AVL);
9173   auto *ImmAVL = dyn_cast<ConstantSDNode>(AVL);
9174   return (RegisterAVL && RegisterAVL->getReg() == RISCV::X0) ||
9175          (ImmAVL && ImmAVL->getZExtValue() >= 1);
9176 }
9177 
9178 /// Helper to lower a reduction sequence of the form:
9179 /// scalar = reduce_op vec, scalar_start
lowerReductionSeq(unsigned RVVOpcode,MVT ResVT,SDValue StartValue,SDValue Vec,SDValue Mask,SDValue VL,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)9180 static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
9181                                  SDValue StartValue, SDValue Vec, SDValue Mask,
9182                                  SDValue VL, const SDLoc &DL, SelectionDAG &DAG,
9183                                  const RISCVSubtarget &Subtarget) {
9184   const MVT VecVT = Vec.getSimpleValueType();
9185   const MVT M1VT = getLMUL1VT(VecVT);
9186   const MVT XLenVT = Subtarget.getXLenVT();
9187   const bool NonZeroAVL = isNonZeroAVL(VL);
9188 
9189   // The reduction needs an LMUL1 input; do the splat at either LMUL1
9190   // or the original VT if fractional.
9191   auto InnerVT = VecVT.bitsLE(M1VT) ? VecVT : M1VT;
9192   // We reuse the VL of the reduction to reduce vsetvli toggles if we can
9193   // prove it is non-zero.  For the AVL=0 case, we need the scalar to
9194   // be the result of the reduction operation.
9195   auto InnerVL = NonZeroAVL ? VL : DAG.getConstant(1, DL, XLenVT);
9196   SDValue InitialValue = lowerScalarInsert(StartValue, InnerVL, InnerVT, DL,
9197                                            DAG, Subtarget);
9198   if (M1VT != InnerVT)
9199     InitialValue = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, M1VT,
9200                                DAG.getUNDEF(M1VT),
9201                                InitialValue, DAG.getConstant(0, DL, XLenVT));
9202   SDValue PassThru = NonZeroAVL ? DAG.getUNDEF(M1VT) : InitialValue;
9203   SDValue Policy = DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT);
9204   SDValue Ops[] = {PassThru, Vec, InitialValue, Mask, VL, Policy};
9205   SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, Ops);
9206   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
9207                      DAG.getConstant(0, DL, XLenVT));
9208 }
9209 
lowerVECREDUCE(SDValue Op,SelectionDAG & DAG) const9210 SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
9211                                             SelectionDAG &DAG) const {
9212   SDLoc DL(Op);
9213   SDValue Vec = Op.getOperand(0);
9214   EVT VecEVT = Vec.getValueType();
9215 
9216   unsigned BaseOpc = ISD::getVecReduceBaseOpcode(Op.getOpcode());
9217 
9218   // Due to ordering in legalize types we may have a vector type that needs to
9219   // be split. Do that manually so we can get down to a legal type.
9220   while (getTypeAction(*DAG.getContext(), VecEVT) ==
9221          TargetLowering::TypeSplitVector) {
9222     auto [Lo, Hi] = DAG.SplitVector(Vec, DL);
9223     VecEVT = Lo.getValueType();
9224     Vec = DAG.getNode(BaseOpc, DL, VecEVT, Lo, Hi);
9225   }
9226 
9227   // TODO: The type may need to be widened rather than split. Or widened before
9228   // it can be split.
9229   if (!isTypeLegal(VecEVT))
9230     return SDValue();
9231 
9232   MVT VecVT = VecEVT.getSimpleVT();
9233   MVT VecEltVT = VecVT.getVectorElementType();
9234   unsigned RVVOpcode = getRVVReductionOp(Op.getOpcode());
9235 
9236   MVT ContainerVT = VecVT;
9237   if (VecVT.isFixedLengthVector()) {
9238     ContainerVT = getContainerForFixedLengthVector(VecVT);
9239     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
9240   }
9241 
9242   auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
9243 
9244   SDValue StartV = DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags());
9245   switch (BaseOpc) {
9246   case ISD::AND:
9247   case ISD::OR:
9248   case ISD::UMAX:
9249   case ISD::UMIN:
9250   case ISD::SMAX:
9251   case ISD::SMIN:
9252     MVT XLenVT = Subtarget.getXLenVT();
9253     StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec,
9254                          DAG.getConstant(0, DL, XLenVT));
9255   }
9256   return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec,
9257                            Mask, VL, DL, DAG, Subtarget);
9258 }
9259 
9260 // Given a reduction op, this function returns the matching reduction opcode,
9261 // the vector SDValue and the scalar SDValue required to lower this to a
9262 // RISCVISD node.
9263 static std::tuple<unsigned, SDValue, SDValue>
getRVVFPReductionOpAndOperands(SDValue Op,SelectionDAG & DAG,EVT EltVT,const RISCVSubtarget & Subtarget)9264 getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT,
9265                                const RISCVSubtarget &Subtarget) {
9266   SDLoc DL(Op);
9267   auto Flags = Op->getFlags();
9268   unsigned Opcode = Op.getOpcode();
9269   switch (Opcode) {
9270   default:
9271     llvm_unreachable("Unhandled reduction");
9272   case ISD::VECREDUCE_FADD: {
9273     // Use positive zero if we can. It is cheaper to materialize.
9274     SDValue Zero =
9275         DAG.getConstantFP(Flags.hasNoSignedZeros() ? 0.0 : -0.0, DL, EltVT);
9276     return std::make_tuple(RISCVISD::VECREDUCE_FADD_VL, Op.getOperand(0), Zero);
9277   }
9278   case ISD::VECREDUCE_SEQ_FADD:
9279     return std::make_tuple(RISCVISD::VECREDUCE_SEQ_FADD_VL, Op.getOperand(1),
9280                            Op.getOperand(0));
9281   case ISD::VECREDUCE_FMIN:
9282   case ISD::VECREDUCE_FMAX: {
9283     MVT XLenVT = Subtarget.getXLenVT();
9284     SDValue Front =
9285         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0),
9286                     DAG.getConstant(0, DL, XLenVT));
9287     unsigned RVVOpc = (Opcode == ISD::VECREDUCE_FMIN)
9288                           ? RISCVISD::VECREDUCE_FMIN_VL
9289                           : RISCVISD::VECREDUCE_FMAX_VL;
9290     return std::make_tuple(RVVOpc, Op.getOperand(0), Front);
9291   }
9292   }
9293 }
9294 
lowerFPVECREDUCE(SDValue Op,SelectionDAG & DAG) const9295 SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op,
9296                                               SelectionDAG &DAG) const {
9297   SDLoc DL(Op);
9298   MVT VecEltVT = Op.getSimpleValueType();
9299 
9300   unsigned RVVOpcode;
9301   SDValue VectorVal, ScalarVal;
9302   std::tie(RVVOpcode, VectorVal, ScalarVal) =
9303       getRVVFPReductionOpAndOperands(Op, DAG, VecEltVT, Subtarget);
9304   MVT VecVT = VectorVal.getSimpleValueType();
9305 
9306   MVT ContainerVT = VecVT;
9307   if (VecVT.isFixedLengthVector()) {
9308     ContainerVT = getContainerForFixedLengthVector(VecVT);
9309     VectorVal = convertToScalableVector(ContainerVT, VectorVal, DAG, Subtarget);
9310   }
9311 
9312   auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
9313   return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), ScalarVal,
9314                            VectorVal, Mask, VL, DL, DAG, Subtarget);
9315 }
9316 
lowerVPREDUCE(SDValue Op,SelectionDAG & DAG) const9317 SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
9318                                            SelectionDAG &DAG) const {
9319   SDLoc DL(Op);
9320   SDValue Vec = Op.getOperand(1);
9321   EVT VecEVT = Vec.getValueType();
9322 
9323   // TODO: The type may need to be widened rather than split. Or widened before
9324   // it can be split.
9325   if (!isTypeLegal(VecEVT))
9326     return SDValue();
9327 
9328   MVT VecVT = VecEVT.getSimpleVT();
9329   unsigned RVVOpcode = getRVVReductionOp(Op.getOpcode());
9330 
9331   if (VecVT.isFixedLengthVector()) {
9332     auto ContainerVT = getContainerForFixedLengthVector(VecVT);
9333     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
9334   }
9335 
9336   SDValue VL = Op.getOperand(3);
9337   SDValue Mask = Op.getOperand(2);
9338   return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), Op.getOperand(0),
9339                            Vec, Mask, VL, DL, DAG, Subtarget);
9340 }
9341 
lowerINSERT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const9342 SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9343                                                    SelectionDAG &DAG) const {
9344   SDValue Vec = Op.getOperand(0);
9345   SDValue SubVec = Op.getOperand(1);
9346   MVT VecVT = Vec.getSimpleValueType();
9347   MVT SubVecVT = SubVec.getSimpleValueType();
9348 
9349   SDLoc DL(Op);
9350   MVT XLenVT = Subtarget.getXLenVT();
9351   unsigned OrigIdx = Op.getConstantOperandVal(2);
9352   const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
9353 
9354   // We don't have the ability to slide mask vectors up indexed by their i1
9355   // elements; the smallest we can do is i8. Often we are able to bitcast to
9356   // equivalent i8 vectors. Note that when inserting a fixed-length vector
9357   // into a scalable one, we might not necessarily have enough scalable
9358   // elements to safely divide by 8: nxv1i1 = insert nxv1i1, v4i1 is valid.
9359   if (SubVecVT.getVectorElementType() == MVT::i1 &&
9360       (OrigIdx != 0 || !Vec.isUndef())) {
9361     if (VecVT.getVectorMinNumElements() >= 8 &&
9362         SubVecVT.getVectorMinNumElements() >= 8) {
9363       assert(OrigIdx % 8 == 0 && "Invalid index");
9364       assert(VecVT.getVectorMinNumElements() % 8 == 0 &&
9365              SubVecVT.getVectorMinNumElements() % 8 == 0 &&
9366              "Unexpected mask vector lowering");
9367       OrigIdx /= 8;
9368       SubVecVT =
9369           MVT::getVectorVT(MVT::i8, SubVecVT.getVectorMinNumElements() / 8,
9370                            SubVecVT.isScalableVector());
9371       VecVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorMinNumElements() / 8,
9372                                VecVT.isScalableVector());
9373       Vec = DAG.getBitcast(VecVT, Vec);
9374       SubVec = DAG.getBitcast(SubVecVT, SubVec);
9375     } else {
9376       // We can't slide this mask vector up indexed by its i1 elements.
9377       // This poses a problem when we wish to insert a scalable vector which
9378       // can't be re-expressed as a larger type. Just choose the slow path and
9379       // extend to a larger type, then truncate back down.
9380       MVT ExtVecVT = VecVT.changeVectorElementType(MVT::i8);
9381       MVT ExtSubVecVT = SubVecVT.changeVectorElementType(MVT::i8);
9382       Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVecVT, Vec);
9383       SubVec = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtSubVecVT, SubVec);
9384       Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ExtVecVT, Vec, SubVec,
9385                         Op.getOperand(2));
9386       SDValue SplatZero = DAG.getConstant(0, DL, ExtVecVT);
9387       return DAG.getSetCC(DL, VecVT, Vec, SplatZero, ISD::SETNE);
9388     }
9389   }
9390 
9391   // If the subvector vector is a fixed-length type, we cannot use subregister
9392   // manipulation to simplify the codegen; we don't know which register of a
9393   // LMUL group contains the specific subvector as we only know the minimum
9394   // register size. Therefore we must slide the vector group up the full
9395   // amount.
9396   if (SubVecVT.isFixedLengthVector()) {
9397     if (OrigIdx == 0 && Vec.isUndef() && !VecVT.isFixedLengthVector())
9398       return Op;
9399     MVT ContainerVT = VecVT;
9400     if (VecVT.isFixedLengthVector()) {
9401       ContainerVT = getContainerForFixedLengthVector(VecVT);
9402       Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
9403     }
9404 
9405     if (OrigIdx == 0 && Vec.isUndef() && VecVT.isFixedLengthVector()) {
9406       SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT,
9407                            DAG.getUNDEF(ContainerVT), SubVec,
9408                            DAG.getConstant(0, DL, XLenVT));
9409       SubVec = convertFromScalableVector(VecVT, SubVec, DAG, Subtarget);
9410       return DAG.getBitcast(Op.getValueType(), SubVec);
9411     }
9412 
9413     SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT,
9414                          DAG.getUNDEF(ContainerVT), SubVec,
9415                          DAG.getConstant(0, DL, XLenVT));
9416     SDValue Mask =
9417         getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
9418     // Set the vector length to only the number of elements we care about. Note
9419     // that for slideup this includes the offset.
9420     unsigned EndIndex = OrigIdx + SubVecVT.getVectorNumElements();
9421     SDValue VL = getVLOp(EndIndex, ContainerVT, DL, DAG, Subtarget);
9422 
9423     // Use tail agnostic policy if we're inserting over Vec's tail.
9424     unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
9425     if (VecVT.isFixedLengthVector() && EndIndex == VecVT.getVectorNumElements())
9426       Policy = RISCVII::TAIL_AGNOSTIC;
9427 
9428     // If we're inserting into the lowest elements, use a tail undisturbed
9429     // vmv.v.v.
9430     if (OrigIdx == 0) {
9431       SubVec =
9432           DAG.getNode(RISCVISD::VMV_V_V_VL, DL, ContainerVT, Vec, SubVec, VL);
9433     } else {
9434       SDValue SlideupAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
9435       SubVec = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, SubVec,
9436                            SlideupAmt, Mask, VL, Policy);
9437     }
9438 
9439     if (VecVT.isFixedLengthVector())
9440       SubVec = convertFromScalableVector(VecVT, SubVec, DAG, Subtarget);
9441     return DAG.getBitcast(Op.getValueType(), SubVec);
9442   }
9443 
9444   unsigned SubRegIdx, RemIdx;
9445   std::tie(SubRegIdx, RemIdx) =
9446       RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9447           VecVT, SubVecVT, OrigIdx, TRI);
9448 
9449   RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(SubVecVT);
9450   bool IsSubVecPartReg = SubVecLMUL == RISCVII::VLMUL::LMUL_F2 ||
9451                          SubVecLMUL == RISCVII::VLMUL::LMUL_F4 ||
9452                          SubVecLMUL == RISCVII::VLMUL::LMUL_F8;
9453 
9454   // 1. If the Idx has been completely eliminated and this subvector's size is
9455   // a vector register or a multiple thereof, or the surrounding elements are
9456   // undef, then this is a subvector insert which naturally aligns to a vector
9457   // register. These can easily be handled using subregister manipulation.
9458   // 2. If the subvector is smaller than a vector register, then the insertion
9459   // must preserve the undisturbed elements of the register. We do this by
9460   // lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
9461   // (which resolves to a subregister copy), performing a VSLIDEUP to place the
9462   // subvector within the vector register, and an INSERT_SUBVECTOR of that
9463   // LMUL=1 type back into the larger vector (resolving to another subregister
9464   // operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
9465   // to avoid allocating a large register group to hold our subvector.
9466   if (RemIdx == 0 && (!IsSubVecPartReg || Vec.isUndef()))
9467     return Op;
9468 
9469   // VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
9470   // OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
9471   // (in our case undisturbed). This means we can set up a subvector insertion
9472   // where OFFSET is the insertion offset, and the VL is the OFFSET plus the
9473   // size of the subvector.
9474   MVT InterSubVT = VecVT;
9475   SDValue AlignedExtract = Vec;
9476   unsigned AlignedIdx = OrigIdx - RemIdx;
9477   if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
9478     InterSubVT = getLMUL1VT(VecVT);
9479     // Extract a subvector equal to the nearest full vector register type. This
9480     // should resolve to a EXTRACT_SUBREG instruction.
9481     AlignedExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec,
9482                                  DAG.getConstant(AlignedIdx, DL, XLenVT));
9483   }
9484 
9485   SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InterSubVT,
9486                        DAG.getUNDEF(InterSubVT), SubVec,
9487                        DAG.getConstant(0, DL, XLenVT));
9488 
9489   auto [Mask, VL] = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
9490 
9491   VL = computeVLMax(SubVecVT, DL, DAG);
9492 
9493   // If we're inserting into the lowest elements, use a tail undisturbed
9494   // vmv.v.v.
9495   if (RemIdx == 0) {
9496     SubVec = DAG.getNode(RISCVISD::VMV_V_V_VL, DL, InterSubVT, AlignedExtract,
9497                          SubVec, VL);
9498   } else {
9499     SDValue SlideupAmt =
9500         DAG.getVScale(DL, XLenVT, APInt(XLenVT.getSizeInBits(), RemIdx));
9501 
9502     // Construct the vector length corresponding to RemIdx + length(SubVecVT).
9503     VL = DAG.getNode(ISD::ADD, DL, XLenVT, SlideupAmt, VL);
9504 
9505     SubVec = getVSlideup(DAG, Subtarget, DL, InterSubVT, AlignedExtract, SubVec,
9506                          SlideupAmt, Mask, VL);
9507   }
9508 
9509   // If required, insert this subvector back into the correct vector register.
9510   // This should resolve to an INSERT_SUBREG instruction.
9511   if (VecVT.bitsGT(InterSubVT))
9512     SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT, Vec, SubVec,
9513                          DAG.getConstant(AlignedIdx, DL, XLenVT));
9514 
9515   // We might have bitcast from a mask type: cast back to the original type if
9516   // required.
9517   return DAG.getBitcast(Op.getSimpleValueType(), SubVec);
9518 }
9519 
lowerEXTRACT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const9520 SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
9521                                                     SelectionDAG &DAG) const {
9522   SDValue Vec = Op.getOperand(0);
9523   MVT SubVecVT = Op.getSimpleValueType();
9524   MVT VecVT = Vec.getSimpleValueType();
9525 
9526   SDLoc DL(Op);
9527   MVT XLenVT = Subtarget.getXLenVT();
9528   unsigned OrigIdx = Op.getConstantOperandVal(1);
9529   const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
9530 
9531   // We don't have the ability to slide mask vectors down indexed by their i1
9532   // elements; the smallest we can do is i8. Often we are able to bitcast to
9533   // equivalent i8 vectors. Note that when extracting a fixed-length vector
9534   // from a scalable one, we might not necessarily have enough scalable
9535   // elements to safely divide by 8: v8i1 = extract nxv1i1 is valid.
9536   if (SubVecVT.getVectorElementType() == MVT::i1 && OrigIdx != 0) {
9537     if (VecVT.getVectorMinNumElements() >= 8 &&
9538         SubVecVT.getVectorMinNumElements() >= 8) {
9539       assert(OrigIdx % 8 == 0 && "Invalid index");
9540       assert(VecVT.getVectorMinNumElements() % 8 == 0 &&
9541              SubVecVT.getVectorMinNumElements() % 8 == 0 &&
9542              "Unexpected mask vector lowering");
9543       OrigIdx /= 8;
9544       SubVecVT =
9545           MVT::getVectorVT(MVT::i8, SubVecVT.getVectorMinNumElements() / 8,
9546                            SubVecVT.isScalableVector());
9547       VecVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorMinNumElements() / 8,
9548                                VecVT.isScalableVector());
9549       Vec = DAG.getBitcast(VecVT, Vec);
9550     } else {
9551       // We can't slide this mask vector down, indexed by its i1 elements.
9552       // This poses a problem when we wish to extract a scalable vector which
9553       // can't be re-expressed as a larger type. Just choose the slow path and
9554       // extend to a larger type, then truncate back down.
9555       // TODO: We could probably improve this when extracting certain fixed
9556       // from fixed, where we can extract as i8 and shift the correct element
9557       // right to reach the desired subvector?
9558       MVT ExtVecVT = VecVT.changeVectorElementType(MVT::i8);
9559       MVT ExtSubVecVT = SubVecVT.changeVectorElementType(MVT::i8);
9560       Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVecVT, Vec);
9561       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtSubVecVT, Vec,
9562                         Op.getOperand(1));
9563       SDValue SplatZero = DAG.getConstant(0, DL, ExtSubVecVT);
9564       return DAG.getSetCC(DL, SubVecVT, Vec, SplatZero, ISD::SETNE);
9565     }
9566   }
9567 
9568   // With an index of 0 this is a cast-like subvector, which can be performed
9569   // with subregister operations.
9570   if (OrigIdx == 0)
9571     return Op;
9572 
9573   // If the subvector vector is a fixed-length type, we cannot use subregister
9574   // manipulation to simplify the codegen; we don't know which register of a
9575   // LMUL group contains the specific subvector as we only know the minimum
9576   // register size. Therefore we must slide the vector group down the full
9577   // amount.
9578   if (SubVecVT.isFixedLengthVector()) {
9579     MVT ContainerVT = VecVT;
9580     if (VecVT.isFixedLengthVector()) {
9581       ContainerVT = getContainerForFixedLengthVector(VecVT);
9582       Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
9583     }
9584 
9585     // Shrink down Vec so we're performing the slidedown on a smaller LMUL.
9586     unsigned LastIdx = OrigIdx + SubVecVT.getVectorNumElements() - 1;
9587     if (auto ShrunkVT =
9588             getSmallestVTForIndex(ContainerVT, LastIdx, DL, DAG, Subtarget)) {
9589       ContainerVT = *ShrunkVT;
9590       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
9591                         DAG.getVectorIdxConstant(0, DL));
9592     }
9593 
9594     SDValue Mask =
9595         getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
9596     // Set the vector length to only the number of elements we care about. This
9597     // avoids sliding down elements we're going to discard straight away.
9598     SDValue VL = getVLOp(SubVecVT.getVectorNumElements(), ContainerVT, DL, DAG,
9599                          Subtarget);
9600     SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
9601     SDValue Slidedown =
9602         getVSlidedown(DAG, Subtarget, DL, ContainerVT,
9603                       DAG.getUNDEF(ContainerVT), Vec, SlidedownAmt, Mask, VL);
9604     // Now we can use a cast-like subvector extract to get the result.
9605     Slidedown = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Slidedown,
9606                             DAG.getConstant(0, DL, XLenVT));
9607     return DAG.getBitcast(Op.getValueType(), Slidedown);
9608   }
9609 
9610   unsigned SubRegIdx, RemIdx;
9611   std::tie(SubRegIdx, RemIdx) =
9612       RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9613           VecVT, SubVecVT, OrigIdx, TRI);
9614 
9615   // If the Idx has been completely eliminated then this is a subvector extract
9616   // which naturally aligns to a vector register. These can easily be handled
9617   // using subregister manipulation.
9618   if (RemIdx == 0)
9619     return Op;
9620 
9621   // Else SubVecVT is a fractional LMUL and may need to be slid down.
9622   assert(RISCVVType::decodeVLMUL(getLMUL(SubVecVT)).second);
9623 
9624   // If the vector type is an LMUL-group type, extract a subvector equal to the
9625   // nearest full vector register type.
9626   MVT InterSubVT = VecVT;
9627   if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
9628     // If VecVT has an LMUL > 1, then SubVecVT should have a smaller LMUL, and
9629     // we should have successfully decomposed the extract into a subregister.
9630     assert(SubRegIdx != RISCV::NoSubRegister);
9631     InterSubVT = getLMUL1VT(VecVT);
9632     Vec = DAG.getTargetExtractSubreg(SubRegIdx, DL, InterSubVT, Vec);
9633   }
9634 
9635   // Slide this vector register down by the desired number of elements in order
9636   // to place the desired subvector starting at element 0.
9637   SDValue SlidedownAmt =
9638       DAG.getVScale(DL, XLenVT, APInt(XLenVT.getSizeInBits(), RemIdx));
9639 
9640   auto [Mask, VL] = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget);
9641   SDValue Slidedown =
9642       getVSlidedown(DAG, Subtarget, DL, InterSubVT, DAG.getUNDEF(InterSubVT),
9643                     Vec, SlidedownAmt, Mask, VL);
9644 
9645   // Now the vector is in the right position, extract our final subvector. This
9646   // should resolve to a COPY.
9647   Slidedown = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Slidedown,
9648                           DAG.getConstant(0, DL, XLenVT));
9649 
9650   // We might have bitcast from a mask type: cast back to the original type if
9651   // required.
9652   return DAG.getBitcast(Op.getSimpleValueType(), Slidedown);
9653 }
9654 
9655 // Widen a vector's operands to i8, then truncate its results back to the
9656 // original type, typically i1.  All operand and result types must be the same.
widenVectorOpsToi8(SDValue N,const SDLoc & DL,SelectionDAG & DAG)9657 static SDValue widenVectorOpsToi8(SDValue N, const SDLoc &DL,
9658                                   SelectionDAG &DAG) {
9659   MVT VT = N.getSimpleValueType();
9660   MVT WideVT = VT.changeVectorElementType(MVT::i8);
9661   SmallVector<SDValue, 4> WideOps;
9662   for (SDValue Op : N->ops()) {
9663     assert(Op.getSimpleValueType() == VT &&
9664            "Operands and result must be same type");
9665     WideOps.push_back(DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, Op));
9666   }
9667 
9668   unsigned NumVals = N->getNumValues();
9669 
9670   SDVTList VTs = DAG.getVTList(SmallVector<EVT, 4>(
9671       NumVals, N.getValueType().changeVectorElementType(MVT::i8)));
9672   SDValue WideN = DAG.getNode(N.getOpcode(), DL, VTs, WideOps);
9673   SmallVector<SDValue, 4> TruncVals;
9674   for (unsigned I = 0; I < NumVals; I++) {
9675     TruncVals.push_back(
9676         DAG.getSetCC(DL, N->getSimpleValueType(I), WideN.getValue(I),
9677                      DAG.getConstant(0, DL, WideVT), ISD::SETNE));
9678   }
9679 
9680   if (TruncVals.size() > 1)
9681     return DAG.getMergeValues(TruncVals, DL);
9682   return TruncVals.front();
9683 }
9684 
lowerVECTOR_DEINTERLEAVE(SDValue Op,SelectionDAG & DAG) const9685 SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
9686                                                       SelectionDAG &DAG) const {
9687   SDLoc DL(Op);
9688   MVT VecVT = Op.getSimpleValueType();
9689   MVT XLenVT = Subtarget.getXLenVT();
9690 
9691   assert(VecVT.isScalableVector() &&
9692          "vector_interleave on non-scalable vector!");
9693 
9694   // 1 bit element vectors need to be widened to e8
9695   if (VecVT.getVectorElementType() == MVT::i1)
9696     return widenVectorOpsToi8(Op, DL, DAG);
9697 
9698   // If the VT is LMUL=8, we need to split and reassemble.
9699   if (VecVT.getSizeInBits().getKnownMinValue() ==
9700       (8 * RISCV::RVVBitsPerBlock)) {
9701     auto [Op0Lo, Op0Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
9702     auto [Op1Lo, Op1Hi] = DAG.SplitVectorOperand(Op.getNode(), 1);
9703     EVT SplitVT = Op0Lo.getValueType();
9704 
9705     SDValue ResLo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
9706                                 DAG.getVTList(SplitVT, SplitVT), Op0Lo, Op0Hi);
9707     SDValue ResHi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
9708                                 DAG.getVTList(SplitVT, SplitVT), Op1Lo, Op1Hi);
9709 
9710     SDValue Even = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT,
9711                                ResLo.getValue(0), ResHi.getValue(0));
9712     SDValue Odd = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, ResLo.getValue(1),
9713                               ResHi.getValue(1));
9714     return DAG.getMergeValues({Even, Odd}, DL);
9715   }
9716 
9717   // Concatenate the two vectors as one vector to deinterleave
9718   MVT ConcatVT =
9719       MVT::getVectorVT(VecVT.getVectorElementType(),
9720                        VecVT.getVectorElementCount().multiplyCoefficientBy(2));
9721   SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
9722                                Op.getOperand(0), Op.getOperand(1));
9723 
9724   // We want to operate on all lanes, so get the mask and VL and mask for it
9725   auto [Mask, VL] = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget);
9726   SDValue Passthru = DAG.getUNDEF(ConcatVT);
9727 
9728   // We can deinterleave through vnsrl.wi if the element type is smaller than
9729   // ELEN
9730   if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
9731     SDValue Even =
9732         getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG);
9733     SDValue Odd =
9734         getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, Subtarget, DAG);
9735     return DAG.getMergeValues({Even, Odd}, DL);
9736   }
9737 
9738   // For the indices, use the same SEW to avoid an extra vsetvli
9739   MVT IdxVT = ConcatVT.changeVectorElementTypeToInteger();
9740   // Create a vector of even indices {0, 2, 4, ...}
9741   SDValue EvenIdx =
9742       DAG.getStepVector(DL, IdxVT, APInt(IdxVT.getScalarSizeInBits(), 2));
9743   // Create a vector of odd indices {1, 3, 5, ... }
9744   SDValue OddIdx =
9745       DAG.getNode(ISD::ADD, DL, IdxVT, EvenIdx, DAG.getConstant(1, DL, IdxVT));
9746 
9747   // Gather the even and odd elements into two separate vectors
9748   SDValue EvenWide = DAG.getNode(RISCVISD::VRGATHER_VV_VL, DL, ConcatVT,
9749                                  Concat, EvenIdx, Passthru, Mask, VL);
9750   SDValue OddWide = DAG.getNode(RISCVISD::VRGATHER_VV_VL, DL, ConcatVT,
9751                                 Concat, OddIdx, Passthru, Mask, VL);
9752 
9753   // Extract the result half of the gather for even and odd
9754   SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
9755                              DAG.getConstant(0, DL, XLenVT));
9756   SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
9757                             DAG.getConstant(0, DL, XLenVT));
9758 
9759   return DAG.getMergeValues({Even, Odd}, DL);
9760 }
9761 
lowerVECTOR_INTERLEAVE(SDValue Op,SelectionDAG & DAG) const9762 SDValue RISCVTargetLowering::lowerVECTOR_INTERLEAVE(SDValue Op,
9763                                                     SelectionDAG &DAG) const {
9764   SDLoc DL(Op);
9765   MVT VecVT = Op.getSimpleValueType();
9766 
9767   assert(VecVT.isScalableVector() &&
9768          "vector_interleave on non-scalable vector!");
9769 
9770   // i1 vectors need to be widened to i8
9771   if (VecVT.getVectorElementType() == MVT::i1)
9772     return widenVectorOpsToi8(Op, DL, DAG);
9773 
9774   MVT XLenVT = Subtarget.getXLenVT();
9775   SDValue VL = DAG.getRegister(RISCV::X0, XLenVT);
9776 
9777   // If the VT is LMUL=8, we need to split and reassemble.
9778   if (VecVT.getSizeInBits().getKnownMinValue() == (8 * RISCV::RVVBitsPerBlock)) {
9779     auto [Op0Lo, Op0Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
9780     auto [Op1Lo, Op1Hi] = DAG.SplitVectorOperand(Op.getNode(), 1);
9781     EVT SplitVT = Op0Lo.getValueType();
9782 
9783     SDValue ResLo = DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
9784                                 DAG.getVTList(SplitVT, SplitVT), Op0Lo, Op1Lo);
9785     SDValue ResHi = DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
9786                                 DAG.getVTList(SplitVT, SplitVT), Op0Hi, Op1Hi);
9787 
9788     SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT,
9789                              ResLo.getValue(0), ResLo.getValue(1));
9790     SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT,
9791                              ResHi.getValue(0), ResHi.getValue(1));
9792     return DAG.getMergeValues({Lo, Hi}, DL);
9793   }
9794 
9795   SDValue Interleaved;
9796 
9797   // If the element type is smaller than ELEN, then we can interleave with
9798   // vwaddu.vv and vwmaccu.vx
9799   if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
9800     Interleaved = getWideningInterleave(Op.getOperand(0), Op.getOperand(1), DL,
9801                                         DAG, Subtarget);
9802   } else {
9803     // Otherwise, fallback to using vrgathere16.vv
9804     MVT ConcatVT =
9805       MVT::getVectorVT(VecVT.getVectorElementType(),
9806                        VecVT.getVectorElementCount().multiplyCoefficientBy(2));
9807     SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
9808                                  Op.getOperand(0), Op.getOperand(1));
9809 
9810     MVT IdxVT = ConcatVT.changeVectorElementType(MVT::i16);
9811 
9812     // 0 1 2 3 4 5 6 7 ...
9813     SDValue StepVec = DAG.getStepVector(DL, IdxVT);
9814 
9815     // 1 1 1 1 1 1 1 1 ...
9816     SDValue Ones = DAG.getSplatVector(IdxVT, DL, DAG.getConstant(1, DL, XLenVT));
9817 
9818     // 1 0 1 0 1 0 1 0 ...
9819     SDValue OddMask = DAG.getNode(ISD::AND, DL, IdxVT, StepVec, Ones);
9820     OddMask = DAG.getSetCC(
9821         DL, IdxVT.changeVectorElementType(MVT::i1), OddMask,
9822         DAG.getSplatVector(IdxVT, DL, DAG.getConstant(0, DL, XLenVT)),
9823         ISD::CondCode::SETNE);
9824 
9825     SDValue VLMax = DAG.getSplatVector(IdxVT, DL, computeVLMax(VecVT, DL, DAG));
9826 
9827     // Build up the index vector for interleaving the concatenated vector
9828     //      0      0      1      1      2      2      3      3 ...
9829     SDValue Idx = DAG.getNode(ISD::SRL, DL, IdxVT, StepVec, Ones);
9830     //      0      n      1    n+1      2    n+2      3    n+3 ...
9831     Idx =
9832         DAG.getNode(RISCVISD::ADD_VL, DL, IdxVT, Idx, VLMax, Idx, OddMask, VL);
9833 
9834     // Then perform the interleave
9835     //   v[0]   v[n]   v[1] v[n+1]   v[2] v[n+2]   v[3] v[n+3] ...
9836     SDValue TrueMask = getAllOnesMask(IdxVT, VL, DL, DAG);
9837     Interleaved = DAG.getNode(RISCVISD::VRGATHEREI16_VV_VL, DL, ConcatVT,
9838                               Concat, Idx, DAG.getUNDEF(ConcatVT), TrueMask, VL);
9839   }
9840 
9841   // Extract the two halves from the interleaved result
9842   SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, Interleaved,
9843                            DAG.getVectorIdxConstant(0, DL));
9844   SDValue Hi = DAG.getNode(
9845       ISD::EXTRACT_SUBVECTOR, DL, VecVT, Interleaved,
9846       DAG.getVectorIdxConstant(VecVT.getVectorMinNumElements(), DL));
9847 
9848   return DAG.getMergeValues({Lo, Hi}, DL);
9849 }
9850 
9851 // Lower step_vector to the vid instruction. Any non-identity step value must
9852 // be accounted for my manual expansion.
lowerSTEP_VECTOR(SDValue Op,SelectionDAG & DAG) const9853 SDValue RISCVTargetLowering::lowerSTEP_VECTOR(SDValue Op,
9854                                               SelectionDAG &DAG) const {
9855   SDLoc DL(Op);
9856   MVT VT = Op.getSimpleValueType();
9857   assert(VT.isScalableVector() && "Expected scalable vector");
9858   MVT XLenVT = Subtarget.getXLenVT();
9859   auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
9860   SDValue StepVec = DAG.getNode(RISCVISD::VID_VL, DL, VT, Mask, VL);
9861   uint64_t StepValImm = Op.getConstantOperandVal(0);
9862   if (StepValImm != 1) {
9863     if (isPowerOf2_64(StepValImm)) {
9864       SDValue StepVal =
9865           DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
9866                       DAG.getConstant(Log2_64(StepValImm), DL, XLenVT), VL);
9867       StepVec = DAG.getNode(ISD::SHL, DL, VT, StepVec, StepVal);
9868     } else {
9869       SDValue StepVal = lowerScalarSplat(
9870           SDValue(), DAG.getConstant(StepValImm, DL, VT.getVectorElementType()),
9871           VL, VT, DL, DAG, Subtarget);
9872       StepVec = DAG.getNode(ISD::MUL, DL, VT, StepVec, StepVal);
9873     }
9874   }
9875   return StepVec;
9876 }
9877 
9878 // Implement vector_reverse using vrgather.vv with indices determined by
9879 // subtracting the id of each element from (VLMAX-1). This will convert
9880 // the indices like so:
9881 // (0, 1,..., VLMAX-2, VLMAX-1) -> (VLMAX-1, VLMAX-2,..., 1, 0).
9882 // TODO: This code assumes VLMAX <= 65536 for LMUL=8 SEW=16.
lowerVECTOR_REVERSE(SDValue Op,SelectionDAG & DAG) const9883 SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
9884                                                  SelectionDAG &DAG) const {
9885   SDLoc DL(Op);
9886   MVT VecVT = Op.getSimpleValueType();
9887   if (VecVT.getVectorElementType() == MVT::i1) {
9888     MVT WidenVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorElementCount());
9889     SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WidenVT, Op.getOperand(0));
9890     SDValue Op2 = DAG.getNode(ISD::VECTOR_REVERSE, DL, WidenVT, Op1);
9891     return DAG.getNode(ISD::TRUNCATE, DL, VecVT, Op2);
9892   }
9893   unsigned EltSize = VecVT.getScalarSizeInBits();
9894   unsigned MinSize = VecVT.getSizeInBits().getKnownMinValue();
9895   unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
9896   unsigned MaxVLMAX =
9897     RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize);
9898 
9899   unsigned GatherOpc = RISCVISD::VRGATHER_VV_VL;
9900   MVT IntVT = VecVT.changeVectorElementTypeToInteger();
9901 
9902   // If this is SEW=8 and VLMAX is potentially more than 256, we need
9903   // to use vrgatherei16.vv.
9904   // TODO: It's also possible to use vrgatherei16.vv for other types to
9905   // decrease register width for the index calculation.
9906   if (MaxVLMAX > 256 && EltSize == 8) {
9907     // If this is LMUL=8, we have to split before can use vrgatherei16.vv.
9908     // Reverse each half, then reassemble them in reverse order.
9909     // NOTE: It's also possible that after splitting that VLMAX no longer
9910     // requires vrgatherei16.vv.
9911     if (MinSize == (8 * RISCV::RVVBitsPerBlock)) {
9912       auto [Lo, Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
9913       auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VecVT);
9914       Lo = DAG.getNode(ISD::VECTOR_REVERSE, DL, LoVT, Lo);
9915       Hi = DAG.getNode(ISD::VECTOR_REVERSE, DL, HiVT, Hi);
9916       // Reassemble the low and high pieces reversed.
9917       // FIXME: This is a CONCAT_VECTORS.
9918       SDValue Res =
9919           DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT, DAG.getUNDEF(VecVT), Hi,
9920                       DAG.getIntPtrConstant(0, DL));
9921       return DAG.getNode(
9922           ISD::INSERT_SUBVECTOR, DL, VecVT, Res, Lo,
9923           DAG.getIntPtrConstant(LoVT.getVectorMinNumElements(), DL));
9924     }
9925 
9926     // Just promote the int type to i16 which will double the LMUL.
9927     IntVT = MVT::getVectorVT(MVT::i16, VecVT.getVectorElementCount());
9928     GatherOpc = RISCVISD::VRGATHEREI16_VV_VL;
9929   }
9930 
9931   MVT XLenVT = Subtarget.getXLenVT();
9932   auto [Mask, VL] = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
9933 
9934   // Calculate VLMAX-1 for the desired SEW.
9935   SDValue VLMinus1 = DAG.getNode(ISD::SUB, DL, XLenVT,
9936                                  computeVLMax(VecVT, DL, DAG),
9937                                  DAG.getConstant(1, DL, XLenVT));
9938 
9939   // Splat VLMAX-1 taking care to handle SEW==64 on RV32.
9940   bool IsRV32E64 =
9941       !Subtarget.is64Bit() && IntVT.getVectorElementType() == MVT::i64;
9942   SDValue SplatVL;
9943   if (!IsRV32E64)
9944     SplatVL = DAG.getSplatVector(IntVT, DL, VLMinus1);
9945   else
9946     SplatVL = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT, DAG.getUNDEF(IntVT),
9947                           VLMinus1, DAG.getRegister(RISCV::X0, XLenVT));
9948 
9949   SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IntVT, Mask, VL);
9950   SDValue Indices = DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID,
9951                                 DAG.getUNDEF(IntVT), Mask, VL);
9952 
9953   return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices,
9954                      DAG.getUNDEF(VecVT), Mask, VL);
9955 }
9956 
lowerVECTOR_SPLICE(SDValue Op,SelectionDAG & DAG) const9957 SDValue RISCVTargetLowering::lowerVECTOR_SPLICE(SDValue Op,
9958                                                 SelectionDAG &DAG) const {
9959   SDLoc DL(Op);
9960   SDValue V1 = Op.getOperand(0);
9961   SDValue V2 = Op.getOperand(1);
9962   MVT XLenVT = Subtarget.getXLenVT();
9963   MVT VecVT = Op.getSimpleValueType();
9964 
9965   SDValue VLMax = computeVLMax(VecVT, DL, DAG);
9966 
9967   int64_t ImmValue = cast<ConstantSDNode>(Op.getOperand(2))->getSExtValue();
9968   SDValue DownOffset, UpOffset;
9969   if (ImmValue >= 0) {
9970     // The operand is a TargetConstant, we need to rebuild it as a regular
9971     // constant.
9972     DownOffset = DAG.getConstant(ImmValue, DL, XLenVT);
9973     UpOffset = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, DownOffset);
9974   } else {
9975     // The operand is a TargetConstant, we need to rebuild it as a regular
9976     // constant rather than negating the original operand.
9977     UpOffset = DAG.getConstant(-ImmValue, DL, XLenVT);
9978     DownOffset = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, UpOffset);
9979   }
9980 
9981   SDValue TrueMask = getAllOnesMask(VecVT, VLMax, DL, DAG);
9982 
9983   SDValue SlideDown =
9984       getVSlidedown(DAG, Subtarget, DL, VecVT, DAG.getUNDEF(VecVT), V1,
9985                     DownOffset, TrueMask, UpOffset);
9986   return getVSlideup(DAG, Subtarget, DL, VecVT, SlideDown, V2, UpOffset,
9987                      TrueMask, DAG.getRegister(RISCV::X0, XLenVT),
9988                      RISCVII::TAIL_AGNOSTIC);
9989 }
9990 
9991 SDValue
lowerFixedLengthVectorLoadToRVV(SDValue Op,SelectionDAG & DAG) const9992 RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
9993                                                      SelectionDAG &DAG) const {
9994   SDLoc DL(Op);
9995   auto *Load = cast<LoadSDNode>(Op);
9996 
9997   assert(allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
9998                                         Load->getMemoryVT(),
9999                                         *Load->getMemOperand()) &&
10000          "Expecting a correctly-aligned load");
10001 
10002   MVT VT = Op.getSimpleValueType();
10003   MVT XLenVT = Subtarget.getXLenVT();
10004   MVT ContainerVT = getContainerForFixedLengthVector(VT);
10005 
10006   // If we know the exact VLEN and our fixed length vector completely fills
10007   // the container, use a whole register load instead.
10008   const auto [MinVLMAX, MaxVLMAX] =
10009       RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
10010   if (MinVLMAX == MaxVLMAX && MinVLMAX == VT.getVectorNumElements() &&
10011       getLMUL1VT(ContainerVT).bitsLE(ContainerVT)) {
10012     SDValue NewLoad =
10013         DAG.getLoad(ContainerVT, DL, Load->getChain(), Load->getBasePtr(),
10014                     Load->getMemOperand());
10015     SDValue Result = convertFromScalableVector(VT, NewLoad, DAG, Subtarget);
10016     return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
10017   }
10018 
10019   SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, Subtarget);
10020 
10021   bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
10022   SDValue IntID = DAG.getTargetConstant(
10023       IsMaskOp ? Intrinsic::riscv_vlm : Intrinsic::riscv_vle, DL, XLenVT);
10024   SmallVector<SDValue, 4> Ops{Load->getChain(), IntID};
10025   if (!IsMaskOp)
10026     Ops.push_back(DAG.getUNDEF(ContainerVT));
10027   Ops.push_back(Load->getBasePtr());
10028   Ops.push_back(VL);
10029   SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
10030   SDValue NewLoad =
10031       DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
10032                               Load->getMemoryVT(), Load->getMemOperand());
10033 
10034   SDValue Result = convertFromScalableVector(VT, NewLoad, DAG, Subtarget);
10035   return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
10036 }
10037 
10038 SDValue
lowerFixedLengthVectorStoreToRVV(SDValue Op,SelectionDAG & DAG) const10039 RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
10040                                                       SelectionDAG &DAG) const {
10041   SDLoc DL(Op);
10042   auto *Store = cast<StoreSDNode>(Op);
10043 
10044   assert(allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
10045                                         Store->getMemoryVT(),
10046                                         *Store->getMemOperand()) &&
10047          "Expecting a correctly-aligned store");
10048 
10049   SDValue StoreVal = Store->getValue();
10050   MVT VT = StoreVal.getSimpleValueType();
10051   MVT XLenVT = Subtarget.getXLenVT();
10052 
10053   // If the size less than a byte, we need to pad with zeros to make a byte.
10054   if (VT.getVectorElementType() == MVT::i1 && VT.getVectorNumElements() < 8) {
10055     VT = MVT::v8i1;
10056     StoreVal = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
10057                            DAG.getConstant(0, DL, VT), StoreVal,
10058                            DAG.getIntPtrConstant(0, DL));
10059   }
10060 
10061   MVT ContainerVT = getContainerForFixedLengthVector(VT);
10062 
10063   SDValue NewValue =
10064       convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);
10065 
10066 
10067   // If we know the exact VLEN and our fixed length vector completely fills
10068   // the container, use a whole register store instead.
10069   const auto [MinVLMAX, MaxVLMAX] =
10070       RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
10071   if (MinVLMAX == MaxVLMAX && MinVLMAX == VT.getVectorNumElements() &&
10072       getLMUL1VT(ContainerVT).bitsLE(ContainerVT))
10073     return DAG.getStore(Store->getChain(), DL, NewValue, Store->getBasePtr(),
10074                         Store->getMemOperand());
10075 
10076   SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
10077                        Subtarget);
10078 
10079   bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
10080   SDValue IntID = DAG.getTargetConstant(
10081       IsMaskOp ? Intrinsic::riscv_vsm : Intrinsic::riscv_vse, DL, XLenVT);
10082   return DAG.getMemIntrinsicNode(
10083       ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other),
10084       {Store->getChain(), IntID, NewValue, Store->getBasePtr(), VL},
10085       Store->getMemoryVT(), Store->getMemOperand());
10086 }
10087 
lowerMaskedLoad(SDValue Op,SelectionDAG & DAG) const10088 SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
10089                                              SelectionDAG &DAG) const {
10090   SDLoc DL(Op);
10091   MVT VT = Op.getSimpleValueType();
10092 
10093   const auto *MemSD = cast<MemSDNode>(Op);
10094   EVT MemVT = MemSD->getMemoryVT();
10095   MachineMemOperand *MMO = MemSD->getMemOperand();
10096   SDValue Chain = MemSD->getChain();
10097   SDValue BasePtr = MemSD->getBasePtr();
10098 
10099   SDValue Mask, PassThru, VL;
10100   if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
10101     Mask = VPLoad->getMask();
10102     PassThru = DAG.getUNDEF(VT);
10103     VL = VPLoad->getVectorLength();
10104   } else {
10105     const auto *MLoad = cast<MaskedLoadSDNode>(Op);
10106     Mask = MLoad->getMask();
10107     PassThru = MLoad->getPassThru();
10108   }
10109 
10110   bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
10111 
10112   MVT XLenVT = Subtarget.getXLenVT();
10113 
10114   MVT ContainerVT = VT;
10115   if (VT.isFixedLengthVector()) {
10116     ContainerVT = getContainerForFixedLengthVector(VT);
10117     PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
10118     if (!IsUnmasked) {
10119       MVT MaskVT = getMaskTypeFor(ContainerVT);
10120       Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
10121     }
10122   }
10123 
10124   if (!VL)
10125     VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10126 
10127   unsigned IntID =
10128       IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
10129   SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
10130   if (IsUnmasked)
10131     Ops.push_back(DAG.getUNDEF(ContainerVT));
10132   else
10133     Ops.push_back(PassThru);
10134   Ops.push_back(BasePtr);
10135   if (!IsUnmasked)
10136     Ops.push_back(Mask);
10137   Ops.push_back(VL);
10138   if (!IsUnmasked)
10139     Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
10140 
10141   SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
10142 
10143   SDValue Result =
10144       DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
10145   Chain = Result.getValue(1);
10146 
10147   if (VT.isFixedLengthVector())
10148     Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
10149 
10150   return DAG.getMergeValues({Result, Chain}, DL);
10151 }
10152 
lowerMaskedStore(SDValue Op,SelectionDAG & DAG) const10153 SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
10154                                               SelectionDAG &DAG) const {
10155   SDLoc DL(Op);
10156 
10157   const auto *MemSD = cast<MemSDNode>(Op);
10158   EVT MemVT = MemSD->getMemoryVT();
10159   MachineMemOperand *MMO = MemSD->getMemOperand();
10160   SDValue Chain = MemSD->getChain();
10161   SDValue BasePtr = MemSD->getBasePtr();
10162   SDValue Val, Mask, VL;
10163 
10164   if (const auto *VPStore = dyn_cast<VPStoreSDNode>(Op)) {
10165     Val = VPStore->getValue();
10166     Mask = VPStore->getMask();
10167     VL = VPStore->getVectorLength();
10168   } else {
10169     const auto *MStore = cast<MaskedStoreSDNode>(Op);
10170     Val = MStore->getValue();
10171     Mask = MStore->getMask();
10172   }
10173 
10174   bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
10175 
10176   MVT VT = Val.getSimpleValueType();
10177   MVT XLenVT = Subtarget.getXLenVT();
10178 
10179   MVT ContainerVT = VT;
10180   if (VT.isFixedLengthVector()) {
10181     ContainerVT = getContainerForFixedLengthVector(VT);
10182 
10183     Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
10184     if (!IsUnmasked) {
10185       MVT MaskVT = getMaskTypeFor(ContainerVT);
10186       Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
10187     }
10188   }
10189 
10190   if (!VL)
10191     VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10192 
10193   unsigned IntID =
10194       IsUnmasked ? Intrinsic::riscv_vse : Intrinsic::riscv_vse_mask;
10195   SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
10196   Ops.push_back(Val);
10197   Ops.push_back(BasePtr);
10198   if (!IsUnmasked)
10199     Ops.push_back(Mask);
10200   Ops.push_back(VL);
10201 
10202   return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL,
10203                                  DAG.getVTList(MVT::Other), Ops, MemVT, MMO);
10204 }
10205 
10206 SDValue
lowerFixedLengthVectorSetccToRVV(SDValue Op,SelectionDAG & DAG) const10207 RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
10208                                                       SelectionDAG &DAG) const {
10209   MVT InVT = Op.getOperand(0).getSimpleValueType();
10210   MVT ContainerVT = getContainerForFixedLengthVector(InVT);
10211 
10212   MVT VT = Op.getSimpleValueType();
10213 
10214   SDValue Op1 =
10215       convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
10216   SDValue Op2 =
10217       convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
10218 
10219   SDLoc DL(Op);
10220   auto [Mask, VL] = getDefaultVLOps(VT.getVectorNumElements(), ContainerVT, DL,
10221                                     DAG, Subtarget);
10222   MVT MaskVT = getMaskTypeFor(ContainerVT);
10223 
10224   SDValue Cmp =
10225       DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
10226                   {Op1, Op2, Op.getOperand(2), DAG.getUNDEF(MaskVT), Mask, VL});
10227 
10228   return convertFromScalableVector(VT, Cmp, DAG, Subtarget);
10229 }
10230 
lowerVectorStrictFSetcc(SDValue Op,SelectionDAG & DAG) const10231 SDValue RISCVTargetLowering::lowerVectorStrictFSetcc(SDValue Op,
10232                                                      SelectionDAG &DAG) const {
10233   unsigned Opc = Op.getOpcode();
10234   SDLoc DL(Op);
10235   SDValue Chain = Op.getOperand(0);
10236   SDValue Op1 = Op.getOperand(1);
10237   SDValue Op2 = Op.getOperand(2);
10238   SDValue CC = Op.getOperand(3);
10239   ISD::CondCode CCVal = cast<CondCodeSDNode>(CC)->get();
10240   MVT VT = Op.getSimpleValueType();
10241   MVT InVT = Op1.getSimpleValueType();
10242 
10243   // RVV VMFEQ/VMFNE ignores qNan, so we expand strict_fsetccs with OEQ/UNE
10244   // condition code.
10245   if (Opc == ISD::STRICT_FSETCCS) {
10246     // Expand strict_fsetccs(x, oeq) to
10247     // (and strict_fsetccs(x, y, oge), strict_fsetccs(x, y, ole))
10248     SDVTList VTList = Op->getVTList();
10249     if (CCVal == ISD::SETEQ || CCVal == ISD::SETOEQ) {
10250       SDValue OLECCVal = DAG.getCondCode(ISD::SETOLE);
10251       SDValue Tmp1 = DAG.getNode(ISD::STRICT_FSETCCS, DL, VTList, Chain, Op1,
10252                                  Op2, OLECCVal);
10253       SDValue Tmp2 = DAG.getNode(ISD::STRICT_FSETCCS, DL, VTList, Chain, Op2,
10254                                  Op1, OLECCVal);
10255       SDValue OutChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
10256                                      Tmp1.getValue(1), Tmp2.getValue(1));
10257       // Tmp1 and Tmp2 might be the same node.
10258       if (Tmp1 != Tmp2)
10259         Tmp1 = DAG.getNode(ISD::AND, DL, VT, Tmp1, Tmp2);
10260       return DAG.getMergeValues({Tmp1, OutChain}, DL);
10261     }
10262 
10263     // Expand (strict_fsetccs x, y, une) to (not (strict_fsetccs x, y, oeq))
10264     if (CCVal == ISD::SETNE || CCVal == ISD::SETUNE) {
10265       SDValue OEQCCVal = DAG.getCondCode(ISD::SETOEQ);
10266       SDValue OEQ = DAG.getNode(ISD::STRICT_FSETCCS, DL, VTList, Chain, Op1,
10267                                 Op2, OEQCCVal);
10268       SDValue Res = DAG.getNOT(DL, OEQ, VT);
10269       return DAG.getMergeValues({Res, OEQ.getValue(1)}, DL);
10270     }
10271   }
10272 
10273   MVT ContainerInVT = InVT;
10274   if (InVT.isFixedLengthVector()) {
10275     ContainerInVT = getContainerForFixedLengthVector(InVT);
10276     Op1 = convertToScalableVector(ContainerInVT, Op1, DAG, Subtarget);
10277     Op2 = convertToScalableVector(ContainerInVT, Op2, DAG, Subtarget);
10278   }
10279   MVT MaskVT = getMaskTypeFor(ContainerInVT);
10280 
10281   auto [Mask, VL] = getDefaultVLOps(InVT, ContainerInVT, DL, DAG, Subtarget);
10282 
10283   SDValue Res;
10284   if (Opc == ISD::STRICT_FSETCC &&
10285       (CCVal == ISD::SETLT || CCVal == ISD::SETOLT || CCVal == ISD::SETLE ||
10286        CCVal == ISD::SETOLE)) {
10287     // VMFLT/VMFLE/VMFGT/VMFGE raise exception for qNan. Generate a mask to only
10288     // active when both input elements are ordered.
10289     SDValue True = getAllOnesMask(ContainerInVT, VL, DL, DAG);
10290     SDValue OrderMask1 = DAG.getNode(
10291         RISCVISD::STRICT_FSETCC_VL, DL, DAG.getVTList(MaskVT, MVT::Other),
10292         {Chain, Op1, Op1, DAG.getCondCode(ISD::SETOEQ), DAG.getUNDEF(MaskVT),
10293          True, VL});
10294     SDValue OrderMask2 = DAG.getNode(
10295         RISCVISD::STRICT_FSETCC_VL, DL, DAG.getVTList(MaskVT, MVT::Other),
10296         {Chain, Op2, Op2, DAG.getCondCode(ISD::SETOEQ), DAG.getUNDEF(MaskVT),
10297          True, VL});
10298     Mask =
10299         DAG.getNode(RISCVISD::VMAND_VL, DL, MaskVT, OrderMask1, OrderMask2, VL);
10300     // Use Mask as the merge operand to let the result be 0 if either of the
10301     // inputs is unordered.
10302     Res = DAG.getNode(RISCVISD::STRICT_FSETCCS_VL, DL,
10303                       DAG.getVTList(MaskVT, MVT::Other),
10304                       {Chain, Op1, Op2, CC, Mask, Mask, VL});
10305   } else {
10306     unsigned RVVOpc = Opc == ISD::STRICT_FSETCC ? RISCVISD::STRICT_FSETCC_VL
10307                                                 : RISCVISD::STRICT_FSETCCS_VL;
10308     Res = DAG.getNode(RVVOpc, DL, DAG.getVTList(MaskVT, MVT::Other),
10309                       {Chain, Op1, Op2, CC, DAG.getUNDEF(MaskVT), Mask, VL});
10310   }
10311 
10312   if (VT.isFixedLengthVector()) {
10313     SDValue SubVec = convertFromScalableVector(VT, Res, DAG, Subtarget);
10314     return DAG.getMergeValues({SubVec, Res.getValue(1)}, DL);
10315   }
10316   return Res;
10317 }
10318 
10319 // Lower vector ABS to smax(X, sub(0, X)).
lowerABS(SDValue Op,SelectionDAG & DAG) const10320 SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
10321   SDLoc DL(Op);
10322   MVT VT = Op.getSimpleValueType();
10323   SDValue X = Op.getOperand(0);
10324 
10325   assert((Op.getOpcode() == ISD::VP_ABS || VT.isFixedLengthVector()) &&
10326          "Unexpected type for ISD::ABS");
10327 
10328   MVT ContainerVT = VT;
10329   if (VT.isFixedLengthVector()) {
10330     ContainerVT = getContainerForFixedLengthVector(VT);
10331     X = convertToScalableVector(ContainerVT, X, DAG, Subtarget);
10332   }
10333 
10334   SDValue Mask, VL;
10335   if (Op->getOpcode() == ISD::VP_ABS) {
10336     Mask = Op->getOperand(1);
10337     if (VT.isFixedLengthVector())
10338       Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
10339                                      Subtarget);
10340     VL = Op->getOperand(2);
10341   } else
10342     std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
10343 
10344   SDValue SplatZero = DAG.getNode(
10345       RISCVISD::VMV_V_X_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
10346       DAG.getConstant(0, DL, Subtarget.getXLenVT()), VL);
10347   SDValue NegX = DAG.getNode(RISCVISD::SUB_VL, DL, ContainerVT, SplatZero, X,
10348                              DAG.getUNDEF(ContainerVT), Mask, VL);
10349   SDValue Max = DAG.getNode(RISCVISD::SMAX_VL, DL, ContainerVT, X, NegX,
10350                             DAG.getUNDEF(ContainerVT), Mask, VL);
10351 
10352   if (VT.isFixedLengthVector())
10353     Max = convertFromScalableVector(VT, Max, DAG, Subtarget);
10354   return Max;
10355 }
10356 
lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,SelectionDAG & DAG) const10357 SDValue RISCVTargetLowering::lowerFixedLengthVectorFCOPYSIGNToRVV(
10358     SDValue Op, SelectionDAG &DAG) const {
10359   SDLoc DL(Op);
10360   MVT VT = Op.getSimpleValueType();
10361   SDValue Mag = Op.getOperand(0);
10362   SDValue Sign = Op.getOperand(1);
10363   assert(Mag.getValueType() == Sign.getValueType() &&
10364          "Can only handle COPYSIGN with matching types.");
10365 
10366   MVT ContainerVT = getContainerForFixedLengthVector(VT);
10367   Mag = convertToScalableVector(ContainerVT, Mag, DAG, Subtarget);
10368   Sign = convertToScalableVector(ContainerVT, Sign, DAG, Subtarget);
10369 
10370   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
10371 
10372   SDValue CopySign = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Mag,
10373                                  Sign, DAG.getUNDEF(ContainerVT), Mask, VL);
10374 
10375   return convertFromScalableVector(VT, CopySign, DAG, Subtarget);
10376 }
10377 
lowerFixedLengthVectorSelectToRVV(SDValue Op,SelectionDAG & DAG) const10378 SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
10379     SDValue Op, SelectionDAG &DAG) const {
10380   MVT VT = Op.getSimpleValueType();
10381   MVT ContainerVT = getContainerForFixedLengthVector(VT);
10382 
10383   MVT I1ContainerVT =
10384       MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
10385 
10386   SDValue CC =
10387       convertToScalableVector(I1ContainerVT, Op.getOperand(0), DAG, Subtarget);
10388   SDValue Op1 =
10389       convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
10390   SDValue Op2 =
10391       convertToScalableVector(ContainerVT, Op.getOperand(2), DAG, Subtarget);
10392 
10393   SDLoc DL(Op);
10394   SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10395 
10396   SDValue Select = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, CC, Op1,
10397                                Op2, DAG.getUNDEF(ContainerVT), VL);
10398 
10399   return convertFromScalableVector(VT, Select, DAG, Subtarget);
10400 }
10401 
lowerToScalableOp(SDValue Op,SelectionDAG & DAG) const10402 SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
10403                                                SelectionDAG &DAG) const {
10404   unsigned NewOpc = getRISCVVLOp(Op);
10405   bool HasMergeOp = hasMergeOp(NewOpc);
10406   bool HasMask = hasMaskOp(NewOpc);
10407 
10408   MVT VT = Op.getSimpleValueType();
10409   MVT ContainerVT = getContainerForFixedLengthVector(VT);
10410 
10411   // Create list of operands by converting existing ones to scalable types.
10412   SmallVector<SDValue, 6> Ops;
10413   for (const SDValue &V : Op->op_values()) {
10414     assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
10415 
10416     // Pass through non-vector operands.
10417     if (!V.getValueType().isVector()) {
10418       Ops.push_back(V);
10419       continue;
10420     }
10421 
10422     // "cast" fixed length vector to a scalable vector.
10423     assert(useRVVForFixedLengthVectorVT(V.getSimpleValueType()) &&
10424            "Only fixed length vectors are supported!");
10425     Ops.push_back(convertToScalableVector(ContainerVT, V, DAG, Subtarget));
10426   }
10427 
10428   SDLoc DL(Op);
10429   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
10430   if (HasMergeOp)
10431     Ops.push_back(DAG.getUNDEF(ContainerVT));
10432   if (HasMask)
10433     Ops.push_back(Mask);
10434   Ops.push_back(VL);
10435 
10436   // StrictFP operations have two result values. Their lowered result should
10437   // have same result count.
10438   if (Op->isStrictFPOpcode()) {
10439     SDValue ScalableRes =
10440         DAG.getNode(NewOpc, DL, DAG.getVTList(ContainerVT, MVT::Other), Ops,
10441                     Op->getFlags());
10442     SDValue SubVec = convertFromScalableVector(VT, ScalableRes, DAG, Subtarget);
10443     return DAG.getMergeValues({SubVec, ScalableRes.getValue(1)}, DL);
10444   }
10445 
10446   SDValue ScalableRes =
10447       DAG.getNode(NewOpc, DL, ContainerVT, Ops, Op->getFlags());
10448   return convertFromScalableVector(VT, ScalableRes, DAG, Subtarget);
10449 }
10450 
10451 // Lower a VP_* ISD node to the corresponding RISCVISD::*_VL node:
10452 // * Operands of each node are assumed to be in the same order.
10453 // * The EVL operand is promoted from i32 to i64 on RV64.
10454 // * Fixed-length vectors are converted to their scalable-vector container
10455 //   types.
lowerVPOp(SDValue Op,SelectionDAG & DAG) const10456 SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const {
10457   unsigned RISCVISDOpc = getRISCVVLOp(Op);
10458   bool HasMergeOp = hasMergeOp(RISCVISDOpc);
10459 
10460   SDLoc DL(Op);
10461   MVT VT = Op.getSimpleValueType();
10462   SmallVector<SDValue, 4> Ops;
10463 
10464   MVT ContainerVT = VT;
10465   if (VT.isFixedLengthVector())
10466     ContainerVT = getContainerForFixedLengthVector(VT);
10467 
10468   for (const auto &OpIdx : enumerate(Op->ops())) {
10469     SDValue V = OpIdx.value();
10470     assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
10471     // Add dummy merge value before the mask. Or if there isn't a mask, before
10472     // EVL.
10473     if (HasMergeOp) {
10474       auto MaskIdx = ISD::getVPMaskIdx(Op.getOpcode());
10475       if (MaskIdx) {
10476         if (*MaskIdx == OpIdx.index())
10477           Ops.push_back(DAG.getUNDEF(ContainerVT));
10478       } else if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) ==
10479                  OpIdx.index()) {
10480         if (Op.getOpcode() == ISD::VP_MERGE) {
10481           // For VP_MERGE, copy the false operand instead of an undef value.
10482           Ops.push_back(Ops.back());
10483         } else {
10484           assert(Op.getOpcode() == ISD::VP_SELECT);
10485           // For VP_SELECT, add an undef value.
10486           Ops.push_back(DAG.getUNDEF(ContainerVT));
10487         }
10488       }
10489     }
10490     // Pass through operands which aren't fixed-length vectors.
10491     if (!V.getValueType().isFixedLengthVector()) {
10492       Ops.push_back(V);
10493       continue;
10494     }
10495     // "cast" fixed length vector to a scalable vector.
10496     MVT OpVT = V.getSimpleValueType();
10497     MVT ContainerVT = getContainerForFixedLengthVector(OpVT);
10498     assert(useRVVForFixedLengthVectorVT(OpVT) &&
10499            "Only fixed length vectors are supported!");
10500     Ops.push_back(convertToScalableVector(ContainerVT, V, DAG, Subtarget));
10501   }
10502 
10503   if (!VT.isFixedLengthVector())
10504     return DAG.getNode(RISCVISDOpc, DL, VT, Ops, Op->getFlags());
10505 
10506   SDValue VPOp = DAG.getNode(RISCVISDOpc, DL, ContainerVT, Ops, Op->getFlags());
10507 
10508   return convertFromScalableVector(VT, VPOp, DAG, Subtarget);
10509 }
10510 
lowerVPExtMaskOp(SDValue Op,SelectionDAG & DAG) const10511 SDValue RISCVTargetLowering::lowerVPExtMaskOp(SDValue Op,
10512                                               SelectionDAG &DAG) const {
10513   SDLoc DL(Op);
10514   MVT VT = Op.getSimpleValueType();
10515 
10516   SDValue Src = Op.getOperand(0);
10517   // NOTE: Mask is dropped.
10518   SDValue VL = Op.getOperand(2);
10519 
10520   MVT ContainerVT = VT;
10521   if (VT.isFixedLengthVector()) {
10522     ContainerVT = getContainerForFixedLengthVector(VT);
10523     MVT SrcVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
10524     Src = convertToScalableVector(SrcVT, Src, DAG, Subtarget);
10525   }
10526 
10527   MVT XLenVT = Subtarget.getXLenVT();
10528   SDValue Zero = DAG.getConstant(0, DL, XLenVT);
10529   SDValue ZeroSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
10530                                   DAG.getUNDEF(ContainerVT), Zero, VL);
10531 
10532   SDValue SplatValue = DAG.getConstant(
10533       Op.getOpcode() == ISD::VP_ZERO_EXTEND ? 1 : -1, DL, XLenVT);
10534   SDValue Splat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
10535                               DAG.getUNDEF(ContainerVT), SplatValue, VL);
10536 
10537   SDValue Result = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, Src, Splat,
10538                                ZeroSplat, DAG.getUNDEF(ContainerVT), VL);
10539   if (!VT.isFixedLengthVector())
10540     return Result;
10541   return convertFromScalableVector(VT, Result, DAG, Subtarget);
10542 }
10543 
lowerVPSetCCMaskOp(SDValue Op,SelectionDAG & DAG) const10544 SDValue RISCVTargetLowering::lowerVPSetCCMaskOp(SDValue Op,
10545                                                 SelectionDAG &DAG) const {
10546   SDLoc DL(Op);
10547   MVT VT = Op.getSimpleValueType();
10548 
10549   SDValue Op1 = Op.getOperand(0);
10550   SDValue Op2 = Op.getOperand(1);
10551   ISD::CondCode Condition = cast<CondCodeSDNode>(Op.getOperand(2))->get();
10552   // NOTE: Mask is dropped.
10553   SDValue VL = Op.getOperand(4);
10554 
10555   MVT ContainerVT = VT;
10556   if (VT.isFixedLengthVector()) {
10557     ContainerVT = getContainerForFixedLengthVector(VT);
10558     Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
10559     Op2 = convertToScalableVector(ContainerVT, Op2, DAG, Subtarget);
10560   }
10561 
10562   SDValue Result;
10563   SDValue AllOneMask = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL);
10564 
10565   switch (Condition) {
10566   default:
10567     break;
10568   // X != Y  --> (X^Y)
10569   case ISD::SETNE:
10570     Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, Op2, VL);
10571     break;
10572   // X == Y  --> ~(X^Y)
10573   case ISD::SETEQ: {
10574     SDValue Temp =
10575         DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, Op2, VL);
10576     Result =
10577         DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, AllOneMask, VL);
10578     break;
10579   }
10580   // X >s Y   -->  X == 0 & Y == 1  -->  ~X & Y
10581   // X <u Y   -->  X == 0 & Y == 1  -->  ~X & Y
10582   case ISD::SETGT:
10583   case ISD::SETULT: {
10584     SDValue Temp =
10585         DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, AllOneMask, VL);
10586     Result = DAG.getNode(RISCVISD::VMAND_VL, DL, ContainerVT, Temp, Op2, VL);
10587     break;
10588   }
10589   // X <s Y   --> X == 1 & Y == 0  -->  ~Y & X
10590   // X >u Y   --> X == 1 & Y == 0  -->  ~Y & X
10591   case ISD::SETLT:
10592   case ISD::SETUGT: {
10593     SDValue Temp =
10594         DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op2, AllOneMask, VL);
10595     Result = DAG.getNode(RISCVISD::VMAND_VL, DL, ContainerVT, Op1, Temp, VL);
10596     break;
10597   }
10598   // X >=s Y  --> X == 0 | Y == 1  -->  ~X | Y
10599   // X <=u Y  --> X == 0 | Y == 1  -->  ~X | Y
10600   case ISD::SETGE:
10601   case ISD::SETULE: {
10602     SDValue Temp =
10603         DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, AllOneMask, VL);
10604     Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, Op2, VL);
10605     break;
10606   }
10607   // X <=s Y  --> X == 1 | Y == 0  -->  ~Y | X
10608   // X >=u Y  --> X == 1 | Y == 0  -->  ~Y | X
10609   case ISD::SETLE:
10610   case ISD::SETUGE: {
10611     SDValue Temp =
10612         DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op2, AllOneMask, VL);
10613     Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, Op1, VL);
10614     break;
10615   }
10616   }
10617 
10618   if (!VT.isFixedLengthVector())
10619     return Result;
10620   return convertFromScalableVector(VT, Result, DAG, Subtarget);
10621 }
10622 
10623 // Lower Floating-Point/Integer Type-Convert VP SDNodes
lowerVPFPIntConvOp(SDValue Op,SelectionDAG & DAG) const10624 SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op,
10625                                                 SelectionDAG &DAG) const {
10626   SDLoc DL(Op);
10627 
10628   SDValue Src = Op.getOperand(0);
10629   SDValue Mask = Op.getOperand(1);
10630   SDValue VL = Op.getOperand(2);
10631   unsigned RISCVISDOpc = getRISCVVLOp(Op);
10632 
10633   MVT DstVT = Op.getSimpleValueType();
10634   MVT SrcVT = Src.getSimpleValueType();
10635   if (DstVT.isFixedLengthVector()) {
10636     DstVT = getContainerForFixedLengthVector(DstVT);
10637     SrcVT = getContainerForFixedLengthVector(SrcVT);
10638     Src = convertToScalableVector(SrcVT, Src, DAG, Subtarget);
10639     MVT MaskVT = getMaskTypeFor(DstVT);
10640     Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
10641   }
10642 
10643   unsigned DstEltSize = DstVT.getScalarSizeInBits();
10644   unsigned SrcEltSize = SrcVT.getScalarSizeInBits();
10645 
10646   SDValue Result;
10647   if (DstEltSize >= SrcEltSize) { // Single-width and widening conversion.
10648     if (SrcVT.isInteger()) {
10649       assert(DstVT.isFloatingPoint() && "Wrong input/output vector types");
10650 
10651       unsigned RISCVISDExtOpc = RISCVISDOpc == RISCVISD::SINT_TO_FP_VL
10652                                     ? RISCVISD::VSEXT_VL
10653                                     : RISCVISD::VZEXT_VL;
10654 
10655       // Do we need to do any pre-widening before converting?
10656       if (SrcEltSize == 1) {
10657         MVT IntVT = DstVT.changeVectorElementTypeToInteger();
10658         MVT XLenVT = Subtarget.getXLenVT();
10659         SDValue Zero = DAG.getConstant(0, DL, XLenVT);
10660         SDValue ZeroSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT,
10661                                         DAG.getUNDEF(IntVT), Zero, VL);
10662         SDValue One = DAG.getConstant(
10663             RISCVISDExtOpc == RISCVISD::VZEXT_VL ? 1 : -1, DL, XLenVT);
10664         SDValue OneSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT,
10665                                        DAG.getUNDEF(IntVT), One, VL);
10666         Src = DAG.getNode(RISCVISD::VMERGE_VL, DL, IntVT, Src, OneSplat,
10667                           ZeroSplat, DAG.getUNDEF(IntVT), VL);
10668       } else if (DstEltSize > (2 * SrcEltSize)) {
10669         // Widen before converting.
10670         MVT IntVT = MVT::getVectorVT(MVT::getIntegerVT(DstEltSize / 2),
10671                                      DstVT.getVectorElementCount());
10672         Src = DAG.getNode(RISCVISDExtOpc, DL, IntVT, Src, Mask, VL);
10673       }
10674 
10675       Result = DAG.getNode(RISCVISDOpc, DL, DstVT, Src, Mask, VL);
10676     } else {
10677       assert(SrcVT.isFloatingPoint() && DstVT.isInteger() &&
10678              "Wrong input/output vector types");
10679 
10680       // Convert f16 to f32 then convert f32 to i64.
10681       if (DstEltSize > (2 * SrcEltSize)) {
10682         assert(SrcVT.getVectorElementType() == MVT::f16 && "Unexpected type!");
10683         MVT InterimFVT =
10684             MVT::getVectorVT(MVT::f32, DstVT.getVectorElementCount());
10685         Src =
10686             DAG.getNode(RISCVISD::FP_EXTEND_VL, DL, InterimFVT, Src, Mask, VL);
10687       }
10688 
10689       Result = DAG.getNode(RISCVISDOpc, DL, DstVT, Src, Mask, VL);
10690     }
10691   } else { // Narrowing + Conversion
10692     if (SrcVT.isInteger()) {
10693       assert(DstVT.isFloatingPoint() && "Wrong input/output vector types");
10694       // First do a narrowing convert to an FP type half the size, then round
10695       // the FP type to a small FP type if needed.
10696 
10697       MVT InterimFVT = DstVT;
10698       if (SrcEltSize > (2 * DstEltSize)) {
10699         assert(SrcEltSize == (4 * DstEltSize) && "Unexpected types!");
10700         assert(DstVT.getVectorElementType() == MVT::f16 && "Unexpected type!");
10701         InterimFVT = MVT::getVectorVT(MVT::f32, DstVT.getVectorElementCount());
10702       }
10703 
10704       Result = DAG.getNode(RISCVISDOpc, DL, InterimFVT, Src, Mask, VL);
10705 
10706       if (InterimFVT != DstVT) {
10707         Src = Result;
10708         Result = DAG.getNode(RISCVISD::FP_ROUND_VL, DL, DstVT, Src, Mask, VL);
10709       }
10710     } else {
10711       assert(SrcVT.isFloatingPoint() && DstVT.isInteger() &&
10712              "Wrong input/output vector types");
10713       // First do a narrowing conversion to an integer half the size, then
10714       // truncate if needed.
10715 
10716       if (DstEltSize == 1) {
10717         // First convert to the same size integer, then convert to mask using
10718         // setcc.
10719         assert(SrcEltSize >= 16 && "Unexpected FP type!");
10720         MVT InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize),
10721                                           DstVT.getVectorElementCount());
10722         Result = DAG.getNode(RISCVISDOpc, DL, InterimIVT, Src, Mask, VL);
10723 
10724         // Compare the integer result to 0. The integer should be 0 or 1/-1,
10725         // otherwise the conversion was undefined.
10726         MVT XLenVT = Subtarget.getXLenVT();
10727         SDValue SplatZero = DAG.getConstant(0, DL, XLenVT);
10728         SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, InterimIVT,
10729                                 DAG.getUNDEF(InterimIVT), SplatZero, VL);
10730         Result = DAG.getNode(RISCVISD::SETCC_VL, DL, DstVT,
10731                              {Result, SplatZero, DAG.getCondCode(ISD::SETNE),
10732                               DAG.getUNDEF(DstVT), Mask, VL});
10733       } else {
10734         MVT InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2),
10735                                           DstVT.getVectorElementCount());
10736 
10737         Result = DAG.getNode(RISCVISDOpc, DL, InterimIVT, Src, Mask, VL);
10738 
10739         while (InterimIVT != DstVT) {
10740           SrcEltSize /= 2;
10741           Src = Result;
10742           InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2),
10743                                         DstVT.getVectorElementCount());
10744           Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, InterimIVT,
10745                                Src, Mask, VL);
10746         }
10747       }
10748     }
10749   }
10750 
10751   MVT VT = Op.getSimpleValueType();
10752   if (!VT.isFixedLengthVector())
10753     return Result;
10754   return convertFromScalableVector(VT, Result, DAG, Subtarget);
10755 }
10756 
10757 SDValue
lowerVPSpliceExperimental(SDValue Op,SelectionDAG & DAG) const10758 RISCVTargetLowering::lowerVPSpliceExperimental(SDValue Op,
10759                                                SelectionDAG &DAG) const {
10760   SDLoc DL(Op);
10761 
10762   SDValue Op1 = Op.getOperand(0);
10763   SDValue Op2 = Op.getOperand(1);
10764   SDValue Offset = Op.getOperand(2);
10765   SDValue Mask = Op.getOperand(3);
10766   SDValue EVL1 = Op.getOperand(4);
10767   SDValue EVL2 = Op.getOperand(5);
10768 
10769   const MVT XLenVT = Subtarget.getXLenVT();
10770   MVT VT = Op.getSimpleValueType();
10771   MVT ContainerVT = VT;
10772   if (VT.isFixedLengthVector()) {
10773     ContainerVT = getContainerForFixedLengthVector(VT);
10774     Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
10775     Op2 = convertToScalableVector(ContainerVT, Op2, DAG, Subtarget);
10776     MVT MaskVT = getMaskTypeFor(ContainerVT);
10777     Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
10778   }
10779 
10780   bool IsMaskVector = VT.getVectorElementType() == MVT::i1;
10781   if (IsMaskVector) {
10782     ContainerVT = ContainerVT.changeVectorElementType(MVT::i8);
10783 
10784     // Expand input operands
10785     SDValue SplatOneOp1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
10786                                       DAG.getUNDEF(ContainerVT),
10787                                       DAG.getConstant(1, DL, XLenVT), EVL1);
10788     SDValue SplatZeroOp1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
10789                                        DAG.getUNDEF(ContainerVT),
10790                                        DAG.getConstant(0, DL, XLenVT), EVL1);
10791     Op1 = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, Op1, SplatOneOp1,
10792                       SplatZeroOp1, DAG.getUNDEF(ContainerVT), EVL1);
10793 
10794     SDValue SplatOneOp2 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
10795                                       DAG.getUNDEF(ContainerVT),
10796                                       DAG.getConstant(1, DL, XLenVT), EVL2);
10797     SDValue SplatZeroOp2 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
10798                                        DAG.getUNDEF(ContainerVT),
10799                                        DAG.getConstant(0, DL, XLenVT), EVL2);
10800     Op2 = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, Op2, SplatOneOp2,
10801                       SplatZeroOp2, DAG.getUNDEF(ContainerVT), EVL2);
10802   }
10803 
10804   int64_t ImmValue = cast<ConstantSDNode>(Offset)->getSExtValue();
10805   SDValue DownOffset, UpOffset;
10806   if (ImmValue >= 0) {
10807     // The operand is a TargetConstant, we need to rebuild it as a regular
10808     // constant.
10809     DownOffset = DAG.getConstant(ImmValue, DL, XLenVT);
10810     UpOffset = DAG.getNode(ISD::SUB, DL, XLenVT, EVL1, DownOffset);
10811   } else {
10812     // The operand is a TargetConstant, we need to rebuild it as a regular
10813     // constant rather than negating the original operand.
10814     UpOffset = DAG.getConstant(-ImmValue, DL, XLenVT);
10815     DownOffset = DAG.getNode(ISD::SUB, DL, XLenVT, EVL1, UpOffset);
10816   }
10817 
10818   SDValue SlideDown =
10819       getVSlidedown(DAG, Subtarget, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
10820                     Op1, DownOffset, Mask, UpOffset);
10821   SDValue Result = getVSlideup(DAG, Subtarget, DL, ContainerVT, SlideDown, Op2,
10822                                UpOffset, Mask, EVL2, RISCVII::TAIL_AGNOSTIC);
10823 
10824   if (IsMaskVector) {
10825     // Truncate Result back to a mask vector (Result has same EVL as Op2)
10826     Result = DAG.getNode(
10827         RISCVISD::SETCC_VL, DL, ContainerVT.changeVectorElementType(MVT::i1),
10828         {Result, DAG.getConstant(0, DL, ContainerVT),
10829          DAG.getCondCode(ISD::SETNE), DAG.getUNDEF(getMaskTypeFor(ContainerVT)),
10830          Mask, EVL2});
10831   }
10832 
10833   if (!VT.isFixedLengthVector())
10834     return Result;
10835   return convertFromScalableVector(VT, Result, DAG, Subtarget);
10836 }
10837 
10838 SDValue
lowerVPReverseExperimental(SDValue Op,SelectionDAG & DAG) const10839 RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op,
10840                                                 SelectionDAG &DAG) const {
10841   SDLoc DL(Op);
10842   MVT VT = Op.getSimpleValueType();
10843   MVT XLenVT = Subtarget.getXLenVT();
10844 
10845   SDValue Op1 = Op.getOperand(0);
10846   SDValue Mask = Op.getOperand(1);
10847   SDValue EVL = Op.getOperand(2);
10848 
10849   MVT ContainerVT = VT;
10850   if (VT.isFixedLengthVector()) {
10851     ContainerVT = getContainerForFixedLengthVector(VT);
10852     Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
10853     MVT MaskVT = getMaskTypeFor(ContainerVT);
10854     Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
10855   }
10856 
10857   MVT GatherVT = ContainerVT;
10858   MVT IndicesVT = ContainerVT.changeVectorElementTypeToInteger();
10859   // Check if we are working with mask vectors
10860   bool IsMaskVector = ContainerVT.getVectorElementType() == MVT::i1;
10861   if (IsMaskVector) {
10862     GatherVT = IndicesVT = ContainerVT.changeVectorElementType(MVT::i8);
10863 
10864     // Expand input operand
10865     SDValue SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT,
10866                                    DAG.getUNDEF(IndicesVT),
10867                                    DAG.getConstant(1, DL, XLenVT), EVL);
10868     SDValue SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT,
10869                                     DAG.getUNDEF(IndicesVT),
10870                                     DAG.getConstant(0, DL, XLenVT), EVL);
10871     Op1 = DAG.getNode(RISCVISD::VMERGE_VL, DL, IndicesVT, Op1, SplatOne,
10872                       SplatZero, DAG.getUNDEF(IndicesVT), EVL);
10873   }
10874 
10875   unsigned EltSize = GatherVT.getScalarSizeInBits();
10876   unsigned MinSize = GatherVT.getSizeInBits().getKnownMinValue();
10877   unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
10878   unsigned MaxVLMAX =
10879       RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize);
10880 
10881   unsigned GatherOpc = RISCVISD::VRGATHER_VV_VL;
10882   // If this is SEW=8 and VLMAX is unknown or more than 256, we need
10883   // to use vrgatherei16.vv.
10884   // TODO: It's also possible to use vrgatherei16.vv for other types to
10885   // decrease register width for the index calculation.
10886   // NOTE: This code assumes VLMAX <= 65536 for LMUL=8 SEW=16.
10887   if (MaxVLMAX > 256 && EltSize == 8) {
10888     // If this is LMUL=8, we have to split before using vrgatherei16.vv.
10889     // Split the vector in half and reverse each half using a full register
10890     // reverse.
10891     // Swap the halves and concatenate them.
10892     // Slide the concatenated result by (VLMax - VL).
10893     if (MinSize == (8 * RISCV::RVVBitsPerBlock)) {
10894       auto [LoVT, HiVT] = DAG.GetSplitDestVTs(GatherVT);
10895       auto [Lo, Hi] = DAG.SplitVector(Op1, DL);
10896 
10897       SDValue LoRev = DAG.getNode(ISD::VECTOR_REVERSE, DL, LoVT, Lo);
10898       SDValue HiRev = DAG.getNode(ISD::VECTOR_REVERSE, DL, HiVT, Hi);
10899 
10900       // Reassemble the low and high pieces reversed.
10901       // NOTE: this Result is unmasked (because we do not need masks for
10902       // shuffles). If in the future this has to change, we can use a SELECT_VL
10903       // between Result and UNDEF using the mask originally passed to VP_REVERSE
10904       SDValue Result =
10905           DAG.getNode(ISD::CONCAT_VECTORS, DL, GatherVT, HiRev, LoRev);
10906 
10907       // Slide off any elements from past EVL that were reversed into the low
10908       // elements.
10909       unsigned MinElts = GatherVT.getVectorMinNumElements();
10910       SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT,
10911                                   DAG.getConstant(MinElts, DL, XLenVT));
10912       SDValue Diff = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, EVL);
10913 
10914       Result = getVSlidedown(DAG, Subtarget, DL, GatherVT,
10915                              DAG.getUNDEF(GatherVT), Result, Diff, Mask, EVL);
10916 
10917       if (IsMaskVector) {
10918         // Truncate Result back to a mask vector
10919         Result =
10920             DAG.getNode(RISCVISD::SETCC_VL, DL, ContainerVT,
10921                         {Result, DAG.getConstant(0, DL, GatherVT),
10922                          DAG.getCondCode(ISD::SETNE),
10923                          DAG.getUNDEF(getMaskTypeFor(ContainerVT)), Mask, EVL});
10924       }
10925 
10926       if (!VT.isFixedLengthVector())
10927         return Result;
10928       return convertFromScalableVector(VT, Result, DAG, Subtarget);
10929     }
10930 
10931     // Just promote the int type to i16 which will double the LMUL.
10932     IndicesVT = MVT::getVectorVT(MVT::i16, IndicesVT.getVectorElementCount());
10933     GatherOpc = RISCVISD::VRGATHEREI16_VV_VL;
10934   }
10935 
10936   SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IndicesVT, Mask, EVL);
10937   SDValue VecLen =
10938       DAG.getNode(ISD::SUB, DL, XLenVT, EVL, DAG.getConstant(1, DL, XLenVT));
10939   SDValue VecLenSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT,
10940                                     DAG.getUNDEF(IndicesVT), VecLen, EVL);
10941   SDValue VRSUB = DAG.getNode(RISCVISD::SUB_VL, DL, IndicesVT, VecLenSplat, VID,
10942                               DAG.getUNDEF(IndicesVT), Mask, EVL);
10943   SDValue Result = DAG.getNode(GatherOpc, DL, GatherVT, Op1, VRSUB,
10944                                DAG.getUNDEF(GatherVT), Mask, EVL);
10945 
10946   if (IsMaskVector) {
10947     // Truncate Result back to a mask vector
10948     Result = DAG.getNode(
10949         RISCVISD::SETCC_VL, DL, ContainerVT,
10950         {Result, DAG.getConstant(0, DL, GatherVT), DAG.getCondCode(ISD::SETNE),
10951          DAG.getUNDEF(getMaskTypeFor(ContainerVT)), Mask, EVL});
10952   }
10953 
10954   if (!VT.isFixedLengthVector())
10955     return Result;
10956   return convertFromScalableVector(VT, Result, DAG, Subtarget);
10957 }
10958 
lowerLogicVPOp(SDValue Op,SelectionDAG & DAG) const10959 SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op,
10960                                             SelectionDAG &DAG) const {
10961   MVT VT = Op.getSimpleValueType();
10962   if (VT.getVectorElementType() != MVT::i1)
10963     return lowerVPOp(Op, DAG);
10964 
10965   // It is safe to drop mask parameter as masked-off elements are undef.
10966   SDValue Op1 = Op->getOperand(0);
10967   SDValue Op2 = Op->getOperand(1);
10968   SDValue VL = Op->getOperand(3);
10969 
10970   MVT ContainerVT = VT;
10971   const bool IsFixed = VT.isFixedLengthVector();
10972   if (IsFixed) {
10973     ContainerVT = getContainerForFixedLengthVector(VT);
10974     Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
10975     Op2 = convertToScalableVector(ContainerVT, Op2, DAG, Subtarget);
10976   }
10977 
10978   SDLoc DL(Op);
10979   SDValue Val = DAG.getNode(getRISCVVLOp(Op), DL, ContainerVT, Op1, Op2, VL);
10980   if (!IsFixed)
10981     return Val;
10982   return convertFromScalableVector(VT, Val, DAG, Subtarget);
10983 }
10984 
lowerVPStridedLoad(SDValue Op,SelectionDAG & DAG) const10985 SDValue RISCVTargetLowering::lowerVPStridedLoad(SDValue Op,
10986                                                 SelectionDAG &DAG) const {
10987   SDLoc DL(Op);
10988   MVT XLenVT = Subtarget.getXLenVT();
10989   MVT VT = Op.getSimpleValueType();
10990   MVT ContainerVT = VT;
10991   if (VT.isFixedLengthVector())
10992     ContainerVT = getContainerForFixedLengthVector(VT);
10993 
10994   SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
10995 
10996   auto *VPNode = cast<VPStridedLoadSDNode>(Op);
10997   // Check if the mask is known to be all ones
10998   SDValue Mask = VPNode->getMask();
10999   bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
11000 
11001   SDValue IntID = DAG.getTargetConstant(IsUnmasked ? Intrinsic::riscv_vlse
11002                                                    : Intrinsic::riscv_vlse_mask,
11003                                         DL, XLenVT);
11004   SmallVector<SDValue, 8> Ops{VPNode->getChain(), IntID,
11005                               DAG.getUNDEF(ContainerVT), VPNode->getBasePtr(),
11006                               VPNode->getStride()};
11007   if (!IsUnmasked) {
11008     if (VT.isFixedLengthVector()) {
11009       MVT MaskVT = ContainerVT.changeVectorElementType(MVT::i1);
11010       Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
11011     }
11012     Ops.push_back(Mask);
11013   }
11014   Ops.push_back(VPNode->getVectorLength());
11015   if (!IsUnmasked) {
11016     SDValue Policy = DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT);
11017     Ops.push_back(Policy);
11018   }
11019 
11020   SDValue Result =
11021       DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
11022                               VPNode->getMemoryVT(), VPNode->getMemOperand());
11023   SDValue Chain = Result.getValue(1);
11024 
11025   if (VT.isFixedLengthVector())
11026     Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
11027 
11028   return DAG.getMergeValues({Result, Chain}, DL);
11029 }
11030 
lowerVPStridedStore(SDValue Op,SelectionDAG & DAG) const11031 SDValue RISCVTargetLowering::lowerVPStridedStore(SDValue Op,
11032                                                  SelectionDAG &DAG) const {
11033   SDLoc DL(Op);
11034   MVT XLenVT = Subtarget.getXLenVT();
11035 
11036   auto *VPNode = cast<VPStridedStoreSDNode>(Op);
11037   SDValue StoreVal = VPNode->getValue();
11038   MVT VT = StoreVal.getSimpleValueType();
11039   MVT ContainerVT = VT;
11040   if (VT.isFixedLengthVector()) {
11041     ContainerVT = getContainerForFixedLengthVector(VT);
11042     StoreVal = convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);
11043   }
11044 
11045   // Check if the mask is known to be all ones
11046   SDValue Mask = VPNode->getMask();
11047   bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
11048 
11049   SDValue IntID = DAG.getTargetConstant(IsUnmasked ? Intrinsic::riscv_vsse
11050                                                    : Intrinsic::riscv_vsse_mask,
11051                                         DL, XLenVT);
11052   SmallVector<SDValue, 8> Ops{VPNode->getChain(), IntID, StoreVal,
11053                               VPNode->getBasePtr(), VPNode->getStride()};
11054   if (!IsUnmasked) {
11055     if (VT.isFixedLengthVector()) {
11056       MVT MaskVT = ContainerVT.changeVectorElementType(MVT::i1);
11057       Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
11058     }
11059     Ops.push_back(Mask);
11060   }
11061   Ops.push_back(VPNode->getVectorLength());
11062 
11063   return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, VPNode->getVTList(),
11064                                  Ops, VPNode->getMemoryVT(),
11065                                  VPNode->getMemOperand());
11066 }
11067 
11068 // Custom lower MGATHER/VP_GATHER to a legalized form for RVV. It will then be
11069 // matched to a RVV indexed load. The RVV indexed load instructions only
11070 // support the "unsigned unscaled" addressing mode; indices are implicitly
11071 // zero-extended or truncated to XLEN and are treated as byte offsets. Any
11072 // signed or scaled indexing is extended to the XLEN value type and scaled
11073 // accordingly.
lowerMaskedGather(SDValue Op,SelectionDAG & DAG) const11074 SDValue RISCVTargetLowering::lowerMaskedGather(SDValue Op,
11075                                                SelectionDAG &DAG) const {
11076   SDLoc DL(Op);
11077   MVT VT = Op.getSimpleValueType();
11078 
11079   const auto *MemSD = cast<MemSDNode>(Op.getNode());
11080   EVT MemVT = MemSD->getMemoryVT();
11081   MachineMemOperand *MMO = MemSD->getMemOperand();
11082   SDValue Chain = MemSD->getChain();
11083   SDValue BasePtr = MemSD->getBasePtr();
11084 
11085   ISD::LoadExtType LoadExtType;
11086   SDValue Index, Mask, PassThru, VL;
11087 
11088   if (auto *VPGN = dyn_cast<VPGatherSDNode>(Op.getNode())) {
11089     Index = VPGN->getIndex();
11090     Mask = VPGN->getMask();
11091     PassThru = DAG.getUNDEF(VT);
11092     VL = VPGN->getVectorLength();
11093     // VP doesn't support extending loads.
11094     LoadExtType = ISD::NON_EXTLOAD;
11095   } else {
11096     // Else it must be a MGATHER.
11097     auto *MGN = cast<MaskedGatherSDNode>(Op.getNode());
11098     Index = MGN->getIndex();
11099     Mask = MGN->getMask();
11100     PassThru = MGN->getPassThru();
11101     LoadExtType = MGN->getExtensionType();
11102   }
11103 
11104   MVT IndexVT = Index.getSimpleValueType();
11105   MVT XLenVT = Subtarget.getXLenVT();
11106 
11107   assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
11108          "Unexpected VTs!");
11109   assert(BasePtr.getSimpleValueType() == XLenVT && "Unexpected pointer type");
11110   // Targets have to explicitly opt-in for extending vector loads.
11111   assert(LoadExtType == ISD::NON_EXTLOAD &&
11112          "Unexpected extending MGATHER/VP_GATHER");
11113   (void)LoadExtType;
11114 
11115   // If the mask is known to be all ones, optimize to an unmasked intrinsic;
11116   // the selection of the masked intrinsics doesn't do this for us.
11117   bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
11118 
11119   MVT ContainerVT = VT;
11120   if (VT.isFixedLengthVector()) {
11121     ContainerVT = getContainerForFixedLengthVector(VT);
11122     IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(),
11123                                ContainerVT.getVectorElementCount());
11124 
11125     Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
11126 
11127     if (!IsUnmasked) {
11128       MVT MaskVT = getMaskTypeFor(ContainerVT);
11129       Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
11130       PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
11131     }
11132   }
11133 
11134   if (!VL)
11135     VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
11136 
11137   if (XLenVT == MVT::i32 && IndexVT.getVectorElementType().bitsGT(XLenVT)) {
11138     IndexVT = IndexVT.changeVectorElementType(XLenVT);
11139     Index = DAG.getNode(ISD::TRUNCATE, DL, IndexVT, Index);
11140   }
11141 
11142   unsigned IntID =
11143       IsUnmasked ? Intrinsic::riscv_vluxei : Intrinsic::riscv_vluxei_mask;
11144   SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
11145   if (IsUnmasked)
11146     Ops.push_back(DAG.getUNDEF(ContainerVT));
11147   else
11148     Ops.push_back(PassThru);
11149   Ops.push_back(BasePtr);
11150   Ops.push_back(Index);
11151   if (!IsUnmasked)
11152     Ops.push_back(Mask);
11153   Ops.push_back(VL);
11154   if (!IsUnmasked)
11155     Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
11156 
11157   SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
11158   SDValue Result =
11159       DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
11160   Chain = Result.getValue(1);
11161 
11162   if (VT.isFixedLengthVector())
11163     Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
11164 
11165   return DAG.getMergeValues({Result, Chain}, DL);
11166 }
11167 
11168 // Custom lower MSCATTER/VP_SCATTER to a legalized form for RVV. It will then be
11169 // matched to a RVV indexed store. The RVV indexed store instructions only
11170 // support the "unsigned unscaled" addressing mode; indices are implicitly
11171 // zero-extended or truncated to XLEN and are treated as byte offsets. Any
11172 // signed or scaled indexing is extended to the XLEN value type and scaled
11173 // accordingly.
lowerMaskedScatter(SDValue Op,SelectionDAG & DAG) const11174 SDValue RISCVTargetLowering::lowerMaskedScatter(SDValue Op,
11175                                                 SelectionDAG &DAG) const {
11176   SDLoc DL(Op);
11177   const auto *MemSD = cast<MemSDNode>(Op.getNode());
11178   EVT MemVT = MemSD->getMemoryVT();
11179   MachineMemOperand *MMO = MemSD->getMemOperand();
11180   SDValue Chain = MemSD->getChain();
11181   SDValue BasePtr = MemSD->getBasePtr();
11182 
11183   bool IsTruncatingStore = false;
11184   SDValue Index, Mask, Val, VL;
11185 
11186   if (auto *VPSN = dyn_cast<VPScatterSDNode>(Op.getNode())) {
11187     Index = VPSN->getIndex();
11188     Mask = VPSN->getMask();
11189     Val = VPSN->getValue();
11190     VL = VPSN->getVectorLength();
11191     // VP doesn't support truncating stores.
11192     IsTruncatingStore = false;
11193   } else {
11194     // Else it must be a MSCATTER.
11195     auto *MSN = cast<MaskedScatterSDNode>(Op.getNode());
11196     Index = MSN->getIndex();
11197     Mask = MSN->getMask();
11198     Val = MSN->getValue();
11199     IsTruncatingStore = MSN->isTruncatingStore();
11200   }
11201 
11202   MVT VT = Val.getSimpleValueType();
11203   MVT IndexVT = Index.getSimpleValueType();
11204   MVT XLenVT = Subtarget.getXLenVT();
11205 
11206   assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
11207          "Unexpected VTs!");
11208   assert(BasePtr.getSimpleValueType() == XLenVT && "Unexpected pointer type");
11209   // Targets have to explicitly opt-in for extending vector loads and
11210   // truncating vector stores.
11211   assert(!IsTruncatingStore && "Unexpected truncating MSCATTER/VP_SCATTER");
11212   (void)IsTruncatingStore;
11213 
11214   // If the mask is known to be all ones, optimize to an unmasked intrinsic;
11215   // the selection of the masked intrinsics doesn't do this for us.
11216   bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
11217 
11218   MVT ContainerVT = VT;
11219   if (VT.isFixedLengthVector()) {
11220     ContainerVT = getContainerForFixedLengthVector(VT);
11221     IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(),
11222                                ContainerVT.getVectorElementCount());
11223 
11224     Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
11225     Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
11226 
11227     if (!IsUnmasked) {
11228       MVT MaskVT = getMaskTypeFor(ContainerVT);
11229       Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
11230     }
11231   }
11232 
11233   if (!VL)
11234     VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
11235 
11236   if (XLenVT == MVT::i32 && IndexVT.getVectorElementType().bitsGT(XLenVT)) {
11237     IndexVT = IndexVT.changeVectorElementType(XLenVT);
11238     Index = DAG.getNode(ISD::TRUNCATE, DL, IndexVT, Index);
11239   }
11240 
11241   unsigned IntID =
11242       IsUnmasked ? Intrinsic::riscv_vsoxei : Intrinsic::riscv_vsoxei_mask;
11243   SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
11244   Ops.push_back(Val);
11245   Ops.push_back(BasePtr);
11246   Ops.push_back(Index);
11247   if (!IsUnmasked)
11248     Ops.push_back(Mask);
11249   Ops.push_back(VL);
11250 
11251   return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL,
11252                                  DAG.getVTList(MVT::Other), Ops, MemVT, MMO);
11253 }
11254 
lowerGET_ROUNDING(SDValue Op,SelectionDAG & DAG) const11255 SDValue RISCVTargetLowering::lowerGET_ROUNDING(SDValue Op,
11256                                                SelectionDAG &DAG) const {
11257   const MVT XLenVT = Subtarget.getXLenVT();
11258   SDLoc DL(Op);
11259   SDValue Chain = Op->getOperand(0);
11260   SDValue SysRegNo = DAG.getTargetConstant(
11261       RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT);
11262   SDVTList VTs = DAG.getVTList(XLenVT, MVT::Other);
11263   SDValue RM = DAG.getNode(RISCVISD::READ_CSR, DL, VTs, Chain, SysRegNo);
11264 
11265   // Encoding used for rounding mode in RISC-V differs from that used in
11266   // FLT_ROUNDS. To convert it the RISC-V rounding mode is used as an index in a
11267   // table, which consists of a sequence of 4-bit fields, each representing
11268   // corresponding FLT_ROUNDS mode.
11269   static const int Table =
11270       (int(RoundingMode::NearestTiesToEven) << 4 * RISCVFPRndMode::RNE) |
11271       (int(RoundingMode::TowardZero) << 4 * RISCVFPRndMode::RTZ) |
11272       (int(RoundingMode::TowardNegative) << 4 * RISCVFPRndMode::RDN) |
11273       (int(RoundingMode::TowardPositive) << 4 * RISCVFPRndMode::RUP) |
11274       (int(RoundingMode::NearestTiesToAway) << 4 * RISCVFPRndMode::RMM);
11275 
11276   SDValue Shift =
11277       DAG.getNode(ISD::SHL, DL, XLenVT, RM, DAG.getConstant(2, DL, XLenVT));
11278   SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT,
11279                                 DAG.getConstant(Table, DL, XLenVT), Shift);
11280   SDValue Masked = DAG.getNode(ISD::AND, DL, XLenVT, Shifted,
11281                                DAG.getConstant(7, DL, XLenVT));
11282 
11283   return DAG.getMergeValues({Masked, Chain}, DL);
11284 }
11285 
lowerSET_ROUNDING(SDValue Op,SelectionDAG & DAG) const11286 SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op,
11287                                                SelectionDAG &DAG) const {
11288   const MVT XLenVT = Subtarget.getXLenVT();
11289   SDLoc DL(Op);
11290   SDValue Chain = Op->getOperand(0);
11291   SDValue RMValue = Op->getOperand(1);
11292   SDValue SysRegNo = DAG.getTargetConstant(
11293       RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT);
11294 
11295   // Encoding used for rounding mode in RISC-V differs from that used in
11296   // FLT_ROUNDS. To convert it the C rounding mode is used as an index in
11297   // a table, which consists of a sequence of 4-bit fields, each representing
11298   // corresponding RISC-V mode.
11299   static const unsigned Table =
11300       (RISCVFPRndMode::RNE << 4 * int(RoundingMode::NearestTiesToEven)) |
11301       (RISCVFPRndMode::RTZ << 4 * int(RoundingMode::TowardZero)) |
11302       (RISCVFPRndMode::RDN << 4 * int(RoundingMode::TowardNegative)) |
11303       (RISCVFPRndMode::RUP << 4 * int(RoundingMode::TowardPositive)) |
11304       (RISCVFPRndMode::RMM << 4 * int(RoundingMode::NearestTiesToAway));
11305 
11306   RMValue = DAG.getNode(ISD::ZERO_EXTEND, DL, XLenVT, RMValue);
11307 
11308   SDValue Shift = DAG.getNode(ISD::SHL, DL, XLenVT, RMValue,
11309                               DAG.getConstant(2, DL, XLenVT));
11310   SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT,
11311                                 DAG.getConstant(Table, DL, XLenVT), Shift);
11312   RMValue = DAG.getNode(ISD::AND, DL, XLenVT, Shifted,
11313                         DAG.getConstant(0x7, DL, XLenVT));
11314   return DAG.getNode(RISCVISD::WRITE_CSR, DL, MVT::Other, Chain, SysRegNo,
11315                      RMValue);
11316 }
11317 
lowerEH_DWARF_CFA(SDValue Op,SelectionDAG & DAG) const11318 SDValue RISCVTargetLowering::lowerEH_DWARF_CFA(SDValue Op,
11319                                                SelectionDAG &DAG) const {
11320   MachineFunction &MF = DAG.getMachineFunction();
11321 
11322   bool isRISCV64 = Subtarget.is64Bit();
11323   EVT PtrVT = getPointerTy(DAG.getDataLayout());
11324 
11325   int FI = MF.getFrameInfo().CreateFixedObject(isRISCV64 ? 8 : 4, 0, false);
11326   return DAG.getFrameIndex(FI, PtrVT);
11327 }
11328 
11329 // Returns the opcode of the target-specific SDNode that implements the 32-bit
11330 // form of the given Opcode.
getRISCVWOpcode(unsigned Opcode)11331 static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
11332   switch (Opcode) {
11333   default:
11334     llvm_unreachable("Unexpected opcode");
11335   case ISD::SHL:
11336     return RISCVISD::SLLW;
11337   case ISD::SRA:
11338     return RISCVISD::SRAW;
11339   case ISD::SRL:
11340     return RISCVISD::SRLW;
11341   case ISD::SDIV:
11342     return RISCVISD::DIVW;
11343   case ISD::UDIV:
11344     return RISCVISD::DIVUW;
11345   case ISD::UREM:
11346     return RISCVISD::REMUW;
11347   case ISD::ROTL:
11348     return RISCVISD::ROLW;
11349   case ISD::ROTR:
11350     return RISCVISD::RORW;
11351   }
11352 }
11353 
11354 // Converts the given i8/i16/i32 operation to a target-specific SelectionDAG
11355 // node. Because i8/i16/i32 isn't a legal type for RV64, these operations would
11356 // otherwise be promoted to i64, making it difficult to select the
11357 // SLLW/DIVUW/.../*W later one because the fact the operation was originally of
11358 // type i8/i16/i32 is lost.
customLegalizeToWOp(SDNode * N,SelectionDAG & DAG,unsigned ExtOpc=ISD::ANY_EXTEND)11359 static SDValue customLegalizeToWOp(SDNode *N, SelectionDAG &DAG,
11360                                    unsigned ExtOpc = ISD::ANY_EXTEND) {
11361   SDLoc DL(N);
11362   RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode());
11363   SDValue NewOp0 = DAG.getNode(ExtOpc, DL, MVT::i64, N->getOperand(0));
11364   SDValue NewOp1 = DAG.getNode(ExtOpc, DL, MVT::i64, N->getOperand(1));
11365   SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, NewOp1);
11366   // ReplaceNodeResults requires we maintain the same type for the return value.
11367   return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewRes);
11368 }
11369 
11370 // Converts the given 32-bit operation to a i64 operation with signed extension
11371 // semantic to reduce the signed extension instructions.
customLegalizeToWOpWithSExt(SDNode * N,SelectionDAG & DAG)11372 static SDValue customLegalizeToWOpWithSExt(SDNode *N, SelectionDAG &DAG) {
11373   SDLoc DL(N);
11374   SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
11375   SDValue NewOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
11376   SDValue NewWOp = DAG.getNode(N->getOpcode(), DL, MVT::i64, NewOp0, NewOp1);
11377   SDValue NewRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, NewWOp,
11378                                DAG.getValueType(MVT::i32));
11379   return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes);
11380 }
11381 
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const11382 void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
11383                                              SmallVectorImpl<SDValue> &Results,
11384                                              SelectionDAG &DAG) const {
11385   SDLoc DL(N);
11386   switch (N->getOpcode()) {
11387   default:
11388     llvm_unreachable("Don't know how to custom type legalize this operation!");
11389   case ISD::STRICT_FP_TO_SINT:
11390   case ISD::STRICT_FP_TO_UINT:
11391   case ISD::FP_TO_SINT:
11392   case ISD::FP_TO_UINT: {
11393     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11394            "Unexpected custom legalisation");
11395     bool IsStrict = N->isStrictFPOpcode();
11396     bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT ||
11397                     N->getOpcode() == ISD::STRICT_FP_TO_SINT;
11398     SDValue Op0 = IsStrict ? N->getOperand(1) : N->getOperand(0);
11399     if (getTypeAction(*DAG.getContext(), Op0.getValueType()) !=
11400         TargetLowering::TypeSoftenFloat) {
11401       if (!isTypeLegal(Op0.getValueType()))
11402         return;
11403       if (IsStrict) {
11404         SDValue Chain = N->getOperand(0);
11405         // In absense of Zfh, promote f16 to f32, then convert.
11406         if (Op0.getValueType() == MVT::f16 &&
11407             !Subtarget.hasStdExtZfhOrZhinx()) {
11408           Op0 = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
11409                             {Chain, Op0});
11410           Chain = Op0.getValue(1);
11411         }
11412         unsigned Opc = IsSigned ? RISCVISD::STRICT_FCVT_W_RV64
11413                                 : RISCVISD::STRICT_FCVT_WU_RV64;
11414         SDVTList VTs = DAG.getVTList(MVT::i64, MVT::Other);
11415         SDValue Res = DAG.getNode(
11416             Opc, DL, VTs, Chain, Op0,
11417             DAG.getTargetConstant(RISCVFPRndMode::RTZ, DL, MVT::i64));
11418         Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11419         Results.push_back(Res.getValue(1));
11420         return;
11421       }
11422       // For bf16, or f16 in absense of Zfh, promote [b]f16 to f32 and then
11423       // convert.
11424       if ((Op0.getValueType() == MVT::f16 &&
11425            !Subtarget.hasStdExtZfhOrZhinx()) ||
11426           Op0.getValueType() == MVT::bf16)
11427         Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0);
11428 
11429       unsigned Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
11430       SDValue Res =
11431           DAG.getNode(Opc, DL, MVT::i64, Op0,
11432                       DAG.getTargetConstant(RISCVFPRndMode::RTZ, DL, MVT::i64));
11433       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11434       return;
11435     }
11436     // If the FP type needs to be softened, emit a library call using the 'si'
11437     // version. If we left it to default legalization we'd end up with 'di'. If
11438     // the FP type doesn't need to be softened just let generic type
11439     // legalization promote the result type.
11440     RTLIB::Libcall LC;
11441     if (IsSigned)
11442       LC = RTLIB::getFPTOSINT(Op0.getValueType(), N->getValueType(0));
11443     else
11444       LC = RTLIB::getFPTOUINT(Op0.getValueType(), N->getValueType(0));
11445     MakeLibCallOptions CallOptions;
11446     EVT OpVT = Op0.getValueType();
11447     CallOptions.setTypeListBeforeSoften(OpVT, N->getValueType(0), true);
11448     SDValue Chain = IsStrict ? N->getOperand(0) : SDValue();
11449     SDValue Result;
11450     std::tie(Result, Chain) =
11451         makeLibCall(DAG, LC, N->getValueType(0), Op0, CallOptions, DL, Chain);
11452     Results.push_back(Result);
11453     if (IsStrict)
11454       Results.push_back(Chain);
11455     break;
11456   }
11457   case ISD::LROUND: {
11458     SDValue Op0 = N->getOperand(0);
11459     EVT Op0VT = Op0.getValueType();
11460     if (getTypeAction(*DAG.getContext(), Op0.getValueType()) !=
11461         TargetLowering::TypeSoftenFloat) {
11462       if (!isTypeLegal(Op0VT))
11463         return;
11464 
11465       // In absense of Zfh, promote f16 to f32, then convert.
11466       if (Op0.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfhOrZhinx())
11467         Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0);
11468 
11469       SDValue Res =
11470           DAG.getNode(RISCVISD::FCVT_W_RV64, DL, MVT::i64, Op0,
11471                       DAG.getTargetConstant(RISCVFPRndMode::RMM, DL, MVT::i64));
11472       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11473       return;
11474     }
11475     // If the FP type needs to be softened, emit a library call to lround. We'll
11476     // need to truncate the result. We assume any value that doesn't fit in i32
11477     // is allowed to return an unspecified value.
11478     RTLIB::Libcall LC =
11479         Op0.getValueType() == MVT::f64 ? RTLIB::LROUND_F64 : RTLIB::LROUND_F32;
11480     MakeLibCallOptions CallOptions;
11481     EVT OpVT = Op0.getValueType();
11482     CallOptions.setTypeListBeforeSoften(OpVT, MVT::i64, true);
11483     SDValue Result = makeLibCall(DAG, LC, MVT::i64, Op0, CallOptions, DL).first;
11484     Result = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Result);
11485     Results.push_back(Result);
11486     break;
11487   }
11488   case ISD::READCYCLECOUNTER: {
11489     assert(!Subtarget.is64Bit() &&
11490            "READCYCLECOUNTER only has custom type legalization on riscv32");
11491 
11492     SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32, MVT::Other);
11493     SDValue RCW =
11494         DAG.getNode(RISCVISD::READ_CYCLE_WIDE, DL, VTs, N->getOperand(0));
11495 
11496     Results.push_back(
11497         DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, RCW, RCW.getValue(1)));
11498     Results.push_back(RCW.getValue(2));
11499     break;
11500   }
11501   case ISD::LOAD: {
11502     if (!ISD::isNON_EXTLoad(N))
11503       return;
11504 
11505     // Use a SEXTLOAD instead of the default EXTLOAD. Similar to the
11506     // sext_inreg we emit for ADD/SUB/MUL/SLLI.
11507     LoadSDNode *Ld = cast<LoadSDNode>(N);
11508 
11509     SDLoc dl(N);
11510     SDValue Res = DAG.getExtLoad(ISD::SEXTLOAD, dl, MVT::i64, Ld->getChain(),
11511                                  Ld->getBasePtr(), Ld->getMemoryVT(),
11512                                  Ld->getMemOperand());
11513     Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Res));
11514     Results.push_back(Res.getValue(1));
11515     return;
11516   }
11517   case ISD::MUL: {
11518     unsigned Size = N->getSimpleValueType(0).getSizeInBits();
11519     unsigned XLen = Subtarget.getXLen();
11520     // This multiply needs to be expanded, try to use MULHSU+MUL if possible.
11521     if (Size > XLen) {
11522       assert(Size == (XLen * 2) && "Unexpected custom legalisation");
11523       SDValue LHS = N->getOperand(0);
11524       SDValue RHS = N->getOperand(1);
11525       APInt HighMask = APInt::getHighBitsSet(Size, XLen);
11526 
11527       bool LHSIsU = DAG.MaskedValueIsZero(LHS, HighMask);
11528       bool RHSIsU = DAG.MaskedValueIsZero(RHS, HighMask);
11529       // We need exactly one side to be unsigned.
11530       if (LHSIsU == RHSIsU)
11531         return;
11532 
11533       auto MakeMULPair = [&](SDValue S, SDValue U) {
11534         MVT XLenVT = Subtarget.getXLenVT();
11535         S = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, S);
11536         U = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, U);
11537         SDValue Lo = DAG.getNode(ISD::MUL, DL, XLenVT, S, U);
11538         SDValue Hi = DAG.getNode(RISCVISD::MULHSU, DL, XLenVT, S, U);
11539         return DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0), Lo, Hi);
11540       };
11541 
11542       bool LHSIsS = DAG.ComputeNumSignBits(LHS) > XLen;
11543       bool RHSIsS = DAG.ComputeNumSignBits(RHS) > XLen;
11544 
11545       // The other operand should be signed, but still prefer MULH when
11546       // possible.
11547       if (RHSIsU && LHSIsS && !RHSIsS)
11548         Results.push_back(MakeMULPair(LHS, RHS));
11549       else if (LHSIsU && RHSIsS && !LHSIsS)
11550         Results.push_back(MakeMULPair(RHS, LHS));
11551 
11552       return;
11553     }
11554     [[fallthrough]];
11555   }
11556   case ISD::ADD:
11557   case ISD::SUB:
11558     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11559            "Unexpected custom legalisation");
11560     Results.push_back(customLegalizeToWOpWithSExt(N, DAG));
11561     break;
11562   case ISD::SHL:
11563   case ISD::SRA:
11564   case ISD::SRL:
11565     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11566            "Unexpected custom legalisation");
11567     if (N->getOperand(1).getOpcode() != ISD::Constant) {
11568       // If we can use a BSET instruction, allow default promotion to apply.
11569       if (N->getOpcode() == ISD::SHL && Subtarget.hasStdExtZbs() &&
11570           isOneConstant(N->getOperand(0)))
11571         break;
11572       Results.push_back(customLegalizeToWOp(N, DAG));
11573       break;
11574     }
11575 
11576     // Custom legalize ISD::SHL by placing a SIGN_EXTEND_INREG after. This is
11577     // similar to customLegalizeToWOpWithSExt, but we must zero_extend the
11578     // shift amount.
11579     if (N->getOpcode() == ISD::SHL) {
11580       SDLoc DL(N);
11581       SDValue NewOp0 =
11582           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
11583       SDValue NewOp1 =
11584           DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N->getOperand(1));
11585       SDValue NewWOp = DAG.getNode(ISD::SHL, DL, MVT::i64, NewOp0, NewOp1);
11586       SDValue NewRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, NewWOp,
11587                                    DAG.getValueType(MVT::i32));
11588       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes));
11589     }
11590 
11591     break;
11592   case ISD::ROTL:
11593   case ISD::ROTR:
11594     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11595            "Unexpected custom legalisation");
11596     assert((Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb() ||
11597             Subtarget.hasVendorXTHeadBb()) &&
11598            "Unexpected custom legalization");
11599     if (!isa<ConstantSDNode>(N->getOperand(1)) &&
11600         !(Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()))
11601       return;
11602     Results.push_back(customLegalizeToWOp(N, DAG));
11603     break;
11604   case ISD::CTTZ:
11605   case ISD::CTTZ_ZERO_UNDEF:
11606   case ISD::CTLZ:
11607   case ISD::CTLZ_ZERO_UNDEF: {
11608     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11609            "Unexpected custom legalisation");
11610 
11611     SDValue NewOp0 =
11612         DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
11613     bool IsCTZ =
11614         N->getOpcode() == ISD::CTTZ || N->getOpcode() == ISD::CTTZ_ZERO_UNDEF;
11615     unsigned Opc = IsCTZ ? RISCVISD::CTZW : RISCVISD::CLZW;
11616     SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp0);
11617     Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11618     return;
11619   }
11620   case ISD::SDIV:
11621   case ISD::UDIV:
11622   case ISD::UREM: {
11623     MVT VT = N->getSimpleValueType(0);
11624     assert((VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32) &&
11625            Subtarget.is64Bit() && Subtarget.hasStdExtM() &&
11626            "Unexpected custom legalisation");
11627     // Don't promote division/remainder by constant since we should expand those
11628     // to multiply by magic constant.
11629     AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
11630     if (N->getOperand(1).getOpcode() == ISD::Constant &&
11631         !isIntDivCheap(N->getValueType(0), Attr))
11632       return;
11633 
11634     // If the input is i32, use ANY_EXTEND since the W instructions don't read
11635     // the upper 32 bits. For other types we need to sign or zero extend
11636     // based on the opcode.
11637     unsigned ExtOpc = ISD::ANY_EXTEND;
11638     if (VT != MVT::i32)
11639       ExtOpc = N->getOpcode() == ISD::SDIV ? ISD::SIGN_EXTEND
11640                                            : ISD::ZERO_EXTEND;
11641 
11642     Results.push_back(customLegalizeToWOp(N, DAG, ExtOpc));
11643     break;
11644   }
11645   case ISD::SADDO: {
11646     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11647            "Unexpected custom legalisation");
11648 
11649     // If the RHS is a constant, we can simplify ConditionRHS below. Otherwise
11650     // use the default legalization.
11651     if (!isa<ConstantSDNode>(N->getOperand(1)))
11652       return;
11653 
11654     SDValue LHS = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0));
11655     SDValue RHS = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(1));
11656     SDValue Res = DAG.getNode(ISD::ADD, DL, MVT::i64, LHS, RHS);
11657     Res = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, Res,
11658                       DAG.getValueType(MVT::i32));
11659 
11660     SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11661 
11662     // For an addition, the result should be less than one of the operands (LHS)
11663     // if and only if the other operand (RHS) is negative, otherwise there will
11664     // be overflow.
11665     // For a subtraction, the result should be less than one of the operands
11666     // (LHS) if and only if the other operand (RHS) is (non-zero) positive,
11667     // otherwise there will be overflow.
11668     EVT OType = N->getValueType(1);
11669     SDValue ResultLowerThanLHS = DAG.getSetCC(DL, OType, Res, LHS, ISD::SETLT);
11670     SDValue ConditionRHS = DAG.getSetCC(DL, OType, RHS, Zero, ISD::SETLT);
11671 
11672     SDValue Overflow =
11673         DAG.getNode(ISD::XOR, DL, OType, ConditionRHS, ResultLowerThanLHS);
11674     Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11675     Results.push_back(Overflow);
11676     return;
11677   }
11678   case ISD::UADDO:
11679   case ISD::USUBO: {
11680     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11681            "Unexpected custom legalisation");
11682     bool IsAdd = N->getOpcode() == ISD::UADDO;
11683     // Create an ADDW or SUBW.
11684     SDValue LHS = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
11685     SDValue RHS = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
11686     SDValue Res =
11687         DAG.getNode(IsAdd ? ISD::ADD : ISD::SUB, DL, MVT::i64, LHS, RHS);
11688     Res = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, Res,
11689                       DAG.getValueType(MVT::i32));
11690 
11691     SDValue Overflow;
11692     if (IsAdd && isOneConstant(RHS)) {
11693       // Special case uaddo X, 1 overflowed if the addition result is 0.
11694       // The general case (X + C) < C is not necessarily beneficial. Although we
11695       // reduce the live range of X, we may introduce the materialization of
11696       // constant C, especially when the setcc result is used by branch. We have
11697       // no compare with constant and branch instructions.
11698       Overflow = DAG.getSetCC(DL, N->getValueType(1), Res,
11699                               DAG.getConstant(0, DL, MVT::i64), ISD::SETEQ);
11700     } else if (IsAdd && isAllOnesConstant(RHS)) {
11701       // Special case uaddo X, -1 overflowed if X != 0.
11702       Overflow = DAG.getSetCC(DL, N->getValueType(1), N->getOperand(0),
11703                               DAG.getConstant(0, DL, MVT::i32), ISD::SETNE);
11704     } else {
11705       // Sign extend the LHS and perform an unsigned compare with the ADDW
11706       // result. Since the inputs are sign extended from i32, this is equivalent
11707       // to comparing the lower 32 bits.
11708       LHS = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0));
11709       Overflow = DAG.getSetCC(DL, N->getValueType(1), Res, LHS,
11710                               IsAdd ? ISD::SETULT : ISD::SETUGT);
11711     }
11712 
11713     Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11714     Results.push_back(Overflow);
11715     return;
11716   }
11717   case ISD::UADDSAT:
11718   case ISD::USUBSAT: {
11719     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11720            "Unexpected custom legalisation");
11721     if (Subtarget.hasStdExtZbb()) {
11722       // With Zbb we can sign extend and let LegalizeDAG use minu/maxu. Using
11723       // sign extend allows overflow of the lower 32 bits to be detected on
11724       // the promoted size.
11725       SDValue LHS =
11726           DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0));
11727       SDValue RHS =
11728           DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(1));
11729       SDValue Res = DAG.getNode(N->getOpcode(), DL, MVT::i64, LHS, RHS);
11730       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11731       return;
11732     }
11733 
11734     // Without Zbb, expand to UADDO/USUBO+select which will trigger our custom
11735     // promotion for UADDO/USUBO.
11736     Results.push_back(expandAddSubSat(N, DAG));
11737     return;
11738   }
11739   case ISD::ABS: {
11740     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
11741            "Unexpected custom legalisation");
11742 
11743     if (Subtarget.hasStdExtZbb()) {
11744       // Emit a special ABSW node that will be expanded to NEGW+MAX at isel.
11745       // This allows us to remember that the result is sign extended. Expanding
11746       // to NEGW+MAX here requires a Freeze which breaks ComputeNumSignBits.
11747       SDValue Src = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64,
11748                                 N->getOperand(0));
11749       SDValue Abs = DAG.getNode(RISCVISD::ABSW, DL, MVT::i64, Src);
11750       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Abs));
11751       return;
11752     }
11753 
11754     // Expand abs to Y = (sraiw X, 31); subw(xor(X, Y), Y)
11755     SDValue Src = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
11756 
11757     // Freeze the source so we can increase it's use count.
11758     Src = DAG.getFreeze(Src);
11759 
11760     // Copy sign bit to all bits using the sraiw pattern.
11761     SDValue SignFill = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, Src,
11762                                    DAG.getValueType(MVT::i32));
11763     SignFill = DAG.getNode(ISD::SRA, DL, MVT::i64, SignFill,
11764                            DAG.getConstant(31, DL, MVT::i64));
11765 
11766     SDValue NewRes = DAG.getNode(ISD::XOR, DL, MVT::i64, Src, SignFill);
11767     NewRes = DAG.getNode(ISD::SUB, DL, MVT::i64, NewRes, SignFill);
11768 
11769     // NOTE: The result is only required to be anyextended, but sext is
11770     // consistent with type legalization of sub.
11771     NewRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, NewRes,
11772                          DAG.getValueType(MVT::i32));
11773     Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes));
11774     return;
11775   }
11776   case ISD::BITCAST: {
11777     EVT VT = N->getValueType(0);
11778     assert(VT.isInteger() && !VT.isVector() && "Unexpected VT!");
11779     SDValue Op0 = N->getOperand(0);
11780     EVT Op0VT = Op0.getValueType();
11781     MVT XLenVT = Subtarget.getXLenVT();
11782     if (VT == MVT::i16 && Op0VT == MVT::f16 &&
11783         Subtarget.hasStdExtZfhminOrZhinxmin()) {
11784       SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
11785       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
11786     } else if (VT == MVT::i16 && Op0VT == MVT::bf16 &&
11787                Subtarget.hasStdExtZfbfmin()) {
11788       SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
11789       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
11790     } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() &&
11791                Subtarget.hasStdExtFOrZfinx()) {
11792       SDValue FPConv =
11793           DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Op0);
11794       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPConv));
11795     } else if (VT == MVT::i64 && Op0VT == MVT::f64 && XLenVT == MVT::i32 &&
11796                Subtarget.hasStdExtZfa()) {
11797       SDValue NewReg = DAG.getNode(RISCVISD::SplitF64, DL,
11798                                    DAG.getVTList(MVT::i32, MVT::i32), Op0);
11799       SDValue RetReg = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
11800                                    NewReg.getValue(0), NewReg.getValue(1));
11801       Results.push_back(RetReg);
11802     } else if (!VT.isVector() && Op0VT.isFixedLengthVector() &&
11803                isTypeLegal(Op0VT)) {
11804       // Custom-legalize bitcasts from fixed-length vector types to illegal
11805       // scalar types in order to improve codegen. Bitcast the vector to a
11806       // one-element vector type whose element type is the same as the result
11807       // type, and extract the first element.
11808       EVT BVT = EVT::getVectorVT(*DAG.getContext(), VT, 1);
11809       if (isTypeLegal(BVT)) {
11810         SDValue BVec = DAG.getBitcast(BVT, Op0);
11811         Results.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
11812                                       DAG.getConstant(0, DL, XLenVT)));
11813       }
11814     }
11815     break;
11816   }
11817   case RISCVISD::BREV8: {
11818     MVT VT = N->getSimpleValueType(0);
11819     MVT XLenVT = Subtarget.getXLenVT();
11820     assert((VT == MVT::i16 || (VT == MVT::i32 && Subtarget.is64Bit())) &&
11821            "Unexpected custom legalisation");
11822     assert(Subtarget.hasStdExtZbkb() && "Unexpected extension");
11823     SDValue NewOp = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, N->getOperand(0));
11824     SDValue NewRes = DAG.getNode(N->getOpcode(), DL, XLenVT, NewOp);
11825     // ReplaceNodeResults requires we maintain the same type for the return
11826     // value.
11827     Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, NewRes));
11828     break;
11829   }
11830   case ISD::EXTRACT_VECTOR_ELT: {
11831     // Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element
11832     // type is illegal (currently only vXi64 RV32).
11833     // With vmv.x.s, when SEW > XLEN, only the least-significant XLEN bits are
11834     // transferred to the destination register. We issue two of these from the
11835     // upper- and lower- halves of the SEW-bit vector element, slid down to the
11836     // first element.
11837     SDValue Vec = N->getOperand(0);
11838     SDValue Idx = N->getOperand(1);
11839 
11840     // The vector type hasn't been legalized yet so we can't issue target
11841     // specific nodes if it needs legalization.
11842     // FIXME: We would manually legalize if it's important.
11843     if (!isTypeLegal(Vec.getValueType()))
11844       return;
11845 
11846     MVT VecVT = Vec.getSimpleValueType();
11847 
11848     assert(!Subtarget.is64Bit() && N->getValueType(0) == MVT::i64 &&
11849            VecVT.getVectorElementType() == MVT::i64 &&
11850            "Unexpected EXTRACT_VECTOR_ELT legalization");
11851 
11852     // If this is a fixed vector, we need to convert it to a scalable vector.
11853     MVT ContainerVT = VecVT;
11854     if (VecVT.isFixedLengthVector()) {
11855       ContainerVT = getContainerForFixedLengthVector(VecVT);
11856       Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
11857     }
11858 
11859     MVT XLenVT = Subtarget.getXLenVT();
11860 
11861     // Use a VL of 1 to avoid processing more elements than we need.
11862     auto [Mask, VL] = getDefaultVLOps(1, ContainerVT, DL, DAG, Subtarget);
11863 
11864     // Unless the index is known to be 0, we must slide the vector down to get
11865     // the desired element into index 0.
11866     if (!isNullConstant(Idx)) {
11867       Vec = getVSlidedown(DAG, Subtarget, DL, ContainerVT,
11868                           DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
11869     }
11870 
11871     // Extract the lower XLEN bits of the correct vector element.
11872     SDValue EltLo = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
11873 
11874     // To extract the upper XLEN bits of the vector element, shift the first
11875     // element right by 32 bits and re-extract the lower XLEN bits.
11876     SDValue ThirtyTwoV = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
11877                                      DAG.getUNDEF(ContainerVT),
11878                                      DAG.getConstant(32, DL, XLenVT), VL);
11879     SDValue LShr32 =
11880         DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, Vec, ThirtyTwoV,
11881                     DAG.getUNDEF(ContainerVT), Mask, VL);
11882 
11883     SDValue EltHi = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, LShr32);
11884 
11885     Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, EltLo, EltHi));
11886     break;
11887   }
11888   case ISD::INTRINSIC_WO_CHAIN: {
11889     unsigned IntNo = N->getConstantOperandVal(0);
11890     switch (IntNo) {
11891     default:
11892       llvm_unreachable(
11893           "Don't know how to custom type legalize this intrinsic!");
11894     case Intrinsic::experimental_get_vector_length: {
11895       SDValue Res = lowerGetVectorLength(N, DAG, Subtarget);
11896       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11897       return;
11898     }
11899     case Intrinsic::riscv_orc_b:
11900     case Intrinsic::riscv_brev8:
11901     case Intrinsic::riscv_sha256sig0:
11902     case Intrinsic::riscv_sha256sig1:
11903     case Intrinsic::riscv_sha256sum0:
11904     case Intrinsic::riscv_sha256sum1:
11905     case Intrinsic::riscv_sm3p0:
11906     case Intrinsic::riscv_sm3p1: {
11907       if (!Subtarget.is64Bit() || N->getValueType(0) != MVT::i32)
11908         return;
11909       unsigned Opc;
11910       switch (IntNo) {
11911       case Intrinsic::riscv_orc_b:      Opc = RISCVISD::ORC_B;      break;
11912       case Intrinsic::riscv_brev8:      Opc = RISCVISD::BREV8;      break;
11913       case Intrinsic::riscv_sha256sig0: Opc = RISCVISD::SHA256SIG0; break;
11914       case Intrinsic::riscv_sha256sig1: Opc = RISCVISD::SHA256SIG1; break;
11915       case Intrinsic::riscv_sha256sum0: Opc = RISCVISD::SHA256SUM0; break;
11916       case Intrinsic::riscv_sha256sum1: Opc = RISCVISD::SHA256SUM1; break;
11917       case Intrinsic::riscv_sm3p0:      Opc = RISCVISD::SM3P0;      break;
11918       case Intrinsic::riscv_sm3p1:      Opc = RISCVISD::SM3P1;      break;
11919       }
11920 
11921       SDValue NewOp =
11922           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
11923       SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp);
11924       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11925       return;
11926     }
11927     case Intrinsic::riscv_sm4ks:
11928     case Intrinsic::riscv_sm4ed: {
11929       unsigned Opc =
11930           IntNo == Intrinsic::riscv_sm4ks ? RISCVISD::SM4KS : RISCVISD::SM4ED;
11931       SDValue NewOp0 =
11932           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
11933       SDValue NewOp1 =
11934           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(2));
11935       SDValue Res =
11936           DAG.getNode(Opc, DL, MVT::i64, NewOp0, NewOp1, N->getOperand(3));
11937       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11938       return;
11939     }
11940     case Intrinsic::riscv_clmul: {
11941       if (!Subtarget.is64Bit() || N->getValueType(0) != MVT::i32)
11942         return;
11943 
11944       SDValue NewOp0 =
11945           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
11946       SDValue NewOp1 =
11947           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(2));
11948       SDValue Res = DAG.getNode(RISCVISD::CLMUL, DL, MVT::i64, NewOp0, NewOp1);
11949       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11950       return;
11951     }
11952     case Intrinsic::riscv_clmulh:
11953     case Intrinsic::riscv_clmulr: {
11954       if (!Subtarget.is64Bit() || N->getValueType(0) != MVT::i32)
11955         return;
11956 
11957       // Extend inputs to XLen, and shift by 32. This will add 64 trailing zeros
11958       // to the full 128-bit clmul result of multiplying two xlen values.
11959       // Perform clmulr or clmulh on the shifted values. Finally, extract the
11960       // upper 32 bits.
11961       //
11962       // The alternative is to mask the inputs to 32 bits and use clmul, but
11963       // that requires two shifts to mask each input without zext.w.
11964       // FIXME: If the inputs are known zero extended or could be freely
11965       // zero extended, the mask form would be better.
11966       SDValue NewOp0 =
11967           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
11968       SDValue NewOp1 =
11969           DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(2));
11970       NewOp0 = DAG.getNode(ISD::SHL, DL, MVT::i64, NewOp0,
11971                            DAG.getConstant(32, DL, MVT::i64));
11972       NewOp1 = DAG.getNode(ISD::SHL, DL, MVT::i64, NewOp1,
11973                            DAG.getConstant(32, DL, MVT::i64));
11974       unsigned Opc = IntNo == Intrinsic::riscv_clmulh ? RISCVISD::CLMULH
11975                                                       : RISCVISD::CLMULR;
11976       SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp0, NewOp1);
11977       Res = DAG.getNode(ISD::SRL, DL, MVT::i64, Res,
11978                         DAG.getConstant(32, DL, MVT::i64));
11979       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
11980       return;
11981     }
11982     case Intrinsic::riscv_vmv_x_s: {
11983       EVT VT = N->getValueType(0);
11984       MVT XLenVT = Subtarget.getXLenVT();
11985       if (VT.bitsLT(XLenVT)) {
11986         // Simple case just extract using vmv.x.s and truncate.
11987         SDValue Extract = DAG.getNode(RISCVISD::VMV_X_S, DL,
11988                                       Subtarget.getXLenVT(), N->getOperand(1));
11989         Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Extract));
11990         return;
11991       }
11992 
11993       assert(VT == MVT::i64 && !Subtarget.is64Bit() &&
11994              "Unexpected custom legalization");
11995 
11996       // We need to do the move in two steps.
11997       SDValue Vec = N->getOperand(1);
11998       MVT VecVT = Vec.getSimpleValueType();
11999 
12000       // First extract the lower XLEN bits of the element.
12001       SDValue EltLo = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
12002 
12003       // To extract the upper XLEN bits of the vector element, shift the first
12004       // element right by 32 bits and re-extract the lower XLEN bits.
12005       auto [Mask, VL] = getDefaultVLOps(1, VecVT, DL, DAG, Subtarget);
12006 
12007       SDValue ThirtyTwoV =
12008           DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT),
12009                       DAG.getConstant(32, DL, XLenVT), VL);
12010       SDValue LShr32 = DAG.getNode(RISCVISD::SRL_VL, DL, VecVT, Vec, ThirtyTwoV,
12011                                    DAG.getUNDEF(VecVT), Mask, VL);
12012       SDValue EltHi = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, LShr32);
12013 
12014       Results.push_back(
12015           DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, EltLo, EltHi));
12016       break;
12017     }
12018     }
12019     break;
12020   }
12021   case ISD::VECREDUCE_ADD:
12022   case ISD::VECREDUCE_AND:
12023   case ISD::VECREDUCE_OR:
12024   case ISD::VECREDUCE_XOR:
12025   case ISD::VECREDUCE_SMAX:
12026   case ISD::VECREDUCE_UMAX:
12027   case ISD::VECREDUCE_SMIN:
12028   case ISD::VECREDUCE_UMIN:
12029     if (SDValue V = lowerVECREDUCE(SDValue(N, 0), DAG))
12030       Results.push_back(V);
12031     break;
12032   case ISD::VP_REDUCE_ADD:
12033   case ISD::VP_REDUCE_AND:
12034   case ISD::VP_REDUCE_OR:
12035   case ISD::VP_REDUCE_XOR:
12036   case ISD::VP_REDUCE_SMAX:
12037   case ISD::VP_REDUCE_UMAX:
12038   case ISD::VP_REDUCE_SMIN:
12039   case ISD::VP_REDUCE_UMIN:
12040     if (SDValue V = lowerVPREDUCE(SDValue(N, 0), DAG))
12041       Results.push_back(V);
12042     break;
12043   case ISD::GET_ROUNDING: {
12044     SDVTList VTs = DAG.getVTList(Subtarget.getXLenVT(), MVT::Other);
12045     SDValue Res = DAG.getNode(ISD::GET_ROUNDING, DL, VTs, N->getOperand(0));
12046     Results.push_back(Res.getValue(0));
12047     Results.push_back(Res.getValue(1));
12048     break;
12049   }
12050   }
12051 }
12052 
12053 /// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP
12054 /// which corresponds to it.
getVecReduceOpcode(unsigned Opc)12055 static unsigned getVecReduceOpcode(unsigned Opc) {
12056   switch (Opc) {
12057   default:
12058     llvm_unreachable("Unhandled binary to transfrom reduction");
12059   case ISD::ADD:
12060     return ISD::VECREDUCE_ADD;
12061   case ISD::UMAX:
12062     return ISD::VECREDUCE_UMAX;
12063   case ISD::SMAX:
12064     return ISD::VECREDUCE_SMAX;
12065   case ISD::UMIN:
12066     return ISD::VECREDUCE_UMIN;
12067   case ISD::SMIN:
12068     return ISD::VECREDUCE_SMIN;
12069   case ISD::AND:
12070     return ISD::VECREDUCE_AND;
12071   case ISD::OR:
12072     return ISD::VECREDUCE_OR;
12073   case ISD::XOR:
12074     return ISD::VECREDUCE_XOR;
12075   case ISD::FADD:
12076     // Note: This is the associative form of the generic reduction opcode.
12077     return ISD::VECREDUCE_FADD;
12078   }
12079 }
12080 
12081 /// Perform two related transforms whose purpose is to incrementally recognize
12082 /// an explode_vector followed by scalar reduction as a vector reduction node.
12083 /// This exists to recover from a deficiency in SLP which can't handle
12084 /// forests with multiple roots sharing common nodes.  In some cases, one
12085 /// of the trees will be vectorized, and the other will remain (unprofitably)
12086 /// scalarized.
12087 static SDValue
combineBinOpOfExtractToReduceTree(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12088 combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
12089                                   const RISCVSubtarget &Subtarget) {
12090 
12091   // This transforms need to run before all integer types have been legalized
12092   // to i64 (so that the vector element type matches the add type), and while
12093   // it's safe to introduce odd sized vector types.
12094   if (DAG.NewNodesMustHaveLegalTypes)
12095     return SDValue();
12096 
12097   // Without V, this transform isn't useful.  We could form the (illegal)
12098   // operations and let them be scalarized again, but there's really no point.
12099   if (!Subtarget.hasVInstructions())
12100     return SDValue();
12101 
12102   const SDLoc DL(N);
12103   const EVT VT = N->getValueType(0);
12104   const unsigned Opc = N->getOpcode();
12105 
12106   // For FADD, we only handle the case with reassociation allowed.  We
12107   // could handle strict reduction order, but at the moment, there's no
12108   // known reason to, and the complexity isn't worth it.
12109   // TODO: Handle fminnum and fmaxnum here
12110   if (!VT.isInteger() &&
12111       (Opc != ISD::FADD || !N->getFlags().hasAllowReassociation()))
12112     return SDValue();
12113 
12114   const unsigned ReduceOpc = getVecReduceOpcode(Opc);
12115   assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
12116          "Inconsistent mappings");
12117   SDValue LHS = N->getOperand(0);
12118   SDValue RHS = N->getOperand(1);
12119 
12120   if (!LHS.hasOneUse() || !RHS.hasOneUse())
12121     return SDValue();
12122 
12123   if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
12124     std::swap(LHS, RHS);
12125 
12126   if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
12127       !isa<ConstantSDNode>(RHS.getOperand(1)))
12128     return SDValue();
12129 
12130   uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
12131   SDValue SrcVec = RHS.getOperand(0);
12132   EVT SrcVecVT = SrcVec.getValueType();
12133   assert(SrcVecVT.getVectorElementType() == VT);
12134   if (SrcVecVT.isScalableVector())
12135     return SDValue();
12136 
12137   if (SrcVecVT.getScalarSizeInBits() > Subtarget.getELen())
12138     return SDValue();
12139 
12140   // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
12141   // reduce_op (extract_subvector [2 x VT] from V).  This will form the
12142   // root of our reduction tree. TODO: We could extend this to any two
12143   // adjacent aligned constant indices if desired.
12144   if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
12145       LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
12146     uint64_t LHSIdx =
12147       cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
12148     if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
12149       EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
12150       SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
12151                                 DAG.getVectorIdxConstant(0, DL));
12152       return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
12153     }
12154   }
12155 
12156   // Match (binop (reduce (extract_subvector V, 0),
12157   //                      (extract_vector_elt V, sizeof(SubVec))))
12158   // into a reduction of one more element from the original vector V.
12159   if (LHS.getOpcode() != ReduceOpc)
12160     return SDValue();
12161 
12162   SDValue ReduceVec = LHS.getOperand(0);
12163   if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
12164       ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
12165       isNullConstant(ReduceVec.getOperand(1)) &&
12166       ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
12167     // For illegal types (e.g. 3xi32), most will be combined again into a
12168     // wider (hopefully legal) type.  If this is a terminal state, we are
12169     // relying on type legalization here to produce something reasonable
12170     // and this lowering quality could probably be improved. (TODO)
12171     EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
12172     SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
12173                               DAG.getVectorIdxConstant(0, DL));
12174     auto Flags = ReduceVec->getFlags();
12175     Flags.intersectWith(N->getFlags());
12176     return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
12177   }
12178 
12179   return SDValue();
12180 }
12181 
12182 
12183 // Try to fold (<bop> x, (reduction.<bop> vec, start))
combineBinOpToReduce(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12184 static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG,
12185                                     const RISCVSubtarget &Subtarget) {
12186   auto BinOpToRVVReduce = [](unsigned Opc) {
12187     switch (Opc) {
12188     default:
12189       llvm_unreachable("Unhandled binary to transfrom reduction");
12190     case ISD::ADD:
12191       return RISCVISD::VECREDUCE_ADD_VL;
12192     case ISD::UMAX:
12193       return RISCVISD::VECREDUCE_UMAX_VL;
12194     case ISD::SMAX:
12195       return RISCVISD::VECREDUCE_SMAX_VL;
12196     case ISD::UMIN:
12197       return RISCVISD::VECREDUCE_UMIN_VL;
12198     case ISD::SMIN:
12199       return RISCVISD::VECREDUCE_SMIN_VL;
12200     case ISD::AND:
12201       return RISCVISD::VECREDUCE_AND_VL;
12202     case ISD::OR:
12203       return RISCVISD::VECREDUCE_OR_VL;
12204     case ISD::XOR:
12205       return RISCVISD::VECREDUCE_XOR_VL;
12206     case ISD::FADD:
12207       return RISCVISD::VECREDUCE_FADD_VL;
12208     case ISD::FMAXNUM:
12209       return RISCVISD::VECREDUCE_FMAX_VL;
12210     case ISD::FMINNUM:
12211       return RISCVISD::VECREDUCE_FMIN_VL;
12212     }
12213   };
12214 
12215   auto IsReduction = [&BinOpToRVVReduce](SDValue V, unsigned Opc) {
12216     return V.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
12217            isNullConstant(V.getOperand(1)) &&
12218            V.getOperand(0).getOpcode() == BinOpToRVVReduce(Opc);
12219   };
12220 
12221   unsigned Opc = N->getOpcode();
12222   unsigned ReduceIdx;
12223   if (IsReduction(N->getOperand(0), Opc))
12224     ReduceIdx = 0;
12225   else if (IsReduction(N->getOperand(1), Opc))
12226     ReduceIdx = 1;
12227   else
12228     return SDValue();
12229 
12230   // Skip if FADD disallows reassociation but the combiner needs.
12231   if (Opc == ISD::FADD && !N->getFlags().hasAllowReassociation())
12232     return SDValue();
12233 
12234   SDValue Extract = N->getOperand(ReduceIdx);
12235   SDValue Reduce = Extract.getOperand(0);
12236   if (!Extract.hasOneUse() || !Reduce.hasOneUse())
12237     return SDValue();
12238 
12239   SDValue ScalarV = Reduce.getOperand(2);
12240   EVT ScalarVT = ScalarV.getValueType();
12241   if (ScalarV.getOpcode() == ISD::INSERT_SUBVECTOR &&
12242       ScalarV.getOperand(0)->isUndef() &&
12243       isNullConstant(ScalarV.getOperand(2)))
12244     ScalarV = ScalarV.getOperand(1);
12245 
12246   // Make sure that ScalarV is a splat with VL=1.
12247   if (ScalarV.getOpcode() != RISCVISD::VFMV_S_F_VL &&
12248       ScalarV.getOpcode() != RISCVISD::VMV_S_X_VL &&
12249       ScalarV.getOpcode() != RISCVISD::VMV_V_X_VL)
12250     return SDValue();
12251 
12252   if (!isNonZeroAVL(ScalarV.getOperand(2)))
12253     return SDValue();
12254 
12255   // Check the scalar of ScalarV is neutral element
12256   // TODO: Deal with value other than neutral element.
12257   if (!isNeutralConstant(N->getOpcode(), N->getFlags(), ScalarV.getOperand(1),
12258                          0))
12259     return SDValue();
12260 
12261   // If the AVL is zero, operand 0 will be returned. So it's not safe to fold.
12262   // FIXME: We might be able to improve this if operand 0 is undef.
12263   if (!isNonZeroAVL(Reduce.getOperand(5)))
12264     return SDValue();
12265 
12266   SDValue NewStart = N->getOperand(1 - ReduceIdx);
12267 
12268   SDLoc DL(N);
12269   SDValue NewScalarV =
12270       lowerScalarInsert(NewStart, ScalarV.getOperand(2),
12271                         ScalarV.getSimpleValueType(), DL, DAG, Subtarget);
12272 
12273   // If we looked through an INSERT_SUBVECTOR we need to restore it.
12274   if (ScalarVT != ScalarV.getValueType())
12275     NewScalarV =
12276         DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalarVT, DAG.getUNDEF(ScalarVT),
12277                     NewScalarV, DAG.getConstant(0, DL, Subtarget.getXLenVT()));
12278 
12279   SDValue Ops[] = {Reduce.getOperand(0), Reduce.getOperand(1),
12280                    NewScalarV,           Reduce.getOperand(3),
12281                    Reduce.getOperand(4), Reduce.getOperand(5)};
12282   SDValue NewReduce =
12283       DAG.getNode(Reduce.getOpcode(), DL, Reduce.getValueType(), Ops);
12284   return DAG.getNode(Extract.getOpcode(), DL, Extract.getValueType(), NewReduce,
12285                      Extract.getOperand(1));
12286 }
12287 
12288 // Optimize (add (shl x, c0), (shl y, c1)) ->
12289 //          (SLLI (SH*ADD x, y), c0), if c1-c0 equals to [1|2|3].
transformAddShlImm(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12290 static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
12291                                   const RISCVSubtarget &Subtarget) {
12292   // Perform this optimization only in the zba extension.
12293   if (!Subtarget.hasStdExtZba())
12294     return SDValue();
12295 
12296   // Skip for vector types and larger types.
12297   EVT VT = N->getValueType(0);
12298   if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
12299     return SDValue();
12300 
12301   // The two operand nodes must be SHL and have no other use.
12302   SDValue N0 = N->getOperand(0);
12303   SDValue N1 = N->getOperand(1);
12304   if (N0->getOpcode() != ISD::SHL || N1->getOpcode() != ISD::SHL ||
12305       !N0->hasOneUse() || !N1->hasOneUse())
12306     return SDValue();
12307 
12308   // Check c0 and c1.
12309   auto *N0C = dyn_cast<ConstantSDNode>(N0->getOperand(1));
12310   auto *N1C = dyn_cast<ConstantSDNode>(N1->getOperand(1));
12311   if (!N0C || !N1C)
12312     return SDValue();
12313   int64_t C0 = N0C->getSExtValue();
12314   int64_t C1 = N1C->getSExtValue();
12315   if (C0 <= 0 || C1 <= 0)
12316     return SDValue();
12317 
12318   // Skip if SH1ADD/SH2ADD/SH3ADD are not applicable.
12319   int64_t Bits = std::min(C0, C1);
12320   int64_t Diff = std::abs(C0 - C1);
12321   if (Diff != 1 && Diff != 2 && Diff != 3)
12322     return SDValue();
12323 
12324   // Build nodes.
12325   SDLoc DL(N);
12326   SDValue NS = (C0 < C1) ? N0->getOperand(0) : N1->getOperand(0);
12327   SDValue NL = (C0 > C1) ? N0->getOperand(0) : N1->getOperand(0);
12328   SDValue NA0 =
12329       DAG.getNode(ISD::SHL, DL, VT, NL, DAG.getConstant(Diff, DL, VT));
12330   SDValue NA1 = DAG.getNode(ISD::ADD, DL, VT, NA0, NS);
12331   return DAG.getNode(ISD::SHL, DL, VT, NA1, DAG.getConstant(Bits, DL, VT));
12332 }
12333 
12334 // Combine a constant select operand into its use:
12335 //
12336 // (and (select cond, -1, c), x)
12337 //   -> (select cond, x, (and x, c))  [AllOnes=1]
12338 // (or  (select cond, 0, c), x)
12339 //   -> (select cond, x, (or x, c))  [AllOnes=0]
12340 // (xor (select cond, 0, c), x)
12341 //   -> (select cond, x, (xor x, c))  [AllOnes=0]
12342 // (add (select cond, 0, c), x)
12343 //   -> (select cond, x, (add x, c))  [AllOnes=0]
12344 // (sub x, (select cond, 0, c))
12345 //   -> (select cond, x, (sub x, c))  [AllOnes=0]
combineSelectAndUse(SDNode * N,SDValue Slct,SDValue OtherOp,SelectionDAG & DAG,bool AllOnes,const RISCVSubtarget & Subtarget)12346 static SDValue combineSelectAndUse(SDNode *N, SDValue Slct, SDValue OtherOp,
12347                                    SelectionDAG &DAG, bool AllOnes,
12348                                    const RISCVSubtarget &Subtarget) {
12349   EVT VT = N->getValueType(0);
12350 
12351   // Skip vectors.
12352   if (VT.isVector())
12353     return SDValue();
12354 
12355   if (!Subtarget.hasConditionalMoveFusion()) {
12356     // (select cond, x, (and x, c)) has custom lowering with Zicond.
12357     if ((!Subtarget.hasStdExtZicond() &&
12358          !Subtarget.hasVendorXVentanaCondOps()) ||
12359         N->getOpcode() != ISD::AND)
12360       return SDValue();
12361 
12362     // Maybe harmful when condition code has multiple use.
12363     if (Slct.getOpcode() == ISD::SELECT && !Slct.getOperand(0).hasOneUse())
12364       return SDValue();
12365 
12366     // Maybe harmful when VT is wider than XLen.
12367     if (VT.getSizeInBits() > Subtarget.getXLen())
12368       return SDValue();
12369   }
12370 
12371   if ((Slct.getOpcode() != ISD::SELECT &&
12372        Slct.getOpcode() != RISCVISD::SELECT_CC) ||
12373       !Slct.hasOneUse())
12374     return SDValue();
12375 
12376   auto isZeroOrAllOnes = [](SDValue N, bool AllOnes) {
12377     return AllOnes ? isAllOnesConstant(N) : isNullConstant(N);
12378   };
12379 
12380   bool SwapSelectOps;
12381   unsigned OpOffset = Slct.getOpcode() == RISCVISD::SELECT_CC ? 2 : 0;
12382   SDValue TrueVal = Slct.getOperand(1 + OpOffset);
12383   SDValue FalseVal = Slct.getOperand(2 + OpOffset);
12384   SDValue NonConstantVal;
12385   if (isZeroOrAllOnes(TrueVal, AllOnes)) {
12386     SwapSelectOps = false;
12387     NonConstantVal = FalseVal;
12388   } else if (isZeroOrAllOnes(FalseVal, AllOnes)) {
12389     SwapSelectOps = true;
12390     NonConstantVal = TrueVal;
12391   } else
12392     return SDValue();
12393 
12394   // Slct is now know to be the desired identity constant when CC is true.
12395   TrueVal = OtherOp;
12396   FalseVal = DAG.getNode(N->getOpcode(), SDLoc(N), VT, OtherOp, NonConstantVal);
12397   // Unless SwapSelectOps says the condition should be false.
12398   if (SwapSelectOps)
12399     std::swap(TrueVal, FalseVal);
12400 
12401   if (Slct.getOpcode() == RISCVISD::SELECT_CC)
12402     return DAG.getNode(RISCVISD::SELECT_CC, SDLoc(N), VT,
12403                        {Slct.getOperand(0), Slct.getOperand(1),
12404                         Slct.getOperand(2), TrueVal, FalseVal});
12405 
12406   return DAG.getNode(ISD::SELECT, SDLoc(N), VT,
12407                      {Slct.getOperand(0), TrueVal, FalseVal});
12408 }
12409 
12410 // Attempt combineSelectAndUse on each operand of a commutative operator N.
combineSelectAndUseCommutative(SDNode * N,SelectionDAG & DAG,bool AllOnes,const RISCVSubtarget & Subtarget)12411 static SDValue combineSelectAndUseCommutative(SDNode *N, SelectionDAG &DAG,
12412                                               bool AllOnes,
12413                                               const RISCVSubtarget &Subtarget) {
12414   SDValue N0 = N->getOperand(0);
12415   SDValue N1 = N->getOperand(1);
12416   if (SDValue Result = combineSelectAndUse(N, N0, N1, DAG, AllOnes, Subtarget))
12417     return Result;
12418   if (SDValue Result = combineSelectAndUse(N, N1, N0, DAG, AllOnes, Subtarget))
12419     return Result;
12420   return SDValue();
12421 }
12422 
12423 // Transform (add (mul x, c0), c1) ->
12424 //           (add (mul (add x, c1/c0), c0), c1%c0).
12425 // if c1/c0 and c1%c0 are simm12, while c1 is not. A special corner case
12426 // that should be excluded is when c0*(c1/c0) is simm12, which will lead
12427 // to an infinite loop in DAGCombine if transformed.
12428 // Or transform (add (mul x, c0), c1) ->
12429 //              (add (mul (add x, c1/c0+1), c0), c1%c0-c0),
12430 // if c1/c0+1 and c1%c0-c0 are simm12, while c1 is not. A special corner
12431 // case that should be excluded is when c0*(c1/c0+1) is simm12, which will
12432 // lead to an infinite loop in DAGCombine if transformed.
12433 // Or transform (add (mul x, c0), c1) ->
12434 //              (add (mul (add x, c1/c0-1), c0), c1%c0+c0),
12435 // if c1/c0-1 and c1%c0+c0 are simm12, while c1 is not. A special corner
12436 // case that should be excluded is when c0*(c1/c0-1) is simm12, which will
12437 // lead to an infinite loop in DAGCombine if transformed.
12438 // Or transform (add (mul x, c0), c1) ->
12439 //              (mul (add x, c1/c0), c0).
12440 // if c1%c0 is zero, and c1/c0 is simm12 while c1 is not.
transformAddImmMulImm(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12441 static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
12442                                      const RISCVSubtarget &Subtarget) {
12443   // Skip for vector types and larger types.
12444   EVT VT = N->getValueType(0);
12445   if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
12446     return SDValue();
12447   // The first operand node must be a MUL and has no other use.
12448   SDValue N0 = N->getOperand(0);
12449   if (!N0->hasOneUse() || N0->getOpcode() != ISD::MUL)
12450     return SDValue();
12451   // Check if c0 and c1 match above conditions.
12452   auto *N0C = dyn_cast<ConstantSDNode>(N0->getOperand(1));
12453   auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
12454   if (!N0C || !N1C)
12455     return SDValue();
12456   // If N0C has multiple uses it's possible one of the cases in
12457   // DAGCombiner::isMulAddWithConstProfitable will be true, which would result
12458   // in an infinite loop.
12459   if (!N0C->hasOneUse())
12460     return SDValue();
12461   int64_t C0 = N0C->getSExtValue();
12462   int64_t C1 = N1C->getSExtValue();
12463   int64_t CA, CB;
12464   if (C0 == -1 || C0 == 0 || C0 == 1 || isInt<12>(C1))
12465     return SDValue();
12466   // Search for proper CA (non-zero) and CB that both are simm12.
12467   if ((C1 / C0) != 0 && isInt<12>(C1 / C0) && isInt<12>(C1 % C0) &&
12468       !isInt<12>(C0 * (C1 / C0))) {
12469     CA = C1 / C0;
12470     CB = C1 % C0;
12471   } else if ((C1 / C0 + 1) != 0 && isInt<12>(C1 / C0 + 1) &&
12472              isInt<12>(C1 % C0 - C0) && !isInt<12>(C0 * (C1 / C0 + 1))) {
12473     CA = C1 / C0 + 1;
12474     CB = C1 % C0 - C0;
12475   } else if ((C1 / C0 - 1) != 0 && isInt<12>(C1 / C0 - 1) &&
12476              isInt<12>(C1 % C0 + C0) && !isInt<12>(C0 * (C1 / C0 - 1))) {
12477     CA = C1 / C0 - 1;
12478     CB = C1 % C0 + C0;
12479   } else
12480     return SDValue();
12481   // Build new nodes (add (mul (add x, c1/c0), c0), c1%c0).
12482   SDLoc DL(N);
12483   SDValue New0 = DAG.getNode(ISD::ADD, DL, VT, N0->getOperand(0),
12484                              DAG.getConstant(CA, DL, VT));
12485   SDValue New1 =
12486       DAG.getNode(ISD::MUL, DL, VT, New0, DAG.getConstant(C0, DL, VT));
12487   return DAG.getNode(ISD::ADD, DL, VT, New1, DAG.getConstant(CB, DL, VT));
12488 }
12489 
12490 // Try to turn (add (xor bool, 1) -1) into (neg bool).
combineAddOfBooleanXor(SDNode * N,SelectionDAG & DAG)12491 static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) {
12492   SDValue N0 = N->getOperand(0);
12493   SDValue N1 = N->getOperand(1);
12494   EVT VT = N->getValueType(0);
12495   SDLoc DL(N);
12496 
12497   // RHS should be -1.
12498   if (!isAllOnesConstant(N1))
12499     return SDValue();
12500 
12501   // Look for (xor X, 1).
12502   if (N0.getOpcode() != ISD::XOR || !isOneConstant(N0.getOperand(1)))
12503     return SDValue();
12504 
12505   // First xor input should be 0 or 1.
12506   APInt Mask = APInt::getBitsSetFrom(VT.getSizeInBits(), 1);
12507   if (!DAG.MaskedValueIsZero(N0.getOperand(0), Mask))
12508     return SDValue();
12509 
12510   // Emit a negate of the setcc.
12511   return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
12512                      N0.getOperand(0));
12513 }
12514 
performADDCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12515 static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
12516                                  const RISCVSubtarget &Subtarget) {
12517   if (SDValue V = combineAddOfBooleanXor(N, DAG))
12518     return V;
12519   if (SDValue V = transformAddImmMulImm(N, DAG, Subtarget))
12520     return V;
12521   if (SDValue V = transformAddShlImm(N, DAG, Subtarget))
12522     return V;
12523   if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
12524     return V;
12525   if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
12526     return V;
12527 
12528   // fold (add (select lhs, rhs, cc, 0, y), x) ->
12529   //      (select lhs, rhs, cc, x, (add x, y))
12530   return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
12531 }
12532 
12533 // Try to turn a sub boolean RHS and constant LHS into an addi.
combineSubOfBoolean(SDNode * N,SelectionDAG & DAG)12534 static SDValue combineSubOfBoolean(SDNode *N, SelectionDAG &DAG) {
12535   SDValue N0 = N->getOperand(0);
12536   SDValue N1 = N->getOperand(1);
12537   EVT VT = N->getValueType(0);
12538   SDLoc DL(N);
12539 
12540   // Require a constant LHS.
12541   auto *N0C = dyn_cast<ConstantSDNode>(N0);
12542   if (!N0C)
12543     return SDValue();
12544 
12545   // All our optimizations involve subtracting 1 from the immediate and forming
12546   // an ADDI. Make sure the new immediate is valid for an ADDI.
12547   APInt ImmValMinus1 = N0C->getAPIntValue() - 1;
12548   if (!ImmValMinus1.isSignedIntN(12))
12549     return SDValue();
12550 
12551   SDValue NewLHS;
12552   if (N1.getOpcode() == ISD::SETCC && N1.hasOneUse()) {
12553     // (sub constant, (setcc x, y, eq/neq)) ->
12554     // (add (setcc x, y, neq/eq), constant - 1)
12555     ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
12556     EVT SetCCOpVT = N1.getOperand(0).getValueType();
12557     if (!isIntEqualitySetCC(CCVal) || !SetCCOpVT.isInteger())
12558       return SDValue();
12559     CCVal = ISD::getSetCCInverse(CCVal, SetCCOpVT);
12560     NewLHS =
12561         DAG.getSetCC(SDLoc(N1), VT, N1.getOperand(0), N1.getOperand(1), CCVal);
12562   } else if (N1.getOpcode() == ISD::XOR && isOneConstant(N1.getOperand(1)) &&
12563              N1.getOperand(0).getOpcode() == ISD::SETCC) {
12564     // (sub C, (xor (setcc), 1)) -> (add (setcc), C-1).
12565     // Since setcc returns a bool the xor is equivalent to 1-setcc.
12566     NewLHS = N1.getOperand(0);
12567   } else
12568     return SDValue();
12569 
12570   SDValue NewRHS = DAG.getConstant(ImmValMinus1, DL, VT);
12571   return DAG.getNode(ISD::ADD, DL, VT, NewLHS, NewRHS);
12572 }
12573 
performSUBCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12574 static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
12575                                  const RISCVSubtarget &Subtarget) {
12576   if (SDValue V = combineSubOfBoolean(N, DAG))
12577     return V;
12578 
12579   SDValue N0 = N->getOperand(0);
12580   SDValue N1 = N->getOperand(1);
12581   // fold (sub 0, (setcc x, 0, setlt)) -> (sra x, xlen - 1)
12582   if (isNullConstant(N0) && N1.getOpcode() == ISD::SETCC && N1.hasOneUse() &&
12583       isNullConstant(N1.getOperand(1))) {
12584     ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
12585     if (CCVal == ISD::SETLT) {
12586       EVT VT = N->getValueType(0);
12587       SDLoc DL(N);
12588       unsigned ShAmt = N0.getValueSizeInBits() - 1;
12589       return DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0),
12590                          DAG.getConstant(ShAmt, DL, VT));
12591     }
12592   }
12593 
12594   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
12595   //      (select lhs, rhs, cc, x, (sub x, y))
12596   return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false, Subtarget);
12597 }
12598 
12599 // Apply DeMorgan's law to (and/or (xor X, 1), (xor Y, 1)) if X and Y are 0/1.
12600 // Legalizing setcc can introduce xors like this. Doing this transform reduces
12601 // the number of xors and may allow the xor to fold into a branch condition.
combineDeMorganOfBoolean(SDNode * N,SelectionDAG & DAG)12602 static SDValue combineDeMorganOfBoolean(SDNode *N, SelectionDAG &DAG) {
12603   SDValue N0 = N->getOperand(0);
12604   SDValue N1 = N->getOperand(1);
12605   bool IsAnd = N->getOpcode() == ISD::AND;
12606 
12607   if (N0.getOpcode() != ISD::XOR || N1.getOpcode() != ISD::XOR)
12608     return SDValue();
12609 
12610   if (!N0.hasOneUse() || !N1.hasOneUse())
12611     return SDValue();
12612 
12613   SDValue N01 = N0.getOperand(1);
12614   SDValue N11 = N1.getOperand(1);
12615 
12616   // For AND, SimplifyDemandedBits may have turned one of the (xor X, 1) into
12617   // (xor X, -1) based on the upper bits of the other operand being 0. If the
12618   // operation is And, allow one of the Xors to use -1.
12619   if (isOneConstant(N01)) {
12620     if (!isOneConstant(N11) && !(IsAnd && isAllOnesConstant(N11)))
12621       return SDValue();
12622   } else if (isOneConstant(N11)) {
12623     // N01 and N11 being 1 was already handled. Handle N11==1 and N01==-1.
12624     if (!(IsAnd && isAllOnesConstant(N01)))
12625       return SDValue();
12626   } else
12627     return SDValue();
12628 
12629   EVT VT = N->getValueType(0);
12630 
12631   SDValue N00 = N0.getOperand(0);
12632   SDValue N10 = N1.getOperand(0);
12633 
12634   // The LHS of the xors needs to be 0/1.
12635   APInt Mask = APInt::getBitsSetFrom(VT.getSizeInBits(), 1);
12636   if (!DAG.MaskedValueIsZero(N00, Mask) || !DAG.MaskedValueIsZero(N10, Mask))
12637     return SDValue();
12638 
12639   // Invert the opcode and insert a new xor.
12640   SDLoc DL(N);
12641   unsigned Opc = IsAnd ? ISD::OR : ISD::AND;
12642   SDValue Logic = DAG.getNode(Opc, DL, VT, N00, N10);
12643   return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT));
12644 }
12645 
performTRUNCATECombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12646 static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
12647                                       const RISCVSubtarget &Subtarget) {
12648   SDValue N0 = N->getOperand(0);
12649   EVT VT = N->getValueType(0);
12650 
12651   // Pre-promote (i1 (truncate (srl X, Y))) on RV64 with Zbs without zero
12652   // extending X. This is safe since we only need the LSB after the shift and
12653   // shift amounts larger than 31 would produce poison. If we wait until
12654   // type legalization, we'll create RISCVISD::SRLW and we can't recover it
12655   // to use a BEXT instruction.
12656   if (!RV64LegalI32 && Subtarget.is64Bit() && Subtarget.hasStdExtZbs() && VT == MVT::i1 &&
12657       N0.getValueType() == MVT::i32 && N0.getOpcode() == ISD::SRL &&
12658       !isa<ConstantSDNode>(N0.getOperand(1)) && N0.hasOneUse()) {
12659     SDLoc DL(N0);
12660     SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N0.getOperand(0));
12661     SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N0.getOperand(1));
12662     SDValue Srl = DAG.getNode(ISD::SRL, DL, MVT::i64, Op0, Op1);
12663     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Srl);
12664   }
12665 
12666   return SDValue();
12667 }
12668 
12669 // Combines two comparison operation and logic operation to one selection
12670 // operation(min, max) and logic operation. Returns new constructed Node if
12671 // conditions for optimization are satisfied.
performANDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)12672 static SDValue performANDCombine(SDNode *N,
12673                                  TargetLowering::DAGCombinerInfo &DCI,
12674                                  const RISCVSubtarget &Subtarget) {
12675   SelectionDAG &DAG = DCI.DAG;
12676 
12677   SDValue N0 = N->getOperand(0);
12678   // Pre-promote (i32 (and (srl X, Y), 1)) on RV64 with Zbs without zero
12679   // extending X. This is safe since we only need the LSB after the shift and
12680   // shift amounts larger than 31 would produce poison. If we wait until
12681   // type legalization, we'll create RISCVISD::SRLW and we can't recover it
12682   // to use a BEXT instruction.
12683   if (!RV64LegalI32 && Subtarget.is64Bit() && Subtarget.hasStdExtZbs() &&
12684       N->getValueType(0) == MVT::i32 && isOneConstant(N->getOperand(1)) &&
12685       N0.getOpcode() == ISD::SRL && !isa<ConstantSDNode>(N0.getOperand(1)) &&
12686       N0.hasOneUse()) {
12687     SDLoc DL(N);
12688     SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N0.getOperand(0));
12689     SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N0.getOperand(1));
12690     SDValue Srl = DAG.getNode(ISD::SRL, DL, MVT::i64, Op0, Op1);
12691     SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, Srl,
12692                               DAG.getConstant(1, DL, MVT::i64));
12693     return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And);
12694   }
12695 
12696   if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
12697     return V;
12698   if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
12699     return V;
12700 
12701   if (DCI.isAfterLegalizeDAG())
12702     if (SDValue V = combineDeMorganOfBoolean(N, DAG))
12703       return V;
12704 
12705   // fold (and (select lhs, rhs, cc, -1, y), x) ->
12706   //      (select lhs, rhs, cc, x, (and x, y))
12707   return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ true, Subtarget);
12708 }
12709 
12710 // Try to pull an xor with 1 through a select idiom that uses czero_eqz/nez.
12711 // FIXME: Generalize to other binary operators with same operand.
combineOrOfCZERO(SDNode * N,SDValue N0,SDValue N1,SelectionDAG & DAG)12712 static SDValue combineOrOfCZERO(SDNode *N, SDValue N0, SDValue N1,
12713                                 SelectionDAG &DAG) {
12714   assert(N->getOpcode() == ISD::OR && "Unexpected opcode");
12715 
12716   if (N0.getOpcode() != RISCVISD::CZERO_EQZ ||
12717       N1.getOpcode() != RISCVISD::CZERO_NEZ ||
12718       !N0.hasOneUse() || !N1.hasOneUse())
12719     return SDValue();
12720 
12721   // Should have the same condition.
12722   SDValue Cond = N0.getOperand(1);
12723   if (Cond != N1.getOperand(1))
12724     return SDValue();
12725 
12726   SDValue TrueV = N0.getOperand(0);
12727   SDValue FalseV = N1.getOperand(0);
12728 
12729   if (TrueV.getOpcode() != ISD::XOR || FalseV.getOpcode() != ISD::XOR ||
12730       TrueV.getOperand(1) != FalseV.getOperand(1) ||
12731       !isOneConstant(TrueV.getOperand(1)) ||
12732       !TrueV.hasOneUse() || !FalseV.hasOneUse())
12733     return SDValue();
12734 
12735   EVT VT = N->getValueType(0);
12736   SDLoc DL(N);
12737 
12738   SDValue NewN0 = DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV.getOperand(0),
12739                               Cond);
12740   SDValue NewN1 = DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV.getOperand(0),
12741                               Cond);
12742   SDValue NewOr = DAG.getNode(ISD::OR, DL, VT, NewN0, NewN1);
12743   return DAG.getNode(ISD::XOR, DL, VT, NewOr, TrueV.getOperand(1));
12744 }
12745 
performORCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)12746 static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
12747                                 const RISCVSubtarget &Subtarget) {
12748   SelectionDAG &DAG = DCI.DAG;
12749 
12750   if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
12751     return V;
12752   if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
12753     return V;
12754 
12755   if (DCI.isAfterLegalizeDAG())
12756     if (SDValue V = combineDeMorganOfBoolean(N, DAG))
12757       return V;
12758 
12759   // Look for Or of CZERO_EQZ/NEZ with same condition which is the select idiom.
12760   // We may be able to pull a common operation out of the true and false value.
12761   SDValue N0 = N->getOperand(0);
12762   SDValue N1 = N->getOperand(1);
12763   if (SDValue V = combineOrOfCZERO(N, N0, N1, DAG))
12764     return V;
12765   if (SDValue V = combineOrOfCZERO(N, N1, N0, DAG))
12766     return V;
12767 
12768   // fold (or (select cond, 0, y), x) ->
12769   //      (select cond, x, (or x, y))
12770   return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
12771 }
12772 
performXORCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12773 static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
12774                                  const RISCVSubtarget &Subtarget) {
12775   SDValue N0 = N->getOperand(0);
12776   SDValue N1 = N->getOperand(1);
12777 
12778   // Pre-promote (i32 (xor (shl -1, X), ~0)) on RV64 with Zbs so we can use
12779   // (ADDI (BSET X0, X), -1). If we wait until/ type legalization, we'll create
12780   // RISCVISD:::SLLW and we can't recover it to use a BSET instruction.
12781   if (!RV64LegalI32 && Subtarget.is64Bit() && Subtarget.hasStdExtZbs() &&
12782       N->getValueType(0) == MVT::i32 && isAllOnesConstant(N1) &&
12783       N0.getOpcode() == ISD::SHL && isAllOnesConstant(N0.getOperand(0)) &&
12784       !isa<ConstantSDNode>(N0.getOperand(1)) && N0.hasOneUse()) {
12785     SDLoc DL(N);
12786     SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N0.getOperand(0));
12787     SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N0.getOperand(1));
12788     SDValue Shl = DAG.getNode(ISD::SHL, DL, MVT::i64, Op0, Op1);
12789     SDValue And = DAG.getNOT(DL, Shl, MVT::i64);
12790     return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And);
12791   }
12792 
12793   // fold (xor (sllw 1, x), -1) -> (rolw ~1, x)
12794   // NOTE: Assumes ROL being legal means ROLW is legal.
12795   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12796   if (N0.getOpcode() == RISCVISD::SLLW &&
12797       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0)) &&
12798       TLI.isOperationLegal(ISD::ROTL, MVT::i64)) {
12799     SDLoc DL(N);
12800     return DAG.getNode(RISCVISD::ROLW, DL, MVT::i64,
12801                        DAG.getConstant(~1, DL, MVT::i64), N0.getOperand(1));
12802   }
12803 
12804   // Fold (xor (setcc constant, y, setlt), 1) -> (setcc y, constant + 1, setlt)
12805   if (N0.getOpcode() == ISD::SETCC && isOneConstant(N1) && N0.hasOneUse()) {
12806     auto *ConstN00 = dyn_cast<ConstantSDNode>(N0.getOperand(0));
12807     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
12808     if (ConstN00 && CC == ISD::SETLT) {
12809       EVT VT = N0.getValueType();
12810       SDLoc DL(N0);
12811       const APInt &Imm = ConstN00->getAPIntValue();
12812       if ((Imm + 1).isSignedIntN(12))
12813         return DAG.getSetCC(DL, VT, N0.getOperand(1),
12814                             DAG.getConstant(Imm + 1, DL, VT), CC);
12815     }
12816   }
12817 
12818   if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
12819     return V;
12820   if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
12821     return V;
12822 
12823   // fold (xor (select cond, 0, y), x) ->
12824   //      (select cond, x, (xor x, y))
12825   return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
12826 }
12827 
performMULCombine(SDNode * N,SelectionDAG & DAG)12828 static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG) {
12829   EVT VT = N->getValueType(0);
12830   if (!VT.isVector())
12831     return SDValue();
12832 
12833   SDLoc DL(N);
12834   SDValue N0 = N->getOperand(0);
12835   SDValue N1 = N->getOperand(1);
12836   SDValue MulOper;
12837   unsigned AddSubOpc;
12838 
12839   // vmadd: (mul (add x, 1), y) -> (add (mul x, y), y)
12840   //        (mul x, add (y, 1)) -> (add x, (mul x, y))
12841   // vnmsub: (mul (sub 1, x), y) -> (sub y, (mul x, y))
12842   //         (mul x, (sub 1, y)) -> (sub x, (mul x, y))
12843   auto IsAddSubWith1 = [&](SDValue V) -> bool {
12844     AddSubOpc = V->getOpcode();
12845     if ((AddSubOpc == ISD::ADD || AddSubOpc == ISD::SUB) && V->hasOneUse()) {
12846       SDValue Opnd = V->getOperand(1);
12847       MulOper = V->getOperand(0);
12848       if (AddSubOpc == ISD::SUB)
12849         std::swap(Opnd, MulOper);
12850       if (isOneOrOneSplat(Opnd))
12851         return true;
12852     }
12853     return false;
12854   };
12855 
12856   if (IsAddSubWith1(N0)) {
12857     SDValue MulVal = DAG.getNode(ISD::MUL, DL, VT, N1, MulOper);
12858     return DAG.getNode(AddSubOpc, DL, VT, N1, MulVal);
12859   }
12860 
12861   if (IsAddSubWith1(N1)) {
12862     SDValue MulVal = DAG.getNode(ISD::MUL, DL, VT, N0, MulOper);
12863     return DAG.getNode(AddSubOpc, DL, VT, N0, MulVal);
12864   }
12865 
12866   return SDValue();
12867 }
12868 
12869 /// According to the property that indexed load/store instructions zero-extend
12870 /// their indices, try to narrow the type of index operand.
narrowIndex(SDValue & N,ISD::MemIndexType IndexType,SelectionDAG & DAG)12871 static bool narrowIndex(SDValue &N, ISD::MemIndexType IndexType, SelectionDAG &DAG) {
12872   if (isIndexTypeSigned(IndexType))
12873     return false;
12874 
12875   if (!N->hasOneUse())
12876     return false;
12877 
12878   EVT VT = N.getValueType();
12879   SDLoc DL(N);
12880 
12881   // In general, what we're doing here is seeing if we can sink a truncate to
12882   // a smaller element type into the expression tree building our index.
12883   // TODO: We can generalize this and handle a bunch more cases if useful.
12884 
12885   // Narrow a buildvector to the narrowest element type.  This requires less
12886   // work and less register pressure at high LMUL, and creates smaller constants
12887   // which may be cheaper to materialize.
12888   if (ISD::isBuildVectorOfConstantSDNodes(N.getNode())) {
12889     KnownBits Known = DAG.computeKnownBits(N);
12890     unsigned ActiveBits = std::max(8u, Known.countMaxActiveBits());
12891     LLVMContext &C = *DAG.getContext();
12892     EVT ResultVT = EVT::getIntegerVT(C, ActiveBits).getRoundIntegerType(C);
12893     if (ResultVT.bitsLT(VT.getVectorElementType())) {
12894       N = DAG.getNode(ISD::TRUNCATE, DL,
12895                       VT.changeVectorElementType(ResultVT), N);
12896       return true;
12897     }
12898   }
12899 
12900   // Handle the pattern (shl (zext x to ty), C) and bits(x) + C < bits(ty).
12901   if (N.getOpcode() != ISD::SHL)
12902     return false;
12903 
12904   SDValue N0 = N.getOperand(0);
12905   if (N0.getOpcode() != ISD::ZERO_EXTEND &&
12906       N0.getOpcode() != RISCVISD::VZEXT_VL)
12907     return false;
12908   if (!N0->hasOneUse())
12909     return false;
12910 
12911   APInt ShAmt;
12912   SDValue N1 = N.getOperand(1);
12913   if (!ISD::isConstantSplatVector(N1.getNode(), ShAmt))
12914     return false;
12915 
12916   SDValue Src = N0.getOperand(0);
12917   EVT SrcVT = Src.getValueType();
12918   unsigned SrcElen = SrcVT.getScalarSizeInBits();
12919   unsigned ShAmtV = ShAmt.getZExtValue();
12920   unsigned NewElen = PowerOf2Ceil(SrcElen + ShAmtV);
12921   NewElen = std::max(NewElen, 8U);
12922 
12923   // Skip if NewElen is not narrower than the original extended type.
12924   if (NewElen >= N0.getValueType().getScalarSizeInBits())
12925     return false;
12926 
12927   EVT NewEltVT = EVT::getIntegerVT(*DAG.getContext(), NewElen);
12928   EVT NewVT = SrcVT.changeVectorElementType(NewEltVT);
12929 
12930   SDValue NewExt = DAG.getNode(N0->getOpcode(), DL, NewVT, N0->ops());
12931   SDValue NewShAmtVec = DAG.getConstant(ShAmtV, DL, NewVT);
12932   N = DAG.getNode(ISD::SHL, DL, NewVT, NewExt, NewShAmtVec);
12933   return true;
12934 }
12935 
12936 // Replace (seteq (i64 (and X, 0xffffffff)), C1) with
12937 // (seteq (i64 (sext_inreg (X, i32)), C1')) where C1' is C1 sign extended from
12938 // bit 31. Same for setne. C1' may be cheaper to materialize and the sext_inreg
12939 // can become a sext.w instead of a shift pair.
performSETCCCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12940 static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
12941                                    const RISCVSubtarget &Subtarget) {
12942   SDValue N0 = N->getOperand(0);
12943   SDValue N1 = N->getOperand(1);
12944   EVT VT = N->getValueType(0);
12945   EVT OpVT = N0.getValueType();
12946 
12947   if (OpVT != MVT::i64 || !Subtarget.is64Bit())
12948     return SDValue();
12949 
12950   // RHS needs to be a constant.
12951   auto *N1C = dyn_cast<ConstantSDNode>(N1);
12952   if (!N1C)
12953     return SDValue();
12954 
12955   // LHS needs to be (and X, 0xffffffff).
12956   if (N0.getOpcode() != ISD::AND || !N0.hasOneUse() ||
12957       !isa<ConstantSDNode>(N0.getOperand(1)) ||
12958       N0.getConstantOperandVal(1) != UINT64_C(0xffffffff))
12959     return SDValue();
12960 
12961   // Looking for an equality compare.
12962   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
12963   if (!isIntEqualitySetCC(Cond))
12964     return SDValue();
12965 
12966   // Don't do this if the sign bit is provably zero, it will be turned back into
12967   // an AND.
12968   APInt SignMask = APInt::getOneBitSet(64, 31);
12969   if (DAG.MaskedValueIsZero(N0.getOperand(0), SignMask))
12970     return SDValue();
12971 
12972   const APInt &C1 = N1C->getAPIntValue();
12973 
12974   SDLoc dl(N);
12975   // If the constant is larger than 2^32 - 1 it is impossible for both sides
12976   // to be equal.
12977   if (C1.getActiveBits() > 32)
12978     return DAG.getBoolConstant(Cond == ISD::SETNE, dl, VT, OpVT);
12979 
12980   SDValue SExtOp = DAG.getNode(ISD::SIGN_EXTEND_INREG, N, OpVT,
12981                                N0.getOperand(0), DAG.getValueType(MVT::i32));
12982   return DAG.getSetCC(dl, VT, SExtOp, DAG.getConstant(C1.trunc(32).sext(64),
12983                                                       dl, OpVT), Cond);
12984 }
12985 
12986 static SDValue
performSIGN_EXTEND_INREGCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)12987 performSIGN_EXTEND_INREGCombine(SDNode *N, SelectionDAG &DAG,
12988                                 const RISCVSubtarget &Subtarget) {
12989   SDValue Src = N->getOperand(0);
12990   EVT VT = N->getValueType(0);
12991 
12992   // Fold (sext_inreg (fmv_x_anyexth X), i16) -> (fmv_x_signexth X)
12993   if (Src.getOpcode() == RISCVISD::FMV_X_ANYEXTH &&
12994       cast<VTSDNode>(N->getOperand(1))->getVT().bitsGE(MVT::i16))
12995     return DAG.getNode(RISCVISD::FMV_X_SIGNEXTH, SDLoc(N), VT,
12996                        Src.getOperand(0));
12997 
12998   return SDValue();
12999 }
13000 
13001 namespace {
13002 // Forward declaration of the structure holding the necessary information to
13003 // apply a combine.
13004 struct CombineResult;
13005 
13006 /// Helper class for folding sign/zero extensions.
13007 /// In particular, this class is used for the following combines:
13008 /// add | add_vl -> vwadd(u) | vwadd(u)_w
13009 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
13010 /// mul | mul_vl -> vwmul(u) | vwmul_su
13011 ///
13012 /// An object of this class represents an operand of the operation we want to
13013 /// combine.
13014 /// E.g., when trying to combine `mul_vl a, b`, we will have one instance of
13015 /// NodeExtensionHelper for `a` and one for `b`.
13016 ///
13017 /// This class abstracts away how the extension is materialized and
13018 /// how its Mask, VL, number of users affect the combines.
13019 ///
13020 /// In particular:
13021 /// - VWADD_W is conceptually == add(op0, sext(op1))
13022 /// - VWADDU_W == add(op0, zext(op1))
13023 /// - VWSUB_W == sub(op0, sext(op1))
13024 /// - VWSUBU_W == sub(op0, zext(op1))
13025 ///
13026 /// And VMV_V_X_VL, depending on the value, is conceptually equivalent to
13027 /// zext|sext(smaller_value).
13028 struct NodeExtensionHelper {
13029   /// Records if this operand is like being zero extended.
13030   bool SupportsZExt;
13031   /// Records if this operand is like being sign extended.
13032   /// Note: SupportsZExt and SupportsSExt are not mutually exclusive. For
13033   /// instance, a splat constant (e.g., 3), would support being both sign and
13034   /// zero extended.
13035   bool SupportsSExt;
13036   /// This boolean captures whether we care if this operand would still be
13037   /// around after the folding happens.
13038   bool EnforceOneUse;
13039   /// Records if this operand's mask needs to match the mask of the operation
13040   /// that it will fold into.
13041   bool CheckMask;
13042   /// Value of the Mask for this operand.
13043   /// It may be SDValue().
13044   SDValue Mask;
13045   /// Value of the vector length operand.
13046   /// It may be SDValue().
13047   SDValue VL;
13048   /// Original value that this NodeExtensionHelper represents.
13049   SDValue OrigOperand;
13050 
13051   /// Get the value feeding the extension or the value itself.
13052   /// E.g., for zext(a), this would return a.
getSource__anonb0ee9b7f1111::NodeExtensionHelper13053   SDValue getSource() const {
13054     switch (OrigOperand.getOpcode()) {
13055     case ISD::ZERO_EXTEND:
13056     case ISD::SIGN_EXTEND:
13057     case RISCVISD::VSEXT_VL:
13058     case RISCVISD::VZEXT_VL:
13059       return OrigOperand.getOperand(0);
13060     default:
13061       return OrigOperand;
13062     }
13063   }
13064 
13065   /// Check if this instance represents a splat.
isSplat__anonb0ee9b7f1111::NodeExtensionHelper13066   bool isSplat() const {
13067     return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
13068   }
13069 
13070   /// Get or create a value that can feed \p Root with the given extension \p
13071   /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
13072   /// \see ::getSource().
getOrCreateExtendedOp__anonb0ee9b7f1111::NodeExtensionHelper13073   SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
13074                                 const RISCVSubtarget &Subtarget,
13075                                 std::optional<bool> SExt) const {
13076     if (!SExt.has_value())
13077       return OrigOperand;
13078 
13079     MVT NarrowVT = getNarrowType(Root);
13080 
13081     SDValue Source = getSource();
13082     if (Source.getValueType() == NarrowVT)
13083       return Source;
13084 
13085     unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
13086 
13087     // If we need an extension, we should be changing the type.
13088     SDLoc DL(Root);
13089     auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
13090     switch (OrigOperand.getOpcode()) {
13091     case ISD::ZERO_EXTEND:
13092     case ISD::SIGN_EXTEND:
13093     case RISCVISD::VSEXT_VL:
13094     case RISCVISD::VZEXT_VL:
13095       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
13096     case RISCVISD::VMV_V_X_VL:
13097       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
13098                          DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
13099     default:
13100       // Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
13101       // and that operand should already have the right NarrowVT so no
13102       // extension should be required at this point.
13103       llvm_unreachable("Unsupported opcode");
13104     }
13105   }
13106 
13107   /// Helper function to get the narrow type for \p Root.
13108   /// The narrow type is the type of \p Root where we divided the size of each
13109   /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
13110   /// \pre The size of the type of the elements of Root must be a multiple of 2
13111   /// and be greater than 16.
getNarrowType__anonb0ee9b7f1111::NodeExtensionHelper13112   static MVT getNarrowType(const SDNode *Root) {
13113     MVT VT = Root->getSimpleValueType(0);
13114 
13115     // Determine the narrow size.
13116     unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
13117     assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
13118     MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
13119                                     VT.getVectorElementCount());
13120     return NarrowVT;
13121   }
13122 
13123   /// Return the opcode required to materialize the folding of the sign
13124   /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
13125   /// both operands for \p Opcode.
13126   /// Put differently, get the opcode to materialize:
13127   /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
13128   /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
13129   /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
getSameExtensionOpcode__anonb0ee9b7f1111::NodeExtensionHelper13130   static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
13131     switch (Opcode) {
13132     case ISD::ADD:
13133     case RISCVISD::ADD_VL:
13134     case RISCVISD::VWADD_W_VL:
13135     case RISCVISD::VWADDU_W_VL:
13136       return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
13137     case ISD::MUL:
13138     case RISCVISD::MUL_VL:
13139       return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
13140     case ISD::SUB:
13141     case RISCVISD::SUB_VL:
13142     case RISCVISD::VWSUB_W_VL:
13143     case RISCVISD::VWSUBU_W_VL:
13144       return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
13145     default:
13146       llvm_unreachable("Unexpected opcode");
13147     }
13148   }
13149 
13150   /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
13151   /// newOpcode(a, b).
getSUOpcode__anonb0ee9b7f1111::NodeExtensionHelper13152   static unsigned getSUOpcode(unsigned Opcode) {
13153     assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
13154            "SU is only supported for MUL");
13155     return RISCVISD::VWMULSU_VL;
13156   }
13157 
13158   /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
13159   /// newOpcode(a, b).
getWOpcode__anonb0ee9b7f1111::NodeExtensionHelper13160   static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
13161     switch (Opcode) {
13162     case ISD::ADD:
13163     case RISCVISD::ADD_VL:
13164       return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
13165     case ISD::SUB:
13166     case RISCVISD::SUB_VL:
13167       return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
13168     default:
13169       llvm_unreachable("Unexpected opcode");
13170     }
13171   }
13172 
13173   using CombineToTry = std::function<std::optional<CombineResult>(
13174       SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
13175       const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
13176       const RISCVSubtarget &)>;
13177 
13178   /// Check if this node needs to be fully folded or extended for all users.
needToPromoteOtherUsers__anonb0ee9b7f1111::NodeExtensionHelper13179   bool needToPromoteOtherUsers() const { return EnforceOneUse; }
13180 
13181   /// Helper method to set the various fields of this struct based on the
13182   /// type of \p Root.
fillUpExtensionSupport__anonb0ee9b7f1111::NodeExtensionHelper13183   void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
13184                               const RISCVSubtarget &Subtarget) {
13185     SupportsZExt = false;
13186     SupportsSExt = false;
13187     EnforceOneUse = true;
13188     CheckMask = true;
13189     unsigned Opc = OrigOperand.getOpcode();
13190     switch (Opc) {
13191     case ISD::ZERO_EXTEND:
13192     case ISD::SIGN_EXTEND: {
13193       MVT VT = OrigOperand.getSimpleValueType();
13194       if (!VT.isVector())
13195         break;
13196 
13197       SDValue NarrowElt = OrigOperand.getOperand(0);
13198       MVT NarrowVT = NarrowElt.getSimpleValueType();
13199 
13200       unsigned ScalarBits = VT.getScalarSizeInBits();
13201       unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits();
13202 
13203       // Ensure the narrowing element type is legal
13204       if (!Subtarget.getTargetLowering()->isTypeLegal(NarrowElt.getValueType()))
13205         break;
13206 
13207       // Ensure the extension's semantic is equivalent to rvv vzext or vsext.
13208       if (ScalarBits != NarrowScalarBits * 2)
13209         break;
13210 
13211       SupportsZExt = Opc == ISD::ZERO_EXTEND;
13212       SupportsSExt = Opc == ISD::SIGN_EXTEND;
13213 
13214       SDLoc DL(Root);
13215       std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
13216       break;
13217     }
13218     case RISCVISD::VZEXT_VL:
13219       SupportsZExt = true;
13220       Mask = OrigOperand.getOperand(1);
13221       VL = OrigOperand.getOperand(2);
13222       break;
13223     case RISCVISD::VSEXT_VL:
13224       SupportsSExt = true;
13225       Mask = OrigOperand.getOperand(1);
13226       VL = OrigOperand.getOperand(2);
13227       break;
13228     case RISCVISD::VMV_V_X_VL: {
13229       // Historically, we didn't care about splat values not disappearing during
13230       // combines.
13231       EnforceOneUse = false;
13232       CheckMask = false;
13233       VL = OrigOperand.getOperand(2);
13234 
13235       // The operand is a splat of a scalar.
13236 
13237       // The pasthru must be undef for tail agnostic.
13238       if (!OrigOperand.getOperand(0).isUndef())
13239         break;
13240 
13241       // Get the scalar value.
13242       SDValue Op = OrigOperand.getOperand(1);
13243 
13244       // See if we have enough sign bits or zero bits in the scalar to use a
13245       // widening opcode by splatting to smaller element size.
13246       MVT VT = Root->getSimpleValueType(0);
13247       unsigned EltBits = VT.getScalarSizeInBits();
13248       unsigned ScalarBits = Op.getValueSizeInBits();
13249       // Make sure we're getting all element bits from the scalar register.
13250       // FIXME: Support implicit sign extension of vmv.v.x?
13251       if (ScalarBits < EltBits)
13252         break;
13253 
13254       unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
13255       // If the narrow type cannot be expressed with a legal VMV,
13256       // this is not a valid candidate.
13257       if (NarrowSize < 8)
13258         break;
13259 
13260       if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
13261         SupportsSExt = true;
13262       if (DAG.MaskedValueIsZero(Op,
13263                                 APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
13264         SupportsZExt = true;
13265       break;
13266     }
13267     default:
13268       break;
13269     }
13270   }
13271 
13272   /// Check if \p Root supports any extension folding combines.
isSupportedRoot__anonb0ee9b7f1111::NodeExtensionHelper13273   static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
13274     switch (Root->getOpcode()) {
13275     case ISD::ADD:
13276     case ISD::SUB:
13277     case ISD::MUL: {
13278       const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13279       if (!TLI.isTypeLegal(Root->getValueType(0)))
13280         return false;
13281       return Root->getValueType(0).isScalableVector();
13282     }
13283     case RISCVISD::ADD_VL:
13284     case RISCVISD::MUL_VL:
13285     case RISCVISD::VWADD_W_VL:
13286     case RISCVISD::VWADDU_W_VL:
13287     case RISCVISD::SUB_VL:
13288     case RISCVISD::VWSUB_W_VL:
13289     case RISCVISD::VWSUBU_W_VL:
13290       return true;
13291     default:
13292       return false;
13293     }
13294   }
13295 
13296   /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
NodeExtensionHelper__anonb0ee9b7f1111::NodeExtensionHelper13297   NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
13298                       const RISCVSubtarget &Subtarget) {
13299     assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
13300                                          "unsupported root");
13301     assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
13302     OrigOperand = Root->getOperand(OperandIdx);
13303 
13304     unsigned Opc = Root->getOpcode();
13305     switch (Opc) {
13306     // We consider VW<ADD|SUB>(U)_W(LHS, RHS) as if they were
13307     // <ADD|SUB>(LHS, S|ZEXT(RHS))
13308     case RISCVISD::VWADD_W_VL:
13309     case RISCVISD::VWADDU_W_VL:
13310     case RISCVISD::VWSUB_W_VL:
13311     case RISCVISD::VWSUBU_W_VL:
13312       if (OperandIdx == 1) {
13313         SupportsZExt =
13314             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
13315         SupportsSExt = !SupportsZExt;
13316         std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
13317         CheckMask = true;
13318         // There's no existing extension here, so we don't have to worry about
13319         // making sure it gets removed.
13320         EnforceOneUse = false;
13321         break;
13322       }
13323       [[fallthrough]];
13324     default:
13325       fillUpExtensionSupport(Root, DAG, Subtarget);
13326       break;
13327     }
13328   }
13329 
13330   /// Check if this operand is compatible with the given vector length \p VL.
isVLCompatible__anonb0ee9b7f1111::NodeExtensionHelper13331   bool isVLCompatible(SDValue VL) const {
13332     return this->VL != SDValue() && this->VL == VL;
13333   }
13334 
13335   /// Check if this operand is compatible with the given \p Mask.
isMaskCompatible__anonb0ee9b7f1111::NodeExtensionHelper13336   bool isMaskCompatible(SDValue Mask) const {
13337     return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
13338   }
13339 
13340   /// Helper function to get the Mask and VL from \p Root.
13341   static std::pair<SDValue, SDValue>
getMaskAndVL__anonb0ee9b7f1111::NodeExtensionHelper13342   getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
13343                const RISCVSubtarget &Subtarget) {
13344     assert(isSupportedRoot(Root, DAG) && "Unexpected root");
13345     switch (Root->getOpcode()) {
13346     case ISD::ADD:
13347     case ISD::SUB:
13348     case ISD::MUL: {
13349       SDLoc DL(Root);
13350       MVT VT = Root->getSimpleValueType(0);
13351       return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
13352     }
13353     default:
13354       return std::make_pair(Root->getOperand(3), Root->getOperand(4));
13355     }
13356   }
13357 
13358   /// Check if the Mask and VL of this operand are compatible with \p Root.
areVLAndMaskCompatible__anonb0ee9b7f1111::NodeExtensionHelper13359   bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
13360                               const RISCVSubtarget &Subtarget) const {
13361     auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
13362     return isMaskCompatible(Mask) && isVLCompatible(VL);
13363   }
13364 
13365   /// Helper function to check if \p N is commutative with respect to the
13366   /// foldings that are supported by this class.
isCommutative__anonb0ee9b7f1111::NodeExtensionHelper13367   static bool isCommutative(const SDNode *N) {
13368     switch (N->getOpcode()) {
13369     case ISD::ADD:
13370     case ISD::MUL:
13371     case RISCVISD::ADD_VL:
13372     case RISCVISD::MUL_VL:
13373     case RISCVISD::VWADD_W_VL:
13374     case RISCVISD::VWADDU_W_VL:
13375       return true;
13376     case ISD::SUB:
13377     case RISCVISD::SUB_VL:
13378     case RISCVISD::VWSUB_W_VL:
13379     case RISCVISD::VWSUBU_W_VL:
13380       return false;
13381     default:
13382       llvm_unreachable("Unexpected opcode");
13383     }
13384   }
13385 
13386   /// Get a list of combine to try for folding extensions in \p Root.
13387   /// Note that each returned CombineToTry function doesn't actually modify
13388   /// anything. Instead they produce an optional CombineResult that if not None,
13389   /// need to be materialized for the combine to be applied.
13390   /// \see CombineResult::materialize.
13391   /// If the related CombineToTry function returns std::nullopt, that means the
13392   /// combine didn't match.
13393   static SmallVector<CombineToTry> getSupportedFoldings(const SDNode *Root);
13394 };
13395 
13396 /// Helper structure that holds all the necessary information to materialize a
13397 /// combine that does some extension folding.
13398 struct CombineResult {
13399   /// Opcode to be generated when materializing the combine.
13400   unsigned TargetOpcode;
13401   // No value means no extension is needed. If extension is needed, the value
13402   // indicates if it needs to be sign extended.
13403   std::optional<bool> SExtLHS;
13404   std::optional<bool> SExtRHS;
13405   /// Root of the combine.
13406   SDNode *Root;
13407   /// LHS of the TargetOpcode.
13408   NodeExtensionHelper LHS;
13409   /// RHS of the TargetOpcode.
13410   NodeExtensionHelper RHS;
13411 
CombineResult__anonb0ee9b7f1111::CombineResult13412   CombineResult(unsigned TargetOpcode, SDNode *Root,
13413                 const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
13414                 const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
13415       : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
13416         Root(Root), LHS(LHS), RHS(RHS) {}
13417 
13418   /// Return a value that uses TargetOpcode and that can be used to replace
13419   /// Root.
13420   /// The actual replacement is *not* done in that method.
materialize__anonb0ee9b7f1111::CombineResult13421   SDValue materialize(SelectionDAG &DAG,
13422                       const RISCVSubtarget &Subtarget) const {
13423     SDValue Mask, VL, Merge;
13424     std::tie(Mask, VL) =
13425         NodeExtensionHelper::getMaskAndVL(Root, DAG, Subtarget);
13426     switch (Root->getOpcode()) {
13427     default:
13428       Merge = Root->getOperand(2);
13429       break;
13430     case ISD::ADD:
13431     case ISD::SUB:
13432     case ISD::MUL:
13433       Merge = DAG.getUNDEF(Root->getValueType(0));
13434       break;
13435     }
13436     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
13437                        LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
13438                        RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
13439                        Merge, Mask, VL);
13440   }
13441 };
13442 
13443 /// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
13444 /// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
13445 /// are zext) and LHS and RHS can be folded into Root.
13446 /// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
13447 ///
13448 /// \note If the pattern can match with both zext and sext, the returned
13449 /// CombineResult will feature the zext result.
13450 ///
13451 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
13452 /// can be used to apply the pattern.
13453 static std::optional<CombineResult>
canFoldToVWWithSameExtensionImpl(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS,bool AllowSExt,bool AllowZExt,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)13454 canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
13455                                  const NodeExtensionHelper &RHS, bool AllowSExt,
13456                                  bool AllowZExt, SelectionDAG &DAG,
13457                                  const RISCVSubtarget &Subtarget) {
13458   assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
13459   if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
13460       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
13461     return std::nullopt;
13462   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13463     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13464                              Root->getOpcode(), /*IsSExt=*/false),
13465                          Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
13466   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
13467     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13468                              Root->getOpcode(), /*IsSExt=*/true),
13469                          Root, LHS, /*SExtLHS=*/true, RHS,
13470                          /*SExtRHS=*/true);
13471   return std::nullopt;
13472 }
13473 
13474 /// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
13475 /// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
13476 /// are zext) and LHS and RHS can be folded into Root.
13477 ///
13478 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
13479 /// can be used to apply the pattern.
13480 static std::optional<CombineResult>
canFoldToVWWithSameExtension(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)13481 canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
13482                              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13483                              const RISCVSubtarget &Subtarget) {
13484   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13485                                           /*AllowZExt=*/true, DAG, Subtarget);
13486 }
13487 
13488 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
13489 ///
13490 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
13491 /// can be used to apply the pattern.
13492 static std::optional<CombineResult>
canFoldToVW_W(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)13493 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
13494               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13495               const RISCVSubtarget &Subtarget) {
13496   if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
13497     return std::nullopt;
13498 
13499   // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
13500   // sext/zext?
13501   // Control this behavior behind an option (AllowSplatInVW_W) for testing
13502   // purposes.
13503   if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W))
13504     return CombineResult(
13505         NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false),
13506         Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false);
13507   if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W))
13508     return CombineResult(
13509         NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true),
13510         Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true);
13511   return std::nullopt;
13512 }
13513 
13514 /// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
13515 ///
13516 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
13517 /// can be used to apply the pattern.
13518 static std::optional<CombineResult>
canFoldToVWWithSEXT(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)13519 canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
13520                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13521                     const RISCVSubtarget &Subtarget) {
13522   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13523                                           /*AllowZExt=*/false, DAG, Subtarget);
13524 }
13525 
13526 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
13527 ///
13528 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
13529 /// can be used to apply the pattern.
13530 static std::optional<CombineResult>
canFoldToVWWithZEXT(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)13531 canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
13532                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13533                     const RISCVSubtarget &Subtarget) {
13534   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13535                                           /*AllowZExt=*/true, DAG, Subtarget);
13536 }
13537 
13538 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
13539 ///
13540 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
13541 /// can be used to apply the pattern.
13542 static std::optional<CombineResult>
canFoldToVW_SU(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)13543 canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
13544                const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13545                const RISCVSubtarget &Subtarget) {
13546 
13547   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
13548     return std::nullopt;
13549   if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
13550       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
13551     return std::nullopt;
13552   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
13553                        Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
13554 }
13555 
13556 SmallVector<NodeExtensionHelper::CombineToTry>
getSupportedFoldings(const SDNode * Root)13557 NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
13558   SmallVector<CombineToTry> Strategies;
13559   switch (Root->getOpcode()) {
13560   case ISD::ADD:
13561   case ISD::SUB:
13562   case RISCVISD::ADD_VL:
13563   case RISCVISD::SUB_VL:
13564     // add|sub -> vwadd(u)|vwsub(u)
13565     Strategies.push_back(canFoldToVWWithSameExtension);
13566     // add|sub -> vwadd(u)_w|vwsub(u)_w
13567     Strategies.push_back(canFoldToVW_W);
13568     break;
13569   case ISD::MUL:
13570   case RISCVISD::MUL_VL:
13571     // mul -> vwmul(u)
13572     Strategies.push_back(canFoldToVWWithSameExtension);
13573     // mul -> vwmulsu
13574     Strategies.push_back(canFoldToVW_SU);
13575     break;
13576   case RISCVISD::VWADD_W_VL:
13577   case RISCVISD::VWSUB_W_VL:
13578     // vwadd_w|vwsub_w -> vwadd|vwsub
13579     Strategies.push_back(canFoldToVWWithSEXT);
13580     break;
13581   case RISCVISD::VWADDU_W_VL:
13582   case RISCVISD::VWSUBU_W_VL:
13583     // vwaddu_w|vwsubu_w -> vwaddu|vwsubu
13584     Strategies.push_back(canFoldToVWWithZEXT);
13585     break;
13586   default:
13587     llvm_unreachable("Unexpected opcode");
13588   }
13589   return Strategies;
13590 }
13591 } // End anonymous namespace.
13592 
13593 /// Combine a binary operation to its equivalent VW or VW_W form.
13594 /// The supported combines are:
13595 /// add_vl -> vwadd(u) | vwadd(u)_w
13596 /// sub_vl -> vwsub(u) | vwsub(u)_w
13597 /// mul_vl -> vwmul(u) | vwmul_su
13598 /// vwadd_w(u) -> vwadd(u)
13599 /// vwub_w(u) -> vwadd(u)
combineBinOp_VLToVWBinOp_VL(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)13600 static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
13601                                            TargetLowering::DAGCombinerInfo &DCI,
13602                                            const RISCVSubtarget &Subtarget) {
13603   SelectionDAG &DAG = DCI.DAG;
13604 
13605   if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
13606     return SDValue();
13607 
13608   SmallVector<SDNode *> Worklist;
13609   SmallSet<SDNode *, 8> Inserted;
13610   Worklist.push_back(N);
13611   Inserted.insert(N);
13612   SmallVector<CombineResult> CombinesToApply;
13613 
13614   while (!Worklist.empty()) {
13615     SDNode *Root = Worklist.pop_back_val();
13616     if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
13617       return SDValue();
13618 
13619     NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
13620     NodeExtensionHelper RHS(N, 1, DAG, Subtarget);
13621     auto AppendUsersIfNeeded = [&Worklist,
13622                                 &Inserted](const NodeExtensionHelper &Op) {
13623       if (Op.needToPromoteOtherUsers()) {
13624         for (SDNode *TheUse : Op.OrigOperand->uses()) {
13625           if (Inserted.insert(TheUse).second)
13626             Worklist.push_back(TheUse);
13627         }
13628       }
13629     };
13630 
13631     // Control the compile time by limiting the number of node we look at in
13632     // total.
13633     if (Inserted.size() > ExtensionMaxWebSize)
13634       return SDValue();
13635 
13636     SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
13637         NodeExtensionHelper::getSupportedFoldings(N);
13638 
13639     assert(!FoldingStrategies.empty() && "Nothing to be folded");
13640     bool Matched = false;
13641     for (int Attempt = 0;
13642          (Attempt != 1 + NodeExtensionHelper::isCommutative(N)) && !Matched;
13643          ++Attempt) {
13644 
13645       for (NodeExtensionHelper::CombineToTry FoldingStrategy :
13646            FoldingStrategies) {
13647         std::optional<CombineResult> Res =
13648             FoldingStrategy(N, LHS, RHS, DAG, Subtarget);
13649         if (Res) {
13650           Matched = true;
13651           CombinesToApply.push_back(*Res);
13652           // All the inputs that are extended need to be folded, otherwise
13653           // we would be leaving the old input (since it is may still be used),
13654           // and the new one.
13655           if (Res->SExtLHS.has_value())
13656             AppendUsersIfNeeded(LHS);
13657           if (Res->SExtRHS.has_value())
13658             AppendUsersIfNeeded(RHS);
13659           break;
13660         }
13661       }
13662       std::swap(LHS, RHS);
13663     }
13664     // Right now we do an all or nothing approach.
13665     if (!Matched)
13666       return SDValue();
13667   }
13668   // Store the value for the replacement of the input node separately.
13669   SDValue InputRootReplacement;
13670   // We do the RAUW after we materialize all the combines, because some replaced
13671   // nodes may be feeding some of the yet-to-be-replaced nodes. Put differently,
13672   // some of these nodes may appear in the NodeExtensionHelpers of some of the
13673   // yet-to-be-visited CombinesToApply roots.
13674   SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
13675   ValuesToReplace.reserve(CombinesToApply.size());
13676   for (CombineResult Res : CombinesToApply) {
13677     SDValue NewValue = Res.materialize(DAG, Subtarget);
13678     if (!InputRootReplacement) {
13679       assert(Res.Root == N &&
13680              "First element is expected to be the current node");
13681       InputRootReplacement = NewValue;
13682     } else {
13683       ValuesToReplace.emplace_back(SDValue(Res.Root, 0), NewValue);
13684     }
13685   }
13686   for (std::pair<SDValue, SDValue> OldNewValues : ValuesToReplace) {
13687     DAG.ReplaceAllUsesOfValueWith(OldNewValues.first, OldNewValues.second);
13688     DCI.AddToWorklist(OldNewValues.second.getNode());
13689   }
13690   return InputRootReplacement;
13691 }
13692 
13693 // Helper function for performMemPairCombine.
13694 // Try to combine the memory loads/stores LSNode1 and LSNode2
13695 // into a single memory pair operation.
tryMemPairCombine(SelectionDAG & DAG,LSBaseSDNode * LSNode1,LSBaseSDNode * LSNode2,SDValue BasePtr,uint64_t Imm)13696 static SDValue tryMemPairCombine(SelectionDAG &DAG, LSBaseSDNode *LSNode1,
13697                                  LSBaseSDNode *LSNode2, SDValue BasePtr,
13698                                  uint64_t Imm) {
13699   SmallPtrSet<const SDNode *, 32> Visited;
13700   SmallVector<const SDNode *, 8> Worklist = {LSNode1, LSNode2};
13701 
13702   if (SDNode::hasPredecessorHelper(LSNode1, Visited, Worklist) ||
13703       SDNode::hasPredecessorHelper(LSNode2, Visited, Worklist))
13704     return SDValue();
13705 
13706   MachineFunction &MF = DAG.getMachineFunction();
13707   const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
13708 
13709   // The new operation has twice the width.
13710   MVT XLenVT = Subtarget.getXLenVT();
13711   EVT MemVT = LSNode1->getMemoryVT();
13712   EVT NewMemVT = (MemVT == MVT::i32) ? MVT::i64 : MVT::i128;
13713   MachineMemOperand *MMO = LSNode1->getMemOperand();
13714   MachineMemOperand *NewMMO = MF.getMachineMemOperand(
13715       MMO, MMO->getPointerInfo(), MemVT == MVT::i32 ? 8 : 16);
13716 
13717   if (LSNode1->getOpcode() == ISD::LOAD) {
13718     auto Ext = cast<LoadSDNode>(LSNode1)->getExtensionType();
13719     unsigned Opcode;
13720     if (MemVT == MVT::i32)
13721       Opcode = (Ext == ISD::ZEXTLOAD) ? RISCVISD::TH_LWUD : RISCVISD::TH_LWD;
13722     else
13723       Opcode = RISCVISD::TH_LDD;
13724 
13725     SDValue Res = DAG.getMemIntrinsicNode(
13726         Opcode, SDLoc(LSNode1), DAG.getVTList({XLenVT, XLenVT, MVT::Other}),
13727         {LSNode1->getChain(), BasePtr,
13728          DAG.getConstant(Imm, SDLoc(LSNode1), XLenVT)},
13729         NewMemVT, NewMMO);
13730 
13731     SDValue Node1 =
13732         DAG.getMergeValues({Res.getValue(0), Res.getValue(2)}, SDLoc(LSNode1));
13733     SDValue Node2 =
13734         DAG.getMergeValues({Res.getValue(1), Res.getValue(2)}, SDLoc(LSNode2));
13735 
13736     DAG.ReplaceAllUsesWith(LSNode2, Node2.getNode());
13737     return Node1;
13738   } else {
13739     unsigned Opcode = (MemVT == MVT::i32) ? RISCVISD::TH_SWD : RISCVISD::TH_SDD;
13740 
13741     SDValue Res = DAG.getMemIntrinsicNode(
13742         Opcode, SDLoc(LSNode1), DAG.getVTList(MVT::Other),
13743         {LSNode1->getChain(), LSNode1->getOperand(1), LSNode2->getOperand(1),
13744          BasePtr, DAG.getConstant(Imm, SDLoc(LSNode1), XLenVT)},
13745         NewMemVT, NewMMO);
13746 
13747     DAG.ReplaceAllUsesWith(LSNode2, Res.getNode());
13748     return Res;
13749   }
13750 }
13751 
13752 // Try to combine two adjacent loads/stores to a single pair instruction from
13753 // the XTHeadMemPair vendor extension.
performMemPairCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)13754 static SDValue performMemPairCombine(SDNode *N,
13755                                      TargetLowering::DAGCombinerInfo &DCI) {
13756   SelectionDAG &DAG = DCI.DAG;
13757   MachineFunction &MF = DAG.getMachineFunction();
13758   const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
13759 
13760   // Target does not support load/store pair.
13761   if (!Subtarget.hasVendorXTHeadMemPair())
13762     return SDValue();
13763 
13764   LSBaseSDNode *LSNode1 = cast<LSBaseSDNode>(N);
13765   EVT MemVT = LSNode1->getMemoryVT();
13766   unsigned OpNum = LSNode1->getOpcode() == ISD::LOAD ? 1 : 2;
13767 
13768   // No volatile, indexed or atomic loads/stores.
13769   if (!LSNode1->isSimple() || LSNode1->isIndexed())
13770     return SDValue();
13771 
13772   // Function to get a base + constant representation from a memory value.
13773   auto ExtractBaseAndOffset = [](SDValue Ptr) -> std::pair<SDValue, uint64_t> {
13774     if (Ptr->getOpcode() == ISD::ADD)
13775       if (auto *C1 = dyn_cast<ConstantSDNode>(Ptr->getOperand(1)))
13776         return {Ptr->getOperand(0), C1->getZExtValue()};
13777     return {Ptr, 0};
13778   };
13779 
13780   auto [Base1, Offset1] = ExtractBaseAndOffset(LSNode1->getOperand(OpNum));
13781 
13782   SDValue Chain = N->getOperand(0);
13783   for (SDNode::use_iterator UI = Chain->use_begin(), UE = Chain->use_end();
13784        UI != UE; ++UI) {
13785     SDUse &Use = UI.getUse();
13786     if (Use.getUser() != N && Use.getResNo() == 0 &&
13787         Use.getUser()->getOpcode() == N->getOpcode()) {
13788       LSBaseSDNode *LSNode2 = cast<LSBaseSDNode>(Use.getUser());
13789 
13790       // No volatile, indexed or atomic loads/stores.
13791       if (!LSNode2->isSimple() || LSNode2->isIndexed())
13792         continue;
13793 
13794       // Check if LSNode1 and LSNode2 have the same type and extension.
13795       if (LSNode1->getOpcode() == ISD::LOAD)
13796         if (cast<LoadSDNode>(LSNode2)->getExtensionType() !=
13797             cast<LoadSDNode>(LSNode1)->getExtensionType())
13798           continue;
13799 
13800       if (LSNode1->getMemoryVT() != LSNode2->getMemoryVT())
13801         continue;
13802 
13803       auto [Base2, Offset2] = ExtractBaseAndOffset(LSNode2->getOperand(OpNum));
13804 
13805       // Check if the base pointer is the same for both instruction.
13806       if (Base1 != Base2)
13807         continue;
13808 
13809       // Check if the offsets match the XTHeadMemPair encoding contraints.
13810       bool Valid = false;
13811       if (MemVT == MVT::i32) {
13812         // Check for adjacent i32 values and a 2-bit index.
13813         if ((Offset1 + 4 == Offset2) && isShiftedUInt<2, 3>(Offset1))
13814           Valid = true;
13815       } else if (MemVT == MVT::i64) {
13816         // Check for adjacent i64 values and a 2-bit index.
13817         if ((Offset1 + 8 == Offset2) && isShiftedUInt<2, 4>(Offset1))
13818           Valid = true;
13819       }
13820 
13821       if (!Valid)
13822         continue;
13823 
13824       // Try to combine.
13825       if (SDValue Res =
13826               tryMemPairCombine(DAG, LSNode1, LSNode2, Base1, Offset1))
13827         return Res;
13828     }
13829   }
13830 
13831   return SDValue();
13832 }
13833 
13834 // Fold
13835 //   (fp_to_int (froundeven X)) -> fcvt X, rne
13836 //   (fp_to_int (ftrunc X))     -> fcvt X, rtz
13837 //   (fp_to_int (ffloor X))     -> fcvt X, rdn
13838 //   (fp_to_int (fceil X))      -> fcvt X, rup
13839 //   (fp_to_int (fround X))     -> fcvt X, rmm
13840 //   (fp_to_int (frint X))      -> fcvt X
performFP_TO_INTCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)13841 static SDValue performFP_TO_INTCombine(SDNode *N,
13842                                        TargetLowering::DAGCombinerInfo &DCI,
13843                                        const RISCVSubtarget &Subtarget) {
13844   SelectionDAG &DAG = DCI.DAG;
13845   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13846   MVT XLenVT = Subtarget.getXLenVT();
13847 
13848   SDValue Src = N->getOperand(0);
13849 
13850   // Don't do this for strict-fp Src.
13851   if (Src->isStrictFPOpcode() || Src->isTargetStrictFPOpcode())
13852     return SDValue();
13853 
13854   // Ensure the FP type is legal.
13855   if (!TLI.isTypeLegal(Src.getValueType()))
13856     return SDValue();
13857 
13858   // Don't do this for f16 with Zfhmin and not Zfh.
13859   if (Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh())
13860     return SDValue();
13861 
13862   RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src.getOpcode());
13863   // If the result is invalid, we didn't find a foldable instruction.
13864   if (FRM == RISCVFPRndMode::Invalid)
13865     return SDValue();
13866 
13867   SDLoc DL(N);
13868   bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT;
13869   EVT VT = N->getValueType(0);
13870 
13871   if (VT.isVector() && TLI.isTypeLegal(VT)) {
13872     MVT SrcVT = Src.getSimpleValueType();
13873     MVT SrcContainerVT = SrcVT;
13874     MVT ContainerVT = VT.getSimpleVT();
13875     SDValue XVal = Src.getOperand(0);
13876 
13877     // For widening and narrowing conversions we just combine it into a
13878     // VFCVT_..._VL node, as there are no specific VFWCVT/VFNCVT VL nodes. They
13879     // end up getting lowered to their appropriate pseudo instructions based on
13880     // their operand types
13881     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits() * 2 ||
13882         VT.getScalarSizeInBits() * 2 < SrcVT.getScalarSizeInBits())
13883       return SDValue();
13884 
13885     // Make fixed-length vectors scalable first
13886     if (SrcVT.isFixedLengthVector()) {
13887       SrcContainerVT = getContainerForFixedLengthVector(DAG, SrcVT, Subtarget);
13888       XVal = convertToScalableVector(SrcContainerVT, XVal, DAG, Subtarget);
13889       ContainerVT =
13890           getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget);
13891     }
13892 
13893     auto [Mask, VL] =
13894         getDefaultVLOps(SrcVT, SrcContainerVT, DL, DAG, Subtarget);
13895 
13896     SDValue FpToInt;
13897     if (FRM == RISCVFPRndMode::RTZ) {
13898       // Use the dedicated trunc static rounding mode if we're truncating so we
13899       // don't need to generate calls to fsrmi/fsrm
13900       unsigned Opc =
13901           IsSigned ? RISCVISD::VFCVT_RTZ_X_F_VL : RISCVISD::VFCVT_RTZ_XU_F_VL;
13902       FpToInt = DAG.getNode(Opc, DL, ContainerVT, XVal, Mask, VL);
13903     } else if (FRM == RISCVFPRndMode::DYN) {
13904       unsigned Opc =
13905           IsSigned ? RISCVISD::VFCVT_X_F_VL : RISCVISD::VFCVT_XU_F_VL;
13906       FpToInt = DAG.getNode(Opc, DL, ContainerVT, XVal, Mask, VL);
13907     } else {
13908       unsigned Opc =
13909           IsSigned ? RISCVISD::VFCVT_RM_X_F_VL : RISCVISD::VFCVT_RM_XU_F_VL;
13910       FpToInt = DAG.getNode(Opc, DL, ContainerVT, XVal, Mask,
13911                             DAG.getTargetConstant(FRM, DL, XLenVT), VL);
13912     }
13913 
13914     // If converted from fixed-length to scalable, convert back
13915     if (VT.isFixedLengthVector())
13916       FpToInt = convertFromScalableVector(VT, FpToInt, DAG, Subtarget);
13917 
13918     return FpToInt;
13919   }
13920 
13921   // Only handle XLen or i32 types. Other types narrower than XLen will
13922   // eventually be legalized to XLenVT.
13923   if (VT != MVT::i32 && VT != XLenVT)
13924     return SDValue();
13925 
13926   unsigned Opc;
13927   if (VT == XLenVT)
13928     Opc = IsSigned ? RISCVISD::FCVT_X : RISCVISD::FCVT_XU;
13929   else
13930     Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
13931 
13932   SDValue FpToInt = DAG.getNode(Opc, DL, XLenVT, Src.getOperand(0),
13933                                 DAG.getTargetConstant(FRM, DL, XLenVT));
13934   return DAG.getNode(ISD::TRUNCATE, DL, VT, FpToInt);
13935 }
13936 
13937 // Fold
13938 //   (fp_to_int_sat (froundeven X)) -> (select X == nan, 0, (fcvt X, rne))
13939 //   (fp_to_int_sat (ftrunc X))     -> (select X == nan, 0, (fcvt X, rtz))
13940 //   (fp_to_int_sat (ffloor X))     -> (select X == nan, 0, (fcvt X, rdn))
13941 //   (fp_to_int_sat (fceil X))      -> (select X == nan, 0, (fcvt X, rup))
13942 //   (fp_to_int_sat (fround X))     -> (select X == nan, 0, (fcvt X, rmm))
13943 //   (fp_to_int_sat (frint X))      -> (select X == nan, 0, (fcvt X, dyn))
performFP_TO_INT_SATCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)13944 static SDValue performFP_TO_INT_SATCombine(SDNode *N,
13945                                        TargetLowering::DAGCombinerInfo &DCI,
13946                                        const RISCVSubtarget &Subtarget) {
13947   SelectionDAG &DAG = DCI.DAG;
13948   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13949   MVT XLenVT = Subtarget.getXLenVT();
13950 
13951   // Only handle XLen types. Other types narrower than XLen will eventually be
13952   // legalized to XLenVT.
13953   EVT DstVT = N->getValueType(0);
13954   if (DstVT != XLenVT)
13955     return SDValue();
13956 
13957   SDValue Src = N->getOperand(0);
13958 
13959   // Don't do this for strict-fp Src.
13960   if (Src->isStrictFPOpcode() || Src->isTargetStrictFPOpcode())
13961     return SDValue();
13962 
13963   // Ensure the FP type is also legal.
13964   if (!TLI.isTypeLegal(Src.getValueType()))
13965     return SDValue();
13966 
13967   // Don't do this for f16 with Zfhmin and not Zfh.
13968   if (Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh())
13969     return SDValue();
13970 
13971   EVT SatVT = cast<VTSDNode>(N->getOperand(1))->getVT();
13972 
13973   RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src.getOpcode());
13974   if (FRM == RISCVFPRndMode::Invalid)
13975     return SDValue();
13976 
13977   bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT_SAT;
13978 
13979   unsigned Opc;
13980   if (SatVT == DstVT)
13981     Opc = IsSigned ? RISCVISD::FCVT_X : RISCVISD::FCVT_XU;
13982   else if (DstVT == MVT::i64 && SatVT == MVT::i32)
13983     Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
13984   else
13985     return SDValue();
13986   // FIXME: Support other SatVTs by clamping before or after the conversion.
13987 
13988   Src = Src.getOperand(0);
13989 
13990   SDLoc DL(N);
13991   SDValue FpToInt = DAG.getNode(Opc, DL, XLenVT, Src,
13992                                 DAG.getTargetConstant(FRM, DL, XLenVT));
13993 
13994   // fcvt.wu.* sign extends bit 31 on RV64. FP_TO_UINT_SAT expects to zero
13995   // extend.
13996   if (Opc == RISCVISD::FCVT_WU_RV64)
13997     FpToInt = DAG.getZeroExtendInReg(FpToInt, DL, MVT::i32);
13998 
13999   // RISC-V FP-to-int conversions saturate to the destination register size, but
14000   // don't produce 0 for nan.
14001   SDValue ZeroInt = DAG.getConstant(0, DL, DstVT);
14002   return DAG.getSelectCC(DL, Src, Src, ZeroInt, FpToInt, ISD::CondCode::SETUO);
14003 }
14004 
14005 // Combine (bitreverse (bswap X)) to the BREV8 GREVI encoding if the type is
14006 // smaller than XLenVT.
performBITREVERSECombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14007 static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
14008                                         const RISCVSubtarget &Subtarget) {
14009   assert(Subtarget.hasStdExtZbkb() && "Unexpected extension");
14010 
14011   SDValue Src = N->getOperand(0);
14012   if (Src.getOpcode() != ISD::BSWAP)
14013     return SDValue();
14014 
14015   EVT VT = N->getValueType(0);
14016   if (!VT.isScalarInteger() || VT.getSizeInBits() >= Subtarget.getXLen() ||
14017       !llvm::has_single_bit<uint32_t>(VT.getSizeInBits()))
14018     return SDValue();
14019 
14020   SDLoc DL(N);
14021   return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
14022 }
14023 
14024 // Convert from one FMA opcode to another based on whether we are negating the
14025 // multiply result and/or the accumulator.
14026 // NOTE: Only supports RVV operations with VL.
negateFMAOpcode(unsigned Opcode,bool NegMul,bool NegAcc)14027 static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
14028   // Negating the multiply result changes ADD<->SUB and toggles 'N'.
14029   if (NegMul) {
14030     // clang-format off
14031     switch (Opcode) {
14032     default: llvm_unreachable("Unexpected opcode");
14033     case RISCVISD::VFMADD_VL:  Opcode = RISCVISD::VFNMSUB_VL; break;
14034     case RISCVISD::VFNMSUB_VL: Opcode = RISCVISD::VFMADD_VL;  break;
14035     case RISCVISD::VFNMADD_VL: Opcode = RISCVISD::VFMSUB_VL;  break;
14036     case RISCVISD::VFMSUB_VL:  Opcode = RISCVISD::VFNMADD_VL; break;
14037     case RISCVISD::STRICT_VFMADD_VL:  Opcode = RISCVISD::STRICT_VFNMSUB_VL; break;
14038     case RISCVISD::STRICT_VFNMSUB_VL: Opcode = RISCVISD::STRICT_VFMADD_VL;  break;
14039     case RISCVISD::STRICT_VFNMADD_VL: Opcode = RISCVISD::STRICT_VFMSUB_VL;  break;
14040     case RISCVISD::STRICT_VFMSUB_VL:  Opcode = RISCVISD::STRICT_VFNMADD_VL; break;
14041     }
14042     // clang-format on
14043   }
14044 
14045   // Negating the accumulator changes ADD<->SUB.
14046   if (NegAcc) {
14047     // clang-format off
14048     switch (Opcode) {
14049     default: llvm_unreachable("Unexpected opcode");
14050     case RISCVISD::VFMADD_VL:  Opcode = RISCVISD::VFMSUB_VL;  break;
14051     case RISCVISD::VFMSUB_VL:  Opcode = RISCVISD::VFMADD_VL;  break;
14052     case RISCVISD::VFNMADD_VL: Opcode = RISCVISD::VFNMSUB_VL; break;
14053     case RISCVISD::VFNMSUB_VL: Opcode = RISCVISD::VFNMADD_VL; break;
14054     case RISCVISD::STRICT_VFMADD_VL:  Opcode = RISCVISD::STRICT_VFMSUB_VL;  break;
14055     case RISCVISD::STRICT_VFMSUB_VL:  Opcode = RISCVISD::STRICT_VFMADD_VL;  break;
14056     case RISCVISD::STRICT_VFNMADD_VL: Opcode = RISCVISD::STRICT_VFNMSUB_VL; break;
14057     case RISCVISD::STRICT_VFNMSUB_VL: Opcode = RISCVISD::STRICT_VFNMADD_VL; break;
14058     }
14059     // clang-format on
14060   }
14061 
14062   return Opcode;
14063 }
14064 
combineVFMADD_VLWithVFNEG_VL(SDNode * N,SelectionDAG & DAG)14065 static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
14066   // Fold FNEG_VL into FMA opcodes.
14067   // The first operand of strict-fp is chain.
14068   unsigned Offset = N->isTargetStrictFPOpcode();
14069   SDValue A = N->getOperand(0 + Offset);
14070   SDValue B = N->getOperand(1 + Offset);
14071   SDValue C = N->getOperand(2 + Offset);
14072   SDValue Mask = N->getOperand(3 + Offset);
14073   SDValue VL = N->getOperand(4 + Offset);
14074 
14075   auto invertIfNegative = [&Mask, &VL](SDValue &V) {
14076     if (V.getOpcode() == RISCVISD::FNEG_VL && V.getOperand(1) == Mask &&
14077         V.getOperand(2) == VL) {
14078       // Return the negated input.
14079       V = V.getOperand(0);
14080       return true;
14081     }
14082 
14083     return false;
14084   };
14085 
14086   bool NegA = invertIfNegative(A);
14087   bool NegB = invertIfNegative(B);
14088   bool NegC = invertIfNegative(C);
14089 
14090   // If no operands are negated, we're done.
14091   if (!NegA && !NegB && !NegC)
14092     return SDValue();
14093 
14094   unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), NegA != NegB, NegC);
14095   if (N->isTargetStrictFPOpcode())
14096     return DAG.getNode(NewOpcode, SDLoc(N), N->getVTList(),
14097                        {N->getOperand(0), A, B, C, Mask, VL});
14098   return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), A, B, C, Mask,
14099                      VL);
14100 }
14101 
performVFMADD_VLCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14102 static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
14103                                        const RISCVSubtarget &Subtarget) {
14104   if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
14105     return V;
14106 
14107   if (N->getValueType(0).isScalableVector() &&
14108       N->getValueType(0).getVectorElementType() == MVT::f32 &&
14109       (Subtarget.hasVInstructionsF16Minimal() &&
14110        !Subtarget.hasVInstructionsF16())) {
14111     return SDValue();
14112   }
14113 
14114   // FIXME: Ignore strict opcodes for now.
14115   if (N->isTargetStrictFPOpcode())
14116     return SDValue();
14117 
14118   // Try to form widening FMA.
14119   SDValue Op0 = N->getOperand(0);
14120   SDValue Op1 = N->getOperand(1);
14121   SDValue Mask = N->getOperand(3);
14122   SDValue VL = N->getOperand(4);
14123 
14124   if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
14125       Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
14126     return SDValue();
14127 
14128   // TODO: Refactor to handle more complex cases similar to
14129   // combineBinOp_VLToVWBinOp_VL.
14130   if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
14131       (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
14132     return SDValue();
14133 
14134   // Check the mask and VL are the same.
14135   if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
14136       Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
14137     return SDValue();
14138 
14139   unsigned NewOpc;
14140   switch (N->getOpcode()) {
14141   default:
14142     llvm_unreachable("Unexpected opcode");
14143   case RISCVISD::VFMADD_VL:
14144     NewOpc = RISCVISD::VFWMADD_VL;
14145     break;
14146   case RISCVISD::VFNMSUB_VL:
14147     NewOpc = RISCVISD::VFWNMSUB_VL;
14148     break;
14149   case RISCVISD::VFNMADD_VL:
14150     NewOpc = RISCVISD::VFWNMADD_VL;
14151     break;
14152   case RISCVISD::VFMSUB_VL:
14153     NewOpc = RISCVISD::VFWMSUB_VL;
14154     break;
14155   }
14156 
14157   Op0 = Op0.getOperand(0);
14158   Op1 = Op1.getOperand(0);
14159 
14160   return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
14161                      N->getOperand(2), Mask, VL);
14162 }
14163 
performVFMUL_VLCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14164 static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG,
14165                                       const RISCVSubtarget &Subtarget) {
14166   if (N->getValueType(0).isScalableVector() &&
14167       N->getValueType(0).getVectorElementType() == MVT::f32 &&
14168       (Subtarget.hasVInstructionsF16Minimal() &&
14169        !Subtarget.hasVInstructionsF16())) {
14170     return SDValue();
14171   }
14172 
14173   // FIXME: Ignore strict opcodes for now.
14174   assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode");
14175 
14176   // Try to form widening multiply.
14177   SDValue Op0 = N->getOperand(0);
14178   SDValue Op1 = N->getOperand(1);
14179   SDValue Merge = N->getOperand(2);
14180   SDValue Mask = N->getOperand(3);
14181   SDValue VL = N->getOperand(4);
14182 
14183   if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
14184       Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
14185     return SDValue();
14186 
14187   // TODO: Refactor to handle more complex cases similar to
14188   // combineBinOp_VLToVWBinOp_VL.
14189   if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
14190       (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
14191     return SDValue();
14192 
14193   // Check the mask and VL are the same.
14194   if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
14195       Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
14196     return SDValue();
14197 
14198   Op0 = Op0.getOperand(0);
14199   Op1 = Op1.getOperand(0);
14200 
14201   return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0,
14202                      Op1, Merge, Mask, VL);
14203 }
14204 
performFADDSUB_VLCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14205 static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG,
14206                                         const RISCVSubtarget &Subtarget) {
14207   if (N->getValueType(0).isScalableVector() &&
14208       N->getValueType(0).getVectorElementType() == MVT::f32 &&
14209       (Subtarget.hasVInstructionsF16Minimal() &&
14210        !Subtarget.hasVInstructionsF16())) {
14211     return SDValue();
14212   }
14213 
14214   SDValue Op0 = N->getOperand(0);
14215   SDValue Op1 = N->getOperand(1);
14216   SDValue Merge = N->getOperand(2);
14217   SDValue Mask = N->getOperand(3);
14218   SDValue VL = N->getOperand(4);
14219 
14220   bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL;
14221 
14222   // Look for foldable FP_EXTENDS.
14223   bool Op0IsExtend =
14224       Op0.getOpcode() == RISCVISD::FP_EXTEND_VL &&
14225       (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0)));
14226   bool Op1IsExtend =
14227       (Op0 == Op1 && Op0IsExtend) ||
14228       (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse());
14229 
14230   // Check the mask and VL.
14231   if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL))
14232     Op0IsExtend = false;
14233   if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL))
14234     Op1IsExtend = false;
14235 
14236   // Canonicalize.
14237   if (!Op1IsExtend) {
14238     // Sub requires at least operand 1 to be an extend.
14239     if (!IsAdd)
14240       return SDValue();
14241 
14242     // Add is commutable, if the other operand is foldable, swap them.
14243     if (!Op0IsExtend)
14244       return SDValue();
14245 
14246     std::swap(Op0, Op1);
14247     std::swap(Op0IsExtend, Op1IsExtend);
14248   }
14249 
14250   // Op1 is a foldable extend. Op0 might be foldable.
14251   Op1 = Op1.getOperand(0);
14252   if (Op0IsExtend)
14253     Op0 = Op0.getOperand(0);
14254 
14255   unsigned Opc;
14256   if (IsAdd)
14257     Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL;
14258   else
14259     Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL;
14260 
14261   return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask,
14262                      VL);
14263 }
14264 
performSRACombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14265 static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
14266                                  const RISCVSubtarget &Subtarget) {
14267   assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
14268 
14269   if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit())
14270     return SDValue();
14271 
14272   if (!isa<ConstantSDNode>(N->getOperand(1)))
14273     return SDValue();
14274   uint64_t ShAmt = N->getConstantOperandVal(1);
14275   if (ShAmt > 32)
14276     return SDValue();
14277 
14278   SDValue N0 = N->getOperand(0);
14279 
14280   // Combine (sra (sext_inreg (shl X, C1), i32), C2) ->
14281   // (sra (shl X, C1+32), C2+32) so it gets selected as SLLI+SRAI instead of
14282   // SLLIW+SRAIW. SLLI+SRAI have compressed forms.
14283   if (ShAmt < 32 &&
14284       N0.getOpcode() == ISD::SIGN_EXTEND_INREG && N0.hasOneUse() &&
14285       cast<VTSDNode>(N0.getOperand(1))->getVT() == MVT::i32 &&
14286       N0.getOperand(0).getOpcode() == ISD::SHL && N0.getOperand(0).hasOneUse() &&
14287       isa<ConstantSDNode>(N0.getOperand(0).getOperand(1))) {
14288     uint64_t LShAmt = N0.getOperand(0).getConstantOperandVal(1);
14289     if (LShAmt < 32) {
14290       SDLoc ShlDL(N0.getOperand(0));
14291       SDValue Shl = DAG.getNode(ISD::SHL, ShlDL, MVT::i64,
14292                                 N0.getOperand(0).getOperand(0),
14293                                 DAG.getConstant(LShAmt + 32, ShlDL, MVT::i64));
14294       SDLoc DL(N);
14295       return DAG.getNode(ISD::SRA, DL, MVT::i64, Shl,
14296                          DAG.getConstant(ShAmt + 32, DL, MVT::i64));
14297     }
14298   }
14299 
14300   // Combine (sra (shl X, 32), 32 - C) -> (shl (sext_inreg X, i32), C)
14301   // FIXME: Should this be a generic combine? There's a similar combine on X86.
14302   //
14303   // Also try these folds where an add or sub is in the middle.
14304   // (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C)
14305   // (sra (sub C1, (shl X, 32)), 32 - C) -> (shl (sext_inreg (sub C1, X), C)
14306   SDValue Shl;
14307   ConstantSDNode *AddC = nullptr;
14308 
14309   // We might have an ADD or SUB between the SRA and SHL.
14310   bool IsAdd = N0.getOpcode() == ISD::ADD;
14311   if ((IsAdd || N0.getOpcode() == ISD::SUB)) {
14312     // Other operand needs to be a constant we can modify.
14313     AddC = dyn_cast<ConstantSDNode>(N0.getOperand(IsAdd ? 1 : 0));
14314     if (!AddC)
14315       return SDValue();
14316 
14317     // AddC needs to have at least 32 trailing zeros.
14318     if (AddC->getAPIntValue().countr_zero() < 32)
14319       return SDValue();
14320 
14321     // All users should be a shift by constant less than or equal to 32. This
14322     // ensures we'll do this optimization for each of them to produce an
14323     // add/sub+sext_inreg they can all share.
14324     for (SDNode *U : N0->uses()) {
14325       if (U->getOpcode() != ISD::SRA ||
14326           !isa<ConstantSDNode>(U->getOperand(1)) ||
14327           U->getConstantOperandVal(1) > 32)
14328         return SDValue();
14329     }
14330 
14331     Shl = N0.getOperand(IsAdd ? 0 : 1);
14332   } else {
14333     // Not an ADD or SUB.
14334     Shl = N0;
14335   }
14336 
14337   // Look for a shift left by 32.
14338   if (Shl.getOpcode() != ISD::SHL || !isa<ConstantSDNode>(Shl.getOperand(1)) ||
14339       Shl.getConstantOperandVal(1) != 32)
14340     return SDValue();
14341 
14342   // We if we didn't look through an add/sub, then the shl should have one use.
14343   // If we did look through an add/sub, the sext_inreg we create is free so
14344   // we're only creating 2 new instructions. It's enough to only remove the
14345   // original sra+add/sub.
14346   if (!AddC && !Shl.hasOneUse())
14347     return SDValue();
14348 
14349   SDLoc DL(N);
14350   SDValue In = Shl.getOperand(0);
14351 
14352   // If we looked through an ADD or SUB, we need to rebuild it with the shifted
14353   // constant.
14354   if (AddC) {
14355     SDValue ShiftedAddC =
14356         DAG.getConstant(AddC->getAPIntValue().lshr(32), DL, MVT::i64);
14357     if (IsAdd)
14358       In = DAG.getNode(ISD::ADD, DL, MVT::i64, In, ShiftedAddC);
14359     else
14360       In = DAG.getNode(ISD::SUB, DL, MVT::i64, ShiftedAddC, In);
14361   }
14362 
14363   SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, In,
14364                              DAG.getValueType(MVT::i32));
14365   if (ShAmt == 32)
14366     return SExt;
14367 
14368   return DAG.getNode(
14369       ISD::SHL, DL, MVT::i64, SExt,
14370       DAG.getConstant(32 - ShAmt, DL, MVT::i64));
14371 }
14372 
14373 // Invert (and/or (set cc X, Y), (xor Z, 1)) to (or/and (set !cc X, Y)), Z) if
14374 // the result is used as the conditon of a br_cc or select_cc we can invert,
14375 // inverting the setcc is free, and Z is 0/1. Caller will invert the
14376 // br_cc/select_cc.
tryDemorganOfBooleanCondition(SDValue Cond,SelectionDAG & DAG)14377 static SDValue tryDemorganOfBooleanCondition(SDValue Cond, SelectionDAG &DAG) {
14378   bool IsAnd = Cond.getOpcode() == ISD::AND;
14379   if (!IsAnd && Cond.getOpcode() != ISD::OR)
14380     return SDValue();
14381 
14382   if (!Cond.hasOneUse())
14383     return SDValue();
14384 
14385   SDValue Setcc = Cond.getOperand(0);
14386   SDValue Xor = Cond.getOperand(1);
14387   // Canonicalize setcc to LHS.
14388   if (Setcc.getOpcode() != ISD::SETCC)
14389     std::swap(Setcc, Xor);
14390   // LHS should be a setcc and RHS should be an xor.
14391   if (Setcc.getOpcode() != ISD::SETCC || !Setcc.hasOneUse() ||
14392       Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
14393     return SDValue();
14394 
14395   // If the condition is an And, SimplifyDemandedBits may have changed
14396   // (xor Z, 1) to (not Z).
14397   SDValue Xor1 = Xor.getOperand(1);
14398   if (!isOneConstant(Xor1) && !(IsAnd && isAllOnesConstant(Xor1)))
14399     return SDValue();
14400 
14401   EVT VT = Cond.getValueType();
14402   SDValue Xor0 = Xor.getOperand(0);
14403 
14404   // The LHS of the xor needs to be 0/1.
14405   APInt Mask = APInt::getBitsSetFrom(VT.getSizeInBits(), 1);
14406   if (!DAG.MaskedValueIsZero(Xor0, Mask))
14407     return SDValue();
14408 
14409   // We can only invert integer setccs.
14410   EVT SetCCOpVT = Setcc.getOperand(0).getValueType();
14411   if (!SetCCOpVT.isScalarInteger())
14412     return SDValue();
14413 
14414   ISD::CondCode CCVal = cast<CondCodeSDNode>(Setcc.getOperand(2))->get();
14415   if (ISD::isIntEqualitySetCC(CCVal)) {
14416     CCVal = ISD::getSetCCInverse(CCVal, SetCCOpVT);
14417     Setcc = DAG.getSetCC(SDLoc(Setcc), VT, Setcc.getOperand(0),
14418                          Setcc.getOperand(1), CCVal);
14419   } else if (CCVal == ISD::SETLT && isNullConstant(Setcc.getOperand(0))) {
14420     // Invert (setlt 0, X) by converting to (setlt X, 1).
14421     Setcc = DAG.getSetCC(SDLoc(Setcc), VT, Setcc.getOperand(1),
14422                          DAG.getConstant(1, SDLoc(Setcc), VT), CCVal);
14423   } else if (CCVal == ISD::SETLT && isOneConstant(Setcc.getOperand(1))) {
14424     // (setlt X, 1) by converting to (setlt 0, X).
14425     Setcc = DAG.getSetCC(SDLoc(Setcc), VT,
14426                          DAG.getConstant(0, SDLoc(Setcc), VT),
14427                          Setcc.getOperand(0), CCVal);
14428   } else
14429     return SDValue();
14430 
14431   unsigned Opc = IsAnd ? ISD::OR : ISD::AND;
14432   return DAG.getNode(Opc, SDLoc(Cond), VT, Setcc, Xor.getOperand(0));
14433 }
14434 
14435 // Perform common combines for BR_CC and SELECT_CC condtions.
combine_CC(SDValue & LHS,SDValue & RHS,SDValue & CC,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14436 static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
14437                        SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
14438   ISD::CondCode CCVal = cast<CondCodeSDNode>(CC)->get();
14439 
14440   // As far as arithmetic right shift always saves the sign,
14441   // shift can be omitted.
14442   // Fold setlt (sra X, N), 0 -> setlt X, 0 and
14443   // setge (sra X, N), 0 -> setge X, 0
14444   if (isNullConstant(RHS) && (CCVal == ISD::SETGE || CCVal == ISD::SETLT) &&
14445       LHS.getOpcode() == ISD::SRA) {
14446     LHS = LHS.getOperand(0);
14447     return true;
14448   }
14449 
14450   if (!ISD::isIntEqualitySetCC(CCVal))
14451     return false;
14452 
14453   // Fold ((setlt X, Y), 0, ne) -> (X, Y, lt)
14454   // Sometimes the setcc is introduced after br_cc/select_cc has been formed.
14455   if (LHS.getOpcode() == ISD::SETCC && isNullConstant(RHS) &&
14456       LHS.getOperand(0).getValueType() == Subtarget.getXLenVT()) {
14457     // If we're looking for eq 0 instead of ne 0, we need to invert the
14458     // condition.
14459     bool Invert = CCVal == ISD::SETEQ;
14460     CCVal = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
14461     if (Invert)
14462       CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
14463 
14464     RHS = LHS.getOperand(1);
14465     LHS = LHS.getOperand(0);
14466     translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
14467 
14468     CC = DAG.getCondCode(CCVal);
14469     return true;
14470   }
14471 
14472   // Fold ((xor X, Y), 0, eq/ne) -> (X, Y, eq/ne)
14473   if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) {
14474     RHS = LHS.getOperand(1);
14475     LHS = LHS.getOperand(0);
14476     return true;
14477   }
14478 
14479   // Fold ((srl (and X, 1<<C), C), 0, eq/ne) -> ((shl X, XLen-1-C), 0, ge/lt)
14480   if (isNullConstant(RHS) && LHS.getOpcode() == ISD::SRL && LHS.hasOneUse() &&
14481       LHS.getOperand(1).getOpcode() == ISD::Constant) {
14482     SDValue LHS0 = LHS.getOperand(0);
14483     if (LHS0.getOpcode() == ISD::AND &&
14484         LHS0.getOperand(1).getOpcode() == ISD::Constant) {
14485       uint64_t Mask = LHS0.getConstantOperandVal(1);
14486       uint64_t ShAmt = LHS.getConstantOperandVal(1);
14487       if (isPowerOf2_64(Mask) && Log2_64(Mask) == ShAmt) {
14488         CCVal = CCVal == ISD::SETEQ ? ISD::SETGE : ISD::SETLT;
14489         CC = DAG.getCondCode(CCVal);
14490 
14491         ShAmt = LHS.getValueSizeInBits() - 1 - ShAmt;
14492         LHS = LHS0.getOperand(0);
14493         if (ShAmt != 0)
14494           LHS =
14495               DAG.getNode(ISD::SHL, DL, LHS.getValueType(), LHS0.getOperand(0),
14496                           DAG.getConstant(ShAmt, DL, LHS.getValueType()));
14497         return true;
14498       }
14499     }
14500   }
14501 
14502   // (X, 1, setne) -> // (X, 0, seteq) if we can prove X is 0/1.
14503   // This can occur when legalizing some floating point comparisons.
14504   APInt Mask = APInt::getBitsSetFrom(LHS.getValueSizeInBits(), 1);
14505   if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) {
14506     CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
14507     CC = DAG.getCondCode(CCVal);
14508     RHS = DAG.getConstant(0, DL, LHS.getValueType());
14509     return true;
14510   }
14511 
14512   if (isNullConstant(RHS)) {
14513     if (SDValue NewCond = tryDemorganOfBooleanCondition(LHS, DAG)) {
14514       CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
14515       CC = DAG.getCondCode(CCVal);
14516       LHS = NewCond;
14517       return true;
14518     }
14519   }
14520 
14521   return false;
14522 }
14523 
14524 // Fold
14525 // (select C, (add Y, X), Y) -> (add Y, (select C, X, 0)).
14526 // (select C, (sub Y, X), Y) -> (sub Y, (select C, X, 0)).
14527 // (select C, (or Y, X), Y)  -> (or Y, (select C, X, 0)).
14528 // (select C, (xor Y, X), Y) -> (xor Y, (select C, X, 0)).
tryFoldSelectIntoOp(SDNode * N,SelectionDAG & DAG,SDValue TrueVal,SDValue FalseVal,bool Swapped)14529 static SDValue tryFoldSelectIntoOp(SDNode *N, SelectionDAG &DAG,
14530                                    SDValue TrueVal, SDValue FalseVal,
14531                                    bool Swapped) {
14532   bool Commutative = true;
14533   unsigned Opc = TrueVal.getOpcode();
14534   switch (Opc) {
14535   default:
14536     return SDValue();
14537   case ISD::SHL:
14538   case ISD::SRA:
14539   case ISD::SRL:
14540   case ISD::SUB:
14541     Commutative = false;
14542     break;
14543   case ISD::ADD:
14544   case ISD::OR:
14545   case ISD::XOR:
14546     break;
14547   }
14548 
14549   if (!TrueVal.hasOneUse() || isa<ConstantSDNode>(FalseVal))
14550     return SDValue();
14551 
14552   unsigned OpToFold;
14553   if (FalseVal == TrueVal.getOperand(0))
14554     OpToFold = 0;
14555   else if (Commutative && FalseVal == TrueVal.getOperand(1))
14556     OpToFold = 1;
14557   else
14558     return SDValue();
14559 
14560   EVT VT = N->getValueType(0);
14561   SDLoc DL(N);
14562   SDValue OtherOp = TrueVal.getOperand(1 - OpToFold);
14563   EVT OtherOpVT = OtherOp.getValueType();
14564   SDValue IdentityOperand =
14565       DAG.getNeutralElement(Opc, DL, OtherOpVT, N->getFlags());
14566   if (!Commutative)
14567     IdentityOperand = DAG.getConstant(0, DL, OtherOpVT);
14568   assert(IdentityOperand && "No identity operand!");
14569 
14570   if (Swapped)
14571     std::swap(OtherOp, IdentityOperand);
14572   SDValue NewSel =
14573       DAG.getSelect(DL, OtherOpVT, N->getOperand(0), OtherOp, IdentityOperand);
14574   return DAG.getNode(TrueVal.getOpcode(), DL, VT, FalseVal, NewSel);
14575 }
14576 
14577 // This tries to get rid of `select` and `icmp` that are being used to handle
14578 // `Targets` that do not support `cttz(0)`/`ctlz(0)`.
foldSelectOfCTTZOrCTLZ(SDNode * N,SelectionDAG & DAG)14579 static SDValue foldSelectOfCTTZOrCTLZ(SDNode *N, SelectionDAG &DAG) {
14580   SDValue Cond = N->getOperand(0);
14581 
14582   // This represents either CTTZ or CTLZ instruction.
14583   SDValue CountZeroes;
14584 
14585   SDValue ValOnZero;
14586 
14587   if (Cond.getOpcode() != ISD::SETCC)
14588     return SDValue();
14589 
14590   if (!isNullConstant(Cond->getOperand(1)))
14591     return SDValue();
14592 
14593   ISD::CondCode CCVal = cast<CondCodeSDNode>(Cond->getOperand(2))->get();
14594   if (CCVal == ISD::CondCode::SETEQ) {
14595     CountZeroes = N->getOperand(2);
14596     ValOnZero = N->getOperand(1);
14597   } else if (CCVal == ISD::CondCode::SETNE) {
14598     CountZeroes = N->getOperand(1);
14599     ValOnZero = N->getOperand(2);
14600   } else {
14601     return SDValue();
14602   }
14603 
14604   if (CountZeroes.getOpcode() == ISD::TRUNCATE ||
14605       CountZeroes.getOpcode() == ISD::ZERO_EXTEND)
14606     CountZeroes = CountZeroes.getOperand(0);
14607 
14608   if (CountZeroes.getOpcode() != ISD::CTTZ &&
14609       CountZeroes.getOpcode() != ISD::CTTZ_ZERO_UNDEF &&
14610       CountZeroes.getOpcode() != ISD::CTLZ &&
14611       CountZeroes.getOpcode() != ISD::CTLZ_ZERO_UNDEF)
14612     return SDValue();
14613 
14614   if (!isNullConstant(ValOnZero))
14615     return SDValue();
14616 
14617   SDValue CountZeroesArgument = CountZeroes->getOperand(0);
14618   if (Cond->getOperand(0) != CountZeroesArgument)
14619     return SDValue();
14620 
14621   if (CountZeroes.getOpcode() == ISD::CTTZ_ZERO_UNDEF) {
14622     CountZeroes = DAG.getNode(ISD::CTTZ, SDLoc(CountZeroes),
14623                               CountZeroes.getValueType(), CountZeroesArgument);
14624   } else if (CountZeroes.getOpcode() == ISD::CTLZ_ZERO_UNDEF) {
14625     CountZeroes = DAG.getNode(ISD::CTLZ, SDLoc(CountZeroes),
14626                               CountZeroes.getValueType(), CountZeroesArgument);
14627   }
14628 
14629   unsigned BitWidth = CountZeroes.getValueSizeInBits();
14630   SDValue BitWidthMinusOne =
14631       DAG.getConstant(BitWidth - 1, SDLoc(N), CountZeroes.getValueType());
14632 
14633   auto AndNode = DAG.getNode(ISD::AND, SDLoc(N), CountZeroes.getValueType(),
14634                              CountZeroes, BitWidthMinusOne);
14635   return DAG.getZExtOrTrunc(AndNode, SDLoc(N), N->getValueType(0));
14636 }
14637 
useInversedSetcc(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14638 static SDValue useInversedSetcc(SDNode *N, SelectionDAG &DAG,
14639                                 const RISCVSubtarget &Subtarget) {
14640   SDValue Cond = N->getOperand(0);
14641   SDValue True = N->getOperand(1);
14642   SDValue False = N->getOperand(2);
14643   SDLoc DL(N);
14644   EVT VT = N->getValueType(0);
14645   EVT CondVT = Cond.getValueType();
14646 
14647   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse())
14648     return SDValue();
14649 
14650   // Replace (setcc eq (and x, C)) with (setcc ne (and x, C))) to generate
14651   // BEXTI, where C is power of 2.
14652   if (Subtarget.hasStdExtZbs() && VT.isScalarInteger() &&
14653       (Subtarget.hasStdExtZicond() || Subtarget.hasVendorXVentanaCondOps())) {
14654     SDValue LHS = Cond.getOperand(0);
14655     SDValue RHS = Cond.getOperand(1);
14656     ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
14657     if (CC == ISD::SETEQ && LHS.getOpcode() == ISD::AND &&
14658         isa<ConstantSDNode>(LHS.getOperand(1)) && isNullConstant(RHS)) {
14659       const APInt &MaskVal = LHS.getConstantOperandAPInt(1);
14660       if (MaskVal.isPowerOf2() && !MaskVal.isSignedIntN(12))
14661         return DAG.getSelect(DL, VT,
14662                              DAG.getSetCC(DL, CondVT, LHS, RHS, ISD::SETNE),
14663                              False, True);
14664     }
14665   }
14666   return SDValue();
14667 }
14668 
performSELECTCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14669 static SDValue performSELECTCombine(SDNode *N, SelectionDAG &DAG,
14670                                     const RISCVSubtarget &Subtarget) {
14671   if (SDValue Folded = foldSelectOfCTTZOrCTLZ(N, DAG))
14672     return Folded;
14673 
14674   if (SDValue V = useInversedSetcc(N, DAG, Subtarget))
14675     return V;
14676 
14677   if (Subtarget.hasConditionalMoveFusion())
14678     return SDValue();
14679 
14680   SDValue TrueVal = N->getOperand(1);
14681   SDValue FalseVal = N->getOperand(2);
14682   if (SDValue V = tryFoldSelectIntoOp(N, DAG, TrueVal, FalseVal, /*Swapped*/false))
14683     return V;
14684   return tryFoldSelectIntoOp(N, DAG, FalseVal, TrueVal, /*Swapped*/true);
14685 }
14686 
14687 /// If we have a build_vector where each lane is binop X, C, where C
14688 /// is a constant (but not necessarily the same constant on all lanes),
14689 /// form binop (build_vector x1, x2, ...), (build_vector c1, c2, c3, ..).
14690 /// We assume that materializing a constant build vector will be no more
14691 /// expensive that performing O(n) binops.
performBUILD_VECTORCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget,const RISCVTargetLowering & TLI)14692 static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
14693                                           const RISCVSubtarget &Subtarget,
14694                                           const RISCVTargetLowering &TLI) {
14695   SDLoc DL(N);
14696   EVT VT = N->getValueType(0);
14697 
14698   assert(!VT.isScalableVector() && "unexpected build vector");
14699 
14700   if (VT.getVectorNumElements() == 1)
14701     return SDValue();
14702 
14703   const unsigned Opcode = N->op_begin()->getNode()->getOpcode();
14704   if (!TLI.isBinOp(Opcode))
14705     return SDValue();
14706 
14707   if (!TLI.isOperationLegalOrCustom(Opcode, VT) || !TLI.isTypeLegal(VT))
14708     return SDValue();
14709 
14710   SmallVector<SDValue> LHSOps;
14711   SmallVector<SDValue> RHSOps;
14712   for (SDValue Op : N->ops()) {
14713     if (Op.isUndef()) {
14714       // We can't form a divide or remainder from undef.
14715       if (!DAG.isSafeToSpeculativelyExecute(Opcode))
14716         return SDValue();
14717 
14718       LHSOps.push_back(Op);
14719       RHSOps.push_back(Op);
14720       continue;
14721     }
14722 
14723     // TODO: We can handle operations which have an neutral rhs value
14724     // (e.g. x + 0, a * 1 or a << 0), but we then have to keep track
14725     // of profit in a more explicit manner.
14726     if (Op.getOpcode() != Opcode || !Op.hasOneUse())
14727       return SDValue();
14728 
14729     LHSOps.push_back(Op.getOperand(0));
14730     if (!isa<ConstantSDNode>(Op.getOperand(1)) &&
14731         !isa<ConstantFPSDNode>(Op.getOperand(1)))
14732       return SDValue();
14733     // FIXME: Return failure if the RHS type doesn't match the LHS. Shifts may
14734     // have different LHS and RHS types.
14735     if (Op.getOperand(0).getValueType() != Op.getOperand(1).getValueType())
14736       return SDValue();
14737     RHSOps.push_back(Op.getOperand(1));
14738   }
14739 
14740   return DAG.getNode(Opcode, DL, VT, DAG.getBuildVector(VT, DL, LHSOps),
14741                      DAG.getBuildVector(VT, DL, RHSOps));
14742 }
14743 
performINSERT_VECTOR_ELTCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget,const RISCVTargetLowering & TLI)14744 static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
14745                                                const RISCVSubtarget &Subtarget,
14746                                                const RISCVTargetLowering &TLI) {
14747   SDValue InVec = N->getOperand(0);
14748   SDValue InVal = N->getOperand(1);
14749   SDValue EltNo = N->getOperand(2);
14750   SDLoc DL(N);
14751 
14752   EVT VT = InVec.getValueType();
14753   if (VT.isScalableVector())
14754     return SDValue();
14755 
14756   if (!InVec.hasOneUse())
14757     return SDValue();
14758 
14759   // Given insert_vector_elt (binop a, VecC), (same_binop b, C2), Elt
14760   // move the insert_vector_elts into the arms of the binop.  Note that
14761   // the new RHS must be a constant.
14762   const unsigned InVecOpcode = InVec->getOpcode();
14763   if (InVecOpcode == InVal->getOpcode() && TLI.isBinOp(InVecOpcode) &&
14764       InVal.hasOneUse()) {
14765     SDValue InVecLHS = InVec->getOperand(0);
14766     SDValue InVecRHS = InVec->getOperand(1);
14767     SDValue InValLHS = InVal->getOperand(0);
14768     SDValue InValRHS = InVal->getOperand(1);
14769 
14770     if (!ISD::isBuildVectorOfConstantSDNodes(InVecRHS.getNode()))
14771       return SDValue();
14772     if (!isa<ConstantSDNode>(InValRHS) && !isa<ConstantFPSDNode>(InValRHS))
14773       return SDValue();
14774     // FIXME: Return failure if the RHS type doesn't match the LHS. Shifts may
14775     // have different LHS and RHS types.
14776     if (InVec.getOperand(0).getValueType() != InVec.getOperand(1).getValueType())
14777       return SDValue();
14778     SDValue LHS = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
14779                               InVecLHS, InValLHS, EltNo);
14780     SDValue RHS = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
14781                               InVecRHS, InValRHS, EltNo);
14782     return DAG.getNode(InVecOpcode, DL, VT, LHS, RHS);
14783   }
14784 
14785   // Given insert_vector_elt (concat_vectors ...), InVal, Elt
14786   // move the insert_vector_elt to the source operand of the concat_vector.
14787   if (InVec.getOpcode() != ISD::CONCAT_VECTORS)
14788     return SDValue();
14789 
14790   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
14791   if (!IndexC)
14792     return SDValue();
14793   unsigned Elt = IndexC->getZExtValue();
14794 
14795   EVT ConcatVT = InVec.getOperand(0).getValueType();
14796   if (ConcatVT.getVectorElementType() != InVal.getValueType())
14797     return SDValue();
14798   unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
14799   SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, DL,
14800                                    EltNo.getValueType());
14801 
14802   unsigned ConcatOpIdx = Elt / ConcatNumElts;
14803   SDValue ConcatOp = InVec.getOperand(ConcatOpIdx);
14804   ConcatOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ConcatVT,
14805                          ConcatOp, InVal, NewIdx);
14806 
14807   SmallVector<SDValue> ConcatOps;
14808   ConcatOps.append(InVec->op_begin(), InVec->op_end());
14809   ConcatOps[ConcatOpIdx] = ConcatOp;
14810   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
14811 }
14812 
14813 // If we're concatenating a series of vector loads like
14814 // concat_vectors (load v4i8, p+0), (load v4i8, p+n), (load v4i8, p+n*2) ...
14815 // Then we can turn this into a strided load by widening the vector elements
14816 // vlse32 p, stride=n
performCONCAT_VECTORSCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget,const RISCVTargetLowering & TLI)14817 static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
14818                                             const RISCVSubtarget &Subtarget,
14819                                             const RISCVTargetLowering &TLI) {
14820   SDLoc DL(N);
14821   EVT VT = N->getValueType(0);
14822 
14823   // Only perform this combine on legal MVTs.
14824   if (!TLI.isTypeLegal(VT))
14825     return SDValue();
14826 
14827   // TODO: Potentially extend this to scalable vectors
14828   if (VT.isScalableVector())
14829     return SDValue();
14830 
14831   auto *BaseLd = dyn_cast<LoadSDNode>(N->getOperand(0));
14832   if (!BaseLd || !BaseLd->isSimple() || !ISD::isNormalLoad(BaseLd) ||
14833       !SDValue(BaseLd, 0).hasOneUse())
14834     return SDValue();
14835 
14836   EVT BaseLdVT = BaseLd->getValueType(0);
14837 
14838   // Go through the loads and check that they're strided
14839   SmallVector<LoadSDNode *> Lds;
14840   Lds.push_back(BaseLd);
14841   Align Align = BaseLd->getAlign();
14842   for (SDValue Op : N->ops().drop_front()) {
14843     auto *Ld = dyn_cast<LoadSDNode>(Op);
14844     if (!Ld || !Ld->isSimple() || !Op.hasOneUse() ||
14845         Ld->getChain() != BaseLd->getChain() || !ISD::isNormalLoad(Ld) ||
14846         Ld->getValueType(0) != BaseLdVT)
14847       return SDValue();
14848 
14849     Lds.push_back(Ld);
14850 
14851     // The common alignment is the most restrictive (smallest) of all the loads
14852     Align = std::min(Align, Ld->getAlign());
14853   }
14854 
14855   using PtrDiff = std::pair<std::variant<int64_t, SDValue>, bool>;
14856   auto GetPtrDiff = [&DAG](LoadSDNode *Ld1,
14857                            LoadSDNode *Ld2) -> std::optional<PtrDiff> {
14858     // If the load ptrs can be decomposed into a common (Base + Index) with a
14859     // common constant stride, then return the constant stride.
14860     BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG);
14861     BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG);
14862     if (BIO1.equalBaseIndex(BIO2, DAG))
14863       return {{BIO2.getOffset() - BIO1.getOffset(), false}};
14864 
14865     // Otherwise try to match (add LastPtr, Stride) or (add NextPtr, Stride)
14866     SDValue P1 = Ld1->getBasePtr();
14867     SDValue P2 = Ld2->getBasePtr();
14868     if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
14869       return {{P2.getOperand(1), false}};
14870     if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
14871       return {{P1.getOperand(1), true}};
14872 
14873     return std::nullopt;
14874   };
14875 
14876   // Get the distance between the first and second loads
14877   auto BaseDiff = GetPtrDiff(Lds[0], Lds[1]);
14878   if (!BaseDiff)
14879     return SDValue();
14880 
14881   // Check all the loads are the same distance apart
14882   for (auto *It = Lds.begin() + 1; It != Lds.end() - 1; It++)
14883     if (GetPtrDiff(*It, *std::next(It)) != BaseDiff)
14884       return SDValue();
14885 
14886   // TODO: At this point, we've successfully matched a generalized gather
14887   // load.  Maybe we should emit that, and then move the specialized
14888   // matchers above and below into a DAG combine?
14889 
14890   // Get the widened scalar type, e.g. v4i8 -> i64
14891   unsigned WideScalarBitWidth =
14892       BaseLdVT.getScalarSizeInBits() * BaseLdVT.getVectorNumElements();
14893   MVT WideScalarVT = MVT::getIntegerVT(WideScalarBitWidth);
14894 
14895   // Get the vector type for the strided load, e.g. 4 x v4i8 -> v4i64
14896   MVT WideVecVT = MVT::getVectorVT(WideScalarVT, N->getNumOperands());
14897   if (!TLI.isTypeLegal(WideVecVT))
14898     return SDValue();
14899 
14900   // Check that the operation is legal
14901   if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
14902     return SDValue();
14903 
14904   auto [StrideVariant, MustNegateStride] = *BaseDiff;
14905   SDValue Stride = std::holds_alternative<SDValue>(StrideVariant)
14906                        ? std::get<SDValue>(StrideVariant)
14907                        : DAG.getConstant(std::get<int64_t>(StrideVariant), DL,
14908                                          Lds[0]->getOffset().getValueType());
14909   if (MustNegateStride)
14910     Stride = DAG.getNegative(Stride, DL, Stride.getValueType());
14911 
14912   SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
14913   SDValue IntID =
14914     DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
14915                           Subtarget.getXLenVT());
14916 
14917   SDValue AllOneMask =
14918     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
14919                  DAG.getConstant(1, DL, MVT::i1));
14920 
14921   SDValue Ops[] = {BaseLd->getChain(),   IntID,  DAG.getUNDEF(WideVecVT),
14922                    BaseLd->getBasePtr(), Stride, AllOneMask};
14923 
14924   uint64_t MemSize;
14925   if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
14926       ConstStride && ConstStride->getSExtValue() >= 0)
14927     // total size = (elsize * n) + (stride - elsize) * (n-1)
14928     //            = elsize + stride * (n-1)
14929     MemSize = WideScalarVT.getSizeInBits() +
14930               ConstStride->getSExtValue() * (N->getNumOperands() - 1);
14931   else
14932     // If Stride isn't constant, then we can't know how much it will load
14933     MemSize = MemoryLocation::UnknownSize;
14934 
14935   MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
14936       BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(), MemSize,
14937       Align);
14938 
14939   SDValue StridedLoad = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs,
14940                                                 Ops, WideVecVT, MMO);
14941   for (SDValue Ld : N->ops())
14942     DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), StridedLoad);
14943 
14944   return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
14945 }
14946 
combineToVWMACC(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)14947 static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
14948                                const RISCVSubtarget &Subtarget) {
14949 
14950   assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
14951 
14952   if (N->getValueType(0).isFixedLengthVector())
14953     return SDValue();
14954 
14955   SDValue Addend = N->getOperand(0);
14956   SDValue MulOp = N->getOperand(1);
14957 
14958   if (N->getOpcode() == RISCVISD::ADD_VL) {
14959     SDValue AddMergeOp = N->getOperand(2);
14960     if (!AddMergeOp.isUndef())
14961       return SDValue();
14962   }
14963 
14964   auto IsVWMulOpc = [](unsigned Opc) {
14965     switch (Opc) {
14966     case RISCVISD::VWMUL_VL:
14967     case RISCVISD::VWMULU_VL:
14968     case RISCVISD::VWMULSU_VL:
14969       return true;
14970     default:
14971       return false;
14972     }
14973   };
14974 
14975   if (!IsVWMulOpc(MulOp.getOpcode()))
14976     std::swap(Addend, MulOp);
14977 
14978   if (!IsVWMulOpc(MulOp.getOpcode()))
14979     return SDValue();
14980 
14981   SDValue MulMergeOp = MulOp.getOperand(2);
14982 
14983   if (!MulMergeOp.isUndef())
14984     return SDValue();
14985 
14986   auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
14987                              const RISCVSubtarget &Subtarget) {
14988     if (N->getOpcode() == ISD::ADD) {
14989       SDLoc DL(N);
14990       return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
14991                                      Subtarget);
14992     }
14993     return std::make_pair(N->getOperand(3), N->getOperand(4));
14994   }(N, DAG, Subtarget);
14995 
14996   SDValue MulMask = MulOp.getOperand(3);
14997   SDValue MulVL = MulOp.getOperand(4);
14998 
14999   if (AddMask != MulMask || AddVL != MulVL)
15000     return SDValue();
15001 
15002   unsigned Opc = RISCVISD::VWMACC_VL + MulOp.getOpcode() - RISCVISD::VWMUL_VL;
15003   static_assert(RISCVISD::VWMACC_VL + 1 == RISCVISD::VWMACCU_VL,
15004                 "Unexpected opcode after VWMACC_VL");
15005   static_assert(RISCVISD::VWMACC_VL + 2 == RISCVISD::VWMACCSU_VL,
15006                 "Unexpected opcode after VWMACC_VL!");
15007   static_assert(RISCVISD::VWMUL_VL + 1 == RISCVISD::VWMULU_VL,
15008                 "Unexpected opcode after VWMUL_VL!");
15009   static_assert(RISCVISD::VWMUL_VL + 2 == RISCVISD::VWMULSU_VL,
15010                 "Unexpected opcode after VWMUL_VL!");
15011 
15012   SDLoc DL(N);
15013   EVT VT = N->getValueType(0);
15014   SDValue Ops[] = {MulOp.getOperand(0), MulOp.getOperand(1), Addend, AddMask,
15015                    AddVL};
15016   return DAG.getNode(Opc, DL, VT, Ops);
15017 }
15018 
legalizeScatterGatherIndexType(SDLoc DL,SDValue & Index,ISD::MemIndexType & IndexType,RISCVTargetLowering::DAGCombinerInfo & DCI)15019 static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
15020                                            ISD::MemIndexType &IndexType,
15021                                            RISCVTargetLowering::DAGCombinerInfo &DCI) {
15022   if (!DCI.isBeforeLegalize())
15023     return false;
15024 
15025   SelectionDAG &DAG = DCI.DAG;
15026   const MVT XLenVT =
15027     DAG.getMachineFunction().getSubtarget<RISCVSubtarget>().getXLenVT();
15028 
15029   const EVT IndexVT = Index.getValueType();
15030 
15031   // RISC-V indexed loads only support the "unsigned unscaled" addressing
15032   // mode, so anything else must be manually legalized.
15033   if (!isIndexTypeSigned(IndexType))
15034     return false;
15035 
15036   if (IndexVT.getVectorElementType().bitsLT(XLenVT)) {
15037     // Any index legalization should first promote to XLenVT, so we don't lose
15038     // bits when scaling. This may create an illegal index type so we let
15039     // LLVM's legalization take care of the splitting.
15040     // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet.
15041     Index = DAG.getNode(ISD::SIGN_EXTEND, DL,
15042                         IndexVT.changeVectorElementType(XLenVT), Index);
15043   }
15044   IndexType = ISD::UNSIGNED_SCALED;
15045   return true;
15046 }
15047 
15048 /// Match the index vector of a scatter or gather node as the shuffle mask
15049 /// which performs the rearrangement if possible.  Will only match if
15050 /// all lanes are touched, and thus replacing the scatter or gather with
15051 /// a unit strided access and shuffle is legal.
matchIndexAsShuffle(EVT VT,SDValue Index,SDValue Mask,SmallVector<int> & ShuffleMask)15052 static bool matchIndexAsShuffle(EVT VT, SDValue Index, SDValue Mask,
15053                                 SmallVector<int> &ShuffleMask) {
15054   if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode()))
15055     return false;
15056   if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode()))
15057     return false;
15058 
15059   const unsigned ElementSize = VT.getScalarStoreSize();
15060   const unsigned NumElems = VT.getVectorNumElements();
15061 
15062   // Create the shuffle mask and check all bits active
15063   assert(ShuffleMask.empty());
15064   BitVector ActiveLanes(NumElems);
15065   for (unsigned i = 0; i < Index->getNumOperands(); i++) {
15066     // TODO: We've found an active bit of UB, and could be
15067     // more aggressive here if desired.
15068     if (Index->getOperand(i)->isUndef())
15069       return false;
15070     uint64_t C = Index->getConstantOperandVal(i);
15071     if (C % ElementSize != 0)
15072       return false;
15073     C = C / ElementSize;
15074     if (C >= NumElems)
15075       return false;
15076     ShuffleMask.push_back(C);
15077     ActiveLanes.set(C);
15078   }
15079   return ActiveLanes.all();
15080 }
15081 
15082 /// Match the index of a gather or scatter operation as an operation
15083 /// with twice the element width and half the number of elements.  This is
15084 /// generally profitable (if legal) because these operations are linear
15085 /// in VL, so even if we cause some extract VTYPE/VL toggles, we still
15086 /// come out ahead.
matchIndexAsWiderOp(EVT VT,SDValue Index,SDValue Mask,Align BaseAlign,const RISCVSubtarget & ST)15087 static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
15088                                 Align BaseAlign, const RISCVSubtarget &ST) {
15089   if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode()))
15090     return false;
15091   if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode()))
15092     return false;
15093 
15094   // Attempt a doubling.  If we can use a element type 4x or 8x in
15095   // size, this will happen via multiply iterations of the transform.
15096   const unsigned NumElems = VT.getVectorNumElements();
15097   if (NumElems % 2 != 0)
15098     return false;
15099 
15100   const unsigned ElementSize = VT.getScalarStoreSize();
15101   const unsigned WiderElementSize = ElementSize * 2;
15102   if (WiderElementSize > ST.getELen()/8)
15103     return false;
15104 
15105   if (!ST.hasFastUnalignedAccess() && BaseAlign < WiderElementSize)
15106     return false;
15107 
15108   for (unsigned i = 0; i < Index->getNumOperands(); i++) {
15109     // TODO: We've found an active bit of UB, and could be
15110     // more aggressive here if desired.
15111     if (Index->getOperand(i)->isUndef())
15112       return false;
15113     // TODO: This offset check is too strict if we support fully
15114     // misaligned memory operations.
15115     uint64_t C = Index->getConstantOperandVal(i);
15116     if (i % 2 == 0) {
15117       if (C % WiderElementSize != 0)
15118         return false;
15119       continue;
15120     }
15121     uint64_t Last = Index->getConstantOperandVal(i-1);
15122     if (C != Last + ElementSize)
15123       return false;
15124   }
15125   return true;
15126 }
15127 
15128 
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const15129 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
15130                                                DAGCombinerInfo &DCI) const {
15131   SelectionDAG &DAG = DCI.DAG;
15132   const MVT XLenVT = Subtarget.getXLenVT();
15133   SDLoc DL(N);
15134 
15135   // Helper to call SimplifyDemandedBits on an operand of N where only some low
15136   // bits are demanded. N will be added to the Worklist if it was not deleted.
15137   // Caller should return SDValue(N, 0) if this returns true.
15138   auto SimplifyDemandedLowBitsHelper = [&](unsigned OpNo, unsigned LowBits) {
15139     SDValue Op = N->getOperand(OpNo);
15140     APInt Mask = APInt::getLowBitsSet(Op.getValueSizeInBits(), LowBits);
15141     if (!SimplifyDemandedBits(Op, Mask, DCI))
15142       return false;
15143 
15144     if (N->getOpcode() != ISD::DELETED_NODE)
15145       DCI.AddToWorklist(N);
15146     return true;
15147   };
15148 
15149   switch (N->getOpcode()) {
15150   default:
15151     break;
15152   case RISCVISD::SplitF64: {
15153     SDValue Op0 = N->getOperand(0);
15154     // If the input to SplitF64 is just BuildPairF64 then the operation is
15155     // redundant. Instead, use BuildPairF64's operands directly.
15156     if (Op0->getOpcode() == RISCVISD::BuildPairF64)
15157       return DCI.CombineTo(N, Op0.getOperand(0), Op0.getOperand(1));
15158 
15159     if (Op0->isUndef()) {
15160       SDValue Lo = DAG.getUNDEF(MVT::i32);
15161       SDValue Hi = DAG.getUNDEF(MVT::i32);
15162       return DCI.CombineTo(N, Lo, Hi);
15163     }
15164 
15165     // It's cheaper to materialise two 32-bit integers than to load a double
15166     // from the constant pool and transfer it to integer registers through the
15167     // stack.
15168     if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Op0)) {
15169       APInt V = C->getValueAPF().bitcastToAPInt();
15170       SDValue Lo = DAG.getConstant(V.trunc(32), DL, MVT::i32);
15171       SDValue Hi = DAG.getConstant(V.lshr(32).trunc(32), DL, MVT::i32);
15172       return DCI.CombineTo(N, Lo, Hi);
15173     }
15174 
15175     // This is a target-specific version of a DAGCombine performed in
15176     // DAGCombiner::visitBITCAST. It performs the equivalent of:
15177     // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
15178     // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
15179     if (!(Op0.getOpcode() == ISD::FNEG || Op0.getOpcode() == ISD::FABS) ||
15180         !Op0.getNode()->hasOneUse())
15181       break;
15182     SDValue NewSplitF64 =
15183         DAG.getNode(RISCVISD::SplitF64, DL, DAG.getVTList(MVT::i32, MVT::i32),
15184                     Op0.getOperand(0));
15185     SDValue Lo = NewSplitF64.getValue(0);
15186     SDValue Hi = NewSplitF64.getValue(1);
15187     APInt SignBit = APInt::getSignMask(32);
15188     if (Op0.getOpcode() == ISD::FNEG) {
15189       SDValue NewHi = DAG.getNode(ISD::XOR, DL, MVT::i32, Hi,
15190                                   DAG.getConstant(SignBit, DL, MVT::i32));
15191       return DCI.CombineTo(N, Lo, NewHi);
15192     }
15193     assert(Op0.getOpcode() == ISD::FABS);
15194     SDValue NewHi = DAG.getNode(ISD::AND, DL, MVT::i32, Hi,
15195                                 DAG.getConstant(~SignBit, DL, MVT::i32));
15196     return DCI.CombineTo(N, Lo, NewHi);
15197   }
15198   case RISCVISD::SLLW:
15199   case RISCVISD::SRAW:
15200   case RISCVISD::SRLW:
15201   case RISCVISD::RORW:
15202   case RISCVISD::ROLW: {
15203     // Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
15204     if (SimplifyDemandedLowBitsHelper(0, 32) ||
15205         SimplifyDemandedLowBitsHelper(1, 5))
15206       return SDValue(N, 0);
15207 
15208     break;
15209   }
15210   case RISCVISD::CLZW:
15211   case RISCVISD::CTZW: {
15212     // Only the lower 32 bits of the first operand are read
15213     if (SimplifyDemandedLowBitsHelper(0, 32))
15214       return SDValue(N, 0);
15215     break;
15216   }
15217   case RISCVISD::FMV_W_X_RV64: {
15218     // If the input to FMV_W_X_RV64 is just FMV_X_ANYEXTW_RV64 the the
15219     // conversion is unnecessary and can be replaced with the
15220     // FMV_X_ANYEXTW_RV64 operand.
15221     SDValue Op0 = N->getOperand(0);
15222     if (Op0.getOpcode() == RISCVISD::FMV_X_ANYEXTW_RV64)
15223       return Op0.getOperand(0);
15224     break;
15225   }
15226   case RISCVISD::FMV_X_ANYEXTH:
15227   case RISCVISD::FMV_X_ANYEXTW_RV64: {
15228     SDLoc DL(N);
15229     SDValue Op0 = N->getOperand(0);
15230     MVT VT = N->getSimpleValueType(0);
15231     // If the input to FMV_X_ANYEXTW_RV64 is just FMV_W_X_RV64 then the
15232     // conversion is unnecessary and can be replaced with the FMV_W_X_RV64
15233     // operand. Similar for FMV_X_ANYEXTH and FMV_H_X.
15234     if ((N->getOpcode() == RISCVISD::FMV_X_ANYEXTW_RV64 &&
15235          Op0->getOpcode() == RISCVISD::FMV_W_X_RV64) ||
15236         (N->getOpcode() == RISCVISD::FMV_X_ANYEXTH &&
15237          Op0->getOpcode() == RISCVISD::FMV_H_X)) {
15238       assert(Op0.getOperand(0).getValueType() == VT &&
15239              "Unexpected value type!");
15240       return Op0.getOperand(0);
15241     }
15242 
15243     // This is a target-specific version of a DAGCombine performed in
15244     // DAGCombiner::visitBITCAST. It performs the equivalent of:
15245     // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
15246     // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
15247     if (!(Op0.getOpcode() == ISD::FNEG || Op0.getOpcode() == ISD::FABS) ||
15248         !Op0.getNode()->hasOneUse())
15249       break;
15250     SDValue NewFMV = DAG.getNode(N->getOpcode(), DL, VT, Op0.getOperand(0));
15251     unsigned FPBits = N->getOpcode() == RISCVISD::FMV_X_ANYEXTW_RV64 ? 32 : 16;
15252     APInt SignBit = APInt::getSignMask(FPBits).sext(VT.getSizeInBits());
15253     if (Op0.getOpcode() == ISD::FNEG)
15254       return DAG.getNode(ISD::XOR, DL, VT, NewFMV,
15255                          DAG.getConstant(SignBit, DL, VT));
15256 
15257     assert(Op0.getOpcode() == ISD::FABS);
15258     return DAG.getNode(ISD::AND, DL, VT, NewFMV,
15259                        DAG.getConstant(~SignBit, DL, VT));
15260   }
15261   case ISD::ADD: {
15262     if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
15263       return V;
15264     if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
15265       return V;
15266     return performADDCombine(N, DAG, Subtarget);
15267   }
15268   case ISD::SUB: {
15269     if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
15270       return V;
15271     return performSUBCombine(N, DAG, Subtarget);
15272   }
15273   case ISD::AND:
15274     return performANDCombine(N, DCI, Subtarget);
15275   case ISD::OR:
15276     return performORCombine(N, DCI, Subtarget);
15277   case ISD::XOR:
15278     return performXORCombine(N, DAG, Subtarget);
15279   case ISD::MUL:
15280     if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
15281       return V;
15282     return performMULCombine(N, DAG);
15283   case ISD::FADD:
15284   case ISD::UMAX:
15285   case ISD::UMIN:
15286   case ISD::SMAX:
15287   case ISD::SMIN:
15288   case ISD::FMAXNUM:
15289   case ISD::FMINNUM: {
15290     if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
15291       return V;
15292     if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
15293       return V;
15294     return SDValue();
15295   }
15296   case ISD::SETCC:
15297     return performSETCCCombine(N, DAG, Subtarget);
15298   case ISD::SIGN_EXTEND_INREG:
15299     return performSIGN_EXTEND_INREGCombine(N, DAG, Subtarget);
15300   case ISD::ZERO_EXTEND:
15301     // Fold (zero_extend (fp_to_uint X)) to prevent forming fcvt+zexti32 during
15302     // type legalization. This is safe because fp_to_uint produces poison if
15303     // it overflows.
15304     if (N->getValueType(0) == MVT::i64 && Subtarget.is64Bit()) {
15305       SDValue Src = N->getOperand(0);
15306       if (Src.getOpcode() == ISD::FP_TO_UINT &&
15307           isTypeLegal(Src.getOperand(0).getValueType()))
15308         return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), MVT::i64,
15309                            Src.getOperand(0));
15310       if (Src.getOpcode() == ISD::STRICT_FP_TO_UINT && Src.hasOneUse() &&
15311           isTypeLegal(Src.getOperand(1).getValueType())) {
15312         SDVTList VTs = DAG.getVTList(MVT::i64, MVT::Other);
15313         SDValue Res = DAG.getNode(ISD::STRICT_FP_TO_UINT, SDLoc(N), VTs,
15314                                   Src.getOperand(0), Src.getOperand(1));
15315         DCI.CombineTo(N, Res);
15316         DAG.ReplaceAllUsesOfValueWith(Src.getValue(1), Res.getValue(1));
15317         DCI.recursivelyDeleteUnusedNodes(Src.getNode());
15318         return SDValue(N, 0); // Return N so it doesn't get rechecked.
15319       }
15320     }
15321     return SDValue();
15322   case RISCVISD::TRUNCATE_VECTOR_VL: {
15323     // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
15324     // This would be benefit for the cases where X and Y are both the same value
15325     // type of low precision vectors. Since the truncate would be lowered into
15326     // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
15327     // restriction, such pattern would be expanded into a series of "vsetvli"
15328     // and "vnsrl" instructions later to reach this point.
15329     auto IsTruncNode = [](SDValue V) {
15330       if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
15331         return false;
15332       SDValue VL = V.getOperand(2);
15333       auto *C = dyn_cast<ConstantSDNode>(VL);
15334       // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
15335       bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
15336                              (isa<RegisterSDNode>(VL) &&
15337                               cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
15338       return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
15339              IsVLMAXForVMSET;
15340     };
15341 
15342     SDValue Op = N->getOperand(0);
15343 
15344     // We need to first find the inner level of TRUNCATE_VECTOR_VL node
15345     // to distinguish such pattern.
15346     while (IsTruncNode(Op)) {
15347       if (!Op.hasOneUse())
15348         return SDValue();
15349       Op = Op.getOperand(0);
15350     }
15351 
15352     if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
15353       SDValue N0 = Op.getOperand(0);
15354       SDValue N1 = Op.getOperand(1);
15355       if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
15356           N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
15357         SDValue N00 = N0.getOperand(0);
15358         SDValue N10 = N1.getOperand(0);
15359         if (N00.getValueType().isVector() &&
15360             N00.getValueType() == N10.getValueType() &&
15361             N->getValueType(0) == N10.getValueType()) {
15362           unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
15363           SDValue SMin = DAG.getNode(
15364               ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
15365               DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
15366           return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
15367         }
15368       }
15369     }
15370     break;
15371   }
15372   case ISD::TRUNCATE:
15373     return performTRUNCATECombine(N, DAG, Subtarget);
15374   case ISD::SELECT:
15375     return performSELECTCombine(N, DAG, Subtarget);
15376   case RISCVISD::CZERO_EQZ:
15377   case RISCVISD::CZERO_NEZ:
15378     // czero_eq X, (xor Y, 1) -> czero_ne X, Y if Y is 0 or 1.
15379     // czero_ne X, (xor Y, 1) -> czero_eq X, Y if Y is 0 or 1.
15380     if (N->getOperand(1).getOpcode() == ISD::XOR &&
15381         isOneConstant(N->getOperand(1).getOperand(1))) {
15382       SDValue Cond = N->getOperand(1).getOperand(0);
15383       APInt Mask = APInt::getBitsSetFrom(Cond.getValueSizeInBits(), 1);
15384       if (DAG.MaskedValueIsZero(Cond, Mask)) {
15385         unsigned NewOpc = N->getOpcode() == RISCVISD::CZERO_EQZ
15386                               ? RISCVISD::CZERO_NEZ
15387                               : RISCVISD::CZERO_EQZ;
15388         return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0),
15389                            N->getOperand(0), Cond);
15390       }
15391     }
15392     return SDValue();
15393 
15394   case RISCVISD::SELECT_CC: {
15395     // Transform
15396     SDValue LHS = N->getOperand(0);
15397     SDValue RHS = N->getOperand(1);
15398     SDValue CC = N->getOperand(2);
15399     ISD::CondCode CCVal = cast<CondCodeSDNode>(CC)->get();
15400     SDValue TrueV = N->getOperand(3);
15401     SDValue FalseV = N->getOperand(4);
15402     SDLoc DL(N);
15403     EVT VT = N->getValueType(0);
15404 
15405     // If the True and False values are the same, we don't need a select_cc.
15406     if (TrueV == FalseV)
15407       return TrueV;
15408 
15409     // (select (x < 0), y, z)  -> x >> (XLEN - 1) & (y - z) + z
15410     // (select (x >= 0), y, z) -> x >> (XLEN - 1) & (z - y) + y
15411     if (!Subtarget.hasShortForwardBranchOpt() && isa<ConstantSDNode>(TrueV) &&
15412         isa<ConstantSDNode>(FalseV) && isNullConstant(RHS) &&
15413         (CCVal == ISD::CondCode::SETLT || CCVal == ISD::CondCode::SETGE)) {
15414       if (CCVal == ISD::CondCode::SETGE)
15415         std::swap(TrueV, FalseV);
15416 
15417       int64_t TrueSImm = cast<ConstantSDNode>(TrueV)->getSExtValue();
15418       int64_t FalseSImm = cast<ConstantSDNode>(FalseV)->getSExtValue();
15419       // Only handle simm12, if it is not in this range, it can be considered as
15420       // register.
15421       if (isInt<12>(TrueSImm) && isInt<12>(FalseSImm) &&
15422           isInt<12>(TrueSImm - FalseSImm)) {
15423         SDValue SRA =
15424             DAG.getNode(ISD::SRA, DL, VT, LHS,
15425                         DAG.getConstant(Subtarget.getXLen() - 1, DL, VT));
15426         SDValue AND =
15427             DAG.getNode(ISD::AND, DL, VT, SRA,
15428                         DAG.getConstant(TrueSImm - FalseSImm, DL, VT));
15429         return DAG.getNode(ISD::ADD, DL, VT, AND, FalseV);
15430       }
15431 
15432       if (CCVal == ISD::CondCode::SETGE)
15433         std::swap(TrueV, FalseV);
15434     }
15435 
15436     if (combine_CC(LHS, RHS, CC, DL, DAG, Subtarget))
15437       return DAG.getNode(RISCVISD::SELECT_CC, DL, N->getValueType(0),
15438                          {LHS, RHS, CC, TrueV, FalseV});
15439 
15440     if (!Subtarget.hasConditionalMoveFusion()) {
15441       // (select c, -1, y) -> -c | y
15442       if (isAllOnesConstant(TrueV)) {
15443         SDValue C = DAG.getSetCC(DL, VT, LHS, RHS, CCVal);
15444         SDValue Neg = DAG.getNegative(C, DL, VT);
15445         return DAG.getNode(ISD::OR, DL, VT, Neg, FalseV);
15446       }
15447       // (select c, y, -1) -> -!c | y
15448       if (isAllOnesConstant(FalseV)) {
15449         SDValue C =
15450             DAG.getSetCC(DL, VT, LHS, RHS, ISD::getSetCCInverse(CCVal, VT));
15451         SDValue Neg = DAG.getNegative(C, DL, VT);
15452         return DAG.getNode(ISD::OR, DL, VT, Neg, TrueV);
15453       }
15454 
15455       // (select c, 0, y) -> -!c & y
15456       if (isNullConstant(TrueV)) {
15457         SDValue C =
15458             DAG.getSetCC(DL, VT, LHS, RHS, ISD::getSetCCInverse(CCVal, VT));
15459         SDValue Neg = DAG.getNegative(C, DL, VT);
15460         return DAG.getNode(ISD::AND, DL, VT, Neg, FalseV);
15461       }
15462       // (select c, y, 0) -> -c & y
15463       if (isNullConstant(FalseV)) {
15464         SDValue C = DAG.getSetCC(DL, VT, LHS, RHS, CCVal);
15465         SDValue Neg = DAG.getNegative(C, DL, VT);
15466         return DAG.getNode(ISD::AND, DL, VT, Neg, TrueV);
15467       }
15468       // (riscvisd::select_cc x, 0, ne, x, 1) -> (add x, (setcc x, 0, eq))
15469       // (riscvisd::select_cc x, 0, eq, 1, x) -> (add x, (setcc x, 0, eq))
15470       if (((isOneConstant(FalseV) && LHS == TrueV &&
15471             CCVal == ISD::CondCode::SETNE) ||
15472            (isOneConstant(TrueV) && LHS == FalseV &&
15473             CCVal == ISD::CondCode::SETEQ)) &&
15474           isNullConstant(RHS)) {
15475         // freeze it to be safe.
15476         LHS = DAG.getFreeze(LHS);
15477         SDValue C = DAG.getSetCC(DL, VT, LHS, RHS, ISD::CondCode::SETEQ);
15478         return DAG.getNode(ISD::ADD, DL, VT, LHS, C);
15479       }
15480     }
15481 
15482     // If both true/false are an xor with 1, pull through the select.
15483     // This can occur after op legalization if both operands are setccs that
15484     // require an xor to invert.
15485     // FIXME: Generalize to other binary ops with identical operand?
15486     if (TrueV.getOpcode() == ISD::XOR && FalseV.getOpcode() == ISD::XOR &&
15487         TrueV.getOperand(1) == FalseV.getOperand(1) &&
15488         isOneConstant(TrueV.getOperand(1)) &&
15489         TrueV.hasOneUse() && FalseV.hasOneUse()) {
15490       SDValue NewSel = DAG.getNode(RISCVISD::SELECT_CC, DL, VT, LHS, RHS, CC,
15491                                    TrueV.getOperand(0), FalseV.getOperand(0));
15492       return DAG.getNode(ISD::XOR, DL, VT, NewSel, TrueV.getOperand(1));
15493     }
15494 
15495     return SDValue();
15496   }
15497   case RISCVISD::BR_CC: {
15498     SDValue LHS = N->getOperand(1);
15499     SDValue RHS = N->getOperand(2);
15500     SDValue CC = N->getOperand(3);
15501     SDLoc DL(N);
15502 
15503     if (combine_CC(LHS, RHS, CC, DL, DAG, Subtarget))
15504       return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0),
15505                          N->getOperand(0), LHS, RHS, CC, N->getOperand(4));
15506 
15507     return SDValue();
15508   }
15509   case ISD::BITREVERSE:
15510     return performBITREVERSECombine(N, DAG, Subtarget);
15511   case ISD::FP_TO_SINT:
15512   case ISD::FP_TO_UINT:
15513     return performFP_TO_INTCombine(N, DCI, Subtarget);
15514   case ISD::FP_TO_SINT_SAT:
15515   case ISD::FP_TO_UINT_SAT:
15516     return performFP_TO_INT_SATCombine(N, DCI, Subtarget);
15517   case ISD::FCOPYSIGN: {
15518     EVT VT = N->getValueType(0);
15519     if (!VT.isVector())
15520       break;
15521     // There is a form of VFSGNJ which injects the negated sign of its second
15522     // operand. Try and bubble any FNEG up after the extend/round to produce
15523     // this optimized pattern. Avoid modifying cases where FP_ROUND and
15524     // TRUNC=1.
15525     SDValue In2 = N->getOperand(1);
15526     // Avoid cases where the extend/round has multiple uses, as duplicating
15527     // those is typically more expensive than removing a fneg.
15528     if (!In2.hasOneUse())
15529       break;
15530     if (In2.getOpcode() != ISD::FP_EXTEND &&
15531         (In2.getOpcode() != ISD::FP_ROUND || In2.getConstantOperandVal(1) != 0))
15532       break;
15533     In2 = In2.getOperand(0);
15534     if (In2.getOpcode() != ISD::FNEG)
15535       break;
15536     SDLoc DL(N);
15537     SDValue NewFPExtRound = DAG.getFPExtendOrRound(In2.getOperand(0), DL, VT);
15538     return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N->getOperand(0),
15539                        DAG.getNode(ISD::FNEG, DL, VT, NewFPExtRound));
15540   }
15541   case ISD::MGATHER: {
15542     const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
15543     const EVT VT = N->getValueType(0);
15544     SDValue Index = MGN->getIndex();
15545     SDValue ScaleOp = MGN->getScale();
15546     ISD::MemIndexType IndexType = MGN->getIndexType();
15547     assert(!MGN->isIndexScaled() &&
15548            "Scaled gather/scatter should not be formed");
15549 
15550     SDLoc DL(N);
15551     if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
15552       return DAG.getMaskedGather(
15553           N->getVTList(), MGN->getMemoryVT(), DL,
15554           {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
15555            MGN->getBasePtr(), Index, ScaleOp},
15556           MGN->getMemOperand(), IndexType, MGN->getExtensionType());
15557 
15558     if (narrowIndex(Index, IndexType, DAG))
15559       return DAG.getMaskedGather(
15560           N->getVTList(), MGN->getMemoryVT(), DL,
15561           {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
15562            MGN->getBasePtr(), Index, ScaleOp},
15563           MGN->getMemOperand(), IndexType, MGN->getExtensionType());
15564 
15565     if (Index.getOpcode() == ISD::BUILD_VECTOR &&
15566         MGN->getExtensionType() == ISD::NON_EXTLOAD && isTypeLegal(VT)) {
15567       // The sequence will be XLenVT, not the type of Index. Tell
15568       // isSimpleVIDSequence this so we avoid overflow.
15569       if (std::optional<VIDSequence> SimpleVID =
15570               isSimpleVIDSequence(Index, Subtarget.getXLen());
15571           SimpleVID && SimpleVID->StepDenominator == 1) {
15572         const int64_t StepNumerator = SimpleVID->StepNumerator;
15573         const int64_t Addend = SimpleVID->Addend;
15574 
15575         // Note: We don't need to check alignment here since (by assumption
15576         // from the existance of the gather), our offsets must be sufficiently
15577         // aligned.
15578 
15579         const EVT PtrVT = getPointerTy(DAG.getDataLayout());
15580         assert(MGN->getBasePtr()->getValueType(0) == PtrVT);
15581         assert(IndexType == ISD::UNSIGNED_SCALED);
15582         SDValue BasePtr = DAG.getNode(ISD::ADD, DL, PtrVT, MGN->getBasePtr(),
15583                                       DAG.getConstant(Addend, DL, PtrVT));
15584 
15585         SDVTList VTs = DAG.getVTList({VT, MVT::Other});
15586         SDValue IntID =
15587           DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
15588                                 XLenVT);
15589         SDValue Ops[] =
15590           {MGN->getChain(), IntID, MGN->getPassThru(), BasePtr,
15591            DAG.getConstant(StepNumerator, DL, XLenVT), MGN->getMask()};
15592         return DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs,
15593                                        Ops, VT, MGN->getMemOperand());
15594       }
15595     }
15596 
15597     SmallVector<int> ShuffleMask;
15598     if (MGN->getExtensionType() == ISD::NON_EXTLOAD &&
15599         matchIndexAsShuffle(VT, Index, MGN->getMask(), ShuffleMask)) {
15600       SDValue Load = DAG.getMaskedLoad(VT, DL, MGN->getChain(),
15601                                        MGN->getBasePtr(), DAG.getUNDEF(XLenVT),
15602                                        MGN->getMask(), DAG.getUNDEF(VT),
15603                                        MGN->getMemoryVT(), MGN->getMemOperand(),
15604                                        ISD::UNINDEXED, ISD::NON_EXTLOAD);
15605       SDValue Shuffle =
15606         DAG.getVectorShuffle(VT, DL, Load, DAG.getUNDEF(VT), ShuffleMask);
15607       return DAG.getMergeValues({Shuffle, Load.getValue(1)}, DL);
15608     }
15609 
15610     if (MGN->getExtensionType() == ISD::NON_EXTLOAD &&
15611         matchIndexAsWiderOp(VT, Index, MGN->getMask(),
15612                             MGN->getMemOperand()->getBaseAlign(), Subtarget)) {
15613       SmallVector<SDValue> NewIndices;
15614       for (unsigned i = 0; i < Index->getNumOperands(); i += 2)
15615         NewIndices.push_back(Index.getOperand(i));
15616       EVT IndexVT = Index.getValueType()
15617         .getHalfNumVectorElementsVT(*DAG.getContext());
15618       Index = DAG.getBuildVector(IndexVT, DL, NewIndices);
15619 
15620       unsigned ElementSize = VT.getScalarStoreSize();
15621       EVT WideScalarVT = MVT::getIntegerVT(ElementSize * 8 * 2);
15622       auto EltCnt = VT.getVectorElementCount();
15623       assert(EltCnt.isKnownEven() && "Splitting vector, but not in half!");
15624       EVT WideVT = EVT::getVectorVT(*DAG.getContext(), WideScalarVT,
15625                                     EltCnt.divideCoefficientBy(2));
15626       SDValue Passthru = DAG.getBitcast(WideVT, MGN->getPassThru());
15627       EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
15628                                     EltCnt.divideCoefficientBy(2));
15629       SDValue Mask = DAG.getSplat(MaskVT, DL, DAG.getConstant(1, DL, MVT::i1));
15630 
15631       SDValue Gather =
15632         DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), WideVT, DL,
15633                             {MGN->getChain(), Passthru, Mask, MGN->getBasePtr(),
15634                              Index, ScaleOp},
15635                             MGN->getMemOperand(), IndexType, ISD::NON_EXTLOAD);
15636       SDValue Result = DAG.getBitcast(VT, Gather.getValue(0));
15637       return DAG.getMergeValues({Result, Gather.getValue(1)}, DL);
15638     }
15639     break;
15640   }
15641   case ISD::MSCATTER:{
15642     const auto *MSN = dyn_cast<MaskedScatterSDNode>(N);
15643     SDValue Index = MSN->getIndex();
15644     SDValue ScaleOp = MSN->getScale();
15645     ISD::MemIndexType IndexType = MSN->getIndexType();
15646     assert(!MSN->isIndexScaled() &&
15647            "Scaled gather/scatter should not be formed");
15648 
15649     SDLoc DL(N);
15650     if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
15651       return DAG.getMaskedScatter(
15652           N->getVTList(), MSN->getMemoryVT(), DL,
15653           {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
15654            Index, ScaleOp},
15655           MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
15656 
15657     if (narrowIndex(Index, IndexType, DAG))
15658       return DAG.getMaskedScatter(
15659           N->getVTList(), MSN->getMemoryVT(), DL,
15660           {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
15661            Index, ScaleOp},
15662           MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
15663 
15664     EVT VT = MSN->getValue()->getValueType(0);
15665     SmallVector<int> ShuffleMask;
15666     if (!MSN->isTruncatingStore() &&
15667         matchIndexAsShuffle(VT, Index, MSN->getMask(), ShuffleMask)) {
15668       SDValue Shuffle = DAG.getVectorShuffle(VT, DL, MSN->getValue(),
15669                                              DAG.getUNDEF(VT), ShuffleMask);
15670       return DAG.getMaskedStore(MSN->getChain(), DL, Shuffle, MSN->getBasePtr(),
15671                                 DAG.getUNDEF(XLenVT), MSN->getMask(),
15672                                 MSN->getMemoryVT(), MSN->getMemOperand(),
15673                                 ISD::UNINDEXED, false);
15674     }
15675     break;
15676   }
15677   case ISD::VP_GATHER: {
15678     const auto *VPGN = dyn_cast<VPGatherSDNode>(N);
15679     SDValue Index = VPGN->getIndex();
15680     SDValue ScaleOp = VPGN->getScale();
15681     ISD::MemIndexType IndexType = VPGN->getIndexType();
15682     assert(!VPGN->isIndexScaled() &&
15683            "Scaled gather/scatter should not be formed");
15684 
15685     SDLoc DL(N);
15686     if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
15687       return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL,
15688                              {VPGN->getChain(), VPGN->getBasePtr(), Index,
15689                               ScaleOp, VPGN->getMask(),
15690                               VPGN->getVectorLength()},
15691                              VPGN->getMemOperand(), IndexType);
15692 
15693     if (narrowIndex(Index, IndexType, DAG))
15694       return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL,
15695                              {VPGN->getChain(), VPGN->getBasePtr(), Index,
15696                               ScaleOp, VPGN->getMask(),
15697                               VPGN->getVectorLength()},
15698                              VPGN->getMemOperand(), IndexType);
15699 
15700     break;
15701   }
15702   case ISD::VP_SCATTER: {
15703     const auto *VPSN = dyn_cast<VPScatterSDNode>(N);
15704     SDValue Index = VPSN->getIndex();
15705     SDValue ScaleOp = VPSN->getScale();
15706     ISD::MemIndexType IndexType = VPSN->getIndexType();
15707     assert(!VPSN->isIndexScaled() &&
15708            "Scaled gather/scatter should not be formed");
15709 
15710     SDLoc DL(N);
15711     if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
15712       return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL,
15713                               {VPSN->getChain(), VPSN->getValue(),
15714                                VPSN->getBasePtr(), Index, ScaleOp,
15715                                VPSN->getMask(), VPSN->getVectorLength()},
15716                               VPSN->getMemOperand(), IndexType);
15717 
15718     if (narrowIndex(Index, IndexType, DAG))
15719       return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL,
15720                               {VPSN->getChain(), VPSN->getValue(),
15721                                VPSN->getBasePtr(), Index, ScaleOp,
15722                                VPSN->getMask(), VPSN->getVectorLength()},
15723                               VPSN->getMemOperand(), IndexType);
15724     break;
15725   }
15726   case RISCVISD::SRA_VL:
15727   case RISCVISD::SRL_VL:
15728   case RISCVISD::SHL_VL: {
15729     SDValue ShAmt = N->getOperand(1);
15730     if (ShAmt.getOpcode() == RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL) {
15731       // We don't need the upper 32 bits of a 64-bit element for a shift amount.
15732       SDLoc DL(N);
15733       SDValue VL = N->getOperand(4);
15734       EVT VT = N->getValueType(0);
15735       ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
15736                           ShAmt.getOperand(1), VL);
15737       return DAG.getNode(N->getOpcode(), DL, VT, N->getOperand(0), ShAmt,
15738                          N->getOperand(2), N->getOperand(3), N->getOperand(4));
15739     }
15740     break;
15741   }
15742   case ISD::SRA:
15743     if (SDValue V = performSRACombine(N, DAG, Subtarget))
15744       return V;
15745     [[fallthrough]];
15746   case ISD::SRL:
15747   case ISD::SHL: {
15748     SDValue ShAmt = N->getOperand(1);
15749     if (ShAmt.getOpcode() == RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL) {
15750       // We don't need the upper 32 bits of a 64-bit element for a shift amount.
15751       SDLoc DL(N);
15752       EVT VT = N->getValueType(0);
15753       ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
15754                           ShAmt.getOperand(1),
15755                           DAG.getRegister(RISCV::X0, Subtarget.getXLenVT()));
15756       return DAG.getNode(N->getOpcode(), DL, VT, N->getOperand(0), ShAmt);
15757     }
15758     break;
15759   }
15760   case RISCVISD::ADD_VL:
15761     if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
15762       return V;
15763     return combineToVWMACC(N, DAG, Subtarget);
15764   case RISCVISD::SUB_VL:
15765   case RISCVISD::VWADD_W_VL:
15766   case RISCVISD::VWADDU_W_VL:
15767   case RISCVISD::VWSUB_W_VL:
15768   case RISCVISD::VWSUBU_W_VL:
15769   case RISCVISD::MUL_VL:
15770     return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
15771   case RISCVISD::VFMADD_VL:
15772   case RISCVISD::VFNMADD_VL:
15773   case RISCVISD::VFMSUB_VL:
15774   case RISCVISD::VFNMSUB_VL:
15775   case RISCVISD::STRICT_VFMADD_VL:
15776   case RISCVISD::STRICT_VFNMADD_VL:
15777   case RISCVISD::STRICT_VFMSUB_VL:
15778   case RISCVISD::STRICT_VFNMSUB_VL:
15779     return performVFMADD_VLCombine(N, DAG, Subtarget);
15780   case RISCVISD::FMUL_VL:
15781     return performVFMUL_VLCombine(N, DAG, Subtarget);
15782   case RISCVISD::FADD_VL:
15783   case RISCVISD::FSUB_VL:
15784     return performFADDSUB_VLCombine(N, DAG, Subtarget);
15785   case ISD::LOAD:
15786   case ISD::STORE: {
15787     if (DCI.isAfterLegalizeDAG())
15788       if (SDValue V = performMemPairCombine(N, DCI))
15789         return V;
15790 
15791     if (N->getOpcode() != ISD::STORE)
15792       break;
15793 
15794     auto *Store = cast<StoreSDNode>(N);
15795     SDValue Chain = Store->getChain();
15796     EVT MemVT = Store->getMemoryVT();
15797     SDValue Val = Store->getValue();
15798     SDLoc DL(N);
15799 
15800     bool IsScalarizable =
15801         MemVT.isFixedLengthVector() && ISD::isNormalStore(Store) &&
15802         Store->isSimple() &&
15803         MemVT.getVectorElementType().bitsLE(Subtarget.getXLenVT()) &&
15804         isPowerOf2_64(MemVT.getSizeInBits()) &&
15805         MemVT.getSizeInBits() <= Subtarget.getXLen();
15806 
15807     // If sufficiently aligned we can scalarize stores of constant vectors of
15808     // any power-of-two size up to XLen bits, provided that they aren't too
15809     // expensive to materialize.
15810     //   vsetivli   zero, 2, e8, m1, ta, ma
15811     //   vmv.v.i    v8, 4
15812     //   vse64.v    v8, (a0)
15813     // ->
15814     //   li     a1, 1028
15815     //   sh     a1, 0(a0)
15816     if (DCI.isBeforeLegalize() && IsScalarizable &&
15817         ISD::isBuildVectorOfConstantSDNodes(Val.getNode())) {
15818       // Get the constant vector bits
15819       APInt NewC(Val.getValueSizeInBits(), 0);
15820       uint64_t EltSize = Val.getScalarValueSizeInBits();
15821       for (unsigned i = 0; i < Val.getNumOperands(); i++) {
15822         if (Val.getOperand(i).isUndef())
15823           continue;
15824         NewC.insertBits(Val.getConstantOperandAPInt(i).trunc(EltSize),
15825                         i * EltSize);
15826       }
15827       MVT NewVT = MVT::getIntegerVT(MemVT.getSizeInBits());
15828 
15829       if (RISCVMatInt::getIntMatCost(NewC, Subtarget.getXLen(), Subtarget,
15830                                      true) <= 2 &&
15831           allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
15832                                          NewVT, *Store->getMemOperand())) {
15833         SDValue NewV = DAG.getConstant(NewC, DL, NewVT);
15834         return DAG.getStore(Chain, DL, NewV, Store->getBasePtr(),
15835                             Store->getPointerInfo(), Store->getOriginalAlign(),
15836                             Store->getMemOperand()->getFlags());
15837       }
15838     }
15839 
15840     // Similarly, if sufficiently aligned we can scalarize vector copies, e.g.
15841     //   vsetivli   zero, 2, e16, m1, ta, ma
15842     //   vle16.v    v8, (a0)
15843     //   vse16.v    v8, (a1)
15844     if (auto *L = dyn_cast<LoadSDNode>(Val);
15845         L && DCI.isBeforeLegalize() && IsScalarizable && L->isSimple() &&
15846         L->hasNUsesOfValue(1, 0) && L->hasNUsesOfValue(1, 1) &&
15847         Store->getChain() == SDValue(L, 1) && ISD::isNormalLoad(L) &&
15848         L->getMemoryVT() == MemVT) {
15849       MVT NewVT = MVT::getIntegerVT(MemVT.getSizeInBits());
15850       if (allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
15851                                          NewVT, *Store->getMemOperand()) &&
15852           allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
15853                                          NewVT, *L->getMemOperand())) {
15854         SDValue NewL = DAG.getLoad(NewVT, DL, L->getChain(), L->getBasePtr(),
15855                                    L->getPointerInfo(), L->getOriginalAlign(),
15856                                    L->getMemOperand()->getFlags());
15857         return DAG.getStore(Chain, DL, NewL, Store->getBasePtr(),
15858                             Store->getPointerInfo(), Store->getOriginalAlign(),
15859                             Store->getMemOperand()->getFlags());
15860       }
15861     }
15862 
15863     // Combine store of vmv.x.s/vfmv.f.s to vse with VL of 1.
15864     // vfmv.f.s is represented as extract element from 0. Match it late to avoid
15865     // any illegal types.
15866     if (Val.getOpcode() == RISCVISD::VMV_X_S ||
15867         (DCI.isAfterLegalizeDAG() &&
15868          Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
15869          isNullConstant(Val.getOperand(1)))) {
15870       SDValue Src = Val.getOperand(0);
15871       MVT VecVT = Src.getSimpleValueType();
15872       // VecVT should be scalable and memory VT should match the element type.
15873       if (!Store->isIndexed() && VecVT.isScalableVector() &&
15874           MemVT == VecVT.getVectorElementType()) {
15875         SDLoc DL(N);
15876         MVT MaskVT = getMaskTypeFor(VecVT);
15877         return DAG.getStoreVP(
15878             Store->getChain(), DL, Src, Store->getBasePtr(), Store->getOffset(),
15879             DAG.getConstant(1, DL, MaskVT),
15880             DAG.getConstant(1, DL, Subtarget.getXLenVT()), MemVT,
15881             Store->getMemOperand(), Store->getAddressingMode(),
15882             Store->isTruncatingStore(), /*IsCompress*/ false);
15883       }
15884     }
15885 
15886     break;
15887   }
15888   case ISD::SPLAT_VECTOR: {
15889     EVT VT = N->getValueType(0);
15890     // Only perform this combine on legal MVT types.
15891     if (!isTypeLegal(VT))
15892       break;
15893     if (auto Gather = matchSplatAsGather(N->getOperand(0), VT.getSimpleVT(), N,
15894                                          DAG, Subtarget))
15895       return Gather;
15896     break;
15897   }
15898   case ISD::BUILD_VECTOR:
15899     if (SDValue V = performBUILD_VECTORCombine(N, DAG, Subtarget, *this))
15900       return V;
15901     break;
15902   case ISD::CONCAT_VECTORS:
15903     if (SDValue V = performCONCAT_VECTORSCombine(N, DAG, Subtarget, *this))
15904       return V;
15905     break;
15906   case ISD::INSERT_VECTOR_ELT:
15907     if (SDValue V = performINSERT_VECTOR_ELTCombine(N, DAG, Subtarget, *this))
15908       return V;
15909     break;
15910   case RISCVISD::VFMV_V_F_VL: {
15911     const MVT VT = N->getSimpleValueType(0);
15912     SDValue Passthru = N->getOperand(0);
15913     SDValue Scalar = N->getOperand(1);
15914     SDValue VL = N->getOperand(2);
15915 
15916     // If VL is 1, we can use vfmv.s.f.
15917     if (isOneConstant(VL))
15918       return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, VT, Passthru, Scalar, VL);
15919     break;
15920   }
15921   case RISCVISD::VMV_V_X_VL: {
15922     const MVT VT = N->getSimpleValueType(0);
15923     SDValue Passthru = N->getOperand(0);
15924     SDValue Scalar = N->getOperand(1);
15925     SDValue VL = N->getOperand(2);
15926 
15927     // Tail agnostic VMV.V.X only demands the vector element bitwidth from the
15928     // scalar input.
15929     unsigned ScalarSize = Scalar.getValueSizeInBits();
15930     unsigned EltWidth = VT.getScalarSizeInBits();
15931     if (ScalarSize > EltWidth && Passthru.isUndef())
15932       if (SimplifyDemandedLowBitsHelper(1, EltWidth))
15933         return SDValue(N, 0);
15934 
15935     // If VL is 1 and the scalar value won't benefit from immediate, we can
15936     // use vmv.s.x.
15937     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar);
15938     if (isOneConstant(VL) &&
15939         (!Const || Const->isZero() ||
15940          !Const->getAPIntValue().sextOrTrunc(EltWidth).isSignedIntN(5)))
15941       return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL);
15942 
15943     break;
15944   }
15945   case RISCVISD::VFMV_S_F_VL: {
15946     SDValue Src = N->getOperand(1);
15947     // Try to remove vector->scalar->vector if the scalar->vector is inserting
15948     // into an undef vector.
15949     // TODO: Could use a vslide or vmv.v.v for non-undef.
15950     if (N->getOperand(0).isUndef() &&
15951         Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
15952         isNullConstant(Src.getOperand(1)) &&
15953         Src.getOperand(0).getValueType().isScalableVector()) {
15954       EVT VT = N->getValueType(0);
15955       EVT SrcVT = Src.getOperand(0).getValueType();
15956       assert(SrcVT.getVectorElementType() == VT.getVectorElementType());
15957       // Widths match, just return the original vector.
15958       if (SrcVT == VT)
15959         return Src.getOperand(0);
15960       // TODO: Use insert_subvector/extract_subvector to change widen/narrow?
15961     }
15962     [[fallthrough]];
15963   }
15964   case RISCVISD::VMV_S_X_VL: {
15965     const MVT VT = N->getSimpleValueType(0);
15966     SDValue Passthru = N->getOperand(0);
15967     SDValue Scalar = N->getOperand(1);
15968     SDValue VL = N->getOperand(2);
15969 
15970     // Use M1 or smaller to avoid over constraining register allocation
15971     const MVT M1VT = getLMUL1VT(VT);
15972     if (M1VT.bitsLT(VT)) {
15973       SDValue M1Passthru =
15974           DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, Passthru,
15975                       DAG.getVectorIdxConstant(0, DL));
15976       SDValue Result =
15977           DAG.getNode(N->getOpcode(), DL, M1VT, M1Passthru, Scalar, VL);
15978       Result = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, Result,
15979                            DAG.getConstant(0, DL, XLenVT));
15980       return Result;
15981     }
15982 
15983     // We use a vmv.v.i if possible.  We limit this to LMUL1.  LMUL2 or
15984     // higher would involve overly constraining the register allocator for
15985     // no purpose.
15986     if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar);
15987         Const && !Const->isZero() && isInt<5>(Const->getSExtValue()) &&
15988         VT.bitsLE(getLMUL1VT(VT)) && Passthru.isUndef())
15989       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
15990 
15991     break;
15992   }
15993   case RISCVISD::VMV_X_S: {
15994     SDValue Vec = N->getOperand(0);
15995     MVT VecVT = N->getOperand(0).getSimpleValueType();
15996     const MVT M1VT = getLMUL1VT(VecVT);
15997     if (M1VT.bitsLT(VecVT)) {
15998       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, Vec,
15999                         DAG.getVectorIdxConstant(0, DL));
16000       return DAG.getNode(RISCVISD::VMV_X_S, DL, N->getSimpleValueType(0), Vec);
16001     }
16002     break;
16003   }
16004   case ISD::INTRINSIC_VOID:
16005   case ISD::INTRINSIC_W_CHAIN:
16006   case ISD::INTRINSIC_WO_CHAIN: {
16007     unsigned IntOpNo = N->getOpcode() == ISD::INTRINSIC_WO_CHAIN ? 0 : 1;
16008     unsigned IntNo = N->getConstantOperandVal(IntOpNo);
16009     switch (IntNo) {
16010       // By default we do not combine any intrinsic.
16011     default:
16012       return SDValue();
16013     case Intrinsic::riscv_masked_strided_load: {
16014       MVT VT = N->getSimpleValueType(0);
16015       auto *Load = cast<MemIntrinsicSDNode>(N);
16016       SDValue PassThru = N->getOperand(2);
16017       SDValue Base = N->getOperand(3);
16018       SDValue Stride = N->getOperand(4);
16019       SDValue Mask = N->getOperand(5);
16020 
16021       // If the stride is equal to the element size in bytes,  we can use
16022       // a masked.load.
16023       const unsigned ElementSize = VT.getScalarStoreSize();
16024       if (auto *StrideC = dyn_cast<ConstantSDNode>(Stride);
16025           StrideC && StrideC->getZExtValue() == ElementSize)
16026         return DAG.getMaskedLoad(VT, DL, Load->getChain(), Base,
16027                                  DAG.getUNDEF(XLenVT), Mask, PassThru,
16028                                  Load->getMemoryVT(), Load->getMemOperand(),
16029                                  ISD::UNINDEXED, ISD::NON_EXTLOAD);
16030       return SDValue();
16031     }
16032     case Intrinsic::riscv_masked_strided_store: {
16033       auto *Store = cast<MemIntrinsicSDNode>(N);
16034       SDValue Value = N->getOperand(2);
16035       SDValue Base = N->getOperand(3);
16036       SDValue Stride = N->getOperand(4);
16037       SDValue Mask = N->getOperand(5);
16038 
16039       // If the stride is equal to the element size in bytes,  we can use
16040       // a masked.store.
16041       const unsigned ElementSize = Value.getValueType().getScalarStoreSize();
16042       if (auto *StrideC = dyn_cast<ConstantSDNode>(Stride);
16043           StrideC && StrideC->getZExtValue() == ElementSize)
16044         return DAG.getMaskedStore(Store->getChain(), DL, Value, Base,
16045                                   DAG.getUNDEF(XLenVT), Mask,
16046                                   Store->getMemoryVT(), Store->getMemOperand(),
16047                                   ISD::UNINDEXED, false);
16048       return SDValue();
16049     }
16050     case Intrinsic::riscv_vcpop:
16051     case Intrinsic::riscv_vcpop_mask:
16052     case Intrinsic::riscv_vfirst:
16053     case Intrinsic::riscv_vfirst_mask: {
16054       SDValue VL = N->getOperand(2);
16055       if (IntNo == Intrinsic::riscv_vcpop_mask ||
16056           IntNo == Intrinsic::riscv_vfirst_mask)
16057         VL = N->getOperand(3);
16058       if (!isNullConstant(VL))
16059         return SDValue();
16060       // If VL is 0, vcpop -> li 0, vfirst -> li -1.
16061       SDLoc DL(N);
16062       EVT VT = N->getValueType(0);
16063       if (IntNo == Intrinsic::riscv_vfirst ||
16064           IntNo == Intrinsic::riscv_vfirst_mask)
16065         return DAG.getConstant(-1, DL, VT);
16066       return DAG.getConstant(0, DL, VT);
16067     }
16068     }
16069   }
16070   case ISD::BITCAST: {
16071     assert(Subtarget.useRVVForFixedLengthVectors());
16072     SDValue N0 = N->getOperand(0);
16073     EVT VT = N->getValueType(0);
16074     EVT SrcVT = N0.getValueType();
16075     // If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer
16076     // type, widen both sides to avoid a trip through memory.
16077     if ((SrcVT == MVT::v1i1 || SrcVT == MVT::v2i1 || SrcVT == MVT::v4i1) &&
16078         VT.isScalarInteger()) {
16079       unsigned NumConcats = 8 / SrcVT.getVectorNumElements();
16080       SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(SrcVT));
16081       Ops[0] = N0;
16082       SDLoc DL(N);
16083       N0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i1, Ops);
16084       N0 = DAG.getBitcast(MVT::i8, N0);
16085       return DAG.getNode(ISD::TRUNCATE, DL, VT, N0);
16086     }
16087 
16088     return SDValue();
16089   }
16090   }
16091 
16092   return SDValue();
16093 }
16094 
shouldTransformSignedTruncationCheck(EVT XVT,unsigned KeptBits) const16095 bool RISCVTargetLowering::shouldTransformSignedTruncationCheck(
16096     EVT XVT, unsigned KeptBits) const {
16097   // For vectors, we don't have a preference..
16098   if (XVT.isVector())
16099     return false;
16100 
16101   if (XVT != MVT::i32 && XVT != MVT::i64)
16102     return false;
16103 
16104   // We can use sext.w for RV64 or an srai 31 on RV32.
16105   if (KeptBits == 32 || KeptBits == 64)
16106     return true;
16107 
16108   // With Zbb we can use sext.h/sext.b.
16109   return Subtarget.hasStdExtZbb() &&
16110          ((KeptBits == 8 && XVT == MVT::i64 && !Subtarget.is64Bit()) ||
16111           KeptBits == 16);
16112 }
16113 
isDesirableToCommuteWithShift(const SDNode * N,CombineLevel Level) const16114 bool RISCVTargetLowering::isDesirableToCommuteWithShift(
16115     const SDNode *N, CombineLevel Level) const {
16116   assert((N->getOpcode() == ISD::SHL || N->getOpcode() == ISD::SRA ||
16117           N->getOpcode() == ISD::SRL) &&
16118          "Expected shift op");
16119 
16120   // The following folds are only desirable if `(OP _, c1 << c2)` can be
16121   // materialised in fewer instructions than `(OP _, c1)`:
16122   //
16123   //   (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
16124   //   (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
16125   SDValue N0 = N->getOperand(0);
16126   EVT Ty = N0.getValueType();
16127   if (Ty.isScalarInteger() &&
16128       (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR)) {
16129     auto *C1 = dyn_cast<ConstantSDNode>(N0->getOperand(1));
16130     auto *C2 = dyn_cast<ConstantSDNode>(N->getOperand(1));
16131     if (C1 && C2) {
16132       const APInt &C1Int = C1->getAPIntValue();
16133       APInt ShiftedC1Int = C1Int << C2->getAPIntValue();
16134 
16135       // We can materialise `c1 << c2` into an add immediate, so it's "free",
16136       // and the combine should happen, to potentially allow further combines
16137       // later.
16138       if (ShiftedC1Int.getSignificantBits() <= 64 &&
16139           isLegalAddImmediate(ShiftedC1Int.getSExtValue()))
16140         return true;
16141 
16142       // We can materialise `c1` in an add immediate, so it's "free", and the
16143       // combine should be prevented.
16144       if (C1Int.getSignificantBits() <= 64 &&
16145           isLegalAddImmediate(C1Int.getSExtValue()))
16146         return false;
16147 
16148       // Neither constant will fit into an immediate, so find materialisation
16149       // costs.
16150       int C1Cost =
16151           RISCVMatInt::getIntMatCost(C1Int, Ty.getSizeInBits(), Subtarget,
16152                                      /*CompressionCost*/ true);
16153       int ShiftedC1Cost = RISCVMatInt::getIntMatCost(
16154           ShiftedC1Int, Ty.getSizeInBits(), Subtarget,
16155           /*CompressionCost*/ true);
16156 
16157       // Materialising `c1` is cheaper than materialising `c1 << c2`, so the
16158       // combine should be prevented.
16159       if (C1Cost < ShiftedC1Cost)
16160         return false;
16161     }
16162   }
16163   return true;
16164 }
16165 
targetShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,TargetLoweringOpt & TLO) const16166 bool RISCVTargetLowering::targetShrinkDemandedConstant(
16167     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
16168     TargetLoweringOpt &TLO) const {
16169   // Delay this optimization as late as possible.
16170   if (!TLO.LegalOps)
16171     return false;
16172 
16173   EVT VT = Op.getValueType();
16174   if (VT.isVector())
16175     return false;
16176 
16177   unsigned Opcode = Op.getOpcode();
16178   if (Opcode != ISD::AND && Opcode != ISD::OR && Opcode != ISD::XOR)
16179     return false;
16180 
16181   ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
16182   if (!C)
16183     return false;
16184 
16185   const APInt &Mask = C->getAPIntValue();
16186 
16187   // Clear all non-demanded bits initially.
16188   APInt ShrunkMask = Mask & DemandedBits;
16189 
16190   // Try to make a smaller immediate by setting undemanded bits.
16191 
16192   APInt ExpandedMask = Mask | ~DemandedBits;
16193 
16194   auto IsLegalMask = [ShrunkMask, ExpandedMask](const APInt &Mask) -> bool {
16195     return ShrunkMask.isSubsetOf(Mask) && Mask.isSubsetOf(ExpandedMask);
16196   };
16197   auto UseMask = [Mask, Op, &TLO](const APInt &NewMask) -> bool {
16198     if (NewMask == Mask)
16199       return true;
16200     SDLoc DL(Op);
16201     SDValue NewC = TLO.DAG.getConstant(NewMask, DL, Op.getValueType());
16202     SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), DL, Op.getValueType(),
16203                                     Op.getOperand(0), NewC);
16204     return TLO.CombineTo(Op, NewOp);
16205   };
16206 
16207   // If the shrunk mask fits in sign extended 12 bits, let the target
16208   // independent code apply it.
16209   if (ShrunkMask.isSignedIntN(12))
16210     return false;
16211 
16212   // And has a few special cases for zext.
16213   if (Opcode == ISD::AND) {
16214     // Preserve (and X, 0xffff), if zext.h exists use zext.h,
16215     // otherwise use SLLI + SRLI.
16216     APInt NewMask = APInt(Mask.getBitWidth(), 0xffff);
16217     if (IsLegalMask(NewMask))
16218       return UseMask(NewMask);
16219 
16220     // Try to preserve (and X, 0xffffffff), the (zext_inreg X, i32) pattern.
16221     if (VT == MVT::i64) {
16222       APInt NewMask = APInt(64, 0xffffffff);
16223       if (IsLegalMask(NewMask))
16224         return UseMask(NewMask);
16225     }
16226   }
16227 
16228   // For the remaining optimizations, we need to be able to make a negative
16229   // number through a combination of mask and undemanded bits.
16230   if (!ExpandedMask.isNegative())
16231     return false;
16232 
16233   // What is the fewest number of bits we need to represent the negative number.
16234   unsigned MinSignedBits = ExpandedMask.getSignificantBits();
16235 
16236   // Try to make a 12 bit negative immediate. If that fails try to make a 32
16237   // bit negative immediate unless the shrunk immediate already fits in 32 bits.
16238   // If we can't create a simm12, we shouldn't change opaque constants.
16239   APInt NewMask = ShrunkMask;
16240   if (MinSignedBits <= 12)
16241     NewMask.setBitsFrom(11);
16242   else if (!C->isOpaque() && MinSignedBits <= 32 && !ShrunkMask.isSignedIntN(32))
16243     NewMask.setBitsFrom(31);
16244   else
16245     return false;
16246 
16247   // Check that our new mask is a subset of the demanded mask.
16248   assert(IsLegalMask(NewMask));
16249   return UseMask(NewMask);
16250 }
16251 
computeGREVOrGORC(uint64_t x,unsigned ShAmt,bool IsGORC)16252 static uint64_t computeGREVOrGORC(uint64_t x, unsigned ShAmt, bool IsGORC) {
16253   static const uint64_t GREVMasks[] = {
16254       0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL,
16255       0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL};
16256 
16257   for (unsigned Stage = 0; Stage != 6; ++Stage) {
16258     unsigned Shift = 1 << Stage;
16259     if (ShAmt & Shift) {
16260       uint64_t Mask = GREVMasks[Stage];
16261       uint64_t Res = ((x & Mask) << Shift) | ((x >> Shift) & Mask);
16262       if (IsGORC)
16263         Res |= x;
16264       x = Res;
16265     }
16266   }
16267 
16268   return x;
16269 }
16270 
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const16271 void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
16272                                                         KnownBits &Known,
16273                                                         const APInt &DemandedElts,
16274                                                         const SelectionDAG &DAG,
16275                                                         unsigned Depth) const {
16276   unsigned BitWidth = Known.getBitWidth();
16277   unsigned Opc = Op.getOpcode();
16278   assert((Opc >= ISD::BUILTIN_OP_END ||
16279           Opc == ISD::INTRINSIC_WO_CHAIN ||
16280           Opc == ISD::INTRINSIC_W_CHAIN ||
16281           Opc == ISD::INTRINSIC_VOID) &&
16282          "Should use MaskedValueIsZero if you don't know whether Op"
16283          " is a target node!");
16284 
16285   Known.resetAll();
16286   switch (Opc) {
16287   default: break;
16288   case RISCVISD::SELECT_CC: {
16289     Known = DAG.computeKnownBits(Op.getOperand(4), Depth + 1);
16290     // If we don't know any bits, early out.
16291     if (Known.isUnknown())
16292       break;
16293     KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(3), Depth + 1);
16294 
16295     // Only known if known in both the LHS and RHS.
16296     Known = Known.intersectWith(Known2);
16297     break;
16298   }
16299   case RISCVISD::CZERO_EQZ:
16300   case RISCVISD::CZERO_NEZ:
16301     Known = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
16302     // Result is either all zero or operand 0. We can propagate zeros, but not
16303     // ones.
16304     Known.One.clearAllBits();
16305     break;
16306   case RISCVISD::REMUW: {
16307     KnownBits Known2;
16308     Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
16309     Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
16310     // We only care about the lower 32 bits.
16311     Known = KnownBits::urem(Known.trunc(32), Known2.trunc(32));
16312     // Restore the original width by sign extending.
16313     Known = Known.sext(BitWidth);
16314     break;
16315   }
16316   case RISCVISD::DIVUW: {
16317     KnownBits Known2;
16318     Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
16319     Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
16320     // We only care about the lower 32 bits.
16321     Known = KnownBits::udiv(Known.trunc(32), Known2.trunc(32));
16322     // Restore the original width by sign extending.
16323     Known = Known.sext(BitWidth);
16324     break;
16325   }
16326   case RISCVISD::SLLW: {
16327     KnownBits Known2;
16328     Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
16329     Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
16330     Known = KnownBits::shl(Known.trunc(32), Known2.trunc(5).zext(32));
16331     // Restore the original width by sign extending.
16332     Known = Known.sext(BitWidth);
16333     break;
16334   }
16335   case RISCVISD::CTZW: {
16336     KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
16337     unsigned PossibleTZ = Known2.trunc(32).countMaxTrailingZeros();
16338     unsigned LowBits = llvm::bit_width(PossibleTZ);
16339     Known.Zero.setBitsFrom(LowBits);
16340     break;
16341   }
16342   case RISCVISD::CLZW: {
16343     KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
16344     unsigned PossibleLZ = Known2.trunc(32).countMaxLeadingZeros();
16345     unsigned LowBits = llvm::bit_width(PossibleLZ);
16346     Known.Zero.setBitsFrom(LowBits);
16347     break;
16348   }
16349   case RISCVISD::BREV8:
16350   case RISCVISD::ORC_B: {
16351     // FIXME: This is based on the non-ratified Zbp GREV and GORC where a
16352     // control value of 7 is equivalent to brev8 and orc.b.
16353     Known = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
16354     bool IsGORC = Op.getOpcode() == RISCVISD::ORC_B;
16355     // To compute zeros, we need to invert the value and invert it back after.
16356     Known.Zero =
16357         ~computeGREVOrGORC(~Known.Zero.getZExtValue(), 7, IsGORC);
16358     Known.One = computeGREVOrGORC(Known.One.getZExtValue(), 7, IsGORC);
16359     break;
16360   }
16361   case RISCVISD::READ_VLENB: {
16362     // We can use the minimum and maximum VLEN values to bound VLENB.  We
16363     // know VLEN must be a power of two.
16364     const unsigned MinVLenB = Subtarget.getRealMinVLen() / 8;
16365     const unsigned MaxVLenB = Subtarget.getRealMaxVLen() / 8;
16366     assert(MinVLenB > 0 && "READ_VLENB without vector extension enabled?");
16367     Known.Zero.setLowBits(Log2_32(MinVLenB));
16368     Known.Zero.setBitsFrom(Log2_32(MaxVLenB)+1);
16369     if (MaxVLenB == MinVLenB)
16370       Known.One.setBit(Log2_32(MinVLenB));
16371     break;
16372   }
16373   case RISCVISD::FCLASS: {
16374     // fclass will only set one of the low 10 bits.
16375     Known.Zero.setBitsFrom(10);
16376     break;
16377   }
16378   case ISD::INTRINSIC_W_CHAIN:
16379   case ISD::INTRINSIC_WO_CHAIN: {
16380     unsigned IntNo =
16381         Op.getConstantOperandVal(Opc == ISD::INTRINSIC_WO_CHAIN ? 0 : 1);
16382     switch (IntNo) {
16383     default:
16384       // We can't do anything for most intrinsics.
16385       break;
16386     case Intrinsic::riscv_vsetvli:
16387     case Intrinsic::riscv_vsetvlimax: {
16388       bool HasAVL = IntNo == Intrinsic::riscv_vsetvli;
16389       unsigned VSEW = Op.getConstantOperandVal(HasAVL + 1);
16390       RISCVII::VLMUL VLMUL =
16391           static_cast<RISCVII::VLMUL>(Op.getConstantOperandVal(HasAVL + 2));
16392       unsigned SEW = RISCVVType::decodeVSEW(VSEW);
16393       auto [LMul, Fractional] = RISCVVType::decodeVLMUL(VLMUL);
16394       uint64_t MaxVL = Subtarget.getRealMaxVLen() / SEW;
16395       MaxVL = (Fractional) ? MaxVL / LMul : MaxVL * LMul;
16396 
16397       // Result of vsetvli must be not larger than AVL.
16398       if (HasAVL && isa<ConstantSDNode>(Op.getOperand(1)))
16399         MaxVL = std::min(MaxVL, Op.getConstantOperandVal(1));
16400 
16401       unsigned KnownZeroFirstBit = Log2_32(MaxVL) + 1;
16402       if (BitWidth > KnownZeroFirstBit)
16403         Known.Zero.setBitsFrom(KnownZeroFirstBit);
16404       break;
16405     }
16406     }
16407     break;
16408   }
16409   }
16410 }
16411 
ComputeNumSignBitsForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const16412 unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
16413     SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
16414     unsigned Depth) const {
16415   switch (Op.getOpcode()) {
16416   default:
16417     break;
16418   case RISCVISD::SELECT_CC: {
16419     unsigned Tmp =
16420         DAG.ComputeNumSignBits(Op.getOperand(3), DemandedElts, Depth + 1);
16421     if (Tmp == 1) return 1;  // Early out.
16422     unsigned Tmp2 =
16423         DAG.ComputeNumSignBits(Op.getOperand(4), DemandedElts, Depth + 1);
16424     return std::min(Tmp, Tmp2);
16425   }
16426   case RISCVISD::CZERO_EQZ:
16427   case RISCVISD::CZERO_NEZ:
16428     // Output is either all zero or operand 0. We can propagate sign bit count
16429     // from operand 0.
16430     return DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
16431   case RISCVISD::ABSW: {
16432     // We expand this at isel to negw+max. The result will have 33 sign bits
16433     // if the input has at least 33 sign bits.
16434     unsigned Tmp =
16435         DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
16436     if (Tmp < 33) return 1;
16437     return 33;
16438   }
16439   case RISCVISD::SLLW:
16440   case RISCVISD::SRAW:
16441   case RISCVISD::SRLW:
16442   case RISCVISD::DIVW:
16443   case RISCVISD::DIVUW:
16444   case RISCVISD::REMUW:
16445   case RISCVISD::ROLW:
16446   case RISCVISD::RORW:
16447   case RISCVISD::FCVT_W_RV64:
16448   case RISCVISD::FCVT_WU_RV64:
16449   case RISCVISD::STRICT_FCVT_W_RV64:
16450   case RISCVISD::STRICT_FCVT_WU_RV64:
16451     // TODO: As the result is sign-extended, this is conservatively correct. A
16452     // more precise answer could be calculated for SRAW depending on known
16453     // bits in the shift amount.
16454     return 33;
16455   case RISCVISD::VMV_X_S: {
16456     // The number of sign bits of the scalar result is computed by obtaining the
16457     // element type of the input vector operand, subtracting its width from the
16458     // XLEN, and then adding one (sign bit within the element type). If the
16459     // element type is wider than XLen, the least-significant XLEN bits are
16460     // taken.
16461     unsigned XLen = Subtarget.getXLen();
16462     unsigned EltBits = Op.getOperand(0).getScalarValueSizeInBits();
16463     if (EltBits <= XLen)
16464       return XLen - EltBits + 1;
16465     break;
16466   }
16467   case ISD::INTRINSIC_W_CHAIN: {
16468     unsigned IntNo = Op.getConstantOperandVal(1);
16469     switch (IntNo) {
16470     default:
16471       break;
16472     case Intrinsic::riscv_masked_atomicrmw_xchg_i64:
16473     case Intrinsic::riscv_masked_atomicrmw_add_i64:
16474     case Intrinsic::riscv_masked_atomicrmw_sub_i64:
16475     case Intrinsic::riscv_masked_atomicrmw_nand_i64:
16476     case Intrinsic::riscv_masked_atomicrmw_max_i64:
16477     case Intrinsic::riscv_masked_atomicrmw_min_i64:
16478     case Intrinsic::riscv_masked_atomicrmw_umax_i64:
16479     case Intrinsic::riscv_masked_atomicrmw_umin_i64:
16480     case Intrinsic::riscv_masked_cmpxchg_i64:
16481       // riscv_masked_{atomicrmw_*,cmpxchg} intrinsics represent an emulated
16482       // narrow atomic operation. These are implemented using atomic
16483       // operations at the minimum supported atomicrmw/cmpxchg width whose
16484       // result is then sign extended to XLEN. With +A, the minimum width is
16485       // 32 for both 64 and 32.
16486       assert(Subtarget.getXLen() == 64);
16487       assert(getMinCmpXchgSizeInBits() == 32);
16488       assert(Subtarget.hasStdExtA());
16489       return 33;
16490     }
16491     break;
16492   }
16493   }
16494 
16495   return 1;
16496 }
16497 
16498 const Constant *
getTargetConstantFromLoad(LoadSDNode * Ld) const16499 RISCVTargetLowering::getTargetConstantFromLoad(LoadSDNode *Ld) const {
16500   assert(Ld && "Unexpected null LoadSDNode");
16501   if (!ISD::isNormalLoad(Ld))
16502     return nullptr;
16503 
16504   SDValue Ptr = Ld->getBasePtr();
16505 
16506   // Only constant pools with no offset are supported.
16507   auto GetSupportedConstantPool = [](SDValue Ptr) -> ConstantPoolSDNode * {
16508     auto *CNode = dyn_cast<ConstantPoolSDNode>(Ptr);
16509     if (!CNode || CNode->isMachineConstantPoolEntry() ||
16510         CNode->getOffset() != 0)
16511       return nullptr;
16512 
16513     return CNode;
16514   };
16515 
16516   // Simple case, LLA.
16517   if (Ptr.getOpcode() == RISCVISD::LLA) {
16518     auto *CNode = GetSupportedConstantPool(Ptr);
16519     if (!CNode || CNode->getTargetFlags() != 0)
16520       return nullptr;
16521 
16522     return CNode->getConstVal();
16523   }
16524 
16525   // Look for a HI and ADD_LO pair.
16526   if (Ptr.getOpcode() != RISCVISD::ADD_LO ||
16527       Ptr.getOperand(0).getOpcode() != RISCVISD::HI)
16528     return nullptr;
16529 
16530   auto *CNodeLo = GetSupportedConstantPool(Ptr.getOperand(1));
16531   auto *CNodeHi = GetSupportedConstantPool(Ptr.getOperand(0).getOperand(0));
16532 
16533   if (!CNodeLo || CNodeLo->getTargetFlags() != RISCVII::MO_LO ||
16534       !CNodeHi || CNodeHi->getTargetFlags() != RISCVII::MO_HI)
16535     return nullptr;
16536 
16537   if (CNodeLo->getConstVal() != CNodeHi->getConstVal())
16538     return nullptr;
16539 
16540   return CNodeLo->getConstVal();
16541 }
16542 
emitReadCycleWidePseudo(MachineInstr & MI,MachineBasicBlock * BB)16543 static MachineBasicBlock *emitReadCycleWidePseudo(MachineInstr &MI,
16544                                                   MachineBasicBlock *BB) {
16545   assert(MI.getOpcode() == RISCV::ReadCycleWide && "Unexpected instruction");
16546 
16547   // To read the 64-bit cycle CSR on a 32-bit target, we read the two halves.
16548   // Should the count have wrapped while it was being read, we need to try
16549   // again.
16550   // ...
16551   // read:
16552   // rdcycleh x3 # load high word of cycle
16553   // rdcycle  x2 # load low word of cycle
16554   // rdcycleh x4 # load high word of cycle
16555   // bne x3, x4, read # check if high word reads match, otherwise try again
16556   // ...
16557 
16558   MachineFunction &MF = *BB->getParent();
16559   const BasicBlock *LLVM_BB = BB->getBasicBlock();
16560   MachineFunction::iterator It = ++BB->getIterator();
16561 
16562   MachineBasicBlock *LoopMBB = MF.CreateMachineBasicBlock(LLVM_BB);
16563   MF.insert(It, LoopMBB);
16564 
16565   MachineBasicBlock *DoneMBB = MF.CreateMachineBasicBlock(LLVM_BB);
16566   MF.insert(It, DoneMBB);
16567 
16568   // Transfer the remainder of BB and its successor edges to DoneMBB.
16569   DoneMBB->splice(DoneMBB->begin(), BB,
16570                   std::next(MachineBasicBlock::iterator(MI)), BB->end());
16571   DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
16572 
16573   BB->addSuccessor(LoopMBB);
16574 
16575   MachineRegisterInfo &RegInfo = MF.getRegInfo();
16576   Register ReadAgainReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass);
16577   Register LoReg = MI.getOperand(0).getReg();
16578   Register HiReg = MI.getOperand(1).getReg();
16579   DebugLoc DL = MI.getDebugLoc();
16580 
16581   const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
16582   BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), HiReg)
16583       .addImm(RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding)
16584       .addReg(RISCV::X0);
16585   BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), LoReg)
16586       .addImm(RISCVSysReg::lookupSysRegByName("CYCLE")->Encoding)
16587       .addReg(RISCV::X0);
16588   BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), ReadAgainReg)
16589       .addImm(RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding)
16590       .addReg(RISCV::X0);
16591 
16592   BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
16593       .addReg(HiReg)
16594       .addReg(ReadAgainReg)
16595       .addMBB(LoopMBB);
16596 
16597   LoopMBB->addSuccessor(LoopMBB);
16598   LoopMBB->addSuccessor(DoneMBB);
16599 
16600   MI.eraseFromParent();
16601 
16602   return DoneMBB;
16603 }
16604 
emitSplitF64Pseudo(MachineInstr & MI,MachineBasicBlock * BB,const RISCVSubtarget & Subtarget)16605 static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
16606                                              MachineBasicBlock *BB,
16607                                              const RISCVSubtarget &Subtarget) {
16608   assert((MI.getOpcode() == RISCV::SplitF64Pseudo ||
16609           MI.getOpcode() == RISCV::SplitF64Pseudo_INX) &&
16610          "Unexpected instruction");
16611 
16612   MachineFunction &MF = *BB->getParent();
16613   DebugLoc DL = MI.getDebugLoc();
16614   const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
16615   const TargetRegisterInfo *RI = MF.getSubtarget().getRegisterInfo();
16616   Register LoReg = MI.getOperand(0).getReg();
16617   Register HiReg = MI.getOperand(1).getReg();
16618   Register SrcReg = MI.getOperand(2).getReg();
16619 
16620   const TargetRegisterClass *SrcRC = MI.getOpcode() == RISCV::SplitF64Pseudo_INX
16621                                          ? &RISCV::GPRPairRegClass
16622                                          : &RISCV::FPR64RegClass;
16623   int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);
16624 
16625   TII.storeRegToStackSlot(*BB, MI, SrcReg, MI.getOperand(2).isKill(), FI, SrcRC,
16626                           RI, Register());
16627   MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, FI);
16628   MachineMemOperand *MMOLo =
16629       MF.getMachineMemOperand(MPI, MachineMemOperand::MOLoad, 4, Align(8));
16630   MachineMemOperand *MMOHi = MF.getMachineMemOperand(
16631       MPI.getWithOffset(4), MachineMemOperand::MOLoad, 4, Align(8));
16632   BuildMI(*BB, MI, DL, TII.get(RISCV::LW), LoReg)
16633       .addFrameIndex(FI)
16634       .addImm(0)
16635       .addMemOperand(MMOLo);
16636   BuildMI(*BB, MI, DL, TII.get(RISCV::LW), HiReg)
16637       .addFrameIndex(FI)
16638       .addImm(4)
16639       .addMemOperand(MMOHi);
16640   MI.eraseFromParent(); // The pseudo instruction is gone now.
16641   return BB;
16642 }
16643 
emitBuildPairF64Pseudo(MachineInstr & MI,MachineBasicBlock * BB,const RISCVSubtarget & Subtarget)16644 static MachineBasicBlock *emitBuildPairF64Pseudo(MachineInstr &MI,
16645                                                  MachineBasicBlock *BB,
16646                                                  const RISCVSubtarget &Subtarget) {
16647   assert((MI.getOpcode() == RISCV::BuildPairF64Pseudo ||
16648           MI.getOpcode() == RISCV::BuildPairF64Pseudo_INX) &&
16649          "Unexpected instruction");
16650 
16651   MachineFunction &MF = *BB->getParent();
16652   DebugLoc DL = MI.getDebugLoc();
16653   const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
16654   const TargetRegisterInfo *RI = MF.getSubtarget().getRegisterInfo();
16655   Register DstReg = MI.getOperand(0).getReg();
16656   Register LoReg = MI.getOperand(1).getReg();
16657   Register HiReg = MI.getOperand(2).getReg();
16658 
16659   const TargetRegisterClass *DstRC =
16660       MI.getOpcode() == RISCV::BuildPairF64Pseudo_INX ? &RISCV::GPRPairRegClass
16661                                                       : &RISCV::FPR64RegClass;
16662   int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);
16663 
16664   MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, FI);
16665   MachineMemOperand *MMOLo =
16666       MF.getMachineMemOperand(MPI, MachineMemOperand::MOStore, 4, Align(8));
16667   MachineMemOperand *MMOHi = MF.getMachineMemOperand(
16668       MPI.getWithOffset(4), MachineMemOperand::MOStore, 4, Align(8));
16669   BuildMI(*BB, MI, DL, TII.get(RISCV::SW))
16670       .addReg(LoReg, getKillRegState(MI.getOperand(1).isKill()))
16671       .addFrameIndex(FI)
16672       .addImm(0)
16673       .addMemOperand(MMOLo);
16674   BuildMI(*BB, MI, DL, TII.get(RISCV::SW))
16675       .addReg(HiReg, getKillRegState(MI.getOperand(2).isKill()))
16676       .addFrameIndex(FI)
16677       .addImm(4)
16678       .addMemOperand(MMOHi);
16679   TII.loadRegFromStackSlot(*BB, MI, DstReg, FI, DstRC, RI, Register());
16680   MI.eraseFromParent(); // The pseudo instruction is gone now.
16681   return BB;
16682 }
16683 
isSelectPseudo(MachineInstr & MI)16684 static bool isSelectPseudo(MachineInstr &MI) {
16685   switch (MI.getOpcode()) {
16686   default:
16687     return false;
16688   case RISCV::Select_GPR_Using_CC_GPR:
16689   case RISCV::Select_FPR16_Using_CC_GPR:
16690   case RISCV::Select_FPR16INX_Using_CC_GPR:
16691   case RISCV::Select_FPR32_Using_CC_GPR:
16692   case RISCV::Select_FPR32INX_Using_CC_GPR:
16693   case RISCV::Select_FPR64_Using_CC_GPR:
16694   case RISCV::Select_FPR64INX_Using_CC_GPR:
16695   case RISCV::Select_FPR64IN32X_Using_CC_GPR:
16696     return true;
16697   }
16698 }
16699 
emitQuietFCMP(MachineInstr & MI,MachineBasicBlock * BB,unsigned RelOpcode,unsigned EqOpcode,const RISCVSubtarget & Subtarget)16700 static MachineBasicBlock *emitQuietFCMP(MachineInstr &MI, MachineBasicBlock *BB,
16701                                         unsigned RelOpcode, unsigned EqOpcode,
16702                                         const RISCVSubtarget &Subtarget) {
16703   DebugLoc DL = MI.getDebugLoc();
16704   Register DstReg = MI.getOperand(0).getReg();
16705   Register Src1Reg = MI.getOperand(1).getReg();
16706   Register Src2Reg = MI.getOperand(2).getReg();
16707   MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
16708   Register SavedFFlags = MRI.createVirtualRegister(&RISCV::GPRRegClass);
16709   const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo();
16710 
16711   // Save the current FFLAGS.
16712   BuildMI(*BB, MI, DL, TII.get(RISCV::ReadFFLAGS), SavedFFlags);
16713 
16714   auto MIB = BuildMI(*BB, MI, DL, TII.get(RelOpcode), DstReg)
16715                  .addReg(Src1Reg)
16716                  .addReg(Src2Reg);
16717   if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
16718     MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
16719 
16720   // Restore the FFLAGS.
16721   BuildMI(*BB, MI, DL, TII.get(RISCV::WriteFFLAGS))
16722       .addReg(SavedFFlags, RegState::Kill);
16723 
16724   // Issue a dummy FEQ opcode to raise exception for signaling NaNs.
16725   auto MIB2 = BuildMI(*BB, MI, DL, TII.get(EqOpcode), RISCV::X0)
16726                   .addReg(Src1Reg, getKillRegState(MI.getOperand(1).isKill()))
16727                   .addReg(Src2Reg, getKillRegState(MI.getOperand(2).isKill()));
16728   if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
16729     MIB2->setFlag(MachineInstr::MIFlag::NoFPExcept);
16730 
16731   // Erase the pseudoinstruction.
16732   MI.eraseFromParent();
16733   return BB;
16734 }
16735 
16736 static MachineBasicBlock *
EmitLoweredCascadedSelect(MachineInstr & First,MachineInstr & Second,MachineBasicBlock * ThisMBB,const RISCVSubtarget & Subtarget)16737 EmitLoweredCascadedSelect(MachineInstr &First, MachineInstr &Second,
16738                           MachineBasicBlock *ThisMBB,
16739                           const RISCVSubtarget &Subtarget) {
16740   // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5)
16741   // Without this, custom-inserter would have generated:
16742   //
16743   //   A
16744   //   | \
16745   //   |  B
16746   //   | /
16747   //   C
16748   //   | \
16749   //   |  D
16750   //   | /
16751   //   E
16752   //
16753   // A: X = ...; Y = ...
16754   // B: empty
16755   // C: Z = PHI [X, A], [Y, B]
16756   // D: empty
16757   // E: PHI [X, C], [Z, D]
16758   //
16759   // If we lower both Select_FPRX_ in a single step, we can instead generate:
16760   //
16761   //   A
16762   //   | \
16763   //   |  C
16764   //   | /|
16765   //   |/ |
16766   //   |  |
16767   //   |  D
16768   //   | /
16769   //   E
16770   //
16771   // A: X = ...; Y = ...
16772   // D: empty
16773   // E: PHI [X, A], [X, C], [Y, D]
16774 
16775   const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
16776   const DebugLoc &DL = First.getDebugLoc();
16777   const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock();
16778   MachineFunction *F = ThisMBB->getParent();
16779   MachineBasicBlock *FirstMBB = F->CreateMachineBasicBlock(LLVM_BB);
16780   MachineBasicBlock *SecondMBB = F->CreateMachineBasicBlock(LLVM_BB);
16781   MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB);
16782   MachineFunction::iterator It = ++ThisMBB->getIterator();
16783   F->insert(It, FirstMBB);
16784   F->insert(It, SecondMBB);
16785   F->insert(It, SinkMBB);
16786 
16787   // Transfer the remainder of ThisMBB and its successor edges to SinkMBB.
16788   SinkMBB->splice(SinkMBB->begin(), ThisMBB,
16789                   std::next(MachineBasicBlock::iterator(First)),
16790                   ThisMBB->end());
16791   SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB);
16792 
16793   // Fallthrough block for ThisMBB.
16794   ThisMBB->addSuccessor(FirstMBB);
16795   // Fallthrough block for FirstMBB.
16796   FirstMBB->addSuccessor(SecondMBB);
16797   ThisMBB->addSuccessor(SinkMBB);
16798   FirstMBB->addSuccessor(SinkMBB);
16799   // This is fallthrough.
16800   SecondMBB->addSuccessor(SinkMBB);
16801 
16802   auto FirstCC = static_cast<RISCVCC::CondCode>(First.getOperand(3).getImm());
16803   Register FLHS = First.getOperand(1).getReg();
16804   Register FRHS = First.getOperand(2).getReg();
16805   // Insert appropriate branch.
16806   BuildMI(FirstMBB, DL, TII.getBrCond(FirstCC))
16807       .addReg(FLHS)
16808       .addReg(FRHS)
16809       .addMBB(SinkMBB);
16810 
16811   Register SLHS = Second.getOperand(1).getReg();
16812   Register SRHS = Second.getOperand(2).getReg();
16813   Register Op1Reg4 = First.getOperand(4).getReg();
16814   Register Op1Reg5 = First.getOperand(5).getReg();
16815 
16816   auto SecondCC = static_cast<RISCVCC::CondCode>(Second.getOperand(3).getImm());
16817   // Insert appropriate branch.
16818   BuildMI(ThisMBB, DL, TII.getBrCond(SecondCC))
16819       .addReg(SLHS)
16820       .addReg(SRHS)
16821       .addMBB(SinkMBB);
16822 
16823   Register DestReg = Second.getOperand(0).getReg();
16824   Register Op2Reg4 = Second.getOperand(4).getReg();
16825   BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg)
16826       .addReg(Op2Reg4)
16827       .addMBB(ThisMBB)
16828       .addReg(Op1Reg4)
16829       .addMBB(FirstMBB)
16830       .addReg(Op1Reg5)
16831       .addMBB(SecondMBB);
16832 
16833   // Now remove the Select_FPRX_s.
16834   First.eraseFromParent();
16835   Second.eraseFromParent();
16836   return SinkMBB;
16837 }
16838 
emitSelectPseudo(MachineInstr & MI,MachineBasicBlock * BB,const RISCVSubtarget & Subtarget)16839 static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
16840                                            MachineBasicBlock *BB,
16841                                            const RISCVSubtarget &Subtarget) {
16842   // To "insert" Select_* instructions, we actually have to insert the triangle
16843   // control-flow pattern.  The incoming instructions know the destination vreg
16844   // to set, the condition code register to branch on, the true/false values to
16845   // select between, and the condcode to use to select the appropriate branch.
16846   //
16847   // We produce the following control flow:
16848   //     HeadMBB
16849   //     |  \
16850   //     |  IfFalseMBB
16851   //     | /
16852   //    TailMBB
16853   //
16854   // When we find a sequence of selects we attempt to optimize their emission
16855   // by sharing the control flow. Currently we only handle cases where we have
16856   // multiple selects with the exact same condition (same LHS, RHS and CC).
16857   // The selects may be interleaved with other instructions if the other
16858   // instructions meet some requirements we deem safe:
16859   // - They are not pseudo instructions.
16860   // - They are debug instructions. Otherwise,
16861   // - They do not have side-effects, do not access memory and their inputs do
16862   //   not depend on the results of the select pseudo-instructions.
16863   // The TrueV/FalseV operands of the selects cannot depend on the result of
16864   // previous selects in the sequence.
16865   // These conditions could be further relaxed. See the X86 target for a
16866   // related approach and more information.
16867   //
16868   // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5))
16869   // is checked here and handled by a separate function -
16870   // EmitLoweredCascadedSelect.
16871   Register LHS = MI.getOperand(1).getReg();
16872   Register RHS = MI.getOperand(2).getReg();
16873   auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());
16874 
16875   SmallVector<MachineInstr *, 4> SelectDebugValues;
16876   SmallSet<Register, 4> SelectDests;
16877   SelectDests.insert(MI.getOperand(0).getReg());
16878 
16879   MachineInstr *LastSelectPseudo = &MI;
16880   auto Next = next_nodbg(MI.getIterator(), BB->instr_end());
16881   if (MI.getOpcode() != RISCV::Select_GPR_Using_CC_GPR && Next != BB->end() &&
16882       Next->getOpcode() == MI.getOpcode() &&
16883       Next->getOperand(5).getReg() == MI.getOperand(0).getReg() &&
16884       Next->getOperand(5).isKill()) {
16885     return EmitLoweredCascadedSelect(MI, *Next, BB, Subtarget);
16886   }
16887 
16888   for (auto E = BB->end(), SequenceMBBI = MachineBasicBlock::iterator(MI);
16889        SequenceMBBI != E; ++SequenceMBBI) {
16890     if (SequenceMBBI->isDebugInstr())
16891       continue;
16892     if (isSelectPseudo(*SequenceMBBI)) {
16893       if (SequenceMBBI->getOperand(1).getReg() != LHS ||
16894           SequenceMBBI->getOperand(2).getReg() != RHS ||
16895           SequenceMBBI->getOperand(3).getImm() != CC ||
16896           SelectDests.count(SequenceMBBI->getOperand(4).getReg()) ||
16897           SelectDests.count(SequenceMBBI->getOperand(5).getReg()))
16898         break;
16899       LastSelectPseudo = &*SequenceMBBI;
16900       SequenceMBBI->collectDebugValues(SelectDebugValues);
16901       SelectDests.insert(SequenceMBBI->getOperand(0).getReg());
16902       continue;
16903     }
16904     if (SequenceMBBI->hasUnmodeledSideEffects() ||
16905         SequenceMBBI->mayLoadOrStore() ||
16906         SequenceMBBI->usesCustomInsertionHook())
16907       break;
16908     if (llvm::any_of(SequenceMBBI->operands(), [&](MachineOperand &MO) {
16909           return MO.isReg() && MO.isUse() && SelectDests.count(MO.getReg());
16910         }))
16911       break;
16912   }
16913 
16914   const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
16915   const BasicBlock *LLVM_BB = BB->getBasicBlock();
16916   DebugLoc DL = MI.getDebugLoc();
16917   MachineFunction::iterator I = ++BB->getIterator();
16918 
16919   MachineBasicBlock *HeadMBB = BB;
16920   MachineFunction *F = BB->getParent();
16921   MachineBasicBlock *TailMBB = F->CreateMachineBasicBlock(LLVM_BB);
16922   MachineBasicBlock *IfFalseMBB = F->CreateMachineBasicBlock(LLVM_BB);
16923 
16924   F->insert(I, IfFalseMBB);
16925   F->insert(I, TailMBB);
16926 
16927   // Transfer debug instructions associated with the selects to TailMBB.
16928   for (MachineInstr *DebugInstr : SelectDebugValues) {
16929     TailMBB->push_back(DebugInstr->removeFromParent());
16930   }
16931 
16932   // Move all instructions after the sequence to TailMBB.
16933   TailMBB->splice(TailMBB->end(), HeadMBB,
16934                   std::next(LastSelectPseudo->getIterator()), HeadMBB->end());
16935   // Update machine-CFG edges by transferring all successors of the current
16936   // block to the new block which will contain the Phi nodes for the selects.
16937   TailMBB->transferSuccessorsAndUpdatePHIs(HeadMBB);
16938   // Set the successors for HeadMBB.
16939   HeadMBB->addSuccessor(IfFalseMBB);
16940   HeadMBB->addSuccessor(TailMBB);
16941 
16942   // Insert appropriate branch.
16943   BuildMI(HeadMBB, DL, TII.getBrCond(CC))
16944     .addReg(LHS)
16945     .addReg(RHS)
16946     .addMBB(TailMBB);
16947 
16948   // IfFalseMBB just falls through to TailMBB.
16949   IfFalseMBB->addSuccessor(TailMBB);
16950 
16951   // Create PHIs for all of the select pseudo-instructions.
16952   auto SelectMBBI = MI.getIterator();
16953   auto SelectEnd = std::next(LastSelectPseudo->getIterator());
16954   auto InsertionPoint = TailMBB->begin();
16955   while (SelectMBBI != SelectEnd) {
16956     auto Next = std::next(SelectMBBI);
16957     if (isSelectPseudo(*SelectMBBI)) {
16958       // %Result = phi [ %TrueValue, HeadMBB ], [ %FalseValue, IfFalseMBB ]
16959       BuildMI(*TailMBB, InsertionPoint, SelectMBBI->getDebugLoc(),
16960               TII.get(RISCV::PHI), SelectMBBI->getOperand(0).getReg())
16961           .addReg(SelectMBBI->getOperand(4).getReg())
16962           .addMBB(HeadMBB)
16963           .addReg(SelectMBBI->getOperand(5).getReg())
16964           .addMBB(IfFalseMBB);
16965       SelectMBBI->eraseFromParent();
16966     }
16967     SelectMBBI = Next;
16968   }
16969 
16970   F->getProperties().reset(MachineFunctionProperties::Property::NoPHIs);
16971   return TailMBB;
16972 }
16973 
emitVFROUND_NOEXCEPT_MASK(MachineInstr & MI,MachineBasicBlock * BB,unsigned CVTXOpc,unsigned CVTFOpc)16974 static MachineBasicBlock *emitVFROUND_NOEXCEPT_MASK(MachineInstr &MI,
16975                                                     MachineBasicBlock *BB,
16976                                                     unsigned CVTXOpc,
16977                                                     unsigned CVTFOpc) {
16978   DebugLoc DL = MI.getDebugLoc();
16979 
16980   const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo();
16981 
16982   MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
16983   Register SavedFFLAGS = MRI.createVirtualRegister(&RISCV::GPRRegClass);
16984 
16985   // Save the old value of FFLAGS.
16986   BuildMI(*BB, MI, DL, TII.get(RISCV::ReadFFLAGS), SavedFFLAGS);
16987 
16988   assert(MI.getNumOperands() == 7);
16989 
16990   // Emit a VFCVT_X_F
16991   const TargetRegisterInfo *TRI =
16992       BB->getParent()->getSubtarget().getRegisterInfo();
16993   const TargetRegisterClass *RC = MI.getRegClassConstraint(0, &TII, TRI);
16994   Register Tmp = MRI.createVirtualRegister(RC);
16995   BuildMI(*BB, MI, DL, TII.get(CVTXOpc), Tmp)
16996       .add(MI.getOperand(1))
16997       .add(MI.getOperand(2))
16998       .add(MI.getOperand(3))
16999       .add(MachineOperand::CreateImm(7)) // frm = DYN
17000       .add(MI.getOperand(4))
17001       .add(MI.getOperand(5))
17002       .add(MI.getOperand(6))
17003       .add(MachineOperand::CreateReg(RISCV::FRM,
17004                                      /*IsDef*/ false,
17005                                      /*IsImp*/ true));
17006 
17007   // Emit a VFCVT_F_X
17008   BuildMI(*BB, MI, DL, TII.get(CVTFOpc))
17009       .add(MI.getOperand(0))
17010       .add(MI.getOperand(1))
17011       .addReg(Tmp)
17012       .add(MI.getOperand(3))
17013       .add(MachineOperand::CreateImm(7)) // frm = DYN
17014       .add(MI.getOperand(4))
17015       .add(MI.getOperand(5))
17016       .add(MI.getOperand(6))
17017       .add(MachineOperand::CreateReg(RISCV::FRM,
17018                                      /*IsDef*/ false,
17019                                      /*IsImp*/ true));
17020 
17021   // Restore FFLAGS.
17022   BuildMI(*BB, MI, DL, TII.get(RISCV::WriteFFLAGS))
17023       .addReg(SavedFFLAGS, RegState::Kill);
17024 
17025   // Erase the pseudoinstruction.
17026   MI.eraseFromParent();
17027   return BB;
17028 }
17029 
emitFROUND(MachineInstr & MI,MachineBasicBlock * MBB,const RISCVSubtarget & Subtarget)17030 static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB,
17031                                      const RISCVSubtarget &Subtarget) {
17032   unsigned CmpOpc, F2IOpc, I2FOpc, FSGNJOpc, FSGNJXOpc;
17033   const TargetRegisterClass *RC;
17034   switch (MI.getOpcode()) {
17035   default:
17036     llvm_unreachable("Unexpected opcode");
17037   case RISCV::PseudoFROUND_H:
17038     CmpOpc = RISCV::FLT_H;
17039     F2IOpc = RISCV::FCVT_W_H;
17040     I2FOpc = RISCV::FCVT_H_W;
17041     FSGNJOpc = RISCV::FSGNJ_H;
17042     FSGNJXOpc = RISCV::FSGNJX_H;
17043     RC = &RISCV::FPR16RegClass;
17044     break;
17045   case RISCV::PseudoFROUND_H_INX:
17046     CmpOpc = RISCV::FLT_H_INX;
17047     F2IOpc = RISCV::FCVT_W_H_INX;
17048     I2FOpc = RISCV::FCVT_H_W_INX;
17049     FSGNJOpc = RISCV::FSGNJ_H_INX;
17050     FSGNJXOpc = RISCV::FSGNJX_H_INX;
17051     RC = &RISCV::GPRF16RegClass;
17052     break;
17053   case RISCV::PseudoFROUND_S:
17054     CmpOpc = RISCV::FLT_S;
17055     F2IOpc = RISCV::FCVT_W_S;
17056     I2FOpc = RISCV::FCVT_S_W;
17057     FSGNJOpc = RISCV::FSGNJ_S;
17058     FSGNJXOpc = RISCV::FSGNJX_S;
17059     RC = &RISCV::FPR32RegClass;
17060     break;
17061   case RISCV::PseudoFROUND_S_INX:
17062     CmpOpc = RISCV::FLT_S_INX;
17063     F2IOpc = RISCV::FCVT_W_S_INX;
17064     I2FOpc = RISCV::FCVT_S_W_INX;
17065     FSGNJOpc = RISCV::FSGNJ_S_INX;
17066     FSGNJXOpc = RISCV::FSGNJX_S_INX;
17067     RC = &RISCV::GPRF32RegClass;
17068     break;
17069   case RISCV::PseudoFROUND_D:
17070     assert(Subtarget.is64Bit() && "Expected 64-bit GPR.");
17071     CmpOpc = RISCV::FLT_D;
17072     F2IOpc = RISCV::FCVT_L_D;
17073     I2FOpc = RISCV::FCVT_D_L;
17074     FSGNJOpc = RISCV::FSGNJ_D;
17075     FSGNJXOpc = RISCV::FSGNJX_D;
17076     RC = &RISCV::FPR64RegClass;
17077     break;
17078   case RISCV::PseudoFROUND_D_INX:
17079     assert(Subtarget.is64Bit() && "Expected 64-bit GPR.");
17080     CmpOpc = RISCV::FLT_D_INX;
17081     F2IOpc = RISCV::FCVT_L_D_INX;
17082     I2FOpc = RISCV::FCVT_D_L_INX;
17083     FSGNJOpc = RISCV::FSGNJ_D_INX;
17084     FSGNJXOpc = RISCV::FSGNJX_D_INX;
17085     RC = &RISCV::GPRRegClass;
17086     break;
17087   }
17088 
17089   const BasicBlock *BB = MBB->getBasicBlock();
17090   DebugLoc DL = MI.getDebugLoc();
17091   MachineFunction::iterator I = ++MBB->getIterator();
17092 
17093   MachineFunction *F = MBB->getParent();
17094   MachineBasicBlock *CvtMBB = F->CreateMachineBasicBlock(BB);
17095   MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(BB);
17096 
17097   F->insert(I, CvtMBB);
17098   F->insert(I, DoneMBB);
17099   // Move all instructions after the sequence to DoneMBB.
17100   DoneMBB->splice(DoneMBB->end(), MBB, MachineBasicBlock::iterator(MI),
17101                   MBB->end());
17102   // Update machine-CFG edges by transferring all successors of the current
17103   // block to the new block which will contain the Phi nodes for the selects.
17104   DoneMBB->transferSuccessorsAndUpdatePHIs(MBB);
17105   // Set the successors for MBB.
17106   MBB->addSuccessor(CvtMBB);
17107   MBB->addSuccessor(DoneMBB);
17108 
17109   Register DstReg = MI.getOperand(0).getReg();
17110   Register SrcReg = MI.getOperand(1).getReg();
17111   Register MaxReg = MI.getOperand(2).getReg();
17112   int64_t FRM = MI.getOperand(3).getImm();
17113 
17114   const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
17115   MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
17116 
17117   Register FabsReg = MRI.createVirtualRegister(RC);
17118   BuildMI(MBB, DL, TII.get(FSGNJXOpc), FabsReg).addReg(SrcReg).addReg(SrcReg);
17119 
17120   // Compare the FP value to the max value.
17121   Register CmpReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
17122   auto MIB =
17123       BuildMI(MBB, DL, TII.get(CmpOpc), CmpReg).addReg(FabsReg).addReg(MaxReg);
17124   if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
17125     MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
17126 
17127   // Insert branch.
17128   BuildMI(MBB, DL, TII.get(RISCV::BEQ))
17129       .addReg(CmpReg)
17130       .addReg(RISCV::X0)
17131       .addMBB(DoneMBB);
17132 
17133   CvtMBB->addSuccessor(DoneMBB);
17134 
17135   // Convert to integer.
17136   Register F2IReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
17137   MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg).addImm(FRM);
17138   if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
17139     MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
17140 
17141   // Convert back to FP.
17142   Register I2FReg = MRI.createVirtualRegister(RC);
17143   MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg).addImm(FRM);
17144   if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
17145     MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
17146 
17147   // Restore the sign bit.
17148   Register CvtReg = MRI.createVirtualRegister(RC);
17149   BuildMI(CvtMBB, DL, TII.get(FSGNJOpc), CvtReg).addReg(I2FReg).addReg(SrcReg);
17150 
17151   // Merge the results.
17152   BuildMI(*DoneMBB, DoneMBB->begin(), DL, TII.get(RISCV::PHI), DstReg)
17153       .addReg(SrcReg)
17154       .addMBB(MBB)
17155       .addReg(CvtReg)
17156       .addMBB(CvtMBB);
17157 
17158   MI.eraseFromParent();
17159   return DoneMBB;
17160 }
17161 
17162 MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr & MI,MachineBasicBlock * BB) const17163 RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
17164                                                  MachineBasicBlock *BB) const {
17165   switch (MI.getOpcode()) {
17166   default:
17167     llvm_unreachable("Unexpected instr type to insert");
17168   case RISCV::ReadCycleWide:
17169     assert(!Subtarget.is64Bit() &&
17170            "ReadCycleWrite is only to be used on riscv32");
17171     return emitReadCycleWidePseudo(MI, BB);
17172   case RISCV::Select_GPR_Using_CC_GPR:
17173   case RISCV::Select_FPR16_Using_CC_GPR:
17174   case RISCV::Select_FPR16INX_Using_CC_GPR:
17175   case RISCV::Select_FPR32_Using_CC_GPR:
17176   case RISCV::Select_FPR32INX_Using_CC_GPR:
17177   case RISCV::Select_FPR64_Using_CC_GPR:
17178   case RISCV::Select_FPR64INX_Using_CC_GPR:
17179   case RISCV::Select_FPR64IN32X_Using_CC_GPR:
17180     return emitSelectPseudo(MI, BB, Subtarget);
17181   case RISCV::BuildPairF64Pseudo:
17182   case RISCV::BuildPairF64Pseudo_INX:
17183     return emitBuildPairF64Pseudo(MI, BB, Subtarget);
17184   case RISCV::SplitF64Pseudo:
17185   case RISCV::SplitF64Pseudo_INX:
17186     return emitSplitF64Pseudo(MI, BB, Subtarget);
17187   case RISCV::PseudoQuietFLE_H:
17188     return emitQuietFCMP(MI, BB, RISCV::FLE_H, RISCV::FEQ_H, Subtarget);
17189   case RISCV::PseudoQuietFLE_H_INX:
17190     return emitQuietFCMP(MI, BB, RISCV::FLE_H_INX, RISCV::FEQ_H_INX, Subtarget);
17191   case RISCV::PseudoQuietFLT_H:
17192     return emitQuietFCMP(MI, BB, RISCV::FLT_H, RISCV::FEQ_H, Subtarget);
17193   case RISCV::PseudoQuietFLT_H_INX:
17194     return emitQuietFCMP(MI, BB, RISCV::FLT_H_INX, RISCV::FEQ_H_INX, Subtarget);
17195   case RISCV::PseudoQuietFLE_S:
17196     return emitQuietFCMP(MI, BB, RISCV::FLE_S, RISCV::FEQ_S, Subtarget);
17197   case RISCV::PseudoQuietFLE_S_INX:
17198     return emitQuietFCMP(MI, BB, RISCV::FLE_S_INX, RISCV::FEQ_S_INX, Subtarget);
17199   case RISCV::PseudoQuietFLT_S:
17200     return emitQuietFCMP(MI, BB, RISCV::FLT_S, RISCV::FEQ_S, Subtarget);
17201   case RISCV::PseudoQuietFLT_S_INX:
17202     return emitQuietFCMP(MI, BB, RISCV::FLT_S_INX, RISCV::FEQ_S_INX, Subtarget);
17203   case RISCV::PseudoQuietFLE_D:
17204     return emitQuietFCMP(MI, BB, RISCV::FLE_D, RISCV::FEQ_D, Subtarget);
17205   case RISCV::PseudoQuietFLE_D_INX:
17206     return emitQuietFCMP(MI, BB, RISCV::FLE_D_INX, RISCV::FEQ_D_INX, Subtarget);
17207   case RISCV::PseudoQuietFLE_D_IN32X:
17208     return emitQuietFCMP(MI, BB, RISCV::FLE_D_IN32X, RISCV::FEQ_D_IN32X,
17209                          Subtarget);
17210   case RISCV::PseudoQuietFLT_D:
17211     return emitQuietFCMP(MI, BB, RISCV::FLT_D, RISCV::FEQ_D, Subtarget);
17212   case RISCV::PseudoQuietFLT_D_INX:
17213     return emitQuietFCMP(MI, BB, RISCV::FLT_D_INX, RISCV::FEQ_D_INX, Subtarget);
17214   case RISCV::PseudoQuietFLT_D_IN32X:
17215     return emitQuietFCMP(MI, BB, RISCV::FLT_D_IN32X, RISCV::FEQ_D_IN32X,
17216                          Subtarget);
17217 
17218   case RISCV::PseudoVFROUND_NOEXCEPT_V_M1_MASK:
17219     return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M1_MASK,
17220                                      RISCV::PseudoVFCVT_F_X_V_M1_MASK);
17221   case RISCV::PseudoVFROUND_NOEXCEPT_V_M2_MASK:
17222     return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M2_MASK,
17223                                      RISCV::PseudoVFCVT_F_X_V_M2_MASK);
17224   case RISCV::PseudoVFROUND_NOEXCEPT_V_M4_MASK:
17225     return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M4_MASK,
17226                                      RISCV::PseudoVFCVT_F_X_V_M4_MASK);
17227   case RISCV::PseudoVFROUND_NOEXCEPT_V_M8_MASK:
17228     return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M8_MASK,
17229                                      RISCV::PseudoVFCVT_F_X_V_M8_MASK);
17230   case RISCV::PseudoVFROUND_NOEXCEPT_V_MF2_MASK:
17231     return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_MF2_MASK,
17232                                      RISCV::PseudoVFCVT_F_X_V_MF2_MASK);
17233   case RISCV::PseudoVFROUND_NOEXCEPT_V_MF4_MASK:
17234     return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_MF4_MASK,
17235                                      RISCV::PseudoVFCVT_F_X_V_MF4_MASK);
17236   case RISCV::PseudoFROUND_H:
17237   case RISCV::PseudoFROUND_H_INX:
17238   case RISCV::PseudoFROUND_S:
17239   case RISCV::PseudoFROUND_S_INX:
17240   case RISCV::PseudoFROUND_D:
17241   case RISCV::PseudoFROUND_D_INX:
17242   case RISCV::PseudoFROUND_D_IN32X:
17243     return emitFROUND(MI, BB, Subtarget);
17244   case TargetOpcode::STATEPOINT:
17245   case TargetOpcode::STACKMAP:
17246   case TargetOpcode::PATCHPOINT:
17247     if (!Subtarget.is64Bit())
17248       report_fatal_error("STACKMAP, PATCHPOINT and STATEPOINT are only "
17249                          "supported on 64-bit targets");
17250     return emitPatchPoint(MI, BB);
17251   }
17252 }
17253 
AdjustInstrPostInstrSelection(MachineInstr & MI,SDNode * Node) const17254 void RISCVTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
17255                                                         SDNode *Node) const {
17256   // Add FRM dependency to any instructions with dynamic rounding mode.
17257   int Idx = RISCV::getNamedOperandIdx(MI.getOpcode(), RISCV::OpName::frm);
17258   if (Idx < 0) {
17259     // Vector pseudos have FRM index indicated by TSFlags.
17260     Idx = RISCVII::getFRMOpNum(MI.getDesc());
17261     if (Idx < 0)
17262       return;
17263   }
17264   if (MI.getOperand(Idx).getImm() != RISCVFPRndMode::DYN)
17265     return;
17266   // If the instruction already reads FRM, don't add another read.
17267   if (MI.readsRegister(RISCV::FRM))
17268     return;
17269   MI.addOperand(
17270       MachineOperand::CreateReg(RISCV::FRM, /*isDef*/ false, /*isImp*/ true));
17271 }
17272 
17273 // Calling Convention Implementation.
17274 // The expectations for frontend ABI lowering vary from target to target.
17275 // Ideally, an LLVM frontend would be able to avoid worrying about many ABI
17276 // details, but this is a longer term goal. For now, we simply try to keep the
17277 // role of the frontend as simple and well-defined as possible. The rules can
17278 // be summarised as:
17279 // * Never split up large scalar arguments. We handle them here.
17280 // * If a hardfloat calling convention is being used, and the struct may be
17281 // passed in a pair of registers (fp+fp, int+fp), and both registers are
17282 // available, then pass as two separate arguments. If either the GPRs or FPRs
17283 // are exhausted, then pass according to the rule below.
17284 // * If a struct could never be passed in registers or directly in a stack
17285 // slot (as it is larger than 2*XLEN and the floating point rules don't
17286 // apply), then pass it using a pointer with the byval attribute.
17287 // * If a struct is less than 2*XLEN, then coerce to either a two-element
17288 // word-sized array or a 2*XLEN scalar (depending on alignment).
17289 // * The frontend can determine whether a struct is returned by reference or
17290 // not based on its size and fields. If it will be returned by reference, the
17291 // frontend must modify the prototype so a pointer with the sret annotation is
17292 // passed as the first argument. This is not necessary for large scalar
17293 // returns.
17294 // * Struct return values and varargs should be coerced to structs containing
17295 // register-size fields in the same situations they would be for fixed
17296 // arguments.
17297 
17298 static const MCPhysReg ArgFPR16s[] = {
17299   RISCV::F10_H, RISCV::F11_H, RISCV::F12_H, RISCV::F13_H,
17300   RISCV::F14_H, RISCV::F15_H, RISCV::F16_H, RISCV::F17_H
17301 };
17302 static const MCPhysReg ArgFPR32s[] = {
17303   RISCV::F10_F, RISCV::F11_F, RISCV::F12_F, RISCV::F13_F,
17304   RISCV::F14_F, RISCV::F15_F, RISCV::F16_F, RISCV::F17_F
17305 };
17306 static const MCPhysReg ArgFPR64s[] = {
17307   RISCV::F10_D, RISCV::F11_D, RISCV::F12_D, RISCV::F13_D,
17308   RISCV::F14_D, RISCV::F15_D, RISCV::F16_D, RISCV::F17_D
17309 };
17310 // This is an interim calling convention and it may be changed in the future.
17311 static const MCPhysReg ArgVRs[] = {
17312     RISCV::V8,  RISCV::V9,  RISCV::V10, RISCV::V11, RISCV::V12, RISCV::V13,
17313     RISCV::V14, RISCV::V15, RISCV::V16, RISCV::V17, RISCV::V18, RISCV::V19,
17314     RISCV::V20, RISCV::V21, RISCV::V22, RISCV::V23};
17315 static const MCPhysReg ArgVRM2s[] = {RISCV::V8M2,  RISCV::V10M2, RISCV::V12M2,
17316                                      RISCV::V14M2, RISCV::V16M2, RISCV::V18M2,
17317                                      RISCV::V20M2, RISCV::V22M2};
17318 static const MCPhysReg ArgVRM4s[] = {RISCV::V8M4, RISCV::V12M4, RISCV::V16M4,
17319                                      RISCV::V20M4};
17320 static const MCPhysReg ArgVRM8s[] = {RISCV::V8M8, RISCV::V16M8};
17321 
getArgGPRs(const RISCVABI::ABI ABI)17322 ArrayRef<MCPhysReg> RISCV::getArgGPRs(const RISCVABI::ABI ABI) {
17323   // The GPRs used for passing arguments in the ILP32* and LP64* ABIs, except
17324   // the ILP32E ABI.
17325   static const MCPhysReg ArgIGPRs[] = {RISCV::X10, RISCV::X11, RISCV::X12,
17326                                        RISCV::X13, RISCV::X14, RISCV::X15,
17327                                        RISCV::X16, RISCV::X17};
17328   // The GPRs used for passing arguments in the ILP32E/ILP64E ABI.
17329   static const MCPhysReg ArgEGPRs[] = {RISCV::X10, RISCV::X11, RISCV::X12,
17330                                        RISCV::X13, RISCV::X14, RISCV::X15};
17331 
17332   if (ABI == RISCVABI::ABI_ILP32E || ABI == RISCVABI::ABI_LP64E)
17333     return ArrayRef(ArgEGPRs);
17334 
17335   return ArrayRef(ArgIGPRs);
17336 }
17337 
getFastCCArgGPRs(const RISCVABI::ABI ABI)17338 static ArrayRef<MCPhysReg> getFastCCArgGPRs(const RISCVABI::ABI ABI) {
17339   // The GPRs used for passing arguments in the FastCC, X5 and X6 might be used
17340   // for save-restore libcall, so we don't use them.
17341   static const MCPhysReg FastCCIGPRs[] = {
17342       RISCV::X10, RISCV::X11, RISCV::X12, RISCV::X13, RISCV::X14,
17343       RISCV::X15, RISCV::X16, RISCV::X17, RISCV::X7,  RISCV::X28,
17344       RISCV::X29, RISCV::X30, RISCV::X31};
17345 
17346   // The GPRs used for passing arguments in the FastCC when using ILP32E/ILP64E.
17347   static const MCPhysReg FastCCEGPRs[] = {RISCV::X10, RISCV::X11, RISCV::X12,
17348                                           RISCV::X13, RISCV::X14, RISCV::X15,
17349                                           RISCV::X7};
17350 
17351   if (ABI == RISCVABI::ABI_ILP32E || ABI == RISCVABI::ABI_LP64E)
17352     return ArrayRef(FastCCEGPRs);
17353 
17354   return ArrayRef(FastCCIGPRs);
17355 }
17356 
17357 // Pass a 2*XLEN argument that has been split into two XLEN values through
17358 // registers or the stack as necessary.
CC_RISCVAssign2XLen(unsigned XLen,CCState & State,CCValAssign VA1,ISD::ArgFlagsTy ArgFlags1,unsigned ValNo2,MVT ValVT2,MVT LocVT2,ISD::ArgFlagsTy ArgFlags2,bool EABI)17359 static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
17360                                 ISD::ArgFlagsTy ArgFlags1, unsigned ValNo2,
17361                                 MVT ValVT2, MVT LocVT2,
17362                                 ISD::ArgFlagsTy ArgFlags2, bool EABI) {
17363   unsigned XLenInBytes = XLen / 8;
17364   const RISCVSubtarget &STI =
17365       State.getMachineFunction().getSubtarget<RISCVSubtarget>();
17366   ArrayRef<MCPhysReg> ArgGPRs = RISCV::getArgGPRs(STI.getTargetABI());
17367 
17368   if (Register Reg = State.AllocateReg(ArgGPRs)) {
17369     // At least one half can be passed via register.
17370     State.addLoc(CCValAssign::getReg(VA1.getValNo(), VA1.getValVT(), Reg,
17371                                      VA1.getLocVT(), CCValAssign::Full));
17372   } else {
17373     // Both halves must be passed on the stack, with proper alignment.
17374     // TODO: To be compatible with GCC's behaviors, we force them to have 4-byte
17375     // alignment. This behavior may be changed when RV32E/ILP32E is ratified.
17376     Align StackAlign(XLenInBytes);
17377     if (!EABI || XLen != 32)
17378       StackAlign = std::max(StackAlign, ArgFlags1.getNonZeroOrigAlign());
17379     State.addLoc(
17380         CCValAssign::getMem(VA1.getValNo(), VA1.getValVT(),
17381                             State.AllocateStack(XLenInBytes, StackAlign),
17382                             VA1.getLocVT(), CCValAssign::Full));
17383     State.addLoc(CCValAssign::getMem(
17384         ValNo2, ValVT2, State.AllocateStack(XLenInBytes, Align(XLenInBytes)),
17385         LocVT2, CCValAssign::Full));
17386     return false;
17387   }
17388 
17389   if (Register Reg = State.AllocateReg(ArgGPRs)) {
17390     // The second half can also be passed via register.
17391     State.addLoc(
17392         CCValAssign::getReg(ValNo2, ValVT2, Reg, LocVT2, CCValAssign::Full));
17393   } else {
17394     // The second half is passed via the stack, without additional alignment.
17395     State.addLoc(CCValAssign::getMem(
17396         ValNo2, ValVT2, State.AllocateStack(XLenInBytes, Align(XLenInBytes)),
17397         LocVT2, CCValAssign::Full));
17398   }
17399 
17400   return false;
17401 }
17402 
allocateRVVReg(MVT ValVT,unsigned ValNo,std::optional<unsigned> FirstMaskArgument,CCState & State,const RISCVTargetLowering & TLI)17403 static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo,
17404                                std::optional<unsigned> FirstMaskArgument,
17405                                CCState &State, const RISCVTargetLowering &TLI) {
17406   const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
17407   if (RC == &RISCV::VRRegClass) {
17408     // Assign the first mask argument to V0.
17409     // This is an interim calling convention and it may be changed in the
17410     // future.
17411     if (FirstMaskArgument && ValNo == *FirstMaskArgument)
17412       return State.AllocateReg(RISCV::V0);
17413     return State.AllocateReg(ArgVRs);
17414   }
17415   if (RC == &RISCV::VRM2RegClass)
17416     return State.AllocateReg(ArgVRM2s);
17417   if (RC == &RISCV::VRM4RegClass)
17418     return State.AllocateReg(ArgVRM4s);
17419   if (RC == &RISCV::VRM8RegClass)
17420     return State.AllocateReg(ArgVRM8s);
17421   llvm_unreachable("Unhandled register class for ValueType");
17422 }
17423 
17424 // Implements the RISC-V calling convention. Returns true upon failure.
CC_RISCV(const DataLayout & DL,RISCVABI::ABI ABI,unsigned ValNo,MVT ValVT,MVT LocVT,CCValAssign::LocInfo LocInfo,ISD::ArgFlagsTy ArgFlags,CCState & State,bool IsFixed,bool IsRet,Type * OrigTy,const RISCVTargetLowering & TLI,std::optional<unsigned> FirstMaskArgument)17425 bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
17426                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
17427                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
17428                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
17429                      std::optional<unsigned> FirstMaskArgument) {
17430   unsigned XLen = DL.getLargestLegalIntTypeSizeInBits();
17431   assert(XLen == 32 || XLen == 64);
17432   MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64;
17433 
17434   // Static chain parameter must not be passed in normal argument registers,
17435   // so we assign t2 for it as done in GCC's __builtin_call_with_static_chain
17436   if (ArgFlags.isNest()) {
17437     if (unsigned Reg = State.AllocateReg(RISCV::X7)) {
17438       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17439       return false;
17440     }
17441   }
17442 
17443   // Any return value split in to more than two values can't be returned
17444   // directly. Vectors are returned via the available vector registers.
17445   if (!LocVT.isVector() && IsRet && ValNo > 1)
17446     return true;
17447 
17448   // UseGPRForF16_F32 if targeting one of the soft-float ABIs, if passing a
17449   // variadic argument, or if no F16/F32 argument registers are available.
17450   bool UseGPRForF16_F32 = true;
17451   // UseGPRForF64 if targeting soft-float ABIs or an FLEN=32 ABI, if passing a
17452   // variadic argument, or if no F64 argument registers are available.
17453   bool UseGPRForF64 = true;
17454 
17455   switch (ABI) {
17456   default:
17457     llvm_unreachable("Unexpected ABI");
17458   case RISCVABI::ABI_ILP32:
17459   case RISCVABI::ABI_ILP32E:
17460   case RISCVABI::ABI_LP64:
17461   case RISCVABI::ABI_LP64E:
17462     break;
17463   case RISCVABI::ABI_ILP32F:
17464   case RISCVABI::ABI_LP64F:
17465     UseGPRForF16_F32 = !IsFixed;
17466     break;
17467   case RISCVABI::ABI_ILP32D:
17468   case RISCVABI::ABI_LP64D:
17469     UseGPRForF16_F32 = !IsFixed;
17470     UseGPRForF64 = !IsFixed;
17471     break;
17472   }
17473 
17474   // FPR16, FPR32, and FPR64 alias each other.
17475   if (State.getFirstUnallocated(ArgFPR32s) == std::size(ArgFPR32s)) {
17476     UseGPRForF16_F32 = true;
17477     UseGPRForF64 = true;
17478   }
17479 
17480   // From this point on, rely on UseGPRForF16_F32, UseGPRForF64 and
17481   // similar local variables rather than directly checking against the target
17482   // ABI.
17483 
17484   if (UseGPRForF16_F32 &&
17485       (ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) {
17486     LocVT = XLenVT;
17487     LocInfo = CCValAssign::BCvt;
17488   } else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) {
17489     LocVT = MVT::i64;
17490     LocInfo = CCValAssign::BCvt;
17491   }
17492 
17493   ArrayRef<MCPhysReg> ArgGPRs = RISCV::getArgGPRs(ABI);
17494 
17495   // If this is a variadic argument, the RISC-V calling convention requires
17496   // that it is assigned an 'even' or 'aligned' register if it has 8-byte
17497   // alignment (RV32) or 16-byte alignment (RV64). An aligned register should
17498   // be used regardless of whether the original argument was split during
17499   // legalisation or not. The argument will not be passed by registers if the
17500   // original type is larger than 2*XLEN, so the register alignment rule does
17501   // not apply.
17502   // TODO: To be compatible with GCC's behaviors, we don't align registers
17503   // currently if we are using ILP32E calling convention. This behavior may be
17504   // changed when RV32E/ILP32E is ratified.
17505   unsigned TwoXLenInBytes = (2 * XLen) / 8;
17506   if (!IsFixed && ArgFlags.getNonZeroOrigAlign() == TwoXLenInBytes &&
17507       DL.getTypeAllocSize(OrigTy) == TwoXLenInBytes &&
17508       ABI != RISCVABI::ABI_ILP32E) {
17509     unsigned RegIdx = State.getFirstUnallocated(ArgGPRs);
17510     // Skip 'odd' register if necessary.
17511     if (RegIdx != std::size(ArgGPRs) && RegIdx % 2 == 1)
17512       State.AllocateReg(ArgGPRs);
17513   }
17514 
17515   SmallVectorImpl<CCValAssign> &PendingLocs = State.getPendingLocs();
17516   SmallVectorImpl<ISD::ArgFlagsTy> &PendingArgFlags =
17517       State.getPendingArgFlags();
17518 
17519   assert(PendingLocs.size() == PendingArgFlags.size() &&
17520          "PendingLocs and PendingArgFlags out of sync");
17521 
17522   // Handle passing f64 on RV32D with a soft float ABI or when floating point
17523   // registers are exhausted.
17524   if (UseGPRForF64 && XLen == 32 && ValVT == MVT::f64) {
17525     assert(PendingLocs.empty() && "Can't lower f64 if it is split");
17526     // Depending on available argument GPRS, f64 may be passed in a pair of
17527     // GPRs, split between a GPR and the stack, or passed completely on the
17528     // stack. LowerCall/LowerFormalArguments/LowerReturn must recognise these
17529     // cases.
17530     Register Reg = State.AllocateReg(ArgGPRs);
17531     if (!Reg) {
17532       unsigned StackOffset = State.AllocateStack(8, Align(8));
17533       State.addLoc(
17534           CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
17535       return false;
17536     }
17537     LocVT = MVT::i32;
17538     State.addLoc(CCValAssign::getCustomReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17539     Register HiReg = State.AllocateReg(ArgGPRs);
17540     if (HiReg) {
17541       State.addLoc(
17542           CCValAssign::getCustomReg(ValNo, ValVT, HiReg, LocVT, LocInfo));
17543     } else {
17544       unsigned StackOffset = State.AllocateStack(4, Align(4));
17545       State.addLoc(
17546           CCValAssign::getCustomMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
17547     }
17548     return false;
17549   }
17550 
17551   // Fixed-length vectors are located in the corresponding scalable-vector
17552   // container types.
17553   if (ValVT.isFixedLengthVector())
17554     LocVT = TLI.getContainerForFixedLengthVector(LocVT);
17555 
17556   // Split arguments might be passed indirectly, so keep track of the pending
17557   // values. Split vectors are passed via a mix of registers and indirectly, so
17558   // treat them as we would any other argument.
17559   if (ValVT.isScalarInteger() && (ArgFlags.isSplit() || !PendingLocs.empty())) {
17560     LocVT = XLenVT;
17561     LocInfo = CCValAssign::Indirect;
17562     PendingLocs.push_back(
17563         CCValAssign::getPending(ValNo, ValVT, LocVT, LocInfo));
17564     PendingArgFlags.push_back(ArgFlags);
17565     if (!ArgFlags.isSplitEnd()) {
17566       return false;
17567     }
17568   }
17569 
17570   // If the split argument only had two elements, it should be passed directly
17571   // in registers or on the stack.
17572   if (ValVT.isScalarInteger() && ArgFlags.isSplitEnd() &&
17573       PendingLocs.size() <= 2) {
17574     assert(PendingLocs.size() == 2 && "Unexpected PendingLocs.size()");
17575     // Apply the normal calling convention rules to the first half of the
17576     // split argument.
17577     CCValAssign VA = PendingLocs[0];
17578     ISD::ArgFlagsTy AF = PendingArgFlags[0];
17579     PendingLocs.clear();
17580     PendingArgFlags.clear();
17581     return CC_RISCVAssign2XLen(
17582         XLen, State, VA, AF, ValNo, ValVT, LocVT, ArgFlags,
17583         ABI == RISCVABI::ABI_ILP32E || ABI == RISCVABI::ABI_LP64E);
17584   }
17585 
17586   // Allocate to a register if possible, or else a stack slot.
17587   Register Reg;
17588   unsigned StoreSizeBytes = XLen / 8;
17589   Align StackAlign = Align(XLen / 8);
17590 
17591   if ((ValVT == MVT::f16 || ValVT == MVT::bf16) && !UseGPRForF16_F32)
17592     Reg = State.AllocateReg(ArgFPR16s);
17593   else if (ValVT == MVT::f32 && !UseGPRForF16_F32)
17594     Reg = State.AllocateReg(ArgFPR32s);
17595   else if (ValVT == MVT::f64 && !UseGPRForF64)
17596     Reg = State.AllocateReg(ArgFPR64s);
17597   else if (ValVT.isVector()) {
17598     Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI);
17599     if (!Reg) {
17600       // For return values, the vector must be passed fully via registers or
17601       // via the stack.
17602       // FIXME: The proposed vector ABI only mandates v8-v15 for return values,
17603       // but we're using all of them.
17604       if (IsRet)
17605         return true;
17606       // Try using a GPR to pass the address
17607       if ((Reg = State.AllocateReg(ArgGPRs))) {
17608         LocVT = XLenVT;
17609         LocInfo = CCValAssign::Indirect;
17610       } else if (ValVT.isScalableVector()) {
17611         LocVT = XLenVT;
17612         LocInfo = CCValAssign::Indirect;
17613       } else {
17614         // Pass fixed-length vectors on the stack.
17615         LocVT = ValVT;
17616         StoreSizeBytes = ValVT.getStoreSize();
17617         // Align vectors to their element sizes, being careful for vXi1
17618         // vectors.
17619         StackAlign = MaybeAlign(ValVT.getScalarSizeInBits() / 8).valueOrOne();
17620       }
17621     }
17622   } else {
17623     Reg = State.AllocateReg(ArgGPRs);
17624   }
17625 
17626   unsigned StackOffset =
17627       Reg ? 0 : State.AllocateStack(StoreSizeBytes, StackAlign);
17628 
17629   // If we reach this point and PendingLocs is non-empty, we must be at the
17630   // end of a split argument that must be passed indirectly.
17631   if (!PendingLocs.empty()) {
17632     assert(ArgFlags.isSplitEnd() && "Expected ArgFlags.isSplitEnd()");
17633     assert(PendingLocs.size() > 2 && "Unexpected PendingLocs.size()");
17634 
17635     for (auto &It : PendingLocs) {
17636       if (Reg)
17637         It.convertToReg(Reg);
17638       else
17639         It.convertToMem(StackOffset);
17640       State.addLoc(It);
17641     }
17642     PendingLocs.clear();
17643     PendingArgFlags.clear();
17644     return false;
17645   }
17646 
17647   assert((!UseGPRForF16_F32 || !UseGPRForF64 || LocVT == XLenVT ||
17648           (TLI.getSubtarget().hasVInstructions() && ValVT.isVector())) &&
17649          "Expected an XLenVT or vector types at this stage");
17650 
17651   if (Reg) {
17652     State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17653     return false;
17654   }
17655 
17656   // When a scalar floating-point value is passed on the stack, no
17657   // bit-conversion is needed.
17658   if (ValVT.isFloatingPoint() && LocInfo != CCValAssign::Indirect) {
17659     assert(!ValVT.isVector());
17660     LocVT = ValVT;
17661     LocInfo = CCValAssign::Full;
17662   }
17663   State.addLoc(CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
17664   return false;
17665 }
17666 
17667 template <typename ArgTy>
preAssignMask(const ArgTy & Args)17668 static std::optional<unsigned> preAssignMask(const ArgTy &Args) {
17669   for (const auto &ArgIdx : enumerate(Args)) {
17670     MVT ArgVT = ArgIdx.value().VT;
17671     if (ArgVT.isVector() && ArgVT.getVectorElementType() == MVT::i1)
17672       return ArgIdx.index();
17673   }
17674   return std::nullopt;
17675 }
17676 
analyzeInputArgs(MachineFunction & MF,CCState & CCInfo,const SmallVectorImpl<ISD::InputArg> & Ins,bool IsRet,RISCVCCAssignFn Fn) const17677 void RISCVTargetLowering::analyzeInputArgs(
17678     MachineFunction &MF, CCState &CCInfo,
17679     const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
17680     RISCVCCAssignFn Fn) const {
17681   unsigned NumArgs = Ins.size();
17682   FunctionType *FType = MF.getFunction().getFunctionType();
17683 
17684   std::optional<unsigned> FirstMaskArgument;
17685   if (Subtarget.hasVInstructions())
17686     FirstMaskArgument = preAssignMask(Ins);
17687 
17688   for (unsigned i = 0; i != NumArgs; ++i) {
17689     MVT ArgVT = Ins[i].VT;
17690     ISD::ArgFlagsTy ArgFlags = Ins[i].Flags;
17691 
17692     Type *ArgTy = nullptr;
17693     if (IsRet)
17694       ArgTy = FType->getReturnType();
17695     else if (Ins[i].isOrigArg())
17696       ArgTy = FType->getParamType(Ins[i].getOrigArgIndex());
17697 
17698     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
17699     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
17700            ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
17701            FirstMaskArgument)) {
17702       LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type "
17703                         << ArgVT << '\n');
17704       llvm_unreachable(nullptr);
17705     }
17706   }
17707 }
17708 
analyzeOutputArgs(MachineFunction & MF,CCState & CCInfo,const SmallVectorImpl<ISD::OutputArg> & Outs,bool IsRet,CallLoweringInfo * CLI,RISCVCCAssignFn Fn) const17709 void RISCVTargetLowering::analyzeOutputArgs(
17710     MachineFunction &MF, CCState &CCInfo,
17711     const SmallVectorImpl<ISD::OutputArg> &Outs, bool IsRet,
17712     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
17713   unsigned NumArgs = Outs.size();
17714 
17715   std::optional<unsigned> FirstMaskArgument;
17716   if (Subtarget.hasVInstructions())
17717     FirstMaskArgument = preAssignMask(Outs);
17718 
17719   for (unsigned i = 0; i != NumArgs; i++) {
17720     MVT ArgVT = Outs[i].VT;
17721     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
17722     Type *OrigTy = CLI ? CLI->getArgs()[Outs[i].OrigArgIndex].Ty : nullptr;
17723 
17724     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
17725     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
17726            ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
17727            FirstMaskArgument)) {
17728       LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type "
17729                         << ArgVT << "\n");
17730       llvm_unreachable(nullptr);
17731     }
17732   }
17733 }
17734 
17735 // Convert Val to a ValVT. Should not be called for CCValAssign::Indirect
17736 // values.
convertLocVTToValVT(SelectionDAG & DAG,SDValue Val,const CCValAssign & VA,const SDLoc & DL,const RISCVSubtarget & Subtarget)17737 static SDValue convertLocVTToValVT(SelectionDAG &DAG, SDValue Val,
17738                                    const CCValAssign &VA, const SDLoc &DL,
17739                                    const RISCVSubtarget &Subtarget) {
17740   switch (VA.getLocInfo()) {
17741   default:
17742     llvm_unreachable("Unexpected CCValAssign::LocInfo");
17743   case CCValAssign::Full:
17744     if (VA.getValVT().isFixedLengthVector() && VA.getLocVT().isScalableVector())
17745       Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget);
17746     break;
17747   case CCValAssign::BCvt:
17748     if (VA.getLocVT().isInteger() &&
17749         (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) {
17750       Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val);
17751     } else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) {
17752       if (RV64LegalI32) {
17753         Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Val);
17754         Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
17755       } else {
17756         Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val);
17757       }
17758     } else {
17759       Val = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), Val);
17760     }
17761     break;
17762   }
17763   return Val;
17764 }
17765 
17766 // The caller is responsible for loading the full value if the argument is
17767 // passed with CCValAssign::Indirect.
unpackFromRegLoc(SelectionDAG & DAG,SDValue Chain,const CCValAssign & VA,const SDLoc & DL,const ISD::InputArg & In,const RISCVTargetLowering & TLI)17768 static SDValue unpackFromRegLoc(SelectionDAG &DAG, SDValue Chain,
17769                                 const CCValAssign &VA, const SDLoc &DL,
17770                                 const ISD::InputArg &In,
17771                                 const RISCVTargetLowering &TLI) {
17772   MachineFunction &MF = DAG.getMachineFunction();
17773   MachineRegisterInfo &RegInfo = MF.getRegInfo();
17774   EVT LocVT = VA.getLocVT();
17775   SDValue Val;
17776   const TargetRegisterClass *RC = TLI.getRegClassFor(LocVT.getSimpleVT());
17777   Register VReg = RegInfo.createVirtualRegister(RC);
17778   RegInfo.addLiveIn(VA.getLocReg(), VReg);
17779   Val = DAG.getCopyFromReg(Chain, DL, VReg, LocVT);
17780 
17781   // If input is sign extended from 32 bits, note it for the SExtWRemoval pass.
17782   if (In.isOrigArg()) {
17783     Argument *OrigArg = MF.getFunction().getArg(In.getOrigArgIndex());
17784     if (OrigArg->getType()->isIntegerTy()) {
17785       unsigned BitWidth = OrigArg->getType()->getIntegerBitWidth();
17786       // An input zero extended from i31 can also be considered sign extended.
17787       if ((BitWidth <= 32 && In.Flags.isSExt()) ||
17788           (BitWidth < 32 && In.Flags.isZExt())) {
17789         RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
17790         RVFI->addSExt32Register(VReg);
17791       }
17792     }
17793   }
17794 
17795   if (VA.getLocInfo() == CCValAssign::Indirect)
17796     return Val;
17797 
17798   return convertLocVTToValVT(DAG, Val, VA, DL, TLI.getSubtarget());
17799 }
17800 
convertValVTToLocVT(SelectionDAG & DAG,SDValue Val,const CCValAssign & VA,const SDLoc & DL,const RISCVSubtarget & Subtarget)17801 static SDValue convertValVTToLocVT(SelectionDAG &DAG, SDValue Val,
17802                                    const CCValAssign &VA, const SDLoc &DL,
17803                                    const RISCVSubtarget &Subtarget) {
17804   EVT LocVT = VA.getLocVT();
17805 
17806   switch (VA.getLocInfo()) {
17807   default:
17808     llvm_unreachable("Unexpected CCValAssign::LocInfo");
17809   case CCValAssign::Full:
17810     if (VA.getValVT().isFixedLengthVector() && LocVT.isScalableVector())
17811       Val = convertToScalableVector(LocVT, Val, DAG, Subtarget);
17812     break;
17813   case CCValAssign::BCvt:
17814     if (LocVT.isInteger() &&
17815         (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) {
17816       Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, LocVT, Val);
17817     } else if (LocVT == MVT::i64 && VA.getValVT() == MVT::f32) {
17818       if (RV64LegalI32) {
17819         Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
17820         Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Val);
17821       } else {
17822         Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val);
17823       }
17824     } else {
17825       Val = DAG.getNode(ISD::BITCAST, DL, LocVT, Val);
17826     }
17827     break;
17828   }
17829   return Val;
17830 }
17831 
17832 // The caller is responsible for loading the full value if the argument is
17833 // passed with CCValAssign::Indirect.
unpackFromMemLoc(SelectionDAG & DAG,SDValue Chain,const CCValAssign & VA,const SDLoc & DL)17834 static SDValue unpackFromMemLoc(SelectionDAG &DAG, SDValue Chain,
17835                                 const CCValAssign &VA, const SDLoc &DL) {
17836   MachineFunction &MF = DAG.getMachineFunction();
17837   MachineFrameInfo &MFI = MF.getFrameInfo();
17838   EVT LocVT = VA.getLocVT();
17839   EVT ValVT = VA.getValVT();
17840   EVT PtrVT = MVT::getIntegerVT(DAG.getDataLayout().getPointerSizeInBits(0));
17841   if (ValVT.isScalableVector()) {
17842     // When the value is a scalable vector, we save the pointer which points to
17843     // the scalable vector value in the stack. The ValVT will be the pointer
17844     // type, instead of the scalable vector type.
17845     ValVT = LocVT;
17846   }
17847   int FI = MFI.CreateFixedObject(ValVT.getStoreSize(), VA.getLocMemOffset(),
17848                                  /*IsImmutable=*/true);
17849   SDValue FIN = DAG.getFrameIndex(FI, PtrVT);
17850   SDValue Val;
17851 
17852   ISD::LoadExtType ExtType;
17853   switch (VA.getLocInfo()) {
17854   default:
17855     llvm_unreachable("Unexpected CCValAssign::LocInfo");
17856   case CCValAssign::Full:
17857   case CCValAssign::Indirect:
17858   case CCValAssign::BCvt:
17859     ExtType = ISD::NON_EXTLOAD;
17860     break;
17861   }
17862   Val = DAG.getExtLoad(
17863       ExtType, DL, LocVT, Chain, FIN,
17864       MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI), ValVT);
17865   return Val;
17866 }
17867 
unpackF64OnRV32DSoftABI(SelectionDAG & DAG,SDValue Chain,const CCValAssign & VA,const CCValAssign & HiVA,const SDLoc & DL)17868 static SDValue unpackF64OnRV32DSoftABI(SelectionDAG &DAG, SDValue Chain,
17869                                        const CCValAssign &VA,
17870                                        const CCValAssign &HiVA,
17871                                        const SDLoc &DL) {
17872   assert(VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64 &&
17873          "Unexpected VA");
17874   MachineFunction &MF = DAG.getMachineFunction();
17875   MachineFrameInfo &MFI = MF.getFrameInfo();
17876   MachineRegisterInfo &RegInfo = MF.getRegInfo();
17877 
17878   assert(VA.isRegLoc() && "Expected register VA assignment");
17879 
17880   Register LoVReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass);
17881   RegInfo.addLiveIn(VA.getLocReg(), LoVReg);
17882   SDValue Lo = DAG.getCopyFromReg(Chain, DL, LoVReg, MVT::i32);
17883   SDValue Hi;
17884   if (HiVA.isMemLoc()) {
17885     // Second half of f64 is passed on the stack.
17886     int FI = MFI.CreateFixedObject(4, HiVA.getLocMemOffset(),
17887                                    /*IsImmutable=*/true);
17888     SDValue FIN = DAG.getFrameIndex(FI, MVT::i32);
17889     Hi = DAG.getLoad(MVT::i32, DL, Chain, FIN,
17890                      MachinePointerInfo::getFixedStack(MF, FI));
17891   } else {
17892     // Second half of f64 is passed in another GPR.
17893     Register HiVReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass);
17894     RegInfo.addLiveIn(HiVA.getLocReg(), HiVReg);
17895     Hi = DAG.getCopyFromReg(Chain, DL, HiVReg, MVT::i32);
17896   }
17897   return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi);
17898 }
17899 
17900 // FastCC has less than 1% performance improvement for some particular
17901 // benchmark. But theoretically, it may has benenfit for some cases.
CC_RISCV_FastCC(const DataLayout & DL,RISCVABI::ABI ABI,unsigned ValNo,MVT ValVT,MVT LocVT,CCValAssign::LocInfo LocInfo,ISD::ArgFlagsTy ArgFlags,CCState & State,bool IsFixed,bool IsRet,Type * OrigTy,const RISCVTargetLowering & TLI,std::optional<unsigned> FirstMaskArgument)17902 bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
17903                             unsigned ValNo, MVT ValVT, MVT LocVT,
17904                             CCValAssign::LocInfo LocInfo,
17905                             ISD::ArgFlagsTy ArgFlags, CCState &State,
17906                             bool IsFixed, bool IsRet, Type *OrigTy,
17907                             const RISCVTargetLowering &TLI,
17908                             std::optional<unsigned> FirstMaskArgument) {
17909   if (LocVT == MVT::i32 || LocVT == MVT::i64) {
17910     if (unsigned Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
17911       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17912       return false;
17913     }
17914   }
17915 
17916   const RISCVSubtarget &Subtarget = TLI.getSubtarget();
17917 
17918   if (LocVT == MVT::f16 &&
17919       (Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZfhmin())) {
17920     static const MCPhysReg FPR16List[] = {
17921         RISCV::F10_H, RISCV::F11_H, RISCV::F12_H, RISCV::F13_H, RISCV::F14_H,
17922         RISCV::F15_H, RISCV::F16_H, RISCV::F17_H, RISCV::F0_H,  RISCV::F1_H,
17923         RISCV::F2_H,  RISCV::F3_H,  RISCV::F4_H,  RISCV::F5_H,  RISCV::F6_H,
17924         RISCV::F7_H,  RISCV::F28_H, RISCV::F29_H, RISCV::F30_H, RISCV::F31_H};
17925     if (unsigned Reg = State.AllocateReg(FPR16List)) {
17926       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17927       return false;
17928     }
17929   }
17930 
17931   if (LocVT == MVT::f32 && Subtarget.hasStdExtF()) {
17932     static const MCPhysReg FPR32List[] = {
17933         RISCV::F10_F, RISCV::F11_F, RISCV::F12_F, RISCV::F13_F, RISCV::F14_F,
17934         RISCV::F15_F, RISCV::F16_F, RISCV::F17_F, RISCV::F0_F,  RISCV::F1_F,
17935         RISCV::F2_F,  RISCV::F3_F,  RISCV::F4_F,  RISCV::F5_F,  RISCV::F6_F,
17936         RISCV::F7_F,  RISCV::F28_F, RISCV::F29_F, RISCV::F30_F, RISCV::F31_F};
17937     if (unsigned Reg = State.AllocateReg(FPR32List)) {
17938       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17939       return false;
17940     }
17941   }
17942 
17943   if (LocVT == MVT::f64 && Subtarget.hasStdExtD()) {
17944     static const MCPhysReg FPR64List[] = {
17945         RISCV::F10_D, RISCV::F11_D, RISCV::F12_D, RISCV::F13_D, RISCV::F14_D,
17946         RISCV::F15_D, RISCV::F16_D, RISCV::F17_D, RISCV::F0_D,  RISCV::F1_D,
17947         RISCV::F2_D,  RISCV::F3_D,  RISCV::F4_D,  RISCV::F5_D,  RISCV::F6_D,
17948         RISCV::F7_D,  RISCV::F28_D, RISCV::F29_D, RISCV::F30_D, RISCV::F31_D};
17949     if (unsigned Reg = State.AllocateReg(FPR64List)) {
17950       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17951       return false;
17952     }
17953   }
17954 
17955   // Check if there is an available GPR before hitting the stack.
17956   if ((LocVT == MVT::f16 &&
17957        (Subtarget.hasStdExtZhinx() || Subtarget.hasStdExtZhinxmin())) ||
17958       (LocVT == MVT::f32 && Subtarget.hasStdExtZfinx()) ||
17959       (LocVT == MVT::f64 && Subtarget.is64Bit() &&
17960        Subtarget.hasStdExtZdinx())) {
17961     if (unsigned Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
17962       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17963       return false;
17964     }
17965   }
17966 
17967   if (LocVT == MVT::f16) {
17968     unsigned Offset2 = State.AllocateStack(2, Align(2));
17969     State.addLoc(CCValAssign::getMem(ValNo, ValVT, Offset2, LocVT, LocInfo));
17970     return false;
17971   }
17972 
17973   if (LocVT == MVT::i32 || LocVT == MVT::f32) {
17974     unsigned Offset4 = State.AllocateStack(4, Align(4));
17975     State.addLoc(CCValAssign::getMem(ValNo, ValVT, Offset4, LocVT, LocInfo));
17976     return false;
17977   }
17978 
17979   if (LocVT == MVT::i64 || LocVT == MVT::f64) {
17980     unsigned Offset5 = State.AllocateStack(8, Align(8));
17981     State.addLoc(CCValAssign::getMem(ValNo, ValVT, Offset5, LocVT, LocInfo));
17982     return false;
17983   }
17984 
17985   if (LocVT.isVector()) {
17986     if (unsigned Reg =
17987             allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) {
17988       // Fixed-length vectors are located in the corresponding scalable-vector
17989       // container types.
17990       if (ValVT.isFixedLengthVector())
17991         LocVT = TLI.getContainerForFixedLengthVector(LocVT);
17992       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
17993     } else {
17994       // Try and pass the address via a "fast" GPR.
17995       if (unsigned GPRReg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
17996         LocInfo = CCValAssign::Indirect;
17997         LocVT = TLI.getSubtarget().getXLenVT();
17998         State.addLoc(CCValAssign::getReg(ValNo, ValVT, GPRReg, LocVT, LocInfo));
17999       } else if (ValVT.isFixedLengthVector()) {
18000         auto StackAlign =
18001             MaybeAlign(ValVT.getScalarSizeInBits() / 8).valueOrOne();
18002         unsigned StackOffset =
18003             State.AllocateStack(ValVT.getStoreSize(), StackAlign);
18004         State.addLoc(
18005             CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
18006       } else {
18007         // Can't pass scalable vectors on the stack.
18008         return true;
18009       }
18010     }
18011 
18012     return false;
18013   }
18014 
18015   return true; // CC didn't match.
18016 }
18017 
CC_RISCV_GHC(unsigned ValNo,MVT ValVT,MVT LocVT,CCValAssign::LocInfo LocInfo,ISD::ArgFlagsTy ArgFlags,CCState & State)18018 bool RISCV::CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
18019                          CCValAssign::LocInfo LocInfo,
18020                          ISD::ArgFlagsTy ArgFlags, CCState &State) {
18021   if (ArgFlags.isNest()) {
18022     report_fatal_error(
18023         "Attribute 'nest' is not supported in GHC calling convention");
18024   }
18025 
18026   static const MCPhysReg GPRList[] = {
18027       RISCV::X9,  RISCV::X18, RISCV::X19, RISCV::X20, RISCV::X21, RISCV::X22,
18028       RISCV::X23, RISCV::X24, RISCV::X25, RISCV::X26, RISCV::X27};
18029 
18030   if (LocVT == MVT::i32 || LocVT == MVT::i64) {
18031     // Pass in STG registers: Base, Sp, Hp, R1, R2, R3, R4, R5, R6, R7, SpLim
18032     //                        s1    s2  s3  s4  s5  s6  s7  s8  s9  s10 s11
18033     if (unsigned Reg = State.AllocateReg(GPRList)) {
18034       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
18035       return false;
18036     }
18037   }
18038 
18039   const RISCVSubtarget &Subtarget =
18040       State.getMachineFunction().getSubtarget<RISCVSubtarget>();
18041 
18042   if (LocVT == MVT::f32 && Subtarget.hasStdExtF()) {
18043     // Pass in STG registers: F1, ..., F6
18044     //                        fs0 ... fs5
18045     static const MCPhysReg FPR32List[] = {RISCV::F8_F, RISCV::F9_F,
18046                                           RISCV::F18_F, RISCV::F19_F,
18047                                           RISCV::F20_F, RISCV::F21_F};
18048     if (unsigned Reg = State.AllocateReg(FPR32List)) {
18049       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
18050       return false;
18051     }
18052   }
18053 
18054   if (LocVT == MVT::f64 && Subtarget.hasStdExtD()) {
18055     // Pass in STG registers: D1, ..., D6
18056     //                        fs6 ... fs11
18057     static const MCPhysReg FPR64List[] = {RISCV::F22_D, RISCV::F23_D,
18058                                           RISCV::F24_D, RISCV::F25_D,
18059                                           RISCV::F26_D, RISCV::F27_D};
18060     if (unsigned Reg = State.AllocateReg(FPR64List)) {
18061       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
18062       return false;
18063     }
18064   }
18065 
18066   if ((LocVT == MVT::f32 && Subtarget.hasStdExtZfinx()) ||
18067       (LocVT == MVT::f64 && Subtarget.hasStdExtZdinx() &&
18068        Subtarget.is64Bit())) {
18069     if (unsigned Reg = State.AllocateReg(GPRList)) {
18070       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
18071       return false;
18072     }
18073   }
18074 
18075   report_fatal_error("No registers left in GHC calling convention");
18076   return true;
18077 }
18078 
18079 // Transform physical registers into virtual registers.
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool IsVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const18080 SDValue RISCVTargetLowering::LowerFormalArguments(
18081     SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
18082     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
18083     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
18084 
18085   MachineFunction &MF = DAG.getMachineFunction();
18086 
18087   switch (CallConv) {
18088   default:
18089     report_fatal_error("Unsupported calling convention");
18090   case CallingConv::C:
18091   case CallingConv::Fast:
18092   case CallingConv::SPIR_KERNEL:
18093   case CallingConv::GRAAL:
18094     break;
18095   case CallingConv::GHC:
18096     if (Subtarget.isRVE())
18097       report_fatal_error("GHC calling convention is not supported on RVE!");
18098     if (!Subtarget.hasStdExtFOrZfinx() || !Subtarget.hasStdExtDOrZdinx())
18099       report_fatal_error("GHC calling convention requires the (Zfinx/F) and "
18100                          "(Zdinx/D) instruction set extensions");
18101   }
18102 
18103   const Function &Func = MF.getFunction();
18104   if (Func.hasFnAttribute("interrupt")) {
18105     if (!Func.arg_empty())
18106       report_fatal_error(
18107         "Functions with the interrupt attribute cannot have arguments!");
18108 
18109     StringRef Kind =
18110       MF.getFunction().getFnAttribute("interrupt").getValueAsString();
18111 
18112     if (!(Kind == "user" || Kind == "supervisor" || Kind == "machine"))
18113       report_fatal_error(
18114         "Function interrupt attribute argument not supported!");
18115   }
18116 
18117   EVT PtrVT = getPointerTy(DAG.getDataLayout());
18118   MVT XLenVT = Subtarget.getXLenVT();
18119   unsigned XLenInBytes = Subtarget.getXLen() / 8;
18120   // Used with vargs to acumulate store chains.
18121   std::vector<SDValue> OutChains;
18122 
18123   // Assign locations to all of the incoming arguments.
18124   SmallVector<CCValAssign, 16> ArgLocs;
18125   CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
18126 
18127   if (CallConv == CallingConv::GHC)
18128     CCInfo.AnalyzeFormalArguments(Ins, RISCV::CC_RISCV_GHC);
18129   else
18130     analyzeInputArgs(MF, CCInfo, Ins, /*IsRet=*/false,
18131                      CallConv == CallingConv::Fast ? RISCV::CC_RISCV_FastCC
18132                                                    : RISCV::CC_RISCV);
18133 
18134   for (unsigned i = 0, e = ArgLocs.size(), InsIdx = 0; i != e; ++i, ++InsIdx) {
18135     CCValAssign &VA = ArgLocs[i];
18136     SDValue ArgValue;
18137     // Passing f64 on RV32D with a soft float ABI must be handled as a special
18138     // case.
18139     if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) {
18140       assert(VA.needsCustom());
18141       ArgValue = unpackF64OnRV32DSoftABI(DAG, Chain, VA, ArgLocs[++i], DL);
18142     } else if (VA.isRegLoc())
18143       ArgValue = unpackFromRegLoc(DAG, Chain, VA, DL, Ins[InsIdx], *this);
18144     else
18145       ArgValue = unpackFromMemLoc(DAG, Chain, VA, DL);
18146 
18147     if (VA.getLocInfo() == CCValAssign::Indirect) {
18148       // If the original argument was split and passed by reference (e.g. i128
18149       // on RV32), we need to load all parts of it here (using the same
18150       // address). Vectors may be partly split to registers and partly to the
18151       // stack, in which case the base address is partly offset and subsequent
18152       // stores are relative to that.
18153       InVals.push_back(DAG.getLoad(VA.getValVT(), DL, Chain, ArgValue,
18154                                    MachinePointerInfo()));
18155       unsigned ArgIndex = Ins[InsIdx].OrigArgIndex;
18156       unsigned ArgPartOffset = Ins[InsIdx].PartOffset;
18157       assert(VA.getValVT().isVector() || ArgPartOffset == 0);
18158       while (i + 1 != e && Ins[InsIdx + 1].OrigArgIndex == ArgIndex) {
18159         CCValAssign &PartVA = ArgLocs[i + 1];
18160         unsigned PartOffset = Ins[InsIdx + 1].PartOffset - ArgPartOffset;
18161         SDValue Offset = DAG.getIntPtrConstant(PartOffset, DL);
18162         if (PartVA.getValVT().isScalableVector())
18163           Offset = DAG.getNode(ISD::VSCALE, DL, XLenVT, Offset);
18164         SDValue Address = DAG.getNode(ISD::ADD, DL, PtrVT, ArgValue, Offset);
18165         InVals.push_back(DAG.getLoad(PartVA.getValVT(), DL, Chain, Address,
18166                                      MachinePointerInfo()));
18167         ++i;
18168         ++InsIdx;
18169       }
18170       continue;
18171     }
18172     InVals.push_back(ArgValue);
18173   }
18174 
18175   if (any_of(ArgLocs,
18176              [](CCValAssign &VA) { return VA.getLocVT().isScalableVector(); }))
18177     MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
18178 
18179   if (IsVarArg) {
18180     ArrayRef<MCPhysReg> ArgRegs = RISCV::getArgGPRs(Subtarget.getTargetABI());
18181     unsigned Idx = CCInfo.getFirstUnallocated(ArgRegs);
18182     const TargetRegisterClass *RC = &RISCV::GPRRegClass;
18183     MachineFrameInfo &MFI = MF.getFrameInfo();
18184     MachineRegisterInfo &RegInfo = MF.getRegInfo();
18185     RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
18186 
18187     // Size of the vararg save area. For now, the varargs save area is either
18188     // zero or large enough to hold a0-a7.
18189     int VarArgsSaveSize = XLenInBytes * (ArgRegs.size() - Idx);
18190     int FI;
18191 
18192     // If all registers are allocated, then all varargs must be passed on the
18193     // stack and we don't need to save any argregs.
18194     if (VarArgsSaveSize == 0) {
18195       int VaArgOffset = CCInfo.getStackSize();
18196       FI = MFI.CreateFixedObject(XLenInBytes, VaArgOffset, true);
18197     } else {
18198       int VaArgOffset = -VarArgsSaveSize;
18199       FI = MFI.CreateFixedObject(VarArgsSaveSize, VaArgOffset, true);
18200 
18201       // If saving an odd number of registers then create an extra stack slot to
18202       // ensure that the frame pointer is 2*XLEN-aligned, which in turn ensures
18203       // offsets to even-numbered registered remain 2*XLEN-aligned.
18204       if (Idx % 2) {
18205         MFI.CreateFixedObject(
18206             XLenInBytes, VaArgOffset - static_cast<int>(XLenInBytes), true);
18207         VarArgsSaveSize += XLenInBytes;
18208       }
18209 
18210       SDValue FIN = DAG.getFrameIndex(FI, PtrVT);
18211 
18212       // Copy the integer registers that may have been used for passing varargs
18213       // to the vararg save area.
18214       for (unsigned I = Idx; I < ArgRegs.size(); ++I) {
18215         const Register Reg = RegInfo.createVirtualRegister(RC);
18216         RegInfo.addLiveIn(ArgRegs[I], Reg);
18217         SDValue ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, XLenVT);
18218         SDValue Store = DAG.getStore(
18219             Chain, DL, ArgValue, FIN,
18220             MachinePointerInfo::getFixedStack(MF, FI, (I - Idx) * XLenInBytes));
18221         OutChains.push_back(Store);
18222         FIN =
18223             DAG.getMemBasePlusOffset(FIN, TypeSize::getFixed(XLenInBytes), DL);
18224       }
18225     }
18226 
18227     // Record the frame index of the first variable argument
18228     // which is a value necessary to VASTART.
18229     RVFI->setVarArgsFrameIndex(FI);
18230     RVFI->setVarArgsSaveSize(VarArgsSaveSize);
18231   }
18232 
18233   // All stores are grouped in one node to allow the matching between
18234   // the size of Ins and InVals. This only happens for vararg functions.
18235   if (!OutChains.empty()) {
18236     OutChains.push_back(Chain);
18237     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, OutChains);
18238   }
18239 
18240   return Chain;
18241 }
18242 
18243 /// isEligibleForTailCallOptimization - Check whether the call is eligible
18244 /// for tail call optimization.
18245 /// Note: This is modelled after ARM's IsEligibleForTailCallOptimization.
isEligibleForTailCallOptimization(CCState & CCInfo,CallLoweringInfo & CLI,MachineFunction & MF,const SmallVector<CCValAssign,16> & ArgLocs) const18246 bool RISCVTargetLowering::isEligibleForTailCallOptimization(
18247     CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
18248     const SmallVector<CCValAssign, 16> &ArgLocs) const {
18249 
18250   auto CalleeCC = CLI.CallConv;
18251   auto &Outs = CLI.Outs;
18252   auto &Caller = MF.getFunction();
18253   auto CallerCC = Caller.getCallingConv();
18254 
18255   // Exception-handling functions need a special set of instructions to
18256   // indicate a return to the hardware. Tail-calling another function would
18257   // probably break this.
18258   // TODO: The "interrupt" attribute isn't currently defined by RISC-V. This
18259   // should be expanded as new function attributes are introduced.
18260   if (Caller.hasFnAttribute("interrupt"))
18261     return false;
18262 
18263   // Do not tail call opt if the stack is used to pass parameters.
18264   if (CCInfo.getStackSize() != 0)
18265     return false;
18266 
18267   // Do not tail call opt if any parameters need to be passed indirectly.
18268   // Since long doubles (fp128) and i128 are larger than 2*XLEN, they are
18269   // passed indirectly. So the address of the value will be passed in a
18270   // register, or if not available, then the address is put on the stack. In
18271   // order to pass indirectly, space on the stack often needs to be allocated
18272   // in order to store the value. In this case the CCInfo.getNextStackOffset()
18273   // != 0 check is not enough and we need to check if any CCValAssign ArgsLocs
18274   // are passed CCValAssign::Indirect.
18275   for (auto &VA : ArgLocs)
18276     if (VA.getLocInfo() == CCValAssign::Indirect)
18277       return false;
18278 
18279   // Do not tail call opt if either caller or callee uses struct return
18280   // semantics.
18281   auto IsCallerStructRet = Caller.hasStructRetAttr();
18282   auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
18283   if (IsCallerStructRet || IsCalleeStructRet)
18284     return false;
18285 
18286   // The callee has to preserve all registers the caller needs to preserve.
18287   const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
18288   const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
18289   if (CalleeCC != CallerCC) {
18290     const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
18291     if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
18292       return false;
18293   }
18294 
18295   // Byval parameters hand the function a pointer directly into the stack area
18296   // we want to reuse during a tail call. Working around this *is* possible
18297   // but less efficient and uglier in LowerCall.
18298   for (auto &Arg : Outs)
18299     if (Arg.Flags.isByVal())
18300       return false;
18301 
18302   return true;
18303 }
18304 
getPrefTypeAlign(EVT VT,SelectionDAG & DAG)18305 static Align getPrefTypeAlign(EVT VT, SelectionDAG &DAG) {
18306   return DAG.getDataLayout().getPrefTypeAlign(
18307       VT.getTypeForEVT(*DAG.getContext()));
18308 }
18309 
18310 // Lower a call to a callseq_start + CALL + callseq_end chain, and add input
18311 // and output parameter nodes.
LowerCall(CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const18312 SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
18313                                        SmallVectorImpl<SDValue> &InVals) const {
18314   SelectionDAG &DAG = CLI.DAG;
18315   SDLoc &DL = CLI.DL;
18316   SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
18317   SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
18318   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
18319   SDValue Chain = CLI.Chain;
18320   SDValue Callee = CLI.Callee;
18321   bool &IsTailCall = CLI.IsTailCall;
18322   CallingConv::ID CallConv = CLI.CallConv;
18323   bool IsVarArg = CLI.IsVarArg;
18324   EVT PtrVT = getPointerTy(DAG.getDataLayout());
18325   MVT XLenVT = Subtarget.getXLenVT();
18326 
18327   MachineFunction &MF = DAG.getMachineFunction();
18328 
18329   // Analyze the operands of the call, assigning locations to each operand.
18330   SmallVector<CCValAssign, 16> ArgLocs;
18331   CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
18332 
18333   if (CallConv == CallingConv::GHC) {
18334     if (Subtarget.isRVE())
18335       report_fatal_error("GHC calling convention is not supported on RVE!");
18336     ArgCCInfo.AnalyzeCallOperands(Outs, RISCV::CC_RISCV_GHC);
18337   } else
18338     analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI,
18339                       CallConv == CallingConv::Fast ? RISCV::CC_RISCV_FastCC
18340                                                     : RISCV::CC_RISCV);
18341 
18342   // Check if it's really possible to do a tail call.
18343   if (IsTailCall)
18344     IsTailCall = isEligibleForTailCallOptimization(ArgCCInfo, CLI, MF, ArgLocs);
18345 
18346   if (IsTailCall)
18347     ++NumTailCalls;
18348   else if (CLI.CB && CLI.CB->isMustTailCall())
18349     report_fatal_error("failed to perform tail call elimination on a call "
18350                        "site marked musttail");
18351 
18352   // Get a count of how many bytes are to be pushed on the stack.
18353   unsigned NumBytes = ArgCCInfo.getStackSize();
18354 
18355   // Create local copies for byval args
18356   SmallVector<SDValue, 8> ByValArgs;
18357   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
18358     ISD::ArgFlagsTy Flags = Outs[i].Flags;
18359     if (!Flags.isByVal())
18360       continue;
18361 
18362     SDValue Arg = OutVals[i];
18363     unsigned Size = Flags.getByValSize();
18364     Align Alignment = Flags.getNonZeroByValAlign();
18365 
18366     int FI =
18367         MF.getFrameInfo().CreateStackObject(Size, Alignment, /*isSS=*/false);
18368     SDValue FIPtr = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
18369     SDValue SizeNode = DAG.getConstant(Size, DL, XLenVT);
18370 
18371     Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Alignment,
18372                           /*IsVolatile=*/false,
18373                           /*AlwaysInline=*/false, IsTailCall,
18374                           MachinePointerInfo(), MachinePointerInfo());
18375     ByValArgs.push_back(FIPtr);
18376   }
18377 
18378   if (!IsTailCall)
18379     Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);
18380 
18381   // Copy argument values to their designated locations.
18382   SmallVector<std::pair<Register, SDValue>, 8> RegsToPass;
18383   SmallVector<SDValue, 8> MemOpChains;
18384   SDValue StackPtr;
18385   for (unsigned i = 0, j = 0, e = ArgLocs.size(), OutIdx = 0; i != e;
18386        ++i, ++OutIdx) {
18387     CCValAssign &VA = ArgLocs[i];
18388     SDValue ArgValue = OutVals[OutIdx];
18389     ISD::ArgFlagsTy Flags = Outs[OutIdx].Flags;
18390 
18391     // Handle passing f64 on RV32D with a soft float ABI as a special case.
18392     if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) {
18393       assert(VA.isRegLoc() && "Expected register VA assignment");
18394       assert(VA.needsCustom());
18395       SDValue SplitF64 = DAG.getNode(
18396           RISCVISD::SplitF64, DL, DAG.getVTList(MVT::i32, MVT::i32), ArgValue);
18397       SDValue Lo = SplitF64.getValue(0);
18398       SDValue Hi = SplitF64.getValue(1);
18399 
18400       Register RegLo = VA.getLocReg();
18401       RegsToPass.push_back(std::make_pair(RegLo, Lo));
18402 
18403       // Get the CCValAssign for the Hi part.
18404       CCValAssign &HiVA = ArgLocs[++i];
18405 
18406       if (HiVA.isMemLoc()) {
18407         // Second half of f64 is passed on the stack.
18408         if (!StackPtr.getNode())
18409           StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, PtrVT);
18410         SDValue Address =
18411             DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr,
18412                         DAG.getIntPtrConstant(HiVA.getLocMemOffset(), DL));
18413         // Emit the store.
18414         MemOpChains.push_back(
18415             DAG.getStore(Chain, DL, Hi, Address, MachinePointerInfo()));
18416       } else {
18417         // Second half of f64 is passed in another GPR.
18418         Register RegHigh = HiVA.getLocReg();
18419         RegsToPass.push_back(std::make_pair(RegHigh, Hi));
18420       }
18421       continue;
18422     }
18423 
18424     // Promote the value if needed.
18425     // For now, only handle fully promoted and indirect arguments.
18426     if (VA.getLocInfo() == CCValAssign::Indirect) {
18427       // Store the argument in a stack slot and pass its address.
18428       Align StackAlign =
18429           std::max(getPrefTypeAlign(Outs[OutIdx].ArgVT, DAG),
18430                    getPrefTypeAlign(ArgValue.getValueType(), DAG));
18431       TypeSize StoredSize = ArgValue.getValueType().getStoreSize();
18432       // If the original argument was split (e.g. i128), we need
18433       // to store the required parts of it here (and pass just one address).
18434       // Vectors may be partly split to registers and partly to the stack, in
18435       // which case the base address is partly offset and subsequent stores are
18436       // relative to that.
18437       unsigned ArgIndex = Outs[OutIdx].OrigArgIndex;
18438       unsigned ArgPartOffset = Outs[OutIdx].PartOffset;
18439       assert(VA.getValVT().isVector() || ArgPartOffset == 0);
18440       // Calculate the total size to store. We don't have access to what we're
18441       // actually storing other than performing the loop and collecting the
18442       // info.
18443       SmallVector<std::pair<SDValue, SDValue>> Parts;
18444       while (i + 1 != e && Outs[OutIdx + 1].OrigArgIndex == ArgIndex) {
18445         SDValue PartValue = OutVals[OutIdx + 1];
18446         unsigned PartOffset = Outs[OutIdx + 1].PartOffset - ArgPartOffset;
18447         SDValue Offset = DAG.getIntPtrConstant(PartOffset, DL);
18448         EVT PartVT = PartValue.getValueType();
18449         if (PartVT.isScalableVector())
18450           Offset = DAG.getNode(ISD::VSCALE, DL, XLenVT, Offset);
18451         StoredSize += PartVT.getStoreSize();
18452         StackAlign = std::max(StackAlign, getPrefTypeAlign(PartVT, DAG));
18453         Parts.push_back(std::make_pair(PartValue, Offset));
18454         ++i;
18455         ++OutIdx;
18456       }
18457       SDValue SpillSlot = DAG.CreateStackTemporary(StoredSize, StackAlign);
18458       int FI = cast<FrameIndexSDNode>(SpillSlot)->getIndex();
18459       MemOpChains.push_back(
18460           DAG.getStore(Chain, DL, ArgValue, SpillSlot,
18461                        MachinePointerInfo::getFixedStack(MF, FI)));
18462       for (const auto &Part : Parts) {
18463         SDValue PartValue = Part.first;
18464         SDValue PartOffset = Part.second;
18465         SDValue Address =
18466             DAG.getNode(ISD::ADD, DL, PtrVT, SpillSlot, PartOffset);
18467         MemOpChains.push_back(
18468             DAG.getStore(Chain, DL, PartValue, Address,
18469                          MachinePointerInfo::getFixedStack(MF, FI)));
18470       }
18471       ArgValue = SpillSlot;
18472     } else {
18473       ArgValue = convertValVTToLocVT(DAG, ArgValue, VA, DL, Subtarget);
18474     }
18475 
18476     // Use local copy if it is a byval arg.
18477     if (Flags.isByVal())
18478       ArgValue = ByValArgs[j++];
18479 
18480     if (VA.isRegLoc()) {
18481       // Queue up the argument copies and emit them at the end.
18482       RegsToPass.push_back(std::make_pair(VA.getLocReg(), ArgValue));
18483     } else {
18484       assert(VA.isMemLoc() && "Argument not register or memory");
18485       assert(!IsTailCall && "Tail call not allowed if stack is used "
18486                             "for passing parameters");
18487 
18488       // Work out the address of the stack slot.
18489       if (!StackPtr.getNode())
18490         StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, PtrVT);
18491       SDValue Address =
18492           DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr,
18493                       DAG.getIntPtrConstant(VA.getLocMemOffset(), DL));
18494 
18495       // Emit the store.
18496       MemOpChains.push_back(
18497           DAG.getStore(Chain, DL, ArgValue, Address, MachinePointerInfo()));
18498     }
18499   }
18500 
18501   // Join the stores, which are independent of one another.
18502   if (!MemOpChains.empty())
18503     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
18504 
18505   SDValue Glue;
18506 
18507   // Build a sequence of copy-to-reg nodes, chained and glued together.
18508   for (auto &Reg : RegsToPass) {
18509     Chain = DAG.getCopyToReg(Chain, DL, Reg.first, Reg.second, Glue);
18510     Glue = Chain.getValue(1);
18511   }
18512 
18513   // Validate that none of the argument registers have been marked as
18514   // reserved, if so report an error. Do the same for the return address if this
18515   // is not a tailcall.
18516   validateCCReservedRegs(RegsToPass, MF);
18517   if (!IsTailCall &&
18518       MF.getSubtarget<RISCVSubtarget>().isRegisterReservedByUser(RISCV::X1))
18519     MF.getFunction().getContext().diagnose(DiagnosticInfoUnsupported{
18520         MF.getFunction(),
18521         "Return address register required, but has been reserved."});
18522 
18523   // If the callee is a GlobalAddress/ExternalSymbol node, turn it into a
18524   // TargetGlobalAddress/TargetExternalSymbol node so that legalize won't
18525   // split it and then direct call can be matched by PseudoCALL.
18526   if (GlobalAddressSDNode *S = dyn_cast<GlobalAddressSDNode>(Callee)) {
18527     const GlobalValue *GV = S->getGlobal();
18528     Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, RISCVII::MO_CALL);
18529   } else if (ExternalSymbolSDNode *S = dyn_cast<ExternalSymbolSDNode>(Callee)) {
18530     Callee = DAG.getTargetExternalSymbol(S->getSymbol(), PtrVT, RISCVII::MO_CALL);
18531   }
18532 
18533   // The first call operand is the chain and the second is the target address.
18534   SmallVector<SDValue, 8> Ops;
18535   Ops.push_back(Chain);
18536   Ops.push_back(Callee);
18537 
18538   // Add argument registers to the end of the list so that they are
18539   // known live into the call.
18540   for (auto &Reg : RegsToPass)
18541     Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType()));
18542 
18543   if (!IsTailCall) {
18544     // Add a register mask operand representing the call-preserved registers.
18545     const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
18546     const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
18547     assert(Mask && "Missing call preserved mask for calling convention");
18548     Ops.push_back(DAG.getRegisterMask(Mask));
18549   }
18550 
18551   // Glue the call to the argument copies, if any.
18552   if (Glue.getNode())
18553     Ops.push_back(Glue);
18554 
18555   assert((!CLI.CFIType || CLI.CB->isIndirectCall()) &&
18556          "Unexpected CFI type for a direct call");
18557 
18558   // Emit the call.
18559   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
18560 
18561   if (IsTailCall) {
18562     MF.getFrameInfo().setHasTailCall();
18563     SDValue Ret = DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
18564     if (CLI.CFIType)
18565       Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue());
18566     DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge);
18567     return Ret;
18568   }
18569 
18570   Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
18571   if (CLI.CFIType)
18572     Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue());
18573   DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
18574   Glue = Chain.getValue(1);
18575 
18576   // Mark the end of the call, which is glued to the call itself.
18577   Chain = DAG.getCALLSEQ_END(Chain, NumBytes, 0, Glue, DL);
18578   Glue = Chain.getValue(1);
18579 
18580   // Assign locations to each value returned by this call.
18581   SmallVector<CCValAssign, 16> RVLocs;
18582   CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext());
18583   analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, RISCV::CC_RISCV);
18584 
18585   // Copy all of the result registers out of their specified physreg.
18586   for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
18587     auto &VA = RVLocs[i];
18588     // Copy the value out
18589     SDValue RetValue =
18590         DAG.getCopyFromReg(Chain, DL, VA.getLocReg(), VA.getLocVT(), Glue);
18591     // Glue the RetValue to the end of the call sequence
18592     Chain = RetValue.getValue(1);
18593     Glue = RetValue.getValue(2);
18594 
18595     if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) {
18596       assert(VA.needsCustom());
18597       SDValue RetValue2 = DAG.getCopyFromReg(Chain, DL, RVLocs[++i].getLocReg(),
18598                                              MVT::i32, Glue);
18599       Chain = RetValue2.getValue(1);
18600       Glue = RetValue2.getValue(2);
18601       RetValue = DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, RetValue,
18602                              RetValue2);
18603     }
18604 
18605     RetValue = convertLocVTToValVT(DAG, RetValue, VA, DL, Subtarget);
18606 
18607     InVals.push_back(RetValue);
18608   }
18609 
18610   return Chain;
18611 }
18612 
CanLowerReturn(CallingConv::ID CallConv,MachineFunction & MF,bool IsVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,LLVMContext & Context) const18613 bool RISCVTargetLowering::CanLowerReturn(
18614     CallingConv::ID CallConv, MachineFunction &MF, bool IsVarArg,
18615     const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext &Context) const {
18616   SmallVector<CCValAssign, 16> RVLocs;
18617   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
18618 
18619   std::optional<unsigned> FirstMaskArgument;
18620   if (Subtarget.hasVInstructions())
18621     FirstMaskArgument = preAssignMask(Outs);
18622 
18623   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
18624     MVT VT = Outs[i].VT;
18625     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
18626     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
18627     if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
18628                  ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr,
18629                  *this, FirstMaskArgument))
18630       return false;
18631   }
18632   return true;
18633 }
18634 
18635 SDValue
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool IsVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,const SDLoc & DL,SelectionDAG & DAG) const18636 RISCVTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
18637                                  bool IsVarArg,
18638                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
18639                                  const SmallVectorImpl<SDValue> &OutVals,
18640                                  const SDLoc &DL, SelectionDAG &DAG) const {
18641   MachineFunction &MF = DAG.getMachineFunction();
18642   const RISCVSubtarget &STI = MF.getSubtarget<RISCVSubtarget>();
18643 
18644   // Stores the assignment of the return value to a location.
18645   SmallVector<CCValAssign, 16> RVLocs;
18646 
18647   // Info about the registers and stack slot.
18648   CCState CCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), RVLocs,
18649                  *DAG.getContext());
18650 
18651   analyzeOutputArgs(DAG.getMachineFunction(), CCInfo, Outs, /*IsRet=*/true,
18652                     nullptr, RISCV::CC_RISCV);
18653 
18654   if (CallConv == CallingConv::GHC && !RVLocs.empty())
18655     report_fatal_error("GHC functions return void only");
18656 
18657   SDValue Glue;
18658   SmallVector<SDValue, 4> RetOps(1, Chain);
18659 
18660   // Copy the result values into the output registers.
18661   for (unsigned i = 0, e = RVLocs.size(), OutIdx = 0; i < e; ++i, ++OutIdx) {
18662     SDValue Val = OutVals[OutIdx];
18663     CCValAssign &VA = RVLocs[i];
18664     assert(VA.isRegLoc() && "Can only return in registers!");
18665 
18666     if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) {
18667       // Handle returning f64 on RV32D with a soft float ABI.
18668       assert(VA.isRegLoc() && "Expected return via registers");
18669       assert(VA.needsCustom());
18670       SDValue SplitF64 = DAG.getNode(RISCVISD::SplitF64, DL,
18671                                      DAG.getVTList(MVT::i32, MVT::i32), Val);
18672       SDValue Lo = SplitF64.getValue(0);
18673       SDValue Hi = SplitF64.getValue(1);
18674       Register RegLo = VA.getLocReg();
18675       Register RegHi = RVLocs[++i].getLocReg();
18676 
18677       if (STI.isRegisterReservedByUser(RegLo) ||
18678           STI.isRegisterReservedByUser(RegHi))
18679         MF.getFunction().getContext().diagnose(DiagnosticInfoUnsupported{
18680             MF.getFunction(),
18681             "Return value register required, but has been reserved."});
18682 
18683       Chain = DAG.getCopyToReg(Chain, DL, RegLo, Lo, Glue);
18684       Glue = Chain.getValue(1);
18685       RetOps.push_back(DAG.getRegister(RegLo, MVT::i32));
18686       Chain = DAG.getCopyToReg(Chain, DL, RegHi, Hi, Glue);
18687       Glue = Chain.getValue(1);
18688       RetOps.push_back(DAG.getRegister(RegHi, MVT::i32));
18689     } else {
18690       // Handle a 'normal' return.
18691       Val = convertValVTToLocVT(DAG, Val, VA, DL, Subtarget);
18692       Chain = DAG.getCopyToReg(Chain, DL, VA.getLocReg(), Val, Glue);
18693 
18694       if (STI.isRegisterReservedByUser(VA.getLocReg()))
18695         MF.getFunction().getContext().diagnose(DiagnosticInfoUnsupported{
18696             MF.getFunction(),
18697             "Return value register required, but has been reserved."});
18698 
18699       // Guarantee that all emitted copies are stuck together.
18700       Glue = Chain.getValue(1);
18701       RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT()));
18702     }
18703   }
18704 
18705   RetOps[0] = Chain; // Update chain.
18706 
18707   // Add the glue node if we have it.
18708   if (Glue.getNode()) {
18709     RetOps.push_back(Glue);
18710   }
18711 
18712   if (any_of(RVLocs,
18713              [](CCValAssign &VA) { return VA.getLocVT().isScalableVector(); }))
18714     MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
18715 
18716   unsigned RetOpc = RISCVISD::RET_GLUE;
18717   // Interrupt service routines use different return instructions.
18718   const Function &Func = DAG.getMachineFunction().getFunction();
18719   if (Func.hasFnAttribute("interrupt")) {
18720     if (!Func.getReturnType()->isVoidTy())
18721       report_fatal_error(
18722           "Functions with the interrupt attribute must have void return type!");
18723 
18724     MachineFunction &MF = DAG.getMachineFunction();
18725     StringRef Kind =
18726       MF.getFunction().getFnAttribute("interrupt").getValueAsString();
18727 
18728     if (Kind == "supervisor")
18729       RetOpc = RISCVISD::SRET_GLUE;
18730     else
18731       RetOpc = RISCVISD::MRET_GLUE;
18732   }
18733 
18734   return DAG.getNode(RetOpc, DL, MVT::Other, RetOps);
18735 }
18736 
validateCCReservedRegs(const SmallVectorImpl<std::pair<llvm::Register,llvm::SDValue>> & Regs,MachineFunction & MF) const18737 void RISCVTargetLowering::validateCCReservedRegs(
18738     const SmallVectorImpl<std::pair<llvm::Register, llvm::SDValue>> &Regs,
18739     MachineFunction &MF) const {
18740   const Function &F = MF.getFunction();
18741   const RISCVSubtarget &STI = MF.getSubtarget<RISCVSubtarget>();
18742 
18743   if (llvm::any_of(Regs, [&STI](auto Reg) {
18744         return STI.isRegisterReservedByUser(Reg.first);
18745       }))
18746     F.getContext().diagnose(DiagnosticInfoUnsupported{
18747         F, "Argument register required, but has been reserved."});
18748 }
18749 
18750 // Check if the result of the node is only used as a return value, as
18751 // otherwise we can't perform a tail-call.
isUsedByReturnOnly(SDNode * N,SDValue & Chain) const18752 bool RISCVTargetLowering::isUsedByReturnOnly(SDNode *N, SDValue &Chain) const {
18753   if (N->getNumValues() != 1)
18754     return false;
18755   if (!N->hasNUsesOfValue(1, 0))
18756     return false;
18757 
18758   SDNode *Copy = *N->use_begin();
18759 
18760   if (Copy->getOpcode() == ISD::BITCAST) {
18761     return isUsedByReturnOnly(Copy, Chain);
18762   }
18763 
18764   // TODO: Handle additional opcodes in order to support tail-calling libcalls
18765   // with soft float ABIs.
18766   if (Copy->getOpcode() != ISD::CopyToReg) {
18767     return false;
18768   }
18769 
18770   // If the ISD::CopyToReg has a glue operand, we conservatively assume it
18771   // isn't safe to perform a tail call.
18772   if (Copy->getOperand(Copy->getNumOperands() - 1).getValueType() == MVT::Glue)
18773     return false;
18774 
18775   // The copy must be used by a RISCVISD::RET_GLUE, and nothing else.
18776   bool HasRet = false;
18777   for (SDNode *Node : Copy->uses()) {
18778     if (Node->getOpcode() != RISCVISD::RET_GLUE)
18779       return false;
18780     HasRet = true;
18781   }
18782   if (!HasRet)
18783     return false;
18784 
18785   Chain = Copy->getOperand(0);
18786   return true;
18787 }
18788 
mayBeEmittedAsTailCall(const CallInst * CI) const18789 bool RISCVTargetLowering::mayBeEmittedAsTailCall(const CallInst *CI) const {
18790   return CI->isTailCall();
18791 }
18792 
getTargetNodeName(unsigned Opcode) const18793 const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
18794 #define NODE_NAME_CASE(NODE)                                                   \
18795   case RISCVISD::NODE:                                                         \
18796     return "RISCVISD::" #NODE;
18797   // clang-format off
18798   switch ((RISCVISD::NodeType)Opcode) {
18799   case RISCVISD::FIRST_NUMBER:
18800     break;
18801   NODE_NAME_CASE(RET_GLUE)
18802   NODE_NAME_CASE(SRET_GLUE)
18803   NODE_NAME_CASE(MRET_GLUE)
18804   NODE_NAME_CASE(CALL)
18805   NODE_NAME_CASE(SELECT_CC)
18806   NODE_NAME_CASE(BR_CC)
18807   NODE_NAME_CASE(BuildPairF64)
18808   NODE_NAME_CASE(SplitF64)
18809   NODE_NAME_CASE(TAIL)
18810   NODE_NAME_CASE(ADD_LO)
18811   NODE_NAME_CASE(HI)
18812   NODE_NAME_CASE(LLA)
18813   NODE_NAME_CASE(ADD_TPREL)
18814   NODE_NAME_CASE(MULHSU)
18815   NODE_NAME_CASE(SLLW)
18816   NODE_NAME_CASE(SRAW)
18817   NODE_NAME_CASE(SRLW)
18818   NODE_NAME_CASE(DIVW)
18819   NODE_NAME_CASE(DIVUW)
18820   NODE_NAME_CASE(REMUW)
18821   NODE_NAME_CASE(ROLW)
18822   NODE_NAME_CASE(RORW)
18823   NODE_NAME_CASE(CLZW)
18824   NODE_NAME_CASE(CTZW)
18825   NODE_NAME_CASE(ABSW)
18826   NODE_NAME_CASE(FMV_H_X)
18827   NODE_NAME_CASE(FMV_X_ANYEXTH)
18828   NODE_NAME_CASE(FMV_X_SIGNEXTH)
18829   NODE_NAME_CASE(FMV_W_X_RV64)
18830   NODE_NAME_CASE(FMV_X_ANYEXTW_RV64)
18831   NODE_NAME_CASE(FCVT_X)
18832   NODE_NAME_CASE(FCVT_XU)
18833   NODE_NAME_CASE(FCVT_W_RV64)
18834   NODE_NAME_CASE(FCVT_WU_RV64)
18835   NODE_NAME_CASE(STRICT_FCVT_W_RV64)
18836   NODE_NAME_CASE(STRICT_FCVT_WU_RV64)
18837   NODE_NAME_CASE(FP_ROUND_BF16)
18838   NODE_NAME_CASE(FP_EXTEND_BF16)
18839   NODE_NAME_CASE(FROUND)
18840   NODE_NAME_CASE(FCLASS)
18841   NODE_NAME_CASE(FMAX)
18842   NODE_NAME_CASE(FMIN)
18843   NODE_NAME_CASE(READ_CYCLE_WIDE)
18844   NODE_NAME_CASE(BREV8)
18845   NODE_NAME_CASE(ORC_B)
18846   NODE_NAME_CASE(ZIP)
18847   NODE_NAME_CASE(UNZIP)
18848   NODE_NAME_CASE(CLMUL)
18849   NODE_NAME_CASE(CLMULH)
18850   NODE_NAME_CASE(CLMULR)
18851   NODE_NAME_CASE(SHA256SIG0)
18852   NODE_NAME_CASE(SHA256SIG1)
18853   NODE_NAME_CASE(SHA256SUM0)
18854   NODE_NAME_CASE(SHA256SUM1)
18855   NODE_NAME_CASE(SM4KS)
18856   NODE_NAME_CASE(SM4ED)
18857   NODE_NAME_CASE(SM3P0)
18858   NODE_NAME_CASE(SM3P1)
18859   NODE_NAME_CASE(TH_LWD)
18860   NODE_NAME_CASE(TH_LWUD)
18861   NODE_NAME_CASE(TH_LDD)
18862   NODE_NAME_CASE(TH_SWD)
18863   NODE_NAME_CASE(TH_SDD)
18864   NODE_NAME_CASE(VMV_V_V_VL)
18865   NODE_NAME_CASE(VMV_V_X_VL)
18866   NODE_NAME_CASE(VFMV_V_F_VL)
18867   NODE_NAME_CASE(VMV_X_S)
18868   NODE_NAME_CASE(VMV_S_X_VL)
18869   NODE_NAME_CASE(VFMV_S_F_VL)
18870   NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
18871   NODE_NAME_CASE(READ_VLENB)
18872   NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
18873   NODE_NAME_CASE(VSLIDEUP_VL)
18874   NODE_NAME_CASE(VSLIDE1UP_VL)
18875   NODE_NAME_CASE(VSLIDEDOWN_VL)
18876   NODE_NAME_CASE(VSLIDE1DOWN_VL)
18877   NODE_NAME_CASE(VFSLIDE1UP_VL)
18878   NODE_NAME_CASE(VFSLIDE1DOWN_VL)
18879   NODE_NAME_CASE(VID_VL)
18880   NODE_NAME_CASE(VFNCVT_ROD_VL)
18881   NODE_NAME_CASE(VECREDUCE_ADD_VL)
18882   NODE_NAME_CASE(VECREDUCE_UMAX_VL)
18883   NODE_NAME_CASE(VECREDUCE_SMAX_VL)
18884   NODE_NAME_CASE(VECREDUCE_UMIN_VL)
18885   NODE_NAME_CASE(VECREDUCE_SMIN_VL)
18886   NODE_NAME_CASE(VECREDUCE_AND_VL)
18887   NODE_NAME_CASE(VECREDUCE_OR_VL)
18888   NODE_NAME_CASE(VECREDUCE_XOR_VL)
18889   NODE_NAME_CASE(VECREDUCE_FADD_VL)
18890   NODE_NAME_CASE(VECREDUCE_SEQ_FADD_VL)
18891   NODE_NAME_CASE(VECREDUCE_FMIN_VL)
18892   NODE_NAME_CASE(VECREDUCE_FMAX_VL)
18893   NODE_NAME_CASE(ADD_VL)
18894   NODE_NAME_CASE(AND_VL)
18895   NODE_NAME_CASE(MUL_VL)
18896   NODE_NAME_CASE(OR_VL)
18897   NODE_NAME_CASE(SDIV_VL)
18898   NODE_NAME_CASE(SHL_VL)
18899   NODE_NAME_CASE(SREM_VL)
18900   NODE_NAME_CASE(SRA_VL)
18901   NODE_NAME_CASE(SRL_VL)
18902   NODE_NAME_CASE(ROTL_VL)
18903   NODE_NAME_CASE(ROTR_VL)
18904   NODE_NAME_CASE(SUB_VL)
18905   NODE_NAME_CASE(UDIV_VL)
18906   NODE_NAME_CASE(UREM_VL)
18907   NODE_NAME_CASE(XOR_VL)
18908   NODE_NAME_CASE(AVGFLOORU_VL)
18909   NODE_NAME_CASE(AVGCEILU_VL)
18910   NODE_NAME_CASE(SADDSAT_VL)
18911   NODE_NAME_CASE(UADDSAT_VL)
18912   NODE_NAME_CASE(SSUBSAT_VL)
18913   NODE_NAME_CASE(USUBSAT_VL)
18914   NODE_NAME_CASE(FADD_VL)
18915   NODE_NAME_CASE(FSUB_VL)
18916   NODE_NAME_CASE(FMUL_VL)
18917   NODE_NAME_CASE(FDIV_VL)
18918   NODE_NAME_CASE(FNEG_VL)
18919   NODE_NAME_CASE(FABS_VL)
18920   NODE_NAME_CASE(FSQRT_VL)
18921   NODE_NAME_CASE(FCLASS_VL)
18922   NODE_NAME_CASE(VFMADD_VL)
18923   NODE_NAME_CASE(VFNMADD_VL)
18924   NODE_NAME_CASE(VFMSUB_VL)
18925   NODE_NAME_CASE(VFNMSUB_VL)
18926   NODE_NAME_CASE(VFWMADD_VL)
18927   NODE_NAME_CASE(VFWNMADD_VL)
18928   NODE_NAME_CASE(VFWMSUB_VL)
18929   NODE_NAME_CASE(VFWNMSUB_VL)
18930   NODE_NAME_CASE(FCOPYSIGN_VL)
18931   NODE_NAME_CASE(SMIN_VL)
18932   NODE_NAME_CASE(SMAX_VL)
18933   NODE_NAME_CASE(UMIN_VL)
18934   NODE_NAME_CASE(UMAX_VL)
18935   NODE_NAME_CASE(BITREVERSE_VL)
18936   NODE_NAME_CASE(BSWAP_VL)
18937   NODE_NAME_CASE(CTLZ_VL)
18938   NODE_NAME_CASE(CTTZ_VL)
18939   NODE_NAME_CASE(CTPOP_VL)
18940   NODE_NAME_CASE(VFMIN_VL)
18941   NODE_NAME_CASE(VFMAX_VL)
18942   NODE_NAME_CASE(MULHS_VL)
18943   NODE_NAME_CASE(MULHU_VL)
18944   NODE_NAME_CASE(VFCVT_RTZ_X_F_VL)
18945   NODE_NAME_CASE(VFCVT_RTZ_XU_F_VL)
18946   NODE_NAME_CASE(VFCVT_RM_X_F_VL)
18947   NODE_NAME_CASE(VFCVT_RM_XU_F_VL)
18948   NODE_NAME_CASE(VFCVT_X_F_VL)
18949   NODE_NAME_CASE(VFCVT_XU_F_VL)
18950   NODE_NAME_CASE(VFROUND_NOEXCEPT_VL)
18951   NODE_NAME_CASE(SINT_TO_FP_VL)
18952   NODE_NAME_CASE(UINT_TO_FP_VL)
18953   NODE_NAME_CASE(VFCVT_RM_F_XU_VL)
18954   NODE_NAME_CASE(VFCVT_RM_F_X_VL)
18955   NODE_NAME_CASE(FP_EXTEND_VL)
18956   NODE_NAME_CASE(FP_ROUND_VL)
18957   NODE_NAME_CASE(STRICT_FADD_VL)
18958   NODE_NAME_CASE(STRICT_FSUB_VL)
18959   NODE_NAME_CASE(STRICT_FMUL_VL)
18960   NODE_NAME_CASE(STRICT_FDIV_VL)
18961   NODE_NAME_CASE(STRICT_FSQRT_VL)
18962   NODE_NAME_CASE(STRICT_VFMADD_VL)
18963   NODE_NAME_CASE(STRICT_VFNMADD_VL)
18964   NODE_NAME_CASE(STRICT_VFMSUB_VL)
18965   NODE_NAME_CASE(STRICT_VFNMSUB_VL)
18966   NODE_NAME_CASE(STRICT_FP_ROUND_VL)
18967   NODE_NAME_CASE(STRICT_FP_EXTEND_VL)
18968   NODE_NAME_CASE(STRICT_VFNCVT_ROD_VL)
18969   NODE_NAME_CASE(STRICT_SINT_TO_FP_VL)
18970   NODE_NAME_CASE(STRICT_UINT_TO_FP_VL)
18971   NODE_NAME_CASE(STRICT_VFCVT_RM_X_F_VL)
18972   NODE_NAME_CASE(STRICT_VFCVT_RTZ_X_F_VL)
18973   NODE_NAME_CASE(STRICT_VFCVT_RTZ_XU_F_VL)
18974   NODE_NAME_CASE(STRICT_FSETCC_VL)
18975   NODE_NAME_CASE(STRICT_FSETCCS_VL)
18976   NODE_NAME_CASE(STRICT_VFROUND_NOEXCEPT_VL)
18977   NODE_NAME_CASE(VWMUL_VL)
18978   NODE_NAME_CASE(VWMULU_VL)
18979   NODE_NAME_CASE(VWMULSU_VL)
18980   NODE_NAME_CASE(VWADD_VL)
18981   NODE_NAME_CASE(VWADDU_VL)
18982   NODE_NAME_CASE(VWSUB_VL)
18983   NODE_NAME_CASE(VWSUBU_VL)
18984   NODE_NAME_CASE(VWADD_W_VL)
18985   NODE_NAME_CASE(VWADDU_W_VL)
18986   NODE_NAME_CASE(VWSUB_W_VL)
18987   NODE_NAME_CASE(VWSUBU_W_VL)
18988   NODE_NAME_CASE(VWSLL_VL)
18989   NODE_NAME_CASE(VFWMUL_VL)
18990   NODE_NAME_CASE(VFWADD_VL)
18991   NODE_NAME_CASE(VFWSUB_VL)
18992   NODE_NAME_CASE(VFWADD_W_VL)
18993   NODE_NAME_CASE(VFWSUB_W_VL)
18994   NODE_NAME_CASE(VWMACC_VL)
18995   NODE_NAME_CASE(VWMACCU_VL)
18996   NODE_NAME_CASE(VWMACCSU_VL)
18997   NODE_NAME_CASE(VNSRL_VL)
18998   NODE_NAME_CASE(SETCC_VL)
18999   NODE_NAME_CASE(VMERGE_VL)
19000   NODE_NAME_CASE(VMAND_VL)
19001   NODE_NAME_CASE(VMOR_VL)
19002   NODE_NAME_CASE(VMXOR_VL)
19003   NODE_NAME_CASE(VMCLR_VL)
19004   NODE_NAME_CASE(VMSET_VL)
19005   NODE_NAME_CASE(VRGATHER_VX_VL)
19006   NODE_NAME_CASE(VRGATHER_VV_VL)
19007   NODE_NAME_CASE(VRGATHEREI16_VV_VL)
19008   NODE_NAME_CASE(VSEXT_VL)
19009   NODE_NAME_CASE(VZEXT_VL)
19010   NODE_NAME_CASE(VCPOP_VL)
19011   NODE_NAME_CASE(VFIRST_VL)
19012   NODE_NAME_CASE(READ_CSR)
19013   NODE_NAME_CASE(WRITE_CSR)
19014   NODE_NAME_CASE(SWAP_CSR)
19015   NODE_NAME_CASE(CZERO_EQZ)
19016   NODE_NAME_CASE(CZERO_NEZ)
19017   }
19018   // clang-format on
19019   return nullptr;
19020 #undef NODE_NAME_CASE
19021 }
19022 
19023 /// getConstraintType - Given a constraint letter, return the type of
19024 /// constraint it is for this target.
19025 RISCVTargetLowering::ConstraintType
getConstraintType(StringRef Constraint) const19026 RISCVTargetLowering::getConstraintType(StringRef Constraint) const {
19027   if (Constraint.size() == 1) {
19028     switch (Constraint[0]) {
19029     default:
19030       break;
19031     case 'f':
19032       return C_RegisterClass;
19033     case 'I':
19034     case 'J':
19035     case 'K':
19036       return C_Immediate;
19037     case 'A':
19038       return C_Memory;
19039     case 'S': // A symbolic address
19040       return C_Other;
19041     }
19042   } else {
19043     if (Constraint == "vr" || Constraint == "vm")
19044       return C_RegisterClass;
19045   }
19046   return TargetLowering::getConstraintType(Constraint);
19047 }
19048 
19049 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * TRI,StringRef Constraint,MVT VT) const19050 RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
19051                                                   StringRef Constraint,
19052                                                   MVT VT) const {
19053   // First, see if this is a constraint that directly corresponds to a RISC-V
19054   // register class.
19055   if (Constraint.size() == 1) {
19056     switch (Constraint[0]) {
19057     case 'r':
19058       // TODO: Support fixed vectors up to XLen for P extension?
19059       if (VT.isVector())
19060         break;
19061       if (VT == MVT::f16 && Subtarget.hasStdExtZhinxmin())
19062         return std::make_pair(0U, &RISCV::GPRF16RegClass);
19063       if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
19064         return std::make_pair(0U, &RISCV::GPRF32RegClass);
19065       if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
19066         return std::make_pair(0U, &RISCV::GPRPairRegClass);
19067       return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
19068     case 'f':
19069       if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
19070         return std::make_pair(0U, &RISCV::FPR16RegClass);
19071       if (Subtarget.hasStdExtF() && VT == MVT::f32)
19072         return std::make_pair(0U, &RISCV::FPR32RegClass);
19073       if (Subtarget.hasStdExtD() && VT == MVT::f64)
19074         return std::make_pair(0U, &RISCV::FPR64RegClass);
19075       break;
19076     default:
19077       break;
19078     }
19079   } else if (Constraint == "vr") {
19080     for (const auto *RC : {&RISCV::VRRegClass, &RISCV::VRM2RegClass,
19081                            &RISCV::VRM4RegClass, &RISCV::VRM8RegClass}) {
19082       if (TRI->isTypeLegalForClass(*RC, VT.SimpleTy))
19083         return std::make_pair(0U, RC);
19084     }
19085   } else if (Constraint == "vm") {
19086     if (TRI->isTypeLegalForClass(RISCV::VMV0RegClass, VT.SimpleTy))
19087       return std::make_pair(0U, &RISCV::VMV0RegClass);
19088   }
19089 
19090   // Clang will correctly decode the usage of register name aliases into their
19091   // official names. However, other frontends like `rustc` do not. This allows
19092   // users of these frontends to use the ABI names for registers in LLVM-style
19093   // register constraints.
19094   unsigned XRegFromAlias = StringSwitch<unsigned>(Constraint.lower())
19095                                .Case("{zero}", RISCV::X0)
19096                                .Case("{ra}", RISCV::X1)
19097                                .Case("{sp}", RISCV::X2)
19098                                .Case("{gp}", RISCV::X3)
19099                                .Case("{tp}", RISCV::X4)
19100                                .Case("{t0}", RISCV::X5)
19101                                .Case("{t1}", RISCV::X6)
19102                                .Case("{t2}", RISCV::X7)
19103                                .Cases("{s0}", "{fp}", RISCV::X8)
19104                                .Case("{s1}", RISCV::X9)
19105                                .Case("{a0}", RISCV::X10)
19106                                .Case("{a1}", RISCV::X11)
19107                                .Case("{a2}", RISCV::X12)
19108                                .Case("{a3}", RISCV::X13)
19109                                .Case("{a4}", RISCV::X14)
19110                                .Case("{a5}", RISCV::X15)
19111                                .Case("{a6}", RISCV::X16)
19112                                .Case("{a7}", RISCV::X17)
19113                                .Case("{s2}", RISCV::X18)
19114                                .Case("{s3}", RISCV::X19)
19115                                .Case("{s4}", RISCV::X20)
19116                                .Case("{s5}", RISCV::X21)
19117                                .Case("{s6}", RISCV::X22)
19118                                .Case("{s7}", RISCV::X23)
19119                                .Case("{s8}", RISCV::X24)
19120                                .Case("{s9}", RISCV::X25)
19121                                .Case("{s10}", RISCV::X26)
19122                                .Case("{s11}", RISCV::X27)
19123                                .Case("{t3}", RISCV::X28)
19124                                .Case("{t4}", RISCV::X29)
19125                                .Case("{t5}", RISCV::X30)
19126                                .Case("{t6}", RISCV::X31)
19127                                .Default(RISCV::NoRegister);
19128   if (XRegFromAlias != RISCV::NoRegister)
19129     return std::make_pair(XRegFromAlias, &RISCV::GPRRegClass);
19130 
19131   // Since TargetLowering::getRegForInlineAsmConstraint uses the name of the
19132   // TableGen record rather than the AsmName to choose registers for InlineAsm
19133   // constraints, plus we want to match those names to the widest floating point
19134   // register type available, manually select floating point registers here.
19135   //
19136   // The second case is the ABI name of the register, so that frontends can also
19137   // use the ABI names in register constraint lists.
19138   if (Subtarget.hasStdExtF()) {
19139     unsigned FReg = StringSwitch<unsigned>(Constraint.lower())
19140                         .Cases("{f0}", "{ft0}", RISCV::F0_F)
19141                         .Cases("{f1}", "{ft1}", RISCV::F1_F)
19142                         .Cases("{f2}", "{ft2}", RISCV::F2_F)
19143                         .Cases("{f3}", "{ft3}", RISCV::F3_F)
19144                         .Cases("{f4}", "{ft4}", RISCV::F4_F)
19145                         .Cases("{f5}", "{ft5}", RISCV::F5_F)
19146                         .Cases("{f6}", "{ft6}", RISCV::F6_F)
19147                         .Cases("{f7}", "{ft7}", RISCV::F7_F)
19148                         .Cases("{f8}", "{fs0}", RISCV::F8_F)
19149                         .Cases("{f9}", "{fs1}", RISCV::F9_F)
19150                         .Cases("{f10}", "{fa0}", RISCV::F10_F)
19151                         .Cases("{f11}", "{fa1}", RISCV::F11_F)
19152                         .Cases("{f12}", "{fa2}", RISCV::F12_F)
19153                         .Cases("{f13}", "{fa3}", RISCV::F13_F)
19154                         .Cases("{f14}", "{fa4}", RISCV::F14_F)
19155                         .Cases("{f15}", "{fa5}", RISCV::F15_F)
19156                         .Cases("{f16}", "{fa6}", RISCV::F16_F)
19157                         .Cases("{f17}", "{fa7}", RISCV::F17_F)
19158                         .Cases("{f18}", "{fs2}", RISCV::F18_F)
19159                         .Cases("{f19}", "{fs3}", RISCV::F19_F)
19160                         .Cases("{f20}", "{fs4}", RISCV::F20_F)
19161                         .Cases("{f21}", "{fs5}", RISCV::F21_F)
19162                         .Cases("{f22}", "{fs6}", RISCV::F22_F)
19163                         .Cases("{f23}", "{fs7}", RISCV::F23_F)
19164                         .Cases("{f24}", "{fs8}", RISCV::F24_F)
19165                         .Cases("{f25}", "{fs9}", RISCV::F25_F)
19166                         .Cases("{f26}", "{fs10}", RISCV::F26_F)
19167                         .Cases("{f27}", "{fs11}", RISCV::F27_F)
19168                         .Cases("{f28}", "{ft8}", RISCV::F28_F)
19169                         .Cases("{f29}", "{ft9}", RISCV::F29_F)
19170                         .Cases("{f30}", "{ft10}", RISCV::F30_F)
19171                         .Cases("{f31}", "{ft11}", RISCV::F31_F)
19172                         .Default(RISCV::NoRegister);
19173     if (FReg != RISCV::NoRegister) {
19174       assert(RISCV::F0_F <= FReg && FReg <= RISCV::F31_F && "Unknown fp-reg");
19175       if (Subtarget.hasStdExtD() && (VT == MVT::f64 || VT == MVT::Other)) {
19176         unsigned RegNo = FReg - RISCV::F0_F;
19177         unsigned DReg = RISCV::F0_D + RegNo;
19178         return std::make_pair(DReg, &RISCV::FPR64RegClass);
19179       }
19180       if (VT == MVT::f32 || VT == MVT::Other)
19181         return std::make_pair(FReg, &RISCV::FPR32RegClass);
19182       if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16) {
19183         unsigned RegNo = FReg - RISCV::F0_F;
19184         unsigned HReg = RISCV::F0_H + RegNo;
19185         return std::make_pair(HReg, &RISCV::FPR16RegClass);
19186       }
19187     }
19188   }
19189 
19190   if (Subtarget.hasVInstructions()) {
19191     Register VReg = StringSwitch<Register>(Constraint.lower())
19192                         .Case("{v0}", RISCV::V0)
19193                         .Case("{v1}", RISCV::V1)
19194                         .Case("{v2}", RISCV::V2)
19195                         .Case("{v3}", RISCV::V3)
19196                         .Case("{v4}", RISCV::V4)
19197                         .Case("{v5}", RISCV::V5)
19198                         .Case("{v6}", RISCV::V6)
19199                         .Case("{v7}", RISCV::V7)
19200                         .Case("{v8}", RISCV::V8)
19201                         .Case("{v9}", RISCV::V9)
19202                         .Case("{v10}", RISCV::V10)
19203                         .Case("{v11}", RISCV::V11)
19204                         .Case("{v12}", RISCV::V12)
19205                         .Case("{v13}", RISCV::V13)
19206                         .Case("{v14}", RISCV::V14)
19207                         .Case("{v15}", RISCV::V15)
19208                         .Case("{v16}", RISCV::V16)
19209                         .Case("{v17}", RISCV::V17)
19210                         .Case("{v18}", RISCV::V18)
19211                         .Case("{v19}", RISCV::V19)
19212                         .Case("{v20}", RISCV::V20)
19213                         .Case("{v21}", RISCV::V21)
19214                         .Case("{v22}", RISCV::V22)
19215                         .Case("{v23}", RISCV::V23)
19216                         .Case("{v24}", RISCV::V24)
19217                         .Case("{v25}", RISCV::V25)
19218                         .Case("{v26}", RISCV::V26)
19219                         .Case("{v27}", RISCV::V27)
19220                         .Case("{v28}", RISCV::V28)
19221                         .Case("{v29}", RISCV::V29)
19222                         .Case("{v30}", RISCV::V30)
19223                         .Case("{v31}", RISCV::V31)
19224                         .Default(RISCV::NoRegister);
19225     if (VReg != RISCV::NoRegister) {
19226       if (TRI->isTypeLegalForClass(RISCV::VMRegClass, VT.SimpleTy))
19227         return std::make_pair(VReg, &RISCV::VMRegClass);
19228       if (TRI->isTypeLegalForClass(RISCV::VRRegClass, VT.SimpleTy))
19229         return std::make_pair(VReg, &RISCV::VRRegClass);
19230       for (const auto *RC :
19231            {&RISCV::VRM2RegClass, &RISCV::VRM4RegClass, &RISCV::VRM8RegClass}) {
19232         if (TRI->isTypeLegalForClass(*RC, VT.SimpleTy)) {
19233           VReg = TRI->getMatchingSuperReg(VReg, RISCV::sub_vrm1_0, RC);
19234           return std::make_pair(VReg, RC);
19235         }
19236       }
19237     }
19238   }
19239 
19240   std::pair<Register, const TargetRegisterClass *> Res =
19241       TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
19242 
19243   // If we picked one of the Zfinx register classes, remap it to the GPR class.
19244   // FIXME: When Zfinx is supported in CodeGen this will need to take the
19245   // Subtarget into account.
19246   if (Res.second == &RISCV::GPRF16RegClass ||
19247       Res.second == &RISCV::GPRF32RegClass ||
19248       Res.second == &RISCV::GPRPairRegClass)
19249     return std::make_pair(Res.first, &RISCV::GPRRegClass);
19250 
19251   return Res;
19252 }
19253 
19254 InlineAsm::ConstraintCode
getInlineAsmMemConstraint(StringRef ConstraintCode) const19255 RISCVTargetLowering::getInlineAsmMemConstraint(StringRef ConstraintCode) const {
19256   // Currently only support length 1 constraints.
19257   if (ConstraintCode.size() == 1) {
19258     switch (ConstraintCode[0]) {
19259     case 'A':
19260       return InlineAsm::ConstraintCode::A;
19261     default:
19262       break;
19263     }
19264   }
19265 
19266   return TargetLowering::getInlineAsmMemConstraint(ConstraintCode);
19267 }
19268 
LowerAsmOperandForConstraint(SDValue Op,StringRef Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const19269 void RISCVTargetLowering::LowerAsmOperandForConstraint(
19270     SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
19271     SelectionDAG &DAG) const {
19272   // Currently only support length 1 constraints.
19273   if (Constraint.size() == 1) {
19274     switch (Constraint[0]) {
19275     case 'I':
19276       // Validate & create a 12-bit signed immediate operand.
19277       if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
19278         uint64_t CVal = C->getSExtValue();
19279         if (isInt<12>(CVal))
19280           Ops.push_back(
19281               DAG.getTargetConstant(CVal, SDLoc(Op), Subtarget.getXLenVT()));
19282       }
19283       return;
19284     case 'J':
19285       // Validate & create an integer zero operand.
19286       if (isNullConstant(Op))
19287         Ops.push_back(
19288             DAG.getTargetConstant(0, SDLoc(Op), Subtarget.getXLenVT()));
19289       return;
19290     case 'K':
19291       // Validate & create a 5-bit unsigned immediate operand.
19292       if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
19293         uint64_t CVal = C->getZExtValue();
19294         if (isUInt<5>(CVal))
19295           Ops.push_back(
19296               DAG.getTargetConstant(CVal, SDLoc(Op), Subtarget.getXLenVT()));
19297       }
19298       return;
19299     case 'S':
19300       if (const auto *GA = dyn_cast<GlobalAddressSDNode>(Op)) {
19301         Ops.push_back(DAG.getTargetGlobalAddress(GA->getGlobal(), SDLoc(Op),
19302                                                  GA->getValueType(0)));
19303       } else if (const auto *BA = dyn_cast<BlockAddressSDNode>(Op)) {
19304         Ops.push_back(DAG.getTargetBlockAddress(BA->getBlockAddress(),
19305                                                 BA->getValueType(0)));
19306       }
19307       return;
19308     default:
19309       break;
19310     }
19311   }
19312   TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
19313 }
19314 
emitLeadingFence(IRBuilderBase & Builder,Instruction * Inst,AtomicOrdering Ord) const19315 Instruction *RISCVTargetLowering::emitLeadingFence(IRBuilderBase &Builder,
19316                                                    Instruction *Inst,
19317                                                    AtomicOrdering Ord) const {
19318   if (Subtarget.hasStdExtZtso()) {
19319     if (isa<LoadInst>(Inst) && Ord == AtomicOrdering::SequentiallyConsistent)
19320       return Builder.CreateFence(Ord);
19321     return nullptr;
19322   }
19323 
19324   if (isa<LoadInst>(Inst) && Ord == AtomicOrdering::SequentiallyConsistent)
19325     return Builder.CreateFence(Ord);
19326   if (isa<StoreInst>(Inst) && isReleaseOrStronger(Ord))
19327     return Builder.CreateFence(AtomicOrdering::Release);
19328   return nullptr;
19329 }
19330 
emitTrailingFence(IRBuilderBase & Builder,Instruction * Inst,AtomicOrdering Ord) const19331 Instruction *RISCVTargetLowering::emitTrailingFence(IRBuilderBase &Builder,
19332                                                     Instruction *Inst,
19333                                                     AtomicOrdering Ord) const {
19334   if (Subtarget.hasStdExtZtso()) {
19335     if (isa<StoreInst>(Inst) && Ord == AtomicOrdering::SequentiallyConsistent)
19336       return Builder.CreateFence(Ord);
19337     return nullptr;
19338   }
19339 
19340   if (isa<LoadInst>(Inst) && isAcquireOrStronger(Ord))
19341     return Builder.CreateFence(AtomicOrdering::Acquire);
19342   if (Subtarget.enableSeqCstTrailingFence() && isa<StoreInst>(Inst) &&
19343       Ord == AtomicOrdering::SequentiallyConsistent)
19344     return Builder.CreateFence(AtomicOrdering::SequentiallyConsistent);
19345   return nullptr;
19346 }
19347 
19348 TargetLowering::AtomicExpansionKind
shouldExpandAtomicRMWInIR(AtomicRMWInst * AI) const19349 RISCVTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
19350   // atomicrmw {fadd,fsub} must be expanded to use compare-exchange, as floating
19351   // point operations can't be used in an lr/sc sequence without breaking the
19352   // forward-progress guarantee.
19353   if (AI->isFloatingPointOperation() ||
19354       AI->getOperation() == AtomicRMWInst::UIncWrap ||
19355       AI->getOperation() == AtomicRMWInst::UDecWrap)
19356     return AtomicExpansionKind::CmpXChg;
19357 
19358   // Don't expand forced atomics, we want to have __sync libcalls instead.
19359   if (Subtarget.hasForcedAtomics())
19360     return AtomicExpansionKind::None;
19361 
19362   unsigned Size = AI->getType()->getPrimitiveSizeInBits();
19363   if (Size == 8 || Size == 16)
19364     return AtomicExpansionKind::MaskedIntrinsic;
19365   return AtomicExpansionKind::None;
19366 }
19367 
19368 static Intrinsic::ID
getIntrinsicForMaskedAtomicRMWBinOp(unsigned XLen,AtomicRMWInst::BinOp BinOp)19369 getIntrinsicForMaskedAtomicRMWBinOp(unsigned XLen, AtomicRMWInst::BinOp BinOp) {
19370   if (XLen == 32) {
19371     switch (BinOp) {
19372     default:
19373       llvm_unreachable("Unexpected AtomicRMW BinOp");
19374     case AtomicRMWInst::Xchg:
19375       return Intrinsic::riscv_masked_atomicrmw_xchg_i32;
19376     case AtomicRMWInst::Add:
19377       return Intrinsic::riscv_masked_atomicrmw_add_i32;
19378     case AtomicRMWInst::Sub:
19379       return Intrinsic::riscv_masked_atomicrmw_sub_i32;
19380     case AtomicRMWInst::Nand:
19381       return Intrinsic::riscv_masked_atomicrmw_nand_i32;
19382     case AtomicRMWInst::Max:
19383       return Intrinsic::riscv_masked_atomicrmw_max_i32;
19384     case AtomicRMWInst::Min:
19385       return Intrinsic::riscv_masked_atomicrmw_min_i32;
19386     case AtomicRMWInst::UMax:
19387       return Intrinsic::riscv_masked_atomicrmw_umax_i32;
19388     case AtomicRMWInst::UMin:
19389       return Intrinsic::riscv_masked_atomicrmw_umin_i32;
19390     }
19391   }
19392 
19393   if (XLen == 64) {
19394     switch (BinOp) {
19395     default:
19396       llvm_unreachable("Unexpected AtomicRMW BinOp");
19397     case AtomicRMWInst::Xchg:
19398       return Intrinsic::riscv_masked_atomicrmw_xchg_i64;
19399     case AtomicRMWInst::Add:
19400       return Intrinsic::riscv_masked_atomicrmw_add_i64;
19401     case AtomicRMWInst::Sub:
19402       return Intrinsic::riscv_masked_atomicrmw_sub_i64;
19403     case AtomicRMWInst::Nand:
19404       return Intrinsic::riscv_masked_atomicrmw_nand_i64;
19405     case AtomicRMWInst::Max:
19406       return Intrinsic::riscv_masked_atomicrmw_max_i64;
19407     case AtomicRMWInst::Min:
19408       return Intrinsic::riscv_masked_atomicrmw_min_i64;
19409     case AtomicRMWInst::UMax:
19410       return Intrinsic::riscv_masked_atomicrmw_umax_i64;
19411     case AtomicRMWInst::UMin:
19412       return Intrinsic::riscv_masked_atomicrmw_umin_i64;
19413     }
19414   }
19415 
19416   llvm_unreachable("Unexpected XLen\n");
19417 }
19418 
emitMaskedAtomicRMWIntrinsic(IRBuilderBase & Builder,AtomicRMWInst * AI,Value * AlignedAddr,Value * Incr,Value * Mask,Value * ShiftAmt,AtomicOrdering Ord) const19419 Value *RISCVTargetLowering::emitMaskedAtomicRMWIntrinsic(
19420     IRBuilderBase &Builder, AtomicRMWInst *AI, Value *AlignedAddr, Value *Incr,
19421     Value *Mask, Value *ShiftAmt, AtomicOrdering Ord) const {
19422   // In the case of an atomicrmw xchg with a constant 0/-1 operand, replace
19423   // the atomic instruction with an AtomicRMWInst::And/Or with appropriate
19424   // mask, as this produces better code than the LR/SC loop emitted by
19425   // int_riscv_masked_atomicrmw_xchg.
19426   if (AI->getOperation() == AtomicRMWInst::Xchg &&
19427       isa<ConstantInt>(AI->getValOperand())) {
19428     ConstantInt *CVal = cast<ConstantInt>(AI->getValOperand());
19429     if (CVal->isZero())
19430       return Builder.CreateAtomicRMW(AtomicRMWInst::And, AlignedAddr,
19431                                      Builder.CreateNot(Mask, "Inv_Mask"),
19432                                      AI->getAlign(), Ord);
19433     if (CVal->isMinusOne())
19434       return Builder.CreateAtomicRMW(AtomicRMWInst::Or, AlignedAddr, Mask,
19435                                      AI->getAlign(), Ord);
19436   }
19437 
19438   unsigned XLen = Subtarget.getXLen();
19439   Value *Ordering =
19440       Builder.getIntN(XLen, static_cast<uint64_t>(AI->getOrdering()));
19441   Type *Tys[] = {AlignedAddr->getType()};
19442   Function *LrwOpScwLoop = Intrinsic::getDeclaration(
19443       AI->getModule(),
19444       getIntrinsicForMaskedAtomicRMWBinOp(XLen, AI->getOperation()), Tys);
19445 
19446   if (XLen == 64) {
19447     Incr = Builder.CreateSExt(Incr, Builder.getInt64Ty());
19448     Mask = Builder.CreateSExt(Mask, Builder.getInt64Ty());
19449     ShiftAmt = Builder.CreateSExt(ShiftAmt, Builder.getInt64Ty());
19450   }
19451 
19452   Value *Result;
19453 
19454   // Must pass the shift amount needed to sign extend the loaded value prior
19455   // to performing a signed comparison for min/max. ShiftAmt is the number of
19456   // bits to shift the value into position. Pass XLen-ShiftAmt-ValWidth, which
19457   // is the number of bits to left+right shift the value in order to
19458   // sign-extend.
19459   if (AI->getOperation() == AtomicRMWInst::Min ||
19460       AI->getOperation() == AtomicRMWInst::Max) {
19461     const DataLayout &DL = AI->getModule()->getDataLayout();
19462     unsigned ValWidth =
19463         DL.getTypeStoreSizeInBits(AI->getValOperand()->getType());
19464     Value *SextShamt =
19465         Builder.CreateSub(Builder.getIntN(XLen, XLen - ValWidth), ShiftAmt);
19466     Result = Builder.CreateCall(LrwOpScwLoop,
19467                                 {AlignedAddr, Incr, Mask, SextShamt, Ordering});
19468   } else {
19469     Result =
19470         Builder.CreateCall(LrwOpScwLoop, {AlignedAddr, Incr, Mask, Ordering});
19471   }
19472 
19473   if (XLen == 64)
19474     Result = Builder.CreateTrunc(Result, Builder.getInt32Ty());
19475   return Result;
19476 }
19477 
19478 TargetLowering::AtomicExpansionKind
shouldExpandAtomicCmpXchgInIR(AtomicCmpXchgInst * CI) const19479 RISCVTargetLowering::shouldExpandAtomicCmpXchgInIR(
19480     AtomicCmpXchgInst *CI) const {
19481   // Don't expand forced atomics, we want to have __sync libcalls instead.
19482   if (Subtarget.hasForcedAtomics())
19483     return AtomicExpansionKind::None;
19484 
19485   unsigned Size = CI->getCompareOperand()->getType()->getPrimitiveSizeInBits();
19486   if (Size == 8 || Size == 16)
19487     return AtomicExpansionKind::MaskedIntrinsic;
19488   return AtomicExpansionKind::None;
19489 }
19490 
emitMaskedAtomicCmpXchgIntrinsic(IRBuilderBase & Builder,AtomicCmpXchgInst * CI,Value * AlignedAddr,Value * CmpVal,Value * NewVal,Value * Mask,AtomicOrdering Ord) const19491 Value *RISCVTargetLowering::emitMaskedAtomicCmpXchgIntrinsic(
19492     IRBuilderBase &Builder, AtomicCmpXchgInst *CI, Value *AlignedAddr,
19493     Value *CmpVal, Value *NewVal, Value *Mask, AtomicOrdering Ord) const {
19494   unsigned XLen = Subtarget.getXLen();
19495   Value *Ordering = Builder.getIntN(XLen, static_cast<uint64_t>(Ord));
19496   Intrinsic::ID CmpXchgIntrID = Intrinsic::riscv_masked_cmpxchg_i32;
19497   if (XLen == 64) {
19498     CmpVal = Builder.CreateSExt(CmpVal, Builder.getInt64Ty());
19499     NewVal = Builder.CreateSExt(NewVal, Builder.getInt64Ty());
19500     Mask = Builder.CreateSExt(Mask, Builder.getInt64Ty());
19501     CmpXchgIntrID = Intrinsic::riscv_masked_cmpxchg_i64;
19502   }
19503   Type *Tys[] = {AlignedAddr->getType()};
19504   Function *MaskedCmpXchg =
19505       Intrinsic::getDeclaration(CI->getModule(), CmpXchgIntrID, Tys);
19506   Value *Result = Builder.CreateCall(
19507       MaskedCmpXchg, {AlignedAddr, CmpVal, NewVal, Mask, Ordering});
19508   if (XLen == 64)
19509     Result = Builder.CreateTrunc(Result, Builder.getInt32Ty());
19510   return Result;
19511 }
19512 
shouldRemoveExtendFromGSIndex(SDValue Extend,EVT DataVT) const19513 bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend,
19514                                                         EVT DataVT) const {
19515   // We have indexed loads for all legal index types.  Indices are always
19516   // zero extended
19517   return Extend.getOpcode() == ISD::ZERO_EXTEND &&
19518     isTypeLegal(Extend.getValueType()) &&
19519     isTypeLegal(Extend.getOperand(0).getValueType());
19520 }
19521 
shouldConvertFpToSat(unsigned Op,EVT FPVT,EVT VT) const19522 bool RISCVTargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,
19523                                                EVT VT) const {
19524   if (!isOperationLegalOrCustom(Op, VT) || !FPVT.isSimple())
19525     return false;
19526 
19527   switch (FPVT.getSimpleVT().SimpleTy) {
19528   case MVT::f16:
19529     return Subtarget.hasStdExtZfhmin();
19530   case MVT::f32:
19531     return Subtarget.hasStdExtF();
19532   case MVT::f64:
19533     return Subtarget.hasStdExtD();
19534   default:
19535     return false;
19536   }
19537 }
19538 
getJumpTableEncoding() const19539 unsigned RISCVTargetLowering::getJumpTableEncoding() const {
19540   // If we are using the small code model, we can reduce size of jump table
19541   // entry to 4 bytes.
19542   if (Subtarget.is64Bit() && !isPositionIndependent() &&
19543       getTargetMachine().getCodeModel() == CodeModel::Small) {
19544     return MachineJumpTableInfo::EK_Custom32;
19545   }
19546   return TargetLowering::getJumpTableEncoding();
19547 }
19548 
LowerCustomJumpTableEntry(const MachineJumpTableInfo * MJTI,const MachineBasicBlock * MBB,unsigned uid,MCContext & Ctx) const19549 const MCExpr *RISCVTargetLowering::LowerCustomJumpTableEntry(
19550     const MachineJumpTableInfo *MJTI, const MachineBasicBlock *MBB,
19551     unsigned uid, MCContext &Ctx) const {
19552   assert(Subtarget.is64Bit() && !isPositionIndependent() &&
19553          getTargetMachine().getCodeModel() == CodeModel::Small);
19554   return MCSymbolRefExpr::create(MBB->getSymbol(), Ctx);
19555 }
19556 
isVScaleKnownToBeAPowerOfTwo() const19557 bool RISCVTargetLowering::isVScaleKnownToBeAPowerOfTwo() const {
19558   // We define vscale to be VLEN/RVVBitsPerBlock.  VLEN is always a power
19559   // of two >= 64, and RVVBitsPerBlock is 64.  Thus, vscale must be
19560   // a power of two as well.
19561   // FIXME: This doesn't work for zve32, but that's already broken
19562   // elsewhere for the same reason.
19563   assert(Subtarget.getRealMinVLen() >= 64 && "zve32* unsupported");
19564   static_assert(RISCV::RVVBitsPerBlock == 64,
19565                 "RVVBitsPerBlock changed, audit needed");
19566   return true;
19567 }
19568 
getIndexedAddressParts(SDNode * Op,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const19569 bool RISCVTargetLowering::getIndexedAddressParts(SDNode *Op, SDValue &Base,
19570                                                  SDValue &Offset,
19571                                                  ISD::MemIndexedMode &AM,
19572                                                  SelectionDAG &DAG) const {
19573   // Target does not support indexed loads.
19574   if (!Subtarget.hasVendorXTHeadMemIdx())
19575     return false;
19576 
19577   if (Op->getOpcode() != ISD::ADD && Op->getOpcode() != ISD::SUB)
19578     return false;
19579 
19580   Base = Op->getOperand(0);
19581   if (ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Op->getOperand(1))) {
19582     int64_t RHSC = RHS->getSExtValue();
19583     if (Op->getOpcode() == ISD::SUB)
19584       RHSC = -(uint64_t)RHSC;
19585 
19586     // The constants that can be encoded in the THeadMemIdx instructions
19587     // are of the form (sign_extend(imm5) << imm2).
19588     bool isLegalIndexedOffset = false;
19589     for (unsigned i = 0; i < 4; i++)
19590       if (isInt<5>(RHSC >> i) && ((RHSC % (1LL << i)) == 0)) {
19591         isLegalIndexedOffset = true;
19592         break;
19593       }
19594 
19595     if (!isLegalIndexedOffset)
19596       return false;
19597 
19598     Offset = Op->getOperand(1);
19599     return true;
19600   }
19601 
19602   return false;
19603 }
19604 
getPreIndexedAddressParts(SDNode * N,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const19605 bool RISCVTargetLowering::getPreIndexedAddressParts(SDNode *N, SDValue &Base,
19606                                                     SDValue &Offset,
19607                                                     ISD::MemIndexedMode &AM,
19608                                                     SelectionDAG &DAG) const {
19609   EVT VT;
19610   SDValue Ptr;
19611   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
19612     VT = LD->getMemoryVT();
19613     Ptr = LD->getBasePtr();
19614   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
19615     VT = ST->getMemoryVT();
19616     Ptr = ST->getBasePtr();
19617   } else
19618     return false;
19619 
19620   if (!getIndexedAddressParts(Ptr.getNode(), Base, Offset, AM, DAG))
19621     return false;
19622 
19623   AM = ISD::PRE_INC;
19624   return true;
19625 }
19626 
getPostIndexedAddressParts(SDNode * N,SDNode * Op,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const19627 bool RISCVTargetLowering::getPostIndexedAddressParts(SDNode *N, SDNode *Op,
19628                                                      SDValue &Base,
19629                                                      SDValue &Offset,
19630                                                      ISD::MemIndexedMode &AM,
19631                                                      SelectionDAG &DAG) const {
19632   EVT VT;
19633   SDValue Ptr;
19634   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
19635     VT = LD->getMemoryVT();
19636     Ptr = LD->getBasePtr();
19637   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
19638     VT = ST->getMemoryVT();
19639     Ptr = ST->getBasePtr();
19640   } else
19641     return false;
19642 
19643   if (!getIndexedAddressParts(Op, Base, Offset, AM, DAG))
19644     return false;
19645   // Post-indexing updates the base, so it's not a valid transform
19646   // if that's not the same as the load's pointer.
19647   if (Ptr != Base)
19648     return false;
19649 
19650   AM = ISD::POST_INC;
19651   return true;
19652 }
19653 
isFMAFasterThanFMulAndFAdd(const MachineFunction & MF,EVT VT) const19654 bool RISCVTargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
19655                                                      EVT VT) const {
19656   EVT SVT = VT.getScalarType();
19657 
19658   if (!SVT.isSimple())
19659     return false;
19660 
19661   switch (SVT.getSimpleVT().SimpleTy) {
19662   case MVT::f16:
19663     return VT.isVector() ? Subtarget.hasVInstructionsF16()
19664                          : Subtarget.hasStdExtZfhOrZhinx();
19665   case MVT::f32:
19666     return Subtarget.hasStdExtFOrZfinx();
19667   case MVT::f64:
19668     return Subtarget.hasStdExtDOrZdinx();
19669   default:
19670     break;
19671   }
19672 
19673   return false;
19674 }
19675 
getExtendForAtomicCmpSwapArg() const19676 ISD::NodeType RISCVTargetLowering::getExtendForAtomicCmpSwapArg() const {
19677   // Zacas will use amocas.w which does not require extension.
19678   return Subtarget.hasStdExtZacas() ? ISD::ANY_EXTEND : ISD::SIGN_EXTEND;
19679 }
19680 
getExceptionPointerRegister(const Constant * PersonalityFn) const19681 Register RISCVTargetLowering::getExceptionPointerRegister(
19682     const Constant *PersonalityFn) const {
19683   return RISCV::X10;
19684 }
19685 
getExceptionSelectorRegister(const Constant * PersonalityFn) const19686 Register RISCVTargetLowering::getExceptionSelectorRegister(
19687     const Constant *PersonalityFn) const {
19688   return RISCV::X11;
19689 }
19690 
shouldExtendTypeInLibCall(EVT Type) const19691 bool RISCVTargetLowering::shouldExtendTypeInLibCall(EVT Type) const {
19692   // Return false to suppress the unnecessary extensions if the LibCall
19693   // arguments or return value is a float narrower than XLEN on a soft FP ABI.
19694   if (Subtarget.isSoftFPABI() && (Type.isFloatingPoint() && !Type.isVector() &&
19695                                   Type.getSizeInBits() < Subtarget.getXLen()))
19696     return false;
19697 
19698   return true;
19699 }
19700 
shouldSignExtendTypeInLibCall(EVT Type,bool IsSigned) const19701 bool RISCVTargetLowering::shouldSignExtendTypeInLibCall(EVT Type, bool IsSigned) const {
19702   if (Subtarget.is64Bit() && Type == MVT::i32)
19703     return true;
19704 
19705   return IsSigned;
19706 }
19707 
decomposeMulByConstant(LLVMContext & Context,EVT VT,SDValue C) const19708 bool RISCVTargetLowering::decomposeMulByConstant(LLVMContext &Context, EVT VT,
19709                                                  SDValue C) const {
19710   // Check integral scalar types.
19711   const bool HasExtMOrZmmul =
19712       Subtarget.hasStdExtM() || Subtarget.hasStdExtZmmul();
19713   if (!VT.isScalarInteger())
19714     return false;
19715 
19716   // Omit the optimization if the sub target has the M extension and the data
19717   // size exceeds XLen.
19718   if (HasExtMOrZmmul && VT.getSizeInBits() > Subtarget.getXLen())
19719     return false;
19720 
19721   if (auto *ConstNode = dyn_cast<ConstantSDNode>(C.getNode())) {
19722     // Break the MUL to a SLLI and an ADD/SUB.
19723     const APInt &Imm = ConstNode->getAPIntValue();
19724     if ((Imm + 1).isPowerOf2() || (Imm - 1).isPowerOf2() ||
19725         (1 - Imm).isPowerOf2() || (-1 - Imm).isPowerOf2())
19726       return true;
19727 
19728     // Optimize the MUL to (SH*ADD x, (SLLI x, bits)) if Imm is not simm12.
19729     if (Subtarget.hasStdExtZba() && !Imm.isSignedIntN(12) &&
19730         ((Imm - 2).isPowerOf2() || (Imm - 4).isPowerOf2() ||
19731          (Imm - 8).isPowerOf2()))
19732       return true;
19733 
19734     // Break the MUL to two SLLI instructions and an ADD/SUB, if Imm needs
19735     // a pair of LUI/ADDI.
19736     if (!Imm.isSignedIntN(12) && Imm.countr_zero() < 12 &&
19737         ConstNode->hasOneUse()) {
19738       APInt ImmS = Imm.ashr(Imm.countr_zero());
19739       if ((ImmS + 1).isPowerOf2() || (ImmS - 1).isPowerOf2() ||
19740           (1 - ImmS).isPowerOf2())
19741         return true;
19742     }
19743   }
19744 
19745   return false;
19746 }
19747 
isMulAddWithConstProfitable(SDValue AddNode,SDValue ConstNode) const19748 bool RISCVTargetLowering::isMulAddWithConstProfitable(SDValue AddNode,
19749                                                       SDValue ConstNode) const {
19750   // Let the DAGCombiner decide for vectors.
19751   EVT VT = AddNode.getValueType();
19752   if (VT.isVector())
19753     return true;
19754 
19755   // Let the DAGCombiner decide for larger types.
19756   if (VT.getScalarSizeInBits() > Subtarget.getXLen())
19757     return true;
19758 
19759   // It is worse if c1 is simm12 while c1*c2 is not.
19760   ConstantSDNode *C1Node = cast<ConstantSDNode>(AddNode.getOperand(1));
19761   ConstantSDNode *C2Node = cast<ConstantSDNode>(ConstNode);
19762   const APInt &C1 = C1Node->getAPIntValue();
19763   const APInt &C2 = C2Node->getAPIntValue();
19764   if (C1.isSignedIntN(12) && !(C1 * C2).isSignedIntN(12))
19765     return false;
19766 
19767   // Default to true and let the DAGCombiner decide.
19768   return true;
19769 }
19770 
allowsMisalignedMemoryAccesses(EVT VT,unsigned AddrSpace,Align Alignment,MachineMemOperand::Flags Flags,unsigned * Fast) const19771 bool RISCVTargetLowering::allowsMisalignedMemoryAccesses(
19772     EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
19773     unsigned *Fast) const {
19774   if (!VT.isVector()) {
19775     if (Fast)
19776       *Fast = Subtarget.hasFastUnalignedAccess() ||
19777               Subtarget.enableUnalignedScalarMem();
19778     return Subtarget.hasFastUnalignedAccess() ||
19779            Subtarget.enableUnalignedScalarMem();
19780   }
19781 
19782   // All vector implementations must support element alignment
19783   EVT ElemVT = VT.getVectorElementType();
19784   if (Alignment >= ElemVT.getStoreSize()) {
19785     if (Fast)
19786       *Fast = 1;
19787     return true;
19788   }
19789 
19790   // Note: We lower an unmasked unaligned vector access to an equally sized
19791   // e8 element type access.  Given this, we effectively support all unmasked
19792   // misaligned accesses.  TODO: Work through the codegen implications of
19793   // allowing such accesses to be formed, and considered fast.
19794   if (Fast)
19795     *Fast = Subtarget.hasFastUnalignedAccess();
19796   return Subtarget.hasFastUnalignedAccess();
19797 }
19798 
19799 
getOptimalMemOpType(const MemOp & Op,const AttributeList & FuncAttributes) const19800 EVT RISCVTargetLowering::getOptimalMemOpType(const MemOp &Op,
19801                                              const AttributeList &FuncAttributes) const {
19802   if (!Subtarget.hasVInstructions())
19803     return MVT::Other;
19804 
19805   if (FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat))
19806     return MVT::Other;
19807 
19808   // We use LMUL1 memory operations here for a non-obvious reason.  Our caller
19809   // has an expansion threshold, and we want the number of hardware memory
19810   // operations to correspond roughly to that threshold.  LMUL>1 operations
19811   // are typically expanded linearly internally, and thus correspond to more
19812   // than one actual memory operation.  Note that store merging and load
19813   // combining will typically form larger LMUL operations from the LMUL1
19814   // operations emitted here, and that's okay because combining isn't
19815   // introducing new memory operations; it's just merging existing ones.
19816   const unsigned MinVLenInBytes = Subtarget.getRealMinVLen()/8;
19817   if (Op.size() < MinVLenInBytes)
19818     // TODO: Figure out short memops.  For the moment, do the default thing
19819     // which ends up using scalar sequences.
19820     return MVT::Other;
19821 
19822   // Prefer i8 for non-zero memset as it allows us to avoid materializing
19823   // a large scalar constant and instead use vmv.v.x/i to do the
19824   // broadcast.  For everything else, prefer ELenVT to minimize VL and thus
19825   // maximize the chance we can encode the size in the vsetvli.
19826   MVT ELenVT = MVT::getIntegerVT(Subtarget.getELen());
19827   MVT PreferredVT = (Op.isMemset() && !Op.isZeroMemset()) ? MVT::i8 : ELenVT;
19828 
19829   // Do we have sufficient alignment for our preferred VT?  If not, revert
19830   // to largest size allowed by our alignment criteria.
19831   if (PreferredVT != MVT::i8 && !Subtarget.hasFastUnalignedAccess()) {
19832     Align RequiredAlign(PreferredVT.getStoreSize());
19833     if (Op.isFixedDstAlign())
19834       RequiredAlign = std::min(RequiredAlign, Op.getDstAlign());
19835     if (Op.isMemcpy())
19836       RequiredAlign = std::min(RequiredAlign, Op.getSrcAlign());
19837     PreferredVT = MVT::getIntegerVT(RequiredAlign.value() * 8);
19838   }
19839   return MVT::getVectorVT(PreferredVT, MinVLenInBytes/PreferredVT.getStoreSize());
19840 }
19841 
splitValueIntoRegisterParts(SelectionDAG & DAG,const SDLoc & DL,SDValue Val,SDValue * Parts,unsigned NumParts,MVT PartVT,std::optional<CallingConv::ID> CC) const19842 bool RISCVTargetLowering::splitValueIntoRegisterParts(
19843     SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
19844     unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
19845   bool IsABIRegCopy = CC.has_value();
19846   EVT ValueVT = Val.getValueType();
19847   if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
19848       PartVT == MVT::f32) {
19849     // Cast the [b]f16 to i16, extend to i32, pad with ones to make a float
19850     // nan, and cast to f32.
19851     Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
19852     Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
19853     Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
19854                       DAG.getConstant(0xFFFF0000, DL, MVT::i32));
19855     Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
19856     Parts[0] = Val;
19857     return true;
19858   }
19859 
19860   if (ValueVT.isScalableVector() && PartVT.isScalableVector()) {
19861     LLVMContext &Context = *DAG.getContext();
19862     EVT ValueEltVT = ValueVT.getVectorElementType();
19863     EVT PartEltVT = PartVT.getVectorElementType();
19864     unsigned ValueVTBitSize = ValueVT.getSizeInBits().getKnownMinValue();
19865     unsigned PartVTBitSize = PartVT.getSizeInBits().getKnownMinValue();
19866     if (PartVTBitSize % ValueVTBitSize == 0) {
19867       assert(PartVTBitSize >= ValueVTBitSize);
19868       // If the element types are different, bitcast to the same element type of
19869       // PartVT first.
19870       // Give an example here, we want copy a <vscale x 1 x i8> value to
19871       // <vscale x 4 x i16>.
19872       // We need to convert <vscale x 1 x i8> to <vscale x 8 x i8> by insert
19873       // subvector, then we can bitcast to <vscale x 4 x i16>.
19874       if (ValueEltVT != PartEltVT) {
19875         if (PartVTBitSize > ValueVTBitSize) {
19876           unsigned Count = PartVTBitSize / ValueEltVT.getFixedSizeInBits();
19877           assert(Count != 0 && "The number of element should not be zero.");
19878           EVT SameEltTypeVT =
19879               EVT::getVectorVT(Context, ValueEltVT, Count, /*IsScalable=*/true);
19880           Val = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, SameEltTypeVT,
19881                             DAG.getUNDEF(SameEltTypeVT), Val,
19882                             DAG.getVectorIdxConstant(0, DL));
19883         }
19884         Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
19885       } else {
19886         Val =
19887             DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT),
19888                         Val, DAG.getVectorIdxConstant(0, DL));
19889       }
19890       Parts[0] = Val;
19891       return true;
19892     }
19893   }
19894   return false;
19895 }
19896 
joinRegisterPartsIntoValue(SelectionDAG & DAG,const SDLoc & DL,const SDValue * Parts,unsigned NumParts,MVT PartVT,EVT ValueVT,std::optional<CallingConv::ID> CC) const19897 SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
19898     SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
19899     MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
19900   bool IsABIRegCopy = CC.has_value();
19901   if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
19902       PartVT == MVT::f32) {
19903     SDValue Val = Parts[0];
19904 
19905     // Cast the f32 to i32, truncate to i16, and cast back to [b]f16.
19906     Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
19907     Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
19908     Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
19909     return Val;
19910   }
19911 
19912   if (ValueVT.isScalableVector() && PartVT.isScalableVector()) {
19913     LLVMContext &Context = *DAG.getContext();
19914     SDValue Val = Parts[0];
19915     EVT ValueEltVT = ValueVT.getVectorElementType();
19916     EVT PartEltVT = PartVT.getVectorElementType();
19917     unsigned ValueVTBitSize = ValueVT.getSizeInBits().getKnownMinValue();
19918     unsigned PartVTBitSize = PartVT.getSizeInBits().getKnownMinValue();
19919     if (PartVTBitSize % ValueVTBitSize == 0) {
19920       assert(PartVTBitSize >= ValueVTBitSize);
19921       EVT SameEltTypeVT = ValueVT;
19922       // If the element types are different, convert it to the same element type
19923       // of PartVT.
19924       // Give an example here, we want copy a <vscale x 1 x i8> value from
19925       // <vscale x 4 x i16>.
19926       // We need to convert <vscale x 4 x i16> to <vscale x 8 x i8> first,
19927       // then we can extract <vscale x 1 x i8>.
19928       if (ValueEltVT != PartEltVT) {
19929         unsigned Count = PartVTBitSize / ValueEltVT.getFixedSizeInBits();
19930         assert(Count != 0 && "The number of element should not be zero.");
19931         SameEltTypeVT =
19932             EVT::getVectorVT(Context, ValueEltVT, Count, /*IsScalable=*/true);
19933         Val = DAG.getNode(ISD::BITCAST, DL, SameEltTypeVT, Val);
19934       }
19935       Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ValueVT, Val,
19936                         DAG.getVectorIdxConstant(0, DL));
19937       return Val;
19938     }
19939   }
19940   return SDValue();
19941 }
19942 
isIntDivCheap(EVT VT,AttributeList Attr) const19943 bool RISCVTargetLowering::isIntDivCheap(EVT VT, AttributeList Attr) const {
19944   // When aggressively optimizing for code size, we prefer to use a div
19945   // instruction, as it is usually smaller than the alternative sequence.
19946   // TODO: Add vector division?
19947   bool OptSize = Attr.hasFnAttr(Attribute::MinSize);
19948   return OptSize && !VT.isVector();
19949 }
19950 
preferScalarizeSplat(SDNode * N) const19951 bool RISCVTargetLowering::preferScalarizeSplat(SDNode *N) const {
19952   // Scalarize zero_ext and sign_ext might stop match to widening instruction in
19953   // some situation.
19954   unsigned Opc = N->getOpcode();
19955   if (Opc == ISD::ZERO_EXTEND || Opc == ISD::SIGN_EXTEND)
19956     return false;
19957   return true;
19958 }
19959 
useTpOffset(IRBuilderBase & IRB,unsigned Offset)19960 static Value *useTpOffset(IRBuilderBase &IRB, unsigned Offset) {
19961   Module *M = IRB.GetInsertBlock()->getParent()->getParent();
19962   Function *ThreadPointerFunc =
19963       Intrinsic::getDeclaration(M, Intrinsic::thread_pointer);
19964   return IRB.CreateConstGEP1_32(IRB.getInt8Ty(),
19965                                 IRB.CreateCall(ThreadPointerFunc), Offset);
19966 }
19967 
getIRStackGuard(IRBuilderBase & IRB) const19968 Value *RISCVTargetLowering::getIRStackGuard(IRBuilderBase &IRB) const {
19969   // Fuchsia provides a fixed TLS slot for the stack cookie.
19970   // <zircon/tls.h> defines ZX_TLS_STACK_GUARD_OFFSET with this value.
19971   if (Subtarget.isTargetFuchsia())
19972     return useTpOffset(IRB, -0x10);
19973 
19974   return TargetLowering::getIRStackGuard(IRB);
19975 }
19976 
isLegalInterleavedAccessType(VectorType * VTy,unsigned Factor,Align Alignment,unsigned AddrSpace,const DataLayout & DL) const19977 bool RISCVTargetLowering::isLegalInterleavedAccessType(
19978     VectorType *VTy, unsigned Factor, Align Alignment, unsigned AddrSpace,
19979     const DataLayout &DL) const {
19980   EVT VT = getValueType(DL, VTy);
19981   // Don't lower vlseg/vsseg for vector types that can't be split.
19982   if (!isTypeLegal(VT))
19983     return false;
19984 
19985   if (!isLegalElementTypeForRVV(VT.getScalarType()) ||
19986       !allowsMemoryAccessForAlignment(VTy->getContext(), DL, VT, AddrSpace,
19987                                       Alignment))
19988     return false;
19989 
19990   MVT ContainerVT = VT.getSimpleVT();
19991 
19992   if (auto *FVTy = dyn_cast<FixedVectorType>(VTy)) {
19993     if (!Subtarget.useRVVForFixedLengthVectors())
19994       return false;
19995     // Sometimes the interleaved access pass picks up splats as interleaves of
19996     // one element. Don't lower these.
19997     if (FVTy->getNumElements() < 2)
19998       return false;
19999 
20000     ContainerVT = getContainerForFixedLengthVector(VT.getSimpleVT());
20001   }
20002 
20003   // Need to make sure that EMUL * NFIELDS ≤ 8
20004   auto [LMUL, Fractional] = RISCVVType::decodeVLMUL(getLMUL(ContainerVT));
20005   if (Fractional)
20006     return true;
20007   return Factor * LMUL <= 8;
20008 }
20009 
isLegalStridedLoadStore(EVT DataType,Align Alignment) const20010 bool RISCVTargetLowering::isLegalStridedLoadStore(EVT DataType,
20011                                                   Align Alignment) const {
20012   if (!Subtarget.hasVInstructions())
20013     return false;
20014 
20015   // Only support fixed vectors if we know the minimum vector size.
20016   if (DataType.isFixedLengthVector() && !Subtarget.useRVVForFixedLengthVectors())
20017     return false;
20018 
20019   EVT ScalarType = DataType.getScalarType();
20020   if (!isLegalElementTypeForRVV(ScalarType))
20021     return false;
20022 
20023   if (!Subtarget.hasFastUnalignedAccess() &&
20024       Alignment < ScalarType.getStoreSize())
20025     return false;
20026 
20027   return true;
20028 }
20029 
20030 static const Intrinsic::ID FixedVlsegIntrIds[] = {
20031     Intrinsic::riscv_seg2_load, Intrinsic::riscv_seg3_load,
20032     Intrinsic::riscv_seg4_load, Intrinsic::riscv_seg5_load,
20033     Intrinsic::riscv_seg6_load, Intrinsic::riscv_seg7_load,
20034     Intrinsic::riscv_seg8_load};
20035 
20036 /// Lower an interleaved load into a vlsegN intrinsic.
20037 ///
20038 /// E.g. Lower an interleaved load (Factor = 2):
20039 /// %wide.vec = load <8 x i32>, <8 x i32>* %ptr
20040 /// %v0 = shuffle %wide.vec, undef, <0, 2, 4, 6>  ; Extract even elements
20041 /// %v1 = shuffle %wide.vec, undef, <1, 3, 5, 7>  ; Extract odd elements
20042 ///
20043 /// Into:
20044 /// %ld2 = { <4 x i32>, <4 x i32> } call llvm.riscv.seg2.load.v4i32.p0.i64(
20045 ///                                        %ptr, i64 4)
20046 /// %vec0 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 0
20047 /// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
lowerInterleavedLoad(LoadInst * LI,ArrayRef<ShuffleVectorInst * > Shuffles,ArrayRef<unsigned> Indices,unsigned Factor) const20048 bool RISCVTargetLowering::lowerInterleavedLoad(
20049     LoadInst *LI, ArrayRef<ShuffleVectorInst *> Shuffles,
20050     ArrayRef<unsigned> Indices, unsigned Factor) const {
20051   IRBuilder<> Builder(LI);
20052 
20053   auto *VTy = cast<FixedVectorType>(Shuffles[0]->getType());
20054   if (!isLegalInterleavedAccessType(VTy, Factor, LI->getAlign(),
20055                                     LI->getPointerAddressSpace(),
20056                                     LI->getModule()->getDataLayout()))
20057     return false;
20058 
20059   auto *XLenTy = Type::getIntNTy(LI->getContext(), Subtarget.getXLen());
20060 
20061   Function *VlsegNFunc =
20062       Intrinsic::getDeclaration(LI->getModule(), FixedVlsegIntrIds[Factor - 2],
20063                                 {VTy, LI->getPointerOperandType(), XLenTy});
20064 
20065   Value *VL = ConstantInt::get(XLenTy, VTy->getNumElements());
20066 
20067   CallInst *VlsegN =
20068       Builder.CreateCall(VlsegNFunc, {LI->getPointerOperand(), VL});
20069 
20070   for (unsigned i = 0; i < Shuffles.size(); i++) {
20071     Value *SubVec = Builder.CreateExtractValue(VlsegN, Indices[i]);
20072     Shuffles[i]->replaceAllUsesWith(SubVec);
20073   }
20074 
20075   return true;
20076 }
20077 
20078 static const Intrinsic::ID FixedVssegIntrIds[] = {
20079     Intrinsic::riscv_seg2_store, Intrinsic::riscv_seg3_store,
20080     Intrinsic::riscv_seg4_store, Intrinsic::riscv_seg5_store,
20081     Intrinsic::riscv_seg6_store, Intrinsic::riscv_seg7_store,
20082     Intrinsic::riscv_seg8_store};
20083 
20084 /// Lower an interleaved store into a vssegN intrinsic.
20085 ///
20086 /// E.g. Lower an interleaved store (Factor = 3):
20087 /// %i.vec = shuffle <8 x i32> %v0, <8 x i32> %v1,
20088 ///                  <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11>
20089 /// store <12 x i32> %i.vec, <12 x i32>* %ptr
20090 ///
20091 /// Into:
20092 /// %sub.v0 = shuffle <8 x i32> %v0, <8 x i32> v1, <0, 1, 2, 3>
20093 /// %sub.v1 = shuffle <8 x i32> %v0, <8 x i32> v1, <4, 5, 6, 7>
20094 /// %sub.v2 = shuffle <8 x i32> %v0, <8 x i32> v1, <8, 9, 10, 11>
20095 /// call void llvm.riscv.seg3.store.v4i32.p0.i64(%sub.v0, %sub.v1, %sub.v2,
20096 ///                                              %ptr, i32 4)
20097 ///
20098 /// Note that the new shufflevectors will be removed and we'll only generate one
20099 /// vsseg3 instruction in CodeGen.
lowerInterleavedStore(StoreInst * SI,ShuffleVectorInst * SVI,unsigned Factor) const20100 bool RISCVTargetLowering::lowerInterleavedStore(StoreInst *SI,
20101                                                 ShuffleVectorInst *SVI,
20102                                                 unsigned Factor) const {
20103   IRBuilder<> Builder(SI);
20104   auto *ShuffleVTy = cast<FixedVectorType>(SVI->getType());
20105   // Given SVI : <n*factor x ty>, then VTy : <n x ty>
20106   auto *VTy = FixedVectorType::get(ShuffleVTy->getElementType(),
20107                                    ShuffleVTy->getNumElements() / Factor);
20108   if (!isLegalInterleavedAccessType(VTy, Factor, SI->getAlign(),
20109                                     SI->getPointerAddressSpace(),
20110                                     SI->getModule()->getDataLayout()))
20111     return false;
20112 
20113   auto *XLenTy = Type::getIntNTy(SI->getContext(), Subtarget.getXLen());
20114 
20115   Function *VssegNFunc =
20116       Intrinsic::getDeclaration(SI->getModule(), FixedVssegIntrIds[Factor - 2],
20117                                 {VTy, SI->getPointerOperandType(), XLenTy});
20118 
20119   auto Mask = SVI->getShuffleMask();
20120   SmallVector<Value *, 10> Ops;
20121 
20122   for (unsigned i = 0; i < Factor; i++) {
20123     Value *Shuffle = Builder.CreateShuffleVector(
20124         SVI->getOperand(0), SVI->getOperand(1),
20125         createSequentialMask(Mask[i], VTy->getNumElements(), 0));
20126     Ops.push_back(Shuffle);
20127   }
20128   // This VL should be OK (should be executable in one vsseg instruction,
20129   // potentially under larger LMULs) because we checked that the fixed vector
20130   // type fits in isLegalInterleavedAccessType
20131   Value *VL = ConstantInt::get(XLenTy, VTy->getNumElements());
20132   Ops.append({SI->getPointerOperand(), VL});
20133 
20134   Builder.CreateCall(VssegNFunc, Ops);
20135 
20136   return true;
20137 }
20138 
lowerDeinterleaveIntrinsicToLoad(IntrinsicInst * DI,LoadInst * LI) const20139 bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(IntrinsicInst *DI,
20140                                                            LoadInst *LI) const {
20141   assert(LI->isSimple());
20142   IRBuilder<> Builder(LI);
20143 
20144   // Only deinterleave2 supported at present.
20145   if (DI->getIntrinsicID() != Intrinsic::experimental_vector_deinterleave2)
20146     return false;
20147 
20148   unsigned Factor = 2;
20149 
20150   VectorType *VTy = cast<VectorType>(DI->getOperand(0)->getType());
20151   VectorType *ResVTy = cast<VectorType>(DI->getType()->getContainedType(0));
20152 
20153   if (!isLegalInterleavedAccessType(ResVTy, Factor, LI->getAlign(),
20154                                     LI->getPointerAddressSpace(),
20155                                     LI->getModule()->getDataLayout()))
20156     return false;
20157 
20158   Function *VlsegNFunc;
20159   Value *VL;
20160   Type *XLenTy = Type::getIntNTy(LI->getContext(), Subtarget.getXLen());
20161   SmallVector<Value *, 10> Ops;
20162 
20163   if (auto *FVTy = dyn_cast<FixedVectorType>(VTy)) {
20164     VlsegNFunc = Intrinsic::getDeclaration(
20165         LI->getModule(), FixedVlsegIntrIds[Factor - 2],
20166         {ResVTy, LI->getPointerOperandType(), XLenTy});
20167     VL = ConstantInt::get(XLenTy, FVTy->getNumElements());
20168   } else {
20169     static const Intrinsic::ID IntrIds[] = {
20170         Intrinsic::riscv_vlseg2, Intrinsic::riscv_vlseg3,
20171         Intrinsic::riscv_vlseg4, Intrinsic::riscv_vlseg5,
20172         Intrinsic::riscv_vlseg6, Intrinsic::riscv_vlseg7,
20173         Intrinsic::riscv_vlseg8};
20174 
20175     VlsegNFunc = Intrinsic::getDeclaration(LI->getModule(), IntrIds[Factor - 2],
20176                                            {ResVTy, XLenTy});
20177     VL = Constant::getAllOnesValue(XLenTy);
20178     Ops.append(Factor, PoisonValue::get(ResVTy));
20179   }
20180 
20181   Ops.append({LI->getPointerOperand(), VL});
20182 
20183   Value *Vlseg = Builder.CreateCall(VlsegNFunc, Ops);
20184   DI->replaceAllUsesWith(Vlseg);
20185 
20186   return true;
20187 }
20188 
lowerInterleaveIntrinsicToStore(IntrinsicInst * II,StoreInst * SI) const20189 bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore(IntrinsicInst *II,
20190                                                           StoreInst *SI) const {
20191   assert(SI->isSimple());
20192   IRBuilder<> Builder(SI);
20193 
20194   // Only interleave2 supported at present.
20195   if (II->getIntrinsicID() != Intrinsic::experimental_vector_interleave2)
20196     return false;
20197 
20198   unsigned Factor = 2;
20199 
20200   VectorType *VTy = cast<VectorType>(II->getType());
20201   VectorType *InVTy = cast<VectorType>(II->getOperand(0)->getType());
20202 
20203   if (!isLegalInterleavedAccessType(InVTy, Factor, SI->getAlign(),
20204                                     SI->getPointerAddressSpace(),
20205                                     SI->getModule()->getDataLayout()))
20206     return false;
20207 
20208   Function *VssegNFunc;
20209   Value *VL;
20210   Type *XLenTy = Type::getIntNTy(SI->getContext(), Subtarget.getXLen());
20211 
20212   if (auto *FVTy = dyn_cast<FixedVectorType>(VTy)) {
20213     VssegNFunc = Intrinsic::getDeclaration(
20214         SI->getModule(), FixedVssegIntrIds[Factor - 2],
20215         {InVTy, SI->getPointerOperandType(), XLenTy});
20216     VL = ConstantInt::get(XLenTy, FVTy->getNumElements());
20217   } else {
20218     static const Intrinsic::ID IntrIds[] = {
20219         Intrinsic::riscv_vsseg2, Intrinsic::riscv_vsseg3,
20220         Intrinsic::riscv_vsseg4, Intrinsic::riscv_vsseg5,
20221         Intrinsic::riscv_vsseg6, Intrinsic::riscv_vsseg7,
20222         Intrinsic::riscv_vsseg8};
20223 
20224     VssegNFunc = Intrinsic::getDeclaration(SI->getModule(), IntrIds[Factor - 2],
20225                                            {InVTy, XLenTy});
20226     VL = Constant::getAllOnesValue(XLenTy);
20227   }
20228 
20229   Builder.CreateCall(VssegNFunc, {II->getOperand(0), II->getOperand(1),
20230                                   SI->getPointerOperand(), VL});
20231 
20232   return true;
20233 }
20234 
20235 MachineInstr *
EmitKCFICheck(MachineBasicBlock & MBB,MachineBasicBlock::instr_iterator & MBBI,const TargetInstrInfo * TII) const20236 RISCVTargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
20237                                    MachineBasicBlock::instr_iterator &MBBI,
20238                                    const TargetInstrInfo *TII) const {
20239   assert(MBBI->isCall() && MBBI->getCFIType() &&
20240          "Invalid call instruction for a KCFI check");
20241   assert(is_contained({RISCV::PseudoCALLIndirect, RISCV::PseudoTAILIndirect},
20242                       MBBI->getOpcode()));
20243 
20244   MachineOperand &Target = MBBI->getOperand(0);
20245   Target.setIsRenamable(false);
20246 
20247   return BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(RISCV::KCFI_CHECK))
20248       .addReg(Target.getReg())
20249       .addImm(MBBI->getCFIType())
20250       .getInstr();
20251 }
20252 
20253 #define GET_REGISTER_MATCHER
20254 #include "RISCVGenAsmMatcher.inc"
20255 
20256 Register
getRegisterByName(const char * RegName,LLT VT,const MachineFunction & MF) const20257 RISCVTargetLowering::getRegisterByName(const char *RegName, LLT VT,
20258                                        const MachineFunction &MF) const {
20259   Register Reg = MatchRegisterAltName(RegName);
20260   if (Reg == RISCV::NoRegister)
20261     Reg = MatchRegisterName(RegName);
20262   if (Reg == RISCV::NoRegister)
20263     report_fatal_error(
20264         Twine("Invalid register name \"" + StringRef(RegName) + "\"."));
20265   BitVector ReservedRegs = Subtarget.getRegisterInfo()->getReservedRegs(MF);
20266   if (!ReservedRegs.test(Reg) && !Subtarget.isRegisterReservedByUser(Reg))
20267     report_fatal_error(Twine("Trying to obtain non-reserved register \"" +
20268                              StringRef(RegName) + "\"."));
20269   return Reg;
20270 }
20271 
20272 MachineMemOperand::Flags
getTargetMMOFlags(const Instruction & I) const20273 RISCVTargetLowering::getTargetMMOFlags(const Instruction &I) const {
20274   const MDNode *NontemporalInfo = I.getMetadata(LLVMContext::MD_nontemporal);
20275 
20276   if (NontemporalInfo == nullptr)
20277     return MachineMemOperand::MONone;
20278 
20279   // 1 for default value work as __RISCV_NTLH_ALL
20280   // 2 -> __RISCV_NTLH_INNERMOST_PRIVATE
20281   // 3 -> __RISCV_NTLH_ALL_PRIVATE
20282   // 4 -> __RISCV_NTLH_INNERMOST_SHARED
20283   // 5 -> __RISCV_NTLH_ALL
20284   int NontemporalLevel = 5;
20285   const MDNode *RISCVNontemporalInfo =
20286       I.getMetadata("riscv-nontemporal-domain");
20287   if (RISCVNontemporalInfo != nullptr)
20288     NontemporalLevel =
20289         cast<ConstantInt>(
20290             cast<ConstantAsMetadata>(RISCVNontemporalInfo->getOperand(0))
20291                 ->getValue())
20292             ->getZExtValue();
20293 
20294   assert((1 <= NontemporalLevel && NontemporalLevel <= 5) &&
20295          "RISC-V target doesn't support this non-temporal domain.");
20296 
20297   NontemporalLevel -= 2;
20298   MachineMemOperand::Flags Flags = MachineMemOperand::MONone;
20299   if (NontemporalLevel & 0b1)
20300     Flags |= MONontemporalBit0;
20301   if (NontemporalLevel & 0b10)
20302     Flags |= MONontemporalBit1;
20303 
20304   return Flags;
20305 }
20306 
20307 MachineMemOperand::Flags
getTargetMMOFlags(const MemSDNode & Node) const20308 RISCVTargetLowering::getTargetMMOFlags(const MemSDNode &Node) const {
20309 
20310   MachineMemOperand::Flags NodeFlags = Node.getMemOperand()->getFlags();
20311   MachineMemOperand::Flags TargetFlags = MachineMemOperand::MONone;
20312   TargetFlags |= (NodeFlags & MONontemporalBit0);
20313   TargetFlags |= (NodeFlags & MONontemporalBit1);
20314 
20315   return TargetFlags;
20316 }
20317 
areTwoSDNodeTargetMMOFlagsMergeable(const MemSDNode & NodeX,const MemSDNode & NodeY) const20318 bool RISCVTargetLowering::areTwoSDNodeTargetMMOFlagsMergeable(
20319     const MemSDNode &NodeX, const MemSDNode &NodeY) const {
20320   return getTargetMMOFlags(NodeX) == getTargetMMOFlags(NodeY);
20321 }
20322 
isCtpopFast(EVT VT) const20323 bool RISCVTargetLowering::isCtpopFast(EVT VT) const {
20324   if (VT.isScalableVector())
20325     return isTypeLegal(VT) && Subtarget.hasStdExtZvbb();
20326   if (VT.isFixedLengthVector() && Subtarget.hasStdExtZvbb())
20327     return true;
20328   return Subtarget.hasStdExtZbb() &&
20329          (VT == MVT::i32 || VT == MVT::i64 || VT.isFixedLengthVector());
20330 }
20331 
getCustomCtpopCost(EVT VT,ISD::CondCode Cond) const20332 unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT,
20333                                                  ISD::CondCode Cond) const {
20334   return isCtpopFast(VT) ? 0 : 1;
20335 }
20336 
fallBackToDAGISel(const Instruction & Inst) const20337 bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
20338 
20339   // GISel support is in progress or complete for G_ADD, G_SUB, G_AND, G_OR, and
20340   // G_XOR.
20341   unsigned Op = Inst.getOpcode();
20342   if (Op == Instruction::Add || Op == Instruction::Sub ||
20343       Op == Instruction::And || Op == Instruction::Or || Op == Instruction::Xor)
20344     return false;
20345 
20346   if (Inst.getType()->isScalableTy())
20347     return true;
20348 
20349   for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
20350     if (Inst.getOperand(i)->getType()->isScalableTy() &&
20351         !isa<ReturnInst>(&Inst))
20352       return true;
20353 
20354   if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
20355     if (AI->getAllocatedType()->isScalableTy())
20356       return true;
20357   }
20358 
20359   return false;
20360 }
20361 
20362 SDValue
BuildSDIVPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const20363 RISCVTargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
20364                                    SelectionDAG &DAG,
20365                                    SmallVectorImpl<SDNode *> &Created) const {
20366   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
20367   if (isIntDivCheap(N->getValueType(0), Attr))
20368     return SDValue(N, 0); // Lower SDIV as SDIV
20369 
20370   // Only perform this transform if short forward branch opt is supported.
20371   if (!Subtarget.hasShortForwardBranchOpt())
20372     return SDValue();
20373   EVT VT = N->getValueType(0);
20374   if (!(VT == MVT::i32 || (VT == MVT::i64 && Subtarget.is64Bit())))
20375     return SDValue();
20376 
20377   // Ensure 2**k-1 < 2048 so that we can just emit a single addi/addiw.
20378   if (Divisor.sgt(2048) || Divisor.slt(-2048))
20379     return SDValue();
20380   return TargetLowering::buildSDIVPow2WithCMov(N, Divisor, DAG, Created);
20381 }
20382 
shouldFoldSelectWithSingleBitTest(EVT VT,const APInt & AndMask) const20383 bool RISCVTargetLowering::shouldFoldSelectWithSingleBitTest(
20384     EVT VT, const APInt &AndMask) const {
20385   if (Subtarget.hasStdExtZicond() || Subtarget.hasVendorXVentanaCondOps())
20386     return !Subtarget.hasStdExtZbs() && AndMask.ugt(1024);
20387   return TargetLowering::shouldFoldSelectWithSingleBitTest(VT, AndMask);
20388 }
20389 
getMinimumJumpTableEntries() const20390 unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
20391   return Subtarget.getMinimumJumpTableEntries();
20392 }
20393 
20394 namespace llvm::RISCVVIntrinsicsTable {
20395 
20396 #define GET_RISCVVIntrinsicsTable_IMPL
20397 #include "RISCVGenSearchableTables.inc"
20398 
20399 } // namespace llvm::RISCVVIntrinsicsTable
20400