1 //===-- RISCVISelLowering.cpp - RISCV DAG Lowering Implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that RISCV uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "RISCVISelLowering.h"
15 #include "MCTargetDesc/RISCVMatInt.h"
16 #include "RISCV.h"
17 #include "RISCVMachineFunctionInfo.h"
18 #include "RISCVRegisterInfo.h"
19 #include "RISCVSubtarget.h"
20 #include "RISCVTargetMachine.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/Analysis/MemoryLocation.h"
24 #include "llvm/CodeGen/MachineFrameInfo.h"
25 #include "llvm/CodeGen/MachineFunction.h"
26 #include "llvm/CodeGen/MachineInstrBuilder.h"
27 #include "llvm/CodeGen/MachineJumpTableInfo.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
30 #include "llvm/CodeGen/ValueTypes.h"
31 #include "llvm/IR/DiagnosticInfo.h"
32 #include "llvm/IR/DiagnosticPrinter.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/IR/IntrinsicsRISCV.h"
35 #include "llvm/IR/PatternMatch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/ErrorHandling.h"
39 #include "llvm/Support/KnownBits.h"
40 #include "llvm/Support/MathExtras.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include <optional>
43
44 using namespace llvm;
45
46 #define DEBUG_TYPE "riscv-lower"
47
48 STATISTIC(NumTailCalls, "Number of tail calls");
49
50 static cl::opt<unsigned> ExtensionMaxWebSize(
51 DEBUG_TYPE "-ext-max-web-size", cl::Hidden,
52 cl::desc("Give the maximum size (in number of nodes) of the web of "
53 "instructions that we will consider for VW expansion"),
54 cl::init(18));
55
56 static cl::opt<bool>
57 AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden,
58 cl::desc("Allow the formation of VW_W operations (e.g., "
59 "VWADD_W) with splat constants"),
60 cl::init(false));
61
62 static cl::opt<unsigned> NumRepeatedDivisors(
63 DEBUG_TYPE "-fp-repeated-divisors", cl::Hidden,
64 cl::desc("Set the minimum number of repetitions of a divisor to allow "
65 "transformation to multiplications by the reciprocal"),
66 cl::init(2));
67
RISCVTargetLowering(const TargetMachine & TM,const RISCVSubtarget & STI)68 RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
69 const RISCVSubtarget &STI)
70 : TargetLowering(TM), Subtarget(STI) {
71
72 if (Subtarget.isRV32E())
73 report_fatal_error("Codegen not yet implemented for RV32E");
74
75 RISCVABI::ABI ABI = Subtarget.getTargetABI();
76 assert(ABI != RISCVABI::ABI_Unknown && "Improperly initialised target ABI");
77
78 if ((ABI == RISCVABI::ABI_ILP32F || ABI == RISCVABI::ABI_LP64F) &&
79 !Subtarget.hasStdExtF()) {
80 errs() << "Hard-float 'f' ABI can't be used for a target that "
81 "doesn't support the F instruction set extension (ignoring "
82 "target-abi)\n";
83 ABI = Subtarget.is64Bit() ? RISCVABI::ABI_LP64 : RISCVABI::ABI_ILP32;
84 } else if ((ABI == RISCVABI::ABI_ILP32D || ABI == RISCVABI::ABI_LP64D) &&
85 !Subtarget.hasStdExtD()) {
86 errs() << "Hard-float 'd' ABI can't be used for a target that "
87 "doesn't support the D instruction set extension (ignoring "
88 "target-abi)\n";
89 ABI = Subtarget.is64Bit() ? RISCVABI::ABI_LP64 : RISCVABI::ABI_ILP32;
90 }
91
92 switch (ABI) {
93 default:
94 report_fatal_error("Don't know how to lower this ABI");
95 case RISCVABI::ABI_ILP32:
96 case RISCVABI::ABI_ILP32F:
97 case RISCVABI::ABI_ILP32D:
98 case RISCVABI::ABI_LP64:
99 case RISCVABI::ABI_LP64F:
100 case RISCVABI::ABI_LP64D:
101 break;
102 }
103
104 MVT XLenVT = Subtarget.getXLenVT();
105
106 // Set up the register classes.
107 addRegisterClass(XLenVT, &RISCV::GPRRegClass);
108
109 if (Subtarget.hasStdExtZfhOrZfhmin())
110 addRegisterClass(MVT::f16, &RISCV::FPR16RegClass);
111 if (Subtarget.hasStdExtF())
112 addRegisterClass(MVT::f32, &RISCV::FPR32RegClass);
113 if (Subtarget.hasStdExtD())
114 addRegisterClass(MVT::f64, &RISCV::FPR64RegClass);
115
116 static const MVT::SimpleValueType BoolVecVTs[] = {
117 MVT::nxv1i1, MVT::nxv2i1, MVT::nxv4i1, MVT::nxv8i1,
118 MVT::nxv16i1, MVT::nxv32i1, MVT::nxv64i1};
119 static const MVT::SimpleValueType IntVecVTs[] = {
120 MVT::nxv1i8, MVT::nxv2i8, MVT::nxv4i8, MVT::nxv8i8, MVT::nxv16i8,
121 MVT::nxv32i8, MVT::nxv64i8, MVT::nxv1i16, MVT::nxv2i16, MVT::nxv4i16,
122 MVT::nxv8i16, MVT::nxv16i16, MVT::nxv32i16, MVT::nxv1i32, MVT::nxv2i32,
123 MVT::nxv4i32, MVT::nxv8i32, MVT::nxv16i32, MVT::nxv1i64, MVT::nxv2i64,
124 MVT::nxv4i64, MVT::nxv8i64};
125 static const MVT::SimpleValueType F16VecVTs[] = {
126 MVT::nxv1f16, MVT::nxv2f16, MVT::nxv4f16,
127 MVT::nxv8f16, MVT::nxv16f16, MVT::nxv32f16};
128 static const MVT::SimpleValueType F32VecVTs[] = {
129 MVT::nxv1f32, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv8f32, MVT::nxv16f32};
130 static const MVT::SimpleValueType F64VecVTs[] = {
131 MVT::nxv1f64, MVT::nxv2f64, MVT::nxv4f64, MVT::nxv8f64};
132
133 if (Subtarget.hasVInstructions()) {
134 auto addRegClassForRVV = [this](MVT VT) {
135 // Disable the smallest fractional LMUL types if ELEN is less than
136 // RVVBitsPerBlock.
137 unsigned MinElts = RISCV::RVVBitsPerBlock / Subtarget.getELEN();
138 if (VT.getVectorMinNumElements() < MinElts)
139 return;
140
141 unsigned Size = VT.getSizeInBits().getKnownMinValue();
142 const TargetRegisterClass *RC;
143 if (Size <= RISCV::RVVBitsPerBlock)
144 RC = &RISCV::VRRegClass;
145 else if (Size == 2 * RISCV::RVVBitsPerBlock)
146 RC = &RISCV::VRM2RegClass;
147 else if (Size == 4 * RISCV::RVVBitsPerBlock)
148 RC = &RISCV::VRM4RegClass;
149 else if (Size == 8 * RISCV::RVVBitsPerBlock)
150 RC = &RISCV::VRM8RegClass;
151 else
152 llvm_unreachable("Unexpected size");
153
154 addRegisterClass(VT, RC);
155 };
156
157 for (MVT VT : BoolVecVTs)
158 addRegClassForRVV(VT);
159 for (MVT VT : IntVecVTs) {
160 if (VT.getVectorElementType() == MVT::i64 &&
161 !Subtarget.hasVInstructionsI64())
162 continue;
163 addRegClassForRVV(VT);
164 }
165
166 if (Subtarget.hasVInstructionsF16())
167 for (MVT VT : F16VecVTs)
168 addRegClassForRVV(VT);
169
170 if (Subtarget.hasVInstructionsF32())
171 for (MVT VT : F32VecVTs)
172 addRegClassForRVV(VT);
173
174 if (Subtarget.hasVInstructionsF64())
175 for (MVT VT : F64VecVTs)
176 addRegClassForRVV(VT);
177
178 if (Subtarget.useRVVForFixedLengthVectors()) {
179 auto addRegClassForFixedVectors = [this](MVT VT) {
180 MVT ContainerVT = getContainerForFixedLengthVector(VT);
181 unsigned RCID = getRegClassIDForVecVT(ContainerVT);
182 const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo();
183 addRegisterClass(VT, TRI.getRegClass(RCID));
184 };
185 for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
186 if (useRVVForFixedLengthVectorVT(VT))
187 addRegClassForFixedVectors(VT);
188
189 for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
190 if (useRVVForFixedLengthVectorVT(VT))
191 addRegClassForFixedVectors(VT);
192 }
193 }
194
195 // Compute derived properties from the register classes.
196 computeRegisterProperties(STI.getRegisterInfo());
197
198 setStackPointerRegisterToSaveRestore(RISCV::X2);
199
200 setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, XLenVT,
201 MVT::i1, Promote);
202 // DAGCombiner can call isLoadExtLegal for types that aren't legal.
203 setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i32,
204 MVT::i1, Promote);
205
206 // TODO: add all necessary setOperationAction calls.
207 setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Expand);
208
209 setOperationAction(ISD::BR_JT, MVT::Other, Expand);
210 setOperationAction(ISD::BR_CC, XLenVT, Expand);
211 setOperationAction(ISD::BRCOND, MVT::Other, Custom);
212 setOperationAction(ISD::SELECT_CC, XLenVT, Expand);
213
214 setCondCodeAction(ISD::SETLE, XLenVT, Expand);
215 setCondCodeAction(ISD::SETGT, XLenVT, Custom);
216 setCondCodeAction(ISD::SETGE, XLenVT, Expand);
217 setCondCodeAction(ISD::SETULE, XLenVT, Expand);
218 setCondCodeAction(ISD::SETUGT, XLenVT, Custom);
219 setCondCodeAction(ISD::SETUGE, XLenVT, Expand);
220
221 setOperationAction({ISD::STACKSAVE, ISD::STACKRESTORE}, MVT::Other, Expand);
222
223 setOperationAction(ISD::VASTART, MVT::Other, Custom);
224 setOperationAction({ISD::VAARG, ISD::VACOPY, ISD::VAEND}, MVT::Other, Expand);
225
226 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
227
228 setOperationAction(ISD::EH_DWARF_CFA, MVT::i32, Custom);
229
230 if (!Subtarget.hasStdExtZbb())
231 setOperationAction(ISD::SIGN_EXTEND_INREG, {MVT::i8, MVT::i16}, Expand);
232
233 if (Subtarget.is64Bit()) {
234 setOperationAction(ISD::EH_DWARF_CFA, MVT::i64, Custom);
235
236 setOperationAction(ISD::LOAD, MVT::i32, Custom);
237
238 setOperationAction({ISD::ADD, ISD::SUB, ISD::SHL, ISD::SRA, ISD::SRL},
239 MVT::i32, Custom);
240
241 setOperationAction({ISD::UADDO, ISD::USUBO, ISD::UADDSAT, ISD::USUBSAT},
242 MVT::i32, Custom);
243 } else {
244 setLibcallName(
245 {RTLIB::SHL_I128, RTLIB::SRL_I128, RTLIB::SRA_I128, RTLIB::MUL_I128},
246 nullptr);
247 setLibcallName(RTLIB::MULO_I64, nullptr);
248 }
249
250 if (!Subtarget.hasStdExtM() && !Subtarget.hasStdExtZmmul()) {
251 setOperationAction({ISD::MUL, ISD::MULHS, ISD::MULHU}, XLenVT, Expand);
252 } else {
253 if (Subtarget.is64Bit()) {
254 setOperationAction(ISD::MUL, {MVT::i32, MVT::i128}, Custom);
255 } else {
256 setOperationAction(ISD::MUL, MVT::i64, Custom);
257 }
258 }
259
260 if (!Subtarget.hasStdExtM()) {
261 setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM},
262 XLenVT, Expand);
263 } else {
264 if (Subtarget.is64Bit()) {
265 setOperationAction({ISD::SDIV, ISD::UDIV, ISD::UREM},
266 {MVT::i8, MVT::i16, MVT::i32}, Custom);
267 }
268 }
269
270 setOperationAction(
271 {ISD::SDIVREM, ISD::UDIVREM, ISD::SMUL_LOHI, ISD::UMUL_LOHI}, XLenVT,
272 Expand);
273
274 setOperationAction({ISD::SHL_PARTS, ISD::SRL_PARTS, ISD::SRA_PARTS}, XLenVT,
275 Custom);
276
277 if (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()) {
278 if (Subtarget.is64Bit())
279 setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Custom);
280 } else {
281 setOperationAction({ISD::ROTL, ISD::ROTR}, XLenVT, Expand);
282 }
283
284 // With Zbb we have an XLen rev8 instruction, but not GREVI. So we'll
285 // pattern match it directly in isel.
286 setOperationAction(ISD::BSWAP, XLenVT,
287 (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb())
288 ? Legal
289 : Expand);
290 // Zbkb can use rev8+brev8 to implement bitreverse.
291 setOperationAction(ISD::BITREVERSE, XLenVT,
292 Subtarget.hasStdExtZbkb() ? Custom : Expand);
293
294 if (Subtarget.hasStdExtZbb()) {
295 setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, XLenVT,
296 Legal);
297
298 if (Subtarget.is64Bit())
299 setOperationAction(
300 {ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF},
301 MVT::i32, Custom);
302 } else {
303 setOperationAction({ISD::CTTZ, ISD::CTLZ, ISD::CTPOP}, XLenVT, Expand);
304 }
305
306 if (Subtarget.is64Bit())
307 setOperationAction(ISD::ABS, MVT::i32, Custom);
308
309 if (!Subtarget.hasVendorXVentanaCondOps())
310 setOperationAction(ISD::SELECT, XLenVT, Custom);
311
312 static const unsigned FPLegalNodeTypes[] = {
313 ISD::FMINNUM, ISD::FMAXNUM, ISD::LRINT,
314 ISD::LLRINT, ISD::LROUND, ISD::LLROUND,
315 ISD::STRICT_LRINT, ISD::STRICT_LLRINT, ISD::STRICT_LROUND,
316 ISD::STRICT_LLROUND, ISD::STRICT_FMA, ISD::STRICT_FADD,
317 ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
318 ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS};
319
320 static const ISD::CondCode FPCCToExpand[] = {
321 ISD::SETOGT, ISD::SETOGE, ISD::SETONE, ISD::SETUEQ, ISD::SETUGT,
322 ISD::SETUGE, ISD::SETULT, ISD::SETULE, ISD::SETUNE, ISD::SETGT,
323 ISD::SETGE, ISD::SETNE, ISD::SETO, ISD::SETUO};
324
325 static const unsigned FPOpToExpand[] = {
326 ISD::FSIN, ISD::FCOS, ISD::FSINCOS, ISD::FPOW,
327 ISD::FREM, ISD::FP16_TO_FP, ISD::FP_TO_FP16};
328
329 static const unsigned FPRndMode[] = {
330 ISD::FCEIL, ISD::FFLOOR, ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
331 ISD::FROUNDEVEN};
332
333 if (Subtarget.hasStdExtZfhOrZfhmin())
334 setOperationAction(ISD::BITCAST, MVT::i16, Custom);
335
336 if (Subtarget.hasStdExtZfhOrZfhmin()) {
337 if (Subtarget.hasStdExtZfh()) {
338 setOperationAction(FPLegalNodeTypes, MVT::f16, Legal);
339 setOperationAction(FPRndMode, MVT::f16, Custom);
340 setOperationAction(ISD::SELECT, MVT::f16, Custom);
341 } else {
342 static const unsigned ZfhminPromoteOps[] = {
343 ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD,
344 ISD::FSUB, ISD::FMUL, ISD::FMA,
345 ISD::FDIV, ISD::FSQRT, ISD::FABS,
346 ISD::FNEG, ISD::STRICT_FMA, ISD::STRICT_FADD,
347 ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
348 ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS,
349 ISD::SETCC, ISD::FCEIL, ISD::FFLOOR,
350 ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
351 ISD::FROUNDEVEN, ISD::SELECT};
352
353 setOperationAction(ZfhminPromoteOps, MVT::f16, Promote);
354 setOperationAction({ISD::STRICT_LRINT, ISD::STRICT_LLRINT,
355 ISD::STRICT_LROUND, ISD::STRICT_LLROUND},
356 MVT::f16, Legal);
357 // FIXME: Need to promote f16 FCOPYSIGN to f32, but the
358 // DAGCombiner::visitFP_ROUND probably needs improvements first.
359 setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
360 }
361
362 setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Legal);
363 setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal);
364 setCondCodeAction(FPCCToExpand, MVT::f16, Expand);
365 setOperationAction(ISD::SELECT_CC, MVT::f16, Expand);
366 setOperationAction(ISD::BR_CC, MVT::f16, Expand);
367
368 setOperationAction({ISD::FREM, ISD::FNEARBYINT, ISD::FPOW, ISD::FPOWI,
369 ISD::FCOS, ISD::FSIN, ISD::FSINCOS, ISD::FEXP,
370 ISD::FEXP2, ISD::FLOG, ISD::FLOG2, ISD::FLOG10},
371 MVT::f16, Promote);
372
373 // FIXME: Need to promote f16 STRICT_* to f32 libcalls, but we don't have
374 // complete support for all operations in LegalizeDAG.
375 setOperationAction({ISD::STRICT_FCEIL, ISD::STRICT_FFLOOR,
376 ISD::STRICT_FNEARBYINT, ISD::STRICT_FRINT,
377 ISD::STRICT_FROUND, ISD::STRICT_FROUNDEVEN,
378 ISD::STRICT_FTRUNC},
379 MVT::f16, Promote);
380
381 // We need to custom promote this.
382 if (Subtarget.is64Bit())
383 setOperationAction(ISD::FPOWI, MVT::i32, Custom);
384 }
385
386 if (Subtarget.hasStdExtF()) {
387 setOperationAction(FPLegalNodeTypes, MVT::f32, Legal);
388 setOperationAction(FPRndMode, MVT::f32, Custom);
389 setCondCodeAction(FPCCToExpand, MVT::f32, Expand);
390 setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
391 setOperationAction(ISD::SELECT, MVT::f32, Custom);
392 setOperationAction(ISD::BR_CC, MVT::f32, Expand);
393 setOperationAction(FPOpToExpand, MVT::f32, Expand);
394 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
395 setTruncStoreAction(MVT::f32, MVT::f16, Expand);
396 }
397
398 if (Subtarget.hasStdExtF() && Subtarget.is64Bit())
399 setOperationAction(ISD::BITCAST, MVT::i32, Custom);
400
401 if (Subtarget.hasStdExtD()) {
402 setOperationAction(FPLegalNodeTypes, MVT::f64, Legal);
403 if (Subtarget.is64Bit()) {
404 setOperationAction(FPRndMode, MVT::f64, Custom);
405 }
406 setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Legal);
407 setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Legal);
408 setCondCodeAction(FPCCToExpand, MVT::f64, Expand);
409 setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
410 setOperationAction(ISD::SELECT, MVT::f64, Custom);
411 setOperationAction(ISD::BR_CC, MVT::f64, Expand);
412 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
413 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
414 setOperationAction(FPOpToExpand, MVT::f64, Expand);
415 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
416 setTruncStoreAction(MVT::f64, MVT::f16, Expand);
417 }
418
419 if (Subtarget.is64Bit())
420 setOperationAction({ISD::FP_TO_UINT, ISD::FP_TO_SINT,
421 ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT},
422 MVT::i32, Custom);
423
424 if (Subtarget.hasStdExtF()) {
425 setOperationAction({ISD::FP_TO_UINT_SAT, ISD::FP_TO_SINT_SAT}, XLenVT,
426 Custom);
427
428 setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT,
429 ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
430 XLenVT, Legal);
431
432 setOperationAction(ISD::GET_ROUNDING, XLenVT, Custom);
433 setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
434 }
435
436 setOperationAction({ISD::GlobalAddress, ISD::BlockAddress, ISD::ConstantPool,
437 ISD::JumpTable},
438 XLenVT, Custom);
439
440 setOperationAction(ISD::GlobalTLSAddress, XLenVT, Custom);
441
442 if (Subtarget.is64Bit())
443 setOperationAction(ISD::Constant, MVT::i64, Custom);
444
445 // TODO: On M-mode only targets, the cycle[h] CSR may not be present.
446 // Unfortunately this can't be determined just from the ISA naming string.
447 setOperationAction(ISD::READCYCLECOUNTER, MVT::i64,
448 Subtarget.is64Bit() ? Legal : Custom);
449
450 setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal);
451 setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
452 if (Subtarget.is64Bit())
453 setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i32, Custom);
454
455 if (Subtarget.hasStdExtA()) {
456 setMaxAtomicSizeInBitsSupported(Subtarget.getXLen());
457 setMinCmpXchgSizeInBits(32);
458 } else if (Subtarget.hasForcedAtomics()) {
459 setMaxAtomicSizeInBitsSupported(Subtarget.getXLen());
460 } else {
461 setMaxAtomicSizeInBitsSupported(0);
462 }
463
464 setOperationAction(ISD::ATOMIC_FENCE, MVT::Other, Custom);
465
466 setBooleanContents(ZeroOrOneBooleanContent);
467
468 if (Subtarget.hasVInstructions()) {
469 setBooleanVectorContents(ZeroOrOneBooleanContent);
470
471 setOperationAction(ISD::VSCALE, XLenVT, Custom);
472
473 // RVV intrinsics may have illegal operands.
474 // We also need to custom legalize vmv.x.s.
475 setOperationAction({ISD::INTRINSIC_WO_CHAIN, ISD::INTRINSIC_W_CHAIN},
476 {MVT::i8, MVT::i16}, Custom);
477 if (Subtarget.is64Bit())
478 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i32, Custom);
479 else
480 setOperationAction({ISD::INTRINSIC_WO_CHAIN, ISD::INTRINSIC_W_CHAIN},
481 MVT::i64, Custom);
482
483 setOperationAction({ISD::INTRINSIC_W_CHAIN, ISD::INTRINSIC_VOID},
484 MVT::Other, Custom);
485
486 static const unsigned IntegerVPOps[] = {
487 ISD::VP_ADD, ISD::VP_SUB, ISD::VP_MUL,
488 ISD::VP_SDIV, ISD::VP_UDIV, ISD::VP_SREM,
489 ISD::VP_UREM, ISD::VP_AND, ISD::VP_OR,
490 ISD::VP_XOR, ISD::VP_ASHR, ISD::VP_LSHR,
491 ISD::VP_SHL, ISD::VP_REDUCE_ADD, ISD::VP_REDUCE_AND,
492 ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR, ISD::VP_REDUCE_SMAX,
493 ISD::VP_REDUCE_SMIN, ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN,
494 ISD::VP_MERGE, ISD::VP_SELECT, ISD::VP_FP_TO_SINT,
495 ISD::VP_FP_TO_UINT, ISD::VP_SETCC, ISD::VP_SIGN_EXTEND,
496 ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE, ISD::VP_SMIN,
497 ISD::VP_SMAX, ISD::VP_UMIN, ISD::VP_UMAX,
498 ISD::VP_ABS};
499
500 static const unsigned FloatingPointVPOps[] = {
501 ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
502 ISD::VP_FDIV, ISD::VP_FNEG, ISD::VP_FABS,
503 ISD::VP_FMA, ISD::VP_REDUCE_FADD, ISD::VP_REDUCE_SEQ_FADD,
504 ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_MERGE,
505 ISD::VP_SELECT, ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP,
506 ISD::VP_SETCC, ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND,
507 ISD::VP_SQRT, ISD::VP_FMINNUM, ISD::VP_FMAXNUM,
508 ISD::VP_FCEIL, ISD::VP_FFLOOR, ISD::VP_FROUND,
509 ISD::VP_FROUNDEVEN, ISD::VP_FCOPYSIGN, ISD::VP_FROUNDTOZERO,
510 ISD::VP_FRINT, ISD::VP_FNEARBYINT};
511
512 static const unsigned IntegerVecReduceOps[] = {
513 ISD::VECREDUCE_ADD, ISD::VECREDUCE_AND, ISD::VECREDUCE_OR,
514 ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX, ISD::VECREDUCE_SMIN,
515 ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN};
516
517 static const unsigned FloatingPointVecReduceOps[] = {
518 ISD::VECREDUCE_FADD, ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_FMIN,
519 ISD::VECREDUCE_FMAX};
520
521 if (!Subtarget.is64Bit()) {
522 // We must custom-lower certain vXi64 operations on RV32 due to the vector
523 // element type being illegal.
524 setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT},
525 MVT::i64, Custom);
526
527 setOperationAction(IntegerVecReduceOps, MVT::i64, Custom);
528
529 setOperationAction({ISD::VP_REDUCE_ADD, ISD::VP_REDUCE_AND,
530 ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR,
531 ISD::VP_REDUCE_SMAX, ISD::VP_REDUCE_SMIN,
532 ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN},
533 MVT::i64, Custom);
534 }
535
536 for (MVT VT : BoolVecVTs) {
537 if (!isTypeLegal(VT))
538 continue;
539
540 setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
541
542 // Mask VTs are custom-expanded into a series of standard nodes
543 setOperationAction({ISD::TRUNCATE, ISD::CONCAT_VECTORS,
544 ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR},
545 VT, Custom);
546
547 setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
548 Custom);
549
550 setOperationAction(ISD::SELECT, VT, Custom);
551 setOperationAction(
552 {ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
553 Expand);
554
555 setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR}, VT, Custom);
556
557 setOperationAction(
558 {ISD::VECREDUCE_AND, ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR}, VT,
559 Custom);
560
561 setOperationAction(
562 {ISD::VP_REDUCE_AND, ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR}, VT,
563 Custom);
564
565 // RVV has native int->float & float->int conversions where the
566 // element type sizes are within one power-of-two of each other. Any
567 // wider distances between type sizes have to be lowered as sequences
568 // which progressively narrow the gap in stages.
569 setOperationAction(
570 {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
571 VT, Custom);
572 setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
573 Custom);
574
575 // Expand all extending loads to types larger than this, and truncating
576 // stores from types larger than this.
577 for (MVT OtherVT : MVT::integer_scalable_vector_valuetypes()) {
578 setTruncStoreAction(OtherVT, VT, Expand);
579 setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, OtherVT,
580 VT, Expand);
581 }
582
583 setOperationAction({ISD::VP_FP_TO_SINT, ISD::VP_FP_TO_UINT,
584 ISD::VP_TRUNCATE, ISD::VP_SETCC},
585 VT, Custom);
586 setOperationAction(ISD::VECTOR_REVERSE, VT, Custom);
587
588 setOperationPromotedToType(
589 ISD::VECTOR_SPLICE, VT,
590 MVT::getVectorVT(MVT::i8, VT.getVectorElementCount()));
591 }
592
593 for (MVT VT : IntVecVTs) {
594 if (!isTypeLegal(VT))
595 continue;
596
597 setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
598 setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom);
599
600 // Vectors implement MULHS/MULHU.
601 setOperationAction({ISD::SMUL_LOHI, ISD::UMUL_LOHI}, VT, Expand);
602
603 // nxvXi64 MULHS/MULHU requires the V extension instead of Zve64*.
604 if (VT.getVectorElementType() == MVT::i64 && !Subtarget.hasStdExtV())
605 setOperationAction({ISD::MULHU, ISD::MULHS}, VT, Expand);
606
607 setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, VT,
608 Legal);
609
610 setOperationAction({ISD::ROTL, ISD::ROTR}, VT, Expand);
611
612 setOperationAction({ISD::CTTZ, ISD::CTLZ, ISD::CTPOP}, VT, Expand);
613
614 setOperationAction(ISD::BSWAP, VT, Expand);
615 setOperationAction({ISD::VP_BSWAP, ISD::VP_BITREVERSE}, VT, Expand);
616 setOperationAction({ISD::VP_FSHL, ISD::VP_FSHR}, VT, Expand);
617 setOperationAction({ISD::VP_CTLZ, ISD::VP_CTLZ_ZERO_UNDEF, ISD::VP_CTTZ,
618 ISD::VP_CTTZ_ZERO_UNDEF, ISD::VP_CTPOP},
619 VT, Expand);
620
621 // Custom-lower extensions and truncations from/to mask types.
622 setOperationAction({ISD::ANY_EXTEND, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND},
623 VT, Custom);
624
625 // RVV has native int->float & float->int conversions where the
626 // element type sizes are within one power-of-two of each other. Any
627 // wider distances between type sizes have to be lowered as sequences
628 // which progressively narrow the gap in stages.
629 setOperationAction(
630 {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
631 VT, Custom);
632 setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
633 Custom);
634
635 setOperationAction(
636 {ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, Legal);
637
638 // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
639 // nodes which truncate by one power of two at a time.
640 setOperationAction(ISD::TRUNCATE, VT, Custom);
641
642 // Custom-lower insert/extract operations to simplify patterns.
643 setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
644 Custom);
645
646 // Custom-lower reduction operations to set up the corresponding custom
647 // nodes' operands.
648 setOperationAction(IntegerVecReduceOps, VT, Custom);
649
650 setOperationAction(IntegerVPOps, VT, Custom);
651
652 setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
653
654 setOperationAction({ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER},
655 VT, Custom);
656
657 setOperationAction(
658 {ISD::VP_LOAD, ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
659 ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER, ISD::VP_SCATTER},
660 VT, Custom);
661
662 setOperationAction(
663 {ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR},
664 VT, Custom);
665
666 setOperationAction(ISD::SELECT, VT, Custom);
667 setOperationAction(ISD::SELECT_CC, VT, Expand);
668
669 setOperationAction({ISD::STEP_VECTOR, ISD::VECTOR_REVERSE}, VT, Custom);
670
671 for (MVT OtherVT : MVT::integer_scalable_vector_valuetypes()) {
672 setTruncStoreAction(VT, OtherVT, Expand);
673 setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, OtherVT,
674 VT, Expand);
675 }
676
677 // Splice
678 setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
679
680 // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if element of VT in the range
681 // of f32.
682 EVT FloatVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
683 if (isTypeLegal(FloatVT)) {
684 setOperationAction(
685 {ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
686 Custom);
687 }
688 }
689
690 // Expand various CCs to best match the RVV ISA, which natively supports UNE
691 // but no other unordered comparisons, and supports all ordered comparisons
692 // except ONE. Additionally, we expand GT,OGT,GE,OGE for optimization
693 // purposes; they are expanded to their swapped-operand CCs (LT,OLT,LE,OLE),
694 // and we pattern-match those back to the "original", swapping operands once
695 // more. This way we catch both operations and both "vf" and "fv" forms with
696 // fewer patterns.
697 static const ISD::CondCode VFPCCToExpand[] = {
698 ISD::SETO, ISD::SETONE, ISD::SETUEQ, ISD::SETUGT,
699 ISD::SETUGE, ISD::SETULT, ISD::SETULE, ISD::SETUO,
700 ISD::SETGT, ISD::SETOGT, ISD::SETGE, ISD::SETOGE,
701 };
702
703 // Sets common operation actions on RVV floating-point vector types.
704 const auto SetCommonVFPActions = [&](MVT VT) {
705 setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
706 // RVV has native FP_ROUND & FP_EXTEND conversions where the element type
707 // sizes are within one power-of-two of each other. Therefore conversions
708 // between vXf16 and vXf64 must be lowered as sequences which convert via
709 // vXf32.
710 setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
711 // Custom-lower insert/extract operations to simplify patterns.
712 setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
713 Custom);
714 // Expand various condition codes (explained above).
715 setCondCodeAction(VFPCCToExpand, VT, Expand);
716
717 setOperationAction({ISD::FMINNUM, ISD::FMAXNUM}, VT, Legal);
718
719 setOperationAction(
720 {ISD::FTRUNC, ISD::FCEIL, ISD::FFLOOR, ISD::FROUND, ISD::FROUNDEVEN},
721 VT, Custom);
722
723 setOperationAction(FloatingPointVecReduceOps, VT, Custom);
724
725 // Expand FP operations that need libcalls.
726 setOperationAction(ISD::FREM, VT, Expand);
727 setOperationAction(ISD::FPOW, VT, Expand);
728 setOperationAction(ISD::FCOS, VT, Expand);
729 setOperationAction(ISD::FSIN, VT, Expand);
730 setOperationAction(ISD::FSINCOS, VT, Expand);
731 setOperationAction(ISD::FEXP, VT, Expand);
732 setOperationAction(ISD::FEXP2, VT, Expand);
733 setOperationAction(ISD::FLOG, VT, Expand);
734 setOperationAction(ISD::FLOG2, VT, Expand);
735 setOperationAction(ISD::FLOG10, VT, Expand);
736 setOperationAction(ISD::FRINT, VT, Expand);
737 setOperationAction(ISD::FNEARBYINT, VT, Expand);
738
739 setOperationAction(ISD::FCOPYSIGN, VT, Legal);
740
741 setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
742
743 setOperationAction({ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER},
744 VT, Custom);
745
746 setOperationAction(
747 {ISD::VP_LOAD, ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
748 ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER, ISD::VP_SCATTER},
749 VT, Custom);
750
751 setOperationAction(ISD::SELECT, VT, Custom);
752 setOperationAction(ISD::SELECT_CC, VT, Expand);
753
754 setOperationAction(
755 {ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR},
756 VT, Custom);
757
758 setOperationAction({ISD::VECTOR_REVERSE, ISD::VECTOR_SPLICE}, VT, Custom);
759
760 setOperationAction(FloatingPointVPOps, VT, Custom);
761 };
762
763 // Sets common extload/truncstore actions on RVV floating-point vector
764 // types.
765 const auto SetCommonVFPExtLoadTruncStoreActions =
766 [&](MVT VT, ArrayRef<MVT::SimpleValueType> SmallerVTs) {
767 for (auto SmallVT : SmallerVTs) {
768 setTruncStoreAction(VT, SmallVT, Expand);
769 setLoadExtAction(ISD::EXTLOAD, VT, SmallVT, Expand);
770 }
771 };
772
773 if (Subtarget.hasVInstructionsF16()) {
774 for (MVT VT : F16VecVTs) {
775 if (!isTypeLegal(VT))
776 continue;
777 SetCommonVFPActions(VT);
778 }
779 }
780
781 if (Subtarget.hasVInstructionsF32()) {
782 for (MVT VT : F32VecVTs) {
783 if (!isTypeLegal(VT))
784 continue;
785 SetCommonVFPActions(VT);
786 SetCommonVFPExtLoadTruncStoreActions(VT, F16VecVTs);
787 }
788 }
789
790 if (Subtarget.hasVInstructionsF64()) {
791 for (MVT VT : F64VecVTs) {
792 if (!isTypeLegal(VT))
793 continue;
794 SetCommonVFPActions(VT);
795 SetCommonVFPExtLoadTruncStoreActions(VT, F16VecVTs);
796 SetCommonVFPExtLoadTruncStoreActions(VT, F32VecVTs);
797 }
798 }
799
800 if (Subtarget.useRVVForFixedLengthVectors()) {
801 for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
802 if (!useRVVForFixedLengthVectorVT(VT))
803 continue;
804
805 // By default everything must be expanded.
806 for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
807 setOperationAction(Op, VT, Expand);
808 for (MVT OtherVT : MVT::integer_fixedlen_vector_valuetypes()) {
809 setTruncStoreAction(VT, OtherVT, Expand);
810 setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD},
811 OtherVT, VT, Expand);
812 }
813
814 // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
815 setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT,
816 Custom);
817
818 setOperationAction({ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS}, VT,
819 Custom);
820
821 setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT},
822 VT, Custom);
823
824 setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
825
826 setOperationAction(ISD::SETCC, VT, Custom);
827
828 setOperationAction(ISD::SELECT, VT, Custom);
829
830 setOperationAction(ISD::TRUNCATE, VT, Custom);
831
832 setOperationAction(ISD::BITCAST, VT, Custom);
833
834 setOperationAction(
835 {ISD::VECREDUCE_AND, ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR}, VT,
836 Custom);
837
838 setOperationAction(
839 {ISD::VP_REDUCE_AND, ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR}, VT,
840 Custom);
841
842 setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT,
843 ISD::FP_TO_UINT},
844 VT, Custom);
845 setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
846 Custom);
847
848 // Operations below are different for between masks and other vectors.
849 if (VT.getVectorElementType() == MVT::i1) {
850 setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR, ISD::AND,
851 ISD::OR, ISD::XOR},
852 VT, Custom);
853
854 setOperationAction({ISD::VP_FP_TO_SINT, ISD::VP_FP_TO_UINT,
855 ISD::VP_SETCC, ISD::VP_TRUNCATE},
856 VT, Custom);
857 continue;
858 }
859
860 // Make SPLAT_VECTOR Legal so DAGCombine will convert splat vectors to
861 // it before type legalization for i64 vectors on RV32. It will then be
862 // type legalized to SPLAT_VECTOR_PARTS which we need to Custom handle.
863 // FIXME: Use SPLAT_VECTOR for all types? DAGCombine probably needs
864 // improvements first.
865 if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64) {
866 setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
867 setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom);
868 }
869
870 setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
871
872 setOperationAction(
873 {ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER}, VT, Custom);
874
875 setOperationAction({ISD::VP_LOAD, ISD::VP_STORE,
876 ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
877 ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER,
878 ISD::VP_SCATTER},
879 VT, Custom);
880
881 setOperationAction({ISD::ADD, ISD::MUL, ISD::SUB, ISD::AND, ISD::OR,
882 ISD::XOR, ISD::SDIV, ISD::SREM, ISD::UDIV,
883 ISD::UREM, ISD::SHL, ISD::SRA, ISD::SRL},
884 VT, Custom);
885
886 setOperationAction(
887 {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX, ISD::ABS}, VT, Custom);
888
889 // vXi64 MULHS/MULHU requires the V extension instead of Zve64*.
890 if (VT.getVectorElementType() != MVT::i64 || Subtarget.hasStdExtV())
891 setOperationAction({ISD::MULHS, ISD::MULHU}, VT, Custom);
892
893 setOperationAction(
894 {ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT,
895 Custom);
896
897 setOperationAction(ISD::VSELECT, VT, Custom);
898 setOperationAction(ISD::SELECT_CC, VT, Expand);
899
900 setOperationAction(
901 {ISD::ANY_EXTEND, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND}, VT, Custom);
902
903 // Custom-lower reduction operations to set up the corresponding custom
904 // nodes' operands.
905 setOperationAction({ISD::VECREDUCE_ADD, ISD::VECREDUCE_SMAX,
906 ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX,
907 ISD::VECREDUCE_UMIN},
908 VT, Custom);
909
910 setOperationAction(IntegerVPOps, VT, Custom);
911
912 // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if element of VT in the
913 // range of f32.
914 EVT FloatVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
915 if (isTypeLegal(FloatVT))
916 setOperationAction(
917 {ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
918 Custom);
919 }
920
921 for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) {
922 if (!useRVVForFixedLengthVectorVT(VT))
923 continue;
924
925 // By default everything must be expanded.
926 for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
927 setOperationAction(Op, VT, Expand);
928 for (MVT OtherVT : MVT::fp_fixedlen_vector_valuetypes()) {
929 setLoadExtAction(ISD::EXTLOAD, OtherVT, VT, Expand);
930 setTruncStoreAction(VT, OtherVT, Expand);
931 }
932
933 // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
934 setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT,
935 Custom);
936
937 setOperationAction({ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
938 ISD::VECTOR_SHUFFLE, ISD::INSERT_VECTOR_ELT,
939 ISD::EXTRACT_VECTOR_ELT},
940 VT, Custom);
941
942 setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE,
943 ISD::MGATHER, ISD::MSCATTER},
944 VT, Custom);
945
946 setOperationAction({ISD::VP_LOAD, ISD::VP_STORE,
947 ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
948 ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER,
949 ISD::VP_SCATTER},
950 VT, Custom);
951
952 setOperationAction({ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV,
953 ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN, ISD::FSQRT,
954 ISD::FMA, ISD::FMINNUM, ISD::FMAXNUM},
955 VT, Custom);
956
957 setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
958
959 setOperationAction({ISD::FTRUNC, ISD::FCEIL, ISD::FFLOOR, ISD::FROUND,
960 ISD::FROUNDEVEN},
961 VT, Custom);
962
963 setCondCodeAction(VFPCCToExpand, VT, Expand);
964
965 setOperationAction({ISD::VSELECT, ISD::SELECT}, VT, Custom);
966 setOperationAction(ISD::SELECT_CC, VT, Expand);
967
968 setOperationAction(ISD::BITCAST, VT, Custom);
969
970 setOperationAction(FloatingPointVecReduceOps, VT, Custom);
971
972 setOperationAction(FloatingPointVPOps, VT, Custom);
973 }
974
975 // Custom-legalize bitcasts from fixed-length vectors to scalar types.
976 setOperationAction(ISD::BITCAST, {MVT::i8, MVT::i16, MVT::i32, MVT::i64},
977 Custom);
978 if (Subtarget.hasStdExtZfhOrZfhmin())
979 setOperationAction(ISD::BITCAST, MVT::f16, Custom);
980 if (Subtarget.hasStdExtF())
981 setOperationAction(ISD::BITCAST, MVT::f32, Custom);
982 if (Subtarget.hasStdExtD())
983 setOperationAction(ISD::BITCAST, MVT::f64, Custom);
984 }
985 }
986
987 if (Subtarget.hasForcedAtomics()) {
988 // Set atomic rmw/cas operations to expand to force __sync libcalls.
989 setOperationAction(
990 {ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP, ISD::ATOMIC_LOAD_ADD,
991 ISD::ATOMIC_LOAD_SUB, ISD::ATOMIC_LOAD_AND, ISD::ATOMIC_LOAD_OR,
992 ISD::ATOMIC_LOAD_XOR, ISD::ATOMIC_LOAD_NAND, ISD::ATOMIC_LOAD_MIN,
993 ISD::ATOMIC_LOAD_MAX, ISD::ATOMIC_LOAD_UMIN, ISD::ATOMIC_LOAD_UMAX},
994 XLenVT, Expand);
995 }
996
997 // Function alignments.
998 const Align FunctionAlignment(Subtarget.hasStdExtCOrZca() ? 2 : 4);
999 setMinFunctionAlignment(FunctionAlignment);
1000 setPrefFunctionAlignment(FunctionAlignment);
1001
1002 setMinimumJumpTableEntries(5);
1003
1004 // Jumps are expensive, compared to logic
1005 setJumpIsExpensive();
1006
1007 setTargetDAGCombine({ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
1008 ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
1009 if (Subtarget.is64Bit())
1010 setTargetDAGCombine(ISD::SRA);
1011
1012 if (Subtarget.hasStdExtF())
1013 setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM});
1014
1015 if (Subtarget.hasStdExtZbb())
1016 setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
1017
1018 if (Subtarget.hasStdExtZbs() && Subtarget.is64Bit())
1019 setTargetDAGCombine(ISD::TRUNCATE);
1020
1021 if (Subtarget.hasStdExtZbkb())
1022 setTargetDAGCombine(ISD::BITREVERSE);
1023 if (Subtarget.hasStdExtZfhOrZfhmin())
1024 setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
1025 if (Subtarget.hasStdExtF())
1026 setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
1027 ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
1028 if (Subtarget.hasVInstructions())
1029 setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
1030 ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
1031 ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR});
1032 if (Subtarget.useRVVForFixedLengthVectors())
1033 setTargetDAGCombine(ISD::BITCAST);
1034
1035 setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
1036 setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2");
1037 }
1038
getSetCCResultType(const DataLayout & DL,LLVMContext & Context,EVT VT) const1039 EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL,
1040 LLVMContext &Context,
1041 EVT VT) const {
1042 if (!VT.isVector())
1043 return getPointerTy(DL);
1044 if (Subtarget.hasVInstructions() &&
1045 (VT.isScalableVector() || Subtarget.useRVVForFixedLengthVectors()))
1046 return EVT::getVectorVT(Context, MVT::i1, VT.getVectorElementCount());
1047 return VT.changeVectorElementTypeToInteger();
1048 }
1049
getVPExplicitVectorLengthTy() const1050 MVT RISCVTargetLowering::getVPExplicitVectorLengthTy() const {
1051 return Subtarget.getXLenVT();
1052 }
1053
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const1054 bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
1055 const CallInst &I,
1056 MachineFunction &MF,
1057 unsigned Intrinsic) const {
1058 auto &DL = I.getModule()->getDataLayout();
1059 switch (Intrinsic) {
1060 default:
1061 return false;
1062 case Intrinsic::riscv_masked_atomicrmw_xchg_i32:
1063 case Intrinsic::riscv_masked_atomicrmw_add_i32:
1064 case Intrinsic::riscv_masked_atomicrmw_sub_i32:
1065 case Intrinsic::riscv_masked_atomicrmw_nand_i32:
1066 case Intrinsic::riscv_masked_atomicrmw_max_i32:
1067 case Intrinsic::riscv_masked_atomicrmw_min_i32:
1068 case Intrinsic::riscv_masked_atomicrmw_umax_i32:
1069 case Intrinsic::riscv_masked_atomicrmw_umin_i32:
1070 case Intrinsic::riscv_masked_cmpxchg_i32:
1071 Info.opc = ISD::INTRINSIC_W_CHAIN;
1072 Info.memVT = MVT::i32;
1073 Info.ptrVal = I.getArgOperand(0);
1074 Info.offset = 0;
1075 Info.align = Align(4);
1076 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore |
1077 MachineMemOperand::MOVolatile;
1078 return true;
1079 case Intrinsic::riscv_masked_strided_load:
1080 Info.opc = ISD::INTRINSIC_W_CHAIN;
1081 Info.ptrVal = I.getArgOperand(1);
1082 Info.memVT = getValueType(DL, I.getType()->getScalarType());
1083 Info.align = Align(DL.getTypeSizeInBits(I.getType()->getScalarType()) / 8);
1084 Info.size = MemoryLocation::UnknownSize;
1085 Info.flags |= MachineMemOperand::MOLoad;
1086 return true;
1087 case Intrinsic::riscv_masked_strided_store:
1088 Info.opc = ISD::INTRINSIC_VOID;
1089 Info.ptrVal = I.getArgOperand(1);
1090 Info.memVT =
1091 getValueType(DL, I.getArgOperand(0)->getType()->getScalarType());
1092 Info.align = Align(
1093 DL.getTypeSizeInBits(I.getArgOperand(0)->getType()->getScalarType()) /
1094 8);
1095 Info.size = MemoryLocation::UnknownSize;
1096 Info.flags |= MachineMemOperand::MOStore;
1097 return true;
1098 case Intrinsic::riscv_seg2_load:
1099 case Intrinsic::riscv_seg3_load:
1100 case Intrinsic::riscv_seg4_load:
1101 case Intrinsic::riscv_seg5_load:
1102 case Intrinsic::riscv_seg6_load:
1103 case Intrinsic::riscv_seg7_load:
1104 case Intrinsic::riscv_seg8_load:
1105 Info.opc = ISD::INTRINSIC_W_CHAIN;
1106 Info.ptrVal = I.getArgOperand(0);
1107 Info.memVT =
1108 getValueType(DL, I.getType()->getStructElementType(0)->getScalarType());
1109 Info.align =
1110 Align(DL.getTypeSizeInBits(
1111 I.getType()->getStructElementType(0)->getScalarType()) /
1112 8);
1113 Info.size = MemoryLocation::UnknownSize;
1114 Info.flags |= MachineMemOperand::MOLoad;
1115 return true;
1116 }
1117 }
1118
isLegalAddressingMode(const DataLayout & DL,const AddrMode & AM,Type * Ty,unsigned AS,Instruction * I) const1119 bool RISCVTargetLowering::isLegalAddressingMode(const DataLayout &DL,
1120 const AddrMode &AM, Type *Ty,
1121 unsigned AS,
1122 Instruction *I) const {
1123 // No global is ever allowed as a base.
1124 if (AM.BaseGV)
1125 return false;
1126
1127 // RVV instructions only support register addressing.
1128 if (Subtarget.hasVInstructions() && isa<VectorType>(Ty))
1129 return AM.HasBaseReg && AM.Scale == 0 && !AM.BaseOffs;
1130
1131 // Require a 12-bit signed offset.
1132 if (!isInt<12>(AM.BaseOffs))
1133 return false;
1134
1135 switch (AM.Scale) {
1136 case 0: // "r+i" or just "i", depending on HasBaseReg.
1137 break;
1138 case 1:
1139 if (!AM.HasBaseReg) // allow "r+i".
1140 break;
1141 return false; // disallow "r+r" or "r+r+i".
1142 default:
1143 return false;
1144 }
1145
1146 return true;
1147 }
1148
isLegalICmpImmediate(int64_t Imm) const1149 bool RISCVTargetLowering::isLegalICmpImmediate(int64_t Imm) const {
1150 return isInt<12>(Imm);
1151 }
1152
isLegalAddImmediate(int64_t Imm) const1153 bool RISCVTargetLowering::isLegalAddImmediate(int64_t Imm) const {
1154 return isInt<12>(Imm);
1155 }
1156
1157 // On RV32, 64-bit integers are split into their high and low parts and held
1158 // in two different registers, so the trunc is free since the low register can
1159 // just be used.
1160 // FIXME: Should we consider i64->i32 free on RV64 to match the EVT version of
1161 // isTruncateFree?
isTruncateFree(Type * SrcTy,Type * DstTy) const1162 bool RISCVTargetLowering::isTruncateFree(Type *SrcTy, Type *DstTy) const {
1163 if (Subtarget.is64Bit() || !SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
1164 return false;
1165 unsigned SrcBits = SrcTy->getPrimitiveSizeInBits();
1166 unsigned DestBits = DstTy->getPrimitiveSizeInBits();
1167 return (SrcBits == 64 && DestBits == 32);
1168 }
1169
isTruncateFree(EVT SrcVT,EVT DstVT) const1170 bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
1171 // We consider i64->i32 free on RV64 since we have good selection of W
1172 // instructions that make promoting operations back to i64 free in many cases.
1173 if (SrcVT.isVector() || DstVT.isVector() || !SrcVT.isInteger() ||
1174 !DstVT.isInteger())
1175 return false;
1176 unsigned SrcBits = SrcVT.getSizeInBits();
1177 unsigned DestBits = DstVT.getSizeInBits();
1178 return (SrcBits == 64 && DestBits == 32);
1179 }
1180
isZExtFree(SDValue Val,EVT VT2) const1181 bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
1182 // Zexts are free if they can be combined with a load.
1183 // Don't advertise i32->i64 zextload as being free for RV64. It interacts
1184 // poorly with type legalization of compares preferring sext.
1185 if (auto *LD = dyn_cast<LoadSDNode>(Val)) {
1186 EVT MemVT = LD->getMemoryVT();
1187 if ((MemVT == MVT::i8 || MemVT == MVT::i16) &&
1188 (LD->getExtensionType() == ISD::NON_EXTLOAD ||
1189 LD->getExtensionType() == ISD::ZEXTLOAD))
1190 return true;
1191 }
1192
1193 return TargetLowering::isZExtFree(Val, VT2);
1194 }
1195
isSExtCheaperThanZExt(EVT SrcVT,EVT DstVT) const1196 bool RISCVTargetLowering::isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const {
1197 return Subtarget.is64Bit() && SrcVT == MVT::i32 && DstVT == MVT::i64;
1198 }
1199
signExtendConstant(const ConstantInt * CI) const1200 bool RISCVTargetLowering::signExtendConstant(const ConstantInt *CI) const {
1201 return Subtarget.is64Bit() && CI->getType()->isIntegerTy(32);
1202 }
1203
isCheapToSpeculateCttz(Type * Ty) const1204 bool RISCVTargetLowering::isCheapToSpeculateCttz(Type *Ty) const {
1205 return Subtarget.hasStdExtZbb();
1206 }
1207
isCheapToSpeculateCtlz(Type * Ty) const1208 bool RISCVTargetLowering::isCheapToSpeculateCtlz(Type *Ty) const {
1209 return Subtarget.hasStdExtZbb();
1210 }
1211
isMaskAndCmp0FoldingBeneficial(const Instruction & AndI) const1212 bool RISCVTargetLowering::isMaskAndCmp0FoldingBeneficial(
1213 const Instruction &AndI) const {
1214 // We expect to be able to match a bit extraction instruction if the Zbs
1215 // extension is supported and the mask is a power of two. However, we
1216 // conservatively return false if the mask would fit in an ANDI instruction,
1217 // on the basis that it's possible the sinking+duplication of the AND in
1218 // CodeGenPrepare triggered by this hook wouldn't decrease the instruction
1219 // count and would increase code size (e.g. ANDI+BNEZ => BEXTI+BNEZ).
1220 if (!Subtarget.hasStdExtZbs())
1221 return false;
1222 ConstantInt *Mask = dyn_cast<ConstantInt>(AndI.getOperand(1));
1223 if (!Mask)
1224 return false;
1225 return !Mask->getValue().isSignedIntN(12) && Mask->getValue().isPowerOf2();
1226 }
1227
hasAndNotCompare(SDValue Y) const1228 bool RISCVTargetLowering::hasAndNotCompare(SDValue Y) const {
1229 EVT VT = Y.getValueType();
1230
1231 // FIXME: Support vectors once we have tests.
1232 if (VT.isVector())
1233 return false;
1234
1235 return (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()) &&
1236 !isa<ConstantSDNode>(Y);
1237 }
1238
hasBitTest(SDValue X,SDValue Y) const1239 bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const {
1240 // Zbs provides BEXT[_I], which can be used with SEQZ/SNEZ as a bit test.
1241 if (Subtarget.hasStdExtZbs())
1242 return X.getValueType().isScalarInteger();
1243 // We can use ANDI+SEQZ/SNEZ as a bit test. Y contains the bit position.
1244 auto *C = dyn_cast<ConstantSDNode>(Y);
1245 return C && C->getAPIntValue().ule(10);
1246 }
1247
shouldFoldSelectWithIdentityConstant(unsigned Opcode,EVT VT) const1248 bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
1249 EVT VT) const {
1250 // Only enable for rvv.
1251 if (!VT.isVector() || !Subtarget.hasVInstructions())
1252 return false;
1253
1254 if (VT.isFixedLengthVector() && !isTypeLegal(VT))
1255 return false;
1256
1257 return true;
1258 }
1259
shouldConvertConstantLoadToIntImm(const APInt & Imm,Type * Ty) const1260 bool RISCVTargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
1261 Type *Ty) const {
1262 assert(Ty->isIntegerTy());
1263
1264 unsigned BitSize = Ty->getIntegerBitWidth();
1265 if (BitSize > Subtarget.getXLen())
1266 return false;
1267
1268 // Fast path, assume 32-bit immediates are cheap.
1269 int64_t Val = Imm.getSExtValue();
1270 if (isInt<32>(Val))
1271 return true;
1272
1273 // A constant pool entry may be more aligned thant he load we're trying to
1274 // replace. If we don't support unaligned scalar mem, prefer the constant
1275 // pool.
1276 // TODO: Can the caller pass down the alignment?
1277 if (!Subtarget.enableUnalignedScalarMem())
1278 return true;
1279
1280 // Prefer to keep the load if it would require many instructions.
1281 // This uses the same threshold we use for constant pools but doesn't
1282 // check useConstantPoolForLargeInts.
1283 // TODO: Should we keep the load only when we're definitely going to emit a
1284 // constant pool?
1285
1286 RISCVMatInt::InstSeq Seq =
1287 RISCVMatInt::generateInstSeq(Val, Subtarget.getFeatureBits());
1288 return Seq.size() <= Subtarget.getMaxBuildIntsCost();
1289 }
1290
1291 bool RISCVTargetLowering::
shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(SDValue X,ConstantSDNode * XC,ConstantSDNode * CC,SDValue Y,unsigned OldShiftOpcode,unsigned NewShiftOpcode,SelectionDAG & DAG) const1292 shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
1293 SDValue X, ConstantSDNode *XC, ConstantSDNode *CC, SDValue Y,
1294 unsigned OldShiftOpcode, unsigned NewShiftOpcode,
1295 SelectionDAG &DAG) const {
1296 // One interesting pattern that we'd want to form is 'bit extract':
1297 // ((1 >> Y) & 1) ==/!= 0
1298 // But we also need to be careful not to try to reverse that fold.
1299
1300 // Is this '((1 >> Y) & 1)'?
1301 if (XC && OldShiftOpcode == ISD::SRL && XC->isOne())
1302 return false; // Keep the 'bit extract' pattern.
1303
1304 // Will this be '((1 >> Y) & 1)' after the transform?
1305 if (NewShiftOpcode == ISD::SRL && CC->isOne())
1306 return true; // Do form the 'bit extract' pattern.
1307
1308 // If 'X' is a constant, and we transform, then we will immediately
1309 // try to undo the fold, thus causing endless combine loop.
1310 // So only do the transform if X is not a constant. This matches the default
1311 // implementation of this function.
1312 return !XC;
1313 }
1314
canSplatOperand(unsigned Opcode,int Operand) const1315 bool RISCVTargetLowering::canSplatOperand(unsigned Opcode, int Operand) const {
1316 switch (Opcode) {
1317 case Instruction::Add:
1318 case Instruction::Sub:
1319 case Instruction::Mul:
1320 case Instruction::And:
1321 case Instruction::Or:
1322 case Instruction::Xor:
1323 case Instruction::FAdd:
1324 case Instruction::FSub:
1325 case Instruction::FMul:
1326 case Instruction::FDiv:
1327 case Instruction::ICmp:
1328 case Instruction::FCmp:
1329 return true;
1330 case Instruction::Shl:
1331 case Instruction::LShr:
1332 case Instruction::AShr:
1333 case Instruction::UDiv:
1334 case Instruction::SDiv:
1335 case Instruction::URem:
1336 case Instruction::SRem:
1337 return Operand == 1;
1338 default:
1339 return false;
1340 }
1341 }
1342
1343
canSplatOperand(Instruction * I,int Operand) const1344 bool RISCVTargetLowering::canSplatOperand(Instruction *I, int Operand) const {
1345 if (!I->getType()->isVectorTy() || !Subtarget.hasVInstructions())
1346 return false;
1347
1348 if (canSplatOperand(I->getOpcode(), Operand))
1349 return true;
1350
1351 auto *II = dyn_cast<IntrinsicInst>(I);
1352 if (!II)
1353 return false;
1354
1355 switch (II->getIntrinsicID()) {
1356 case Intrinsic::fma:
1357 case Intrinsic::vp_fma:
1358 return Operand == 0 || Operand == 1;
1359 case Intrinsic::vp_shl:
1360 case Intrinsic::vp_lshr:
1361 case Intrinsic::vp_ashr:
1362 case Intrinsic::vp_udiv:
1363 case Intrinsic::vp_sdiv:
1364 case Intrinsic::vp_urem:
1365 case Intrinsic::vp_srem:
1366 return Operand == 1;
1367 // These intrinsics are commutative.
1368 case Intrinsic::vp_add:
1369 case Intrinsic::vp_mul:
1370 case Intrinsic::vp_and:
1371 case Intrinsic::vp_or:
1372 case Intrinsic::vp_xor:
1373 case Intrinsic::vp_fadd:
1374 case Intrinsic::vp_fmul:
1375 // These intrinsics have 'vr' versions.
1376 case Intrinsic::vp_sub:
1377 case Intrinsic::vp_fsub:
1378 case Intrinsic::vp_fdiv:
1379 return Operand == 0 || Operand == 1;
1380 default:
1381 return false;
1382 }
1383 }
1384
1385 /// Check if sinking \p I's operands to I's basic block is profitable, because
1386 /// the operands can be folded into a target instruction, e.g.
1387 /// splats of scalars can fold into vector instructions.
shouldSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const1388 bool RISCVTargetLowering::shouldSinkOperands(
1389 Instruction *I, SmallVectorImpl<Use *> &Ops) const {
1390 using namespace llvm::PatternMatch;
1391
1392 if (!I->getType()->isVectorTy() || !Subtarget.hasVInstructions())
1393 return false;
1394
1395 for (auto OpIdx : enumerate(I->operands())) {
1396 if (!canSplatOperand(I, OpIdx.index()))
1397 continue;
1398
1399 Instruction *Op = dyn_cast<Instruction>(OpIdx.value().get());
1400 // Make sure we are not already sinking this operand
1401 if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; }))
1402 continue;
1403
1404 // We are looking for a splat that can be sunk.
1405 if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()),
1406 m_Undef(), m_ZeroMask())))
1407 continue;
1408
1409 // All uses of the shuffle should be sunk to avoid duplicating it across gpr
1410 // and vector registers
1411 for (Use &U : Op->uses()) {
1412 Instruction *Insn = cast<Instruction>(U.getUser());
1413 if (!canSplatOperand(Insn, U.getOperandNo()))
1414 return false;
1415 }
1416
1417 Ops.push_back(&Op->getOperandUse(0));
1418 Ops.push_back(&OpIdx.value());
1419 }
1420 return true;
1421 }
1422
shouldScalarizeBinop(SDValue VecOp) const1423 bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
1424 unsigned Opc = VecOp.getOpcode();
1425
1426 // Assume target opcodes can't be scalarized.
1427 // TODO - do we have any exceptions?
1428 if (Opc >= ISD::BUILTIN_OP_END)
1429 return false;
1430
1431 // If the vector op is not supported, try to convert to scalar.
1432 EVT VecVT = VecOp.getValueType();
1433 if (!isOperationLegalOrCustomOrPromote(Opc, VecVT))
1434 return true;
1435
1436 // If the vector op is supported, but the scalar op is not, the transform may
1437 // not be worthwhile.
1438 EVT ScalarVT = VecVT.getScalarType();
1439 return isOperationLegalOrCustomOrPromote(Opc, ScalarVT);
1440 }
1441
isOffsetFoldingLegal(const GlobalAddressSDNode * GA) const1442 bool RISCVTargetLowering::isOffsetFoldingLegal(
1443 const GlobalAddressSDNode *GA) const {
1444 // In order to maximise the opportunity for common subexpression elimination,
1445 // keep a separate ADD node for the global address offset instead of folding
1446 // it in the global address node. Later peephole optimisations may choose to
1447 // fold it back in when profitable.
1448 return false;
1449 }
1450
isFPImmLegal(const APFloat & Imm,EVT VT,bool ForCodeSize) const1451 bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
1452 bool ForCodeSize) const {
1453 if (VT == MVT::f16 && !Subtarget.hasStdExtZfhOrZfhmin())
1454 return false;
1455 if (VT == MVT::f32 && !Subtarget.hasStdExtF())
1456 return false;
1457 if (VT == MVT::f64 && !Subtarget.hasStdExtD())
1458 return false;
1459 return Imm.isZero();
1460 }
1461
1462 // TODO: This is very conservative.
isExtractSubvectorCheap(EVT ResVT,EVT SrcVT,unsigned Index) const1463 bool RISCVTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
1464 unsigned Index) const {
1465 if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
1466 return false;
1467
1468 // Only support extracting a fixed from a fixed vector for now.
1469 if (ResVT.isScalableVector() || SrcVT.isScalableVector())
1470 return false;
1471
1472 unsigned ResElts = ResVT.getVectorNumElements();
1473 unsigned SrcElts = SrcVT.getVectorNumElements();
1474
1475 // Convervatively only handle extracting half of a vector.
1476 // TODO: Relax this.
1477 if ((ResElts * 2) != SrcElts)
1478 return false;
1479
1480 // The smallest type we can slide is i8.
1481 // TODO: We can extract index 0 from a mask vector without a slide.
1482 if (ResVT.getVectorElementType() == MVT::i1)
1483 return false;
1484
1485 // Slide can support arbitrary index, but we only treat vslidedown.vi as
1486 // cheap.
1487 if (Index >= 32)
1488 return false;
1489
1490 // TODO: We can do arbitrary slidedowns, but for now only support extracting
1491 // the upper half of a vector until we have more test coverage.
1492 return Index == 0 || Index == ResElts;
1493 }
1494
hasBitPreservingFPLogic(EVT VT) const1495 bool RISCVTargetLowering::hasBitPreservingFPLogic(EVT VT) const {
1496 return (VT == MVT::f16 && Subtarget.hasStdExtZfhOrZfhmin()) ||
1497 (VT == MVT::f32 && Subtarget.hasStdExtF()) ||
1498 (VT == MVT::f64 && Subtarget.hasStdExtD());
1499 }
1500
getRegisterTypeForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const1501 MVT RISCVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
1502 CallingConv::ID CC,
1503 EVT VT) const {
1504 // Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled.
1505 // We might still end up using a GPR but that will be decided based on ABI.
1506 if (VT == MVT::f16 && Subtarget.hasStdExtF() &&
1507 !Subtarget.hasStdExtZfhOrZfhmin())
1508 return MVT::f32;
1509
1510 return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
1511 }
1512
getNumRegistersForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const1513 unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
1514 CallingConv::ID CC,
1515 EVT VT) const {
1516 // Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled.
1517 // We might still end up using a GPR but that will be decided based on ABI.
1518 if (VT == MVT::f16 && Subtarget.hasStdExtF() &&
1519 !Subtarget.hasStdExtZfhOrZfhmin())
1520 return 1;
1521
1522 return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
1523 }
1524
1525 // Changes the condition code and swaps operands if necessary, so the SetCC
1526 // operation matches one of the comparisons supported directly by branches
1527 // in the RISC-V ISA. May adjust compares to favor compare with 0 over compare
1528 // with 1/-1.
translateSetCCForBranch(const SDLoc & DL,SDValue & LHS,SDValue & RHS,ISD::CondCode & CC,SelectionDAG & DAG)1529 static void translateSetCCForBranch(const SDLoc &DL, SDValue &LHS, SDValue &RHS,
1530 ISD::CondCode &CC, SelectionDAG &DAG) {
1531 // If this is a single bit test that can't be handled by ANDI, shift the
1532 // bit to be tested to the MSB and perform a signed compare with 0.
1533 if (isIntEqualitySetCC(CC) && isNullConstant(RHS) &&
1534 LHS.getOpcode() == ISD::AND && LHS.hasOneUse() &&
1535 isa<ConstantSDNode>(LHS.getOperand(1))) {
1536 uint64_t Mask = LHS.getConstantOperandVal(1);
1537 if ((isPowerOf2_64(Mask) || isMask_64(Mask)) && !isInt<12>(Mask)) {
1538 unsigned ShAmt = 0;
1539 if (isPowerOf2_64(Mask)) {
1540 CC = CC == ISD::SETEQ ? ISD::SETGE : ISD::SETLT;
1541 ShAmt = LHS.getValueSizeInBits() - 1 - Log2_64(Mask);
1542 } else {
1543 ShAmt = LHS.getValueSizeInBits() - llvm::bit_width(Mask);
1544 }
1545
1546 LHS = LHS.getOperand(0);
1547 if (ShAmt != 0)
1548 LHS = DAG.getNode(ISD::SHL, DL, LHS.getValueType(), LHS,
1549 DAG.getConstant(ShAmt, DL, LHS.getValueType()));
1550 return;
1551 }
1552 }
1553
1554 if (auto *RHSC = dyn_cast<ConstantSDNode>(RHS)) {
1555 int64_t C = RHSC->getSExtValue();
1556 switch (CC) {
1557 default: break;
1558 case ISD::SETGT:
1559 // Convert X > -1 to X >= 0.
1560 if (C == -1) {
1561 RHS = DAG.getConstant(0, DL, RHS.getValueType());
1562 CC = ISD::SETGE;
1563 return;
1564 }
1565 break;
1566 case ISD::SETLT:
1567 // Convert X < 1 to 0 <= X.
1568 if (C == 1) {
1569 RHS = LHS;
1570 LHS = DAG.getConstant(0, DL, RHS.getValueType());
1571 CC = ISD::SETGE;
1572 return;
1573 }
1574 break;
1575 }
1576 }
1577
1578 switch (CC) {
1579 default:
1580 break;
1581 case ISD::SETGT:
1582 case ISD::SETLE:
1583 case ISD::SETUGT:
1584 case ISD::SETULE:
1585 CC = ISD::getSetCCSwappedOperands(CC);
1586 std::swap(LHS, RHS);
1587 break;
1588 }
1589 }
1590
getLMUL(MVT VT)1591 RISCVII::VLMUL RISCVTargetLowering::getLMUL(MVT VT) {
1592 assert(VT.isScalableVector() && "Expecting a scalable vector type");
1593 unsigned KnownSize = VT.getSizeInBits().getKnownMinValue();
1594 if (VT.getVectorElementType() == MVT::i1)
1595 KnownSize *= 8;
1596
1597 switch (KnownSize) {
1598 default:
1599 llvm_unreachable("Invalid LMUL.");
1600 case 8:
1601 return RISCVII::VLMUL::LMUL_F8;
1602 case 16:
1603 return RISCVII::VLMUL::LMUL_F4;
1604 case 32:
1605 return RISCVII::VLMUL::LMUL_F2;
1606 case 64:
1607 return RISCVII::VLMUL::LMUL_1;
1608 case 128:
1609 return RISCVII::VLMUL::LMUL_2;
1610 case 256:
1611 return RISCVII::VLMUL::LMUL_4;
1612 case 512:
1613 return RISCVII::VLMUL::LMUL_8;
1614 }
1615 }
1616
getRegClassIDForLMUL(RISCVII::VLMUL LMul)1617 unsigned RISCVTargetLowering::getRegClassIDForLMUL(RISCVII::VLMUL LMul) {
1618 switch (LMul) {
1619 default:
1620 llvm_unreachable("Invalid LMUL.");
1621 case RISCVII::VLMUL::LMUL_F8:
1622 case RISCVII::VLMUL::LMUL_F4:
1623 case RISCVII::VLMUL::LMUL_F2:
1624 case RISCVII::VLMUL::LMUL_1:
1625 return RISCV::VRRegClassID;
1626 case RISCVII::VLMUL::LMUL_2:
1627 return RISCV::VRM2RegClassID;
1628 case RISCVII::VLMUL::LMUL_4:
1629 return RISCV::VRM4RegClassID;
1630 case RISCVII::VLMUL::LMUL_8:
1631 return RISCV::VRM8RegClassID;
1632 }
1633 }
1634
getSubregIndexByMVT(MVT VT,unsigned Index)1635 unsigned RISCVTargetLowering::getSubregIndexByMVT(MVT VT, unsigned Index) {
1636 RISCVII::VLMUL LMUL = getLMUL(VT);
1637 if (LMUL == RISCVII::VLMUL::LMUL_F8 ||
1638 LMUL == RISCVII::VLMUL::LMUL_F4 ||
1639 LMUL == RISCVII::VLMUL::LMUL_F2 ||
1640 LMUL == RISCVII::VLMUL::LMUL_1) {
1641 static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7,
1642 "Unexpected subreg numbering");
1643 return RISCV::sub_vrm1_0 + Index;
1644 }
1645 if (LMUL == RISCVII::VLMUL::LMUL_2) {
1646 static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3,
1647 "Unexpected subreg numbering");
1648 return RISCV::sub_vrm2_0 + Index;
1649 }
1650 if (LMUL == RISCVII::VLMUL::LMUL_4) {
1651 static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1,
1652 "Unexpected subreg numbering");
1653 return RISCV::sub_vrm4_0 + Index;
1654 }
1655 llvm_unreachable("Invalid vector type.");
1656 }
1657
getRegClassIDForVecVT(MVT VT)1658 unsigned RISCVTargetLowering::getRegClassIDForVecVT(MVT VT) {
1659 if (VT.getVectorElementType() == MVT::i1)
1660 return RISCV::VRRegClassID;
1661 return getRegClassIDForLMUL(getLMUL(VT));
1662 }
1663
1664 // Attempt to decompose a subvector insert/extract between VecVT and
1665 // SubVecVT via subregister indices. Returns the subregister index that
1666 // can perform the subvector insert/extract with the given element index, as
1667 // well as the index corresponding to any leftover subvectors that must be
1668 // further inserted/extracted within the register class for SubVecVT.
1669 std::pair<unsigned, unsigned>
decomposeSubvectorInsertExtractToSubRegs(MVT VecVT,MVT SubVecVT,unsigned InsertExtractIdx,const RISCVRegisterInfo * TRI)1670 RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
1671 MVT VecVT, MVT SubVecVT, unsigned InsertExtractIdx,
1672 const RISCVRegisterInfo *TRI) {
1673 static_assert((RISCV::VRM8RegClassID > RISCV::VRM4RegClassID &&
1674 RISCV::VRM4RegClassID > RISCV::VRM2RegClassID &&
1675 RISCV::VRM2RegClassID > RISCV::VRRegClassID),
1676 "Register classes not ordered");
1677 unsigned VecRegClassID = getRegClassIDForVecVT(VecVT);
1678 unsigned SubRegClassID = getRegClassIDForVecVT(SubVecVT);
1679 // Try to compose a subregister index that takes us from the incoming
1680 // LMUL>1 register class down to the outgoing one. At each step we half
1681 // the LMUL:
1682 // nxv16i32@12 -> nxv2i32: sub_vrm4_1_then_sub_vrm2_1_then_sub_vrm1_0
1683 // Note that this is not guaranteed to find a subregister index, such as
1684 // when we are extracting from one VR type to another.
1685 unsigned SubRegIdx = RISCV::NoSubRegister;
1686 for (const unsigned RCID :
1687 {RISCV::VRM4RegClassID, RISCV::VRM2RegClassID, RISCV::VRRegClassID})
1688 if (VecRegClassID > RCID && SubRegClassID <= RCID) {
1689 VecVT = VecVT.getHalfNumVectorElementsVT();
1690 bool IsHi =
1691 InsertExtractIdx >= VecVT.getVectorElementCount().getKnownMinValue();
1692 SubRegIdx = TRI->composeSubRegIndices(SubRegIdx,
1693 getSubregIndexByMVT(VecVT, IsHi));
1694 if (IsHi)
1695 InsertExtractIdx -= VecVT.getVectorElementCount().getKnownMinValue();
1696 }
1697 return {SubRegIdx, InsertExtractIdx};
1698 }
1699
1700 // Permit combining of mask vectors as BUILD_VECTOR never expands to scalar
1701 // stores for those types.
mergeStoresAfterLegalization(EVT VT) const1702 bool RISCVTargetLowering::mergeStoresAfterLegalization(EVT VT) const {
1703 return !Subtarget.useRVVForFixedLengthVectors() ||
1704 (VT.isFixedLengthVector() && VT.getVectorElementType() == MVT::i1);
1705 }
1706
isLegalElementTypeForRVV(Type * ScalarTy) const1707 bool RISCVTargetLowering::isLegalElementTypeForRVV(Type *ScalarTy) const {
1708 if (ScalarTy->isPointerTy())
1709 return true;
1710
1711 if (ScalarTy->isIntegerTy(8) || ScalarTy->isIntegerTy(16) ||
1712 ScalarTy->isIntegerTy(32))
1713 return true;
1714
1715 if (ScalarTy->isIntegerTy(64))
1716 return Subtarget.hasVInstructionsI64();
1717
1718 if (ScalarTy->isHalfTy())
1719 return Subtarget.hasVInstructionsF16();
1720 if (ScalarTy->isFloatTy())
1721 return Subtarget.hasVInstructionsF32();
1722 if (ScalarTy->isDoubleTy())
1723 return Subtarget.hasVInstructionsF64();
1724
1725 return false;
1726 }
1727
combineRepeatedFPDivisors() const1728 unsigned RISCVTargetLowering::combineRepeatedFPDivisors() const {
1729 return NumRepeatedDivisors;
1730 }
1731
getVLOperand(SDValue Op)1732 static SDValue getVLOperand(SDValue Op) {
1733 assert((Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
1734 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN) &&
1735 "Unexpected opcode");
1736 bool HasChain = Op.getOpcode() == ISD::INTRINSIC_W_CHAIN;
1737 unsigned IntNo = Op.getConstantOperandVal(HasChain ? 1 : 0);
1738 const RISCVVIntrinsicsTable::RISCVVIntrinsicInfo *II =
1739 RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IntNo);
1740 if (!II)
1741 return SDValue();
1742 return Op.getOperand(II->VLOperand + 1 + HasChain);
1743 }
1744
useRVVForFixedLengthVectorVT(MVT VT,const RISCVSubtarget & Subtarget)1745 static bool useRVVForFixedLengthVectorVT(MVT VT,
1746 const RISCVSubtarget &Subtarget) {
1747 assert(VT.isFixedLengthVector() && "Expected a fixed length vector type!");
1748 if (!Subtarget.useRVVForFixedLengthVectors())
1749 return false;
1750
1751 // We only support a set of vector types with a consistent maximum fixed size
1752 // across all supported vector element types to avoid legalization issues.
1753 // Therefore -- since the largest is v1024i8/v512i16/etc -- the largest
1754 // fixed-length vector type we support is 1024 bytes.
1755 if (VT.getFixedSizeInBits() > 1024 * 8)
1756 return false;
1757
1758 unsigned MinVLen = Subtarget.getRealMinVLen();
1759
1760 MVT EltVT = VT.getVectorElementType();
1761
1762 // Don't use RVV for vectors we cannot scalarize if required.
1763 switch (EltVT.SimpleTy) {
1764 // i1 is supported but has different rules.
1765 default:
1766 return false;
1767 case MVT::i1:
1768 // Masks can only use a single register.
1769 if (VT.getVectorNumElements() > MinVLen)
1770 return false;
1771 MinVLen /= 8;
1772 break;
1773 case MVT::i8:
1774 case MVT::i16:
1775 case MVT::i32:
1776 break;
1777 case MVT::i64:
1778 if (!Subtarget.hasVInstructionsI64())
1779 return false;
1780 break;
1781 case MVT::f16:
1782 if (!Subtarget.hasVInstructionsF16())
1783 return false;
1784 break;
1785 case MVT::f32:
1786 if (!Subtarget.hasVInstructionsF32())
1787 return false;
1788 break;
1789 case MVT::f64:
1790 if (!Subtarget.hasVInstructionsF64())
1791 return false;
1792 break;
1793 }
1794
1795 // Reject elements larger than ELEN.
1796 if (EltVT.getSizeInBits() > Subtarget.getELEN())
1797 return false;
1798
1799 unsigned LMul = divideCeil(VT.getSizeInBits(), MinVLen);
1800 // Don't use RVV for types that don't fit.
1801 if (LMul > Subtarget.getMaxLMULForFixedLengthVectors())
1802 return false;
1803
1804 // TODO: Perhaps an artificial restriction, but worth having whilst getting
1805 // the base fixed length RVV support in place.
1806 if (!VT.isPow2VectorType())
1807 return false;
1808
1809 return true;
1810 }
1811
useRVVForFixedLengthVectorVT(MVT VT) const1812 bool RISCVTargetLowering::useRVVForFixedLengthVectorVT(MVT VT) const {
1813 return ::useRVVForFixedLengthVectorVT(VT, Subtarget);
1814 }
1815
1816 // Return the largest legal scalable vector type that matches VT's element type.
getContainerForFixedLengthVector(const TargetLowering & TLI,MVT VT,const RISCVSubtarget & Subtarget)1817 static MVT getContainerForFixedLengthVector(const TargetLowering &TLI, MVT VT,
1818 const RISCVSubtarget &Subtarget) {
1819 // This may be called before legal types are setup.
1820 assert(((VT.isFixedLengthVector() && TLI.isTypeLegal(VT)) ||
1821 useRVVForFixedLengthVectorVT(VT, Subtarget)) &&
1822 "Expected legal fixed length vector!");
1823
1824 unsigned MinVLen = Subtarget.getRealMinVLen();
1825 unsigned MaxELen = Subtarget.getELEN();
1826
1827 MVT EltVT = VT.getVectorElementType();
1828 switch (EltVT.SimpleTy) {
1829 default:
1830 llvm_unreachable("unexpected element type for RVV container");
1831 case MVT::i1:
1832 case MVT::i8:
1833 case MVT::i16:
1834 case MVT::i32:
1835 case MVT::i64:
1836 case MVT::f16:
1837 case MVT::f32:
1838 case MVT::f64: {
1839 // We prefer to use LMUL=1 for VLEN sized types. Use fractional lmuls for
1840 // narrower types. The smallest fractional LMUL we support is 8/ELEN. Within
1841 // each fractional LMUL we support SEW between 8 and LMUL*ELEN.
1842 unsigned NumElts =
1843 (VT.getVectorNumElements() * RISCV::RVVBitsPerBlock) / MinVLen;
1844 NumElts = std::max(NumElts, RISCV::RVVBitsPerBlock / MaxELen);
1845 assert(isPowerOf2_32(NumElts) && "Expected power of 2 NumElts");
1846 return MVT::getScalableVectorVT(EltVT, NumElts);
1847 }
1848 }
1849 }
1850
getContainerForFixedLengthVector(SelectionDAG & DAG,MVT VT,const RISCVSubtarget & Subtarget)1851 static MVT getContainerForFixedLengthVector(SelectionDAG &DAG, MVT VT,
1852 const RISCVSubtarget &Subtarget) {
1853 return getContainerForFixedLengthVector(DAG.getTargetLoweringInfo(), VT,
1854 Subtarget);
1855 }
1856
getContainerForFixedLengthVector(MVT VT) const1857 MVT RISCVTargetLowering::getContainerForFixedLengthVector(MVT VT) const {
1858 return ::getContainerForFixedLengthVector(*this, VT, getSubtarget());
1859 }
1860
1861 // Grow V to consume an entire RVV register.
convertToScalableVector(EVT VT,SDValue V,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)1862 static SDValue convertToScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
1863 const RISCVSubtarget &Subtarget) {
1864 assert(VT.isScalableVector() &&
1865 "Expected to convert into a scalable vector!");
1866 assert(V.getValueType().isFixedLengthVector() &&
1867 "Expected a fixed length vector operand!");
1868 SDLoc DL(V);
1869 SDValue Zero = DAG.getConstant(0, DL, Subtarget.getXLenVT());
1870 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), V, Zero);
1871 }
1872
1873 // Shrink V so it's just big enough to maintain a VT's worth of data.
convertFromScalableVector(EVT VT,SDValue V,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)1874 static SDValue convertFromScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
1875 const RISCVSubtarget &Subtarget) {
1876 assert(VT.isFixedLengthVector() &&
1877 "Expected to convert into a fixed length vector!");
1878 assert(V.getValueType().isScalableVector() &&
1879 "Expected a scalable vector operand!");
1880 SDLoc DL(V);
1881 SDValue Zero = DAG.getConstant(0, DL, Subtarget.getXLenVT());
1882 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
1883 }
1884
1885 /// Return the type of the mask type suitable for masking the provided
1886 /// vector type. This is simply an i1 element type vector of the same
1887 /// (possibly scalable) length.
getMaskTypeFor(MVT VecVT)1888 static MVT getMaskTypeFor(MVT VecVT) {
1889 assert(VecVT.isVector());
1890 ElementCount EC = VecVT.getVectorElementCount();
1891 return MVT::getVectorVT(MVT::i1, EC);
1892 }
1893
1894 /// Creates an all ones mask suitable for masking a vector of type VecTy with
1895 /// vector length VL. .
getAllOnesMask(MVT VecVT,SDValue VL,SDLoc DL,SelectionDAG & DAG)1896 static SDValue getAllOnesMask(MVT VecVT, SDValue VL, SDLoc DL,
1897 SelectionDAG &DAG) {
1898 MVT MaskVT = getMaskTypeFor(VecVT);
1899 return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
1900 }
1901
getVLOp(uint64_t NumElts,SDLoc DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)1902 static SDValue getVLOp(uint64_t NumElts, SDLoc DL, SelectionDAG &DAG,
1903 const RISCVSubtarget &Subtarget) {
1904 return DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
1905 }
1906
1907 static std::pair<SDValue, SDValue>
getDefaultVLOps(uint64_t NumElts,MVT ContainerVT,SDLoc DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)1908 getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, SDLoc DL, SelectionDAG &DAG,
1909 const RISCVSubtarget &Subtarget) {
1910 assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
1911 SDValue VL = getVLOp(NumElts, DL, DAG, Subtarget);
1912 SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
1913 return {Mask, VL};
1914 }
1915
1916 // Gets the two common "VL" operands: an all-ones mask and the vector length.
1917 // VecVT is a vector type, either fixed-length or scalable, and ContainerVT is
1918 // the vector type that the fixed-length vector is contained in. Otherwise if
1919 // VecVT is scalable, then ContainerVT should be the same as VecVT.
1920 static std::pair<SDValue, SDValue>
getDefaultVLOps(MVT VecVT,MVT ContainerVT,SDLoc DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)1921 getDefaultVLOps(MVT VecVT, MVT ContainerVT, SDLoc DL, SelectionDAG &DAG,
1922 const RISCVSubtarget &Subtarget) {
1923 if (VecVT.isFixedLengthVector())
1924 return getDefaultVLOps(VecVT.getVectorNumElements(), ContainerVT, DL, DAG,
1925 Subtarget);
1926 assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
1927 MVT XLenVT = Subtarget.getXLenVT();
1928 SDValue VL = DAG.getRegister(RISCV::X0, XLenVT);
1929 SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
1930 return {Mask, VL};
1931 }
1932
1933 // As above but assuming the given type is a scalable vector type.
1934 static std::pair<SDValue, SDValue>
getDefaultScalableVLOps(MVT VecVT,SDLoc DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)1935 getDefaultScalableVLOps(MVT VecVT, SDLoc DL, SelectionDAG &DAG,
1936 const RISCVSubtarget &Subtarget) {
1937 assert(VecVT.isScalableVector() && "Expecting a scalable vector");
1938 return getDefaultVLOps(VecVT, VecVT, DL, DAG, Subtarget);
1939 }
1940
1941 // The state of RVV BUILD_VECTOR and VECTOR_SHUFFLE lowering is that very few
1942 // of either is (currently) supported. This can get us into an infinite loop
1943 // where we try to lower a BUILD_VECTOR as a VECTOR_SHUFFLE as a BUILD_VECTOR
1944 // as a ..., etc.
1945 // Until either (or both) of these can reliably lower any node, reporting that
1946 // we don't want to expand BUILD_VECTORs via VECTOR_SHUFFLEs at least breaks
1947 // the infinite loop. Note that this lowers BUILD_VECTOR through the stack,
1948 // which is not desirable.
shouldExpandBuildVectorWithShuffles(EVT VT,unsigned DefinedValues) const1949 bool RISCVTargetLowering::shouldExpandBuildVectorWithShuffles(
1950 EVT VT, unsigned DefinedValues) const {
1951 return false;
1952 }
1953
lowerFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)1954 static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
1955 const RISCVSubtarget &Subtarget) {
1956 // RISCV FP-to-int conversions saturate to the destination register size, but
1957 // don't produce 0 for nan. We can use a conversion instruction and fix the
1958 // nan case with a compare and a select.
1959 SDValue Src = Op.getOperand(0);
1960
1961 MVT DstVT = Op.getSimpleValueType();
1962 EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
1963
1964 bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT_SAT;
1965
1966 if (!DstVT.isVector()) {
1967 // In absense of Zfh, promote f16 to f32, then saturate the result.
1968 if (Src.getSimpleValueType() == MVT::f16 && !Subtarget.hasStdExtZfh()) {
1969 Src = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Src);
1970 }
1971
1972 unsigned Opc;
1973 if (SatVT == DstVT)
1974 Opc = IsSigned ? RISCVISD::FCVT_X : RISCVISD::FCVT_XU;
1975 else if (DstVT == MVT::i64 && SatVT == MVT::i32)
1976 Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
1977 else
1978 return SDValue();
1979 // FIXME: Support other SatVTs by clamping before or after the conversion.
1980
1981 SDLoc DL(Op);
1982 SDValue FpToInt = DAG.getNode(
1983 Opc, DL, DstVT, Src,
1984 DAG.getTargetConstant(RISCVFPRndMode::RTZ, DL, Subtarget.getXLenVT()));
1985
1986 if (Opc == RISCVISD::FCVT_WU_RV64)
1987 FpToInt = DAG.getZeroExtendInReg(FpToInt, DL, MVT::i32);
1988
1989 SDValue ZeroInt = DAG.getConstant(0, DL, DstVT);
1990 return DAG.getSelectCC(DL, Src, Src, ZeroInt, FpToInt,
1991 ISD::CondCode::SETUO);
1992 }
1993
1994 // Vectors.
1995
1996 MVT DstEltVT = DstVT.getVectorElementType();
1997 MVT SrcVT = Src.getSimpleValueType();
1998 MVT SrcEltVT = SrcVT.getVectorElementType();
1999 unsigned SrcEltSize = SrcEltVT.getSizeInBits();
2000 unsigned DstEltSize = DstEltVT.getSizeInBits();
2001
2002 // Only handle saturating to the destination type.
2003 if (SatVT != DstEltVT)
2004 return SDValue();
2005
2006 // FIXME: Don't support narrowing by more than 1 steps for now.
2007 if (SrcEltSize > (2 * DstEltSize))
2008 return SDValue();
2009
2010 MVT DstContainerVT = DstVT;
2011 MVT SrcContainerVT = SrcVT;
2012 if (DstVT.isFixedLengthVector()) {
2013 DstContainerVT = getContainerForFixedLengthVector(DAG, DstVT, Subtarget);
2014 SrcContainerVT = getContainerForFixedLengthVector(DAG, SrcVT, Subtarget);
2015 assert(DstContainerVT.getVectorElementCount() ==
2016 SrcContainerVT.getVectorElementCount() &&
2017 "Expected same element count");
2018 Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
2019 }
2020
2021 SDLoc DL(Op);
2022
2023 auto [Mask, VL] = getDefaultVLOps(DstVT, DstContainerVT, DL, DAG, Subtarget);
2024
2025 SDValue IsNan = DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(),
2026 {Src, Src, DAG.getCondCode(ISD::SETNE),
2027 DAG.getUNDEF(Mask.getValueType()), Mask, VL});
2028
2029 // Need to widen by more than 1 step, promote the FP type, then do a widening
2030 // convert.
2031 if (DstEltSize > (2 * SrcEltSize)) {
2032 assert(SrcContainerVT.getVectorElementType() == MVT::f16 && "Unexpected VT!");
2033 MVT InterVT = SrcContainerVT.changeVectorElementType(MVT::f32);
2034 Src = DAG.getNode(RISCVISD::FP_EXTEND_VL, DL, InterVT, Src, Mask, VL);
2035 }
2036
2037 unsigned RVVOpc =
2038 IsSigned ? RISCVISD::VFCVT_RTZ_X_F_VL : RISCVISD::VFCVT_RTZ_XU_F_VL;
2039 SDValue Res = DAG.getNode(RVVOpc, DL, DstContainerVT, Src, Mask, VL);
2040
2041 SDValue SplatZero = DAG.getNode(
2042 RISCVISD::VMV_V_X_VL, DL, DstContainerVT, DAG.getUNDEF(DstContainerVT),
2043 DAG.getConstant(0, DL, Subtarget.getXLenVT()), VL);
2044 Res = DAG.getNode(RISCVISD::VSELECT_VL, DL, DstContainerVT, IsNan, SplatZero,
2045 Res, VL);
2046
2047 if (DstVT.isFixedLengthVector())
2048 Res = convertFromScalableVector(DstVT, Res, DAG, Subtarget);
2049
2050 return Res;
2051 }
2052
matchRoundingOp(unsigned Opc)2053 static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
2054 switch (Opc) {
2055 case ISD::FROUNDEVEN:
2056 case ISD::VP_FROUNDEVEN:
2057 return RISCVFPRndMode::RNE;
2058 case ISD::FTRUNC:
2059 case ISD::VP_FROUNDTOZERO:
2060 return RISCVFPRndMode::RTZ;
2061 case ISD::FFLOOR:
2062 case ISD::VP_FFLOOR:
2063 return RISCVFPRndMode::RDN;
2064 case ISD::FCEIL:
2065 case ISD::VP_FCEIL:
2066 return RISCVFPRndMode::RUP;
2067 case ISD::FROUND:
2068 case ISD::VP_FROUND:
2069 return RISCVFPRndMode::RMM;
2070 case ISD::FRINT:
2071 return RISCVFPRndMode::DYN;
2072 }
2073
2074 return RISCVFPRndMode::Invalid;
2075 }
2076
2077 // Expand vector FTRUNC, FCEIL, FFLOOR, FROUND, VP_FCEIL, VP_FFLOOR, VP_FROUND
2078 // VP_FROUNDEVEN, VP_FROUNDTOZERO, VP_FRINT and VP_FNEARBYINT by converting to
2079 // the integer domain and back. Taking care to avoid converting values that are
2080 // nan or already correct.
2081 static SDValue
lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2082 lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
2083 const RISCVSubtarget &Subtarget) {
2084 MVT VT = Op.getSimpleValueType();
2085 assert(VT.isVector() && "Unexpected type");
2086
2087 SDLoc DL(Op);
2088
2089 SDValue Src = Op.getOperand(0);
2090
2091 MVT ContainerVT = VT;
2092 if (VT.isFixedLengthVector()) {
2093 ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
2094 Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
2095 }
2096
2097 SDValue Mask, VL;
2098 if (Op->isVPOpcode()) {
2099 Mask = Op.getOperand(1);
2100 VL = Op.getOperand(2);
2101 } else {
2102 std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
2103 }
2104
2105 // Freeze the source since we are increasing the number of uses.
2106 Src = DAG.getFreeze(Src);
2107
2108 // We do the conversion on the absolute value and fix the sign at the end.
2109 SDValue Abs = DAG.getNode(RISCVISD::FABS_VL, DL, ContainerVT, Src, Mask, VL);
2110
2111 // Determine the largest integer that can be represented exactly. This and
2112 // values larger than it don't have any fractional bits so don't need to
2113 // be converted.
2114 const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(ContainerVT);
2115 unsigned Precision = APFloat::semanticsPrecision(FltSem);
2116 APFloat MaxVal = APFloat(FltSem);
2117 MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
2118 /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
2119 SDValue MaxValNode =
2120 DAG.getConstantFP(MaxVal, DL, ContainerVT.getVectorElementType());
2121 SDValue MaxValSplat = DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
2122 DAG.getUNDEF(ContainerVT), MaxValNode, VL);
2123
2124 // If abs(Src) was larger than MaxVal or nan, keep it.
2125 MVT SetccVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
2126 Mask =
2127 DAG.getNode(RISCVISD::SETCC_VL, DL, SetccVT,
2128 {Abs, MaxValSplat, DAG.getCondCode(ISD::SETOLT),
2129 Mask, Mask, VL});
2130
2131 // Truncate to integer and convert back to FP.
2132 MVT IntVT = ContainerVT.changeVectorElementTypeToInteger();
2133 MVT XLenVT = Subtarget.getXLenVT();
2134 SDValue Truncated;
2135
2136 switch (Op.getOpcode()) {
2137 default:
2138 llvm_unreachable("Unexpected opcode");
2139 case ISD::FCEIL:
2140 case ISD::VP_FCEIL:
2141 case ISD::FFLOOR:
2142 case ISD::VP_FFLOOR:
2143 case ISD::FROUND:
2144 case ISD::FROUNDEVEN:
2145 case ISD::VP_FROUND:
2146 case ISD::VP_FROUNDEVEN:
2147 case ISD::VP_FROUNDTOZERO: {
2148 RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Op.getOpcode());
2149 assert(FRM != RISCVFPRndMode::Invalid);
2150 Truncated = DAG.getNode(RISCVISD::VFCVT_RM_X_F_VL, DL, IntVT, Src, Mask,
2151 DAG.getTargetConstant(FRM, DL, XLenVT), VL);
2152 break;
2153 }
2154 case ISD::FTRUNC:
2155 Truncated = DAG.getNode(RISCVISD::VFCVT_RTZ_X_F_VL, DL, IntVT, Src,
2156 Mask, VL);
2157 break;
2158 case ISD::VP_FRINT:
2159 Truncated = DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, IntVT, Src, Mask, VL);
2160 break;
2161 case ISD::VP_FNEARBYINT:
2162 Truncated = DAG.getNode(RISCVISD::VFROUND_NOEXCEPT_VL, DL, ContainerVT, Src,
2163 Mask, VL);
2164 break;
2165 }
2166
2167 // VFROUND_NOEXCEPT_VL includes SINT_TO_FP_VL.
2168 if (Op.getOpcode() != ISD::VP_FNEARBYINT)
2169 Truncated = DAG.getNode(RISCVISD::SINT_TO_FP_VL, DL, ContainerVT, Truncated,
2170 Mask, VL);
2171
2172 // Restore the original sign so that -0.0 is preserved.
2173 Truncated = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Truncated,
2174 Src, Src, Mask, VL);
2175
2176 if (!VT.isFixedLengthVector())
2177 return Truncated;
2178
2179 return convertFromScalableVector(VT, Truncated, DAG, Subtarget);
2180 }
2181
2182 static SDValue
lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2183 lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
2184 const RISCVSubtarget &Subtarget) {
2185 MVT VT = Op.getSimpleValueType();
2186 if (VT.isVector())
2187 return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
2188
2189 if (DAG.shouldOptForSize())
2190 return SDValue();
2191
2192 SDLoc DL(Op);
2193 SDValue Src = Op.getOperand(0);
2194
2195 // Create an integer the size of the mantissa with the MSB set. This and all
2196 // values larger than it don't have any fractional bits so don't need to be
2197 // converted.
2198 const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
2199 unsigned Precision = APFloat::semanticsPrecision(FltSem);
2200 APFloat MaxVal = APFloat(FltSem);
2201 MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
2202 /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
2203 SDValue MaxValNode = DAG.getConstantFP(MaxVal, DL, VT);
2204
2205 RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Op.getOpcode());
2206 return DAG.getNode(RISCVISD::FROUND, DL, VT, Src, MaxValNode,
2207 DAG.getTargetConstant(FRM, DL, Subtarget.getXLenVT()));
2208 }
2209
2210 struct VIDSequence {
2211 int64_t StepNumerator;
2212 unsigned StepDenominator;
2213 int64_t Addend;
2214 };
2215
getExactInteger(const APFloat & APF,uint32_t BitWidth)2216 static std::optional<uint64_t> getExactInteger(const APFloat &APF,
2217 uint32_t BitWidth) {
2218 APSInt ValInt(BitWidth, !APF.isNegative());
2219 // We use an arbitrary rounding mode here. If a floating-point is an exact
2220 // integer (e.g., 1.0), the rounding mode does not affect the output value. If
2221 // the rounding mode changes the output value, then it is not an exact
2222 // integer.
2223 RoundingMode ArbitraryRM = RoundingMode::TowardZero;
2224 bool IsExact;
2225 // If it is out of signed integer range, it will return an invalid operation.
2226 // If it is not an exact integer, IsExact is false.
2227 if ((APF.convertToInteger(ValInt, ArbitraryRM, &IsExact) ==
2228 APFloatBase::opInvalidOp) ||
2229 !IsExact)
2230 return std::nullopt;
2231 return ValInt.extractBitsAsZExtValue(BitWidth, 0);
2232 }
2233
2234 // Try to match an arithmetic-sequence BUILD_VECTOR [X,X+S,X+2*S,...,X+(N-1)*S]
2235 // to the (non-zero) step S and start value X. This can be then lowered as the
2236 // RVV sequence (VID * S) + X, for example.
2237 // The step S is represented as an integer numerator divided by a positive
2238 // denominator. Note that the implementation currently only identifies
2239 // sequences in which either the numerator is +/- 1 or the denominator is 1. It
2240 // cannot detect 2/3, for example.
2241 // Note that this method will also match potentially unappealing index
2242 // sequences, like <i32 0, i32 50939494>, however it is left to the caller to
2243 // determine whether this is worth generating code for.
isSimpleVIDSequence(SDValue Op)2244 static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
2245 unsigned NumElts = Op.getNumOperands();
2246 assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
2247 bool IsInteger = Op.getValueType().isInteger();
2248
2249 std::optional<unsigned> SeqStepDenom;
2250 std::optional<int64_t> SeqStepNum, SeqAddend;
2251 std::optional<std::pair<uint64_t, unsigned>> PrevElt;
2252 unsigned EltSizeInBits = Op.getValueType().getScalarSizeInBits();
2253 for (unsigned Idx = 0; Idx < NumElts; Idx++) {
2254 // Assume undef elements match the sequence; we just have to be careful
2255 // when interpolating across them.
2256 if (Op.getOperand(Idx).isUndef())
2257 continue;
2258
2259 uint64_t Val;
2260 if (IsInteger) {
2261 // The BUILD_VECTOR must be all constants.
2262 if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
2263 return std::nullopt;
2264 Val = Op.getConstantOperandVal(Idx) &
2265 maskTrailingOnes<uint64_t>(EltSizeInBits);
2266 } else {
2267 // The BUILD_VECTOR must be all constants.
2268 if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
2269 return std::nullopt;
2270 if (auto ExactInteger = getExactInteger(
2271 cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
2272 EltSizeInBits))
2273 Val = *ExactInteger;
2274 else
2275 return std::nullopt;
2276 }
2277
2278 if (PrevElt) {
2279 // Calculate the step since the last non-undef element, and ensure
2280 // it's consistent across the entire sequence.
2281 unsigned IdxDiff = Idx - PrevElt->second;
2282 int64_t ValDiff = SignExtend64(Val - PrevElt->first, EltSizeInBits);
2283
2284 // A zero-value value difference means that we're somewhere in the middle
2285 // of a fractional step, e.g. <0,0,0*,0,1,1,1,1>. Wait until we notice a
2286 // step change before evaluating the sequence.
2287 if (ValDiff == 0)
2288 continue;
2289
2290 int64_t Remainder = ValDiff % IdxDiff;
2291 // Normalize the step if it's greater than 1.
2292 if (Remainder != ValDiff) {
2293 // The difference must cleanly divide the element span.
2294 if (Remainder != 0)
2295 return std::nullopt;
2296 ValDiff /= IdxDiff;
2297 IdxDiff = 1;
2298 }
2299
2300 if (!SeqStepNum)
2301 SeqStepNum = ValDiff;
2302 else if (ValDiff != SeqStepNum)
2303 return std::nullopt;
2304
2305 if (!SeqStepDenom)
2306 SeqStepDenom = IdxDiff;
2307 else if (IdxDiff != *SeqStepDenom)
2308 return std::nullopt;
2309 }
2310
2311 // Record this non-undef element for later.
2312 if (!PrevElt || PrevElt->first != Val)
2313 PrevElt = std::make_pair(Val, Idx);
2314 }
2315
2316 // We need to have logged a step for this to count as a legal index sequence.
2317 if (!SeqStepNum || !SeqStepDenom)
2318 return std::nullopt;
2319
2320 // Loop back through the sequence and validate elements we might have skipped
2321 // while waiting for a valid step. While doing this, log any sequence addend.
2322 for (unsigned Idx = 0; Idx < NumElts; Idx++) {
2323 if (Op.getOperand(Idx).isUndef())
2324 continue;
2325 uint64_t Val;
2326 if (IsInteger) {
2327 Val = Op.getConstantOperandVal(Idx) &
2328 maskTrailingOnes<uint64_t>(EltSizeInBits);
2329 } else {
2330 Val = *getExactInteger(
2331 cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
2332 EltSizeInBits);
2333 }
2334 uint64_t ExpectedVal =
2335 (int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
2336 int64_t Addend = SignExtend64(Val - ExpectedVal, EltSizeInBits);
2337 if (!SeqAddend)
2338 SeqAddend = Addend;
2339 else if (Addend != SeqAddend)
2340 return std::nullopt;
2341 }
2342
2343 assert(SeqAddend && "Must have an addend if we have a step");
2344
2345 return VIDSequence{*SeqStepNum, *SeqStepDenom, *SeqAddend};
2346 }
2347
2348 // Match a splatted value (SPLAT_VECTOR/BUILD_VECTOR) of an EXTRACT_VECTOR_ELT
2349 // and lower it as a VRGATHER_VX_VL from the source vector.
matchSplatAsGather(SDValue SplatVal,MVT VT,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2350 static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
2351 SelectionDAG &DAG,
2352 const RISCVSubtarget &Subtarget) {
2353 if (SplatVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
2354 return SDValue();
2355 SDValue Vec = SplatVal.getOperand(0);
2356 // Only perform this optimization on vectors of the same size for simplicity.
2357 // Don't perform this optimization for i1 vectors.
2358 // FIXME: Support i1 vectors, maybe by promoting to i8?
2359 if (Vec.getValueType() != VT || VT.getVectorElementType() == MVT::i1)
2360 return SDValue();
2361 SDValue Idx = SplatVal.getOperand(1);
2362 // The index must be a legal type.
2363 if (Idx.getValueType() != Subtarget.getXLenVT())
2364 return SDValue();
2365
2366 MVT ContainerVT = VT;
2367 if (VT.isFixedLengthVector()) {
2368 ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
2369 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
2370 }
2371
2372 auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
2373
2374 SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec,
2375 Idx, DAG.getUNDEF(ContainerVT), Mask, VL);
2376
2377 if (!VT.isFixedLengthVector())
2378 return Gather;
2379
2380 return convertFromScalableVector(VT, Gather, DAG, Subtarget);
2381 }
2382
lowerBUILD_VECTOR(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2383 static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
2384 const RISCVSubtarget &Subtarget) {
2385 MVT VT = Op.getSimpleValueType();
2386 assert(VT.isFixedLengthVector() && "Unexpected vector!");
2387
2388 MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
2389
2390 SDLoc DL(Op);
2391 auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
2392
2393 MVT XLenVT = Subtarget.getXLenVT();
2394 unsigned NumElts = Op.getNumOperands();
2395
2396 if (VT.getVectorElementType() == MVT::i1) {
2397 if (ISD::isBuildVectorAllZeros(Op.getNode())) {
2398 SDValue VMClr = DAG.getNode(RISCVISD::VMCLR_VL, DL, ContainerVT, VL);
2399 return convertFromScalableVector(VT, VMClr, DAG, Subtarget);
2400 }
2401
2402 if (ISD::isBuildVectorAllOnes(Op.getNode())) {
2403 SDValue VMSet = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL);
2404 return convertFromScalableVector(VT, VMSet, DAG, Subtarget);
2405 }
2406
2407 // Lower constant mask BUILD_VECTORs via an integer vector type, in
2408 // scalar integer chunks whose bit-width depends on the number of mask
2409 // bits and XLEN.
2410 // First, determine the most appropriate scalar integer type to use. This
2411 // is at most XLenVT, but may be shrunk to a smaller vector element type
2412 // according to the size of the final vector - use i8 chunks rather than
2413 // XLenVT if we're producing a v8i1. This results in more consistent
2414 // codegen across RV32 and RV64.
2415 unsigned NumViaIntegerBits =
2416 std::min(std::max(NumElts, 8u), Subtarget.getXLen());
2417 NumViaIntegerBits = std::min(NumViaIntegerBits, Subtarget.getELEN());
2418 if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) {
2419 // If we have to use more than one INSERT_VECTOR_ELT then this
2420 // optimization is likely to increase code size; avoid peforming it in
2421 // such a case. We can use a load from a constant pool in this case.
2422 if (DAG.shouldOptForSize() && NumElts > NumViaIntegerBits)
2423 return SDValue();
2424 // Now we can create our integer vector type. Note that it may be larger
2425 // than the resulting mask type: v4i1 would use v1i8 as its integer type.
2426 MVT IntegerViaVecVT =
2427 MVT::getVectorVT(MVT::getIntegerVT(NumViaIntegerBits),
2428 divideCeil(NumElts, NumViaIntegerBits));
2429
2430 uint64_t Bits = 0;
2431 unsigned BitPos = 0, IntegerEltIdx = 0;
2432 SDValue Vec = DAG.getUNDEF(IntegerViaVecVT);
2433
2434 for (unsigned I = 0; I < NumElts; I++, BitPos++) {
2435 // Once we accumulate enough bits to fill our scalar type, insert into
2436 // our vector and clear our accumulated data.
2437 if (I != 0 && I % NumViaIntegerBits == 0) {
2438 if (NumViaIntegerBits <= 32)
2439 Bits = SignExtend64<32>(Bits);
2440 SDValue Elt = DAG.getConstant(Bits, DL, XLenVT);
2441 Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, IntegerViaVecVT, Vec,
2442 Elt, DAG.getConstant(IntegerEltIdx, DL, XLenVT));
2443 Bits = 0;
2444 BitPos = 0;
2445 IntegerEltIdx++;
2446 }
2447 SDValue V = Op.getOperand(I);
2448 bool BitValue = !V.isUndef() && cast<ConstantSDNode>(V)->getZExtValue();
2449 Bits |= ((uint64_t)BitValue << BitPos);
2450 }
2451
2452 // Insert the (remaining) scalar value into position in our integer
2453 // vector type.
2454 if (NumViaIntegerBits <= 32)
2455 Bits = SignExtend64<32>(Bits);
2456 SDValue Elt = DAG.getConstant(Bits, DL, XLenVT);
2457 Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, IntegerViaVecVT, Vec, Elt,
2458 DAG.getConstant(IntegerEltIdx, DL, XLenVT));
2459
2460 if (NumElts < NumViaIntegerBits) {
2461 // If we're producing a smaller vector than our minimum legal integer
2462 // type, bitcast to the equivalent (known-legal) mask type, and extract
2463 // our final mask.
2464 assert(IntegerViaVecVT == MVT::v1i8 && "Unexpected mask vector type");
2465 Vec = DAG.getBitcast(MVT::v8i1, Vec);
2466 Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Vec,
2467 DAG.getConstant(0, DL, XLenVT));
2468 } else {
2469 // Else we must have produced an integer type with the same size as the
2470 // mask type; bitcast for the final result.
2471 assert(VT.getSizeInBits() == IntegerViaVecVT.getSizeInBits());
2472 Vec = DAG.getBitcast(VT, Vec);
2473 }
2474
2475 return Vec;
2476 }
2477
2478 // A BUILD_VECTOR can be lowered as a SETCC. For each fixed-length mask
2479 // vector type, we have a legal equivalently-sized i8 type, so we can use
2480 // that.
2481 MVT WideVecVT = VT.changeVectorElementType(MVT::i8);
2482 SDValue VecZero = DAG.getConstant(0, DL, WideVecVT);
2483
2484 SDValue WideVec;
2485 if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
2486 // For a splat, perform a scalar truncate before creating the wider
2487 // vector.
2488 assert(Splat.getValueType() == XLenVT &&
2489 "Unexpected type for i1 splat value");
2490 Splat = DAG.getNode(ISD::AND, DL, XLenVT, Splat,
2491 DAG.getConstant(1, DL, XLenVT));
2492 WideVec = DAG.getSplatBuildVector(WideVecVT, DL, Splat);
2493 } else {
2494 SmallVector<SDValue, 8> Ops(Op->op_values());
2495 WideVec = DAG.getBuildVector(WideVecVT, DL, Ops);
2496 SDValue VecOne = DAG.getConstant(1, DL, WideVecVT);
2497 WideVec = DAG.getNode(ISD::AND, DL, WideVecVT, WideVec, VecOne);
2498 }
2499
2500 return DAG.getSetCC(DL, VT, WideVec, VecZero, ISD::SETNE);
2501 }
2502
2503 if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
2504 if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget))
2505 return Gather;
2506 unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
2507 : RISCVISD::VMV_V_X_VL;
2508 Splat =
2509 DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Splat, VL);
2510 return convertFromScalableVector(VT, Splat, DAG, Subtarget);
2511 }
2512
2513 // Try and match index sequences, which we can lower to the vid instruction
2514 // with optional modifications. An all-undef vector is matched by
2515 // getSplatValue, above.
2516 if (auto SimpleVID = isSimpleVIDSequence(Op)) {
2517 int64_t StepNumerator = SimpleVID->StepNumerator;
2518 unsigned StepDenominator = SimpleVID->StepDenominator;
2519 int64_t Addend = SimpleVID->Addend;
2520
2521 assert(StepNumerator != 0 && "Invalid step");
2522 bool Negate = false;
2523 int64_t SplatStepVal = StepNumerator;
2524 unsigned StepOpcode = ISD::MUL;
2525 if (StepNumerator != 1) {
2526 if (isPowerOf2_64(std::abs(StepNumerator))) {
2527 Negate = StepNumerator < 0;
2528 StepOpcode = ISD::SHL;
2529 SplatStepVal = Log2_64(std::abs(StepNumerator));
2530 }
2531 }
2532
2533 // Only emit VIDs with suitably-small steps/addends. We use imm5 is a
2534 // threshold since it's the immediate value many RVV instructions accept.
2535 // There is no vmul.vi instruction so ensure multiply constant can fit in
2536 // a single addi instruction.
2537 if (((StepOpcode == ISD::MUL && isInt<12>(SplatStepVal)) ||
2538 (StepOpcode == ISD::SHL && isUInt<5>(SplatStepVal))) &&
2539 isPowerOf2_32(StepDenominator) &&
2540 (SplatStepVal >= 0 || StepDenominator == 1) && isInt<5>(Addend)) {
2541 MVT VIDVT =
2542 VT.isFloatingPoint() ? VT.changeVectorElementTypeToInteger() : VT;
2543 MVT VIDContainerVT =
2544 getContainerForFixedLengthVector(DAG, VIDVT, Subtarget);
2545 SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VIDContainerVT, Mask, VL);
2546 // Convert right out of the scalable type so we can use standard ISD
2547 // nodes for the rest of the computation. If we used scalable types with
2548 // these, we'd lose the fixed-length vector info and generate worse
2549 // vsetvli code.
2550 VID = convertFromScalableVector(VIDVT, VID, DAG, Subtarget);
2551 if ((StepOpcode == ISD::MUL && SplatStepVal != 1) ||
2552 (StepOpcode == ISD::SHL && SplatStepVal != 0)) {
2553 SDValue SplatStep = DAG.getSplatBuildVector(
2554 VIDVT, DL, DAG.getConstant(SplatStepVal, DL, XLenVT));
2555 VID = DAG.getNode(StepOpcode, DL, VIDVT, VID, SplatStep);
2556 }
2557 if (StepDenominator != 1) {
2558 SDValue SplatStep = DAG.getSplatBuildVector(
2559 VIDVT, DL, DAG.getConstant(Log2_64(StepDenominator), DL, XLenVT));
2560 VID = DAG.getNode(ISD::SRL, DL, VIDVT, VID, SplatStep);
2561 }
2562 if (Addend != 0 || Negate) {
2563 SDValue SplatAddend = DAG.getSplatBuildVector(
2564 VIDVT, DL, DAG.getConstant(Addend, DL, XLenVT));
2565 VID = DAG.getNode(Negate ? ISD::SUB : ISD::ADD, DL, VIDVT, SplatAddend,
2566 VID);
2567 }
2568 if (VT.isFloatingPoint()) {
2569 // TODO: Use vfwcvt to reduce register pressure.
2570 VID = DAG.getNode(ISD::SINT_TO_FP, DL, VT, VID);
2571 }
2572 return VID;
2573 }
2574 }
2575
2576 // Attempt to detect "hidden" splats, which only reveal themselves as splats
2577 // when re-interpreted as a vector with a larger element type. For example,
2578 // v4i16 = build_vector i16 0, i16 1, i16 0, i16 1
2579 // could be instead splat as
2580 // v2i32 = build_vector i32 0x00010000, i32 0x00010000
2581 // TODO: This optimization could also work on non-constant splats, but it
2582 // would require bit-manipulation instructions to construct the splat value.
2583 SmallVector<SDValue> Sequence;
2584 unsigned EltBitSize = VT.getScalarSizeInBits();
2585 const auto *BV = cast<BuildVectorSDNode>(Op);
2586 if (VT.isInteger() && EltBitSize < 64 &&
2587 ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) &&
2588 BV->getRepeatedSequence(Sequence) &&
2589 (Sequence.size() * EltBitSize) <= 64) {
2590 unsigned SeqLen = Sequence.size();
2591 MVT ViaIntVT = MVT::getIntegerVT(EltBitSize * SeqLen);
2592 MVT ViaVecVT = MVT::getVectorVT(ViaIntVT, NumElts / SeqLen);
2593 assert((ViaIntVT == MVT::i16 || ViaIntVT == MVT::i32 ||
2594 ViaIntVT == MVT::i64) &&
2595 "Unexpected sequence type");
2596
2597 unsigned EltIdx = 0;
2598 uint64_t EltMask = maskTrailingOnes<uint64_t>(EltBitSize);
2599 uint64_t SplatValue = 0;
2600 // Construct the amalgamated value which can be splatted as this larger
2601 // vector type.
2602 for (const auto &SeqV : Sequence) {
2603 if (!SeqV.isUndef())
2604 SplatValue |= ((cast<ConstantSDNode>(SeqV)->getZExtValue() & EltMask)
2605 << (EltIdx * EltBitSize));
2606 EltIdx++;
2607 }
2608
2609 // On RV64, sign-extend from 32 to 64 bits where possible in order to
2610 // achieve better constant materializion.
2611 if (Subtarget.is64Bit() && ViaIntVT == MVT::i32)
2612 SplatValue = SignExtend64<32>(SplatValue);
2613
2614 // Since we can't introduce illegal i64 types at this stage, we can only
2615 // perform an i64 splat on RV32 if it is its own sign-extended value. That
2616 // way we can use RVV instructions to splat.
2617 assert((ViaIntVT.bitsLE(XLenVT) ||
2618 (!Subtarget.is64Bit() && ViaIntVT == MVT::i64)) &&
2619 "Unexpected bitcast sequence");
2620 if (ViaIntVT.bitsLE(XLenVT) || isInt<32>(SplatValue)) {
2621 SDValue ViaVL =
2622 DAG.getConstant(ViaVecVT.getVectorNumElements(), DL, XLenVT);
2623 MVT ViaContainerVT =
2624 getContainerForFixedLengthVector(DAG, ViaVecVT, Subtarget);
2625 SDValue Splat =
2626 DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ViaContainerVT,
2627 DAG.getUNDEF(ViaContainerVT),
2628 DAG.getConstant(SplatValue, DL, XLenVT), ViaVL);
2629 Splat = convertFromScalableVector(ViaVecVT, Splat, DAG, Subtarget);
2630 return DAG.getBitcast(VT, Splat);
2631 }
2632 }
2633
2634 // Try and optimize BUILD_VECTORs with "dominant values" - these are values
2635 // which constitute a large proportion of the elements. In such cases we can
2636 // splat a vector with the dominant element and make up the shortfall with
2637 // INSERT_VECTOR_ELTs.
2638 // Note that this includes vectors of 2 elements by association. The
2639 // upper-most element is the "dominant" one, allowing us to use a splat to
2640 // "insert" the upper element, and an insert of the lower element at position
2641 // 0, which improves codegen.
2642 SDValue DominantValue;
2643 unsigned MostCommonCount = 0;
2644 DenseMap<SDValue, unsigned> ValueCounts;
2645 unsigned NumUndefElts =
2646 count_if(Op->op_values(), [](const SDValue &V) { return V.isUndef(); });
2647
2648 // Track the number of scalar loads we know we'd be inserting, estimated as
2649 // any non-zero floating-point constant. Other kinds of element are either
2650 // already in registers or are materialized on demand. The threshold at which
2651 // a vector load is more desirable than several scalar materializion and
2652 // vector-insertion instructions is not known.
2653 unsigned NumScalarLoads = 0;
2654
2655 for (SDValue V : Op->op_values()) {
2656 if (V.isUndef())
2657 continue;
2658
2659 ValueCounts.insert(std::make_pair(V, 0));
2660 unsigned &Count = ValueCounts[V];
2661
2662 if (auto *CFP = dyn_cast<ConstantFPSDNode>(V))
2663 NumScalarLoads += !CFP->isExactlyValue(+0.0);
2664
2665 // Is this value dominant? In case of a tie, prefer the highest element as
2666 // it's cheaper to insert near the beginning of a vector than it is at the
2667 // end.
2668 if (++Count >= MostCommonCount) {
2669 DominantValue = V;
2670 MostCommonCount = Count;
2671 }
2672 }
2673
2674 assert(DominantValue && "Not expecting an all-undef BUILD_VECTOR");
2675 unsigned NumDefElts = NumElts - NumUndefElts;
2676 unsigned DominantValueCountThreshold = NumDefElts <= 2 ? 0 : NumDefElts - 2;
2677
2678 // Don't perform this optimization when optimizing for size, since
2679 // materializing elements and inserting them tends to cause code bloat.
2680 if (!DAG.shouldOptForSize() && NumScalarLoads < NumElts &&
2681 ((MostCommonCount > DominantValueCountThreshold) ||
2682 (ValueCounts.size() <= Log2_32(NumDefElts)))) {
2683 // Start by splatting the most common element.
2684 SDValue Vec = DAG.getSplatBuildVector(VT, DL, DominantValue);
2685
2686 DenseSet<SDValue> Processed{DominantValue};
2687 MVT SelMaskTy = VT.changeVectorElementType(MVT::i1);
2688 for (const auto &OpIdx : enumerate(Op->ops())) {
2689 const SDValue &V = OpIdx.value();
2690 if (V.isUndef() || !Processed.insert(V).second)
2691 continue;
2692 if (ValueCounts[V] == 1) {
2693 Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V,
2694 DAG.getConstant(OpIdx.index(), DL, XLenVT));
2695 } else {
2696 // Blend in all instances of this value using a VSELECT, using a
2697 // mask where each bit signals whether that element is the one
2698 // we're after.
2699 SmallVector<SDValue> Ops;
2700 transform(Op->op_values(), std::back_inserter(Ops), [&](SDValue V1) {
2701 return DAG.getConstant(V == V1, DL, XLenVT);
2702 });
2703 Vec = DAG.getNode(ISD::VSELECT, DL, VT,
2704 DAG.getBuildVector(SelMaskTy, DL, Ops),
2705 DAG.getSplatBuildVector(VT, DL, V), Vec);
2706 }
2707 }
2708
2709 return Vec;
2710 }
2711
2712 return SDValue();
2713 }
2714
splatPartsI64WithVL(const SDLoc & DL,MVT VT,SDValue Passthru,SDValue Lo,SDValue Hi,SDValue VL,SelectionDAG & DAG)2715 static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
2716 SDValue Lo, SDValue Hi, SDValue VL,
2717 SelectionDAG &DAG) {
2718 if (!Passthru)
2719 Passthru = DAG.getUNDEF(VT);
2720 if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
2721 int32_t LoC = cast<ConstantSDNode>(Lo)->getSExtValue();
2722 int32_t HiC = cast<ConstantSDNode>(Hi)->getSExtValue();
2723 // If Hi constant is all the same sign bit as Lo, lower this as a custom
2724 // node in order to try and match RVV vector/scalar instructions.
2725 if ((LoC >> 31) == HiC)
2726 return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL);
2727
2728 // If vl is equal to XLEN_MAX and Hi constant is equal to Lo, we could use
2729 // vmv.v.x whose EEW = 32 to lower it.
2730 auto *Const = dyn_cast<ConstantSDNode>(VL);
2731 if (LoC == HiC && Const && Const->isAllOnesValue()) {
2732 MVT InterVT = MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2);
2733 // TODO: if vl <= min(VLMAX), we can also do this. But we could not
2734 // access the subtarget here now.
2735 auto InterVec = DAG.getNode(
2736 RISCVISD::VMV_V_X_VL, DL, InterVT, DAG.getUNDEF(InterVT), Lo,
2737 DAG.getRegister(RISCV::X0, MVT::i32));
2738 return DAG.getNode(ISD::BITCAST, DL, VT, InterVec);
2739 }
2740 }
2741
2742 // Fall back to a stack store and stride x0 vector load.
2743 return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VT, Passthru, Lo,
2744 Hi, VL);
2745 }
2746
2747 // Called by type legalization to handle splat of i64 on RV32.
2748 // FIXME: We can optimize this when the type has sign or zero bits in one
2749 // of the halves.
splatSplitI64WithVL(const SDLoc & DL,MVT VT,SDValue Passthru,SDValue Scalar,SDValue VL,SelectionDAG & DAG)2750 static SDValue splatSplitI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
2751 SDValue Scalar, SDValue VL,
2752 SelectionDAG &DAG) {
2753 assert(Scalar.getValueType() == MVT::i64 && "Unexpected VT!");
2754 SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, Scalar,
2755 DAG.getConstant(0, DL, MVT::i32));
2756 SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, Scalar,
2757 DAG.getConstant(1, DL, MVT::i32));
2758 return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG);
2759 }
2760
2761 // This function lowers a splat of a scalar operand Splat with the vector
2762 // length VL. It ensures the final sequence is type legal, which is useful when
2763 // lowering a splat after type legalization.
lowerScalarSplat(SDValue Passthru,SDValue Scalar,SDValue VL,MVT VT,SDLoc DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2764 static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
2765 MVT VT, SDLoc DL, SelectionDAG &DAG,
2766 const RISCVSubtarget &Subtarget) {
2767 bool HasPassthru = Passthru && !Passthru.isUndef();
2768 if (!HasPassthru && !Passthru)
2769 Passthru = DAG.getUNDEF(VT);
2770 if (VT.isFloatingPoint()) {
2771 // If VL is 1, we could use vfmv.s.f.
2772 if (isOneConstant(VL))
2773 return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, VT, Passthru, Scalar, VL);
2774 return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL);
2775 }
2776
2777 MVT XLenVT = Subtarget.getXLenVT();
2778
2779 // Simplest case is that the operand needs to be promoted to XLenVT.
2780 if (Scalar.getValueType().bitsLE(XLenVT)) {
2781 // If the operand is a constant, sign extend to increase our chances
2782 // of being able to use a .vi instruction. ANY_EXTEND would become a
2783 // a zero extend and the simm5 check in isel would fail.
2784 // FIXME: Should we ignore the upper bits in isel instead?
2785 unsigned ExtOpc =
2786 isa<ConstantSDNode>(Scalar) ? ISD::SIGN_EXTEND : ISD::ANY_EXTEND;
2787 Scalar = DAG.getNode(ExtOpc, DL, XLenVT, Scalar);
2788 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar);
2789 // If VL is 1 and the scalar value won't benefit from immediate, we could
2790 // use vmv.s.x.
2791 if (isOneConstant(VL) &&
2792 (!Const || isNullConstant(Scalar) || !isInt<5>(Const->getSExtValue())))
2793 return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL);
2794 return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
2795 }
2796
2797 assert(XLenVT == MVT::i32 && Scalar.getValueType() == MVT::i64 &&
2798 "Unexpected scalar for splat lowering!");
2799
2800 if (isOneConstant(VL) && isNullConstant(Scalar))
2801 return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru,
2802 DAG.getConstant(0, DL, XLenVT), VL);
2803
2804 // Otherwise use the more complicated splatting algorithm.
2805 return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG);
2806 }
2807
getLMUL1VT(MVT VT)2808 static MVT getLMUL1VT(MVT VT) {
2809 assert(VT.getVectorElementType().getSizeInBits() <= 64 &&
2810 "Unexpected vector MVT");
2811 return MVT::getScalableVectorVT(
2812 VT.getVectorElementType(),
2813 RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits());
2814 }
2815
2816 // This function lowers an insert of a scalar operand Scalar into lane
2817 // 0 of the vector regardless of the value of VL. The contents of the
2818 // remaining lanes of the result vector are unspecified. VL is assumed
2819 // to be non-zero.
lowerScalarInsert(SDValue Scalar,SDValue VL,MVT VT,SDLoc DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)2820 static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL,
2821 MVT VT, SDLoc DL, SelectionDAG &DAG,
2822 const RISCVSubtarget &Subtarget) {
2823 const MVT XLenVT = Subtarget.getXLenVT();
2824
2825 SDValue Passthru = DAG.getUNDEF(VT);
2826 if (VT.isFloatingPoint()) {
2827 // TODO: Use vmv.v.i for appropriate constants
2828 // Use M1 or smaller to avoid over constraining register allocation
2829 const MVT M1VT = getLMUL1VT(VT);
2830 auto InnerVT = VT.bitsLE(M1VT) ? VT : M1VT;
2831 SDValue Result = DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, InnerVT,
2832 DAG.getUNDEF(InnerVT), Scalar, VL);
2833 if (VT != InnerVT)
2834 Result = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
2835 DAG.getUNDEF(VT),
2836 Result, DAG.getConstant(0, DL, XLenVT));
2837 return Result;
2838 }
2839
2840
2841 // Avoid the tricky legalization cases by falling back to using the
2842 // splat code which already handles it gracefully.
2843 if (!Scalar.getValueType().bitsLE(XLenVT))
2844 return lowerScalarSplat(DAG.getUNDEF(VT), Scalar,
2845 DAG.getConstant(1, DL, XLenVT),
2846 VT, DL, DAG, Subtarget);
2847
2848 // If the operand is a constant, sign extend to increase our chances
2849 // of being able to use a .vi instruction. ANY_EXTEND would become a
2850 // a zero extend and the simm5 check in isel would fail.
2851 // FIXME: Should we ignore the upper bits in isel instead?
2852 unsigned ExtOpc =
2853 isa<ConstantSDNode>(Scalar) ? ISD::SIGN_EXTEND : ISD::ANY_EXTEND;
2854 Scalar = DAG.getNode(ExtOpc, DL, XLenVT, Scalar);
2855 // We use a vmv.v.i if possible. We limit this to LMUL1. LMUL2 or
2856 // higher would involve overly constraining the register allocator for
2857 // no purpose.
2858 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar)) {
2859 if (!isNullConstant(Scalar) && isInt<5>(Const->getSExtValue()) &&
2860 VT.bitsLE(getLMUL1VT(VT)))
2861 return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
2862 }
2863 // Use M1 or smaller to avoid over constraining register allocation
2864 const MVT M1VT = getLMUL1VT(VT);
2865 auto InnerVT = VT.bitsLE(M1VT) ? VT : M1VT;
2866 SDValue Result = DAG.getNode(RISCVISD::VMV_S_X_VL, DL, InnerVT,
2867 DAG.getUNDEF(InnerVT), Scalar, VL);
2868 if (VT != InnerVT)
2869 Result = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
2870 DAG.getUNDEF(VT),
2871 Result, DAG.getConstant(0, DL, XLenVT));
2872 return Result;
2873
2874 }
2875
isInterleaveShuffle(ArrayRef<int> Mask,MVT VT,bool & SwapSources,const RISCVSubtarget & Subtarget)2876 static bool isInterleaveShuffle(ArrayRef<int> Mask, MVT VT, bool &SwapSources,
2877 const RISCVSubtarget &Subtarget) {
2878 // We need to be able to widen elements to the next larger integer type.
2879 if (VT.getScalarSizeInBits() >= Subtarget.getELEN())
2880 return false;
2881
2882 int Size = Mask.size();
2883 assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size");
2884
2885 int Srcs[] = {-1, -1};
2886 for (int i = 0; i != Size; ++i) {
2887 // Ignore undef elements.
2888 if (Mask[i] < 0)
2889 continue;
2890
2891 // Is this an even or odd element.
2892 int Pol = i % 2;
2893
2894 // Ensure we consistently use the same source for this element polarity.
2895 int Src = Mask[i] / Size;
2896 if (Srcs[Pol] < 0)
2897 Srcs[Pol] = Src;
2898 if (Srcs[Pol] != Src)
2899 return false;
2900
2901 // Make sure the element within the source is appropriate for this element
2902 // in the destination.
2903 int Elt = Mask[i] % Size;
2904 if (Elt != i / 2)
2905 return false;
2906 }
2907
2908 // We need to find a source for each polarity and they can't be the same.
2909 if (Srcs[0] < 0 || Srcs[1] < 0 || Srcs[0] == Srcs[1])
2910 return false;
2911
2912 // Swap the sources if the second source was in the even polarity.
2913 SwapSources = Srcs[0] > Srcs[1];
2914
2915 return true;
2916 }
2917
2918 /// Match shuffles that concatenate two vectors, rotate the concatenation,
2919 /// and then extract the original number of elements from the rotated result.
2920 /// This is equivalent to vector.splice or X86's PALIGNR instruction. The
2921 /// returned rotation amount is for a rotate right, where elements move from
2922 /// higher elements to lower elements. \p LoSrc indicates the first source
2923 /// vector of the rotate or -1 for undef. \p HiSrc indicates the second vector
2924 /// of the rotate or -1 for undef. At least one of \p LoSrc and \p HiSrc will be
2925 /// 0 or 1 if a rotation is found.
2926 ///
2927 /// NOTE: We talk about rotate to the right which matches how bit shift and
2928 /// rotate instructions are described where LSBs are on the right, but LLVM IR
2929 /// and the table below write vectors with the lowest elements on the left.
isElementRotate(int & LoSrc,int & HiSrc,ArrayRef<int> Mask)2930 static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
2931 int Size = Mask.size();
2932
2933 // We need to detect various ways of spelling a rotation:
2934 // [11, 12, 13, 14, 15, 0, 1, 2]
2935 // [-1, 12, 13, 14, -1, -1, 1, -1]
2936 // [-1, -1, -1, -1, -1, -1, 1, 2]
2937 // [ 3, 4, 5, 6, 7, 8, 9, 10]
2938 // [-1, 4, 5, 6, -1, -1, 9, -1]
2939 // [-1, 4, 5, 6, -1, -1, -1, -1]
2940 int Rotation = 0;
2941 LoSrc = -1;
2942 HiSrc = -1;
2943 for (int i = 0; i != Size; ++i) {
2944 int M = Mask[i];
2945 if (M < 0)
2946 continue;
2947
2948 // Determine where a rotate vector would have started.
2949 int StartIdx = i - (M % Size);
2950 // The identity rotation isn't interesting, stop.
2951 if (StartIdx == 0)
2952 return -1;
2953
2954 // If we found the tail of a vector the rotation must be the missing
2955 // front. If we found the head of a vector, it must be how much of the
2956 // head.
2957 int CandidateRotation = StartIdx < 0 ? -StartIdx : Size - StartIdx;
2958
2959 if (Rotation == 0)
2960 Rotation = CandidateRotation;
2961 else if (Rotation != CandidateRotation)
2962 // The rotations don't match, so we can't match this mask.
2963 return -1;
2964
2965 // Compute which value this mask is pointing at.
2966 int MaskSrc = M < Size ? 0 : 1;
2967
2968 // Compute which of the two target values this index should be assigned to.
2969 // This reflects whether the high elements are remaining or the low elemnts
2970 // are remaining.
2971 int &TargetSrc = StartIdx < 0 ? HiSrc : LoSrc;
2972
2973 // Either set up this value if we've not encountered it before, or check
2974 // that it remains consistent.
2975 if (TargetSrc < 0)
2976 TargetSrc = MaskSrc;
2977 else if (TargetSrc != MaskSrc)
2978 // This may be a rotation, but it pulls from the inputs in some
2979 // unsupported interleaving.
2980 return -1;
2981 }
2982
2983 // Check that we successfully analyzed the mask, and normalize the results.
2984 assert(Rotation != 0 && "Failed to locate a viable rotation!");
2985 assert((LoSrc >= 0 || HiSrc >= 0) &&
2986 "Failed to find a rotated input vector!");
2987
2988 return Rotation;
2989 }
2990
2991 // Lower the following shuffles to vnsrl.
2992 // t34: v8i8 = extract_subvector t11, Constant:i64<0>
2993 // t33: v8i8 = extract_subvector t11, Constant:i64<8>
2994 // a) t35: v8i8 = vector_shuffle<0,2,4,6,8,10,12,14> t34, t33
2995 // b) t35: v8i8 = vector_shuffle<1,3,5,7,9,11,13,15> t34, t33
lowerVECTOR_SHUFFLEAsVNSRL(const SDLoc & DL,MVT VT,MVT ContainerVT,SDValue V1,SDValue V2,SDValue TrueMask,SDValue VL,ArrayRef<int> Mask,const RISCVSubtarget & Subtarget,SelectionDAG & DAG)2996 static SDValue lowerVECTOR_SHUFFLEAsVNSRL(const SDLoc &DL, MVT VT,
2997 MVT ContainerVT, SDValue V1,
2998 SDValue V2, SDValue TrueMask,
2999 SDValue VL, ArrayRef<int> Mask,
3000 const RISCVSubtarget &Subtarget,
3001 SelectionDAG &DAG) {
3002 // Need to be able to widen the vector.
3003 if (VT.getScalarSizeInBits() >= Subtarget.getELEN())
3004 return SDValue();
3005
3006 // Both input must be extracts.
3007 if (V1.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
3008 V2.getOpcode() != ISD::EXTRACT_SUBVECTOR)
3009 return SDValue();
3010
3011 // Extracting from the same source.
3012 SDValue Src = V1.getOperand(0);
3013 if (Src != V2.getOperand(0))
3014 return SDValue();
3015
3016 // Src needs to have twice the number of elements.
3017 if (Src.getValueType().getVectorNumElements() != (Mask.size() * 2))
3018 return SDValue();
3019
3020 // The extracts must extract the two halves of the source.
3021 if (V1.getConstantOperandVal(1) != 0 ||
3022 V2.getConstantOperandVal(1) != Mask.size())
3023 return SDValue();
3024
3025 // First index must be the first even or odd element from V1.
3026 if (Mask[0] != 0 && Mask[0] != 1)
3027 return SDValue();
3028
3029 // The others must increase by 2 each time.
3030 // TODO: Support undef elements?
3031 for (unsigned i = 1; i != Mask.size(); ++i)
3032 if (Mask[i] != Mask[i - 1] + 2)
3033 return SDValue();
3034
3035 // Convert the source using a container type with twice the elements. Since
3036 // source VT is legal and twice this VT, we know VT isn't LMUL=8 so it is
3037 // safe to double.
3038 MVT DoubleContainerVT =
3039 MVT::getVectorVT(ContainerVT.getVectorElementType(),
3040 ContainerVT.getVectorElementCount() * 2);
3041 Src = convertToScalableVector(DoubleContainerVT, Src, DAG, Subtarget);
3042
3043 // Convert the vector to a wider integer type with the original element
3044 // count. This also converts FP to int.
3045 unsigned EltBits = ContainerVT.getScalarSizeInBits();
3046 MVT WideIntEltVT = MVT::getIntegerVT(EltBits * 2);
3047 MVT WideIntContainerVT =
3048 MVT::getVectorVT(WideIntEltVT, ContainerVT.getVectorElementCount());
3049 Src = DAG.getBitcast(WideIntContainerVT, Src);
3050
3051 // Convert to the integer version of the container type.
3052 MVT IntEltVT = MVT::getIntegerVT(EltBits);
3053 MVT IntContainerVT =
3054 MVT::getVectorVT(IntEltVT, ContainerVT.getVectorElementCount());
3055
3056 // If we want even elements, then the shift amount is 0. Otherwise, shift by
3057 // the original element size.
3058 unsigned Shift = Mask[0] == 0 ? 0 : EltBits;
3059 SDValue SplatShift = DAG.getNode(
3060 RISCVISD::VMV_V_X_VL, DL, IntContainerVT, DAG.getUNDEF(ContainerVT),
3061 DAG.getConstant(Shift, DL, Subtarget.getXLenVT()), VL);
3062 SDValue Res =
3063 DAG.getNode(RISCVISD::VNSRL_VL, DL, IntContainerVT, Src, SplatShift,
3064 DAG.getUNDEF(IntContainerVT), TrueMask, VL);
3065 // Cast back to FP if needed.
3066 Res = DAG.getBitcast(ContainerVT, Res);
3067
3068 return convertFromScalableVector(VT, Res, DAG, Subtarget);
3069 }
3070
3071 static SDValue
getVSlidedown(SelectionDAG & DAG,const RISCVSubtarget & Subtarget,SDLoc DL,EVT VT,SDValue Merge,SDValue Op,SDValue Offset,SDValue Mask,SDValue VL,unsigned Policy=RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)3072 getVSlidedown(SelectionDAG &DAG, const RISCVSubtarget &Subtarget, SDLoc DL,
3073 EVT VT, SDValue Merge, SDValue Op, SDValue Offset, SDValue Mask,
3074 SDValue VL,
3075 unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED) {
3076 if (Merge.isUndef())
3077 Policy = RISCVII::TAIL_AGNOSTIC | RISCVII::MASK_AGNOSTIC;
3078 SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
3079 SDValue Ops[] = {Merge, Op, Offset, Mask, VL, PolicyOp};
3080 return DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VT, Ops);
3081 }
3082
3083 static SDValue
getVSlideup(SelectionDAG & DAG,const RISCVSubtarget & Subtarget,SDLoc DL,EVT VT,SDValue Merge,SDValue Op,SDValue Offset,SDValue Mask,SDValue VL,unsigned Policy=RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)3084 getVSlideup(SelectionDAG &DAG, const RISCVSubtarget &Subtarget, SDLoc DL,
3085 EVT VT, SDValue Merge, SDValue Op, SDValue Offset, SDValue Mask,
3086 SDValue VL,
3087 unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED) {
3088 if (Merge.isUndef())
3089 Policy = RISCVII::TAIL_AGNOSTIC | RISCVII::MASK_AGNOSTIC;
3090 SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
3091 SDValue Ops[] = {Merge, Op, Offset, Mask, VL, PolicyOp};
3092 return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VT, Ops);
3093 }
3094
3095 // Lower the following shuffle to vslidedown.
3096 // a)
3097 // t49: v8i8 = extract_subvector t13, Constant:i64<0>
3098 // t109: v8i8 = extract_subvector t13, Constant:i64<8>
3099 // t108: v8i8 = vector_shuffle<1,2,3,4,5,6,7,8> t49, t106
3100 // b)
3101 // t69: v16i16 = extract_subvector t68, Constant:i64<0>
3102 // t23: v8i16 = extract_subvector t69, Constant:i64<0>
3103 // t29: v4i16 = extract_subvector t23, Constant:i64<4>
3104 // t26: v8i16 = extract_subvector t69, Constant:i64<8>
3105 // t30: v4i16 = extract_subvector t26, Constant:i64<0>
3106 // t54: v4i16 = vector_shuffle<1,2,3,4> t29, t30
lowerVECTOR_SHUFFLEAsVSlidedown(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const RISCVSubtarget & Subtarget,SelectionDAG & DAG)3107 static SDValue lowerVECTOR_SHUFFLEAsVSlidedown(const SDLoc &DL, MVT VT,
3108 SDValue V1, SDValue V2,
3109 ArrayRef<int> Mask,
3110 const RISCVSubtarget &Subtarget,
3111 SelectionDAG &DAG) {
3112 auto findNonEXTRACT_SUBVECTORParent =
3113 [](SDValue Parent) -> std::pair<SDValue, uint64_t> {
3114 uint64_t Offset = 0;
3115 while (Parent.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
3116 // EXTRACT_SUBVECTOR can be used to extract a fixed-width vector from
3117 // a scalable vector. But we don't want to match the case.
3118 Parent.getOperand(0).getSimpleValueType().isFixedLengthVector()) {
3119 Offset += Parent.getConstantOperandVal(1);
3120 Parent = Parent.getOperand(0);
3121 }
3122 return std::make_pair(Parent, Offset);
3123 };
3124
3125 auto [V1Src, V1IndexOffset] = findNonEXTRACT_SUBVECTORParent(V1);
3126 auto [V2Src, V2IndexOffset] = findNonEXTRACT_SUBVECTORParent(V2);
3127
3128 // Extracting from the same source.
3129 SDValue Src = V1Src;
3130 if (Src != V2Src)
3131 return SDValue();
3132
3133 // Rebuild mask because Src may be from multiple EXTRACT_SUBVECTORs.
3134 SmallVector<int, 16> NewMask(Mask);
3135 for (size_t i = 0; i != NewMask.size(); ++i) {
3136 if (NewMask[i] == -1)
3137 continue;
3138
3139 if (static_cast<size_t>(NewMask[i]) < NewMask.size()) {
3140 NewMask[i] = NewMask[i] + V1IndexOffset;
3141 } else {
3142 // Minus NewMask.size() is needed. Otherwise, the b case would be
3143 // <5,6,7,12> instead of <5,6,7,8>.
3144 NewMask[i] = NewMask[i] - NewMask.size() + V2IndexOffset;
3145 }
3146 }
3147
3148 // First index must be known and non-zero. It will be used as the slidedown
3149 // amount.
3150 if (NewMask[0] <= 0)
3151 return SDValue();
3152
3153 // NewMask is also continuous.
3154 for (unsigned i = 1; i != NewMask.size(); ++i)
3155 if (NewMask[i - 1] + 1 != NewMask[i])
3156 return SDValue();
3157
3158 MVT XLenVT = Subtarget.getXLenVT();
3159 MVT SrcVT = Src.getSimpleValueType();
3160 MVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT, Subtarget);
3161 auto [TrueMask, VL] = getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
3162 SDValue Slidedown =
3163 getVSlidedown(DAG, Subtarget, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
3164 convertToScalableVector(ContainerVT, Src, DAG, Subtarget),
3165 DAG.getConstant(NewMask[0], DL, XLenVT), TrueMask, VL);
3166 return DAG.getNode(
3167 ISD::EXTRACT_SUBVECTOR, DL, VT,
3168 convertFromScalableVector(SrcVT, Slidedown, DAG, Subtarget),
3169 DAG.getConstant(0, DL, XLenVT));
3170 }
3171
lowerVECTOR_SHUFFLE(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3172 static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
3173 const RISCVSubtarget &Subtarget) {
3174 SDValue V1 = Op.getOperand(0);
3175 SDValue V2 = Op.getOperand(1);
3176 SDLoc DL(Op);
3177 MVT XLenVT = Subtarget.getXLenVT();
3178 MVT VT = Op.getSimpleValueType();
3179 unsigned NumElts = VT.getVectorNumElements();
3180 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
3181
3182 MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
3183
3184 auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3185
3186 if (SVN->isSplat()) {
3187 const int Lane = SVN->getSplatIndex();
3188 if (Lane >= 0) {
3189 MVT SVT = VT.getVectorElementType();
3190
3191 // Turn splatted vector load into a strided load with an X0 stride.
3192 SDValue V = V1;
3193 // Peek through CONCAT_VECTORS as VectorCombine can concat a vector
3194 // with undef.
3195 // FIXME: Peek through INSERT_SUBVECTOR, EXTRACT_SUBVECTOR, bitcasts?
3196 int Offset = Lane;
3197 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
3198 int OpElements =
3199 V.getOperand(0).getSimpleValueType().getVectorNumElements();
3200 V = V.getOperand(Offset / OpElements);
3201 Offset %= OpElements;
3202 }
3203
3204 // We need to ensure the load isn't atomic or volatile.
3205 if (ISD::isNormalLoad(V.getNode()) && cast<LoadSDNode>(V)->isSimple()) {
3206 auto *Ld = cast<LoadSDNode>(V);
3207 Offset *= SVT.getStoreSize();
3208 SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(),
3209 TypeSize::Fixed(Offset), DL);
3210
3211 // If this is SEW=64 on RV32, use a strided load with a stride of x0.
3212 if (SVT.isInteger() && SVT.bitsGT(XLenVT)) {
3213 SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
3214 SDValue IntID =
3215 DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, XLenVT);
3216 SDValue Ops[] = {Ld->getChain(),
3217 IntID,
3218 DAG.getUNDEF(ContainerVT),
3219 NewAddr,
3220 DAG.getRegister(RISCV::X0, XLenVT),
3221 VL};
3222 SDValue NewLoad = DAG.getMemIntrinsicNode(
3223 ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, SVT,
3224 DAG.getMachineFunction().getMachineMemOperand(
3225 Ld->getMemOperand(), Offset, SVT.getStoreSize()));
3226 DAG.makeEquivalentMemoryOrdering(Ld, NewLoad);
3227 return convertFromScalableVector(VT, NewLoad, DAG, Subtarget);
3228 }
3229
3230 // Otherwise use a scalar load and splat. This will give the best
3231 // opportunity to fold a splat into the operation. ISel can turn it into
3232 // the x0 strided load if we aren't able to fold away the select.
3233 if (SVT.isFloatingPoint())
3234 V = DAG.getLoad(SVT, DL, Ld->getChain(), NewAddr,
3235 Ld->getPointerInfo().getWithOffset(Offset),
3236 Ld->getOriginalAlign(),
3237 Ld->getMemOperand()->getFlags());
3238 else
3239 V = DAG.getExtLoad(ISD::SEXTLOAD, DL, XLenVT, Ld->getChain(), NewAddr,
3240 Ld->getPointerInfo().getWithOffset(Offset), SVT,
3241 Ld->getOriginalAlign(),
3242 Ld->getMemOperand()->getFlags());
3243 DAG.makeEquivalentMemoryOrdering(Ld, V);
3244
3245 unsigned Opc =
3246 VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL : RISCVISD::VMV_V_X_VL;
3247 SDValue Splat =
3248 DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), V, VL);
3249 return convertFromScalableVector(VT, Splat, DAG, Subtarget);
3250 }
3251
3252 V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
3253 assert(Lane < (int)NumElts && "Unexpected lane!");
3254 SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT,
3255 V1, DAG.getConstant(Lane, DL, XLenVT),
3256 DAG.getUNDEF(ContainerVT), TrueMask, VL);
3257 return convertFromScalableVector(VT, Gather, DAG, Subtarget);
3258 }
3259 }
3260
3261 ArrayRef<int> Mask = SVN->getMask();
3262
3263 if (SDValue V =
3264 lowerVECTOR_SHUFFLEAsVSlidedown(DL, VT, V1, V2, Mask, Subtarget, DAG))
3265 return V;
3266
3267 // Lower rotations to a SLIDEDOWN and a SLIDEUP. One of the source vectors may
3268 // be undef which can be handled with a single SLIDEDOWN/UP.
3269 int LoSrc, HiSrc;
3270 int Rotation = isElementRotate(LoSrc, HiSrc, Mask);
3271 if (Rotation > 0) {
3272 SDValue LoV, HiV;
3273 if (LoSrc >= 0) {
3274 LoV = LoSrc == 0 ? V1 : V2;
3275 LoV = convertToScalableVector(ContainerVT, LoV, DAG, Subtarget);
3276 }
3277 if (HiSrc >= 0) {
3278 HiV = HiSrc == 0 ? V1 : V2;
3279 HiV = convertToScalableVector(ContainerVT, HiV, DAG, Subtarget);
3280 }
3281
3282 // We found a rotation. We need to slide HiV down by Rotation. Then we need
3283 // to slide LoV up by (NumElts - Rotation).
3284 unsigned InvRotate = NumElts - Rotation;
3285
3286 SDValue Res = DAG.getUNDEF(ContainerVT);
3287 if (HiV) {
3288 // If we are doing a SLIDEDOWN+SLIDEUP, reduce the VL for the SLIDEDOWN.
3289 // FIXME: If we are only doing a SLIDEDOWN, don't reduce the VL as it
3290 // causes multiple vsetvlis in some test cases such as lowering
3291 // reduce.mul
3292 SDValue DownVL = VL;
3293 if (LoV)
3294 DownVL = DAG.getConstant(InvRotate, DL, XLenVT);
3295 Res = getVSlidedown(DAG, Subtarget, DL, ContainerVT, Res, HiV,
3296 DAG.getConstant(Rotation, DL, XLenVT), TrueMask,
3297 DownVL);
3298 }
3299 if (LoV)
3300 Res = getVSlideup(DAG, Subtarget, DL, ContainerVT, Res, LoV,
3301 DAG.getConstant(InvRotate, DL, XLenVT), TrueMask, VL,
3302 RISCVII::TAIL_AGNOSTIC);
3303
3304 return convertFromScalableVector(VT, Res, DAG, Subtarget);
3305 }
3306
3307 if (SDValue V = lowerVECTOR_SHUFFLEAsVNSRL(
3308 DL, VT, ContainerVT, V1, V2, TrueMask, VL, Mask, Subtarget, DAG))
3309 return V;
3310
3311 // Detect an interleave shuffle and lower to
3312 // (vmaccu.vx (vwaddu.vx lohalf(V1), lohalf(V2)), lohalf(V2), (2^eltbits - 1))
3313 bool SwapSources;
3314 if (isInterleaveShuffle(Mask, VT, SwapSources, Subtarget)) {
3315 // Swap sources if needed.
3316 if (SwapSources)
3317 std::swap(V1, V2);
3318
3319 // Extract the lower half of the vectors.
3320 MVT HalfVT = VT.getHalfNumVectorElementsVT();
3321 V1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, V1,
3322 DAG.getConstant(0, DL, XLenVT));
3323 V2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, V2,
3324 DAG.getConstant(0, DL, XLenVT));
3325
3326 // Double the element width and halve the number of elements in an int type.
3327 unsigned EltBits = VT.getScalarSizeInBits();
3328 MVT WideIntEltVT = MVT::getIntegerVT(EltBits * 2);
3329 MVT WideIntVT =
3330 MVT::getVectorVT(WideIntEltVT, VT.getVectorNumElements() / 2);
3331 // Convert this to a scalable vector. We need to base this on the
3332 // destination size to ensure there's always a type with a smaller LMUL.
3333 MVT WideIntContainerVT =
3334 getContainerForFixedLengthVector(DAG, WideIntVT, Subtarget);
3335
3336 // Convert sources to scalable vectors with the same element count as the
3337 // larger type.
3338 MVT HalfContainerVT = MVT::getVectorVT(
3339 VT.getVectorElementType(), WideIntContainerVT.getVectorElementCount());
3340 V1 = convertToScalableVector(HalfContainerVT, V1, DAG, Subtarget);
3341 V2 = convertToScalableVector(HalfContainerVT, V2, DAG, Subtarget);
3342
3343 // Cast sources to integer.
3344 MVT IntEltVT = MVT::getIntegerVT(EltBits);
3345 MVT IntHalfVT =
3346 MVT::getVectorVT(IntEltVT, HalfContainerVT.getVectorElementCount());
3347 V1 = DAG.getBitcast(IntHalfVT, V1);
3348 V2 = DAG.getBitcast(IntHalfVT, V2);
3349
3350 // Freeze V2 since we use it twice and we need to be sure that the add and
3351 // multiply see the same value.
3352 V2 = DAG.getFreeze(V2);
3353
3354 // Recreate TrueMask using the widened type's element count.
3355 TrueMask = getAllOnesMask(HalfContainerVT, VL, DL, DAG);
3356
3357 // Widen V1 and V2 with 0s and add one copy of V2 to V1.
3358 SDValue Add =
3359 DAG.getNode(RISCVISD::VWADDU_VL, DL, WideIntContainerVT, V1, V2,
3360 DAG.getUNDEF(WideIntContainerVT), TrueMask, VL);
3361 // Create 2^eltbits - 1 copies of V2 by multiplying by the largest integer.
3362 SDValue Multiplier = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntHalfVT,
3363 DAG.getUNDEF(IntHalfVT),
3364 DAG.getAllOnesConstant(DL, XLenVT), VL);
3365 SDValue WidenMul =
3366 DAG.getNode(RISCVISD::VWMULU_VL, DL, WideIntContainerVT, V2, Multiplier,
3367 DAG.getUNDEF(WideIntContainerVT), TrueMask, VL);
3368 // Add the new copies to our previous addition giving us 2^eltbits copies of
3369 // V2. This is equivalent to shifting V2 left by eltbits. This should
3370 // combine with the vwmulu.vv above to form vwmaccu.vv.
3371 Add = DAG.getNode(RISCVISD::ADD_VL, DL, WideIntContainerVT, Add, WidenMul,
3372 DAG.getUNDEF(WideIntContainerVT), TrueMask, VL);
3373 // Cast back to ContainerVT. We need to re-create a new ContainerVT in case
3374 // WideIntContainerVT is a larger fractional LMUL than implied by the fixed
3375 // vector VT.
3376 ContainerVT =
3377 MVT::getVectorVT(VT.getVectorElementType(),
3378 WideIntContainerVT.getVectorElementCount() * 2);
3379 Add = DAG.getBitcast(ContainerVT, Add);
3380 return convertFromScalableVector(VT, Add, DAG, Subtarget);
3381 }
3382
3383 // Detect shuffles which can be re-expressed as vector selects; these are
3384 // shuffles in which each element in the destination is taken from an element
3385 // at the corresponding index in either source vectors.
3386 bool IsSelect = all_of(enumerate(Mask), [&](const auto &MaskIdx) {
3387 int MaskIndex = MaskIdx.value();
3388 return MaskIndex < 0 || MaskIdx.index() == (unsigned)MaskIndex % NumElts;
3389 });
3390
3391 assert(!V1.isUndef() && "Unexpected shuffle canonicalization");
3392
3393 SmallVector<SDValue> MaskVals;
3394 // As a backup, shuffles can be lowered via a vrgather instruction, possibly
3395 // merged with a second vrgather.
3396 SmallVector<SDValue> GatherIndicesLHS, GatherIndicesRHS;
3397
3398 // By default we preserve the original operand order, and use a mask to
3399 // select LHS as true and RHS as false. However, since RVV vector selects may
3400 // feature splats but only on the LHS, we may choose to invert our mask and
3401 // instead select between RHS and LHS.
3402 bool SwapOps = DAG.isSplatValue(V2) && !DAG.isSplatValue(V1);
3403 bool InvertMask = IsSelect == SwapOps;
3404
3405 // Keep a track of which non-undef indices are used by each LHS/RHS shuffle
3406 // half.
3407 DenseMap<int, unsigned> LHSIndexCounts, RHSIndexCounts;
3408
3409 // Now construct the mask that will be used by the vselect or blended
3410 // vrgather operation. For vrgathers, construct the appropriate indices into
3411 // each vector.
3412 for (int MaskIndex : Mask) {
3413 bool SelectMaskVal = (MaskIndex < (int)NumElts) ^ InvertMask;
3414 MaskVals.push_back(DAG.getConstant(SelectMaskVal, DL, XLenVT));
3415 if (!IsSelect) {
3416 bool IsLHSOrUndefIndex = MaskIndex < (int)NumElts;
3417 GatherIndicesLHS.push_back(IsLHSOrUndefIndex && MaskIndex >= 0
3418 ? DAG.getConstant(MaskIndex, DL, XLenVT)
3419 : DAG.getUNDEF(XLenVT));
3420 GatherIndicesRHS.push_back(
3421 IsLHSOrUndefIndex ? DAG.getUNDEF(XLenVT)
3422 : DAG.getConstant(MaskIndex - NumElts, DL, XLenVT));
3423 if (IsLHSOrUndefIndex && MaskIndex >= 0)
3424 ++LHSIndexCounts[MaskIndex];
3425 if (!IsLHSOrUndefIndex)
3426 ++RHSIndexCounts[MaskIndex - NumElts];
3427 }
3428 }
3429
3430 if (SwapOps) {
3431 std::swap(V1, V2);
3432 std::swap(GatherIndicesLHS, GatherIndicesRHS);
3433 }
3434
3435 assert(MaskVals.size() == NumElts && "Unexpected select-like shuffle");
3436 MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
3437 SDValue SelectMask = DAG.getBuildVector(MaskVT, DL, MaskVals);
3438
3439 if (IsSelect)
3440 return DAG.getNode(ISD::VSELECT, DL, VT, SelectMask, V1, V2);
3441
3442 if (VT.getScalarSizeInBits() == 8 && VT.getVectorNumElements() > 256) {
3443 // On such a large vector we're unable to use i8 as the index type.
3444 // FIXME: We could promote the index to i16 and use vrgatherei16, but that
3445 // may involve vector splitting if we're already at LMUL=8, or our
3446 // user-supplied maximum fixed-length LMUL.
3447 return SDValue();
3448 }
3449
3450 unsigned GatherVXOpc = RISCVISD::VRGATHER_VX_VL;
3451 unsigned GatherVVOpc = RISCVISD::VRGATHER_VV_VL;
3452 MVT IndexVT = VT.changeTypeToInteger();
3453 // Since we can't introduce illegal index types at this stage, use i16 and
3454 // vrgatherei16 if the corresponding index type for plain vrgather is greater
3455 // than XLenVT.
3456 if (IndexVT.getScalarType().bitsGT(XLenVT)) {
3457 GatherVVOpc = RISCVISD::VRGATHEREI16_VV_VL;
3458 IndexVT = IndexVT.changeVectorElementType(MVT::i16);
3459 }
3460
3461 MVT IndexContainerVT =
3462 ContainerVT.changeVectorElementType(IndexVT.getScalarType());
3463
3464 SDValue Gather;
3465 // TODO: This doesn't trigger for i64 vectors on RV32, since there we
3466 // encounter a bitcasted BUILD_VECTOR with low/high i32 values.
3467 if (SDValue SplatValue = DAG.getSplatValue(V1, /*LegalTypes*/ true)) {
3468 Gather = lowerScalarSplat(SDValue(), SplatValue, VL, ContainerVT, DL, DAG,
3469 Subtarget);
3470 } else {
3471 V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
3472 // If only one index is used, we can use a "splat" vrgather.
3473 // TODO: We can splat the most-common index and fix-up any stragglers, if
3474 // that's beneficial.
3475 if (LHSIndexCounts.size() == 1) {
3476 int SplatIndex = LHSIndexCounts.begin()->getFirst();
3477 Gather = DAG.getNode(GatherVXOpc, DL, ContainerVT, V1,
3478 DAG.getConstant(SplatIndex, DL, XLenVT),
3479 DAG.getUNDEF(ContainerVT), TrueMask, VL);
3480 } else {
3481 SDValue LHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesLHS);
3482 LHSIndices =
3483 convertToScalableVector(IndexContainerVT, LHSIndices, DAG, Subtarget);
3484
3485 Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices,
3486 DAG.getUNDEF(ContainerVT), TrueMask, VL);
3487 }
3488 }
3489
3490 // If a second vector operand is used by this shuffle, blend it in with an
3491 // additional vrgather.
3492 if (!V2.isUndef()) {
3493 V2 = convertToScalableVector(ContainerVT, V2, DAG, Subtarget);
3494
3495 MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1);
3496 SelectMask =
3497 convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget);
3498
3499 // If only one index is used, we can use a "splat" vrgather.
3500 // TODO: We can splat the most-common index and fix-up any stragglers, if
3501 // that's beneficial.
3502 if (RHSIndexCounts.size() == 1) {
3503 int SplatIndex = RHSIndexCounts.begin()->getFirst();
3504 Gather = DAG.getNode(GatherVXOpc, DL, ContainerVT, V2,
3505 DAG.getConstant(SplatIndex, DL, XLenVT), Gather,
3506 SelectMask, VL);
3507 } else {
3508 SDValue RHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesRHS);
3509 RHSIndices =
3510 convertToScalableVector(IndexContainerVT, RHSIndices, DAG, Subtarget);
3511 Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V2, RHSIndices, Gather,
3512 SelectMask, VL);
3513 }
3514 }
3515
3516 return convertFromScalableVector(VT, Gather, DAG, Subtarget);
3517 }
3518
isShuffleMaskLegal(ArrayRef<int> M,EVT VT) const3519 bool RISCVTargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
3520 // Support splats for any type. These should type legalize well.
3521 if (ShuffleVectorSDNode::isSplatMask(M.data(), VT))
3522 return true;
3523
3524 // Only support legal VTs for other shuffles for now.
3525 if (!isTypeLegal(VT))
3526 return false;
3527
3528 MVT SVT = VT.getSimpleVT();
3529
3530 bool SwapSources;
3531 int LoSrc, HiSrc;
3532 return (isElementRotate(LoSrc, HiSrc, M) > 0) ||
3533 isInterleaveShuffle(M, SVT, SwapSources, Subtarget);
3534 }
3535
3536 // Lower CTLZ_ZERO_UNDEF or CTTZ_ZERO_UNDEF by converting to FP and extracting
3537 // the exponent.
3538 SDValue
lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,SelectionDAG & DAG) const3539 RISCVTargetLowering::lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,
3540 SelectionDAG &DAG) const {
3541 MVT VT = Op.getSimpleValueType();
3542 unsigned EltSize = VT.getScalarSizeInBits();
3543 SDValue Src = Op.getOperand(0);
3544 SDLoc DL(Op);
3545
3546 // We choose FP type that can represent the value if possible. Otherwise, we
3547 // use rounding to zero conversion for correct exponent of the result.
3548 // TODO: Use f16 for i8 when possible?
3549 MVT FloatEltVT = (EltSize >= 32) ? MVT::f64 : MVT::f32;
3550 if (!isTypeLegal(MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount())))
3551 FloatEltVT = MVT::f32;
3552 MVT FloatVT = MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount());
3553
3554 // Legal types should have been checked in the RISCVTargetLowering
3555 // constructor.
3556 // TODO: Splitting may make sense in some cases.
3557 assert(DAG.getTargetLoweringInfo().isTypeLegal(FloatVT) &&
3558 "Expected legal float type!");
3559
3560 // For CTTZ_ZERO_UNDEF, we need to extract the lowest set bit using X & -X.
3561 // The trailing zero count is equal to log2 of this single bit value.
3562 if (Op.getOpcode() == ISD::CTTZ_ZERO_UNDEF) {
3563 SDValue Neg = DAG.getNegative(Src, DL, VT);
3564 Src = DAG.getNode(ISD::AND, DL, VT, Src, Neg);
3565 }
3566
3567 // We have a legal FP type, convert to it.
3568 SDValue FloatVal;
3569 if (FloatVT.bitsGT(VT)) {
3570 FloatVal = DAG.getNode(ISD::UINT_TO_FP, DL, FloatVT, Src);
3571 } else {
3572 // Use RTZ to avoid rounding influencing exponent of FloatVal.
3573 MVT ContainerVT = VT;
3574 if (VT.isFixedLengthVector()) {
3575 ContainerVT = getContainerForFixedLengthVector(VT);
3576 Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
3577 }
3578
3579 auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3580 SDValue RTZRM =
3581 DAG.getTargetConstant(RISCVFPRndMode::RTZ, DL, Subtarget.getXLenVT());
3582 MVT ContainerFloatVT =
3583 MVT::getVectorVT(FloatEltVT, ContainerVT.getVectorElementCount());
3584 FloatVal = DAG.getNode(RISCVISD::VFCVT_RM_F_XU_VL, DL, ContainerFloatVT,
3585 Src, Mask, RTZRM, VL);
3586 if (VT.isFixedLengthVector())
3587 FloatVal = convertFromScalableVector(FloatVT, FloatVal, DAG, Subtarget);
3588 }
3589 // Bitcast to integer and shift the exponent to the LSB.
3590 EVT IntVT = FloatVT.changeVectorElementTypeToInteger();
3591 SDValue Bitcast = DAG.getBitcast(IntVT, FloatVal);
3592 unsigned ShiftAmt = FloatEltVT == MVT::f64 ? 52 : 23;
3593 SDValue Exp = DAG.getNode(ISD::SRL, DL, IntVT, Bitcast,
3594 DAG.getConstant(ShiftAmt, DL, IntVT));
3595 // Restore back to original type. Truncation after SRL is to generate vnsrl.
3596 if (IntVT.bitsLT(VT))
3597 Exp = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Exp);
3598 else if (IntVT.bitsGT(VT))
3599 Exp = DAG.getNode(ISD::TRUNCATE, DL, VT, Exp);
3600 // The exponent contains log2 of the value in biased form.
3601 unsigned ExponentBias = FloatEltVT == MVT::f64 ? 1023 : 127;
3602
3603 // For trailing zeros, we just need to subtract the bias.
3604 if (Op.getOpcode() == ISD::CTTZ_ZERO_UNDEF)
3605 return DAG.getNode(ISD::SUB, DL, VT, Exp,
3606 DAG.getConstant(ExponentBias, DL, VT));
3607
3608 // For leading zeros, we need to remove the bias and convert from log2 to
3609 // leading zeros. We can do this by subtracting from (Bias + (EltSize - 1)).
3610 unsigned Adjust = ExponentBias + (EltSize - 1);
3611 SDValue Res =
3612 DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(Adjust, DL, VT), Exp);
3613 // The above result with zero input equals to Adjust which is greater than
3614 // EltSize. Hence, we can do min(Res, EltSize) for CTLZ.
3615 if (Op.getOpcode() == ISD::CTLZ)
3616 Res = DAG.getNode(ISD::UMIN, DL, VT, Res, DAG.getConstant(EltSize, DL, VT));
3617 return Res;
3618 }
3619
3620 // While RVV has alignment restrictions, we should always be able to load as a
3621 // legal equivalently-sized byte-typed vector instead. This method is
3622 // responsible for re-expressing a ISD::LOAD via a correctly-aligned type. If
3623 // the load is already correctly-aligned, it returns SDValue().
expandUnalignedRVVLoad(SDValue Op,SelectionDAG & DAG) const3624 SDValue RISCVTargetLowering::expandUnalignedRVVLoad(SDValue Op,
3625 SelectionDAG &DAG) const {
3626 auto *Load = cast<LoadSDNode>(Op);
3627 assert(Load && Load->getMemoryVT().isVector() && "Expected vector load");
3628
3629 if (allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
3630 Load->getMemoryVT(),
3631 *Load->getMemOperand()))
3632 return SDValue();
3633
3634 SDLoc DL(Op);
3635 MVT VT = Op.getSimpleValueType();
3636 unsigned EltSizeBits = VT.getScalarSizeInBits();
3637 assert((EltSizeBits == 16 || EltSizeBits == 32 || EltSizeBits == 64) &&
3638 "Unexpected unaligned RVV load type");
3639 MVT NewVT =
3640 MVT::getVectorVT(MVT::i8, VT.getVectorElementCount() * (EltSizeBits / 8));
3641 assert(NewVT.isValid() &&
3642 "Expecting equally-sized RVV vector types to be legal");
3643 SDValue L = DAG.getLoad(NewVT, DL, Load->getChain(), Load->getBasePtr(),
3644 Load->getPointerInfo(), Load->getOriginalAlign(),
3645 Load->getMemOperand()->getFlags());
3646 return DAG.getMergeValues({DAG.getBitcast(VT, L), L.getValue(1)}, DL);
3647 }
3648
3649 // While RVV has alignment restrictions, we should always be able to store as a
3650 // legal equivalently-sized byte-typed vector instead. This method is
3651 // responsible for re-expressing a ISD::STORE via a correctly-aligned type. It
3652 // returns SDValue() if the store is already correctly aligned.
expandUnalignedRVVStore(SDValue Op,SelectionDAG & DAG) const3653 SDValue RISCVTargetLowering::expandUnalignedRVVStore(SDValue Op,
3654 SelectionDAG &DAG) const {
3655 auto *Store = cast<StoreSDNode>(Op);
3656 assert(Store && Store->getValue().getValueType().isVector() &&
3657 "Expected vector store");
3658
3659 if (allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
3660 Store->getMemoryVT(),
3661 *Store->getMemOperand()))
3662 return SDValue();
3663
3664 SDLoc DL(Op);
3665 SDValue StoredVal = Store->getValue();
3666 MVT VT = StoredVal.getSimpleValueType();
3667 unsigned EltSizeBits = VT.getScalarSizeInBits();
3668 assert((EltSizeBits == 16 || EltSizeBits == 32 || EltSizeBits == 64) &&
3669 "Unexpected unaligned RVV store type");
3670 MVT NewVT =
3671 MVT::getVectorVT(MVT::i8, VT.getVectorElementCount() * (EltSizeBits / 8));
3672 assert(NewVT.isValid() &&
3673 "Expecting equally-sized RVV vector types to be legal");
3674 StoredVal = DAG.getBitcast(NewVT, StoredVal);
3675 return DAG.getStore(Store->getChain(), DL, StoredVal, Store->getBasePtr(),
3676 Store->getPointerInfo(), Store->getOriginalAlign(),
3677 Store->getMemOperand()->getFlags());
3678 }
3679
lowerConstant(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)3680 static SDValue lowerConstant(SDValue Op, SelectionDAG &DAG,
3681 const RISCVSubtarget &Subtarget) {
3682 assert(Op.getValueType() == MVT::i64 && "Unexpected VT");
3683
3684 int64_t Imm = cast<ConstantSDNode>(Op)->getSExtValue();
3685
3686 // All simm32 constants should be handled by isel.
3687 // NOTE: The getMaxBuildIntsCost call below should return a value >= 2 making
3688 // this check redundant, but small immediates are common so this check
3689 // should have better compile time.
3690 if (isInt<32>(Imm))
3691 return Op;
3692
3693 // We only need to cost the immediate, if constant pool lowering is enabled.
3694 if (!Subtarget.useConstantPoolForLargeInts())
3695 return Op;
3696
3697 RISCVMatInt::InstSeq Seq =
3698 RISCVMatInt::generateInstSeq(Imm, Subtarget.getFeatureBits());
3699 if (Seq.size() <= Subtarget.getMaxBuildIntsCost())
3700 return Op;
3701
3702 // Expand to a constant pool using the default expansion code.
3703 return SDValue();
3704 }
3705
LowerATOMIC_FENCE(SDValue Op,SelectionDAG & DAG)3706 static SDValue LowerATOMIC_FENCE(SDValue Op, SelectionDAG &DAG) {
3707 SDLoc dl(Op);
3708 SyncScope::ID FenceSSID =
3709 static_cast<SyncScope::ID>(Op.getConstantOperandVal(2));
3710
3711 // singlethread fences only synchronize with signal handlers on the same
3712 // thread and thus only need to preserve instruction order, not actually
3713 // enforce memory ordering.
3714 if (FenceSSID == SyncScope::SingleThread)
3715 // MEMBARRIER is a compiler barrier; it codegens to a no-op.
3716 return DAG.getNode(ISD::MEMBARRIER, dl, MVT::Other, Op.getOperand(0));
3717
3718 return Op;
3719 }
3720
LowerOperation(SDValue Op,SelectionDAG & DAG) const3721 SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
3722 SelectionDAG &DAG) const {
3723 switch (Op.getOpcode()) {
3724 default:
3725 report_fatal_error("unimplemented operand");
3726 case ISD::ATOMIC_FENCE:
3727 return LowerATOMIC_FENCE(Op, DAG);
3728 case ISD::GlobalAddress:
3729 return lowerGlobalAddress(Op, DAG);
3730 case ISD::BlockAddress:
3731 return lowerBlockAddress(Op, DAG);
3732 case ISD::ConstantPool:
3733 return lowerConstantPool(Op, DAG);
3734 case ISD::JumpTable:
3735 return lowerJumpTable(Op, DAG);
3736 case ISD::GlobalTLSAddress:
3737 return lowerGlobalTLSAddress(Op, DAG);
3738 case ISD::Constant:
3739 return lowerConstant(Op, DAG, Subtarget);
3740 case ISD::SELECT:
3741 return lowerSELECT(Op, DAG);
3742 case ISD::BRCOND:
3743 return lowerBRCOND(Op, DAG);
3744 case ISD::VASTART:
3745 return lowerVASTART(Op, DAG);
3746 case ISD::FRAMEADDR:
3747 return lowerFRAMEADDR(Op, DAG);
3748 case ISD::RETURNADDR:
3749 return lowerRETURNADDR(Op, DAG);
3750 case ISD::SHL_PARTS:
3751 return lowerShiftLeftParts(Op, DAG);
3752 case ISD::SRA_PARTS:
3753 return lowerShiftRightParts(Op, DAG, true);
3754 case ISD::SRL_PARTS:
3755 return lowerShiftRightParts(Op, DAG, false);
3756 case ISD::BITCAST: {
3757 SDLoc DL(Op);
3758 EVT VT = Op.getValueType();
3759 SDValue Op0 = Op.getOperand(0);
3760 EVT Op0VT = Op0.getValueType();
3761 MVT XLenVT = Subtarget.getXLenVT();
3762 if (VT == MVT::f16 && Op0VT == MVT::i16 &&
3763 Subtarget.hasStdExtZfhOrZfhmin()) {
3764 SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0);
3765 SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0);
3766 return FPConv;
3767 }
3768 if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() &&
3769 Subtarget.hasStdExtF()) {
3770 SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0);
3771 SDValue FPConv =
3772 DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, NewOp0);
3773 return FPConv;
3774 }
3775
3776 // Consider other scalar<->scalar casts as legal if the types are legal.
3777 // Otherwise expand them.
3778 if (!VT.isVector() && !Op0VT.isVector()) {
3779 if (isTypeLegal(VT) && isTypeLegal(Op0VT))
3780 return Op;
3781 return SDValue();
3782 }
3783
3784 assert(!VT.isScalableVector() && !Op0VT.isScalableVector() &&
3785 "Unexpected types");
3786
3787 if (VT.isFixedLengthVector()) {
3788 // We can handle fixed length vector bitcasts with a simple replacement
3789 // in isel.
3790 if (Op0VT.isFixedLengthVector())
3791 return Op;
3792 // When bitcasting from scalar to fixed-length vector, insert the scalar
3793 // into a one-element vector of the result type, and perform a vector
3794 // bitcast.
3795 if (!Op0VT.isVector()) {
3796 EVT BVT = EVT::getVectorVT(*DAG.getContext(), Op0VT, 1);
3797 if (!isTypeLegal(BVT))
3798 return SDValue();
3799 return DAG.getBitcast(VT, DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, BVT,
3800 DAG.getUNDEF(BVT), Op0,
3801 DAG.getConstant(0, DL, XLenVT)));
3802 }
3803 return SDValue();
3804 }
3805 // Custom-legalize bitcasts from fixed-length vector types to scalar types
3806 // thus: bitcast the vector to a one-element vector type whose element type
3807 // is the same as the result type, and extract the first element.
3808 if (!VT.isVector() && Op0VT.isFixedLengthVector()) {
3809 EVT BVT = EVT::getVectorVT(*DAG.getContext(), VT, 1);
3810 if (!isTypeLegal(BVT))
3811 return SDValue();
3812 SDValue BVec = DAG.getBitcast(BVT, Op0);
3813 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
3814 DAG.getConstant(0, DL, XLenVT));
3815 }
3816 return SDValue();
3817 }
3818 case ISD::INTRINSIC_WO_CHAIN:
3819 return LowerINTRINSIC_WO_CHAIN(Op, DAG);
3820 case ISD::INTRINSIC_W_CHAIN:
3821 return LowerINTRINSIC_W_CHAIN(Op, DAG);
3822 case ISD::INTRINSIC_VOID:
3823 return LowerINTRINSIC_VOID(Op, DAG);
3824 case ISD::BITREVERSE: {
3825 MVT VT = Op.getSimpleValueType();
3826 SDLoc DL(Op);
3827 assert(Subtarget.hasStdExtZbkb() && "Unexpected custom legalization");
3828 assert(Op.getOpcode() == ISD::BITREVERSE && "Unexpected opcode");
3829 // Expand bitreverse to a bswap(rev8) followed by brev8.
3830 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Op.getOperand(0));
3831 return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
3832 }
3833 case ISD::TRUNCATE:
3834 // Only custom-lower vector truncates
3835 if (!Op.getSimpleValueType().isVector())
3836 return Op;
3837 return lowerVectorTruncLike(Op, DAG);
3838 case ISD::ANY_EXTEND:
3839 case ISD::ZERO_EXTEND:
3840 if (Op.getOperand(0).getValueType().isVector() &&
3841 Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
3842 return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ 1);
3843 return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VZEXT_VL);
3844 case ISD::SIGN_EXTEND:
3845 if (Op.getOperand(0).getValueType().isVector() &&
3846 Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
3847 return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
3848 return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
3849 case ISD::SPLAT_VECTOR_PARTS:
3850 return lowerSPLAT_VECTOR_PARTS(Op, DAG);
3851 case ISD::INSERT_VECTOR_ELT:
3852 return lowerINSERT_VECTOR_ELT(Op, DAG);
3853 case ISD::EXTRACT_VECTOR_ELT:
3854 return lowerEXTRACT_VECTOR_ELT(Op, DAG);
3855 case ISD::VSCALE: {
3856 MVT VT = Op.getSimpleValueType();
3857 SDLoc DL(Op);
3858 SDValue VLENB = DAG.getNode(RISCVISD::READ_VLENB, DL, VT);
3859 // We define our scalable vector types for lmul=1 to use a 64 bit known
3860 // minimum size. e.g. <vscale x 2 x i32>. VLENB is in bytes so we calculate
3861 // vscale as VLENB / 8.
3862 static_assert(RISCV::RVVBitsPerBlock == 64, "Unexpected bits per block!");
3863 if (Subtarget.getRealMinVLen() < RISCV::RVVBitsPerBlock)
3864 report_fatal_error("Support for VLEN==32 is incomplete.");
3865 // We assume VLENB is a multiple of 8. We manually choose the best shift
3866 // here because SimplifyDemandedBits isn't always able to simplify it.
3867 uint64_t Val = Op.getConstantOperandVal(0);
3868 if (isPowerOf2_64(Val)) {
3869 uint64_t Log2 = Log2_64(Val);
3870 if (Log2 < 3)
3871 return DAG.getNode(ISD::SRL, DL, VT, VLENB,
3872 DAG.getConstant(3 - Log2, DL, VT));
3873 if (Log2 > 3)
3874 return DAG.getNode(ISD::SHL, DL, VT, VLENB,
3875 DAG.getConstant(Log2 - 3, DL, VT));
3876 return VLENB;
3877 }
3878 // If the multiplier is a multiple of 8, scale it down to avoid needing
3879 // to shift the VLENB value.
3880 if ((Val % 8) == 0)
3881 return DAG.getNode(ISD::MUL, DL, VT, VLENB,
3882 DAG.getConstant(Val / 8, DL, VT));
3883
3884 SDValue VScale = DAG.getNode(ISD::SRL, DL, VT, VLENB,
3885 DAG.getConstant(3, DL, VT));
3886 return DAG.getNode(ISD::MUL, DL, VT, VScale, Op.getOperand(0));
3887 }
3888 case ISD::FPOWI: {
3889 // Custom promote f16 powi with illegal i32 integer type on RV64. Once
3890 // promoted this will be legalized into a libcall by LegalizeIntegerTypes.
3891 if (Op.getValueType() == MVT::f16 && Subtarget.is64Bit() &&
3892 Op.getOperand(1).getValueType() == MVT::i32) {
3893 SDLoc DL(Op);
3894 SDValue Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
3895 SDValue Powi =
3896 DAG.getNode(ISD::FPOWI, DL, MVT::f32, Op0, Op.getOperand(1));
3897 return DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, Powi,
3898 DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
3899 }
3900 return SDValue();
3901 }
3902 case ISD::FP_EXTEND:
3903 case ISD::FP_ROUND:
3904 if (!Op.getValueType().isVector())
3905 return Op;
3906 return lowerVectorFPExtendOrRoundLike(Op, DAG);
3907 case ISD::FP_TO_SINT:
3908 case ISD::FP_TO_UINT:
3909 case ISD::SINT_TO_FP:
3910 case ISD::UINT_TO_FP: {
3911 // RVV can only do fp<->int conversions to types half/double the size as
3912 // the source. We custom-lower any conversions that do two hops into
3913 // sequences.
3914 MVT VT = Op.getSimpleValueType();
3915 if (!VT.isVector())
3916 return Op;
3917 SDLoc DL(Op);
3918 SDValue Src = Op.getOperand(0);
3919 MVT EltVT = VT.getVectorElementType();
3920 MVT SrcVT = Src.getSimpleValueType();
3921 MVT SrcEltVT = SrcVT.getVectorElementType();
3922 unsigned EltSize = EltVT.getSizeInBits();
3923 unsigned SrcEltSize = SrcEltVT.getSizeInBits();
3924 assert(isPowerOf2_32(EltSize) && isPowerOf2_32(SrcEltSize) &&
3925 "Unexpected vector element types");
3926
3927 bool IsInt2FP = SrcEltVT.isInteger();
3928 // Widening conversions
3929 if (EltSize > (2 * SrcEltSize)) {
3930 if (IsInt2FP) {
3931 // Do a regular integer sign/zero extension then convert to float.
3932 MVT IVecVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize / 2),
3933 VT.getVectorElementCount());
3934 unsigned ExtOpcode = Op.getOpcode() == ISD::UINT_TO_FP
3935 ? ISD::ZERO_EXTEND
3936 : ISD::SIGN_EXTEND;
3937 SDValue Ext = DAG.getNode(ExtOpcode, DL, IVecVT, Src);
3938 return DAG.getNode(Op.getOpcode(), DL, VT, Ext);
3939 }
3940 // FP2Int
3941 assert(SrcEltVT == MVT::f16 && "Unexpected FP_TO_[US]INT lowering");
3942 // Do one doubling fp_extend then complete the operation by converting
3943 // to int.
3944 MVT InterimFVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
3945 SDValue FExt = DAG.getFPExtendOrRound(Src, DL, InterimFVT);
3946 return DAG.getNode(Op.getOpcode(), DL, VT, FExt);
3947 }
3948
3949 // Narrowing conversions
3950 if (SrcEltSize > (2 * EltSize)) {
3951 if (IsInt2FP) {
3952 // One narrowing int_to_fp, then an fp_round.
3953 assert(EltVT == MVT::f16 && "Unexpected [US]_TO_FP lowering");
3954 MVT InterimFVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
3955 SDValue Int2FP = DAG.getNode(Op.getOpcode(), DL, InterimFVT, Src);
3956 return DAG.getFPExtendOrRound(Int2FP, DL, VT);
3957 }
3958 // FP2Int
3959 // One narrowing fp_to_int, then truncate the integer. If the float isn't
3960 // representable by the integer, the result is poison.
3961 MVT IVecVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2),
3962 VT.getVectorElementCount());
3963 SDValue FP2Int = DAG.getNode(Op.getOpcode(), DL, IVecVT, Src);
3964 return DAG.getNode(ISD::TRUNCATE, DL, VT, FP2Int);
3965 }
3966
3967 // Scalable vectors can exit here. Patterns will handle equally-sized
3968 // conversions halving/doubling ones.
3969 if (!VT.isFixedLengthVector())
3970 return Op;
3971
3972 // For fixed-length vectors we lower to a custom "VL" node.
3973 unsigned RVVOpc = 0;
3974 switch (Op.getOpcode()) {
3975 default:
3976 llvm_unreachable("Impossible opcode");
3977 case ISD::FP_TO_SINT:
3978 RVVOpc = RISCVISD::VFCVT_RTZ_X_F_VL;
3979 break;
3980 case ISD::FP_TO_UINT:
3981 RVVOpc = RISCVISD::VFCVT_RTZ_XU_F_VL;
3982 break;
3983 case ISD::SINT_TO_FP:
3984 RVVOpc = RISCVISD::SINT_TO_FP_VL;
3985 break;
3986 case ISD::UINT_TO_FP:
3987 RVVOpc = RISCVISD::UINT_TO_FP_VL;
3988 break;
3989 }
3990
3991 MVT ContainerVT = getContainerForFixedLengthVector(VT);
3992 MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
3993 assert(ContainerVT.getVectorElementCount() == SrcContainerVT.getVectorElementCount() &&
3994 "Expected same element count");
3995
3996 auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3997
3998 Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
3999 Src = DAG.getNode(RVVOpc, DL, ContainerVT, Src, Mask, VL);
4000 return convertFromScalableVector(VT, Src, DAG, Subtarget);
4001 }
4002 case ISD::FP_TO_SINT_SAT:
4003 case ISD::FP_TO_UINT_SAT:
4004 return lowerFP_TO_INT_SAT(Op, DAG, Subtarget);
4005 case ISD::FTRUNC:
4006 case ISD::FCEIL:
4007 case ISD::FFLOOR:
4008 case ISD::FRINT:
4009 case ISD::FROUND:
4010 case ISD::FROUNDEVEN:
4011 return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
4012 case ISD::VECREDUCE_ADD:
4013 case ISD::VECREDUCE_UMAX:
4014 case ISD::VECREDUCE_SMAX:
4015 case ISD::VECREDUCE_UMIN:
4016 case ISD::VECREDUCE_SMIN:
4017 return lowerVECREDUCE(Op, DAG);
4018 case ISD::VECREDUCE_AND:
4019 case ISD::VECREDUCE_OR:
4020 case ISD::VECREDUCE_XOR:
4021 if (Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
4022 return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ false);
4023 return lowerVECREDUCE(Op, DAG);
4024 case ISD::VECREDUCE_FADD:
4025 case ISD::VECREDUCE_SEQ_FADD:
4026 case ISD::VECREDUCE_FMIN:
4027 case ISD::VECREDUCE_FMAX:
4028 return lowerFPVECREDUCE(Op, DAG);
4029 case ISD::VP_REDUCE_ADD:
4030 case ISD::VP_REDUCE_UMAX:
4031 case ISD::VP_REDUCE_SMAX:
4032 case ISD::VP_REDUCE_UMIN:
4033 case ISD::VP_REDUCE_SMIN:
4034 case ISD::VP_REDUCE_FADD:
4035 case ISD::VP_REDUCE_SEQ_FADD:
4036 case ISD::VP_REDUCE_FMIN:
4037 case ISD::VP_REDUCE_FMAX:
4038 return lowerVPREDUCE(Op, DAG);
4039 case ISD::VP_REDUCE_AND:
4040 case ISD::VP_REDUCE_OR:
4041 case ISD::VP_REDUCE_XOR:
4042 if (Op.getOperand(1).getValueType().getVectorElementType() == MVT::i1)
4043 return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ true);
4044 return lowerVPREDUCE(Op, DAG);
4045 case ISD::INSERT_SUBVECTOR:
4046 return lowerINSERT_SUBVECTOR(Op, DAG);
4047 case ISD::EXTRACT_SUBVECTOR:
4048 return lowerEXTRACT_SUBVECTOR(Op, DAG);
4049 case ISD::STEP_VECTOR:
4050 return lowerSTEP_VECTOR(Op, DAG);
4051 case ISD::VECTOR_REVERSE:
4052 return lowerVECTOR_REVERSE(Op, DAG);
4053 case ISD::VECTOR_SPLICE:
4054 return lowerVECTOR_SPLICE(Op, DAG);
4055 case ISD::BUILD_VECTOR:
4056 return lowerBUILD_VECTOR(Op, DAG, Subtarget);
4057 case ISD::SPLAT_VECTOR:
4058 if (Op.getValueType().getVectorElementType() == MVT::i1)
4059 return lowerVectorMaskSplat(Op, DAG);
4060 return SDValue();
4061 case ISD::VECTOR_SHUFFLE:
4062 return lowerVECTOR_SHUFFLE(Op, DAG, Subtarget);
4063 case ISD::CONCAT_VECTORS: {
4064 // Split CONCAT_VECTORS into a series of INSERT_SUBVECTOR nodes. This is
4065 // better than going through the stack, as the default expansion does.
4066 SDLoc DL(Op);
4067 MVT VT = Op.getSimpleValueType();
4068 unsigned NumOpElts =
4069 Op.getOperand(0).getSimpleValueType().getVectorMinNumElements();
4070 SDValue Vec = DAG.getUNDEF(VT);
4071 for (const auto &OpIdx : enumerate(Op->ops())) {
4072 SDValue SubVec = OpIdx.value();
4073 // Don't insert undef subvectors.
4074 if (SubVec.isUndef())
4075 continue;
4076 Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Vec, SubVec,
4077 DAG.getIntPtrConstant(OpIdx.index() * NumOpElts, DL));
4078 }
4079 return Vec;
4080 }
4081 case ISD::LOAD:
4082 if (auto V = expandUnalignedRVVLoad(Op, DAG))
4083 return V;
4084 if (Op.getValueType().isFixedLengthVector())
4085 return lowerFixedLengthVectorLoadToRVV(Op, DAG);
4086 return Op;
4087 case ISD::STORE:
4088 if (auto V = expandUnalignedRVVStore(Op, DAG))
4089 return V;
4090 if (Op.getOperand(1).getValueType().isFixedLengthVector())
4091 return lowerFixedLengthVectorStoreToRVV(Op, DAG);
4092 return Op;
4093 case ISD::MLOAD:
4094 case ISD::VP_LOAD:
4095 return lowerMaskedLoad(Op, DAG);
4096 case ISD::MSTORE:
4097 case ISD::VP_STORE:
4098 return lowerMaskedStore(Op, DAG);
4099 case ISD::SELECT_CC: {
4100 // This occurs because we custom legalize SETGT and SETUGT for setcc. That
4101 // causes LegalizeDAG to think we need to custom legalize select_cc. Expand
4102 // into separate SETCC+SELECT_CC just like LegalizeDAG.
4103 SDValue Tmp1 = Op.getOperand(0);
4104 SDValue Tmp2 = Op.getOperand(1);
4105 SDValue True = Op.getOperand(2);
4106 SDValue False = Op.getOperand(3);
4107 EVT VT = Op.getValueType();
4108 SDValue CC = Op.getOperand(4);
4109 EVT CmpVT = Tmp1.getValueType();
4110 EVT CCVT =
4111 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), CmpVT);
4112 SDLoc DL(Op);
4113 SDValue Cond =
4114 DAG.getNode(ISD::SETCC, DL, CCVT, Tmp1, Tmp2, CC, Op->getFlags());
4115 return DAG.getSelect(DL, VT, Cond, True, False);
4116 }
4117 case ISD::SETCC: {
4118 MVT OpVT = Op.getOperand(0).getSimpleValueType();
4119 if (OpVT.isScalarInteger()) {
4120 MVT VT = Op.getSimpleValueType();
4121 SDValue LHS = Op.getOperand(0);
4122 SDValue RHS = Op.getOperand(1);
4123 ISD::CondCode CCVal = cast<CondCodeSDNode>(Op.getOperand(2))->get();
4124 assert((CCVal == ISD::SETGT || CCVal == ISD::SETUGT) &&
4125 "Unexpected CondCode");
4126
4127 SDLoc DL(Op);
4128
4129 // If the RHS is a constant in the range [-2049, 0) or (0, 2046], we can
4130 // convert this to the equivalent of (set(u)ge X, C+1) by using
4131 // (xori (slti(u) X, C+1), 1). This avoids materializing a small constant
4132 // in a register.
4133 if (isa<ConstantSDNode>(RHS)) {
4134 int64_t Imm = cast<ConstantSDNode>(RHS)->getSExtValue();
4135 if (Imm != 0 && isInt<12>((uint64_t)Imm + 1)) {
4136 // X > -1 should have been replaced with false.
4137 assert((CCVal != ISD::SETUGT || Imm != -1) &&
4138 "Missing canonicalization");
4139 // Using getSetCCSwappedOperands will convert SET(U)GT->SET(U)LT.
4140 CCVal = ISD::getSetCCSwappedOperands(CCVal);
4141 SDValue SetCC = DAG.getSetCC(
4142 DL, VT, LHS, DAG.getConstant(Imm + 1, DL, OpVT), CCVal);
4143 return DAG.getLogicalNOT(DL, SetCC, VT);
4144 }
4145 }
4146
4147 // Not a constant we could handle, swap the operands and condition code to
4148 // SETLT/SETULT.
4149 CCVal = ISD::getSetCCSwappedOperands(CCVal);
4150 return DAG.getSetCC(DL, VT, RHS, LHS, CCVal);
4151 }
4152
4153 return lowerFixedLengthVectorSetccToRVV(Op, DAG);
4154 }
4155 case ISD::ADD:
4156 return lowerToScalableOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true);
4157 case ISD::SUB:
4158 return lowerToScalableOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true);
4159 case ISD::MUL:
4160 return lowerToScalableOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true);
4161 case ISD::MULHS:
4162 return lowerToScalableOp(Op, DAG, RISCVISD::MULHS_VL, /*HasMergeOp*/ true);
4163 case ISD::MULHU:
4164 return lowerToScalableOp(Op, DAG, RISCVISD::MULHU_VL, /*HasMergeOp*/ true);
4165 case ISD::AND:
4166 return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMAND_VL,
4167 RISCVISD::AND_VL);
4168 case ISD::OR:
4169 return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMOR_VL,
4170 RISCVISD::OR_VL);
4171 case ISD::XOR:
4172 return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMXOR_VL,
4173 RISCVISD::XOR_VL);
4174 case ISD::SDIV:
4175 return lowerToScalableOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true);
4176 case ISD::SREM:
4177 return lowerToScalableOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true);
4178 case ISD::UDIV:
4179 return lowerToScalableOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true);
4180 case ISD::UREM:
4181 return lowerToScalableOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true);
4182 case ISD::SHL:
4183 case ISD::SRA:
4184 case ISD::SRL:
4185 if (Op.getSimpleValueType().isFixedLengthVector())
4186 return lowerFixedLengthVectorShiftToRVV(Op, DAG);
4187 // This can be called for an i32 shift amount that needs to be promoted.
4188 assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
4189 "Unexpected custom legalisation");
4190 return SDValue();
4191 case ISD::SADDSAT:
4192 return lowerToScalableOp(Op, DAG, RISCVISD::SADDSAT_VL,
4193 /*HasMergeOp*/ true);
4194 case ISD::UADDSAT:
4195 return lowerToScalableOp(Op, DAG, RISCVISD::UADDSAT_VL,
4196 /*HasMergeOp*/ true);
4197 case ISD::SSUBSAT:
4198 return lowerToScalableOp(Op, DAG, RISCVISD::SSUBSAT_VL,
4199 /*HasMergeOp*/ true);
4200 case ISD::USUBSAT:
4201 return lowerToScalableOp(Op, DAG, RISCVISD::USUBSAT_VL,
4202 /*HasMergeOp*/ true);
4203 case ISD::FADD:
4204 return lowerToScalableOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true);
4205 case ISD::FSUB:
4206 return lowerToScalableOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true);
4207 case ISD::FMUL:
4208 return lowerToScalableOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true);
4209 case ISD::FDIV:
4210 return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true);
4211 case ISD::FNEG:
4212 return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL);
4213 case ISD::FABS:
4214 return lowerToScalableOp(Op, DAG, RISCVISD::FABS_VL);
4215 case ISD::FSQRT:
4216 return lowerToScalableOp(Op, DAG, RISCVISD::FSQRT_VL);
4217 case ISD::FMA:
4218 return lowerToScalableOp(Op, DAG, RISCVISD::VFMADD_VL);
4219 case ISD::SMIN:
4220 return lowerToScalableOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true);
4221 case ISD::SMAX:
4222 return lowerToScalableOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true);
4223 case ISD::UMIN:
4224 return lowerToScalableOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true);
4225 case ISD::UMAX:
4226 return lowerToScalableOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true);
4227 case ISD::FMINNUM:
4228 return lowerToScalableOp(Op, DAG, RISCVISD::FMINNUM_VL,
4229 /*HasMergeOp*/ true);
4230 case ISD::FMAXNUM:
4231 return lowerToScalableOp(Op, DAG, RISCVISD::FMAXNUM_VL,
4232 /*HasMergeOp*/ true);
4233 case ISD::ABS:
4234 case ISD::VP_ABS:
4235 return lowerABS(Op, DAG);
4236 case ISD::CTLZ:
4237 case ISD::CTLZ_ZERO_UNDEF:
4238 case ISD::CTTZ_ZERO_UNDEF:
4239 return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
4240 case ISD::VSELECT:
4241 return lowerFixedLengthVectorSelectToRVV(Op, DAG);
4242 case ISD::FCOPYSIGN:
4243 return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
4244 case ISD::MGATHER:
4245 case ISD::VP_GATHER:
4246 return lowerMaskedGather(Op, DAG);
4247 case ISD::MSCATTER:
4248 case ISD::VP_SCATTER:
4249 return lowerMaskedScatter(Op, DAG);
4250 case ISD::GET_ROUNDING:
4251 return lowerGET_ROUNDING(Op, DAG);
4252 case ISD::SET_ROUNDING:
4253 return lowerSET_ROUNDING(Op, DAG);
4254 case ISD::EH_DWARF_CFA:
4255 return lowerEH_DWARF_CFA(Op, DAG);
4256 case ISD::VP_SELECT:
4257 return lowerVPOp(Op, DAG, RISCVISD::VSELECT_VL);
4258 case ISD::VP_MERGE:
4259 return lowerVPOp(Op, DAG, RISCVISD::VP_MERGE_VL);
4260 case ISD::VP_ADD:
4261 return lowerVPOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true);
4262 case ISD::VP_SUB:
4263 return lowerVPOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true);
4264 case ISD::VP_MUL:
4265 return lowerVPOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true);
4266 case ISD::VP_SDIV:
4267 return lowerVPOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true);
4268 case ISD::VP_UDIV:
4269 return lowerVPOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true);
4270 case ISD::VP_SREM:
4271 return lowerVPOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true);
4272 case ISD::VP_UREM:
4273 return lowerVPOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true);
4274 case ISD::VP_AND:
4275 return lowerLogicVPOp(Op, DAG, RISCVISD::VMAND_VL, RISCVISD::AND_VL);
4276 case ISD::VP_OR:
4277 return lowerLogicVPOp(Op, DAG, RISCVISD::VMOR_VL, RISCVISD::OR_VL);
4278 case ISD::VP_XOR:
4279 return lowerLogicVPOp(Op, DAG, RISCVISD::VMXOR_VL, RISCVISD::XOR_VL);
4280 case ISD::VP_ASHR:
4281 return lowerVPOp(Op, DAG, RISCVISD::SRA_VL, /*HasMergeOp*/ true);
4282 case ISD::VP_LSHR:
4283 return lowerVPOp(Op, DAG, RISCVISD::SRL_VL, /*HasMergeOp*/ true);
4284 case ISD::VP_SHL:
4285 return lowerVPOp(Op, DAG, RISCVISD::SHL_VL, /*HasMergeOp*/ true);
4286 case ISD::VP_FADD:
4287 return lowerVPOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true);
4288 case ISD::VP_FSUB:
4289 return lowerVPOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true);
4290 case ISD::VP_FMUL:
4291 return lowerVPOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true);
4292 case ISD::VP_FDIV:
4293 return lowerVPOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true);
4294 case ISD::VP_FNEG:
4295 return lowerVPOp(Op, DAG, RISCVISD::FNEG_VL);
4296 case ISD::VP_FABS:
4297 return lowerVPOp(Op, DAG, RISCVISD::FABS_VL);
4298 case ISD::VP_SQRT:
4299 return lowerVPOp(Op, DAG, RISCVISD::FSQRT_VL);
4300 case ISD::VP_FMA:
4301 return lowerVPOp(Op, DAG, RISCVISD::VFMADD_VL);
4302 case ISD::VP_FMINNUM:
4303 return lowerVPOp(Op, DAG, RISCVISD::FMINNUM_VL, /*HasMergeOp*/ true);
4304 case ISD::VP_FMAXNUM:
4305 return lowerVPOp(Op, DAG, RISCVISD::FMAXNUM_VL, /*HasMergeOp*/ true);
4306 case ISD::VP_FCOPYSIGN:
4307 return lowerVPOp(Op, DAG, RISCVISD::FCOPYSIGN_VL, /*HasMergeOp*/ true);
4308 case ISD::VP_SIGN_EXTEND:
4309 case ISD::VP_ZERO_EXTEND:
4310 if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
4311 return lowerVPExtMaskOp(Op, DAG);
4312 return lowerVPOp(Op, DAG,
4313 Op.getOpcode() == ISD::VP_SIGN_EXTEND
4314 ? RISCVISD::VSEXT_VL
4315 : RISCVISD::VZEXT_VL);
4316 case ISD::VP_TRUNCATE:
4317 return lowerVectorTruncLike(Op, DAG);
4318 case ISD::VP_FP_EXTEND:
4319 case ISD::VP_FP_ROUND:
4320 return lowerVectorFPExtendOrRoundLike(Op, DAG);
4321 case ISD::VP_FP_TO_SINT:
4322 return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_X_F_VL);
4323 case ISD::VP_FP_TO_UINT:
4324 return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_XU_F_VL);
4325 case ISD::VP_SINT_TO_FP:
4326 return lowerVPFPIntConvOp(Op, DAG, RISCVISD::SINT_TO_FP_VL);
4327 case ISD::VP_UINT_TO_FP:
4328 return lowerVPFPIntConvOp(Op, DAG, RISCVISD::UINT_TO_FP_VL);
4329 case ISD::VP_SETCC:
4330 if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
4331 return lowerVPSetCCMaskOp(Op, DAG);
4332 return lowerVPOp(Op, DAG, RISCVISD::SETCC_VL, /*HasMergeOp*/ true);
4333 case ISD::VP_SMIN:
4334 return lowerVPOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true);
4335 case ISD::VP_SMAX:
4336 return lowerVPOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true);
4337 case ISD::VP_UMIN:
4338 return lowerVPOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true);
4339 case ISD::VP_UMAX:
4340 return lowerVPOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true);
4341 case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
4342 return lowerVPStridedLoad(Op, DAG);
4343 case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
4344 return lowerVPStridedStore(Op, DAG);
4345 case ISD::VP_FCEIL:
4346 case ISD::VP_FFLOOR:
4347 case ISD::VP_FRINT:
4348 case ISD::VP_FNEARBYINT:
4349 case ISD::VP_FROUND:
4350 case ISD::VP_FROUNDEVEN:
4351 case ISD::VP_FROUNDTOZERO:
4352 return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
4353 }
4354 }
4355
getTargetNode(GlobalAddressSDNode * N,SDLoc DL,EVT Ty,SelectionDAG & DAG,unsigned Flags)4356 static SDValue getTargetNode(GlobalAddressSDNode *N, SDLoc DL, EVT Ty,
4357 SelectionDAG &DAG, unsigned Flags) {
4358 return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);
4359 }
4360
getTargetNode(BlockAddressSDNode * N,SDLoc DL,EVT Ty,SelectionDAG & DAG,unsigned Flags)4361 static SDValue getTargetNode(BlockAddressSDNode *N, SDLoc DL, EVT Ty,
4362 SelectionDAG &DAG, unsigned Flags) {
4363 return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(),
4364 Flags);
4365 }
4366
getTargetNode(ConstantPoolSDNode * N,SDLoc DL,EVT Ty,SelectionDAG & DAG,unsigned Flags)4367 static SDValue getTargetNode(ConstantPoolSDNode *N, SDLoc DL, EVT Ty,
4368 SelectionDAG &DAG, unsigned Flags) {
4369 return DAG.getTargetConstantPool(N->getConstVal(), Ty, N->getAlign(),
4370 N->getOffset(), Flags);
4371 }
4372
getTargetNode(JumpTableSDNode * N,SDLoc DL,EVT Ty,SelectionDAG & DAG,unsigned Flags)4373 static SDValue getTargetNode(JumpTableSDNode *N, SDLoc DL, EVT Ty,
4374 SelectionDAG &DAG, unsigned Flags) {
4375 return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags);
4376 }
4377
4378 template <class NodeTy>
getAddr(NodeTy * N,SelectionDAG & DAG,bool IsLocal) const4379 SDValue RISCVTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
4380 bool IsLocal) const {
4381 SDLoc DL(N);
4382 EVT Ty = getPointerTy(DAG.getDataLayout());
4383
4384 // When HWASAN is used and tagging of global variables is enabled
4385 // they should be accessed via the GOT, since the tagged address of a global
4386 // is incompatible with existing code models. This also applies to non-pic
4387 // mode.
4388 if (isPositionIndependent() || Subtarget.allowTaggedGlobals()) {
4389 SDValue Addr = getTargetNode(N, DL, Ty, DAG, 0);
4390 if (IsLocal && !Subtarget.allowTaggedGlobals())
4391 // Use PC-relative addressing to access the symbol. This generates the
4392 // pattern (PseudoLLA sym), which expands to (addi (auipc %pcrel_hi(sym))
4393 // %pcrel_lo(auipc)).
4394 return DAG.getNode(RISCVISD::LLA, DL, Ty, Addr);
4395
4396 // Use PC-relative addressing to access the GOT for this symbol, then load
4397 // the address from the GOT. This generates the pattern (PseudoLA sym),
4398 // which expands to (ld (addi (auipc %got_pcrel_hi(sym)) %pcrel_lo(auipc))).
4399 MachineFunction &MF = DAG.getMachineFunction();
4400 MachineMemOperand *MemOp = MF.getMachineMemOperand(
4401 MachinePointerInfo::getGOT(MF),
4402 MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable |
4403 MachineMemOperand::MOInvariant,
4404 LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8));
4405 SDValue Load =
4406 DAG.getMemIntrinsicNode(RISCVISD::LA, DL, DAG.getVTList(Ty, MVT::Other),
4407 {DAG.getEntryNode(), Addr}, Ty, MemOp);
4408 return Load;
4409 }
4410
4411 switch (getTargetMachine().getCodeModel()) {
4412 default:
4413 report_fatal_error("Unsupported code model for lowering");
4414 case CodeModel::Small: {
4415 // Generate a sequence for accessing addresses within the first 2 GiB of
4416 // address space. This generates the pattern (addi (lui %hi(sym)) %lo(sym)).
4417 SDValue AddrHi = getTargetNode(N, DL, Ty, DAG, RISCVII::MO_HI);
4418 SDValue AddrLo = getTargetNode(N, DL, Ty, DAG, RISCVII::MO_LO);
4419 SDValue MNHi = DAG.getNode(RISCVISD::HI, DL, Ty, AddrHi);
4420 return DAG.getNode(RISCVISD::ADD_LO, DL, Ty, MNHi, AddrLo);
4421 }
4422 case CodeModel::Medium: {
4423 // Generate a sequence for accessing addresses within any 2GiB range within
4424 // the address space. This generates the pattern (PseudoLLA sym), which
4425 // expands to (addi (auipc %pcrel_hi(sym)) %pcrel_lo(auipc)).
4426 SDValue Addr = getTargetNode(N, DL, Ty, DAG, 0);
4427 return DAG.getNode(RISCVISD::LLA, DL, Ty, Addr);
4428 }
4429 }
4430 }
4431
lowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const4432 SDValue RISCVTargetLowering::lowerGlobalAddress(SDValue Op,
4433 SelectionDAG &DAG) const {
4434 GlobalAddressSDNode *N = cast<GlobalAddressSDNode>(Op);
4435 assert(N->getOffset() == 0 && "unexpected offset in global node");
4436 return getAddr(N, DAG, N->getGlobal()->isDSOLocal());
4437 }
4438
lowerBlockAddress(SDValue Op,SelectionDAG & DAG) const4439 SDValue RISCVTargetLowering::lowerBlockAddress(SDValue Op,
4440 SelectionDAG &DAG) const {
4441 BlockAddressSDNode *N = cast<BlockAddressSDNode>(Op);
4442
4443 return getAddr(N, DAG);
4444 }
4445
lowerConstantPool(SDValue Op,SelectionDAG & DAG) const4446 SDValue RISCVTargetLowering::lowerConstantPool(SDValue Op,
4447 SelectionDAG &DAG) const {
4448 ConstantPoolSDNode *N = cast<ConstantPoolSDNode>(Op);
4449
4450 return getAddr(N, DAG);
4451 }
4452
lowerJumpTable(SDValue Op,SelectionDAG & DAG) const4453 SDValue RISCVTargetLowering::lowerJumpTable(SDValue Op,
4454 SelectionDAG &DAG) const {
4455 JumpTableSDNode *N = cast<JumpTableSDNode>(Op);
4456
4457 return getAddr(N, DAG);
4458 }
4459
getStaticTLSAddr(GlobalAddressSDNode * N,SelectionDAG & DAG,bool UseGOT) const4460 SDValue RISCVTargetLowering::getStaticTLSAddr(GlobalAddressSDNode *N,
4461 SelectionDAG &DAG,
4462 bool UseGOT) const {
4463 SDLoc DL(N);
4464 EVT Ty = getPointerTy(DAG.getDataLayout());
4465 const GlobalValue *GV = N->getGlobal();
4466 MVT XLenVT = Subtarget.getXLenVT();
4467
4468 if (UseGOT) {
4469 // Use PC-relative addressing to access the GOT for this TLS symbol, then
4470 // load the address from the GOT and add the thread pointer. This generates
4471 // the pattern (PseudoLA_TLS_IE sym), which expands to
4472 // (ld (auipc %tls_ie_pcrel_hi(sym)) %pcrel_lo(auipc)).
4473 SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0);
4474 MachineFunction &MF = DAG.getMachineFunction();
4475 MachineMemOperand *MemOp = MF.getMachineMemOperand(
4476 MachinePointerInfo::getGOT(MF),
4477 MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable |
4478 MachineMemOperand::MOInvariant,
4479 LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8));
4480 SDValue Load = DAG.getMemIntrinsicNode(
4481 RISCVISD::LA_TLS_IE, DL, DAG.getVTList(Ty, MVT::Other),
4482 {DAG.getEntryNode(), Addr}, Ty, MemOp);
4483
4484 // Add the thread pointer.
4485 SDValue TPReg = DAG.getRegister(RISCV::X4, XLenVT);
4486 return DAG.getNode(ISD::ADD, DL, Ty, Load, TPReg);
4487 }
4488
4489 // Generate a sequence for accessing the address relative to the thread
4490 // pointer, with the appropriate adjustment for the thread pointer offset.
4491 // This generates the pattern
4492 // (add (add_tprel (lui %tprel_hi(sym)) tp %tprel_add(sym)) %tprel_lo(sym))
4493 SDValue AddrHi =
4494 DAG.getTargetGlobalAddress(GV, DL, Ty, 0, RISCVII::MO_TPREL_HI);
4495 SDValue AddrAdd =
4496 DAG.getTargetGlobalAddress(GV, DL, Ty, 0, RISCVII::MO_TPREL_ADD);
4497 SDValue AddrLo =
4498 DAG.getTargetGlobalAddress(GV, DL, Ty, 0, RISCVII::MO_TPREL_LO);
4499
4500 SDValue MNHi = DAG.getNode(RISCVISD::HI, DL, Ty, AddrHi);
4501 SDValue TPReg = DAG.getRegister(RISCV::X4, XLenVT);
4502 SDValue MNAdd =
4503 DAG.getNode(RISCVISD::ADD_TPREL, DL, Ty, MNHi, TPReg, AddrAdd);
4504 return DAG.getNode(RISCVISD::ADD_LO, DL, Ty, MNAdd, AddrLo);
4505 }
4506
getDynamicTLSAddr(GlobalAddressSDNode * N,SelectionDAG & DAG) const4507 SDValue RISCVTargetLowering::getDynamicTLSAddr(GlobalAddressSDNode *N,
4508 SelectionDAG &DAG) const {
4509 SDLoc DL(N);
4510 EVT Ty = getPointerTy(DAG.getDataLayout());
4511 IntegerType *CallTy = Type::getIntNTy(*DAG.getContext(), Ty.getSizeInBits());
4512 const GlobalValue *GV = N->getGlobal();
4513
4514 // Use a PC-relative addressing mode to access the global dynamic GOT address.
4515 // This generates the pattern (PseudoLA_TLS_GD sym), which expands to
4516 // (addi (auipc %tls_gd_pcrel_hi(sym)) %pcrel_lo(auipc)).
4517 SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0);
4518 SDValue Load = DAG.getNode(RISCVISD::LA_TLS_GD, DL, Ty, Addr);
4519
4520 // Prepare argument list to generate call.
4521 ArgListTy Args;
4522 ArgListEntry Entry;
4523 Entry.Node = Load;
4524 Entry.Ty = CallTy;
4525 Args.push_back(Entry);
4526
4527 // Setup call to __tls_get_addr.
4528 TargetLowering::CallLoweringInfo CLI(DAG);
4529 CLI.setDebugLoc(DL)
4530 .setChain(DAG.getEntryNode())
4531 .setLibCallee(CallingConv::C, CallTy,
4532 DAG.getExternalSymbol("__tls_get_addr", Ty),
4533 std::move(Args));
4534
4535 return LowerCallTo(CLI).first;
4536 }
4537
lowerGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const4538 SDValue RISCVTargetLowering::lowerGlobalTLSAddress(SDValue Op,
4539 SelectionDAG &DAG) const {
4540 const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
4541 if (DAG.getTarget().useEmulatedTLS())
4542 return LowerToTLSEmulatedModel(GA, DAG);
4543
4544 GlobalAddressSDNode *N = cast<GlobalAddressSDNode>(Op);
4545 assert(N->getOffset() == 0 && "unexpected offset in global node");
4546
4547 if (DAG.getTarget().useEmulatedTLS())
4548 return LowerToTLSEmulatedModel(N, DAG);
4549
4550 TLSModel::Model Model = getTargetMachine().getTLSModel(N->getGlobal());
4551
4552 if (DAG.getMachineFunction().getFunction().getCallingConv() ==
4553 CallingConv::GHC)
4554 report_fatal_error("In GHC calling convention TLS is not supported");
4555
4556 SDValue Addr;
4557 switch (Model) {
4558 case TLSModel::LocalExec:
4559 Addr = getStaticTLSAddr(N, DAG, /*UseGOT=*/false);
4560 break;
4561 case TLSModel::InitialExec:
4562 Addr = getStaticTLSAddr(N, DAG, /*UseGOT=*/true);
4563 break;
4564 case TLSModel::LocalDynamic:
4565 case TLSModel::GeneralDynamic:
4566 Addr = getDynamicTLSAddr(N, DAG);
4567 break;
4568 }
4569
4570 return Addr;
4571 }
4572
lowerSELECT(SDValue Op,SelectionDAG & DAG) const4573 SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
4574 SDValue CondV = Op.getOperand(0);
4575 SDValue TrueV = Op.getOperand(1);
4576 SDValue FalseV = Op.getOperand(2);
4577 SDLoc DL(Op);
4578 MVT VT = Op.getSimpleValueType();
4579 MVT XLenVT = Subtarget.getXLenVT();
4580
4581 // Lower vector SELECTs to VSELECTs by splatting the condition.
4582 if (VT.isVector()) {
4583 MVT SplatCondVT = VT.changeVectorElementType(MVT::i1);
4584 SDValue CondSplat = DAG.getSplat(SplatCondVT, DL, CondV);
4585 return DAG.getNode(ISD::VSELECT, DL, VT, CondSplat, TrueV, FalseV);
4586 }
4587
4588 if (!Subtarget.hasShortForwardBranchOpt()) {
4589 // (select c, -1, y) -> -c | y
4590 if (isAllOnesConstant(TrueV)) {
4591 SDValue Neg = DAG.getNegative(CondV, DL, VT);
4592 return DAG.getNode(ISD::OR, DL, VT, Neg, FalseV);
4593 }
4594 // (select c, y, -1) -> (c-1) | y
4595 if (isAllOnesConstant(FalseV)) {
4596 SDValue Neg = DAG.getNode(ISD::ADD, DL, VT, CondV,
4597 DAG.getAllOnesConstant(DL, VT));
4598 return DAG.getNode(ISD::OR, DL, VT, Neg, TrueV);
4599 }
4600
4601 // (select c, 0, y) -> (c-1) & y
4602 if (isNullConstant(TrueV)) {
4603 SDValue Neg = DAG.getNode(ISD::ADD, DL, VT, CondV,
4604 DAG.getAllOnesConstant(DL, VT));
4605 return DAG.getNode(ISD::AND, DL, VT, Neg, FalseV);
4606 }
4607 // (select c, y, 0) -> -c & y
4608 if (isNullConstant(FalseV)) {
4609 SDValue Neg = DAG.getNegative(CondV, DL, VT);
4610 return DAG.getNode(ISD::AND, DL, VT, Neg, TrueV);
4611 }
4612 }
4613
4614 // If the condition is not an integer SETCC which operates on XLenVT, we need
4615 // to emit a RISCVISD::SELECT_CC comparing the condition to zero. i.e.:
4616 // (select condv, truev, falsev)
4617 // -> (riscvisd::select_cc condv, zero, setne, truev, falsev)
4618 if (CondV.getOpcode() != ISD::SETCC ||
4619 CondV.getOperand(0).getSimpleValueType() != XLenVT) {
4620 SDValue Zero = DAG.getConstant(0, DL, XLenVT);
4621 SDValue SetNE = DAG.getCondCode(ISD::SETNE);
4622
4623 SDValue Ops[] = {CondV, Zero, SetNE, TrueV, FalseV};
4624
4625 return DAG.getNode(RISCVISD::SELECT_CC, DL, VT, Ops);
4626 }
4627
4628 // If the CondV is the output of a SETCC node which operates on XLenVT inputs,
4629 // then merge the SETCC node into the lowered RISCVISD::SELECT_CC to take
4630 // advantage of the integer compare+branch instructions. i.e.:
4631 // (select (setcc lhs, rhs, cc), truev, falsev)
4632 // -> (riscvisd::select_cc lhs, rhs, cc, truev, falsev)
4633 SDValue LHS = CondV.getOperand(0);
4634 SDValue RHS = CondV.getOperand(1);
4635 ISD::CondCode CCVal = cast<CondCodeSDNode>(CondV.getOperand(2))->get();
4636
4637 // Special case for a select of 2 constants that have a diffence of 1.
4638 // Normally this is done by DAGCombine, but if the select is introduced by
4639 // type legalization or op legalization, we miss it. Restricting to SETLT
4640 // case for now because that is what signed saturating add/sub need.
4641 // FIXME: We don't need the condition to be SETLT or even a SETCC,
4642 // but we would probably want to swap the true/false values if the condition
4643 // is SETGE/SETLE to avoid an XORI.
4644 if (isa<ConstantSDNode>(TrueV) && isa<ConstantSDNode>(FalseV) &&
4645 CCVal == ISD::SETLT) {
4646 const APInt &TrueVal = cast<ConstantSDNode>(TrueV)->getAPIntValue();
4647 const APInt &FalseVal = cast<ConstantSDNode>(FalseV)->getAPIntValue();
4648 if (TrueVal - 1 == FalseVal)
4649 return DAG.getNode(ISD::ADD, DL, VT, CondV, FalseV);
4650 if (TrueVal + 1 == FalseVal)
4651 return DAG.getNode(ISD::SUB, DL, VT, FalseV, CondV);
4652 }
4653
4654 translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
4655 // 1 < x ? x : 1 -> 0 < x ? x : 1
4656 if (isOneConstant(LHS) && (CCVal == ISD::SETLT || CCVal == ISD::SETULT) &&
4657 RHS == TrueV && LHS == FalseV) {
4658 LHS = DAG.getConstant(0, DL, VT);
4659 // 0 <u x is the same as x != 0.
4660 if (CCVal == ISD::SETULT) {
4661 std::swap(LHS, RHS);
4662 CCVal = ISD::SETNE;
4663 }
4664 }
4665
4666 // x <s -1 ? x : -1 -> x <s 0 ? x : -1
4667 if (isAllOnesConstant(RHS) && CCVal == ISD::SETLT && LHS == TrueV &&
4668 RHS == FalseV) {
4669 RHS = DAG.getConstant(0, DL, VT);
4670 }
4671
4672 SDValue TargetCC = DAG.getCondCode(CCVal);
4673
4674 if (isa<ConstantSDNode>(TrueV) && !isa<ConstantSDNode>(FalseV)) {
4675 // (select (setcc lhs, rhs, CC), constant, falsev)
4676 // -> (select (setcc lhs, rhs, InverseCC), falsev, constant)
4677 std::swap(TrueV, FalseV);
4678 TargetCC = DAG.getCondCode(ISD::getSetCCInverse(CCVal, LHS.getValueType()));
4679 }
4680
4681 SDValue Ops[] = {LHS, RHS, TargetCC, TrueV, FalseV};
4682 return DAG.getNode(RISCVISD::SELECT_CC, DL, VT, Ops);
4683 }
4684
lowerBRCOND(SDValue Op,SelectionDAG & DAG) const4685 SDValue RISCVTargetLowering::lowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
4686 SDValue CondV = Op.getOperand(1);
4687 SDLoc DL(Op);
4688 MVT XLenVT = Subtarget.getXLenVT();
4689
4690 if (CondV.getOpcode() == ISD::SETCC &&
4691 CondV.getOperand(0).getValueType() == XLenVT) {
4692 SDValue LHS = CondV.getOperand(0);
4693 SDValue RHS = CondV.getOperand(1);
4694 ISD::CondCode CCVal = cast<CondCodeSDNode>(CondV.getOperand(2))->get();
4695
4696 translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
4697
4698 SDValue TargetCC = DAG.getCondCode(CCVal);
4699 return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0),
4700 LHS, RHS, TargetCC, Op.getOperand(2));
4701 }
4702
4703 return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0),
4704 CondV, DAG.getConstant(0, DL, XLenVT),
4705 DAG.getCondCode(ISD::SETNE), Op.getOperand(2));
4706 }
4707
lowerVASTART(SDValue Op,SelectionDAG & DAG) const4708 SDValue RISCVTargetLowering::lowerVASTART(SDValue Op, SelectionDAG &DAG) const {
4709 MachineFunction &MF = DAG.getMachineFunction();
4710 RISCVMachineFunctionInfo *FuncInfo = MF.getInfo<RISCVMachineFunctionInfo>();
4711
4712 SDLoc DL(Op);
4713 SDValue FI = DAG.getFrameIndex(FuncInfo->getVarArgsFrameIndex(),
4714 getPointerTy(MF.getDataLayout()));
4715
4716 // vastart just stores the address of the VarArgsFrameIndex slot into the
4717 // memory location argument.
4718 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
4719 return DAG.getStore(Op.getOperand(0), DL, FI, Op.getOperand(1),
4720 MachinePointerInfo(SV));
4721 }
4722
lowerFRAMEADDR(SDValue Op,SelectionDAG & DAG) const4723 SDValue RISCVTargetLowering::lowerFRAMEADDR(SDValue Op,
4724 SelectionDAG &DAG) const {
4725 const RISCVRegisterInfo &RI = *Subtarget.getRegisterInfo();
4726 MachineFunction &MF = DAG.getMachineFunction();
4727 MachineFrameInfo &MFI = MF.getFrameInfo();
4728 MFI.setFrameAddressIsTaken(true);
4729 Register FrameReg = RI.getFrameRegister(MF);
4730 int XLenInBytes = Subtarget.getXLen() / 8;
4731
4732 EVT VT = Op.getValueType();
4733 SDLoc DL(Op);
4734 SDValue FrameAddr = DAG.getCopyFromReg(DAG.getEntryNode(), DL, FrameReg, VT);
4735 unsigned Depth = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
4736 while (Depth--) {
4737 int Offset = -(XLenInBytes * 2);
4738 SDValue Ptr = DAG.getNode(ISD::ADD, DL, VT, FrameAddr,
4739 DAG.getIntPtrConstant(Offset, DL));
4740 FrameAddr =
4741 DAG.getLoad(VT, DL, DAG.getEntryNode(), Ptr, MachinePointerInfo());
4742 }
4743 return FrameAddr;
4744 }
4745
lowerRETURNADDR(SDValue Op,SelectionDAG & DAG) const4746 SDValue RISCVTargetLowering::lowerRETURNADDR(SDValue Op,
4747 SelectionDAG &DAG) const {
4748 const RISCVRegisterInfo &RI = *Subtarget.getRegisterInfo();
4749 MachineFunction &MF = DAG.getMachineFunction();
4750 MachineFrameInfo &MFI = MF.getFrameInfo();
4751 MFI.setReturnAddressIsTaken(true);
4752 MVT XLenVT = Subtarget.getXLenVT();
4753 int XLenInBytes = Subtarget.getXLen() / 8;
4754
4755 if (verifyReturnAddressArgumentIsConstant(Op, DAG))
4756 return SDValue();
4757
4758 EVT VT = Op.getValueType();
4759 SDLoc DL(Op);
4760 unsigned Depth = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
4761 if (Depth) {
4762 int Off = -XLenInBytes;
4763 SDValue FrameAddr = lowerFRAMEADDR(Op, DAG);
4764 SDValue Offset = DAG.getConstant(Off, DL, VT);
4765 return DAG.getLoad(VT, DL, DAG.getEntryNode(),
4766 DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset),
4767 MachinePointerInfo());
4768 }
4769
4770 // Return the value of the return address register, marking it an implicit
4771 // live-in.
4772 Register Reg = MF.addLiveIn(RI.getRARegister(), getRegClassFor(XLenVT));
4773 return DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, XLenVT);
4774 }
4775
lowerShiftLeftParts(SDValue Op,SelectionDAG & DAG) const4776 SDValue RISCVTargetLowering::lowerShiftLeftParts(SDValue Op,
4777 SelectionDAG &DAG) const {
4778 SDLoc DL(Op);
4779 SDValue Lo = Op.getOperand(0);
4780 SDValue Hi = Op.getOperand(1);
4781 SDValue Shamt = Op.getOperand(2);
4782 EVT VT = Lo.getValueType();
4783
4784 // if Shamt-XLEN < 0: // Shamt < XLEN
4785 // Lo = Lo << Shamt
4786 // Hi = (Hi << Shamt) | ((Lo >>u 1) >>u (XLEN-1 ^ Shamt))
4787 // else:
4788 // Lo = 0
4789 // Hi = Lo << (Shamt-XLEN)
4790
4791 SDValue Zero = DAG.getConstant(0, DL, VT);
4792 SDValue One = DAG.getConstant(1, DL, VT);
4793 SDValue MinusXLen = DAG.getConstant(-(int)Subtarget.getXLen(), DL, VT);
4794 SDValue XLenMinus1 = DAG.getConstant(Subtarget.getXLen() - 1, DL, VT);
4795 SDValue ShamtMinusXLen = DAG.getNode(ISD::ADD, DL, VT, Shamt, MinusXLen);
4796 SDValue XLenMinus1Shamt = DAG.getNode(ISD::SUB, DL, VT, XLenMinus1, Shamt);
4797
4798 SDValue LoTrue = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt);
4799 SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo, One);
4800 SDValue ShiftRightLo =
4801 DAG.getNode(ISD::SRL, DL, VT, ShiftRight1Lo, XLenMinus1Shamt);
4802 SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, Shamt);
4803 SDValue HiTrue = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
4804 SDValue HiFalse = DAG.getNode(ISD::SHL, DL, VT, Lo, ShamtMinusXLen);
4805
4806 SDValue CC = DAG.getSetCC(DL, VT, ShamtMinusXLen, Zero, ISD::SETLT);
4807
4808 Lo = DAG.getNode(ISD::SELECT, DL, VT, CC, LoTrue, Zero);
4809 Hi = DAG.getNode(ISD::SELECT, DL, VT, CC, HiTrue, HiFalse);
4810
4811 SDValue Parts[2] = {Lo, Hi};
4812 return DAG.getMergeValues(Parts, DL);
4813 }
4814
lowerShiftRightParts(SDValue Op,SelectionDAG & DAG,bool IsSRA) const4815 SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
4816 bool IsSRA) const {
4817 SDLoc DL(Op);
4818 SDValue Lo = Op.getOperand(0);
4819 SDValue Hi = Op.getOperand(1);
4820 SDValue Shamt = Op.getOperand(2);
4821 EVT VT = Lo.getValueType();
4822
4823 // SRA expansion:
4824 // if Shamt-XLEN < 0: // Shamt < XLEN
4825 // Lo = (Lo >>u Shamt) | ((Hi << 1) << (ShAmt ^ XLEN-1))
4826 // Hi = Hi >>s Shamt
4827 // else:
4828 // Lo = Hi >>s (Shamt-XLEN);
4829 // Hi = Hi >>s (XLEN-1)
4830 //
4831 // SRL expansion:
4832 // if Shamt-XLEN < 0: // Shamt < XLEN
4833 // Lo = (Lo >>u Shamt) | ((Hi << 1) << (ShAmt ^ XLEN-1))
4834 // Hi = Hi >>u Shamt
4835 // else:
4836 // Lo = Hi >>u (Shamt-XLEN);
4837 // Hi = 0;
4838
4839 unsigned ShiftRightOp = IsSRA ? ISD::SRA : ISD::SRL;
4840
4841 SDValue Zero = DAG.getConstant(0, DL, VT);
4842 SDValue One = DAG.getConstant(1, DL, VT);
4843 SDValue MinusXLen = DAG.getConstant(-(int)Subtarget.getXLen(), DL, VT);
4844 SDValue XLenMinus1 = DAG.getConstant(Subtarget.getXLen() - 1, DL, VT);
4845 SDValue ShamtMinusXLen = DAG.getNode(ISD::ADD, DL, VT, Shamt, MinusXLen);
4846 SDValue XLenMinus1Shamt = DAG.getNode(ISD::SUB, DL, VT, XLenMinus1, Shamt);
4847
4848 SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt);
4849 SDValue ShiftLeftHi1 = DAG.getNode(ISD::SHL, DL, VT, Hi, One);
4850 SDValue ShiftLeftHi =
4851 DAG.getNode(ISD::SHL, DL, VT, ShiftLeftHi1, XLenMinus1Shamt);
4852 SDValue LoTrue = DAG.getNode(ISD::OR, DL, VT, ShiftRightLo, ShiftLeftHi);
4853 SDValue HiTrue = DAG.getNode(ShiftRightOp, DL, VT, Hi, Shamt);
4854 SDValue LoFalse = DAG.getNode(ShiftRightOp, DL, VT, Hi, ShamtMinusXLen);
4855 SDValue HiFalse =
4856 IsSRA ? DAG.getNode(ISD::SRA, DL, VT, Hi, XLenMinus1) : Zero;
4857
4858 SDValue CC = DAG.getSetCC(DL, VT, ShamtMinusXLen, Zero, ISD::SETLT);
4859
4860 Lo = DAG.getNode(ISD::SELECT, DL, VT, CC, LoTrue, LoFalse);
4861 Hi = DAG.getNode(ISD::SELECT, DL, VT, CC, HiTrue, HiFalse);
4862
4863 SDValue Parts[2] = {Lo, Hi};
4864 return DAG.getMergeValues(Parts, DL);
4865 }
4866
4867 // Lower splats of i1 types to SETCC. For each mask vector type, we have a
4868 // legal equivalently-sized i8 type, so we can use that as a go-between.
lowerVectorMaskSplat(SDValue Op,SelectionDAG & DAG) const4869 SDValue RISCVTargetLowering::lowerVectorMaskSplat(SDValue Op,
4870 SelectionDAG &DAG) const {
4871 SDLoc DL(Op);
4872 MVT VT = Op.getSimpleValueType();
4873 SDValue SplatVal = Op.getOperand(0);
4874 // All-zeros or all-ones splats are handled specially.
4875 if (ISD::isConstantSplatVectorAllOnes(Op.getNode())) {
4876 SDValue VL = getDefaultScalableVLOps(VT, DL, DAG, Subtarget).second;
4877 return DAG.getNode(RISCVISD::VMSET_VL, DL, VT, VL);
4878 }
4879 if (ISD::isConstantSplatVectorAllZeros(Op.getNode())) {
4880 SDValue VL = getDefaultScalableVLOps(VT, DL, DAG, Subtarget).second;
4881 return DAG.getNode(RISCVISD::VMCLR_VL, DL, VT, VL);
4882 }
4883 MVT XLenVT = Subtarget.getXLenVT();
4884 assert(SplatVal.getValueType() == XLenVT &&
4885 "Unexpected type for i1 splat value");
4886 MVT InterVT = VT.changeVectorElementType(MVT::i8);
4887 SplatVal = DAG.getNode(ISD::AND, DL, XLenVT, SplatVal,
4888 DAG.getConstant(1, DL, XLenVT));
4889 SDValue LHS = DAG.getSplatVector(InterVT, DL, SplatVal);
4890 SDValue Zero = DAG.getConstant(0, DL, InterVT);
4891 return DAG.getSetCC(DL, VT, LHS, Zero, ISD::SETNE);
4892 }
4893
4894 // Custom-lower a SPLAT_VECTOR_PARTS where XLEN<SEW, as the SEW element type is
4895 // illegal (currently only vXi64 RV32).
4896 // FIXME: We could also catch non-constant sign-extended i32 values and lower
4897 // them to VMV_V_X_VL.
lowerSPLAT_VECTOR_PARTS(SDValue Op,SelectionDAG & DAG) const4898 SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
4899 SelectionDAG &DAG) const {
4900 SDLoc DL(Op);
4901 MVT VecVT = Op.getSimpleValueType();
4902 assert(!Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64 &&
4903 "Unexpected SPLAT_VECTOR_PARTS lowering");
4904
4905 assert(Op.getNumOperands() == 2 && "Unexpected number of operands!");
4906 SDValue Lo = Op.getOperand(0);
4907 SDValue Hi = Op.getOperand(1);
4908
4909 if (VecVT.isFixedLengthVector()) {
4910 MVT ContainerVT = getContainerForFixedLengthVector(VecVT);
4911 SDLoc DL(Op);
4912 auto VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second;
4913
4914 SDValue Res =
4915 splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG);
4916 return convertFromScalableVector(VecVT, Res, DAG, Subtarget);
4917 }
4918
4919 if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
4920 int32_t LoC = cast<ConstantSDNode>(Lo)->getSExtValue();
4921 int32_t HiC = cast<ConstantSDNode>(Hi)->getSExtValue();
4922 // If Hi constant is all the same sign bit as Lo, lower this as a custom
4923 // node in order to try and match RVV vector/scalar instructions.
4924 if ((LoC >> 31) == HiC)
4925 return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT),
4926 Lo, DAG.getRegister(RISCV::X0, MVT::i32));
4927 }
4928
4929 // Detect cases where Hi is (SRA Lo, 31) which means Hi is Lo sign extended.
4930 if (Hi.getOpcode() == ISD::SRA && Hi.getOperand(0) == Lo &&
4931 isa<ConstantSDNode>(Hi.getOperand(1)) &&
4932 Hi.getConstantOperandVal(1) == 31)
4933 return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT), Lo,
4934 DAG.getRegister(RISCV::X0, MVT::i32));
4935
4936 // Fall back to use a stack store and stride x0 vector load. Use X0 as VL.
4937 return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VecVT,
4938 DAG.getUNDEF(VecVT), Lo, Hi,
4939 DAG.getRegister(RISCV::X0, MVT::i32));
4940 }
4941
4942 // Custom-lower extensions from mask vectors by using a vselect either with 1
4943 // for zero/any-extension or -1 for sign-extension:
4944 // (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
4945 // Note that any-extension is lowered identically to zero-extension.
lowerVectorMaskExt(SDValue Op,SelectionDAG & DAG,int64_t ExtTrueVal) const4946 SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
4947 int64_t ExtTrueVal) const {
4948 SDLoc DL(Op);
4949 MVT VecVT = Op.getSimpleValueType();
4950 SDValue Src = Op.getOperand(0);
4951 // Only custom-lower extensions from mask types
4952 assert(Src.getValueType().isVector() &&
4953 Src.getValueType().getVectorElementType() == MVT::i1);
4954
4955 if (VecVT.isScalableVector()) {
4956 SDValue SplatZero = DAG.getConstant(0, DL, VecVT);
4957 SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, VecVT);
4958 return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero);
4959 }
4960
4961 MVT ContainerVT = getContainerForFixedLengthVector(VecVT);
4962 MVT I1ContainerVT =
4963 MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
4964
4965 SDValue CC = convertToScalableVector(I1ContainerVT, Src, DAG, Subtarget);
4966
4967 SDValue VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second;
4968
4969 MVT XLenVT = Subtarget.getXLenVT();
4970 SDValue SplatZero = DAG.getConstant(0, DL, XLenVT);
4971 SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, XLenVT);
4972
4973 SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
4974 DAG.getUNDEF(ContainerVT), SplatZero, VL);
4975 SplatTrueVal = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
4976 DAG.getUNDEF(ContainerVT), SplatTrueVal, VL);
4977 SDValue Select = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, CC,
4978 SplatTrueVal, SplatZero, VL);
4979
4980 return convertFromScalableVector(VecVT, Select, DAG, Subtarget);
4981 }
4982
lowerFixedLengthVectorExtendToRVV(SDValue Op,SelectionDAG & DAG,unsigned ExtendOpc) const4983 SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV(
4984 SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const {
4985 MVT ExtVT = Op.getSimpleValueType();
4986 // Only custom-lower extensions from fixed-length vector types.
4987 if (!ExtVT.isFixedLengthVector())
4988 return Op;
4989 MVT VT = Op.getOperand(0).getSimpleValueType();
4990 // Grab the canonical container type for the extended type. Infer the smaller
4991 // type from that to ensure the same number of vector elements, as we know
4992 // the LMUL will be sufficient to hold the smaller type.
4993 MVT ContainerExtVT = getContainerForFixedLengthVector(ExtVT);
4994 // Get the extended container type manually to ensure the same number of
4995 // vector elements between source and dest.
4996 MVT ContainerVT = MVT::getVectorVT(VT.getVectorElementType(),
4997 ContainerExtVT.getVectorElementCount());
4998
4999 SDValue Op1 =
5000 convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
5001
5002 SDLoc DL(Op);
5003 auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
5004
5005 SDValue Ext = DAG.getNode(ExtendOpc, DL, ContainerExtVT, Op1, Mask, VL);
5006
5007 return convertFromScalableVector(ExtVT, Ext, DAG, Subtarget);
5008 }
5009
5010 // Custom-lower truncations from vectors to mask vectors by using a mask and a
5011 // setcc operation:
5012 // (vXi1 = trunc vXiN vec) -> (vXi1 = setcc (and vec, 1), 0, ne)
lowerVectorMaskTruncLike(SDValue Op,SelectionDAG & DAG) const5013 SDValue RISCVTargetLowering::lowerVectorMaskTruncLike(SDValue Op,
5014 SelectionDAG &DAG) const {
5015 bool IsVPTrunc = Op.getOpcode() == ISD::VP_TRUNCATE;
5016 SDLoc DL(Op);
5017 EVT MaskVT = Op.getValueType();
5018 // Only expect to custom-lower truncations to mask types
5019 assert(MaskVT.isVector() && MaskVT.getVectorElementType() == MVT::i1 &&
5020 "Unexpected type for vector mask lowering");
5021 SDValue Src = Op.getOperand(0);
5022 MVT VecVT = Src.getSimpleValueType();
5023 SDValue Mask, VL;
5024 if (IsVPTrunc) {
5025 Mask = Op.getOperand(1);
5026 VL = Op.getOperand(2);
5027 }
5028 // If this is a fixed vector, we need to convert it to a scalable vector.
5029 MVT ContainerVT = VecVT;
5030
5031 if (VecVT.isFixedLengthVector()) {
5032 ContainerVT = getContainerForFixedLengthVector(VecVT);
5033 Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
5034 if (IsVPTrunc) {
5035 MVT MaskContainerVT =
5036 getContainerForFixedLengthVector(Mask.getSimpleValueType());
5037 Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
5038 }
5039 }
5040
5041 if (!IsVPTrunc) {
5042 std::tie(Mask, VL) =
5043 getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
5044 }
5045
5046 SDValue SplatOne = DAG.getConstant(1, DL, Subtarget.getXLenVT());
5047 SDValue SplatZero = DAG.getConstant(0, DL, Subtarget.getXLenVT());
5048
5049 SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
5050 DAG.getUNDEF(ContainerVT), SplatOne, VL);
5051 SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
5052 DAG.getUNDEF(ContainerVT), SplatZero, VL);
5053
5054 MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1);
5055 SDValue Trunc = DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT, Src, SplatOne,
5056 DAG.getUNDEF(ContainerVT), Mask, VL);
5057 Trunc = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskContainerVT,
5058 {Trunc, SplatZero, DAG.getCondCode(ISD::SETNE),
5059 DAG.getUNDEF(MaskContainerVT), Mask, VL});
5060 if (MaskVT.isFixedLengthVector())
5061 Trunc = convertFromScalableVector(MaskVT, Trunc, DAG, Subtarget);
5062 return Trunc;
5063 }
5064
lowerVectorTruncLike(SDValue Op,SelectionDAG & DAG) const5065 SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
5066 SelectionDAG &DAG) const {
5067 bool IsVPTrunc = Op.getOpcode() == ISD::VP_TRUNCATE;
5068 SDLoc DL(Op);
5069
5070 MVT VT = Op.getSimpleValueType();
5071 // Only custom-lower vector truncates
5072 assert(VT.isVector() && "Unexpected type for vector truncate lowering");
5073
5074 // Truncates to mask types are handled differently
5075 if (VT.getVectorElementType() == MVT::i1)
5076 return lowerVectorMaskTruncLike(Op, DAG);
5077
5078 // RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary
5079 // truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which
5080 // truncate by one power of two at a time.
5081 MVT DstEltVT = VT.getVectorElementType();
5082
5083 SDValue Src = Op.getOperand(0);
5084 MVT SrcVT = Src.getSimpleValueType();
5085 MVT SrcEltVT = SrcVT.getVectorElementType();
5086
5087 assert(DstEltVT.bitsLT(SrcEltVT) && isPowerOf2_64(DstEltVT.getSizeInBits()) &&
5088 isPowerOf2_64(SrcEltVT.getSizeInBits()) &&
5089 "Unexpected vector truncate lowering");
5090
5091 MVT ContainerVT = SrcVT;
5092 SDValue Mask, VL;
5093 if (IsVPTrunc) {
5094 Mask = Op.getOperand(1);
5095 VL = Op.getOperand(2);
5096 }
5097 if (SrcVT.isFixedLengthVector()) {
5098 ContainerVT = getContainerForFixedLengthVector(SrcVT);
5099 Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
5100 if (IsVPTrunc) {
5101 MVT MaskVT = getMaskTypeFor(ContainerVT);
5102 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
5103 }
5104 }
5105
5106 SDValue Result = Src;
5107 if (!IsVPTrunc) {
5108 std::tie(Mask, VL) =
5109 getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
5110 }
5111
5112 LLVMContext &Context = *DAG.getContext();
5113 const ElementCount Count = ContainerVT.getVectorElementCount();
5114 do {
5115 SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
5116 EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
5117 Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
5118 Mask, VL);
5119 } while (SrcEltVT != DstEltVT);
5120
5121 if (SrcVT.isFixedLengthVector())
5122 Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
5123
5124 return Result;
5125 }
5126
5127 SDValue
lowerVectorFPExtendOrRoundLike(SDValue Op,SelectionDAG & DAG) const5128 RISCVTargetLowering::lowerVectorFPExtendOrRoundLike(SDValue Op,
5129 SelectionDAG &DAG) const {
5130 bool IsVP =
5131 Op.getOpcode() == ISD::VP_FP_ROUND || Op.getOpcode() == ISD::VP_FP_EXTEND;
5132 bool IsExtend =
5133 Op.getOpcode() == ISD::VP_FP_EXTEND || Op.getOpcode() == ISD::FP_EXTEND;
5134 // RVV can only do truncate fp to types half the size as the source. We
5135 // custom-lower f64->f16 rounds via RVV's round-to-odd float
5136 // conversion instruction.
5137 SDLoc DL(Op);
5138 MVT VT = Op.getSimpleValueType();
5139
5140 assert(VT.isVector() && "Unexpected type for vector truncate lowering");
5141
5142 SDValue Src = Op.getOperand(0);
5143 MVT SrcVT = Src.getSimpleValueType();
5144
5145 bool IsDirectExtend = IsExtend && (VT.getVectorElementType() != MVT::f64 ||
5146 SrcVT.getVectorElementType() != MVT::f16);
5147 bool IsDirectTrunc = !IsExtend && (VT.getVectorElementType() != MVT::f16 ||
5148 SrcVT.getVectorElementType() != MVT::f64);
5149
5150 bool IsDirectConv = IsDirectExtend || IsDirectTrunc;
5151
5152 // Prepare any fixed-length vector operands.
5153 MVT ContainerVT = VT;
5154 SDValue Mask, VL;
5155 if (IsVP) {
5156 Mask = Op.getOperand(1);
5157 VL = Op.getOperand(2);
5158 }
5159 if (VT.isFixedLengthVector()) {
5160 MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
5161 ContainerVT =
5162 SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
5163 Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
5164 if (IsVP) {
5165 MVT MaskVT = getMaskTypeFor(ContainerVT);
5166 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
5167 }
5168 }
5169
5170 if (!IsVP)
5171 std::tie(Mask, VL) =
5172 getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
5173
5174 unsigned ConvOpc = IsExtend ? RISCVISD::FP_EXTEND_VL : RISCVISD::FP_ROUND_VL;
5175
5176 if (IsDirectConv) {
5177 Src = DAG.getNode(ConvOpc, DL, ContainerVT, Src, Mask, VL);
5178 if (VT.isFixedLengthVector())
5179 Src = convertFromScalableVector(VT, Src, DAG, Subtarget);
5180 return Src;
5181 }
5182
5183 unsigned InterConvOpc =
5184 IsExtend ? RISCVISD::FP_EXTEND_VL : RISCVISD::VFNCVT_ROD_VL;
5185
5186 MVT InterVT = ContainerVT.changeVectorElementType(MVT::f32);
5187 SDValue IntermediateConv =
5188 DAG.getNode(InterConvOpc, DL, InterVT, Src, Mask, VL);
5189 SDValue Result =
5190 DAG.getNode(ConvOpc, DL, ContainerVT, IntermediateConv, Mask, VL);
5191 if (VT.isFixedLengthVector())
5192 return convertFromScalableVector(VT, Result, DAG, Subtarget);
5193 return Result;
5194 }
5195
5196 // Custom-legalize INSERT_VECTOR_ELT so that the value is inserted into the
5197 // first position of a vector, and that vector is slid up to the insert index.
5198 // By limiting the active vector length to index+1 and merging with the
5199 // original vector (with an undisturbed tail policy for elements >= VL), we
5200 // achieve the desired result of leaving all elements untouched except the one
5201 // at VL-1, which is replaced with the desired value.
lowerINSERT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const5202 SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
5203 SelectionDAG &DAG) const {
5204 SDLoc DL(Op);
5205 MVT VecVT = Op.getSimpleValueType();
5206 SDValue Vec = Op.getOperand(0);
5207 SDValue Val = Op.getOperand(1);
5208 SDValue Idx = Op.getOperand(2);
5209
5210 if (VecVT.getVectorElementType() == MVT::i1) {
5211 // FIXME: For now we just promote to an i8 vector and insert into that,
5212 // but this is probably not optimal.
5213 MVT WideVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorElementCount());
5214 Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, Vec);
5215 Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, WideVT, Vec, Val, Idx);
5216 return DAG.getNode(ISD::TRUNCATE, DL, VecVT, Vec);
5217 }
5218
5219 MVT ContainerVT = VecVT;
5220 // If the operand is a fixed-length vector, convert to a scalable one.
5221 if (VecVT.isFixedLengthVector()) {
5222 ContainerVT = getContainerForFixedLengthVector(VecVT);
5223 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
5224 }
5225
5226 MVT XLenVT = Subtarget.getXLenVT();
5227
5228 SDValue Zero = DAG.getConstant(0, DL, XLenVT);
5229 bool IsLegalInsert = Subtarget.is64Bit() || Val.getValueType() != MVT::i64;
5230 // Even i64-element vectors on RV32 can be lowered without scalar
5231 // legalization if the most-significant 32 bits of the value are not affected
5232 // by the sign-extension of the lower 32 bits.
5233 // TODO: We could also catch sign extensions of a 32-bit value.
5234 if (!IsLegalInsert && isa<ConstantSDNode>(Val)) {
5235 const auto *CVal = cast<ConstantSDNode>(Val);
5236 if (isInt<32>(CVal->getSExtValue())) {
5237 IsLegalInsert = true;
5238 Val = DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32);
5239 }
5240 }
5241
5242 auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
5243
5244 SDValue ValInVec;
5245
5246 if (IsLegalInsert) {
5247 unsigned Opc =
5248 VecVT.isFloatingPoint() ? RISCVISD::VFMV_S_F_VL : RISCVISD::VMV_S_X_VL;
5249 if (isNullConstant(Idx)) {
5250 Vec = DAG.getNode(Opc, DL, ContainerVT, Vec, Val, VL);
5251 if (!VecVT.isFixedLengthVector())
5252 return Vec;
5253 return convertFromScalableVector(VecVT, Vec, DAG, Subtarget);
5254 }
5255 ValInVec = lowerScalarInsert(Val, VL, ContainerVT, DL, DAG, Subtarget);
5256 } else {
5257 // On RV32, i64-element vectors must be specially handled to place the
5258 // value at element 0, by using two vslide1down instructions in sequence on
5259 // the i32 split lo/hi value. Use an equivalently-sized i32 vector for
5260 // this.
5261 SDValue One = DAG.getConstant(1, DL, XLenVT);
5262 SDValue ValLo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, Val, Zero);
5263 SDValue ValHi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, Val, One);
5264 MVT I32ContainerVT =
5265 MVT::getVectorVT(MVT::i32, ContainerVT.getVectorElementCount() * 2);
5266 SDValue I32Mask =
5267 getDefaultScalableVLOps(I32ContainerVT, DL, DAG, Subtarget).first;
5268 // Limit the active VL to two.
5269 SDValue InsertI64VL = DAG.getConstant(2, DL, XLenVT);
5270 // If the Idx is 0 we can insert directly into the vector.
5271 if (isNullConstant(Idx)) {
5272 // First slide in the lo value, then the hi in above it. We use slide1down
5273 // to avoid the register group overlap constraint of vslide1up.
5274 ValInVec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32ContainerVT,
5275 Vec, Vec, ValLo, I32Mask, InsertI64VL);
5276 // If the source vector is undef don't pass along the tail elements from
5277 // the previous slide1down.
5278 SDValue Tail = Vec.isUndef() ? Vec : ValInVec;
5279 ValInVec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32ContainerVT,
5280 Tail, ValInVec, ValHi, I32Mask, InsertI64VL);
5281 // Bitcast back to the right container type.
5282 ValInVec = DAG.getBitcast(ContainerVT, ValInVec);
5283
5284 if (!VecVT.isFixedLengthVector())
5285 return ValInVec;
5286 return convertFromScalableVector(VecVT, ValInVec, DAG, Subtarget);
5287 }
5288
5289 // First slide in the lo value, then the hi in above it. We use slide1down
5290 // to avoid the register group overlap constraint of vslide1up.
5291 ValInVec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32ContainerVT,
5292 DAG.getUNDEF(I32ContainerVT),
5293 DAG.getUNDEF(I32ContainerVT), ValLo,
5294 I32Mask, InsertI64VL);
5295 ValInVec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32ContainerVT,
5296 DAG.getUNDEF(I32ContainerVT), ValInVec, ValHi,
5297 I32Mask, InsertI64VL);
5298 // Bitcast back to the right container type.
5299 ValInVec = DAG.getBitcast(ContainerVT, ValInVec);
5300 }
5301
5302 // Now that the value is in a vector, slide it into position.
5303 SDValue InsertVL =
5304 DAG.getNode(ISD::ADD, DL, XLenVT, Idx, DAG.getConstant(1, DL, XLenVT));
5305
5306 // Use tail agnostic policy if Idx is the last index of Vec.
5307 unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
5308 if (VecVT.isFixedLengthVector() && isa<ConstantSDNode>(Idx) &&
5309 cast<ConstantSDNode>(Idx)->getZExtValue() + 1 ==
5310 VecVT.getVectorNumElements())
5311 Policy = RISCVII::TAIL_AGNOSTIC;
5312 SDValue Slideup = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, ValInVec,
5313 Idx, Mask, InsertVL, Policy);
5314 if (!VecVT.isFixedLengthVector())
5315 return Slideup;
5316 return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget);
5317 }
5318
5319 // Custom-lower EXTRACT_VECTOR_ELT operations to slide the vector down, then
5320 // extract the first element: (extractelt (slidedown vec, idx), 0). For integer
5321 // types this is done using VMV_X_S to allow us to glean information about the
5322 // sign bits of the result.
lowerEXTRACT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const5323 SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
5324 SelectionDAG &DAG) const {
5325 SDLoc DL(Op);
5326 SDValue Idx = Op.getOperand(1);
5327 SDValue Vec = Op.getOperand(0);
5328 EVT EltVT = Op.getValueType();
5329 MVT VecVT = Vec.getSimpleValueType();
5330 MVT XLenVT = Subtarget.getXLenVT();
5331
5332 if (VecVT.getVectorElementType() == MVT::i1) {
5333 // Use vfirst.m to extract the first bit.
5334 if (isNullConstant(Idx)) {
5335 MVT ContainerVT = VecVT;
5336 if (VecVT.isFixedLengthVector()) {
5337 ContainerVT = getContainerForFixedLengthVector(VecVT);
5338 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
5339 }
5340 auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
5341 SDValue Vfirst =
5342 DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Vec, Mask, VL);
5343 return DAG.getSetCC(DL, XLenVT, Vfirst, DAG.getConstant(0, DL, XLenVT),
5344 ISD::SETEQ);
5345 }
5346 if (VecVT.isFixedLengthVector()) {
5347 unsigned NumElts = VecVT.getVectorNumElements();
5348 if (NumElts >= 8) {
5349 MVT WideEltVT;
5350 unsigned WidenVecLen;
5351 SDValue ExtractElementIdx;
5352 SDValue ExtractBitIdx;
5353 unsigned MaxEEW = Subtarget.getELEN();
5354 MVT LargestEltVT = MVT::getIntegerVT(
5355 std::min(MaxEEW, unsigned(XLenVT.getSizeInBits())));
5356 if (NumElts <= LargestEltVT.getSizeInBits()) {
5357 assert(isPowerOf2_32(NumElts) &&
5358 "the number of elements should be power of 2");
5359 WideEltVT = MVT::getIntegerVT(NumElts);
5360 WidenVecLen = 1;
5361 ExtractElementIdx = DAG.getConstant(0, DL, XLenVT);
5362 ExtractBitIdx = Idx;
5363 } else {
5364 WideEltVT = LargestEltVT;
5365 WidenVecLen = NumElts / WideEltVT.getSizeInBits();
5366 // extract element index = index / element width
5367 ExtractElementIdx = DAG.getNode(
5368 ISD::SRL, DL, XLenVT, Idx,
5369 DAG.getConstant(Log2_64(WideEltVT.getSizeInBits()), DL, XLenVT));
5370 // mask bit index = index % element width
5371 ExtractBitIdx = DAG.getNode(
5372 ISD::AND, DL, XLenVT, Idx,
5373 DAG.getConstant(WideEltVT.getSizeInBits() - 1, DL, XLenVT));
5374 }
5375 MVT WideVT = MVT::getVectorVT(WideEltVT, WidenVecLen);
5376 Vec = DAG.getNode(ISD::BITCAST, DL, WideVT, Vec);
5377 SDValue ExtractElt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, XLenVT,
5378 Vec, ExtractElementIdx);
5379 // Extract the bit from GPR.
5380 SDValue ShiftRight =
5381 DAG.getNode(ISD::SRL, DL, XLenVT, ExtractElt, ExtractBitIdx);
5382 return DAG.getNode(ISD::AND, DL, XLenVT, ShiftRight,
5383 DAG.getConstant(1, DL, XLenVT));
5384 }
5385 }
5386 // Otherwise, promote to an i8 vector and extract from that.
5387 MVT WideVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorElementCount());
5388 Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, Vec);
5389 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec, Idx);
5390 }
5391
5392 // If this is a fixed vector, we need to convert it to a scalable vector.
5393 MVT ContainerVT = VecVT;
5394 if (VecVT.isFixedLengthVector()) {
5395 ContainerVT = getContainerForFixedLengthVector(VecVT);
5396 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
5397 }
5398
5399 // If the index is 0, the vector is already in the right position.
5400 if (!isNullConstant(Idx)) {
5401 // Use a VL of 1 to avoid processing more elements than we need.
5402 auto [Mask, VL] = getDefaultVLOps(1, ContainerVT, DL, DAG, Subtarget);
5403 Vec = getVSlidedown(DAG, Subtarget, DL, ContainerVT,
5404 DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
5405 }
5406
5407 if (!EltVT.isInteger()) {
5408 // Floating-point extracts are handled in TableGen.
5409 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
5410 DAG.getConstant(0, DL, XLenVT));
5411 }
5412
5413 SDValue Elt0 = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
5414 return DAG.getNode(ISD::TRUNCATE, DL, EltVT, Elt0);
5415 }
5416
5417 // Some RVV intrinsics may claim that they want an integer operand to be
5418 // promoted or expanded.
lowerVectorIntrinsicScalars(SDValue Op,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)5419 static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
5420 const RISCVSubtarget &Subtarget) {
5421 assert((Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
5422 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN) &&
5423 "Unexpected opcode");
5424
5425 if (!Subtarget.hasVInstructions())
5426 return SDValue();
5427
5428 bool HasChain = Op.getOpcode() == ISD::INTRINSIC_W_CHAIN;
5429 unsigned IntNo = Op.getConstantOperandVal(HasChain ? 1 : 0);
5430 SDLoc DL(Op);
5431
5432 const RISCVVIntrinsicsTable::RISCVVIntrinsicInfo *II =
5433 RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IntNo);
5434 if (!II || !II->hasScalarOperand())
5435 return SDValue();
5436
5437 unsigned SplatOp = II->ScalarOperand + 1 + HasChain;
5438 assert(SplatOp < Op.getNumOperands());
5439
5440 SmallVector<SDValue, 8> Operands(Op->op_begin(), Op->op_end());
5441 SDValue &ScalarOp = Operands[SplatOp];
5442 MVT OpVT = ScalarOp.getSimpleValueType();
5443 MVT XLenVT = Subtarget.getXLenVT();
5444
5445 // If this isn't a scalar, or its type is XLenVT we're done.
5446 if (!OpVT.isScalarInteger() || OpVT == XLenVT)
5447 return SDValue();
5448
5449 // Simplest case is that the operand needs to be promoted to XLenVT.
5450 if (OpVT.bitsLT(XLenVT)) {
5451 // If the operand is a constant, sign extend to increase our chances
5452 // of being able to use a .vi instruction. ANY_EXTEND would become a
5453 // a zero extend and the simm5 check in isel would fail.
5454 // FIXME: Should we ignore the upper bits in isel instead?
5455 unsigned ExtOpc =
5456 isa<ConstantSDNode>(ScalarOp) ? ISD::SIGN_EXTEND : ISD::ANY_EXTEND;
5457 ScalarOp = DAG.getNode(ExtOpc, DL, XLenVT, ScalarOp);
5458 return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
5459 }
5460
5461 // Use the previous operand to get the vXi64 VT. The result might be a mask
5462 // VT for compares. Using the previous operand assumes that the previous
5463 // operand will never have a smaller element size than a scalar operand and
5464 // that a widening operation never uses SEW=64.
5465 // NOTE: If this fails the below assert, we can probably just find the
5466 // element count from any operand or result and use it to construct the VT.
5467 assert(II->ScalarOperand > 0 && "Unexpected splat operand!");
5468 MVT VT = Op.getOperand(SplatOp - 1).getSimpleValueType();
5469
5470 // The more complex case is when the scalar is larger than XLenVT.
5471 assert(XLenVT == MVT::i32 && OpVT == MVT::i64 &&
5472 VT.getVectorElementType() == MVT::i64 && "Unexpected VTs!");
5473
5474 // If this is a sign-extended 32-bit value, we can truncate it and rely on the
5475 // instruction to sign-extend since SEW>XLEN.
5476 if (DAG.ComputeNumSignBits(ScalarOp) > 32) {
5477 ScalarOp = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, ScalarOp);
5478 return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
5479 }
5480
5481 switch (IntNo) {
5482 case Intrinsic::riscv_vslide1up:
5483 case Intrinsic::riscv_vslide1down:
5484 case Intrinsic::riscv_vslide1up_mask:
5485 case Intrinsic::riscv_vslide1down_mask: {
5486 // We need to special case these when the scalar is larger than XLen.
5487 unsigned NumOps = Op.getNumOperands();
5488 bool IsMasked = NumOps == 7;
5489
5490 // Convert the vector source to the equivalent nxvXi32 vector.
5491 MVT I32VT = MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2);
5492 SDValue Vec = DAG.getBitcast(I32VT, Operands[2]);
5493
5494 SDValue ScalarLo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, ScalarOp,
5495 DAG.getConstant(0, DL, XLenVT));
5496 SDValue ScalarHi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, ScalarOp,
5497 DAG.getConstant(1, DL, XLenVT));
5498
5499 // Double the VL since we halved SEW.
5500 SDValue AVL = getVLOperand(Op);
5501 SDValue I32VL;
5502
5503 // Optimize for constant AVL
5504 if (isa<ConstantSDNode>(AVL)) {
5505 unsigned EltSize = VT.getScalarSizeInBits();
5506 unsigned MinSize = VT.getSizeInBits().getKnownMinValue();
5507
5508 unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
5509 unsigned MaxVLMAX =
5510 RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize);
5511
5512 unsigned VectorBitsMin = Subtarget.getRealMinVLen();
5513 unsigned MinVLMAX =
5514 RISCVTargetLowering::computeVLMAX(VectorBitsMin, EltSize, MinSize);
5515
5516 uint64_t AVLInt = cast<ConstantSDNode>(AVL)->getZExtValue();
5517 if (AVLInt <= MinVLMAX) {
5518 I32VL = DAG.getConstant(2 * AVLInt, DL, XLenVT);
5519 } else if (AVLInt >= 2 * MaxVLMAX) {
5520 // Just set vl to VLMAX in this situation
5521 RISCVII::VLMUL Lmul = RISCVTargetLowering::getLMUL(I32VT);
5522 SDValue LMUL = DAG.getConstant(Lmul, DL, XLenVT);
5523 unsigned Sew = RISCVVType::encodeSEW(I32VT.getScalarSizeInBits());
5524 SDValue SEW = DAG.getConstant(Sew, DL, XLenVT);
5525 SDValue SETVLMAX = DAG.getTargetConstant(
5526 Intrinsic::riscv_vsetvlimax_opt, DL, MVT::i32);
5527 I32VL = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, SETVLMAX, SEW,
5528 LMUL);
5529 } else {
5530 // For AVL between (MinVLMAX, 2 * MaxVLMAX), the actual working vl
5531 // is related to the hardware implementation.
5532 // So let the following code handle
5533 }
5534 }
5535 if (!I32VL) {
5536 RISCVII::VLMUL Lmul = RISCVTargetLowering::getLMUL(VT);
5537 SDValue LMUL = DAG.getConstant(Lmul, DL, XLenVT);
5538 unsigned Sew = RISCVVType::encodeSEW(VT.getScalarSizeInBits());
5539 SDValue SEW = DAG.getConstant(Sew, DL, XLenVT);
5540 SDValue SETVL =
5541 DAG.getTargetConstant(Intrinsic::riscv_vsetvli_opt, DL, MVT::i32);
5542 // Using vsetvli instruction to get actually used length which related to
5543 // the hardware implementation
5544 SDValue VL = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, SETVL, AVL,
5545 SEW, LMUL);
5546 I32VL =
5547 DAG.getNode(ISD::SHL, DL, XLenVT, VL, DAG.getConstant(1, DL, XLenVT));
5548 }
5549
5550 SDValue I32Mask = getAllOnesMask(I32VT, I32VL, DL, DAG);
5551
5552 // Shift the two scalar parts in using SEW=32 slide1up/slide1down
5553 // instructions.
5554 SDValue Passthru;
5555 if (IsMasked)
5556 Passthru = DAG.getUNDEF(I32VT);
5557 else
5558 Passthru = DAG.getBitcast(I32VT, Operands[1]);
5559
5560 if (IntNo == Intrinsic::riscv_vslide1up ||
5561 IntNo == Intrinsic::riscv_vslide1up_mask) {
5562 Vec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32VT, Passthru, Vec,
5563 ScalarHi, I32Mask, I32VL);
5564 Vec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32VT, Passthru, Vec,
5565 ScalarLo, I32Mask, I32VL);
5566 } else {
5567 Vec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32VT, Passthru, Vec,
5568 ScalarLo, I32Mask, I32VL);
5569 Vec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32VT, Passthru, Vec,
5570 ScalarHi, I32Mask, I32VL);
5571 }
5572
5573 // Convert back to nxvXi64.
5574 Vec = DAG.getBitcast(VT, Vec);
5575
5576 if (!IsMasked)
5577 return Vec;
5578 // Apply mask after the operation.
5579 SDValue Mask = Operands[NumOps - 3];
5580 SDValue MaskedOff = Operands[1];
5581 // Assume Policy operand is the last operand.
5582 uint64_t Policy =
5583 cast<ConstantSDNode>(Operands[NumOps - 1])->getZExtValue();
5584 // We don't need to select maskedoff if it's undef.
5585 if (MaskedOff.isUndef())
5586 return Vec;
5587 // TAMU
5588 if (Policy == RISCVII::TAIL_AGNOSTIC)
5589 return DAG.getNode(RISCVISD::VSELECT_VL, DL, VT, Mask, Vec, MaskedOff,
5590 AVL);
5591 // TUMA or TUMU: Currently we always emit tumu policy regardless of tuma.
5592 // It's fine because vmerge does not care mask policy.
5593 return DAG.getNode(RISCVISD::VP_MERGE_VL, DL, VT, Mask, Vec, MaskedOff,
5594 AVL);
5595 }
5596 }
5597
5598 // We need to convert the scalar to a splat vector.
5599 SDValue VL = getVLOperand(Op);
5600 assert(VL.getValueType() == XLenVT);
5601 ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL, DAG);
5602 return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
5603 }
5604
LowerINTRINSIC_WO_CHAIN(SDValue Op,SelectionDAG & DAG) const5605 SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
5606 SelectionDAG &DAG) const {
5607 unsigned IntNo = Op.getConstantOperandVal(0);
5608 SDLoc DL(Op);
5609 MVT XLenVT = Subtarget.getXLenVT();
5610
5611 switch (IntNo) {
5612 default:
5613 break; // Don't custom lower most intrinsics.
5614 case Intrinsic::thread_pointer: {
5615 EVT PtrVT = getPointerTy(DAG.getDataLayout());
5616 return DAG.getRegister(RISCV::X4, PtrVT);
5617 }
5618 case Intrinsic::riscv_orc_b:
5619 case Intrinsic::riscv_brev8: {
5620 unsigned Opc =
5621 IntNo == Intrinsic::riscv_brev8 ? RISCVISD::BREV8 : RISCVISD::ORC_B;
5622 return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1));
5623 }
5624 case Intrinsic::riscv_zip:
5625 case Intrinsic::riscv_unzip: {
5626 unsigned Opc =
5627 IntNo == Intrinsic::riscv_zip ? RISCVISD::ZIP : RISCVISD::UNZIP;
5628 return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1));
5629 }
5630 case Intrinsic::riscv_vmv_x_s:
5631 assert(Op.getValueType() == XLenVT && "Unexpected VT!");
5632 return DAG.getNode(RISCVISD::VMV_X_S, DL, Op.getValueType(),
5633 Op.getOperand(1));
5634 case Intrinsic::riscv_vfmv_f_s:
5635 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
5636 Op.getOperand(1), DAG.getConstant(0, DL, XLenVT));
5637 case Intrinsic::riscv_vmv_v_x:
5638 return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
5639 Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
5640 Subtarget);
5641 case Intrinsic::riscv_vfmv_v_f:
5642 return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, Op.getValueType(),
5643 Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5644 case Intrinsic::riscv_vmv_s_x: {
5645 SDValue Scalar = Op.getOperand(2);
5646
5647 if (Scalar.getValueType().bitsLE(XLenVT)) {
5648 Scalar = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Scalar);
5649 return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, Op.getValueType(),
5650 Op.getOperand(1), Scalar, Op.getOperand(3));
5651 }
5652
5653 assert(Scalar.getValueType() == MVT::i64 && "Unexpected scalar VT!");
5654
5655 // This is an i64 value that lives in two scalar registers. We have to
5656 // insert this in a convoluted way. First we build vXi64 splat containing
5657 // the two values that we assemble using some bit math. Next we'll use
5658 // vid.v and vmseq to build a mask with bit 0 set. Then we'll use that mask
5659 // to merge element 0 from our splat into the source vector.
5660 // FIXME: This is probably not the best way to do this, but it is
5661 // consistent with INSERT_VECTOR_ELT lowering so it is a good starting
5662 // point.
5663 // sw lo, (a0)
5664 // sw hi, 4(a0)
5665 // vlse vX, (a0)
5666 //
5667 // vid.v vVid
5668 // vmseq.vx mMask, vVid, 0
5669 // vmerge.vvm vDest, vSrc, vVal, mMask
5670 MVT VT = Op.getSimpleValueType();
5671 SDValue Vec = Op.getOperand(1);
5672 SDValue VL = getVLOperand(Op);
5673
5674 SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL, DAG);
5675 if (Op.getOperand(1).isUndef())
5676 return SplattedVal;
5677 SDValue SplattedIdx =
5678 DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
5679 DAG.getConstant(0, DL, MVT::i32), VL);
5680
5681 MVT MaskVT = getMaskTypeFor(VT);
5682 SDValue Mask = getAllOnesMask(VT, VL, DL, DAG);
5683 SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VT, Mask, VL);
5684 SDValue SelectCond =
5685 DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
5686 {VID, SplattedIdx, DAG.getCondCode(ISD::SETEQ),
5687 DAG.getUNDEF(MaskVT), Mask, VL});
5688 return DAG.getNode(RISCVISD::VSELECT_VL, DL, VT, SelectCond, SplattedVal,
5689 Vec, VL);
5690 }
5691 }
5692
5693 return lowerVectorIntrinsicScalars(Op, DAG, Subtarget);
5694 }
5695
LowerINTRINSIC_W_CHAIN(SDValue Op,SelectionDAG & DAG) const5696 SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
5697 SelectionDAG &DAG) const {
5698 unsigned IntNo = Op.getConstantOperandVal(1);
5699 switch (IntNo) {
5700 default:
5701 break;
5702 case Intrinsic::riscv_masked_strided_load: {
5703 SDLoc DL(Op);
5704 MVT XLenVT = Subtarget.getXLenVT();
5705
5706 // If the mask is known to be all ones, optimize to an unmasked intrinsic;
5707 // the selection of the masked intrinsics doesn't do this for us.
5708 SDValue Mask = Op.getOperand(5);
5709 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
5710
5711 MVT VT = Op->getSimpleValueType(0);
5712 MVT ContainerVT = VT;
5713 if (VT.isFixedLengthVector())
5714 ContainerVT = getContainerForFixedLengthVector(VT);
5715
5716 SDValue PassThru = Op.getOperand(2);
5717 if (!IsUnmasked) {
5718 MVT MaskVT = getMaskTypeFor(ContainerVT);
5719 if (VT.isFixedLengthVector()) {
5720 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
5721 PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
5722 }
5723 }
5724
5725 auto *Load = cast<MemIntrinsicSDNode>(Op);
5726 SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
5727 SDValue Ptr = Op.getOperand(3);
5728 SDValue Stride = Op.getOperand(4);
5729 SDValue Result, Chain;
5730
5731 // TODO: We restrict this to unmasked loads currently in consideration of
5732 // the complexity of hanlding all falses masks.
5733 if (IsUnmasked && isNullConstant(Stride)) {
5734 MVT ScalarVT = ContainerVT.getVectorElementType();
5735 SDValue ScalarLoad =
5736 DAG.getExtLoad(ISD::ZEXTLOAD, DL, XLenVT, Load->getChain(), Ptr,
5737 ScalarVT, Load->getMemOperand());
5738 Chain = ScalarLoad.getValue(1);
5739 Result = lowerScalarSplat(SDValue(), ScalarLoad, VL, ContainerVT, DL, DAG,
5740 Subtarget);
5741 } else {
5742 SDValue IntID = DAG.getTargetConstant(
5743 IsUnmasked ? Intrinsic::riscv_vlse : Intrinsic::riscv_vlse_mask, DL,
5744 XLenVT);
5745
5746 SmallVector<SDValue, 8> Ops{Load->getChain(), IntID};
5747 if (IsUnmasked)
5748 Ops.push_back(DAG.getUNDEF(ContainerVT));
5749 else
5750 Ops.push_back(PassThru);
5751 Ops.push_back(Ptr);
5752 Ops.push_back(Stride);
5753 if (!IsUnmasked)
5754 Ops.push_back(Mask);
5755 Ops.push_back(VL);
5756 if (!IsUnmasked) {
5757 SDValue Policy =
5758 DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT);
5759 Ops.push_back(Policy);
5760 }
5761
5762 SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
5763 Result =
5764 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
5765 Load->getMemoryVT(), Load->getMemOperand());
5766 Chain = Result.getValue(1);
5767 }
5768 if (VT.isFixedLengthVector())
5769 Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
5770 return DAG.getMergeValues({Result, Chain}, DL);
5771 }
5772 case Intrinsic::riscv_seg2_load:
5773 case Intrinsic::riscv_seg3_load:
5774 case Intrinsic::riscv_seg4_load:
5775 case Intrinsic::riscv_seg5_load:
5776 case Intrinsic::riscv_seg6_load:
5777 case Intrinsic::riscv_seg7_load:
5778 case Intrinsic::riscv_seg8_load: {
5779 SDLoc DL(Op);
5780 static const Intrinsic::ID VlsegInts[7] = {
5781 Intrinsic::riscv_vlseg2, Intrinsic::riscv_vlseg3,
5782 Intrinsic::riscv_vlseg4, Intrinsic::riscv_vlseg5,
5783 Intrinsic::riscv_vlseg6, Intrinsic::riscv_vlseg7,
5784 Intrinsic::riscv_vlseg8};
5785 unsigned NF = Op->getNumValues() - 1;
5786 assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
5787 MVT XLenVT = Subtarget.getXLenVT();
5788 MVT VT = Op->getSimpleValueType(0);
5789 MVT ContainerVT = getContainerForFixedLengthVector(VT);
5790
5791 SDValue VL = getVLOp(VT.getVectorNumElements(), DL, DAG, Subtarget);
5792 SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
5793 auto *Load = cast<MemIntrinsicSDNode>(Op);
5794 SmallVector<EVT, 9> ContainerVTs(NF, ContainerVT);
5795 ContainerVTs.push_back(MVT::Other);
5796 SDVTList VTs = DAG.getVTList(ContainerVTs);
5797 SmallVector<SDValue, 12> Ops = {Load->getChain(), IntID};
5798 Ops.insert(Ops.end(), NF, DAG.getUNDEF(ContainerVT));
5799 Ops.push_back(Op.getOperand(2));
5800 Ops.push_back(VL);
5801 SDValue Result =
5802 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
5803 Load->getMemoryVT(), Load->getMemOperand());
5804 SmallVector<SDValue, 9> Results;
5805 for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++)
5806 Results.push_back(convertFromScalableVector(VT, Result.getValue(RetIdx),
5807 DAG, Subtarget));
5808 Results.push_back(Result.getValue(NF));
5809 return DAG.getMergeValues(Results, DL);
5810 }
5811 }
5812
5813 return lowerVectorIntrinsicScalars(Op, DAG, Subtarget);
5814 }
5815
LowerINTRINSIC_VOID(SDValue Op,SelectionDAG & DAG) const5816 SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
5817 SelectionDAG &DAG) const {
5818 unsigned IntNo = Op.getConstantOperandVal(1);
5819 switch (IntNo) {
5820 default:
5821 break;
5822 case Intrinsic::riscv_masked_strided_store: {
5823 SDLoc DL(Op);
5824 MVT XLenVT = Subtarget.getXLenVT();
5825
5826 // If the mask is known to be all ones, optimize to an unmasked intrinsic;
5827 // the selection of the masked intrinsics doesn't do this for us.
5828 SDValue Mask = Op.getOperand(5);
5829 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
5830
5831 SDValue Val = Op.getOperand(2);
5832 MVT VT = Val.getSimpleValueType();
5833 MVT ContainerVT = VT;
5834 if (VT.isFixedLengthVector()) {
5835 ContainerVT = getContainerForFixedLengthVector(VT);
5836 Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
5837 }
5838 if (!IsUnmasked) {
5839 MVT MaskVT = getMaskTypeFor(ContainerVT);
5840 if (VT.isFixedLengthVector())
5841 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
5842 }
5843
5844 SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
5845
5846 SDValue IntID = DAG.getTargetConstant(
5847 IsUnmasked ? Intrinsic::riscv_vsse : Intrinsic::riscv_vsse_mask, DL,
5848 XLenVT);
5849
5850 auto *Store = cast<MemIntrinsicSDNode>(Op);
5851 SmallVector<SDValue, 8> Ops{Store->getChain(), IntID};
5852 Ops.push_back(Val);
5853 Ops.push_back(Op.getOperand(3)); // Ptr
5854 Ops.push_back(Op.getOperand(4)); // Stride
5855 if (!IsUnmasked)
5856 Ops.push_back(Mask);
5857 Ops.push_back(VL);
5858
5859 return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, Store->getVTList(),
5860 Ops, Store->getMemoryVT(),
5861 Store->getMemOperand());
5862 }
5863 }
5864
5865 return SDValue();
5866 }
5867
getRVVReductionOp(unsigned ISDOpcode)5868 static unsigned getRVVReductionOp(unsigned ISDOpcode) {
5869 switch (ISDOpcode) {
5870 default:
5871 llvm_unreachable("Unhandled reduction");
5872 case ISD::VECREDUCE_ADD:
5873 return RISCVISD::VECREDUCE_ADD_VL;
5874 case ISD::VECREDUCE_UMAX:
5875 return RISCVISD::VECREDUCE_UMAX_VL;
5876 case ISD::VECREDUCE_SMAX:
5877 return RISCVISD::VECREDUCE_SMAX_VL;
5878 case ISD::VECREDUCE_UMIN:
5879 return RISCVISD::VECREDUCE_UMIN_VL;
5880 case ISD::VECREDUCE_SMIN:
5881 return RISCVISD::VECREDUCE_SMIN_VL;
5882 case ISD::VECREDUCE_AND:
5883 return RISCVISD::VECREDUCE_AND_VL;
5884 case ISD::VECREDUCE_OR:
5885 return RISCVISD::VECREDUCE_OR_VL;
5886 case ISD::VECREDUCE_XOR:
5887 return RISCVISD::VECREDUCE_XOR_VL;
5888 }
5889 }
5890
lowerVectorMaskVecReduction(SDValue Op,SelectionDAG & DAG,bool IsVP) const5891 SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op,
5892 SelectionDAG &DAG,
5893 bool IsVP) const {
5894 SDLoc DL(Op);
5895 SDValue Vec = Op.getOperand(IsVP ? 1 : 0);
5896 MVT VecVT = Vec.getSimpleValueType();
5897 assert((Op.getOpcode() == ISD::VECREDUCE_AND ||
5898 Op.getOpcode() == ISD::VECREDUCE_OR ||
5899 Op.getOpcode() == ISD::VECREDUCE_XOR ||
5900 Op.getOpcode() == ISD::VP_REDUCE_AND ||
5901 Op.getOpcode() == ISD::VP_REDUCE_OR ||
5902 Op.getOpcode() == ISD::VP_REDUCE_XOR) &&
5903 "Unexpected reduction lowering");
5904
5905 MVT XLenVT = Subtarget.getXLenVT();
5906 assert(Op.getValueType() == XLenVT &&
5907 "Expected reduction output to be legalized to XLenVT");
5908
5909 MVT ContainerVT = VecVT;
5910 if (VecVT.isFixedLengthVector()) {
5911 ContainerVT = getContainerForFixedLengthVector(VecVT);
5912 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
5913 }
5914
5915 SDValue Mask, VL;
5916 if (IsVP) {
5917 Mask = Op.getOperand(2);
5918 VL = Op.getOperand(3);
5919 } else {
5920 std::tie(Mask, VL) =
5921 getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
5922 }
5923
5924 unsigned BaseOpc;
5925 ISD::CondCode CC;
5926 SDValue Zero = DAG.getConstant(0, DL, XLenVT);
5927
5928 switch (Op.getOpcode()) {
5929 default:
5930 llvm_unreachable("Unhandled reduction");
5931 case ISD::VECREDUCE_AND:
5932 case ISD::VP_REDUCE_AND: {
5933 // vcpop ~x == 0
5934 SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL);
5935 Vec = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Vec, TrueMask, VL);
5936 Vec = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Vec, Mask, VL);
5937 CC = ISD::SETEQ;
5938 BaseOpc = ISD::AND;
5939 break;
5940 }
5941 case ISD::VECREDUCE_OR:
5942 case ISD::VP_REDUCE_OR:
5943 // vcpop x != 0
5944 Vec = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Vec, Mask, VL);
5945 CC = ISD::SETNE;
5946 BaseOpc = ISD::OR;
5947 break;
5948 case ISD::VECREDUCE_XOR:
5949 case ISD::VP_REDUCE_XOR: {
5950 // ((vcpop x) & 1) != 0
5951 SDValue One = DAG.getConstant(1, DL, XLenVT);
5952 Vec = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Vec, Mask, VL);
5953 Vec = DAG.getNode(ISD::AND, DL, XLenVT, Vec, One);
5954 CC = ISD::SETNE;
5955 BaseOpc = ISD::XOR;
5956 break;
5957 }
5958 }
5959
5960 SDValue SetCC = DAG.getSetCC(DL, XLenVT, Vec, Zero, CC);
5961
5962 if (!IsVP)
5963 return SetCC;
5964
5965 // Now include the start value in the operation.
5966 // Note that we must return the start value when no elements are operated
5967 // upon. The vcpop instructions we've emitted in each case above will return
5968 // 0 for an inactive vector, and so we've already received the neutral value:
5969 // AND gives us (0 == 0) -> 1 and OR/XOR give us (0 != 0) -> 0. Therefore we
5970 // can simply include the start value.
5971 return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0));
5972 }
5973
hasNonZeroAVL(SDValue AVL)5974 static bool hasNonZeroAVL(SDValue AVL) {
5975 auto *RegisterAVL = dyn_cast<RegisterSDNode>(AVL);
5976 auto *ImmAVL = dyn_cast<ConstantSDNode>(AVL);
5977 return (RegisterAVL && RegisterAVL->getReg() == RISCV::X0) ||
5978 (ImmAVL && ImmAVL->getZExtValue() >= 1);
5979 }
5980
5981 /// Helper to lower a reduction sequence of the form:
5982 /// scalar = reduce_op vec, scalar_start
lowerReductionSeq(unsigned RVVOpcode,MVT ResVT,SDValue StartValue,SDValue Vec,SDValue Mask,SDValue VL,SDLoc DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)5983 static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
5984 SDValue StartValue, SDValue Vec, SDValue Mask,
5985 SDValue VL, SDLoc DL, SelectionDAG &DAG,
5986 const RISCVSubtarget &Subtarget) {
5987 const MVT VecVT = Vec.getSimpleValueType();
5988 const MVT M1VT = getLMUL1VT(VecVT);
5989 const MVT XLenVT = Subtarget.getXLenVT();
5990 const bool NonZeroAVL = hasNonZeroAVL(VL);
5991
5992 // The reduction needs an LMUL1 input; do the splat at either LMUL1
5993 // or the original VT if fractional.
5994 auto InnerVT = VecVT.bitsLE(M1VT) ? VecVT : M1VT;
5995 // We reuse the VL of the reduction to reduce vsetvli toggles if we can
5996 // prove it is non-zero. For the AVL=0 case, we need the scalar to
5997 // be the result of the reduction operation.
5998 auto InnerVL = NonZeroAVL ? VL : DAG.getConstant(1, DL, XLenVT);
5999 SDValue InitialValue = lowerScalarInsert(StartValue, InnerVL, InnerVT, DL,
6000 DAG, Subtarget);
6001 if (M1VT != InnerVT)
6002 InitialValue = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, M1VT,
6003 DAG.getUNDEF(M1VT),
6004 InitialValue, DAG.getConstant(0, DL, XLenVT));
6005 SDValue PassThru = NonZeroAVL ? DAG.getUNDEF(M1VT) : InitialValue;
6006 SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
6007 InitialValue, Mask, VL);
6008 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
6009 DAG.getConstant(0, DL, XLenVT));
6010 }
6011
lowerVECREDUCE(SDValue Op,SelectionDAG & DAG) const6012 SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
6013 SelectionDAG &DAG) const {
6014 SDLoc DL(Op);
6015 SDValue Vec = Op.getOperand(0);
6016 EVT VecEVT = Vec.getValueType();
6017
6018 unsigned BaseOpc = ISD::getVecReduceBaseOpcode(Op.getOpcode());
6019
6020 // Due to ordering in legalize types we may have a vector type that needs to
6021 // be split. Do that manually so we can get down to a legal type.
6022 while (getTypeAction(*DAG.getContext(), VecEVT) ==
6023 TargetLowering::TypeSplitVector) {
6024 auto [Lo, Hi] = DAG.SplitVector(Vec, DL);
6025 VecEVT = Lo.getValueType();
6026 Vec = DAG.getNode(BaseOpc, DL, VecEVT, Lo, Hi);
6027 }
6028
6029 // TODO: The type may need to be widened rather than split. Or widened before
6030 // it can be split.
6031 if (!isTypeLegal(VecEVT))
6032 return SDValue();
6033
6034 MVT VecVT = VecEVT.getSimpleVT();
6035 MVT VecEltVT = VecVT.getVectorElementType();
6036 unsigned RVVOpcode = getRVVReductionOp(Op.getOpcode());
6037
6038 MVT ContainerVT = VecVT;
6039 if (VecVT.isFixedLengthVector()) {
6040 ContainerVT = getContainerForFixedLengthVector(VecVT);
6041 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
6042 }
6043
6044 auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
6045
6046 SDValue NeutralElem =
6047 DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags());
6048 return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), NeutralElem, Vec,
6049 Mask, VL, DL, DAG, Subtarget);
6050 }
6051
6052 // Given a reduction op, this function returns the matching reduction opcode,
6053 // the vector SDValue and the scalar SDValue required to lower this to a
6054 // RISCVISD node.
6055 static std::tuple<unsigned, SDValue, SDValue>
getRVVFPReductionOpAndOperands(SDValue Op,SelectionDAG & DAG,EVT EltVT)6056 getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT) {
6057 SDLoc DL(Op);
6058 auto Flags = Op->getFlags();
6059 unsigned Opcode = Op.getOpcode();
6060 unsigned BaseOpcode = ISD::getVecReduceBaseOpcode(Opcode);
6061 switch (Opcode) {
6062 default:
6063 llvm_unreachable("Unhandled reduction");
6064 case ISD::VECREDUCE_FADD: {
6065 // Use positive zero if we can. It is cheaper to materialize.
6066 SDValue Zero =
6067 DAG.getConstantFP(Flags.hasNoSignedZeros() ? 0.0 : -0.0, DL, EltVT);
6068 return std::make_tuple(RISCVISD::VECREDUCE_FADD_VL, Op.getOperand(0), Zero);
6069 }
6070 case ISD::VECREDUCE_SEQ_FADD:
6071 return std::make_tuple(RISCVISD::VECREDUCE_SEQ_FADD_VL, Op.getOperand(1),
6072 Op.getOperand(0));
6073 case ISD::VECREDUCE_FMIN:
6074 return std::make_tuple(RISCVISD::VECREDUCE_FMIN_VL, Op.getOperand(0),
6075 DAG.getNeutralElement(BaseOpcode, DL, EltVT, Flags));
6076 case ISD::VECREDUCE_FMAX:
6077 return std::make_tuple(RISCVISD::VECREDUCE_FMAX_VL, Op.getOperand(0),
6078 DAG.getNeutralElement(BaseOpcode, DL, EltVT, Flags));
6079 }
6080 }
6081
lowerFPVECREDUCE(SDValue Op,SelectionDAG & DAG) const6082 SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op,
6083 SelectionDAG &DAG) const {
6084 SDLoc DL(Op);
6085 MVT VecEltVT = Op.getSimpleValueType();
6086
6087 unsigned RVVOpcode;
6088 SDValue VectorVal, ScalarVal;
6089 std::tie(RVVOpcode, VectorVal, ScalarVal) =
6090 getRVVFPReductionOpAndOperands(Op, DAG, VecEltVT);
6091 MVT VecVT = VectorVal.getSimpleValueType();
6092
6093 MVT ContainerVT = VecVT;
6094 if (VecVT.isFixedLengthVector()) {
6095 ContainerVT = getContainerForFixedLengthVector(VecVT);
6096 VectorVal = convertToScalableVector(ContainerVT, VectorVal, DAG, Subtarget);
6097 }
6098
6099 auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
6100 return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), ScalarVal,
6101 VectorVal, Mask, VL, DL, DAG, Subtarget);
6102 }
6103
getRVVVPReductionOp(unsigned ISDOpcode)6104 static unsigned getRVVVPReductionOp(unsigned ISDOpcode) {
6105 switch (ISDOpcode) {
6106 default:
6107 llvm_unreachable("Unhandled reduction");
6108 case ISD::VP_REDUCE_ADD:
6109 return RISCVISD::VECREDUCE_ADD_VL;
6110 case ISD::VP_REDUCE_UMAX:
6111 return RISCVISD::VECREDUCE_UMAX_VL;
6112 case ISD::VP_REDUCE_SMAX:
6113 return RISCVISD::VECREDUCE_SMAX_VL;
6114 case ISD::VP_REDUCE_UMIN:
6115 return RISCVISD::VECREDUCE_UMIN_VL;
6116 case ISD::VP_REDUCE_SMIN:
6117 return RISCVISD::VECREDUCE_SMIN_VL;
6118 case ISD::VP_REDUCE_AND:
6119 return RISCVISD::VECREDUCE_AND_VL;
6120 case ISD::VP_REDUCE_OR:
6121 return RISCVISD::VECREDUCE_OR_VL;
6122 case ISD::VP_REDUCE_XOR:
6123 return RISCVISD::VECREDUCE_XOR_VL;
6124 case ISD::VP_REDUCE_FADD:
6125 return RISCVISD::VECREDUCE_FADD_VL;
6126 case ISD::VP_REDUCE_SEQ_FADD:
6127 return RISCVISD::VECREDUCE_SEQ_FADD_VL;
6128 case ISD::VP_REDUCE_FMAX:
6129 return RISCVISD::VECREDUCE_FMAX_VL;
6130 case ISD::VP_REDUCE_FMIN:
6131 return RISCVISD::VECREDUCE_FMIN_VL;
6132 }
6133 }
6134
lowerVPREDUCE(SDValue Op,SelectionDAG & DAG) const6135 SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
6136 SelectionDAG &DAG) const {
6137 SDLoc DL(Op);
6138 SDValue Vec = Op.getOperand(1);
6139 EVT VecEVT = Vec.getValueType();
6140
6141 // TODO: The type may need to be widened rather than split. Or widened before
6142 // it can be split.
6143 if (!isTypeLegal(VecEVT))
6144 return SDValue();
6145
6146 MVT VecVT = VecEVT.getSimpleVT();
6147 unsigned RVVOpcode = getRVVVPReductionOp(Op.getOpcode());
6148
6149 if (VecVT.isFixedLengthVector()) {
6150 auto ContainerVT = getContainerForFixedLengthVector(VecVT);
6151 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
6152 }
6153
6154 SDValue VL = Op.getOperand(3);
6155 SDValue Mask = Op.getOperand(2);
6156 return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), Op.getOperand(0),
6157 Vec, Mask, VL, DL, DAG, Subtarget);
6158 }
6159
lowerINSERT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const6160 SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
6161 SelectionDAG &DAG) const {
6162 SDValue Vec = Op.getOperand(0);
6163 SDValue SubVec = Op.getOperand(1);
6164 MVT VecVT = Vec.getSimpleValueType();
6165 MVT SubVecVT = SubVec.getSimpleValueType();
6166
6167 SDLoc DL(Op);
6168 MVT XLenVT = Subtarget.getXLenVT();
6169 unsigned OrigIdx = Op.getConstantOperandVal(2);
6170 const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
6171
6172 // We don't have the ability to slide mask vectors up indexed by their i1
6173 // elements; the smallest we can do is i8. Often we are able to bitcast to
6174 // equivalent i8 vectors. Note that when inserting a fixed-length vector
6175 // into a scalable one, we might not necessarily have enough scalable
6176 // elements to safely divide by 8: nxv1i1 = insert nxv1i1, v4i1 is valid.
6177 if (SubVecVT.getVectorElementType() == MVT::i1 &&
6178 (OrigIdx != 0 || !Vec.isUndef())) {
6179 if (VecVT.getVectorMinNumElements() >= 8 &&
6180 SubVecVT.getVectorMinNumElements() >= 8) {
6181 assert(OrigIdx % 8 == 0 && "Invalid index");
6182 assert(VecVT.getVectorMinNumElements() % 8 == 0 &&
6183 SubVecVT.getVectorMinNumElements() % 8 == 0 &&
6184 "Unexpected mask vector lowering");
6185 OrigIdx /= 8;
6186 SubVecVT =
6187 MVT::getVectorVT(MVT::i8, SubVecVT.getVectorMinNumElements() / 8,
6188 SubVecVT.isScalableVector());
6189 VecVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorMinNumElements() / 8,
6190 VecVT.isScalableVector());
6191 Vec = DAG.getBitcast(VecVT, Vec);
6192 SubVec = DAG.getBitcast(SubVecVT, SubVec);
6193 } else {
6194 // We can't slide this mask vector up indexed by its i1 elements.
6195 // This poses a problem when we wish to insert a scalable vector which
6196 // can't be re-expressed as a larger type. Just choose the slow path and
6197 // extend to a larger type, then truncate back down.
6198 MVT ExtVecVT = VecVT.changeVectorElementType(MVT::i8);
6199 MVT ExtSubVecVT = SubVecVT.changeVectorElementType(MVT::i8);
6200 Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVecVT, Vec);
6201 SubVec = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtSubVecVT, SubVec);
6202 Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ExtVecVT, Vec, SubVec,
6203 Op.getOperand(2));
6204 SDValue SplatZero = DAG.getConstant(0, DL, ExtVecVT);
6205 return DAG.getSetCC(DL, VecVT, Vec, SplatZero, ISD::SETNE);
6206 }
6207 }
6208
6209 // If the subvector vector is a fixed-length type, we cannot use subregister
6210 // manipulation to simplify the codegen; we don't know which register of a
6211 // LMUL group contains the specific subvector as we only know the minimum
6212 // register size. Therefore we must slide the vector group up the full
6213 // amount.
6214 if (SubVecVT.isFixedLengthVector()) {
6215 if (OrigIdx == 0 && Vec.isUndef() && !VecVT.isFixedLengthVector())
6216 return Op;
6217 MVT ContainerVT = VecVT;
6218 if (VecVT.isFixedLengthVector()) {
6219 ContainerVT = getContainerForFixedLengthVector(VecVT);
6220 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
6221 }
6222 SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT,
6223 DAG.getUNDEF(ContainerVT), SubVec,
6224 DAG.getConstant(0, DL, XLenVT));
6225 if (OrigIdx == 0 && Vec.isUndef() && VecVT.isFixedLengthVector()) {
6226 SubVec = convertFromScalableVector(VecVT, SubVec, DAG, Subtarget);
6227 return DAG.getBitcast(Op.getValueType(), SubVec);
6228 }
6229 SDValue Mask =
6230 getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
6231 // Set the vector length to only the number of elements we care about. Note
6232 // that for slideup this includes the offset.
6233 SDValue VL =
6234 getVLOp(OrigIdx + SubVecVT.getVectorNumElements(), DL, DAG, Subtarget);
6235 SDValue SlideupAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
6236
6237 // Use tail agnostic policy if OrigIdx is the last index of Vec.
6238 unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
6239 if (VecVT.isFixedLengthVector() &&
6240 OrigIdx + 1 == VecVT.getVectorNumElements())
6241 Policy = RISCVII::TAIL_AGNOSTIC;
6242 SDValue Slideup = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, SubVec,
6243 SlideupAmt, Mask, VL, Policy);
6244 if (VecVT.isFixedLengthVector())
6245 Slideup = convertFromScalableVector(VecVT, Slideup, DAG, Subtarget);
6246 return DAG.getBitcast(Op.getValueType(), Slideup);
6247 }
6248
6249 unsigned SubRegIdx, RemIdx;
6250 std::tie(SubRegIdx, RemIdx) =
6251 RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
6252 VecVT, SubVecVT, OrigIdx, TRI);
6253
6254 RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(SubVecVT);
6255 bool IsSubVecPartReg = SubVecLMUL == RISCVII::VLMUL::LMUL_F2 ||
6256 SubVecLMUL == RISCVII::VLMUL::LMUL_F4 ||
6257 SubVecLMUL == RISCVII::VLMUL::LMUL_F8;
6258
6259 // 1. If the Idx has been completely eliminated and this subvector's size is
6260 // a vector register or a multiple thereof, or the surrounding elements are
6261 // undef, then this is a subvector insert which naturally aligns to a vector
6262 // register. These can easily be handled using subregister manipulation.
6263 // 2. If the subvector is smaller than a vector register, then the insertion
6264 // must preserve the undisturbed elements of the register. We do this by
6265 // lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
6266 // (which resolves to a subregister copy), performing a VSLIDEUP to place the
6267 // subvector within the vector register, and an INSERT_SUBVECTOR of that
6268 // LMUL=1 type back into the larger vector (resolving to another subregister
6269 // operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
6270 // to avoid allocating a large register group to hold our subvector.
6271 if (RemIdx == 0 && (!IsSubVecPartReg || Vec.isUndef()))
6272 return Op;
6273
6274 // VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
6275 // OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
6276 // (in our case undisturbed). This means we can set up a subvector insertion
6277 // where OFFSET is the insertion offset, and the VL is the OFFSET plus the
6278 // size of the subvector.
6279 MVT InterSubVT = VecVT;
6280 SDValue AlignedExtract = Vec;
6281 unsigned AlignedIdx = OrigIdx - RemIdx;
6282 if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
6283 InterSubVT = getLMUL1VT(VecVT);
6284 // Extract a subvector equal to the nearest full vector register type. This
6285 // should resolve to a EXTRACT_SUBREG instruction.
6286 AlignedExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec,
6287 DAG.getConstant(AlignedIdx, DL, XLenVT));
6288 }
6289
6290 SDValue SlideupAmt = DAG.getConstant(RemIdx, DL, XLenVT);
6291 // For scalable vectors this must be further multiplied by vscale.
6292 SlideupAmt = DAG.getNode(ISD::VSCALE, DL, XLenVT, SlideupAmt);
6293
6294 auto [Mask, VL] = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
6295
6296 // Construct the vector length corresponding to RemIdx + length(SubVecVT).
6297 VL = DAG.getConstant(SubVecVT.getVectorMinNumElements(), DL, XLenVT);
6298 VL = DAG.getNode(ISD::VSCALE, DL, XLenVT, VL);
6299 VL = DAG.getNode(ISD::ADD, DL, XLenVT, SlideupAmt, VL);
6300
6301 SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InterSubVT,
6302 DAG.getUNDEF(InterSubVT), SubVec,
6303 DAG.getConstant(0, DL, XLenVT));
6304
6305 SDValue Slideup = getVSlideup(DAG, Subtarget, DL, InterSubVT, AlignedExtract,
6306 SubVec, SlideupAmt, Mask, VL);
6307
6308 // If required, insert this subvector back into the correct vector register.
6309 // This should resolve to an INSERT_SUBREG instruction.
6310 if (VecVT.bitsGT(InterSubVT))
6311 Slideup = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT, Vec, Slideup,
6312 DAG.getConstant(AlignedIdx, DL, XLenVT));
6313
6314 // We might have bitcast from a mask type: cast back to the original type if
6315 // required.
6316 return DAG.getBitcast(Op.getSimpleValueType(), Slideup);
6317 }
6318
lowerEXTRACT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const6319 SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
6320 SelectionDAG &DAG) const {
6321 SDValue Vec = Op.getOperand(0);
6322 MVT SubVecVT = Op.getSimpleValueType();
6323 MVT VecVT = Vec.getSimpleValueType();
6324
6325 SDLoc DL(Op);
6326 MVT XLenVT = Subtarget.getXLenVT();
6327 unsigned OrigIdx = Op.getConstantOperandVal(1);
6328 const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
6329
6330 // We don't have the ability to slide mask vectors down indexed by their i1
6331 // elements; the smallest we can do is i8. Often we are able to bitcast to
6332 // equivalent i8 vectors. Note that when extracting a fixed-length vector
6333 // from a scalable one, we might not necessarily have enough scalable
6334 // elements to safely divide by 8: v8i1 = extract nxv1i1 is valid.
6335 if (SubVecVT.getVectorElementType() == MVT::i1 && OrigIdx != 0) {
6336 if (VecVT.getVectorMinNumElements() >= 8 &&
6337 SubVecVT.getVectorMinNumElements() >= 8) {
6338 assert(OrigIdx % 8 == 0 && "Invalid index");
6339 assert(VecVT.getVectorMinNumElements() % 8 == 0 &&
6340 SubVecVT.getVectorMinNumElements() % 8 == 0 &&
6341 "Unexpected mask vector lowering");
6342 OrigIdx /= 8;
6343 SubVecVT =
6344 MVT::getVectorVT(MVT::i8, SubVecVT.getVectorMinNumElements() / 8,
6345 SubVecVT.isScalableVector());
6346 VecVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorMinNumElements() / 8,
6347 VecVT.isScalableVector());
6348 Vec = DAG.getBitcast(VecVT, Vec);
6349 } else {
6350 // We can't slide this mask vector down, indexed by its i1 elements.
6351 // This poses a problem when we wish to extract a scalable vector which
6352 // can't be re-expressed as a larger type. Just choose the slow path and
6353 // extend to a larger type, then truncate back down.
6354 // TODO: We could probably improve this when extracting certain fixed
6355 // from fixed, where we can extract as i8 and shift the correct element
6356 // right to reach the desired subvector?
6357 MVT ExtVecVT = VecVT.changeVectorElementType(MVT::i8);
6358 MVT ExtSubVecVT = SubVecVT.changeVectorElementType(MVT::i8);
6359 Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVecVT, Vec);
6360 Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtSubVecVT, Vec,
6361 Op.getOperand(1));
6362 SDValue SplatZero = DAG.getConstant(0, DL, ExtSubVecVT);
6363 return DAG.getSetCC(DL, SubVecVT, Vec, SplatZero, ISD::SETNE);
6364 }
6365 }
6366
6367 // If the subvector vector is a fixed-length type, we cannot use subregister
6368 // manipulation to simplify the codegen; we don't know which register of a
6369 // LMUL group contains the specific subvector as we only know the minimum
6370 // register size. Therefore we must slide the vector group down the full
6371 // amount.
6372 if (SubVecVT.isFixedLengthVector()) {
6373 // With an index of 0 this is a cast-like subvector, which can be performed
6374 // with subregister operations.
6375 if (OrigIdx == 0)
6376 return Op;
6377 MVT ContainerVT = VecVT;
6378 if (VecVT.isFixedLengthVector()) {
6379 ContainerVT = getContainerForFixedLengthVector(VecVT);
6380 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
6381 }
6382 SDValue Mask =
6383 getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
6384 // Set the vector length to only the number of elements we care about. This
6385 // avoids sliding down elements we're going to discard straight away.
6386 SDValue VL = getVLOp(SubVecVT.getVectorNumElements(), DL, DAG, Subtarget);
6387 SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
6388 SDValue Slidedown =
6389 getVSlidedown(DAG, Subtarget, DL, ContainerVT,
6390 DAG.getUNDEF(ContainerVT), Vec, SlidedownAmt, Mask, VL);
6391 // Now we can use a cast-like subvector extract to get the result.
6392 Slidedown = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Slidedown,
6393 DAG.getConstant(0, DL, XLenVT));
6394 return DAG.getBitcast(Op.getValueType(), Slidedown);
6395 }
6396
6397 unsigned SubRegIdx, RemIdx;
6398 std::tie(SubRegIdx, RemIdx) =
6399 RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
6400 VecVT, SubVecVT, OrigIdx, TRI);
6401
6402 // If the Idx has been completely eliminated then this is a subvector extract
6403 // which naturally aligns to a vector register. These can easily be handled
6404 // using subregister manipulation.
6405 if (RemIdx == 0)
6406 return Op;
6407
6408 // Else we must shift our vector register directly to extract the subvector.
6409 // Do this using VSLIDEDOWN.
6410
6411 // If the vector type is an LMUL-group type, extract a subvector equal to the
6412 // nearest full vector register type. This should resolve to a EXTRACT_SUBREG
6413 // instruction.
6414 MVT InterSubVT = VecVT;
6415 if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
6416 InterSubVT = getLMUL1VT(VecVT);
6417 Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec,
6418 DAG.getConstant(OrigIdx - RemIdx, DL, XLenVT));
6419 }
6420
6421 // Slide this vector register down by the desired number of elements in order
6422 // to place the desired subvector starting at element 0.
6423 SDValue SlidedownAmt = DAG.getConstant(RemIdx, DL, XLenVT);
6424 // For scalable vectors this must be further multiplied by vscale.
6425 SlidedownAmt = DAG.getNode(ISD::VSCALE, DL, XLenVT, SlidedownAmt);
6426
6427 auto [Mask, VL] = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget);
6428 SDValue Slidedown =
6429 getVSlidedown(DAG, Subtarget, DL, InterSubVT, DAG.getUNDEF(InterSubVT),
6430 Vec, SlidedownAmt, Mask, VL);
6431
6432 // Now the vector is in the right position, extract our final subvector. This
6433 // should resolve to a COPY.
6434 Slidedown = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Slidedown,
6435 DAG.getConstant(0, DL, XLenVT));
6436
6437 // We might have bitcast from a mask type: cast back to the original type if
6438 // required.
6439 return DAG.getBitcast(Op.getSimpleValueType(), Slidedown);
6440 }
6441
6442 // Lower step_vector to the vid instruction. Any non-identity step value must
6443 // be accounted for my manual expansion.
lowerSTEP_VECTOR(SDValue Op,SelectionDAG & DAG) const6444 SDValue RISCVTargetLowering::lowerSTEP_VECTOR(SDValue Op,
6445 SelectionDAG &DAG) const {
6446 SDLoc DL(Op);
6447 MVT VT = Op.getSimpleValueType();
6448 assert(VT.isScalableVector() && "Expected scalable vector");
6449 MVT XLenVT = Subtarget.getXLenVT();
6450 auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
6451 SDValue StepVec = DAG.getNode(RISCVISD::VID_VL, DL, VT, Mask, VL);
6452 uint64_t StepValImm = Op.getConstantOperandVal(0);
6453 if (StepValImm != 1) {
6454 if (isPowerOf2_64(StepValImm)) {
6455 SDValue StepVal =
6456 DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
6457 DAG.getConstant(Log2_64(StepValImm), DL, XLenVT), VL);
6458 StepVec = DAG.getNode(ISD::SHL, DL, VT, StepVec, StepVal);
6459 } else {
6460 SDValue StepVal = lowerScalarSplat(
6461 SDValue(), DAG.getConstant(StepValImm, DL, VT.getVectorElementType()),
6462 VL, VT, DL, DAG, Subtarget);
6463 StepVec = DAG.getNode(ISD::MUL, DL, VT, StepVec, StepVal);
6464 }
6465 }
6466 return StepVec;
6467 }
6468
6469 // Implement vector_reverse using vrgather.vv with indices determined by
6470 // subtracting the id of each element from (VLMAX-1). This will convert
6471 // the indices like so:
6472 // (0, 1,..., VLMAX-2, VLMAX-1) -> (VLMAX-1, VLMAX-2,..., 1, 0).
6473 // TODO: This code assumes VLMAX <= 65536 for LMUL=8 SEW=16.
lowerVECTOR_REVERSE(SDValue Op,SelectionDAG & DAG) const6474 SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
6475 SelectionDAG &DAG) const {
6476 SDLoc DL(Op);
6477 MVT VecVT = Op.getSimpleValueType();
6478 if (VecVT.getVectorElementType() == MVT::i1) {
6479 MVT WidenVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorElementCount());
6480 SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WidenVT, Op.getOperand(0));
6481 SDValue Op2 = DAG.getNode(ISD::VECTOR_REVERSE, DL, WidenVT, Op1);
6482 return DAG.getNode(ISD::TRUNCATE, DL, VecVT, Op2);
6483 }
6484 unsigned EltSize = VecVT.getScalarSizeInBits();
6485 unsigned MinSize = VecVT.getSizeInBits().getKnownMinValue();
6486 unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
6487 unsigned MaxVLMAX =
6488 RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize);
6489
6490 unsigned GatherOpc = RISCVISD::VRGATHER_VV_VL;
6491 MVT IntVT = VecVT.changeVectorElementTypeToInteger();
6492
6493 // If this is SEW=8 and VLMAX is potentially more than 256, we need
6494 // to use vrgatherei16.vv.
6495 // TODO: It's also possible to use vrgatherei16.vv for other types to
6496 // decrease register width for the index calculation.
6497 if (MaxVLMAX > 256 && EltSize == 8) {
6498 // If this is LMUL=8, we have to split before can use vrgatherei16.vv.
6499 // Reverse each half, then reassemble them in reverse order.
6500 // NOTE: It's also possible that after splitting that VLMAX no longer
6501 // requires vrgatherei16.vv.
6502 if (MinSize == (8 * RISCV::RVVBitsPerBlock)) {
6503 auto [Lo, Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
6504 auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VecVT);
6505 Lo = DAG.getNode(ISD::VECTOR_REVERSE, DL, LoVT, Lo);
6506 Hi = DAG.getNode(ISD::VECTOR_REVERSE, DL, HiVT, Hi);
6507 // Reassemble the low and high pieces reversed.
6508 // FIXME: This is a CONCAT_VECTORS.
6509 SDValue Res =
6510 DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT, DAG.getUNDEF(VecVT), Hi,
6511 DAG.getIntPtrConstant(0, DL));
6512 return DAG.getNode(
6513 ISD::INSERT_SUBVECTOR, DL, VecVT, Res, Lo,
6514 DAG.getIntPtrConstant(LoVT.getVectorMinNumElements(), DL));
6515 }
6516
6517 // Just promote the int type to i16 which will double the LMUL.
6518 IntVT = MVT::getVectorVT(MVT::i16, VecVT.getVectorElementCount());
6519 GatherOpc = RISCVISD::VRGATHEREI16_VV_VL;
6520 }
6521
6522 MVT XLenVT = Subtarget.getXLenVT();
6523 auto [Mask, VL] = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
6524
6525 // Calculate VLMAX-1 for the desired SEW.
6526 unsigned MinElts = VecVT.getVectorMinNumElements();
6527 SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT,
6528 getVLOp(MinElts, DL, DAG, Subtarget));
6529 SDValue VLMinus1 =
6530 DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, DAG.getConstant(1, DL, XLenVT));
6531
6532 // Splat VLMAX-1 taking care to handle SEW==64 on RV32.
6533 bool IsRV32E64 =
6534 !Subtarget.is64Bit() && IntVT.getVectorElementType() == MVT::i64;
6535 SDValue SplatVL;
6536 if (!IsRV32E64)
6537 SplatVL = DAG.getSplatVector(IntVT, DL, VLMinus1);
6538 else
6539 SplatVL = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT, DAG.getUNDEF(IntVT),
6540 VLMinus1, DAG.getRegister(RISCV::X0, XLenVT));
6541
6542 SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IntVT, Mask, VL);
6543 SDValue Indices = DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID,
6544 DAG.getUNDEF(IntVT), Mask, VL);
6545
6546 return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices,
6547 DAG.getUNDEF(VecVT), Mask, VL);
6548 }
6549
lowerVECTOR_SPLICE(SDValue Op,SelectionDAG & DAG) const6550 SDValue RISCVTargetLowering::lowerVECTOR_SPLICE(SDValue Op,
6551 SelectionDAG &DAG) const {
6552 SDLoc DL(Op);
6553 SDValue V1 = Op.getOperand(0);
6554 SDValue V2 = Op.getOperand(1);
6555 MVT XLenVT = Subtarget.getXLenVT();
6556 MVT VecVT = Op.getSimpleValueType();
6557
6558 unsigned MinElts = VecVT.getVectorMinNumElements();
6559 SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT,
6560 getVLOp(MinElts, DL, DAG, Subtarget));
6561
6562 int64_t ImmValue = cast<ConstantSDNode>(Op.getOperand(2))->getSExtValue();
6563 SDValue DownOffset, UpOffset;
6564 if (ImmValue >= 0) {
6565 // The operand is a TargetConstant, we need to rebuild it as a regular
6566 // constant.
6567 DownOffset = DAG.getConstant(ImmValue, DL, XLenVT);
6568 UpOffset = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, DownOffset);
6569 } else {
6570 // The operand is a TargetConstant, we need to rebuild it as a regular
6571 // constant rather than negating the original operand.
6572 UpOffset = DAG.getConstant(-ImmValue, DL, XLenVT);
6573 DownOffset = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, UpOffset);
6574 }
6575
6576 SDValue TrueMask = getAllOnesMask(VecVT, VLMax, DL, DAG);
6577
6578 SDValue SlideDown =
6579 getVSlidedown(DAG, Subtarget, DL, VecVT, DAG.getUNDEF(VecVT), V1,
6580 DownOffset, TrueMask, UpOffset);
6581 return getVSlideup(DAG, Subtarget, DL, VecVT, SlideDown, V2, UpOffset,
6582 TrueMask, DAG.getRegister(RISCV::X0, XLenVT),
6583 RISCVII::TAIL_AGNOSTIC);
6584 }
6585
6586 SDValue
lowerFixedLengthVectorLoadToRVV(SDValue Op,SelectionDAG & DAG) const6587 RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
6588 SelectionDAG &DAG) const {
6589 SDLoc DL(Op);
6590 auto *Load = cast<LoadSDNode>(Op);
6591
6592 assert(allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
6593 Load->getMemoryVT(),
6594 *Load->getMemOperand()) &&
6595 "Expecting a correctly-aligned load");
6596
6597 MVT VT = Op.getSimpleValueType();
6598 MVT XLenVT = Subtarget.getXLenVT();
6599 MVT ContainerVT = getContainerForFixedLengthVector(VT);
6600
6601 SDValue VL = getVLOp(VT.getVectorNumElements(), DL, DAG, Subtarget);
6602
6603 bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
6604 SDValue IntID = DAG.getTargetConstant(
6605 IsMaskOp ? Intrinsic::riscv_vlm : Intrinsic::riscv_vle, DL, XLenVT);
6606 SmallVector<SDValue, 4> Ops{Load->getChain(), IntID};
6607 if (!IsMaskOp)
6608 Ops.push_back(DAG.getUNDEF(ContainerVT));
6609 Ops.push_back(Load->getBasePtr());
6610 Ops.push_back(VL);
6611 SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
6612 SDValue NewLoad =
6613 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
6614 Load->getMemoryVT(), Load->getMemOperand());
6615
6616 SDValue Result = convertFromScalableVector(VT, NewLoad, DAG, Subtarget);
6617 return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
6618 }
6619
6620 SDValue
lowerFixedLengthVectorStoreToRVV(SDValue Op,SelectionDAG & DAG) const6621 RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
6622 SelectionDAG &DAG) const {
6623 SDLoc DL(Op);
6624 auto *Store = cast<StoreSDNode>(Op);
6625
6626 assert(allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
6627 Store->getMemoryVT(),
6628 *Store->getMemOperand()) &&
6629 "Expecting a correctly-aligned store");
6630
6631 SDValue StoreVal = Store->getValue();
6632 MVT VT = StoreVal.getSimpleValueType();
6633 MVT XLenVT = Subtarget.getXLenVT();
6634
6635 // If the size less than a byte, we need to pad with zeros to make a byte.
6636 if (VT.getVectorElementType() == MVT::i1 && VT.getVectorNumElements() < 8) {
6637 VT = MVT::v8i1;
6638 StoreVal = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
6639 DAG.getConstant(0, DL, VT), StoreVal,
6640 DAG.getIntPtrConstant(0, DL));
6641 }
6642
6643 MVT ContainerVT = getContainerForFixedLengthVector(VT);
6644
6645 SDValue VL = getVLOp(VT.getVectorNumElements(), DL, DAG, Subtarget);
6646
6647 SDValue NewValue =
6648 convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);
6649
6650 bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
6651 SDValue IntID = DAG.getTargetConstant(
6652 IsMaskOp ? Intrinsic::riscv_vsm : Intrinsic::riscv_vse, DL, XLenVT);
6653 return DAG.getMemIntrinsicNode(
6654 ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other),
6655 {Store->getChain(), IntID, NewValue, Store->getBasePtr(), VL},
6656 Store->getMemoryVT(), Store->getMemOperand());
6657 }
6658
lowerMaskedLoad(SDValue Op,SelectionDAG & DAG) const6659 SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
6660 SelectionDAG &DAG) const {
6661 SDLoc DL(Op);
6662 MVT VT = Op.getSimpleValueType();
6663
6664 const auto *MemSD = cast<MemSDNode>(Op);
6665 EVT MemVT = MemSD->getMemoryVT();
6666 MachineMemOperand *MMO = MemSD->getMemOperand();
6667 SDValue Chain = MemSD->getChain();
6668 SDValue BasePtr = MemSD->getBasePtr();
6669
6670 SDValue Mask, PassThru, VL;
6671 if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
6672 Mask = VPLoad->getMask();
6673 PassThru = DAG.getUNDEF(VT);
6674 VL = VPLoad->getVectorLength();
6675 } else {
6676 const auto *MLoad = cast<MaskedLoadSDNode>(Op);
6677 Mask = MLoad->getMask();
6678 PassThru = MLoad->getPassThru();
6679 }
6680
6681 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
6682
6683 MVT XLenVT = Subtarget.getXLenVT();
6684
6685 MVT ContainerVT = VT;
6686 if (VT.isFixedLengthVector()) {
6687 ContainerVT = getContainerForFixedLengthVector(VT);
6688 PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
6689 if (!IsUnmasked) {
6690 MVT MaskVT = getMaskTypeFor(ContainerVT);
6691 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
6692 }
6693 }
6694
6695 if (!VL)
6696 VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
6697
6698 unsigned IntID =
6699 IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
6700 SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
6701 if (IsUnmasked)
6702 Ops.push_back(DAG.getUNDEF(ContainerVT));
6703 else
6704 Ops.push_back(PassThru);
6705 Ops.push_back(BasePtr);
6706 if (!IsUnmasked)
6707 Ops.push_back(Mask);
6708 Ops.push_back(VL);
6709 if (!IsUnmasked)
6710 Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
6711
6712 SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
6713
6714 SDValue Result =
6715 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
6716 Chain = Result.getValue(1);
6717
6718 if (VT.isFixedLengthVector())
6719 Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
6720
6721 return DAG.getMergeValues({Result, Chain}, DL);
6722 }
6723
lowerMaskedStore(SDValue Op,SelectionDAG & DAG) const6724 SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
6725 SelectionDAG &DAG) const {
6726 SDLoc DL(Op);
6727
6728 const auto *MemSD = cast<MemSDNode>(Op);
6729 EVT MemVT = MemSD->getMemoryVT();
6730 MachineMemOperand *MMO = MemSD->getMemOperand();
6731 SDValue Chain = MemSD->getChain();
6732 SDValue BasePtr = MemSD->getBasePtr();
6733 SDValue Val, Mask, VL;
6734
6735 if (const auto *VPStore = dyn_cast<VPStoreSDNode>(Op)) {
6736 Val = VPStore->getValue();
6737 Mask = VPStore->getMask();
6738 VL = VPStore->getVectorLength();
6739 } else {
6740 const auto *MStore = cast<MaskedStoreSDNode>(Op);
6741 Val = MStore->getValue();
6742 Mask = MStore->getMask();
6743 }
6744
6745 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
6746
6747 MVT VT = Val.getSimpleValueType();
6748 MVT XLenVT = Subtarget.getXLenVT();
6749
6750 MVT ContainerVT = VT;
6751 if (VT.isFixedLengthVector()) {
6752 ContainerVT = getContainerForFixedLengthVector(VT);
6753
6754 Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
6755 if (!IsUnmasked) {
6756 MVT MaskVT = getMaskTypeFor(ContainerVT);
6757 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
6758 }
6759 }
6760
6761 if (!VL)
6762 VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
6763
6764 unsigned IntID =
6765 IsUnmasked ? Intrinsic::riscv_vse : Intrinsic::riscv_vse_mask;
6766 SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
6767 Ops.push_back(Val);
6768 Ops.push_back(BasePtr);
6769 if (!IsUnmasked)
6770 Ops.push_back(Mask);
6771 Ops.push_back(VL);
6772
6773 return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL,
6774 DAG.getVTList(MVT::Other), Ops, MemVT, MMO);
6775 }
6776
6777 SDValue
lowerFixedLengthVectorSetccToRVV(SDValue Op,SelectionDAG & DAG) const6778 RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
6779 SelectionDAG &DAG) const {
6780 MVT InVT = Op.getOperand(0).getSimpleValueType();
6781 MVT ContainerVT = getContainerForFixedLengthVector(InVT);
6782
6783 MVT VT = Op.getSimpleValueType();
6784
6785 SDValue Op1 =
6786 convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
6787 SDValue Op2 =
6788 convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
6789
6790 SDLoc DL(Op);
6791 auto [Mask, VL] = getDefaultVLOps(VT.getVectorNumElements(), ContainerVT, DL,
6792 DAG, Subtarget);
6793 MVT MaskVT = getMaskTypeFor(ContainerVT);
6794
6795 SDValue Cmp =
6796 DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
6797 {Op1, Op2, Op.getOperand(2), DAG.getUNDEF(MaskVT), Mask, VL});
6798
6799 return convertFromScalableVector(VT, Cmp, DAG, Subtarget);
6800 }
6801
lowerFixedLengthVectorLogicOpToRVV(SDValue Op,SelectionDAG & DAG,unsigned MaskOpc,unsigned VecOpc) const6802 SDValue RISCVTargetLowering::lowerFixedLengthVectorLogicOpToRVV(
6803 SDValue Op, SelectionDAG &DAG, unsigned MaskOpc, unsigned VecOpc) const {
6804 MVT VT = Op.getSimpleValueType();
6805
6806 if (VT.getVectorElementType() == MVT::i1)
6807 return lowerToScalableOp(Op, DAG, MaskOpc, /*HasMergeOp*/ false,
6808 /*HasMask*/ false);
6809
6810 return lowerToScalableOp(Op, DAG, VecOpc, /*HasMergeOp*/ true);
6811 }
6812
6813 SDValue
lowerFixedLengthVectorShiftToRVV(SDValue Op,SelectionDAG & DAG) const6814 RISCVTargetLowering::lowerFixedLengthVectorShiftToRVV(SDValue Op,
6815 SelectionDAG &DAG) const {
6816 unsigned Opc;
6817 switch (Op.getOpcode()) {
6818 default: llvm_unreachable("Unexpected opcode!");
6819 case ISD::SHL: Opc = RISCVISD::SHL_VL; break;
6820 case ISD::SRA: Opc = RISCVISD::SRA_VL; break;
6821 case ISD::SRL: Opc = RISCVISD::SRL_VL; break;
6822 }
6823
6824 return lowerToScalableOp(Op, DAG, Opc, /*HasMergeOp*/ true);
6825 }
6826
6827 // Lower vector ABS to smax(X, sub(0, X)).
lowerABS(SDValue Op,SelectionDAG & DAG) const6828 SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
6829 SDLoc DL(Op);
6830 MVT VT = Op.getSimpleValueType();
6831 SDValue X = Op.getOperand(0);
6832
6833 assert((Op.getOpcode() == ISD::VP_ABS || VT.isFixedLengthVector()) &&
6834 "Unexpected type for ISD::ABS");
6835
6836 MVT ContainerVT = VT;
6837 if (VT.isFixedLengthVector()) {
6838 ContainerVT = getContainerForFixedLengthVector(VT);
6839 X = convertToScalableVector(ContainerVT, X, DAG, Subtarget);
6840 }
6841
6842 SDValue Mask, VL;
6843 if (Op->getOpcode() == ISD::VP_ABS) {
6844 Mask = Op->getOperand(1);
6845 VL = Op->getOperand(2);
6846 } else
6847 std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
6848
6849 SDValue SplatZero = DAG.getNode(
6850 RISCVISD::VMV_V_X_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
6851 DAG.getConstant(0, DL, Subtarget.getXLenVT()), VL);
6852 SDValue NegX = DAG.getNode(RISCVISD::SUB_VL, DL, ContainerVT, SplatZero, X,
6853 DAG.getUNDEF(ContainerVT), Mask, VL);
6854 SDValue Max = DAG.getNode(RISCVISD::SMAX_VL, DL, ContainerVT, X, NegX,
6855 DAG.getUNDEF(ContainerVT), Mask, VL);
6856
6857 if (VT.isFixedLengthVector())
6858 Max = convertFromScalableVector(VT, Max, DAG, Subtarget);
6859 return Max;
6860 }
6861
lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,SelectionDAG & DAG) const6862 SDValue RISCVTargetLowering::lowerFixedLengthVectorFCOPYSIGNToRVV(
6863 SDValue Op, SelectionDAG &DAG) const {
6864 SDLoc DL(Op);
6865 MVT VT = Op.getSimpleValueType();
6866 SDValue Mag = Op.getOperand(0);
6867 SDValue Sign = Op.getOperand(1);
6868 assert(Mag.getValueType() == Sign.getValueType() &&
6869 "Can only handle COPYSIGN with matching types.");
6870
6871 MVT ContainerVT = getContainerForFixedLengthVector(VT);
6872 Mag = convertToScalableVector(ContainerVT, Mag, DAG, Subtarget);
6873 Sign = convertToScalableVector(ContainerVT, Sign, DAG, Subtarget);
6874
6875 auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
6876
6877 SDValue CopySign = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Mag,
6878 Sign, DAG.getUNDEF(ContainerVT), Mask, VL);
6879
6880 return convertFromScalableVector(VT, CopySign, DAG, Subtarget);
6881 }
6882
lowerFixedLengthVectorSelectToRVV(SDValue Op,SelectionDAG & DAG) const6883 SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
6884 SDValue Op, SelectionDAG &DAG) const {
6885 MVT VT = Op.getSimpleValueType();
6886 MVT ContainerVT = getContainerForFixedLengthVector(VT);
6887
6888 MVT I1ContainerVT =
6889 MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
6890
6891 SDValue CC =
6892 convertToScalableVector(I1ContainerVT, Op.getOperand(0), DAG, Subtarget);
6893 SDValue Op1 =
6894 convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
6895 SDValue Op2 =
6896 convertToScalableVector(ContainerVT, Op.getOperand(2), DAG, Subtarget);
6897
6898 SDLoc DL(Op);
6899 SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
6900
6901 SDValue Select =
6902 DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, CC, Op1, Op2, VL);
6903
6904 return convertFromScalableVector(VT, Select, DAG, Subtarget);
6905 }
6906
lowerToScalableOp(SDValue Op,SelectionDAG & DAG,unsigned NewOpc,bool HasMergeOp,bool HasMask) const6907 SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
6908 unsigned NewOpc, bool HasMergeOp,
6909 bool HasMask) const {
6910 MVT VT = Op.getSimpleValueType();
6911 MVT ContainerVT = getContainerForFixedLengthVector(VT);
6912
6913 // Create list of operands by converting existing ones to scalable types.
6914 SmallVector<SDValue, 6> Ops;
6915 for (const SDValue &V : Op->op_values()) {
6916 assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
6917
6918 // Pass through non-vector operands.
6919 if (!V.getValueType().isVector()) {
6920 Ops.push_back(V);
6921 continue;
6922 }
6923
6924 // "cast" fixed length vector to a scalable vector.
6925 assert(useRVVForFixedLengthVectorVT(V.getSimpleValueType()) &&
6926 "Only fixed length vectors are supported!");
6927 Ops.push_back(convertToScalableVector(ContainerVT, V, DAG, Subtarget));
6928 }
6929
6930 SDLoc DL(Op);
6931 auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
6932 if (HasMergeOp)
6933 Ops.push_back(DAG.getUNDEF(ContainerVT));
6934 if (HasMask)
6935 Ops.push_back(Mask);
6936 Ops.push_back(VL);
6937
6938 SDValue ScalableRes =
6939 DAG.getNode(NewOpc, DL, ContainerVT, Ops, Op->getFlags());
6940 return convertFromScalableVector(VT, ScalableRes, DAG, Subtarget);
6941 }
6942
6943 // Lower a VP_* ISD node to the corresponding RISCVISD::*_VL node:
6944 // * Operands of each node are assumed to be in the same order.
6945 // * The EVL operand is promoted from i32 to i64 on RV64.
6946 // * Fixed-length vectors are converted to their scalable-vector container
6947 // types.
lowerVPOp(SDValue Op,SelectionDAG & DAG,unsigned RISCVISDOpc,bool HasMergeOp) const6948 SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG,
6949 unsigned RISCVISDOpc,
6950 bool HasMergeOp) const {
6951 SDLoc DL(Op);
6952 MVT VT = Op.getSimpleValueType();
6953 SmallVector<SDValue, 4> Ops;
6954
6955 MVT ContainerVT = VT;
6956 if (VT.isFixedLengthVector())
6957 ContainerVT = getContainerForFixedLengthVector(VT);
6958
6959 for (const auto &OpIdx : enumerate(Op->ops())) {
6960 SDValue V = OpIdx.value();
6961 assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
6962 // Add dummy merge value before the mask.
6963 if (HasMergeOp && *ISD::getVPMaskIdx(Op.getOpcode()) == OpIdx.index())
6964 Ops.push_back(DAG.getUNDEF(ContainerVT));
6965 // Pass through operands which aren't fixed-length vectors.
6966 if (!V.getValueType().isFixedLengthVector()) {
6967 Ops.push_back(V);
6968 continue;
6969 }
6970 // "cast" fixed length vector to a scalable vector.
6971 MVT OpVT = V.getSimpleValueType();
6972 MVT ContainerVT = getContainerForFixedLengthVector(OpVT);
6973 assert(useRVVForFixedLengthVectorVT(OpVT) &&
6974 "Only fixed length vectors are supported!");
6975 Ops.push_back(convertToScalableVector(ContainerVT, V, DAG, Subtarget));
6976 }
6977
6978 if (!VT.isFixedLengthVector())
6979 return DAG.getNode(RISCVISDOpc, DL, VT, Ops, Op->getFlags());
6980
6981 SDValue VPOp = DAG.getNode(RISCVISDOpc, DL, ContainerVT, Ops, Op->getFlags());
6982
6983 return convertFromScalableVector(VT, VPOp, DAG, Subtarget);
6984 }
6985
lowerVPExtMaskOp(SDValue Op,SelectionDAG & DAG) const6986 SDValue RISCVTargetLowering::lowerVPExtMaskOp(SDValue Op,
6987 SelectionDAG &DAG) const {
6988 SDLoc DL(Op);
6989 MVT VT = Op.getSimpleValueType();
6990
6991 SDValue Src = Op.getOperand(0);
6992 // NOTE: Mask is dropped.
6993 SDValue VL = Op.getOperand(2);
6994
6995 MVT ContainerVT = VT;
6996 if (VT.isFixedLengthVector()) {
6997 ContainerVT = getContainerForFixedLengthVector(VT);
6998 MVT SrcVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
6999 Src = convertToScalableVector(SrcVT, Src, DAG, Subtarget);
7000 }
7001
7002 MVT XLenVT = Subtarget.getXLenVT();
7003 SDValue Zero = DAG.getConstant(0, DL, XLenVT);
7004 SDValue ZeroSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
7005 DAG.getUNDEF(ContainerVT), Zero, VL);
7006
7007 SDValue SplatValue = DAG.getConstant(
7008 Op.getOpcode() == ISD::VP_ZERO_EXTEND ? 1 : -1, DL, XLenVT);
7009 SDValue Splat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
7010 DAG.getUNDEF(ContainerVT), SplatValue, VL);
7011
7012 SDValue Result = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, Src,
7013 Splat, ZeroSplat, VL);
7014 if (!VT.isFixedLengthVector())
7015 return Result;
7016 return convertFromScalableVector(VT, Result, DAG, Subtarget);
7017 }
7018
lowerVPSetCCMaskOp(SDValue Op,SelectionDAG & DAG) const7019 SDValue RISCVTargetLowering::lowerVPSetCCMaskOp(SDValue Op,
7020 SelectionDAG &DAG) const {
7021 SDLoc DL(Op);
7022 MVT VT = Op.getSimpleValueType();
7023
7024 SDValue Op1 = Op.getOperand(0);
7025 SDValue Op2 = Op.getOperand(1);
7026 ISD::CondCode Condition = cast<CondCodeSDNode>(Op.getOperand(2))->get();
7027 // NOTE: Mask is dropped.
7028 SDValue VL = Op.getOperand(4);
7029
7030 MVT ContainerVT = VT;
7031 if (VT.isFixedLengthVector()) {
7032 ContainerVT = getContainerForFixedLengthVector(VT);
7033 Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
7034 Op2 = convertToScalableVector(ContainerVT, Op2, DAG, Subtarget);
7035 }
7036
7037 SDValue Result;
7038 SDValue AllOneMask = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL);
7039
7040 switch (Condition) {
7041 default:
7042 break;
7043 // X != Y --> (X^Y)
7044 case ISD::SETNE:
7045 Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, Op2, VL);
7046 break;
7047 // X == Y --> ~(X^Y)
7048 case ISD::SETEQ: {
7049 SDValue Temp =
7050 DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, Op2, VL);
7051 Result =
7052 DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, AllOneMask, VL);
7053 break;
7054 }
7055 // X >s Y --> X == 0 & Y == 1 --> ~X & Y
7056 // X <u Y --> X == 0 & Y == 1 --> ~X & Y
7057 case ISD::SETGT:
7058 case ISD::SETULT: {
7059 SDValue Temp =
7060 DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, AllOneMask, VL);
7061 Result = DAG.getNode(RISCVISD::VMAND_VL, DL, ContainerVT, Temp, Op2, VL);
7062 break;
7063 }
7064 // X <s Y --> X == 1 & Y == 0 --> ~Y & X
7065 // X >u Y --> X == 1 & Y == 0 --> ~Y & X
7066 case ISD::SETLT:
7067 case ISD::SETUGT: {
7068 SDValue Temp =
7069 DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op2, AllOneMask, VL);
7070 Result = DAG.getNode(RISCVISD::VMAND_VL, DL, ContainerVT, Op1, Temp, VL);
7071 break;
7072 }
7073 // X >=s Y --> X == 0 | Y == 1 --> ~X | Y
7074 // X <=u Y --> X == 0 | Y == 1 --> ~X | Y
7075 case ISD::SETGE:
7076 case ISD::SETULE: {
7077 SDValue Temp =
7078 DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, AllOneMask, VL);
7079 Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, Op2, VL);
7080 break;
7081 }
7082 // X <=s Y --> X == 1 | Y == 0 --> ~Y | X
7083 // X >=u Y --> X == 1 | Y == 0 --> ~Y | X
7084 case ISD::SETLE:
7085 case ISD::SETUGE: {
7086 SDValue Temp =
7087 DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op2, AllOneMask, VL);
7088 Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, Op1, VL);
7089 break;
7090 }
7091 }
7092
7093 if (!VT.isFixedLengthVector())
7094 return Result;
7095 return convertFromScalableVector(VT, Result, DAG, Subtarget);
7096 }
7097
7098 // Lower Floating-Point/Integer Type-Convert VP SDNodes
lowerVPFPIntConvOp(SDValue Op,SelectionDAG & DAG,unsigned RISCVISDOpc) const7099 SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
7100 unsigned RISCVISDOpc) const {
7101 SDLoc DL(Op);
7102
7103 SDValue Src = Op.getOperand(0);
7104 SDValue Mask = Op.getOperand(1);
7105 SDValue VL = Op.getOperand(2);
7106
7107 MVT DstVT = Op.getSimpleValueType();
7108 MVT SrcVT = Src.getSimpleValueType();
7109 if (DstVT.isFixedLengthVector()) {
7110 DstVT = getContainerForFixedLengthVector(DstVT);
7111 SrcVT = getContainerForFixedLengthVector(SrcVT);
7112 Src = convertToScalableVector(SrcVT, Src, DAG, Subtarget);
7113 MVT MaskVT = getMaskTypeFor(DstVT);
7114 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
7115 }
7116
7117 unsigned DstEltSize = DstVT.getScalarSizeInBits();
7118 unsigned SrcEltSize = SrcVT.getScalarSizeInBits();
7119
7120 SDValue Result;
7121 if (DstEltSize >= SrcEltSize) { // Single-width and widening conversion.
7122 if (SrcVT.isInteger()) {
7123 assert(DstVT.isFloatingPoint() && "Wrong input/output vector types");
7124
7125 unsigned RISCVISDExtOpc = RISCVISDOpc == RISCVISD::SINT_TO_FP_VL
7126 ? RISCVISD::VSEXT_VL
7127 : RISCVISD::VZEXT_VL;
7128
7129 // Do we need to do any pre-widening before converting?
7130 if (SrcEltSize == 1) {
7131 MVT IntVT = DstVT.changeVectorElementTypeToInteger();
7132 MVT XLenVT = Subtarget.getXLenVT();
7133 SDValue Zero = DAG.getConstant(0, DL, XLenVT);
7134 SDValue ZeroSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT,
7135 DAG.getUNDEF(IntVT), Zero, VL);
7136 SDValue One = DAG.getConstant(
7137 RISCVISDExtOpc == RISCVISD::VZEXT_VL ? 1 : -1, DL, XLenVT);
7138 SDValue OneSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT,
7139 DAG.getUNDEF(IntVT), One, VL);
7140 Src = DAG.getNode(RISCVISD::VSELECT_VL, DL, IntVT, Src, OneSplat,
7141 ZeroSplat, VL);
7142 } else if (DstEltSize > (2 * SrcEltSize)) {
7143 // Widen before converting.
7144 MVT IntVT = MVT::getVectorVT(MVT::getIntegerVT(DstEltSize / 2),
7145 DstVT.getVectorElementCount());
7146 Src = DAG.getNode(RISCVISDExtOpc, DL, IntVT, Src, Mask, VL);
7147 }
7148
7149 Result = DAG.getNode(RISCVISDOpc, DL, DstVT, Src, Mask, VL);
7150 } else {
7151 assert(SrcVT.isFloatingPoint() && DstVT.isInteger() &&
7152 "Wrong input/output vector types");
7153
7154 // Convert f16 to f32 then convert f32 to i64.
7155 if (DstEltSize > (2 * SrcEltSize)) {
7156 assert(SrcVT.getVectorElementType() == MVT::f16 && "Unexpected type!");
7157 MVT InterimFVT =
7158 MVT::getVectorVT(MVT::f32, DstVT.getVectorElementCount());
7159 Src =
7160 DAG.getNode(RISCVISD::FP_EXTEND_VL, DL, InterimFVT, Src, Mask, VL);
7161 }
7162
7163 Result = DAG.getNode(RISCVISDOpc, DL, DstVT, Src, Mask, VL);
7164 }
7165 } else { // Narrowing + Conversion
7166 if (SrcVT.isInteger()) {
7167 assert(DstVT.isFloatingPoint() && "Wrong input/output vector types");
7168 // First do a narrowing convert to an FP type half the size, then round
7169 // the FP type to a small FP type if needed.
7170
7171 MVT InterimFVT = DstVT;
7172 if (SrcEltSize > (2 * DstEltSize)) {
7173 assert(SrcEltSize == (4 * DstEltSize) && "Unexpected types!");
7174 assert(DstVT.getVectorElementType() == MVT::f16 && "Unexpected type!");
7175 InterimFVT = MVT::getVectorVT(MVT::f32, DstVT.getVectorElementCount());
7176 }
7177
7178 Result = DAG.getNode(RISCVISDOpc, DL, InterimFVT, Src, Mask, VL);
7179
7180 if (InterimFVT != DstVT) {
7181 Src = Result;
7182 Result = DAG.getNode(RISCVISD::FP_ROUND_VL, DL, DstVT, Src, Mask, VL);
7183 }
7184 } else {
7185 assert(SrcVT.isFloatingPoint() && DstVT.isInteger() &&
7186 "Wrong input/output vector types");
7187 // First do a narrowing conversion to an integer half the size, then
7188 // truncate if needed.
7189
7190 if (DstEltSize == 1) {
7191 // First convert to the same size integer, then convert to mask using
7192 // setcc.
7193 assert(SrcEltSize >= 16 && "Unexpected FP type!");
7194 MVT InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize),
7195 DstVT.getVectorElementCount());
7196 Result = DAG.getNode(RISCVISDOpc, DL, InterimIVT, Src, Mask, VL);
7197
7198 // Compare the integer result to 0. The integer should be 0 or 1/-1,
7199 // otherwise the conversion was undefined.
7200 MVT XLenVT = Subtarget.getXLenVT();
7201 SDValue SplatZero = DAG.getConstant(0, DL, XLenVT);
7202 SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, InterimIVT,
7203 DAG.getUNDEF(InterimIVT), SplatZero, VL);
7204 Result = DAG.getNode(RISCVISD::SETCC_VL, DL, DstVT,
7205 {Result, SplatZero, DAG.getCondCode(ISD::SETNE),
7206 DAG.getUNDEF(DstVT), Mask, VL});
7207 } else {
7208 MVT InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2),
7209 DstVT.getVectorElementCount());
7210
7211 Result = DAG.getNode(RISCVISDOpc, DL, InterimIVT, Src, Mask, VL);
7212
7213 while (InterimIVT != DstVT) {
7214 SrcEltSize /= 2;
7215 Src = Result;
7216 InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2),
7217 DstVT.getVectorElementCount());
7218 Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, InterimIVT,
7219 Src, Mask, VL);
7220 }
7221 }
7222 }
7223 }
7224
7225 MVT VT = Op.getSimpleValueType();
7226 if (!VT.isFixedLengthVector())
7227 return Result;
7228 return convertFromScalableVector(VT, Result, DAG, Subtarget);
7229 }
7230
lowerLogicVPOp(SDValue Op,SelectionDAG & DAG,unsigned MaskOpc,unsigned VecOpc) const7231 SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG,
7232 unsigned MaskOpc,
7233 unsigned VecOpc) const {
7234 MVT VT = Op.getSimpleValueType();
7235 if (VT.getVectorElementType() != MVT::i1)
7236 return lowerVPOp(Op, DAG, VecOpc, true);
7237
7238 // It is safe to drop mask parameter as masked-off elements are undef.
7239 SDValue Op1 = Op->getOperand(0);
7240 SDValue Op2 = Op->getOperand(1);
7241 SDValue VL = Op->getOperand(3);
7242
7243 MVT ContainerVT = VT;
7244 const bool IsFixed = VT.isFixedLengthVector();
7245 if (IsFixed) {
7246 ContainerVT = getContainerForFixedLengthVector(VT);
7247 Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
7248 Op2 = convertToScalableVector(ContainerVT, Op2, DAG, Subtarget);
7249 }
7250
7251 SDLoc DL(Op);
7252 SDValue Val = DAG.getNode(MaskOpc, DL, ContainerVT, Op1, Op2, VL);
7253 if (!IsFixed)
7254 return Val;
7255 return convertFromScalableVector(VT, Val, DAG, Subtarget);
7256 }
7257
lowerVPStridedLoad(SDValue Op,SelectionDAG & DAG) const7258 SDValue RISCVTargetLowering::lowerVPStridedLoad(SDValue Op,
7259 SelectionDAG &DAG) const {
7260 SDLoc DL(Op);
7261 MVT XLenVT = Subtarget.getXLenVT();
7262 MVT VT = Op.getSimpleValueType();
7263 MVT ContainerVT = VT;
7264 if (VT.isFixedLengthVector())
7265 ContainerVT = getContainerForFixedLengthVector(VT);
7266
7267 SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
7268
7269 auto *VPNode = cast<VPStridedLoadSDNode>(Op);
7270 // Check if the mask is known to be all ones
7271 SDValue Mask = VPNode->getMask();
7272 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
7273
7274 SDValue IntID = DAG.getTargetConstant(IsUnmasked ? Intrinsic::riscv_vlse
7275 : Intrinsic::riscv_vlse_mask,
7276 DL, XLenVT);
7277 SmallVector<SDValue, 8> Ops{VPNode->getChain(), IntID,
7278 DAG.getUNDEF(ContainerVT), VPNode->getBasePtr(),
7279 VPNode->getStride()};
7280 if (!IsUnmasked) {
7281 if (VT.isFixedLengthVector()) {
7282 MVT MaskVT = ContainerVT.changeVectorElementType(MVT::i1);
7283 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
7284 }
7285 Ops.push_back(Mask);
7286 }
7287 Ops.push_back(VPNode->getVectorLength());
7288 if (!IsUnmasked) {
7289 SDValue Policy = DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT);
7290 Ops.push_back(Policy);
7291 }
7292
7293 SDValue Result =
7294 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
7295 VPNode->getMemoryVT(), VPNode->getMemOperand());
7296 SDValue Chain = Result.getValue(1);
7297
7298 if (VT.isFixedLengthVector())
7299 Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
7300
7301 return DAG.getMergeValues({Result, Chain}, DL);
7302 }
7303
lowerVPStridedStore(SDValue Op,SelectionDAG & DAG) const7304 SDValue RISCVTargetLowering::lowerVPStridedStore(SDValue Op,
7305 SelectionDAG &DAG) const {
7306 SDLoc DL(Op);
7307 MVT XLenVT = Subtarget.getXLenVT();
7308
7309 auto *VPNode = cast<VPStridedStoreSDNode>(Op);
7310 SDValue StoreVal = VPNode->getValue();
7311 MVT VT = StoreVal.getSimpleValueType();
7312 MVT ContainerVT = VT;
7313 if (VT.isFixedLengthVector()) {
7314 ContainerVT = getContainerForFixedLengthVector(VT);
7315 StoreVal = convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);
7316 }
7317
7318 // Check if the mask is known to be all ones
7319 SDValue Mask = VPNode->getMask();
7320 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
7321
7322 SDValue IntID = DAG.getTargetConstant(IsUnmasked ? Intrinsic::riscv_vsse
7323 : Intrinsic::riscv_vsse_mask,
7324 DL, XLenVT);
7325 SmallVector<SDValue, 8> Ops{VPNode->getChain(), IntID, StoreVal,
7326 VPNode->getBasePtr(), VPNode->getStride()};
7327 if (!IsUnmasked) {
7328 if (VT.isFixedLengthVector()) {
7329 MVT MaskVT = ContainerVT.changeVectorElementType(MVT::i1);
7330 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
7331 }
7332 Ops.push_back(Mask);
7333 }
7334 Ops.push_back(VPNode->getVectorLength());
7335
7336 return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, VPNode->getVTList(),
7337 Ops, VPNode->getMemoryVT(),
7338 VPNode->getMemOperand());
7339 }
7340
7341 // Custom lower MGATHER/VP_GATHER to a legalized form for RVV. It will then be
7342 // matched to a RVV indexed load. The RVV indexed load instructions only
7343 // support the "unsigned unscaled" addressing mode; indices are implicitly
7344 // zero-extended or truncated to XLEN and are treated as byte offsets. Any
7345 // signed or scaled indexing is extended to the XLEN value type and scaled
7346 // accordingly.
lowerMaskedGather(SDValue Op,SelectionDAG & DAG) const7347 SDValue RISCVTargetLowering::lowerMaskedGather(SDValue Op,
7348 SelectionDAG &DAG) const {
7349 SDLoc DL(Op);
7350 MVT VT = Op.getSimpleValueType();
7351
7352 const auto *MemSD = cast<MemSDNode>(Op.getNode());
7353 EVT MemVT = MemSD->getMemoryVT();
7354 MachineMemOperand *MMO = MemSD->getMemOperand();
7355 SDValue Chain = MemSD->getChain();
7356 SDValue BasePtr = MemSD->getBasePtr();
7357
7358 ISD::LoadExtType LoadExtType;
7359 SDValue Index, Mask, PassThru, VL;
7360
7361 if (auto *VPGN = dyn_cast<VPGatherSDNode>(Op.getNode())) {
7362 Index = VPGN->getIndex();
7363 Mask = VPGN->getMask();
7364 PassThru = DAG.getUNDEF(VT);
7365 VL = VPGN->getVectorLength();
7366 // VP doesn't support extending loads.
7367 LoadExtType = ISD::NON_EXTLOAD;
7368 } else {
7369 // Else it must be a MGATHER.
7370 auto *MGN = cast<MaskedGatherSDNode>(Op.getNode());
7371 Index = MGN->getIndex();
7372 Mask = MGN->getMask();
7373 PassThru = MGN->getPassThru();
7374 LoadExtType = MGN->getExtensionType();
7375 }
7376
7377 MVT IndexVT = Index.getSimpleValueType();
7378 MVT XLenVT = Subtarget.getXLenVT();
7379
7380 assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
7381 "Unexpected VTs!");
7382 assert(BasePtr.getSimpleValueType() == XLenVT && "Unexpected pointer type");
7383 // Targets have to explicitly opt-in for extending vector loads.
7384 assert(LoadExtType == ISD::NON_EXTLOAD &&
7385 "Unexpected extending MGATHER/VP_GATHER");
7386 (void)LoadExtType;
7387
7388 // If the mask is known to be all ones, optimize to an unmasked intrinsic;
7389 // the selection of the masked intrinsics doesn't do this for us.
7390 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
7391
7392 MVT ContainerVT = VT;
7393 if (VT.isFixedLengthVector()) {
7394 ContainerVT = getContainerForFixedLengthVector(VT);
7395 IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(),
7396 ContainerVT.getVectorElementCount());
7397
7398 Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
7399
7400 if (!IsUnmasked) {
7401 MVT MaskVT = getMaskTypeFor(ContainerVT);
7402 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
7403 PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
7404 }
7405 }
7406
7407 if (!VL)
7408 VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
7409
7410 if (XLenVT == MVT::i32 && IndexVT.getVectorElementType().bitsGT(XLenVT)) {
7411 IndexVT = IndexVT.changeVectorElementType(XLenVT);
7412 SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, Mask.getValueType(),
7413 VL);
7414 Index = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, IndexVT, Index,
7415 TrueMask, VL);
7416 }
7417
7418 unsigned IntID =
7419 IsUnmasked ? Intrinsic::riscv_vluxei : Intrinsic::riscv_vluxei_mask;
7420 SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
7421 if (IsUnmasked)
7422 Ops.push_back(DAG.getUNDEF(ContainerVT));
7423 else
7424 Ops.push_back(PassThru);
7425 Ops.push_back(BasePtr);
7426 Ops.push_back(Index);
7427 if (!IsUnmasked)
7428 Ops.push_back(Mask);
7429 Ops.push_back(VL);
7430 if (!IsUnmasked)
7431 Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
7432
7433 SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
7434 SDValue Result =
7435 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
7436 Chain = Result.getValue(1);
7437
7438 if (VT.isFixedLengthVector())
7439 Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
7440
7441 return DAG.getMergeValues({Result, Chain}, DL);
7442 }
7443
7444 // Custom lower MSCATTER/VP_SCATTER to a legalized form for RVV. It will then be
7445 // matched to a RVV indexed store. The RVV indexed store instructions only
7446 // support the "unsigned unscaled" addressing mode; indices are implicitly
7447 // zero-extended or truncated to XLEN and are treated as byte offsets. Any
7448 // signed or scaled indexing is extended to the XLEN value type and scaled
7449 // accordingly.
lowerMaskedScatter(SDValue Op,SelectionDAG & DAG) const7450 SDValue RISCVTargetLowering::lowerMaskedScatter(SDValue Op,
7451 SelectionDAG &DAG) const {
7452 SDLoc DL(Op);
7453 const auto *MemSD = cast<MemSDNode>(Op.getNode());
7454 EVT MemVT = MemSD->getMemoryVT();
7455 MachineMemOperand *MMO = MemSD->getMemOperand();
7456 SDValue Chain = MemSD->getChain();
7457 SDValue BasePtr = MemSD->getBasePtr();
7458
7459 bool IsTruncatingStore = false;
7460 SDValue Index, Mask, Val, VL;
7461
7462 if (auto *VPSN = dyn_cast<VPScatterSDNode>(Op.getNode())) {
7463 Index = VPSN->getIndex();
7464 Mask = VPSN->getMask();
7465 Val = VPSN->getValue();
7466 VL = VPSN->getVectorLength();
7467 // VP doesn't support truncating stores.
7468 IsTruncatingStore = false;
7469 } else {
7470 // Else it must be a MSCATTER.
7471 auto *MSN = cast<MaskedScatterSDNode>(Op.getNode());
7472 Index = MSN->getIndex();
7473 Mask = MSN->getMask();
7474 Val = MSN->getValue();
7475 IsTruncatingStore = MSN->isTruncatingStore();
7476 }
7477
7478 MVT VT = Val.getSimpleValueType();
7479 MVT IndexVT = Index.getSimpleValueType();
7480 MVT XLenVT = Subtarget.getXLenVT();
7481
7482 assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
7483 "Unexpected VTs!");
7484 assert(BasePtr.getSimpleValueType() == XLenVT && "Unexpected pointer type");
7485 // Targets have to explicitly opt-in for extending vector loads and
7486 // truncating vector stores.
7487 assert(!IsTruncatingStore && "Unexpected truncating MSCATTER/VP_SCATTER");
7488 (void)IsTruncatingStore;
7489
7490 // If the mask is known to be all ones, optimize to an unmasked intrinsic;
7491 // the selection of the masked intrinsics doesn't do this for us.
7492 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
7493
7494 MVT ContainerVT = VT;
7495 if (VT.isFixedLengthVector()) {
7496 ContainerVT = getContainerForFixedLengthVector(VT);
7497 IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(),
7498 ContainerVT.getVectorElementCount());
7499
7500 Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
7501 Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
7502
7503 if (!IsUnmasked) {
7504 MVT MaskVT = getMaskTypeFor(ContainerVT);
7505 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
7506 }
7507 }
7508
7509 if (!VL)
7510 VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
7511
7512 if (XLenVT == MVT::i32 && IndexVT.getVectorElementType().bitsGT(XLenVT)) {
7513 IndexVT = IndexVT.changeVectorElementType(XLenVT);
7514 SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, Mask.getValueType(),
7515 VL);
7516 Index = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, IndexVT, Index,
7517 TrueMask, VL);
7518 }
7519
7520 unsigned IntID =
7521 IsUnmasked ? Intrinsic::riscv_vsoxei : Intrinsic::riscv_vsoxei_mask;
7522 SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
7523 Ops.push_back(Val);
7524 Ops.push_back(BasePtr);
7525 Ops.push_back(Index);
7526 if (!IsUnmasked)
7527 Ops.push_back(Mask);
7528 Ops.push_back(VL);
7529
7530 return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL,
7531 DAG.getVTList(MVT::Other), Ops, MemVT, MMO);
7532 }
7533
lowerGET_ROUNDING(SDValue Op,SelectionDAG & DAG) const7534 SDValue RISCVTargetLowering::lowerGET_ROUNDING(SDValue Op,
7535 SelectionDAG &DAG) const {
7536 const MVT XLenVT = Subtarget.getXLenVT();
7537 SDLoc DL(Op);
7538 SDValue Chain = Op->getOperand(0);
7539 SDValue SysRegNo = DAG.getTargetConstant(
7540 RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT);
7541 SDVTList VTs = DAG.getVTList(XLenVT, MVT::Other);
7542 SDValue RM = DAG.getNode(RISCVISD::READ_CSR, DL, VTs, Chain, SysRegNo);
7543
7544 // Encoding used for rounding mode in RISCV differs from that used in
7545 // FLT_ROUNDS. To convert it the RISCV rounding mode is used as an index in a
7546 // table, which consists of a sequence of 4-bit fields, each representing
7547 // corresponding FLT_ROUNDS mode.
7548 static const int Table =
7549 (int(RoundingMode::NearestTiesToEven) << 4 * RISCVFPRndMode::RNE) |
7550 (int(RoundingMode::TowardZero) << 4 * RISCVFPRndMode::RTZ) |
7551 (int(RoundingMode::TowardNegative) << 4 * RISCVFPRndMode::RDN) |
7552 (int(RoundingMode::TowardPositive) << 4 * RISCVFPRndMode::RUP) |
7553 (int(RoundingMode::NearestTiesToAway) << 4 * RISCVFPRndMode::RMM);
7554
7555 SDValue Shift =
7556 DAG.getNode(ISD::SHL, DL, XLenVT, RM, DAG.getConstant(2, DL, XLenVT));
7557 SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT,
7558 DAG.getConstant(Table, DL, XLenVT), Shift);
7559 SDValue Masked = DAG.getNode(ISD::AND, DL, XLenVT, Shifted,
7560 DAG.getConstant(7, DL, XLenVT));
7561
7562 return DAG.getMergeValues({Masked, Chain}, DL);
7563 }
7564
lowerSET_ROUNDING(SDValue Op,SelectionDAG & DAG) const7565 SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op,
7566 SelectionDAG &DAG) const {
7567 const MVT XLenVT = Subtarget.getXLenVT();
7568 SDLoc DL(Op);
7569 SDValue Chain = Op->getOperand(0);
7570 SDValue RMValue = Op->getOperand(1);
7571 SDValue SysRegNo = DAG.getTargetConstant(
7572 RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT);
7573
7574 // Encoding used for rounding mode in RISCV differs from that used in
7575 // FLT_ROUNDS. To convert it the C rounding mode is used as an index in
7576 // a table, which consists of a sequence of 4-bit fields, each representing
7577 // corresponding RISCV mode.
7578 static const unsigned Table =
7579 (RISCVFPRndMode::RNE << 4 * int(RoundingMode::NearestTiesToEven)) |
7580 (RISCVFPRndMode::RTZ << 4 * int(RoundingMode::TowardZero)) |
7581 (RISCVFPRndMode::RDN << 4 * int(RoundingMode::TowardNegative)) |
7582 (RISCVFPRndMode::RUP << 4 * int(RoundingMode::TowardPositive)) |
7583 (RISCVFPRndMode::RMM << 4 * int(RoundingMode::NearestTiesToAway));
7584
7585 SDValue Shift = DAG.getNode(ISD::SHL, DL, XLenVT, RMValue,
7586 DAG.getConstant(2, DL, XLenVT));
7587 SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT,
7588 DAG.getConstant(Table, DL, XLenVT), Shift);
7589 RMValue = DAG.getNode(ISD::AND, DL, XLenVT, Shifted,
7590 DAG.getConstant(0x7, DL, XLenVT));
7591 return DAG.getNode(RISCVISD::WRITE_CSR, DL, MVT::Other, Chain, SysRegNo,
7592 RMValue);
7593 }
7594
lowerEH_DWARF_CFA(SDValue Op,SelectionDAG & DAG) const7595 SDValue RISCVTargetLowering::lowerEH_DWARF_CFA(SDValue Op,
7596 SelectionDAG &DAG) const {
7597 MachineFunction &MF = DAG.getMachineFunction();
7598
7599 bool isRISCV64 = Subtarget.is64Bit();
7600 EVT PtrVT = getPointerTy(DAG.getDataLayout());
7601
7602 int FI = MF.getFrameInfo().CreateFixedObject(isRISCV64 ? 8 : 4, 0, false);
7603 return DAG.getFrameIndex(FI, PtrVT);
7604 }
7605
7606 // Returns the opcode of the target-specific SDNode that implements the 32-bit
7607 // form of the given Opcode.
getRISCVWOpcode(unsigned Opcode)7608 static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
7609 switch (Opcode) {
7610 default:
7611 llvm_unreachable("Unexpected opcode");
7612 case ISD::SHL:
7613 return RISCVISD::SLLW;
7614 case ISD::SRA:
7615 return RISCVISD::SRAW;
7616 case ISD::SRL:
7617 return RISCVISD::SRLW;
7618 case ISD::SDIV:
7619 return RISCVISD::DIVW;
7620 case ISD::UDIV:
7621 return RISCVISD::DIVUW;
7622 case ISD::UREM:
7623 return RISCVISD::REMUW;
7624 case ISD::ROTL:
7625 return RISCVISD::ROLW;
7626 case ISD::ROTR:
7627 return RISCVISD::RORW;
7628 }
7629 }
7630
7631 // Converts the given i8/i16/i32 operation to a target-specific SelectionDAG
7632 // node. Because i8/i16/i32 isn't a legal type for RV64, these operations would
7633 // otherwise be promoted to i64, making it difficult to select the
7634 // SLLW/DIVUW/.../*W later one because the fact the operation was originally of
7635 // type i8/i16/i32 is lost.
customLegalizeToWOp(SDNode * N,SelectionDAG & DAG,unsigned ExtOpc=ISD::ANY_EXTEND)7636 static SDValue customLegalizeToWOp(SDNode *N, SelectionDAG &DAG,
7637 unsigned ExtOpc = ISD::ANY_EXTEND) {
7638 SDLoc DL(N);
7639 RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode());
7640 SDValue NewOp0 = DAG.getNode(ExtOpc, DL, MVT::i64, N->getOperand(0));
7641 SDValue NewOp1 = DAG.getNode(ExtOpc, DL, MVT::i64, N->getOperand(1));
7642 SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, NewOp1);
7643 // ReplaceNodeResults requires we maintain the same type for the return value.
7644 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewRes);
7645 }
7646
7647 // Converts the given 32-bit operation to a i64 operation with signed extension
7648 // semantic to reduce the signed extension instructions.
customLegalizeToWOpWithSExt(SDNode * N,SelectionDAG & DAG)7649 static SDValue customLegalizeToWOpWithSExt(SDNode *N, SelectionDAG &DAG) {
7650 SDLoc DL(N);
7651 SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
7652 SDValue NewOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
7653 SDValue NewWOp = DAG.getNode(N->getOpcode(), DL, MVT::i64, NewOp0, NewOp1);
7654 SDValue NewRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, NewWOp,
7655 DAG.getValueType(MVT::i32));
7656 return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes);
7657 }
7658
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const7659 void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
7660 SmallVectorImpl<SDValue> &Results,
7661 SelectionDAG &DAG) const {
7662 SDLoc DL(N);
7663 switch (N->getOpcode()) {
7664 default:
7665 llvm_unreachable("Don't know how to custom type legalize this operation!");
7666 case ISD::STRICT_FP_TO_SINT:
7667 case ISD::STRICT_FP_TO_UINT:
7668 case ISD::FP_TO_SINT:
7669 case ISD::FP_TO_UINT: {
7670 assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
7671 "Unexpected custom legalisation");
7672 bool IsStrict = N->isStrictFPOpcode();
7673 bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT ||
7674 N->getOpcode() == ISD::STRICT_FP_TO_SINT;
7675 SDValue Op0 = IsStrict ? N->getOperand(1) : N->getOperand(0);
7676 if (getTypeAction(*DAG.getContext(), Op0.getValueType()) !=
7677 TargetLowering::TypeSoftenFloat) {
7678 if (!isTypeLegal(Op0.getValueType()))
7679 return;
7680 if (IsStrict) {
7681 SDValue Chain = N->getOperand(0);
7682 // In absense of Zfh, promote f16 to f32, then convert.
7683 if (Op0.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh()) {
7684 Op0 = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
7685 {Chain, Op0});
7686 Chain = Op0.getValue(1);
7687 }
7688 unsigned Opc = IsSigned ? RISCVISD::STRICT_FCVT_W_RV64
7689 : RISCVISD::STRICT_FCVT_WU_RV64;
7690 SDVTList VTs = DAG.getVTList(MVT::i64, MVT::Other);
7691 SDValue Res = DAG.getNode(
7692 Opc, DL, VTs, Chain, Op0,
7693 DAG.getTargetConstant(RISCVFPRndMode::RTZ, DL, MVT::i64));
7694 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
7695 Results.push_back(Res.getValue(1));
7696 return;
7697 }
7698 // In absense of Zfh, promote f16 to f32, then convert.
7699 if (Op0.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh())
7700 Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0);
7701
7702 unsigned Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
7703 SDValue Res =
7704 DAG.getNode(Opc, DL, MVT::i64, Op0,
7705 DAG.getTargetConstant(RISCVFPRndMode::RTZ, DL, MVT::i64));
7706 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
7707 return;
7708 }
7709 // If the FP type needs to be softened, emit a library call using the 'si'
7710 // version. If we left it to default legalization we'd end up with 'di'. If
7711 // the FP type doesn't need to be softened just let generic type
7712 // legalization promote the result type.
7713 RTLIB::Libcall LC;
7714 if (IsSigned)
7715 LC = RTLIB::getFPTOSINT(Op0.getValueType(), N->getValueType(0));
7716 else
7717 LC = RTLIB::getFPTOUINT(Op0.getValueType(), N->getValueType(0));
7718 MakeLibCallOptions CallOptions;
7719 EVT OpVT = Op0.getValueType();
7720 CallOptions.setTypeListBeforeSoften(OpVT, N->getValueType(0), true);
7721 SDValue Chain = IsStrict ? N->getOperand(0) : SDValue();
7722 SDValue Result;
7723 std::tie(Result, Chain) =
7724 makeLibCall(DAG, LC, N->getValueType(0), Op0, CallOptions, DL, Chain);
7725 Results.push_back(Result);
7726 if (IsStrict)
7727 Results.push_back(Chain);
7728 break;
7729 }
7730 case ISD::READCYCLECOUNTER: {
7731 assert(!Subtarget.is64Bit() &&
7732 "READCYCLECOUNTER only has custom type legalization on riscv32");
7733
7734 SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32, MVT::Other);
7735 SDValue RCW =
7736 DAG.getNode(RISCVISD::READ_CYCLE_WIDE, DL, VTs, N->getOperand(0));
7737
7738 Results.push_back(
7739 DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, RCW, RCW.getValue(1)));
7740 Results.push_back(RCW.getValue(2));
7741 break;
7742 }
7743 case ISD::LOAD: {
7744 if (!ISD::isNON_EXTLoad(N))
7745 return;
7746
7747 // Use a SEXTLOAD instead of the default EXTLOAD. Similar to the
7748 // sext_inreg we emit for ADD/SUB/MUL/SLLI.
7749 LoadSDNode *Ld = cast<LoadSDNode>(N);
7750
7751 SDLoc dl(N);
7752 SDValue Res = DAG.getExtLoad(ISD::SEXTLOAD, dl, MVT::i64, Ld->getChain(),
7753 Ld->getBasePtr(), Ld->getMemoryVT(),
7754 Ld->getMemOperand());
7755 Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Res));
7756 Results.push_back(Res.getValue(1));
7757 return;
7758 }
7759 case ISD::MUL: {
7760 unsigned Size = N->getSimpleValueType(0).getSizeInBits();
7761 unsigned XLen = Subtarget.getXLen();
7762 // This multiply needs to be expanded, try to use MULHSU+MUL if possible.
7763 if (Size > XLen) {
7764 assert(Size == (XLen * 2) && "Unexpected custom legalisation");
7765 SDValue LHS = N->getOperand(0);
7766 SDValue RHS = N->getOperand(1);
7767 APInt HighMask = APInt::getHighBitsSet(Size, XLen);
7768
7769 bool LHSIsU = DAG.MaskedValueIsZero(LHS, HighMask);
7770 bool RHSIsU = DAG.MaskedValueIsZero(RHS, HighMask);
7771 // We need exactly one side to be unsigned.
7772 if (LHSIsU == RHSIsU)
7773 return;
7774
7775 auto MakeMULPair = [&](SDValue S, SDValue U) {
7776 MVT XLenVT = Subtarget.getXLenVT();
7777 S = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, S);
7778 U = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, U);
7779 SDValue Lo = DAG.getNode(ISD::MUL, DL, XLenVT, S, U);
7780 SDValue Hi = DAG.getNode(RISCVISD::MULHSU, DL, XLenVT, S, U);
7781 return DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0), Lo, Hi);
7782 };
7783
7784 bool LHSIsS = DAG.ComputeNumSignBits(LHS) > XLen;
7785 bool RHSIsS = DAG.ComputeNumSignBits(RHS) > XLen;
7786
7787 // The other operand should be signed, but still prefer MULH when
7788 // possible.
7789 if (RHSIsU && LHSIsS && !RHSIsS)
7790 Results.push_back(MakeMULPair(LHS, RHS));
7791 else if (LHSIsU && RHSIsS && !LHSIsS)
7792 Results.push_back(MakeMULPair(RHS, LHS));
7793
7794 return;
7795 }
7796 [[fallthrough]];
7797 }
7798 case ISD::ADD:
7799 case ISD::SUB:
7800 assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
7801 "Unexpected custom legalisation");
7802 Results.push_back(customLegalizeToWOpWithSExt(N, DAG));
7803 break;
7804 case ISD::SHL:
7805 case ISD::SRA:
7806 case ISD::SRL:
7807 assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
7808 "Unexpected custom legalisation");
7809 if (N->getOperand(1).getOpcode() != ISD::Constant) {
7810 // If we can use a BSET instruction, allow default promotion to apply.
7811 if (N->getOpcode() == ISD::SHL && Subtarget.hasStdExtZbs() &&
7812 isOneConstant(N->getOperand(0)))
7813 break;
7814 Results.push_back(customLegalizeToWOp(N, DAG));
7815 break;
7816 }
7817
7818 // Custom legalize ISD::SHL by placing a SIGN_EXTEND_INREG after. This is
7819 // similar to customLegalizeToWOpWithSExt, but we must zero_extend the
7820 // shift amount.
7821 if (N->getOpcode() == ISD::SHL) {
7822 SDLoc DL(N);
7823 SDValue NewOp0 =
7824 DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
7825 SDValue NewOp1 =
7826 DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N->getOperand(1));
7827 SDValue NewWOp = DAG.getNode(ISD::SHL, DL, MVT::i64, NewOp0, NewOp1);
7828 SDValue NewRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, NewWOp,
7829 DAG.getValueType(MVT::i32));
7830 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes));
7831 }
7832
7833 break;
7834 case ISD::ROTL:
7835 case ISD::ROTR:
7836 assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
7837 "Unexpected custom legalisation");
7838 Results.push_back(customLegalizeToWOp(N, DAG));
7839 break;
7840 case ISD::CTTZ:
7841 case ISD::CTTZ_ZERO_UNDEF:
7842 case ISD::CTLZ:
7843 case ISD::CTLZ_ZERO_UNDEF: {
7844 assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
7845 "Unexpected custom legalisation");
7846
7847 SDValue NewOp0 =
7848 DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
7849 bool IsCTZ =
7850 N->getOpcode() == ISD::CTTZ || N->getOpcode() == ISD::CTTZ_ZERO_UNDEF;
7851 unsigned Opc = IsCTZ ? RISCVISD::CTZW : RISCVISD::CLZW;
7852 SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp0);
7853 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
7854 return;
7855 }
7856 case ISD::SDIV:
7857 case ISD::UDIV:
7858 case ISD::UREM: {
7859 MVT VT = N->getSimpleValueType(0);
7860 assert((VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32) &&
7861 Subtarget.is64Bit() && Subtarget.hasStdExtM() &&
7862 "Unexpected custom legalisation");
7863 // Don't promote division/remainder by constant since we should expand those
7864 // to multiply by magic constant.
7865 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
7866 if (N->getOperand(1).getOpcode() == ISD::Constant &&
7867 !isIntDivCheap(N->getValueType(0), Attr))
7868 return;
7869
7870 // If the input is i32, use ANY_EXTEND since the W instructions don't read
7871 // the upper 32 bits. For other types we need to sign or zero extend
7872 // based on the opcode.
7873 unsigned ExtOpc = ISD::ANY_EXTEND;
7874 if (VT != MVT::i32)
7875 ExtOpc = N->getOpcode() == ISD::SDIV ? ISD::SIGN_EXTEND
7876 : ISD::ZERO_EXTEND;
7877
7878 Results.push_back(customLegalizeToWOp(N, DAG, ExtOpc));
7879 break;
7880 }
7881 case ISD::UADDO:
7882 case ISD::USUBO: {
7883 assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
7884 "Unexpected custom legalisation");
7885 bool IsAdd = N->getOpcode() == ISD::UADDO;
7886 // Create an ADDW or SUBW.
7887 SDValue LHS = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
7888 SDValue RHS = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
7889 SDValue Res =
7890 DAG.getNode(IsAdd ? ISD::ADD : ISD::SUB, DL, MVT::i64, LHS, RHS);
7891 Res = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, Res,
7892 DAG.getValueType(MVT::i32));
7893
7894 SDValue Overflow;
7895 if (IsAdd && isOneConstant(RHS)) {
7896 // Special case uaddo X, 1 overflowed if the addition result is 0.
7897 // The general case (X + C) < C is not necessarily beneficial. Although we
7898 // reduce the live range of X, we may introduce the materialization of
7899 // constant C, especially when the setcc result is used by branch. We have
7900 // no compare with constant and branch instructions.
7901 Overflow = DAG.getSetCC(DL, N->getValueType(1), Res,
7902 DAG.getConstant(0, DL, MVT::i64), ISD::SETEQ);
7903 } else {
7904 // Sign extend the LHS and perform an unsigned compare with the ADDW
7905 // result. Since the inputs are sign extended from i32, this is equivalent
7906 // to comparing the lower 32 bits.
7907 LHS = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0));
7908 Overflow = DAG.getSetCC(DL, N->getValueType(1), Res, LHS,
7909 IsAdd ? ISD::SETULT : ISD::SETUGT);
7910 }
7911
7912 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
7913 Results.push_back(Overflow);
7914 return;
7915 }
7916 case ISD::UADDSAT:
7917 case ISD::USUBSAT: {
7918 assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
7919 "Unexpected custom legalisation");
7920 if (Subtarget.hasStdExtZbb()) {
7921 // With Zbb we can sign extend and let LegalizeDAG use minu/maxu. Using
7922 // sign extend allows overflow of the lower 32 bits to be detected on
7923 // the promoted size.
7924 SDValue LHS =
7925 DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0));
7926 SDValue RHS =
7927 DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(1));
7928 SDValue Res = DAG.getNode(N->getOpcode(), DL, MVT::i64, LHS, RHS);
7929 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
7930 return;
7931 }
7932
7933 // Without Zbb, expand to UADDO/USUBO+select which will trigger our custom
7934 // promotion for UADDO/USUBO.
7935 Results.push_back(expandAddSubSat(N, DAG));
7936 return;
7937 }
7938 case ISD::ABS: {
7939 assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
7940 "Unexpected custom legalisation");
7941
7942 if (Subtarget.hasStdExtZbb()) {
7943 // Emit a special ABSW node that will be expanded to NEGW+MAX at isel.
7944 // This allows us to remember that the result is sign extended. Expanding
7945 // to NEGW+MAX here requires a Freeze which breaks ComputeNumSignBits.
7946 SDValue Src = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64,
7947 N->getOperand(0));
7948 SDValue Abs = DAG.getNode(RISCVISD::ABSW, DL, MVT::i64, Src);
7949 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Abs));
7950 return;
7951 }
7952
7953 // Expand abs to Y = (sraiw X, 31); subw(xor(X, Y), Y)
7954 SDValue Src = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
7955
7956 // Freeze the source so we can increase it's use count.
7957 Src = DAG.getFreeze(Src);
7958
7959 // Copy sign bit to all bits using the sraiw pattern.
7960 SDValue SignFill = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, Src,
7961 DAG.getValueType(MVT::i32));
7962 SignFill = DAG.getNode(ISD::SRA, DL, MVT::i64, SignFill,
7963 DAG.getConstant(31, DL, MVT::i64));
7964
7965 SDValue NewRes = DAG.getNode(ISD::XOR, DL, MVT::i64, Src, SignFill);
7966 NewRes = DAG.getNode(ISD::SUB, DL, MVT::i64, NewRes, SignFill);
7967
7968 // NOTE: The result is only required to be anyextended, but sext is
7969 // consistent with type legalization of sub.
7970 NewRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, NewRes,
7971 DAG.getValueType(MVT::i32));
7972 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes));
7973 return;
7974 }
7975 case ISD::BITCAST: {
7976 EVT VT = N->getValueType(0);
7977 assert(VT.isInteger() && !VT.isVector() && "Unexpected VT!");
7978 SDValue Op0 = N->getOperand(0);
7979 EVT Op0VT = Op0.getValueType();
7980 MVT XLenVT = Subtarget.getXLenVT();
7981 if (VT == MVT::i16 && Op0VT == MVT::f16 &&
7982 Subtarget.hasStdExtZfhOrZfhmin()) {
7983 SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
7984 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
7985 } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() &&
7986 Subtarget.hasStdExtF()) {
7987 SDValue FPConv =
7988 DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Op0);
7989 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPConv));
7990 } else if (!VT.isVector() && Op0VT.isFixedLengthVector() &&
7991 isTypeLegal(Op0VT)) {
7992 // Custom-legalize bitcasts from fixed-length vector types to illegal
7993 // scalar types in order to improve codegen. Bitcast the vector to a
7994 // one-element vector type whose element type is the same as the result
7995 // type, and extract the first element.
7996 EVT BVT = EVT::getVectorVT(*DAG.getContext(), VT, 1);
7997 if (isTypeLegal(BVT)) {
7998 SDValue BVec = DAG.getBitcast(BVT, Op0);
7999 Results.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
8000 DAG.getConstant(0, DL, XLenVT)));
8001 }
8002 }
8003 break;
8004 }
8005 case RISCVISD::BREV8: {
8006 MVT VT = N->getSimpleValueType(0);
8007 MVT XLenVT = Subtarget.getXLenVT();
8008 assert((VT == MVT::i16 || (VT == MVT::i32 && Subtarget.is64Bit())) &&
8009 "Unexpected custom legalisation");
8010 assert(Subtarget.hasStdExtZbkb() && "Unexpected extension");
8011 SDValue NewOp = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, N->getOperand(0));
8012 SDValue NewRes = DAG.getNode(N->getOpcode(), DL, XLenVT, NewOp);
8013 // ReplaceNodeResults requires we maintain the same type for the return
8014 // value.
8015 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, NewRes));
8016 break;
8017 }
8018 case ISD::EXTRACT_VECTOR_ELT: {
8019 // Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element
8020 // type is illegal (currently only vXi64 RV32).
8021 // With vmv.x.s, when SEW > XLEN, only the least-significant XLEN bits are
8022 // transferred to the destination register. We issue two of these from the
8023 // upper- and lower- halves of the SEW-bit vector element, slid down to the
8024 // first element.
8025 SDValue Vec = N->getOperand(0);
8026 SDValue Idx = N->getOperand(1);
8027
8028 // The vector type hasn't been legalized yet so we can't issue target
8029 // specific nodes if it needs legalization.
8030 // FIXME: We would manually legalize if it's important.
8031 if (!isTypeLegal(Vec.getValueType()))
8032 return;
8033
8034 MVT VecVT = Vec.getSimpleValueType();
8035
8036 assert(!Subtarget.is64Bit() && N->getValueType(0) == MVT::i64 &&
8037 VecVT.getVectorElementType() == MVT::i64 &&
8038 "Unexpected EXTRACT_VECTOR_ELT legalization");
8039
8040 // If this is a fixed vector, we need to convert it to a scalable vector.
8041 MVT ContainerVT = VecVT;
8042 if (VecVT.isFixedLengthVector()) {
8043 ContainerVT = getContainerForFixedLengthVector(VecVT);
8044 Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
8045 }
8046
8047 MVT XLenVT = Subtarget.getXLenVT();
8048
8049 // Use a VL of 1 to avoid processing more elements than we need.
8050 auto [Mask, VL] = getDefaultVLOps(1, ContainerVT, DL, DAG, Subtarget);
8051
8052 // Unless the index is known to be 0, we must slide the vector down to get
8053 // the desired element into index 0.
8054 if (!isNullConstant(Idx)) {
8055 Vec = getVSlidedown(DAG, Subtarget, DL, ContainerVT,
8056 DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
8057 }
8058
8059 // Extract the lower XLEN bits of the correct vector element.
8060 SDValue EltLo = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
8061
8062 // To extract the upper XLEN bits of the vector element, shift the first
8063 // element right by 32 bits and re-extract the lower XLEN bits.
8064 SDValue ThirtyTwoV = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
8065 DAG.getUNDEF(ContainerVT),
8066 DAG.getConstant(32, DL, XLenVT), VL);
8067 SDValue LShr32 =
8068 DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, Vec, ThirtyTwoV,
8069 DAG.getUNDEF(ContainerVT), Mask, VL);
8070
8071 SDValue EltHi = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, LShr32);
8072
8073 Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, EltLo, EltHi));
8074 break;
8075 }
8076 case ISD::INTRINSIC_WO_CHAIN: {
8077 unsigned IntNo = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
8078 switch (IntNo) {
8079 default:
8080 llvm_unreachable(
8081 "Don't know how to custom type legalize this intrinsic!");
8082 case Intrinsic::riscv_orc_b: {
8083 SDValue NewOp =
8084 DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
8085 SDValue Res = DAG.getNode(RISCVISD::ORC_B, DL, MVT::i64, NewOp);
8086 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
8087 return;
8088 }
8089 case Intrinsic::riscv_vmv_x_s: {
8090 EVT VT = N->getValueType(0);
8091 MVT XLenVT = Subtarget.getXLenVT();
8092 if (VT.bitsLT(XLenVT)) {
8093 // Simple case just extract using vmv.x.s and truncate.
8094 SDValue Extract = DAG.getNode(RISCVISD::VMV_X_S, DL,
8095 Subtarget.getXLenVT(), N->getOperand(1));
8096 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Extract));
8097 return;
8098 }
8099
8100 assert(VT == MVT::i64 && !Subtarget.is64Bit() &&
8101 "Unexpected custom legalization");
8102
8103 // We need to do the move in two steps.
8104 SDValue Vec = N->getOperand(1);
8105 MVT VecVT = Vec.getSimpleValueType();
8106
8107 // First extract the lower XLEN bits of the element.
8108 SDValue EltLo = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
8109
8110 // To extract the upper XLEN bits of the vector element, shift the first
8111 // element right by 32 bits and re-extract the lower XLEN bits.
8112 auto [Mask, VL] = getDefaultVLOps(1, VecVT, DL, DAG, Subtarget);
8113
8114 SDValue ThirtyTwoV =
8115 DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT),
8116 DAG.getConstant(32, DL, XLenVT), VL);
8117 SDValue LShr32 = DAG.getNode(RISCVISD::SRL_VL, DL, VecVT, Vec, ThirtyTwoV,
8118 DAG.getUNDEF(VecVT), Mask, VL);
8119 SDValue EltHi = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, LShr32);
8120
8121 Results.push_back(
8122 DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, EltLo, EltHi));
8123 break;
8124 }
8125 }
8126 break;
8127 }
8128 case ISD::VECREDUCE_ADD:
8129 case ISD::VECREDUCE_AND:
8130 case ISD::VECREDUCE_OR:
8131 case ISD::VECREDUCE_XOR:
8132 case ISD::VECREDUCE_SMAX:
8133 case ISD::VECREDUCE_UMAX:
8134 case ISD::VECREDUCE_SMIN:
8135 case ISD::VECREDUCE_UMIN:
8136 if (SDValue V = lowerVECREDUCE(SDValue(N, 0), DAG))
8137 Results.push_back(V);
8138 break;
8139 case ISD::VP_REDUCE_ADD:
8140 case ISD::VP_REDUCE_AND:
8141 case ISD::VP_REDUCE_OR:
8142 case ISD::VP_REDUCE_XOR:
8143 case ISD::VP_REDUCE_SMAX:
8144 case ISD::VP_REDUCE_UMAX:
8145 case ISD::VP_REDUCE_SMIN:
8146 case ISD::VP_REDUCE_UMIN:
8147 if (SDValue V = lowerVPREDUCE(SDValue(N, 0), DAG))
8148 Results.push_back(V);
8149 break;
8150 case ISD::GET_ROUNDING: {
8151 SDVTList VTs = DAG.getVTList(Subtarget.getXLenVT(), MVT::Other);
8152 SDValue Res = DAG.getNode(ISD::GET_ROUNDING, DL, VTs, N->getOperand(0));
8153 Results.push_back(Res.getValue(0));
8154 Results.push_back(Res.getValue(1));
8155 break;
8156 }
8157 }
8158 }
8159
8160 // Try to fold (<bop> x, (reduction.<bop> vec, start))
combineBinOpToReduce(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8161 static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG,
8162 const RISCVSubtarget &Subtarget) {
8163 auto BinOpToRVVReduce = [](unsigned Opc) {
8164 switch (Opc) {
8165 default:
8166 llvm_unreachable("Unhandled binary to transfrom reduction");
8167 case ISD::ADD:
8168 return RISCVISD::VECREDUCE_ADD_VL;
8169 case ISD::UMAX:
8170 return RISCVISD::VECREDUCE_UMAX_VL;
8171 case ISD::SMAX:
8172 return RISCVISD::VECREDUCE_SMAX_VL;
8173 case ISD::UMIN:
8174 return RISCVISD::VECREDUCE_UMIN_VL;
8175 case ISD::SMIN:
8176 return RISCVISD::VECREDUCE_SMIN_VL;
8177 case ISD::AND:
8178 return RISCVISD::VECREDUCE_AND_VL;
8179 case ISD::OR:
8180 return RISCVISD::VECREDUCE_OR_VL;
8181 case ISD::XOR:
8182 return RISCVISD::VECREDUCE_XOR_VL;
8183 case ISD::FADD:
8184 return RISCVISD::VECREDUCE_FADD_VL;
8185 case ISD::FMAXNUM:
8186 return RISCVISD::VECREDUCE_FMAX_VL;
8187 case ISD::FMINNUM:
8188 return RISCVISD::VECREDUCE_FMIN_VL;
8189 }
8190 };
8191
8192 auto IsReduction = [&BinOpToRVVReduce](SDValue V, unsigned Opc) {
8193 return V.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
8194 isNullConstant(V.getOperand(1)) &&
8195 V.getOperand(0).getOpcode() == BinOpToRVVReduce(Opc);
8196 };
8197
8198 unsigned Opc = N->getOpcode();
8199 unsigned ReduceIdx;
8200 if (IsReduction(N->getOperand(0), Opc))
8201 ReduceIdx = 0;
8202 else if (IsReduction(N->getOperand(1), Opc))
8203 ReduceIdx = 1;
8204 else
8205 return SDValue();
8206
8207 // Skip if FADD disallows reassociation but the combiner needs.
8208 if (Opc == ISD::FADD && !N->getFlags().hasAllowReassociation())
8209 return SDValue();
8210
8211 SDValue Extract = N->getOperand(ReduceIdx);
8212 SDValue Reduce = Extract.getOperand(0);
8213 if (!Reduce.hasOneUse())
8214 return SDValue();
8215
8216 SDValue ScalarV = Reduce.getOperand(2);
8217 EVT ScalarVT = ScalarV.getValueType();
8218 if (ScalarV.getOpcode() == ISD::INSERT_SUBVECTOR &&
8219 ScalarV.getOperand(0)->isUndef())
8220 ScalarV = ScalarV.getOperand(1);
8221
8222 // Make sure that ScalarV is a splat with VL=1.
8223 if (ScalarV.getOpcode() != RISCVISD::VFMV_S_F_VL &&
8224 ScalarV.getOpcode() != RISCVISD::VMV_S_X_VL &&
8225 ScalarV.getOpcode() != RISCVISD::VMV_V_X_VL)
8226 return SDValue();
8227
8228 if (!hasNonZeroAVL(ScalarV.getOperand(2)))
8229 return SDValue();
8230
8231 // Check the scalar of ScalarV is neutral element
8232 // TODO: Deal with value other than neutral element.
8233 if (!isNeutralConstant(N->getOpcode(), N->getFlags(), ScalarV.getOperand(1),
8234 0))
8235 return SDValue();
8236
8237 if (!ScalarV.hasOneUse())
8238 return SDValue();
8239
8240 SDValue NewStart = N->getOperand(1 - ReduceIdx);
8241
8242 SDLoc DL(N);
8243 SDValue NewScalarV =
8244 lowerScalarInsert(NewStart, ScalarV.getOperand(2),
8245 ScalarV.getSimpleValueType(), DL, DAG, Subtarget);
8246
8247 // If we looked through an INSERT_SUBVECTOR we need to restore it.
8248 if (ScalarVT != ScalarV.getValueType())
8249 NewScalarV =
8250 DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalarVT, DAG.getUNDEF(ScalarVT),
8251 NewScalarV, DAG.getConstant(0, DL, Subtarget.getXLenVT()));
8252
8253 SDValue NewReduce =
8254 DAG.getNode(Reduce.getOpcode(), DL, Reduce.getValueType(),
8255 Reduce.getOperand(0), Reduce.getOperand(1), NewScalarV,
8256 Reduce.getOperand(3), Reduce.getOperand(4));
8257 return DAG.getNode(Extract.getOpcode(), DL, Extract.getValueType(), NewReduce,
8258 Extract.getOperand(1));
8259 }
8260
8261 // Optimize (add (shl x, c0), (shl y, c1)) ->
8262 // (SLLI (SH*ADD x, y), c0), if c1-c0 equals to [1|2|3].
transformAddShlImm(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8263 static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
8264 const RISCVSubtarget &Subtarget) {
8265 // Perform this optimization only in the zba extension.
8266 if (!Subtarget.hasStdExtZba())
8267 return SDValue();
8268
8269 // Skip for vector types and larger types.
8270 EVT VT = N->getValueType(0);
8271 if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
8272 return SDValue();
8273
8274 // The two operand nodes must be SHL and have no other use.
8275 SDValue N0 = N->getOperand(0);
8276 SDValue N1 = N->getOperand(1);
8277 if (N0->getOpcode() != ISD::SHL || N1->getOpcode() != ISD::SHL ||
8278 !N0->hasOneUse() || !N1->hasOneUse())
8279 return SDValue();
8280
8281 // Check c0 and c1.
8282 auto *N0C = dyn_cast<ConstantSDNode>(N0->getOperand(1));
8283 auto *N1C = dyn_cast<ConstantSDNode>(N1->getOperand(1));
8284 if (!N0C || !N1C)
8285 return SDValue();
8286 int64_t C0 = N0C->getSExtValue();
8287 int64_t C1 = N1C->getSExtValue();
8288 if (C0 <= 0 || C1 <= 0)
8289 return SDValue();
8290
8291 // Skip if SH1ADD/SH2ADD/SH3ADD are not applicable.
8292 int64_t Bits = std::min(C0, C1);
8293 int64_t Diff = std::abs(C0 - C1);
8294 if (Diff != 1 && Diff != 2 && Diff != 3)
8295 return SDValue();
8296
8297 // Build nodes.
8298 SDLoc DL(N);
8299 SDValue NS = (C0 < C1) ? N0->getOperand(0) : N1->getOperand(0);
8300 SDValue NL = (C0 > C1) ? N0->getOperand(0) : N1->getOperand(0);
8301 SDValue NA0 =
8302 DAG.getNode(ISD::SHL, DL, VT, NL, DAG.getConstant(Diff, DL, VT));
8303 SDValue NA1 = DAG.getNode(ISD::ADD, DL, VT, NA0, NS);
8304 return DAG.getNode(ISD::SHL, DL, VT, NA1, DAG.getConstant(Bits, DL, VT));
8305 }
8306
8307 // Combine a constant select operand into its use:
8308 //
8309 // (and (select cond, -1, c), x)
8310 // -> (select cond, x, (and x, c)) [AllOnes=1]
8311 // (or (select cond, 0, c), x)
8312 // -> (select cond, x, (or x, c)) [AllOnes=0]
8313 // (xor (select cond, 0, c), x)
8314 // -> (select cond, x, (xor x, c)) [AllOnes=0]
8315 // (add (select cond, 0, c), x)
8316 // -> (select cond, x, (add x, c)) [AllOnes=0]
8317 // (sub x, (select cond, 0, c))
8318 // -> (select cond, x, (sub x, c)) [AllOnes=0]
combineSelectAndUse(SDNode * N,SDValue Slct,SDValue OtherOp,SelectionDAG & DAG,bool AllOnes,const RISCVSubtarget & Subtarget)8319 static SDValue combineSelectAndUse(SDNode *N, SDValue Slct, SDValue OtherOp,
8320 SelectionDAG &DAG, bool AllOnes,
8321 const RISCVSubtarget &Subtarget) {
8322 EVT VT = N->getValueType(0);
8323
8324 // Skip vectors.
8325 if (VT.isVector())
8326 return SDValue();
8327
8328 if (!Subtarget.hasShortForwardBranchOpt() ||
8329 (Slct.getOpcode() != ISD::SELECT &&
8330 Slct.getOpcode() != RISCVISD::SELECT_CC) ||
8331 !Slct.hasOneUse())
8332 return SDValue();
8333
8334 auto isZeroOrAllOnes = [](SDValue N, bool AllOnes) {
8335 return AllOnes ? isAllOnesConstant(N) : isNullConstant(N);
8336 };
8337
8338 bool SwapSelectOps;
8339 unsigned OpOffset = Slct.getOpcode() == RISCVISD::SELECT_CC ? 2 : 0;
8340 SDValue TrueVal = Slct.getOperand(1 + OpOffset);
8341 SDValue FalseVal = Slct.getOperand(2 + OpOffset);
8342 SDValue NonConstantVal;
8343 if (isZeroOrAllOnes(TrueVal, AllOnes)) {
8344 SwapSelectOps = false;
8345 NonConstantVal = FalseVal;
8346 } else if (isZeroOrAllOnes(FalseVal, AllOnes)) {
8347 SwapSelectOps = true;
8348 NonConstantVal = TrueVal;
8349 } else
8350 return SDValue();
8351
8352 // Slct is now know to be the desired identity constant when CC is true.
8353 TrueVal = OtherOp;
8354 FalseVal = DAG.getNode(N->getOpcode(), SDLoc(N), VT, OtherOp, NonConstantVal);
8355 // Unless SwapSelectOps says the condition should be false.
8356 if (SwapSelectOps)
8357 std::swap(TrueVal, FalseVal);
8358
8359 if (Slct.getOpcode() == RISCVISD::SELECT_CC)
8360 return DAG.getNode(RISCVISD::SELECT_CC, SDLoc(N), VT,
8361 {Slct.getOperand(0), Slct.getOperand(1),
8362 Slct.getOperand(2), TrueVal, FalseVal});
8363
8364 return DAG.getNode(ISD::SELECT, SDLoc(N), VT,
8365 {Slct.getOperand(0), TrueVal, FalseVal});
8366 }
8367
8368 // Attempt combineSelectAndUse on each operand of a commutative operator N.
combineSelectAndUseCommutative(SDNode * N,SelectionDAG & DAG,bool AllOnes,const RISCVSubtarget & Subtarget)8369 static SDValue combineSelectAndUseCommutative(SDNode *N, SelectionDAG &DAG,
8370 bool AllOnes,
8371 const RISCVSubtarget &Subtarget) {
8372 SDValue N0 = N->getOperand(0);
8373 SDValue N1 = N->getOperand(1);
8374 if (SDValue Result = combineSelectAndUse(N, N0, N1, DAG, AllOnes, Subtarget))
8375 return Result;
8376 if (SDValue Result = combineSelectAndUse(N, N1, N0, DAG, AllOnes, Subtarget))
8377 return Result;
8378 return SDValue();
8379 }
8380
8381 // Transform (add (mul x, c0), c1) ->
8382 // (add (mul (add x, c1/c0), c0), c1%c0).
8383 // if c1/c0 and c1%c0 are simm12, while c1 is not. A special corner case
8384 // that should be excluded is when c0*(c1/c0) is simm12, which will lead
8385 // to an infinite loop in DAGCombine if transformed.
8386 // Or transform (add (mul x, c0), c1) ->
8387 // (add (mul (add x, c1/c0+1), c0), c1%c0-c0),
8388 // if c1/c0+1 and c1%c0-c0 are simm12, while c1 is not. A special corner
8389 // case that should be excluded is when c0*(c1/c0+1) is simm12, which will
8390 // lead to an infinite loop in DAGCombine if transformed.
8391 // Or transform (add (mul x, c0), c1) ->
8392 // (add (mul (add x, c1/c0-1), c0), c1%c0+c0),
8393 // if c1/c0-1 and c1%c0+c0 are simm12, while c1 is not. A special corner
8394 // case that should be excluded is when c0*(c1/c0-1) is simm12, which will
8395 // lead to an infinite loop in DAGCombine if transformed.
8396 // Or transform (add (mul x, c0), c1) ->
8397 // (mul (add x, c1/c0), c0).
8398 // if c1%c0 is zero, and c1/c0 is simm12 while c1 is not.
transformAddImmMulImm(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8399 static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
8400 const RISCVSubtarget &Subtarget) {
8401 // Skip for vector types and larger types.
8402 EVT VT = N->getValueType(0);
8403 if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
8404 return SDValue();
8405 // The first operand node must be a MUL and has no other use.
8406 SDValue N0 = N->getOperand(0);
8407 if (!N0->hasOneUse() || N0->getOpcode() != ISD::MUL)
8408 return SDValue();
8409 // Check if c0 and c1 match above conditions.
8410 auto *N0C = dyn_cast<ConstantSDNode>(N0->getOperand(1));
8411 auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
8412 if (!N0C || !N1C)
8413 return SDValue();
8414 // If N0C has multiple uses it's possible one of the cases in
8415 // DAGCombiner::isMulAddWithConstProfitable will be true, which would result
8416 // in an infinite loop.
8417 if (!N0C->hasOneUse())
8418 return SDValue();
8419 int64_t C0 = N0C->getSExtValue();
8420 int64_t C1 = N1C->getSExtValue();
8421 int64_t CA, CB;
8422 if (C0 == -1 || C0 == 0 || C0 == 1 || isInt<12>(C1))
8423 return SDValue();
8424 // Search for proper CA (non-zero) and CB that both are simm12.
8425 if ((C1 / C0) != 0 && isInt<12>(C1 / C0) && isInt<12>(C1 % C0) &&
8426 !isInt<12>(C0 * (C1 / C0))) {
8427 CA = C1 / C0;
8428 CB = C1 % C0;
8429 } else if ((C1 / C0 + 1) != 0 && isInt<12>(C1 / C0 + 1) &&
8430 isInt<12>(C1 % C0 - C0) && !isInt<12>(C0 * (C1 / C0 + 1))) {
8431 CA = C1 / C0 + 1;
8432 CB = C1 % C0 - C0;
8433 } else if ((C1 / C0 - 1) != 0 && isInt<12>(C1 / C0 - 1) &&
8434 isInt<12>(C1 % C0 + C0) && !isInt<12>(C0 * (C1 / C0 - 1))) {
8435 CA = C1 / C0 - 1;
8436 CB = C1 % C0 + C0;
8437 } else
8438 return SDValue();
8439 // Build new nodes (add (mul (add x, c1/c0), c0), c1%c0).
8440 SDLoc DL(N);
8441 SDValue New0 = DAG.getNode(ISD::ADD, DL, VT, N0->getOperand(0),
8442 DAG.getConstant(CA, DL, VT));
8443 SDValue New1 =
8444 DAG.getNode(ISD::MUL, DL, VT, New0, DAG.getConstant(C0, DL, VT));
8445 return DAG.getNode(ISD::ADD, DL, VT, New1, DAG.getConstant(CB, DL, VT));
8446 }
8447
performADDCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8448 static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
8449 const RISCVSubtarget &Subtarget) {
8450 if (SDValue V = transformAddImmMulImm(N, DAG, Subtarget))
8451 return V;
8452 if (SDValue V = transformAddShlImm(N, DAG, Subtarget))
8453 return V;
8454 if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
8455 return V;
8456 // fold (add (select lhs, rhs, cc, 0, y), x) ->
8457 // (select lhs, rhs, cc, x, (add x, y))
8458 return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
8459 }
8460
8461 // Try to turn a sub boolean RHS and constant LHS into an addi.
combineSubOfBoolean(SDNode * N,SelectionDAG & DAG)8462 static SDValue combineSubOfBoolean(SDNode *N, SelectionDAG &DAG) {
8463 SDValue N0 = N->getOperand(0);
8464 SDValue N1 = N->getOperand(1);
8465 EVT VT = N->getValueType(0);
8466 SDLoc DL(N);
8467
8468 // Require a constant LHS.
8469 auto *N0C = dyn_cast<ConstantSDNode>(N0);
8470 if (!N0C)
8471 return SDValue();
8472
8473 // All our optimizations involve subtracting 1 from the immediate and forming
8474 // an ADDI. Make sure the new immediate is valid for an ADDI.
8475 APInt ImmValMinus1 = N0C->getAPIntValue() - 1;
8476 if (!ImmValMinus1.isSignedIntN(12))
8477 return SDValue();
8478
8479 SDValue NewLHS;
8480 if (N1.getOpcode() == ISD::SETCC && N1.hasOneUse()) {
8481 // (sub constant, (setcc x, y, eq/neq)) ->
8482 // (add (setcc x, y, neq/eq), constant - 1)
8483 ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
8484 EVT SetCCOpVT = N1.getOperand(0).getValueType();
8485 if (!isIntEqualitySetCC(CCVal) || !SetCCOpVT.isInteger())
8486 return SDValue();
8487 CCVal = ISD::getSetCCInverse(CCVal, SetCCOpVT);
8488 NewLHS =
8489 DAG.getSetCC(SDLoc(N1), VT, N1.getOperand(0), N1.getOperand(1), CCVal);
8490 } else if (N1.getOpcode() == ISD::XOR && isOneConstant(N1.getOperand(1)) &&
8491 N1.getOperand(0).getOpcode() == ISD::SETCC) {
8492 // (sub C, (xor (setcc), 1)) -> (add (setcc), C-1).
8493 // Since setcc returns a bool the xor is equivalent to 1-setcc.
8494 NewLHS = N1.getOperand(0);
8495 } else
8496 return SDValue();
8497
8498 SDValue NewRHS = DAG.getConstant(ImmValMinus1, DL, VT);
8499 return DAG.getNode(ISD::ADD, DL, VT, NewLHS, NewRHS);
8500 }
8501
performSUBCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8502 static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
8503 const RISCVSubtarget &Subtarget) {
8504 if (SDValue V = combineSubOfBoolean(N, DAG))
8505 return V;
8506
8507 // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
8508 // (select lhs, rhs, cc, x, (sub x, y))
8509 SDValue N0 = N->getOperand(0);
8510 SDValue N1 = N->getOperand(1);
8511 return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false, Subtarget);
8512 }
8513
8514 // Apply DeMorgan's law to (and/or (xor X, 1), (xor Y, 1)) if X and Y are 0/1.
8515 // Legalizing setcc can introduce xors like this. Doing this transform reduces
8516 // the number of xors and may allow the xor to fold into a branch condition.
combineDeMorganOfBoolean(SDNode * N,SelectionDAG & DAG)8517 static SDValue combineDeMorganOfBoolean(SDNode *N, SelectionDAG &DAG) {
8518 SDValue N0 = N->getOperand(0);
8519 SDValue N1 = N->getOperand(1);
8520 bool IsAnd = N->getOpcode() == ISD::AND;
8521
8522 if (N0.getOpcode() != ISD::XOR || N1.getOpcode() != ISD::XOR)
8523 return SDValue();
8524
8525 if (!N0.hasOneUse() || !N1.hasOneUse())
8526 return SDValue();
8527
8528 SDValue N01 = N0.getOperand(1);
8529 SDValue N11 = N1.getOperand(1);
8530
8531 // For AND, SimplifyDemandedBits may have turned one of the (xor X, 1) into
8532 // (xor X, -1) based on the upper bits of the other operand being 0. If the
8533 // operation is And, allow one of the Xors to use -1.
8534 if (isOneConstant(N01)) {
8535 if (!isOneConstant(N11) && !(IsAnd && isAllOnesConstant(N11)))
8536 return SDValue();
8537 } else if (isOneConstant(N11)) {
8538 // N01 and N11 being 1 was already handled. Handle N11==1 and N01==-1.
8539 if (!(IsAnd && isAllOnesConstant(N01)))
8540 return SDValue();
8541 } else
8542 return SDValue();
8543
8544 EVT VT = N->getValueType(0);
8545
8546 SDValue N00 = N0.getOperand(0);
8547 SDValue N10 = N1.getOperand(0);
8548
8549 // The LHS of the xors needs to be 0/1.
8550 APInt Mask = APInt::getBitsSetFrom(VT.getSizeInBits(), 1);
8551 if (!DAG.MaskedValueIsZero(N00, Mask) || !DAG.MaskedValueIsZero(N10, Mask))
8552 return SDValue();
8553
8554 // Invert the opcode and insert a new xor.
8555 SDLoc DL(N);
8556 unsigned Opc = IsAnd ? ISD::OR : ISD::AND;
8557 SDValue Logic = DAG.getNode(Opc, DL, VT, N00, N10);
8558 return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT));
8559 }
8560
performTRUNCATECombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8561 static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
8562 const RISCVSubtarget &Subtarget) {
8563 SDValue N0 = N->getOperand(0);
8564 EVT VT = N->getValueType(0);
8565
8566 // Pre-promote (i1 (truncate (srl X, Y))) on RV64 with Zbs without zero
8567 // extending X. This is safe since we only need the LSB after the shift and
8568 // shift amounts larger than 31 would produce poison. If we wait until
8569 // type legalization, we'll create RISCVISD::SRLW and we can't recover it
8570 // to use a BEXT instruction.
8571 if (Subtarget.is64Bit() && Subtarget.hasStdExtZbs() && VT == MVT::i1 &&
8572 N0.getValueType() == MVT::i32 && N0.getOpcode() == ISD::SRL &&
8573 !isa<ConstantSDNode>(N0.getOperand(1)) && N0.hasOneUse()) {
8574 SDLoc DL(N0);
8575 SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N0.getOperand(0));
8576 SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N0.getOperand(1));
8577 SDValue Srl = DAG.getNode(ISD::SRL, DL, MVT::i64, Op0, Op1);
8578 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Srl);
8579 }
8580
8581 return SDValue();
8582 }
8583
8584 namespace {
8585 // Helper class contains information about comparison operation.
8586 // The first two operands of this operation are compared values and the
8587 // last one is the operation.
8588 // Compared values are stored in Ops.
8589 // Comparison operation is stored in CCode.
8590 class CmpOpInfo {
8591 static unsigned constexpr Size = 2u;
8592
8593 // Type for storing operands of compare operation.
8594 using OpsArray = std::array<SDValue, Size>;
8595 OpsArray Ops;
8596
8597 using const_iterator = OpsArray::const_iterator;
begin() const8598 const_iterator begin() const { return Ops.begin(); }
end() const8599 const_iterator end() const { return Ops.end(); }
8600
8601 ISD::CondCode CCode;
8602
8603 unsigned CommonPos{Size};
8604 unsigned DifferPos{Size};
8605
8606 // Sets CommonPos and DifferPos based on incoming position
8607 // of common operand CPos.
setPositions(const_iterator CPos)8608 void setPositions(const_iterator CPos) {
8609 assert(CPos != Ops.end() && "Common operand has to be in OpsArray.\n");
8610 CommonPos = CPos == Ops.begin() ? 0 : 1;
8611 DifferPos = 1 - CommonPos;
8612 assert((DifferPos == 0 || DifferPos == 1) &&
8613 "Positions can be only 0 or 1.");
8614 }
8615
8616 // Private constructor of comparison info based on comparison operator.
8617 // It is private because CmpOpInfo only reasonable relative to other
8618 // comparison operator. Therefore, infos about comparison operation
8619 // have to be collected simultaneously via CmpOpInfo::getInfoAbout().
CmpOpInfo(const SDValue & CmpOp)8620 CmpOpInfo(const SDValue &CmpOp)
8621 : Ops{CmpOp.getOperand(0), CmpOp.getOperand(1)},
8622 CCode{cast<CondCodeSDNode>(CmpOp.getOperand(2))->get()} {}
8623
8624 // Finds common operand of Op1 and Op2 and finishes filling CmpOpInfos.
8625 // Returns true if common operand is found. Otherwise - false.
establishCorrespondence(CmpOpInfo & Op1,CmpOpInfo & Op2)8626 static bool establishCorrespondence(CmpOpInfo &Op1, CmpOpInfo &Op2) {
8627 const auto CommonOpIt1 =
8628 std::find_first_of(Op1.begin(), Op1.end(), Op2.begin(), Op2.end());
8629 if (CommonOpIt1 == Op1.end())
8630 return false;
8631
8632 const auto CommonOpIt2 = std::find(Op2.begin(), Op2.end(), *CommonOpIt1);
8633 assert(CommonOpIt2 != Op2.end() &&
8634 "Cannot find common operand in the second comparison operation.");
8635
8636 Op1.setPositions(CommonOpIt1);
8637 Op2.setPositions(CommonOpIt2);
8638
8639 return true;
8640 }
8641
8642 public:
8643 CmpOpInfo(const CmpOpInfo &) = default;
8644 CmpOpInfo(CmpOpInfo &&) = default;
8645
operator [](unsigned Pos) const8646 SDValue const &operator[](unsigned Pos) const {
8647 assert(Pos < Size && "Out of range\n");
8648 return Ops[Pos];
8649 }
8650
8651 // Creates infos about comparison operations CmpOp0 and CmpOp1.
8652 // If there is no common operand returns None. Otherwise, returns
8653 // correspondence info about comparison operations.
8654 static std::optional<std::pair<CmpOpInfo, CmpOpInfo>>
getInfoAbout(SDValue const & CmpOp0,SDValue const & CmpOp1)8655 getInfoAbout(SDValue const &CmpOp0, SDValue const &CmpOp1) {
8656 CmpOpInfo Op0{CmpOp0};
8657 CmpOpInfo Op1{CmpOp1};
8658 if (!establishCorrespondence(Op0, Op1))
8659 return std::nullopt;
8660 return std::make_pair(Op0, Op1);
8661 }
8662
8663 // Returns position of common operand.
getCPos() const8664 unsigned getCPos() const { return CommonPos; }
8665
8666 // Returns position of differ operand.
getDPos() const8667 unsigned getDPos() const { return DifferPos; }
8668
8669 // Returns common operand.
getCOp() const8670 SDValue const &getCOp() const { return operator[](CommonPos); }
8671
8672 // Returns differ operand.
getDOp() const8673 SDValue const &getDOp() const { return operator[](DifferPos); }
8674
8675 // Returns consition code of comparison operation.
getCondCode() const8676 ISD::CondCode getCondCode() const { return CCode; }
8677 };
8678 } // namespace
8679
8680 // Verifies conditions to apply an optimization.
8681 // Returns Reference comparison code and three operands A, B, C.
8682 // Conditions for optimization:
8683 // One operand of the compasions has to be common.
8684 // This operand is written to C.
8685 // Two others operands are differend. They are written to A and B.
8686 // Comparisons has to be similar with respect to common operand C.
8687 // e.g. A < C; C > B are similar
8688 // but A < C; B > C are not.
8689 // Reference comparison code is the comparison code if
8690 // common operand is right placed.
8691 // e.g. C > A will be swapped to A < C.
8692 static std::optional<std::tuple<ISD::CondCode, SDValue, SDValue, SDValue>>
verifyCompareConds(SDNode * N,SelectionDAG & DAG)8693 verifyCompareConds(SDNode *N, SelectionDAG &DAG) {
8694 LLVM_DEBUG(
8695 dbgs() << "Checking conditions for comparison operation combining.\n";);
8696
8697 SDValue V0 = N->getOperand(0);
8698 SDValue V1 = N->getOperand(1);
8699 assert(V0.getValueType() == V1.getValueType() &&
8700 "Operations must have the same value type.");
8701
8702 // Condition 1. Operations have to be used only in logic operation.
8703 if (!V0.hasOneUse() || !V1.hasOneUse())
8704 return std::nullopt;
8705
8706 // Condition 2. Operands have to be comparison operations.
8707 if (V0.getOpcode() != ISD::SETCC || V1.getOpcode() != ISD::SETCC)
8708 return std::nullopt;
8709
8710 // Condition 3.1. Operations only with integers.
8711 if (!V0.getOperand(0).getValueType().isInteger())
8712 return std::nullopt;
8713
8714 const auto ComparisonInfo = CmpOpInfo::getInfoAbout(V0, V1);
8715 // Condition 3.2. Common operand has to be in comparison.
8716 if (!ComparisonInfo)
8717 return std::nullopt;
8718
8719 const auto [Op0, Op1] = ComparisonInfo.value();
8720
8721 LLVM_DEBUG(dbgs() << "Shared operands are on positions: " << Op0.getCPos()
8722 << " and " << Op1.getCPos() << '\n';);
8723 // If common operand at the first position then swap operation to convert to
8724 // strict pattern. Common operand has to be right hand side.
8725 ISD::CondCode RefCond = Op0.getCondCode();
8726 ISD::CondCode AssistCode = Op1.getCondCode();
8727 if (!Op0.getCPos())
8728 RefCond = ISD::getSetCCSwappedOperands(RefCond);
8729 if (!Op1.getCPos())
8730 AssistCode = ISD::getSetCCSwappedOperands(AssistCode);
8731 LLVM_DEBUG(dbgs() << "Reference condition is: " << RefCond << '\n';);
8732 // If there are different comparison operations then do not perform an
8733 // optimization. a < c; c < b -> will be changed to b > c.
8734 if (RefCond != AssistCode)
8735 return std::nullopt;
8736
8737 // Conditions can be only similar to Less or Greater. (>, >=, <, <=)
8738 // Applying this mask to the operation will determine Less and Greater
8739 // operations.
8740 const unsigned CmpMask = 0b110;
8741 const unsigned MaskedOpcode = CmpMask & RefCond;
8742 // If masking gave 0b110, then this is an operation NE, O or TRUE.
8743 if (MaskedOpcode == CmpMask)
8744 return std::nullopt;
8745 // If masking gave 00000, then this is an operation E, O or FALSE.
8746 if (MaskedOpcode == 0)
8747 return std::nullopt;
8748 // Everything else is similar to Less or Greater.
8749
8750 SDValue A = Op0.getDOp();
8751 SDValue B = Op1.getDOp();
8752 SDValue C = Op0.getCOp();
8753
8754 LLVM_DEBUG(
8755 dbgs() << "The conditions for combining comparisons are satisfied.\n";);
8756 return std::make_tuple(RefCond, A, B, C);
8757 }
8758
getSelectionCode(bool IsUnsigned,bool IsAnd,bool IsGreaterOp)8759 static ISD::NodeType getSelectionCode(bool IsUnsigned, bool IsAnd,
8760 bool IsGreaterOp) {
8761 // Codes of selection operation. The first index selects signed or unsigned,
8762 // the second index selects MIN/MAX.
8763 static constexpr ISD::NodeType SelectionCodes[2][2] = {
8764 {ISD::SMIN, ISD::SMAX}, {ISD::UMIN, ISD::UMAX}};
8765 const bool ChooseSelCode = IsAnd ^ IsGreaterOp;
8766 return SelectionCodes[IsUnsigned][ChooseSelCode];
8767 }
8768
8769 // Combines two comparison operation and logic operation to one selection
8770 // operation(min, max) and logic operation. Returns new constructed Node if
8771 // conditions for optimization are satisfied.
combineCmpOp(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8772 static SDValue combineCmpOp(SDNode *N, SelectionDAG &DAG,
8773 const RISCVSubtarget &Subtarget) {
8774 if (!Subtarget.hasStdExtZbb())
8775 return SDValue();
8776
8777 const unsigned BitOpcode = N->getOpcode();
8778 assert((BitOpcode == ISD::AND || BitOpcode == ISD::OR) &&
8779 "This optimization can be used only with AND/OR operations");
8780
8781 const auto Props = verifyCompareConds(N, DAG);
8782 // If conditions are invalidated then do not perform an optimization.
8783 if (!Props)
8784 return SDValue();
8785
8786 const auto [RefOpcode, A, B, C] = Props.value();
8787 const EVT CmpOpVT = A.getValueType();
8788
8789 const bool IsGreaterOp = RefOpcode & 0b10;
8790 const bool IsUnsigned = ISD::isUnsignedIntSetCC(RefOpcode);
8791 assert((IsUnsigned || ISD::isSignedIntSetCC(RefOpcode)) &&
8792 "Operation neither with signed or unsigned integers.");
8793
8794 const bool IsAnd = BitOpcode == ISD::AND;
8795 const ISD::NodeType PickCode =
8796 getSelectionCode(IsUnsigned, IsAnd, IsGreaterOp);
8797
8798 SDLoc DL(N);
8799 SDValue Pick = DAG.getNode(PickCode, DL, CmpOpVT, A, B);
8800 SDValue Cmp =
8801 DAG.getSetCC(DL, N->getOperand(0).getValueType(), Pick, C, RefOpcode);
8802
8803 return Cmp;
8804 }
8805
performANDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)8806 static SDValue performANDCombine(SDNode *N,
8807 TargetLowering::DAGCombinerInfo &DCI,
8808 const RISCVSubtarget &Subtarget) {
8809 SelectionDAG &DAG = DCI.DAG;
8810
8811 SDValue N0 = N->getOperand(0);
8812 // Pre-promote (i32 (and (srl X, Y), 1)) on RV64 with Zbs without zero
8813 // extending X. This is safe since we only need the LSB after the shift and
8814 // shift amounts larger than 31 would produce poison. If we wait until
8815 // type legalization, we'll create RISCVISD::SRLW and we can't recover it
8816 // to use a BEXT instruction.
8817 if (Subtarget.is64Bit() && Subtarget.hasStdExtZbs() &&
8818 N->getValueType(0) == MVT::i32 && isOneConstant(N->getOperand(1)) &&
8819 N0.getOpcode() == ISD::SRL && !isa<ConstantSDNode>(N0.getOperand(1)) &&
8820 N0.hasOneUse()) {
8821 SDLoc DL(N);
8822 SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N0.getOperand(0));
8823 SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N0.getOperand(1));
8824 SDValue Srl = DAG.getNode(ISD::SRL, DL, MVT::i64, Op0, Op1);
8825 SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, Srl,
8826 DAG.getConstant(1, DL, MVT::i64));
8827 return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And);
8828 }
8829
8830 if (SDValue V = combineCmpOp(N, DAG, Subtarget))
8831 return V;
8832
8833 if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
8834 return V;
8835
8836 if (DCI.isAfterLegalizeDAG())
8837 if (SDValue V = combineDeMorganOfBoolean(N, DAG))
8838 return V;
8839
8840 // fold (and (select lhs, rhs, cc, -1, y), x) ->
8841 // (select lhs, rhs, cc, x, (and x, y))
8842 return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ true, Subtarget);
8843 }
8844
performORCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)8845 static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
8846 const RISCVSubtarget &Subtarget) {
8847 SelectionDAG &DAG = DCI.DAG;
8848
8849 if (SDValue V = combineCmpOp(N, DAG, Subtarget))
8850 return V;
8851
8852 if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
8853 return V;
8854
8855 if (DCI.isAfterLegalizeDAG())
8856 if (SDValue V = combineDeMorganOfBoolean(N, DAG))
8857 return V;
8858
8859 // fold (or (select cond, 0, y), x) ->
8860 // (select cond, x, (or x, y))
8861 return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
8862 }
8863
performXORCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8864 static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
8865 const RISCVSubtarget &Subtarget) {
8866 SDValue N0 = N->getOperand(0);
8867 SDValue N1 = N->getOperand(1);
8868
8869 // fold (xor (sllw 1, x), -1) -> (rolw ~1, x)
8870 // NOTE: Assumes ROL being legal means ROLW is legal.
8871 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
8872 if (N0.getOpcode() == RISCVISD::SLLW &&
8873 isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0)) &&
8874 TLI.isOperationLegal(ISD::ROTL, MVT::i64)) {
8875 SDLoc DL(N);
8876 return DAG.getNode(RISCVISD::ROLW, DL, MVT::i64,
8877 DAG.getConstant(~1, DL, MVT::i64), N0.getOperand(1));
8878 }
8879
8880 if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
8881 return V;
8882 // fold (xor (select cond, 0, y), x) ->
8883 // (select cond, x, (xor x, y))
8884 return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
8885 }
8886
8887 // Replace (seteq (i64 (and X, 0xffffffff)), C1) with
8888 // (seteq (i64 (sext_inreg (X, i32)), C1')) where C1' is C1 sign extended from
8889 // bit 31. Same for setne. C1' may be cheaper to materialize and the sext_inreg
8890 // can become a sext.w instead of a shift pair.
performSETCCCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8891 static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
8892 const RISCVSubtarget &Subtarget) {
8893 SDValue N0 = N->getOperand(0);
8894 SDValue N1 = N->getOperand(1);
8895 EVT VT = N->getValueType(0);
8896 EVT OpVT = N0.getValueType();
8897
8898 if (OpVT != MVT::i64 || !Subtarget.is64Bit())
8899 return SDValue();
8900
8901 // RHS needs to be a constant.
8902 auto *N1C = dyn_cast<ConstantSDNode>(N1);
8903 if (!N1C)
8904 return SDValue();
8905
8906 // LHS needs to be (and X, 0xffffffff).
8907 if (N0.getOpcode() != ISD::AND || !N0.hasOneUse() ||
8908 !isa<ConstantSDNode>(N0.getOperand(1)) ||
8909 N0.getConstantOperandVal(1) != UINT64_C(0xffffffff))
8910 return SDValue();
8911
8912 // Looking for an equality compare.
8913 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
8914 if (!isIntEqualitySetCC(Cond))
8915 return SDValue();
8916
8917 // Don't do this if the sign bit is provably zero, it will be turned back into
8918 // an AND.
8919 APInt SignMask = APInt::getOneBitSet(64, 31);
8920 if (DAG.MaskedValueIsZero(N0.getOperand(0), SignMask))
8921 return SDValue();
8922
8923 const APInt &C1 = N1C->getAPIntValue();
8924
8925 SDLoc dl(N);
8926 // If the constant is larger than 2^32 - 1 it is impossible for both sides
8927 // to be equal.
8928 if (C1.getActiveBits() > 32)
8929 return DAG.getBoolConstant(Cond == ISD::SETNE, dl, VT, OpVT);
8930
8931 SDValue SExtOp = DAG.getNode(ISD::SIGN_EXTEND_INREG, N, OpVT,
8932 N0.getOperand(0), DAG.getValueType(MVT::i32));
8933 return DAG.getSetCC(dl, VT, SExtOp, DAG.getConstant(C1.trunc(32).sext(64),
8934 dl, OpVT), Cond);
8935 }
8936
8937 static SDValue
performSIGN_EXTEND_INREGCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)8938 performSIGN_EXTEND_INREGCombine(SDNode *N, SelectionDAG &DAG,
8939 const RISCVSubtarget &Subtarget) {
8940 SDValue Src = N->getOperand(0);
8941 EVT VT = N->getValueType(0);
8942
8943 // Fold (sext_inreg (fmv_x_anyexth X), i16) -> (fmv_x_signexth X)
8944 if (Src.getOpcode() == RISCVISD::FMV_X_ANYEXTH &&
8945 cast<VTSDNode>(N->getOperand(1))->getVT().bitsGE(MVT::i16))
8946 return DAG.getNode(RISCVISD::FMV_X_SIGNEXTH, SDLoc(N), VT,
8947 Src.getOperand(0));
8948
8949 return SDValue();
8950 }
8951
8952 namespace {
8953 // Forward declaration of the structure holding the necessary information to
8954 // apply a combine.
8955 struct CombineResult;
8956
8957 /// Helper class for folding sign/zero extensions.
8958 /// In particular, this class is used for the following combines:
8959 /// add_vl -> vwadd(u) | vwadd(u)_w
8960 /// sub_vl -> vwsub(u) | vwsub(u)_w
8961 /// mul_vl -> vwmul(u) | vwmul_su
8962 ///
8963 /// An object of this class represents an operand of the operation we want to
8964 /// combine.
8965 /// E.g., when trying to combine `mul_vl a, b`, we will have one instance of
8966 /// NodeExtensionHelper for `a` and one for `b`.
8967 ///
8968 /// This class abstracts away how the extension is materialized and
8969 /// how its Mask, VL, number of users affect the combines.
8970 ///
8971 /// In particular:
8972 /// - VWADD_W is conceptually == add(op0, sext(op1))
8973 /// - VWADDU_W == add(op0, zext(op1))
8974 /// - VWSUB_W == sub(op0, sext(op1))
8975 /// - VWSUBU_W == sub(op0, zext(op1))
8976 ///
8977 /// And VMV_V_X_VL, depending on the value, is conceptually equivalent to
8978 /// zext|sext(smaller_value).
8979 struct NodeExtensionHelper {
8980 /// Records if this operand is like being zero extended.
8981 bool SupportsZExt;
8982 /// Records if this operand is like being sign extended.
8983 /// Note: SupportsZExt and SupportsSExt are not mutually exclusive. For
8984 /// instance, a splat constant (e.g., 3), would support being both sign and
8985 /// zero extended.
8986 bool SupportsSExt;
8987 /// This boolean captures whether we care if this operand would still be
8988 /// around after the folding happens.
8989 bool EnforceOneUse;
8990 /// Records if this operand's mask needs to match the mask of the operation
8991 /// that it will fold into.
8992 bool CheckMask;
8993 /// Value of the Mask for this operand.
8994 /// It may be SDValue().
8995 SDValue Mask;
8996 /// Value of the vector length operand.
8997 /// It may be SDValue().
8998 SDValue VL;
8999 /// Original value that this NodeExtensionHelper represents.
9000 SDValue OrigOperand;
9001
9002 /// Get the value feeding the extension or the value itself.
9003 /// E.g., for zext(a), this would return a.
getSource__anona89954840f11::NodeExtensionHelper9004 SDValue getSource() const {
9005 switch (OrigOperand.getOpcode()) {
9006 case RISCVISD::VSEXT_VL:
9007 case RISCVISD::VZEXT_VL:
9008 return OrigOperand.getOperand(0);
9009 default:
9010 return OrigOperand;
9011 }
9012 }
9013
9014 /// Check if this instance represents a splat.
isSplat__anona89954840f11::NodeExtensionHelper9015 bool isSplat() const {
9016 return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
9017 }
9018
9019 /// Get or create a value that can feed \p Root with the given extension \p
9020 /// SExt. If \p SExt is None, this returns the source of this operand.
9021 /// \see ::getSource().
getOrCreateExtendedOp__anona89954840f11::NodeExtensionHelper9022 SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
9023 std::optional<bool> SExt) const {
9024 if (!SExt.has_value())
9025 return OrigOperand;
9026
9027 MVT NarrowVT = getNarrowType(Root);
9028
9029 SDValue Source = getSource();
9030 if (Source.getValueType() == NarrowVT)
9031 return Source;
9032
9033 unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
9034
9035 // If we need an extension, we should be changing the type.
9036 SDLoc DL(Root);
9037 auto [Mask, VL] = getMaskAndVL(Root);
9038 switch (OrigOperand.getOpcode()) {
9039 case RISCVISD::VSEXT_VL:
9040 case RISCVISD::VZEXT_VL:
9041 return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
9042 case RISCVISD::VMV_V_X_VL:
9043 return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
9044 DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
9045 default:
9046 // Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
9047 // and that operand should already have the right NarrowVT so no
9048 // extension should be required at this point.
9049 llvm_unreachable("Unsupported opcode");
9050 }
9051 }
9052
9053 /// Helper function to get the narrow type for \p Root.
9054 /// The narrow type is the type of \p Root where we divided the size of each
9055 /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
9056 /// \pre The size of the type of the elements of Root must be a multiple of 2
9057 /// and be greater than 16.
getNarrowType__anona89954840f11::NodeExtensionHelper9058 static MVT getNarrowType(const SDNode *Root) {
9059 MVT VT = Root->getSimpleValueType(0);
9060
9061 // Determine the narrow size.
9062 unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
9063 assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
9064 MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
9065 VT.getVectorElementCount());
9066 return NarrowVT;
9067 }
9068
9069 /// Return the opcode required to materialize the folding of the sign
9070 /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
9071 /// both operands for \p Opcode.
9072 /// Put differently, get the opcode to materialize:
9073 /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
9074 /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
9075 /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
getSameExtensionOpcode__anona89954840f11::NodeExtensionHelper9076 static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
9077 switch (Opcode) {
9078 case RISCVISD::ADD_VL:
9079 case RISCVISD::VWADD_W_VL:
9080 case RISCVISD::VWADDU_W_VL:
9081 return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
9082 case RISCVISD::MUL_VL:
9083 return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
9084 case RISCVISD::SUB_VL:
9085 case RISCVISD::VWSUB_W_VL:
9086 case RISCVISD::VWSUBU_W_VL:
9087 return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
9088 default:
9089 llvm_unreachable("Unexpected opcode");
9090 }
9091 }
9092
9093 /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
9094 /// newOpcode(a, b).
getSUOpcode__anona89954840f11::NodeExtensionHelper9095 static unsigned getSUOpcode(unsigned Opcode) {
9096 assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL");
9097 return RISCVISD::VWMULSU_VL;
9098 }
9099
9100 /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
9101 /// newOpcode(a, b).
getWOpcode__anona89954840f11::NodeExtensionHelper9102 static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
9103 switch (Opcode) {
9104 case RISCVISD::ADD_VL:
9105 return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
9106 case RISCVISD::SUB_VL:
9107 return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
9108 default:
9109 llvm_unreachable("Unexpected opcode");
9110 }
9111 }
9112
9113 using CombineToTry = std::function<std::optional<CombineResult>(
9114 SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
9115 const NodeExtensionHelper & /*RHS*/)>;
9116
9117 /// Check if this node needs to be fully folded or extended for all users.
needToPromoteOtherUsers__anona89954840f11::NodeExtensionHelper9118 bool needToPromoteOtherUsers() const { return EnforceOneUse; }
9119
9120 /// Helper method to set the various fields of this struct based on the
9121 /// type of \p Root.
fillUpExtensionSupport__anona89954840f11::NodeExtensionHelper9122 void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) {
9123 SupportsZExt = false;
9124 SupportsSExt = false;
9125 EnforceOneUse = true;
9126 CheckMask = true;
9127 switch (OrigOperand.getOpcode()) {
9128 case RISCVISD::VZEXT_VL:
9129 SupportsZExt = true;
9130 Mask = OrigOperand.getOperand(1);
9131 VL = OrigOperand.getOperand(2);
9132 break;
9133 case RISCVISD::VSEXT_VL:
9134 SupportsSExt = true;
9135 Mask = OrigOperand.getOperand(1);
9136 VL = OrigOperand.getOperand(2);
9137 break;
9138 case RISCVISD::VMV_V_X_VL: {
9139 // Historically, we didn't care about splat values not disappearing during
9140 // combines.
9141 EnforceOneUse = false;
9142 CheckMask = false;
9143 VL = OrigOperand.getOperand(2);
9144
9145 // The operand is a splat of a scalar.
9146
9147 // The pasthru must be undef for tail agnostic.
9148 if (!OrigOperand.getOperand(0).isUndef())
9149 break;
9150
9151 // Get the scalar value.
9152 SDValue Op = OrigOperand.getOperand(1);
9153
9154 // See if we have enough sign bits or zero bits in the scalar to use a
9155 // widening opcode by splatting to smaller element size.
9156 MVT VT = Root->getSimpleValueType(0);
9157 unsigned EltBits = VT.getScalarSizeInBits();
9158 unsigned ScalarBits = Op.getValueSizeInBits();
9159 // Make sure we're getting all element bits from the scalar register.
9160 // FIXME: Support implicit sign extension of vmv.v.x?
9161 if (ScalarBits < EltBits)
9162 break;
9163
9164 unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
9165 // If the narrow type cannot be expressed with a legal VMV,
9166 // this is not a valid candidate.
9167 if (NarrowSize < 8)
9168 break;
9169
9170 if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
9171 SupportsSExt = true;
9172 if (DAG.MaskedValueIsZero(Op,
9173 APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
9174 SupportsZExt = true;
9175 break;
9176 }
9177 default:
9178 break;
9179 }
9180 }
9181
9182 /// Check if \p Root supports any extension folding combines.
isSupportedRoot__anona89954840f11::NodeExtensionHelper9183 static bool isSupportedRoot(const SDNode *Root) {
9184 switch (Root->getOpcode()) {
9185 case RISCVISD::ADD_VL:
9186 case RISCVISD::MUL_VL:
9187 case RISCVISD::VWADD_W_VL:
9188 case RISCVISD::VWADDU_W_VL:
9189 case RISCVISD::SUB_VL:
9190 case RISCVISD::VWSUB_W_VL:
9191 case RISCVISD::VWSUBU_W_VL:
9192 return true;
9193 default:
9194 return false;
9195 }
9196 }
9197
9198 /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
NodeExtensionHelper__anona89954840f11::NodeExtensionHelper9199 NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) {
9200 assert(isSupportedRoot(Root) && "Trying to build an helper with an "
9201 "unsupported root");
9202 assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
9203 OrigOperand = Root->getOperand(OperandIdx);
9204
9205 unsigned Opc = Root->getOpcode();
9206 switch (Opc) {
9207 // We consider VW<ADD|SUB>(U)_W(LHS, RHS) as if they were
9208 // <ADD|SUB>(LHS, S|ZEXT(RHS))
9209 case RISCVISD::VWADD_W_VL:
9210 case RISCVISD::VWADDU_W_VL:
9211 case RISCVISD::VWSUB_W_VL:
9212 case RISCVISD::VWSUBU_W_VL:
9213 if (OperandIdx == 1) {
9214 SupportsZExt =
9215 Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
9216 SupportsSExt = !SupportsZExt;
9217 std::tie(Mask, VL) = getMaskAndVL(Root);
9218 CheckMask = true;
9219 // There's no existing extension here, so we don't have to worry about
9220 // making sure it gets removed.
9221 EnforceOneUse = false;
9222 break;
9223 }
9224 [[fallthrough]];
9225 default:
9226 fillUpExtensionSupport(Root, DAG);
9227 break;
9228 }
9229 }
9230
9231 /// Check if this operand is compatible with the given vector length \p VL.
isVLCompatible__anona89954840f11::NodeExtensionHelper9232 bool isVLCompatible(SDValue VL) const {
9233 return this->VL != SDValue() && this->VL == VL;
9234 }
9235
9236 /// Check if this operand is compatible with the given \p Mask.
isMaskCompatible__anona89954840f11::NodeExtensionHelper9237 bool isMaskCompatible(SDValue Mask) const {
9238 return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
9239 }
9240
9241 /// Helper function to get the Mask and VL from \p Root.
getMaskAndVL__anona89954840f11::NodeExtensionHelper9242 static std::pair<SDValue, SDValue> getMaskAndVL(const SDNode *Root) {
9243 assert(isSupportedRoot(Root) && "Unexpected root");
9244 return std::make_pair(Root->getOperand(3), Root->getOperand(4));
9245 }
9246
9247 /// Check if the Mask and VL of this operand are compatible with \p Root.
areVLAndMaskCompatible__anona89954840f11::NodeExtensionHelper9248 bool areVLAndMaskCompatible(const SDNode *Root) const {
9249 auto [Mask, VL] = getMaskAndVL(Root);
9250 return isMaskCompatible(Mask) && isVLCompatible(VL);
9251 }
9252
9253 /// Helper function to check if \p N is commutative with respect to the
9254 /// foldings that are supported by this class.
isCommutative__anona89954840f11::NodeExtensionHelper9255 static bool isCommutative(const SDNode *N) {
9256 switch (N->getOpcode()) {
9257 case RISCVISD::ADD_VL:
9258 case RISCVISD::MUL_VL:
9259 case RISCVISD::VWADD_W_VL:
9260 case RISCVISD::VWADDU_W_VL:
9261 return true;
9262 case RISCVISD::SUB_VL:
9263 case RISCVISD::VWSUB_W_VL:
9264 case RISCVISD::VWSUBU_W_VL:
9265 return false;
9266 default:
9267 llvm_unreachable("Unexpected opcode");
9268 }
9269 }
9270
9271 /// Get a list of combine to try for folding extensions in \p Root.
9272 /// Note that each returned CombineToTry function doesn't actually modify
9273 /// anything. Instead they produce an optional CombineResult that if not None,
9274 /// need to be materialized for the combine to be applied.
9275 /// \see CombineResult::materialize.
9276 /// If the related CombineToTry function returns std::nullopt, that means the
9277 /// combine didn't match.
9278 static SmallVector<CombineToTry> getSupportedFoldings(const SDNode *Root);
9279 };
9280
9281 /// Helper structure that holds all the necessary information to materialize a
9282 /// combine that does some extension folding.
9283 struct CombineResult {
9284 /// Opcode to be generated when materializing the combine.
9285 unsigned TargetOpcode;
9286 // No value means no extension is needed. If extension is needed, the value
9287 // indicates if it needs to be sign extended.
9288 std::optional<bool> SExtLHS;
9289 std::optional<bool> SExtRHS;
9290 /// Root of the combine.
9291 SDNode *Root;
9292 /// LHS of the TargetOpcode.
9293 NodeExtensionHelper LHS;
9294 /// RHS of the TargetOpcode.
9295 NodeExtensionHelper RHS;
9296
CombineResult__anona89954840f11::CombineResult9297 CombineResult(unsigned TargetOpcode, SDNode *Root,
9298 const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
9299 const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
9300 : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
9301 Root(Root), LHS(LHS), RHS(RHS) {}
9302
9303 /// Return a value that uses TargetOpcode and that can be used to replace
9304 /// Root.
9305 /// The actual replacement is *not* done in that method.
materialize__anona89954840f11::CombineResult9306 SDValue materialize(SelectionDAG &DAG) const {
9307 SDValue Mask, VL, Merge;
9308 std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
9309 Merge = Root->getOperand(2);
9310 return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
9311 LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
9312 RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
9313 Mask, VL);
9314 }
9315 };
9316
9317 /// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
9318 /// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
9319 /// are zext) and LHS and RHS can be folded into Root.
9320 /// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
9321 ///
9322 /// \note If the pattern can match with both zext and sext, the returned
9323 /// CombineResult will feature the zext result.
9324 ///
9325 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
9326 /// can be used to apply the pattern.
9327 static std::optional<CombineResult>
canFoldToVWWithSameExtensionImpl(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS,bool AllowSExt,bool AllowZExt)9328 canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
9329 const NodeExtensionHelper &RHS, bool AllowSExt,
9330 bool AllowZExt) {
9331 assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
9332 if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
9333 return std::nullopt;
9334 if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
9335 return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
9336 Root->getOpcode(), /*IsSExt=*/false),
9337 Root, LHS, /*SExtLHS=*/false, RHS,
9338 /*SExtRHS=*/false);
9339 if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
9340 return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
9341 Root->getOpcode(), /*IsSExt=*/true),
9342 Root, LHS, /*SExtLHS=*/true, RHS,
9343 /*SExtRHS=*/true);
9344 return std::nullopt;
9345 }
9346
9347 /// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
9348 /// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
9349 /// are zext) and LHS and RHS can be folded into Root.
9350 ///
9351 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
9352 /// can be used to apply the pattern.
9353 static std::optional<CombineResult>
canFoldToVWWithSameExtension(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS)9354 canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
9355 const NodeExtensionHelper &RHS) {
9356 return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
9357 /*AllowZExt=*/true);
9358 }
9359
9360 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
9361 ///
9362 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
9363 /// can be used to apply the pattern.
9364 static std::optional<CombineResult>
canFoldToVW_W(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS)9365 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
9366 const NodeExtensionHelper &RHS) {
9367 if (!RHS.areVLAndMaskCompatible(Root))
9368 return std::nullopt;
9369
9370 // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
9371 // sext/zext?
9372 // Control this behavior behind an option (AllowSplatInVW_W) for testing
9373 // purposes.
9374 if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W))
9375 return CombineResult(
9376 NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false),
9377 Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false);
9378 if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W))
9379 return CombineResult(
9380 NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true),
9381 Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true);
9382 return std::nullopt;
9383 }
9384
9385 /// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
9386 ///
9387 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
9388 /// can be used to apply the pattern.
9389 static std::optional<CombineResult>
canFoldToVWWithSEXT(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS)9390 canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
9391 const NodeExtensionHelper &RHS) {
9392 return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
9393 /*AllowZExt=*/false);
9394 }
9395
9396 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
9397 ///
9398 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
9399 /// can be used to apply the pattern.
9400 static std::optional<CombineResult>
canFoldToVWWithZEXT(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS)9401 canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
9402 const NodeExtensionHelper &RHS) {
9403 return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
9404 /*AllowZExt=*/true);
9405 }
9406
9407 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
9408 ///
9409 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
9410 /// can be used to apply the pattern.
9411 static std::optional<CombineResult>
canFoldToVW_SU(SDNode * Root,const NodeExtensionHelper & LHS,const NodeExtensionHelper & RHS)9412 canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
9413 const NodeExtensionHelper &RHS) {
9414 if (!LHS.SupportsSExt || !RHS.SupportsZExt)
9415 return std::nullopt;
9416 if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
9417 return std::nullopt;
9418 return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
9419 Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
9420 }
9421
9422 SmallVector<NodeExtensionHelper::CombineToTry>
getSupportedFoldings(const SDNode * Root)9423 NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
9424 SmallVector<CombineToTry> Strategies;
9425 switch (Root->getOpcode()) {
9426 case RISCVISD::ADD_VL:
9427 case RISCVISD::SUB_VL:
9428 // add|sub -> vwadd(u)|vwsub(u)
9429 Strategies.push_back(canFoldToVWWithSameExtension);
9430 // add|sub -> vwadd(u)_w|vwsub(u)_w
9431 Strategies.push_back(canFoldToVW_W);
9432 break;
9433 case RISCVISD::MUL_VL:
9434 // mul -> vwmul(u)
9435 Strategies.push_back(canFoldToVWWithSameExtension);
9436 // mul -> vwmulsu
9437 Strategies.push_back(canFoldToVW_SU);
9438 break;
9439 case RISCVISD::VWADD_W_VL:
9440 case RISCVISD::VWSUB_W_VL:
9441 // vwadd_w|vwsub_w -> vwadd|vwsub
9442 Strategies.push_back(canFoldToVWWithSEXT);
9443 break;
9444 case RISCVISD::VWADDU_W_VL:
9445 case RISCVISD::VWSUBU_W_VL:
9446 // vwaddu_w|vwsubu_w -> vwaddu|vwsubu
9447 Strategies.push_back(canFoldToVWWithZEXT);
9448 break;
9449 default:
9450 llvm_unreachable("Unexpected opcode");
9451 }
9452 return Strategies;
9453 }
9454 } // End anonymous namespace.
9455
9456 /// Combine a binary operation to its equivalent VW or VW_W form.
9457 /// The supported combines are:
9458 /// add_vl -> vwadd(u) | vwadd(u)_w
9459 /// sub_vl -> vwsub(u) | vwsub(u)_w
9460 /// mul_vl -> vwmul(u) | vwmul_su
9461 /// vwadd_w(u) -> vwadd(u)
9462 /// vwub_w(u) -> vwadd(u)
9463 static SDValue
combineBinOp_VLToVWBinOp_VL(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)9464 combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
9465 SelectionDAG &DAG = DCI.DAG;
9466
9467 assert(NodeExtensionHelper::isSupportedRoot(N) &&
9468 "Shouldn't have called this method");
9469 SmallVector<SDNode *> Worklist;
9470 SmallSet<SDNode *, 8> Inserted;
9471 Worklist.push_back(N);
9472 Inserted.insert(N);
9473 SmallVector<CombineResult> CombinesToApply;
9474
9475 while (!Worklist.empty()) {
9476 SDNode *Root = Worklist.pop_back_val();
9477 if (!NodeExtensionHelper::isSupportedRoot(Root))
9478 return SDValue();
9479
9480 NodeExtensionHelper LHS(N, 0, DAG);
9481 NodeExtensionHelper RHS(N, 1, DAG);
9482 auto AppendUsersIfNeeded = [&Worklist,
9483 &Inserted](const NodeExtensionHelper &Op) {
9484 if (Op.needToPromoteOtherUsers()) {
9485 for (SDNode *TheUse : Op.OrigOperand->uses()) {
9486 if (Inserted.insert(TheUse).second)
9487 Worklist.push_back(TheUse);
9488 }
9489 }
9490 };
9491
9492 // Control the compile time by limiting the number of node we look at in
9493 // total.
9494 if (Inserted.size() > ExtensionMaxWebSize)
9495 return SDValue();
9496
9497 SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
9498 NodeExtensionHelper::getSupportedFoldings(N);
9499
9500 assert(!FoldingStrategies.empty() && "Nothing to be folded");
9501 bool Matched = false;
9502 for (int Attempt = 0;
9503 (Attempt != 1 + NodeExtensionHelper::isCommutative(N)) && !Matched;
9504 ++Attempt) {
9505
9506 for (NodeExtensionHelper::CombineToTry FoldingStrategy :
9507 FoldingStrategies) {
9508 std::optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
9509 if (Res) {
9510 Matched = true;
9511 CombinesToApply.push_back(*Res);
9512 // All the inputs that are extended need to be folded, otherwise
9513 // we would be leaving the old input (since it is may still be used),
9514 // and the new one.
9515 if (Res->SExtLHS.has_value())
9516 AppendUsersIfNeeded(LHS);
9517 if (Res->SExtRHS.has_value())
9518 AppendUsersIfNeeded(RHS);
9519 break;
9520 }
9521 }
9522 std::swap(LHS, RHS);
9523 }
9524 // Right now we do an all or nothing approach.
9525 if (!Matched)
9526 return SDValue();
9527 }
9528 // Store the value for the replacement of the input node separately.
9529 SDValue InputRootReplacement;
9530 // We do the RAUW after we materialize all the combines, because some replaced
9531 // nodes may be feeding some of the yet-to-be-replaced nodes. Put differently,
9532 // some of these nodes may appear in the NodeExtensionHelpers of some of the
9533 // yet-to-be-visited CombinesToApply roots.
9534 SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
9535 ValuesToReplace.reserve(CombinesToApply.size());
9536 for (CombineResult Res : CombinesToApply) {
9537 SDValue NewValue = Res.materialize(DAG);
9538 if (!InputRootReplacement) {
9539 assert(Res.Root == N &&
9540 "First element is expected to be the current node");
9541 InputRootReplacement = NewValue;
9542 } else {
9543 ValuesToReplace.emplace_back(SDValue(Res.Root, 0), NewValue);
9544 }
9545 }
9546 for (std::pair<SDValue, SDValue> OldNewValues : ValuesToReplace) {
9547 DAG.ReplaceAllUsesOfValueWith(OldNewValues.first, OldNewValues.second);
9548 DCI.AddToWorklist(OldNewValues.second.getNode());
9549 }
9550 return InputRootReplacement;
9551 }
9552
9553 // Fold
9554 // (fp_to_int (froundeven X)) -> fcvt X, rne
9555 // (fp_to_int (ftrunc X)) -> fcvt X, rtz
9556 // (fp_to_int (ffloor X)) -> fcvt X, rdn
9557 // (fp_to_int (fceil X)) -> fcvt X, rup
9558 // (fp_to_int (fround X)) -> fcvt X, rmm
performFP_TO_INTCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)9559 static SDValue performFP_TO_INTCombine(SDNode *N,
9560 TargetLowering::DAGCombinerInfo &DCI,
9561 const RISCVSubtarget &Subtarget) {
9562 SelectionDAG &DAG = DCI.DAG;
9563 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9564 MVT XLenVT = Subtarget.getXLenVT();
9565
9566 SDValue Src = N->getOperand(0);
9567
9568 // Ensure the FP type is legal.
9569 if (!TLI.isTypeLegal(Src.getValueType()))
9570 return SDValue();
9571
9572 // Don't do this for f16 with Zfhmin and not Zfh.
9573 if (Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh())
9574 return SDValue();
9575
9576 RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src.getOpcode());
9577 if (FRM == RISCVFPRndMode::Invalid)
9578 return SDValue();
9579
9580 SDLoc DL(N);
9581 bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT;
9582 EVT VT = N->getValueType(0);
9583
9584 if (VT.isVector() && TLI.isTypeLegal(VT)) {
9585 MVT SrcVT = Src.getSimpleValueType();
9586 MVT SrcContainerVT = SrcVT;
9587 MVT ContainerVT = VT.getSimpleVT();
9588 SDValue XVal = Src.getOperand(0);
9589
9590 // For widening and narrowing conversions we just combine it into a
9591 // VFCVT_..._VL node, as there are no specific VFWCVT/VFNCVT VL nodes. They
9592 // end up getting lowered to their appropriate pseudo instructions based on
9593 // their operand types
9594 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits() * 2 ||
9595 VT.getScalarSizeInBits() * 2 < SrcVT.getScalarSizeInBits())
9596 return SDValue();
9597
9598 // Make fixed-length vectors scalable first
9599 if (SrcVT.isFixedLengthVector()) {
9600 SrcContainerVT = getContainerForFixedLengthVector(DAG, SrcVT, Subtarget);
9601 XVal = convertToScalableVector(SrcContainerVT, XVal, DAG, Subtarget);
9602 ContainerVT =
9603 getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget);
9604 }
9605
9606 auto [Mask, VL] =
9607 getDefaultVLOps(SrcVT, SrcContainerVT, DL, DAG, Subtarget);
9608
9609 SDValue FpToInt;
9610 if (FRM == RISCVFPRndMode::RTZ) {
9611 // Use the dedicated trunc static rounding mode if we're truncating so we
9612 // don't need to generate calls to fsrmi/fsrm
9613 unsigned Opc =
9614 IsSigned ? RISCVISD::VFCVT_RTZ_X_F_VL : RISCVISD::VFCVT_RTZ_XU_F_VL;
9615 FpToInt = DAG.getNode(Opc, DL, ContainerVT, XVal, Mask, VL);
9616 } else {
9617 unsigned Opc =
9618 IsSigned ? RISCVISD::VFCVT_RM_X_F_VL : RISCVISD::VFCVT_RM_XU_F_VL;
9619 FpToInt = DAG.getNode(Opc, DL, ContainerVT, XVal, Mask,
9620 DAG.getTargetConstant(FRM, DL, XLenVT), VL);
9621 }
9622
9623 // If converted from fixed-length to scalable, convert back
9624 if (VT.isFixedLengthVector())
9625 FpToInt = convertFromScalableVector(VT, FpToInt, DAG, Subtarget);
9626
9627 return FpToInt;
9628 }
9629
9630 // Only handle XLen or i32 types. Other types narrower than XLen will
9631 // eventually be legalized to XLenVT.
9632 if (VT != MVT::i32 && VT != XLenVT)
9633 return SDValue();
9634
9635 unsigned Opc;
9636 if (VT == XLenVT)
9637 Opc = IsSigned ? RISCVISD::FCVT_X : RISCVISD::FCVT_XU;
9638 else
9639 Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
9640
9641 SDValue FpToInt = DAG.getNode(Opc, DL, XLenVT, Src.getOperand(0),
9642 DAG.getTargetConstant(FRM, DL, XLenVT));
9643 return DAG.getNode(ISD::TRUNCATE, DL, VT, FpToInt);
9644 }
9645
9646 // Fold
9647 // (fp_to_int_sat (froundeven X)) -> (select X == nan, 0, (fcvt X, rne))
9648 // (fp_to_int_sat (ftrunc X)) -> (select X == nan, 0, (fcvt X, rtz))
9649 // (fp_to_int_sat (ffloor X)) -> (select X == nan, 0, (fcvt X, rdn))
9650 // (fp_to_int_sat (fceil X)) -> (select X == nan, 0, (fcvt X, rup))
9651 // (fp_to_int_sat (fround X)) -> (select X == nan, 0, (fcvt X, rmm))
performFP_TO_INT_SATCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const RISCVSubtarget & Subtarget)9652 static SDValue performFP_TO_INT_SATCombine(SDNode *N,
9653 TargetLowering::DAGCombinerInfo &DCI,
9654 const RISCVSubtarget &Subtarget) {
9655 SelectionDAG &DAG = DCI.DAG;
9656 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9657 MVT XLenVT = Subtarget.getXLenVT();
9658
9659 // Only handle XLen types. Other types narrower than XLen will eventually be
9660 // legalized to XLenVT.
9661 EVT DstVT = N->getValueType(0);
9662 if (DstVT != XLenVT)
9663 return SDValue();
9664
9665 SDValue Src = N->getOperand(0);
9666
9667 // Ensure the FP type is also legal.
9668 if (!TLI.isTypeLegal(Src.getValueType()))
9669 return SDValue();
9670
9671 // Don't do this for f16 with Zfhmin and not Zfh.
9672 if (Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh())
9673 return SDValue();
9674
9675 EVT SatVT = cast<VTSDNode>(N->getOperand(1))->getVT();
9676
9677 RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src.getOpcode());
9678 if (FRM == RISCVFPRndMode::Invalid)
9679 return SDValue();
9680
9681 bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT_SAT;
9682
9683 unsigned Opc;
9684 if (SatVT == DstVT)
9685 Opc = IsSigned ? RISCVISD::FCVT_X : RISCVISD::FCVT_XU;
9686 else if (DstVT == MVT::i64 && SatVT == MVT::i32)
9687 Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
9688 else
9689 return SDValue();
9690 // FIXME: Support other SatVTs by clamping before or after the conversion.
9691
9692 Src = Src.getOperand(0);
9693
9694 SDLoc DL(N);
9695 SDValue FpToInt = DAG.getNode(Opc, DL, XLenVT, Src,
9696 DAG.getTargetConstant(FRM, DL, XLenVT));
9697
9698 // fcvt.wu.* sign extends bit 31 on RV64. FP_TO_UINT_SAT expects to zero
9699 // extend.
9700 if (Opc == RISCVISD::FCVT_WU_RV64)
9701 FpToInt = DAG.getZeroExtendInReg(FpToInt, DL, MVT::i32);
9702
9703 // RISCV FP-to-int conversions saturate to the destination register size, but
9704 // don't produce 0 for nan.
9705 SDValue ZeroInt = DAG.getConstant(0, DL, DstVT);
9706 return DAG.getSelectCC(DL, Src, Src, ZeroInt, FpToInt, ISD::CondCode::SETUO);
9707 }
9708
9709 // Combine (bitreverse (bswap X)) to the BREV8 GREVI encoding if the type is
9710 // smaller than XLenVT.
performBITREVERSECombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)9711 static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
9712 const RISCVSubtarget &Subtarget) {
9713 assert(Subtarget.hasStdExtZbkb() && "Unexpected extension");
9714
9715 SDValue Src = N->getOperand(0);
9716 if (Src.getOpcode() != ISD::BSWAP)
9717 return SDValue();
9718
9719 EVT VT = N->getValueType(0);
9720 if (!VT.isScalarInteger() || VT.getSizeInBits() >= Subtarget.getXLen() ||
9721 !isPowerOf2_32(VT.getSizeInBits()))
9722 return SDValue();
9723
9724 SDLoc DL(N);
9725 return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
9726 }
9727
9728 // Convert from one FMA opcode to another based on whether we are negating the
9729 // multiply result and/or the accumulator.
9730 // NOTE: Only supports RVV operations with VL.
negateFMAOpcode(unsigned Opcode,bool NegMul,bool NegAcc)9731 static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
9732 assert((NegMul || NegAcc) && "Not negating anything?");
9733
9734 // Negating the multiply result changes ADD<->SUB and toggles 'N'.
9735 if (NegMul) {
9736 // clang-format off
9737 switch (Opcode) {
9738 default: llvm_unreachable("Unexpected opcode");
9739 case RISCVISD::VFMADD_VL: Opcode = RISCVISD::VFNMSUB_VL; break;
9740 case RISCVISD::VFNMSUB_VL: Opcode = RISCVISD::VFMADD_VL; break;
9741 case RISCVISD::VFNMADD_VL: Opcode = RISCVISD::VFMSUB_VL; break;
9742 case RISCVISD::VFMSUB_VL: Opcode = RISCVISD::VFNMADD_VL; break;
9743 }
9744 // clang-format on
9745 }
9746
9747 // Negating the accumulator changes ADD<->SUB.
9748 if (NegAcc) {
9749 // clang-format off
9750 switch (Opcode) {
9751 default: llvm_unreachable("Unexpected opcode");
9752 case RISCVISD::VFMADD_VL: Opcode = RISCVISD::VFMSUB_VL; break;
9753 case RISCVISD::VFMSUB_VL: Opcode = RISCVISD::VFMADD_VL; break;
9754 case RISCVISD::VFNMADD_VL: Opcode = RISCVISD::VFNMSUB_VL; break;
9755 case RISCVISD::VFNMSUB_VL: Opcode = RISCVISD::VFNMADD_VL; break;
9756 }
9757 // clang-format on
9758 }
9759
9760 return Opcode;
9761 }
9762
performSRACombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)9763 static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
9764 const RISCVSubtarget &Subtarget) {
9765 assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
9766
9767 if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit())
9768 return SDValue();
9769
9770 if (!isa<ConstantSDNode>(N->getOperand(1)))
9771 return SDValue();
9772 uint64_t ShAmt = N->getConstantOperandVal(1);
9773 if (ShAmt > 32)
9774 return SDValue();
9775
9776 SDValue N0 = N->getOperand(0);
9777
9778 // Combine (sra (sext_inreg (shl X, C1), i32), C2) ->
9779 // (sra (shl X, C1+32), C2+32) so it gets selected as SLLI+SRAI instead of
9780 // SLLIW+SRAIW. SLLI+SRAI have compressed forms.
9781 if (ShAmt < 32 &&
9782 N0.getOpcode() == ISD::SIGN_EXTEND_INREG && N0.hasOneUse() &&
9783 cast<VTSDNode>(N0.getOperand(1))->getVT() == MVT::i32 &&
9784 N0.getOperand(0).getOpcode() == ISD::SHL && N0.getOperand(0).hasOneUse() &&
9785 isa<ConstantSDNode>(N0.getOperand(0).getOperand(1))) {
9786 uint64_t LShAmt = N0.getOperand(0).getConstantOperandVal(1);
9787 if (LShAmt < 32) {
9788 SDLoc ShlDL(N0.getOperand(0));
9789 SDValue Shl = DAG.getNode(ISD::SHL, ShlDL, MVT::i64,
9790 N0.getOperand(0).getOperand(0),
9791 DAG.getConstant(LShAmt + 32, ShlDL, MVT::i64));
9792 SDLoc DL(N);
9793 return DAG.getNode(ISD::SRA, DL, MVT::i64, Shl,
9794 DAG.getConstant(ShAmt + 32, DL, MVT::i64));
9795 }
9796 }
9797
9798 // Combine (sra (shl X, 32), 32 - C) -> (shl (sext_inreg X, i32), C)
9799 // FIXME: Should this be a generic combine? There's a similar combine on X86.
9800 //
9801 // Also try these folds where an add or sub is in the middle.
9802 // (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C)
9803 // (sra (sub C1, (shl X, 32)), 32 - C) -> (shl (sext_inreg (sub C1, X), C)
9804 SDValue Shl;
9805 ConstantSDNode *AddC = nullptr;
9806
9807 // We might have an ADD or SUB between the SRA and SHL.
9808 bool IsAdd = N0.getOpcode() == ISD::ADD;
9809 if ((IsAdd || N0.getOpcode() == ISD::SUB)) {
9810 // Other operand needs to be a constant we can modify.
9811 AddC = dyn_cast<ConstantSDNode>(N0.getOperand(IsAdd ? 1 : 0));
9812 if (!AddC)
9813 return SDValue();
9814
9815 // AddC needs to have at least 32 trailing zeros.
9816 if (AddC->getAPIntValue().countTrailingZeros() < 32)
9817 return SDValue();
9818
9819 // All users should be a shift by constant less than or equal to 32. This
9820 // ensures we'll do this optimization for each of them to produce an
9821 // add/sub+sext_inreg they can all share.
9822 for (SDNode *U : N0->uses()) {
9823 if (U->getOpcode() != ISD::SRA ||
9824 !isa<ConstantSDNode>(U->getOperand(1)) ||
9825 cast<ConstantSDNode>(U->getOperand(1))->getZExtValue() > 32)
9826 return SDValue();
9827 }
9828
9829 Shl = N0.getOperand(IsAdd ? 0 : 1);
9830 } else {
9831 // Not an ADD or SUB.
9832 Shl = N0;
9833 }
9834
9835 // Look for a shift left by 32.
9836 if (Shl.getOpcode() != ISD::SHL || !isa<ConstantSDNode>(Shl.getOperand(1)) ||
9837 Shl.getConstantOperandVal(1) != 32)
9838 return SDValue();
9839
9840 // We if we didn't look through an add/sub, then the shl should have one use.
9841 // If we did look through an add/sub, the sext_inreg we create is free so
9842 // we're only creating 2 new instructions. It's enough to only remove the
9843 // original sra+add/sub.
9844 if (!AddC && !Shl.hasOneUse())
9845 return SDValue();
9846
9847 SDLoc DL(N);
9848 SDValue In = Shl.getOperand(0);
9849
9850 // If we looked through an ADD or SUB, we need to rebuild it with the shifted
9851 // constant.
9852 if (AddC) {
9853 SDValue ShiftedAddC =
9854 DAG.getConstant(AddC->getAPIntValue().lshr(32), DL, MVT::i64);
9855 if (IsAdd)
9856 In = DAG.getNode(ISD::ADD, DL, MVT::i64, In, ShiftedAddC);
9857 else
9858 In = DAG.getNode(ISD::SUB, DL, MVT::i64, ShiftedAddC, In);
9859 }
9860
9861 SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, In,
9862 DAG.getValueType(MVT::i32));
9863 if (ShAmt == 32)
9864 return SExt;
9865
9866 return DAG.getNode(
9867 ISD::SHL, DL, MVT::i64, SExt,
9868 DAG.getConstant(32 - ShAmt, DL, MVT::i64));
9869 }
9870
9871 // Invert (and/or (set cc X, Y), (xor Z, 1)) to (or/and (set !cc X, Y)), Z) if
9872 // the result is used as the conditon of a br_cc or select_cc we can invert,
9873 // inverting the setcc is free, and Z is 0/1. Caller will invert the
9874 // br_cc/select_cc.
tryDemorganOfBooleanCondition(SDValue Cond,SelectionDAG & DAG)9875 static SDValue tryDemorganOfBooleanCondition(SDValue Cond, SelectionDAG &DAG) {
9876 bool IsAnd = Cond.getOpcode() == ISD::AND;
9877 if (!IsAnd && Cond.getOpcode() != ISD::OR)
9878 return SDValue();
9879
9880 if (!Cond.hasOneUse())
9881 return SDValue();
9882
9883 SDValue Setcc = Cond.getOperand(0);
9884 SDValue Xor = Cond.getOperand(1);
9885 // Canonicalize setcc to LHS.
9886 if (Setcc.getOpcode() != ISD::SETCC)
9887 std::swap(Setcc, Xor);
9888 // LHS should be a setcc and RHS should be an xor.
9889 if (Setcc.getOpcode() != ISD::SETCC || !Setcc.hasOneUse() ||
9890 Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9891 return SDValue();
9892
9893 // If the condition is an And, SimplifyDemandedBits may have changed
9894 // (xor Z, 1) to (not Z).
9895 SDValue Xor1 = Xor.getOperand(1);
9896 if (!isOneConstant(Xor1) && !(IsAnd && isAllOnesConstant(Xor1)))
9897 return SDValue();
9898
9899 EVT VT = Cond.getValueType();
9900 SDValue Xor0 = Xor.getOperand(0);
9901
9902 // The LHS of the xor needs to be 0/1.
9903 APInt Mask = APInt::getBitsSetFrom(VT.getSizeInBits(), 1);
9904 if (!DAG.MaskedValueIsZero(Xor0, Mask))
9905 return SDValue();
9906
9907 // We can only invert integer setccs.
9908 EVT SetCCOpVT = Setcc.getOperand(0).getValueType();
9909 if (!SetCCOpVT.isScalarInteger())
9910 return SDValue();
9911
9912 ISD::CondCode CCVal = cast<CondCodeSDNode>(Setcc.getOperand(2))->get();
9913 if (ISD::isIntEqualitySetCC(CCVal)) {
9914 CCVal = ISD::getSetCCInverse(CCVal, SetCCOpVT);
9915 Setcc = DAG.getSetCC(SDLoc(Setcc), VT, Setcc.getOperand(0),
9916 Setcc.getOperand(1), CCVal);
9917 } else if (CCVal == ISD::SETLT && isNullConstant(Setcc.getOperand(0))) {
9918 // Invert (setlt 0, X) by converting to (setlt X, 1).
9919 Setcc = DAG.getSetCC(SDLoc(Setcc), VT, Setcc.getOperand(1),
9920 DAG.getConstant(1, SDLoc(Setcc), VT), CCVal);
9921 } else if (CCVal == ISD::SETLT && isOneConstant(Setcc.getOperand(1))) {
9922 // (setlt X, 1) by converting to (setlt 0, X).
9923 Setcc = DAG.getSetCC(SDLoc(Setcc), VT,
9924 DAG.getConstant(0, SDLoc(Setcc), VT),
9925 Setcc.getOperand(0), CCVal);
9926 } else
9927 return SDValue();
9928
9929 unsigned Opc = IsAnd ? ISD::OR : ISD::AND;
9930 return DAG.getNode(Opc, SDLoc(Cond), VT, Setcc, Xor.getOperand(0));
9931 }
9932
9933 // Perform common combines for BR_CC and SELECT_CC condtions.
combine_CC(SDValue & LHS,SDValue & RHS,SDValue & CC,const SDLoc & DL,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)9934 static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
9935 SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
9936 ISD::CondCode CCVal = cast<CondCodeSDNode>(CC)->get();
9937
9938 // As far as arithmetic right shift always saves the sign,
9939 // shift can be omitted.
9940 // Fold setlt (sra X, N), 0 -> setlt X, 0 and
9941 // setge (sra X, N), 0 -> setge X, 0
9942 if (auto *RHSConst = dyn_cast<ConstantSDNode>(RHS.getNode())) {
9943 if ((CCVal == ISD::SETGE || CCVal == ISD::SETLT) &&
9944 LHS.getOpcode() == ISD::SRA && RHSConst->isZero()) {
9945 LHS = LHS.getOperand(0);
9946 return true;
9947 }
9948 }
9949
9950 if (!ISD::isIntEqualitySetCC(CCVal))
9951 return false;
9952
9953 // Fold ((setlt X, Y), 0, ne) -> (X, Y, lt)
9954 // Sometimes the setcc is introduced after br_cc/select_cc has been formed.
9955 if (LHS.getOpcode() == ISD::SETCC && isNullConstant(RHS) &&
9956 LHS.getOperand(0).getValueType() == Subtarget.getXLenVT()) {
9957 // If we're looking for eq 0 instead of ne 0, we need to invert the
9958 // condition.
9959 bool Invert = CCVal == ISD::SETEQ;
9960 CCVal = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
9961 if (Invert)
9962 CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
9963
9964 RHS = LHS.getOperand(1);
9965 LHS = LHS.getOperand(0);
9966 translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
9967
9968 CC = DAG.getCondCode(CCVal);
9969 return true;
9970 }
9971
9972 // Fold ((xor X, Y), 0, eq/ne) -> (X, Y, eq/ne)
9973 if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) {
9974 RHS = LHS.getOperand(1);
9975 LHS = LHS.getOperand(0);
9976 return true;
9977 }
9978
9979 // Fold ((srl (and X, 1<<C), C), 0, eq/ne) -> ((shl X, XLen-1-C), 0, ge/lt)
9980 if (isNullConstant(RHS) && LHS.getOpcode() == ISD::SRL && LHS.hasOneUse() &&
9981 LHS.getOperand(1).getOpcode() == ISD::Constant) {
9982 SDValue LHS0 = LHS.getOperand(0);
9983 if (LHS0.getOpcode() == ISD::AND &&
9984 LHS0.getOperand(1).getOpcode() == ISD::Constant) {
9985 uint64_t Mask = LHS0.getConstantOperandVal(1);
9986 uint64_t ShAmt = LHS.getConstantOperandVal(1);
9987 if (isPowerOf2_64(Mask) && Log2_64(Mask) == ShAmt) {
9988 CCVal = CCVal == ISD::SETEQ ? ISD::SETGE : ISD::SETLT;
9989 CC = DAG.getCondCode(CCVal);
9990
9991 ShAmt = LHS.getValueSizeInBits() - 1 - ShAmt;
9992 LHS = LHS0.getOperand(0);
9993 if (ShAmt != 0)
9994 LHS =
9995 DAG.getNode(ISD::SHL, DL, LHS.getValueType(), LHS0.getOperand(0),
9996 DAG.getConstant(ShAmt, DL, LHS.getValueType()));
9997 return true;
9998 }
9999 }
10000 }
10001
10002 // (X, 1, setne) -> // (X, 0, seteq) if we can prove X is 0/1.
10003 // This can occur when legalizing some floating point comparisons.
10004 APInt Mask = APInt::getBitsSetFrom(LHS.getValueSizeInBits(), 1);
10005 if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) {
10006 CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
10007 CC = DAG.getCondCode(CCVal);
10008 RHS = DAG.getConstant(0, DL, LHS.getValueType());
10009 return true;
10010 }
10011
10012 if (isNullConstant(RHS)) {
10013 if (SDValue NewCond = tryDemorganOfBooleanCondition(LHS, DAG)) {
10014 CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
10015 CC = DAG.getCondCode(CCVal);
10016 LHS = NewCond;
10017 return true;
10018 }
10019 }
10020
10021 return false;
10022 }
10023
10024 // Fold
10025 // (select C, (add Y, X), Y) -> (add Y, (select C, X, 0)).
10026 // (select C, (sub Y, X), Y) -> (sub Y, (select C, X, 0)).
10027 // (select C, (or Y, X), Y) -> (or Y, (select C, X, 0)).
10028 // (select C, (xor Y, X), Y) -> (xor Y, (select C, X, 0)).
tryFoldSelectIntoOp(SDNode * N,SelectionDAG & DAG,SDValue TrueVal,SDValue FalseVal,bool Swapped)10029 static SDValue tryFoldSelectIntoOp(SDNode *N, SelectionDAG &DAG,
10030 SDValue TrueVal, SDValue FalseVal,
10031 bool Swapped) {
10032 bool Commutative = true;
10033 switch (TrueVal.getOpcode()) {
10034 default:
10035 return SDValue();
10036 case ISD::SUB:
10037 Commutative = false;
10038 break;
10039 case ISD::ADD:
10040 case ISD::OR:
10041 case ISD::XOR:
10042 break;
10043 }
10044
10045 if (!TrueVal.hasOneUse() || isa<ConstantSDNode>(FalseVal))
10046 return SDValue();
10047
10048 unsigned OpToFold;
10049 if (FalseVal == TrueVal.getOperand(0))
10050 OpToFold = 0;
10051 else if (Commutative && FalseVal == TrueVal.getOperand(1))
10052 OpToFold = 1;
10053 else
10054 return SDValue();
10055
10056 EVT VT = N->getValueType(0);
10057 SDLoc DL(N);
10058 SDValue Zero = DAG.getConstant(0, DL, VT);
10059 SDValue OtherOp = TrueVal.getOperand(1 - OpToFold);
10060
10061 if (Swapped)
10062 std::swap(OtherOp, Zero);
10063 SDValue NewSel = DAG.getSelect(DL, VT, N->getOperand(0), OtherOp, Zero);
10064 return DAG.getNode(TrueVal.getOpcode(), DL, VT, FalseVal, NewSel);
10065 }
10066
performSELECTCombine(SDNode * N,SelectionDAG & DAG,const RISCVSubtarget & Subtarget)10067 static SDValue performSELECTCombine(SDNode *N, SelectionDAG &DAG,
10068 const RISCVSubtarget &Subtarget) {
10069 if (Subtarget.hasShortForwardBranchOpt())
10070 return SDValue();
10071
10072 // Only support XLenVT.
10073 if (N->getValueType(0) != Subtarget.getXLenVT())
10074 return SDValue();
10075
10076 SDValue TrueVal = N->getOperand(1);
10077 SDValue FalseVal = N->getOperand(2);
10078 if (SDValue V = tryFoldSelectIntoOp(N, DAG, TrueVal, FalseVal, /*Swapped*/false))
10079 return V;
10080 return tryFoldSelectIntoOp(N, DAG, FalseVal, TrueVal, /*Swapped*/true);
10081 }
10082
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const10083 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
10084 DAGCombinerInfo &DCI) const {
10085 SelectionDAG &DAG = DCI.DAG;
10086
10087 // Helper to call SimplifyDemandedBits on an operand of N where only some low
10088 // bits are demanded. N will be added to the Worklist if it was not deleted.
10089 // Caller should return SDValue(N, 0) if this returns true.
10090 auto SimplifyDemandedLowBitsHelper = [&](unsigned OpNo, unsigned LowBits) {
10091 SDValue Op = N->getOperand(OpNo);
10092 APInt Mask = APInt::getLowBitsSet(Op.getValueSizeInBits(), LowBits);
10093 if (!SimplifyDemandedBits(Op, Mask, DCI))
10094 return false;
10095
10096 if (N->getOpcode() != ISD::DELETED_NODE)
10097 DCI.AddToWorklist(N);
10098 return true;
10099 };
10100
10101 switch (N->getOpcode()) {
10102 default:
10103 break;
10104 case RISCVISD::SplitF64: {
10105 SDValue Op0 = N->getOperand(0);
10106 // If the input to SplitF64 is just BuildPairF64 then the operation is
10107 // redundant. Instead, use BuildPairF64's operands directly.
10108 if (Op0->getOpcode() == RISCVISD::BuildPairF64)
10109 return DCI.CombineTo(N, Op0.getOperand(0), Op0.getOperand(1));
10110
10111 if (Op0->isUndef()) {
10112 SDValue Lo = DAG.getUNDEF(MVT::i32);
10113 SDValue Hi = DAG.getUNDEF(MVT::i32);
10114 return DCI.CombineTo(N, Lo, Hi);
10115 }
10116
10117 SDLoc DL(N);
10118
10119 // It's cheaper to materialise two 32-bit integers than to load a double
10120 // from the constant pool and transfer it to integer registers through the
10121 // stack.
10122 if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Op0)) {
10123 APInt V = C->getValueAPF().bitcastToAPInt();
10124 SDValue Lo = DAG.getConstant(V.trunc(32), DL, MVT::i32);
10125 SDValue Hi = DAG.getConstant(V.lshr(32).trunc(32), DL, MVT::i32);
10126 return DCI.CombineTo(N, Lo, Hi);
10127 }
10128
10129 // This is a target-specific version of a DAGCombine performed in
10130 // DAGCombiner::visitBITCAST. It performs the equivalent of:
10131 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
10132 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
10133 if (!(Op0.getOpcode() == ISD::FNEG || Op0.getOpcode() == ISD::FABS) ||
10134 !Op0.getNode()->hasOneUse())
10135 break;
10136 SDValue NewSplitF64 =
10137 DAG.getNode(RISCVISD::SplitF64, DL, DAG.getVTList(MVT::i32, MVT::i32),
10138 Op0.getOperand(0));
10139 SDValue Lo = NewSplitF64.getValue(0);
10140 SDValue Hi = NewSplitF64.getValue(1);
10141 APInt SignBit = APInt::getSignMask(32);
10142 if (Op0.getOpcode() == ISD::FNEG) {
10143 SDValue NewHi = DAG.getNode(ISD::XOR, DL, MVT::i32, Hi,
10144 DAG.getConstant(SignBit, DL, MVT::i32));
10145 return DCI.CombineTo(N, Lo, NewHi);
10146 }
10147 assert(Op0.getOpcode() == ISD::FABS);
10148 SDValue NewHi = DAG.getNode(ISD::AND, DL, MVT::i32, Hi,
10149 DAG.getConstant(~SignBit, DL, MVT::i32));
10150 return DCI.CombineTo(N, Lo, NewHi);
10151 }
10152 case RISCVISD::SLLW:
10153 case RISCVISD::SRAW:
10154 case RISCVISD::SRLW:
10155 case RISCVISD::RORW:
10156 case RISCVISD::ROLW: {
10157 // Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
10158 if (SimplifyDemandedLowBitsHelper(0, 32) ||
10159 SimplifyDemandedLowBitsHelper(1, 5))
10160 return SDValue(N, 0);
10161
10162 break;
10163 }
10164 case RISCVISD::CLZW:
10165 case RISCVISD::CTZW: {
10166 // Only the lower 32 bits of the first operand are read
10167 if (SimplifyDemandedLowBitsHelper(0, 32))
10168 return SDValue(N, 0);
10169 break;
10170 }
10171 case RISCVISD::FMV_X_ANYEXTH:
10172 case RISCVISD::FMV_X_ANYEXTW_RV64: {
10173 SDLoc DL(N);
10174 SDValue Op0 = N->getOperand(0);
10175 MVT VT = N->getSimpleValueType(0);
10176 // If the input to FMV_X_ANYEXTW_RV64 is just FMV_W_X_RV64 then the
10177 // conversion is unnecessary and can be replaced with the FMV_W_X_RV64
10178 // operand. Similar for FMV_X_ANYEXTH and FMV_H_X.
10179 if ((N->getOpcode() == RISCVISD::FMV_X_ANYEXTW_RV64 &&
10180 Op0->getOpcode() == RISCVISD::FMV_W_X_RV64) ||
10181 (N->getOpcode() == RISCVISD::FMV_X_ANYEXTH &&
10182 Op0->getOpcode() == RISCVISD::FMV_H_X)) {
10183 assert(Op0.getOperand(0).getValueType() == VT &&
10184 "Unexpected value type!");
10185 return Op0.getOperand(0);
10186 }
10187
10188 // This is a target-specific version of a DAGCombine performed in
10189 // DAGCombiner::visitBITCAST. It performs the equivalent of:
10190 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
10191 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
10192 if (!(Op0.getOpcode() == ISD::FNEG || Op0.getOpcode() == ISD::FABS) ||
10193 !Op0.getNode()->hasOneUse())
10194 break;
10195 SDValue NewFMV = DAG.getNode(N->getOpcode(), DL, VT, Op0.getOperand(0));
10196 unsigned FPBits = N->getOpcode() == RISCVISD::FMV_X_ANYEXTW_RV64 ? 32 : 16;
10197 APInt SignBit = APInt::getSignMask(FPBits).sext(VT.getSizeInBits());
10198 if (Op0.getOpcode() == ISD::FNEG)
10199 return DAG.getNode(ISD::XOR, DL, VT, NewFMV,
10200 DAG.getConstant(SignBit, DL, VT));
10201
10202 assert(Op0.getOpcode() == ISD::FABS);
10203 return DAG.getNode(ISD::AND, DL, VT, NewFMV,
10204 DAG.getConstant(~SignBit, DL, VT));
10205 }
10206 case ISD::ADD:
10207 return performADDCombine(N, DAG, Subtarget);
10208 case ISD::SUB:
10209 return performSUBCombine(N, DAG, Subtarget);
10210 case ISD::AND:
10211 return performANDCombine(N, DCI, Subtarget);
10212 case ISD::OR:
10213 return performORCombine(N, DCI, Subtarget);
10214 case ISD::XOR:
10215 return performXORCombine(N, DAG, Subtarget);
10216 case ISD::FADD:
10217 case ISD::UMAX:
10218 case ISD::UMIN:
10219 case ISD::SMAX:
10220 case ISD::SMIN:
10221 case ISD::FMAXNUM:
10222 case ISD::FMINNUM:
10223 return combineBinOpToReduce(N, DAG, Subtarget);
10224 case ISD::SETCC:
10225 return performSETCCCombine(N, DAG, Subtarget);
10226 case ISD::SIGN_EXTEND_INREG:
10227 return performSIGN_EXTEND_INREGCombine(N, DAG, Subtarget);
10228 case ISD::ZERO_EXTEND:
10229 // Fold (zero_extend (fp_to_uint X)) to prevent forming fcvt+zexti32 during
10230 // type legalization. This is safe because fp_to_uint produces poison if
10231 // it overflows.
10232 if (N->getValueType(0) == MVT::i64 && Subtarget.is64Bit()) {
10233 SDValue Src = N->getOperand(0);
10234 if (Src.getOpcode() == ISD::FP_TO_UINT &&
10235 isTypeLegal(Src.getOperand(0).getValueType()))
10236 return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), MVT::i64,
10237 Src.getOperand(0));
10238 if (Src.getOpcode() == ISD::STRICT_FP_TO_UINT && Src.hasOneUse() &&
10239 isTypeLegal(Src.getOperand(1).getValueType())) {
10240 SDVTList VTs = DAG.getVTList(MVT::i64, MVT::Other);
10241 SDValue Res = DAG.getNode(ISD::STRICT_FP_TO_UINT, SDLoc(N), VTs,
10242 Src.getOperand(0), Src.getOperand(1));
10243 DCI.CombineTo(N, Res);
10244 DAG.ReplaceAllUsesOfValueWith(Src.getValue(1), Res.getValue(1));
10245 DCI.recursivelyDeleteUnusedNodes(Src.getNode());
10246 return SDValue(N, 0); // Return N so it doesn't get rechecked.
10247 }
10248 }
10249 return SDValue();
10250 case ISD::TRUNCATE:
10251 return performTRUNCATECombine(N, DAG, Subtarget);
10252 case ISD::SELECT:
10253 return performSELECTCombine(N, DAG, Subtarget);
10254 case RISCVISD::SELECT_CC: {
10255 // Transform
10256 SDValue LHS = N->getOperand(0);
10257 SDValue RHS = N->getOperand(1);
10258 SDValue CC = N->getOperand(2);
10259 ISD::CondCode CCVal = cast<CondCodeSDNode>(CC)->get();
10260 SDValue TrueV = N->getOperand(3);
10261 SDValue FalseV = N->getOperand(4);
10262 SDLoc DL(N);
10263 EVT VT = N->getValueType(0);
10264
10265 // If the True and False values are the same, we don't need a select_cc.
10266 if (TrueV == FalseV)
10267 return TrueV;
10268
10269 // (select (x < 0), y, z) -> x >> (XLEN - 1) & (y - z) + z
10270 // (select (x >= 0), y, z) -> x >> (XLEN - 1) & (z - y) + y
10271 if (!Subtarget.hasShortForwardBranchOpt() && isa<ConstantSDNode>(TrueV) &&
10272 isa<ConstantSDNode>(FalseV) && isNullConstant(RHS) &&
10273 (CCVal == ISD::CondCode::SETLT || CCVal == ISD::CondCode::SETGE)) {
10274 if (CCVal == ISD::CondCode::SETGE)
10275 std::swap(TrueV, FalseV);
10276
10277 int64_t TrueSImm = cast<ConstantSDNode>(TrueV)->getSExtValue();
10278 int64_t FalseSImm = cast<ConstantSDNode>(FalseV)->getSExtValue();
10279 // Only handle simm12, if it is not in this range, it can be considered as
10280 // register.
10281 if (isInt<12>(TrueSImm) && isInt<12>(FalseSImm) &&
10282 isInt<12>(TrueSImm - FalseSImm)) {
10283 SDValue SRA =
10284 DAG.getNode(ISD::SRA, DL, VT, LHS,
10285 DAG.getConstant(Subtarget.getXLen() - 1, DL, VT));
10286 SDValue AND =
10287 DAG.getNode(ISD::AND, DL, VT, SRA,
10288 DAG.getConstant(TrueSImm - FalseSImm, DL, VT));
10289 return DAG.getNode(ISD::ADD, DL, VT, AND, FalseV);
10290 }
10291
10292 if (CCVal == ISD::CondCode::SETGE)
10293 std::swap(TrueV, FalseV);
10294 }
10295
10296 if (combine_CC(LHS, RHS, CC, DL, DAG, Subtarget))
10297 return DAG.getNode(RISCVISD::SELECT_CC, DL, N->getValueType(0),
10298 {LHS, RHS, CC, TrueV, FalseV});
10299
10300 if (!Subtarget.hasShortForwardBranchOpt()) {
10301 // (select c, -1, y) -> -c | y
10302 if (isAllOnesConstant(TrueV)) {
10303 SDValue C = DAG.getSetCC(DL, VT, LHS, RHS, CCVal);
10304 SDValue Neg = DAG.getNegative(C, DL, VT);
10305 return DAG.getNode(ISD::OR, DL, VT, Neg, FalseV);
10306 }
10307 // (select c, y, -1) -> -!c | y
10308 if (isAllOnesConstant(FalseV)) {
10309 SDValue C =
10310 DAG.getSetCC(DL, VT, LHS, RHS, ISD::getSetCCInverse(CCVal, VT));
10311 SDValue Neg = DAG.getNegative(C, DL, VT);
10312 return DAG.getNode(ISD::OR, DL, VT, Neg, TrueV);
10313 }
10314
10315 // (select c, 0, y) -> -!c & y
10316 if (isNullConstant(TrueV)) {
10317 SDValue C =
10318 DAG.getSetCC(DL, VT, LHS, RHS, ISD::getSetCCInverse(CCVal, VT));
10319 SDValue Neg = DAG.getNegative(C, DL, VT);
10320 return DAG.getNode(ISD::AND, DL, VT, Neg, FalseV);
10321 }
10322 // (select c, y, 0) -> -c & y
10323 if (isNullConstant(FalseV)) {
10324 SDValue C = DAG.getSetCC(DL, VT, LHS, RHS, CCVal);
10325 SDValue Neg = DAG.getNegative(C, DL, VT);
10326 return DAG.getNode(ISD::AND, DL, VT, Neg, TrueV);
10327 }
10328 }
10329
10330 return SDValue();
10331 }
10332 case RISCVISD::BR_CC: {
10333 SDValue LHS = N->getOperand(1);
10334 SDValue RHS = N->getOperand(2);
10335 SDValue CC = N->getOperand(3);
10336 SDLoc DL(N);
10337
10338 if (combine_CC(LHS, RHS, CC, DL, DAG, Subtarget))
10339 return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0),
10340 N->getOperand(0), LHS, RHS, CC, N->getOperand(4));
10341
10342 return SDValue();
10343 }
10344 case ISD::BITREVERSE:
10345 return performBITREVERSECombine(N, DAG, Subtarget);
10346 case ISD::FP_TO_SINT:
10347 case ISD::FP_TO_UINT:
10348 return performFP_TO_INTCombine(N, DCI, Subtarget);
10349 case ISD::FP_TO_SINT_SAT:
10350 case ISD::FP_TO_UINT_SAT:
10351 return performFP_TO_INT_SATCombine(N, DCI, Subtarget);
10352 case ISD::FCOPYSIGN: {
10353 EVT VT = N->getValueType(0);
10354 if (!VT.isVector())
10355 break;
10356 // There is a form of VFSGNJ which injects the negated sign of its second
10357 // operand. Try and bubble any FNEG up after the extend/round to produce
10358 // this optimized pattern. Avoid modifying cases where FP_ROUND and
10359 // TRUNC=1.
10360 SDValue In2 = N->getOperand(1);
10361 // Avoid cases where the extend/round has multiple uses, as duplicating
10362 // those is typically more expensive than removing a fneg.
10363 if (!In2.hasOneUse())
10364 break;
10365 if (In2.getOpcode() != ISD::FP_EXTEND &&
10366 (In2.getOpcode() != ISD::FP_ROUND || In2.getConstantOperandVal(1) != 0))
10367 break;
10368 In2 = In2.getOperand(0);
10369 if (In2.getOpcode() != ISD::FNEG)
10370 break;
10371 SDLoc DL(N);
10372 SDValue NewFPExtRound = DAG.getFPExtendOrRound(In2.getOperand(0), DL, VT);
10373 return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N->getOperand(0),
10374 DAG.getNode(ISD::FNEG, DL, VT, NewFPExtRound));
10375 }
10376 case ISD::MGATHER:
10377 case ISD::MSCATTER:
10378 case ISD::VP_GATHER:
10379 case ISD::VP_SCATTER: {
10380 if (!DCI.isBeforeLegalize())
10381 break;
10382 SDValue Index, ScaleOp;
10383 bool IsIndexSigned = false;
10384 if (const auto *VPGSN = dyn_cast<VPGatherScatterSDNode>(N)) {
10385 Index = VPGSN->getIndex();
10386 ScaleOp = VPGSN->getScale();
10387 IsIndexSigned = VPGSN->isIndexSigned();
10388 assert(!VPGSN->isIndexScaled() &&
10389 "Scaled gather/scatter should not be formed");
10390 } else {
10391 const auto *MGSN = cast<MaskedGatherScatterSDNode>(N);
10392 Index = MGSN->getIndex();
10393 ScaleOp = MGSN->getScale();
10394 IsIndexSigned = MGSN->isIndexSigned();
10395 assert(!MGSN->isIndexScaled() &&
10396 "Scaled gather/scatter should not be formed");
10397
10398 }
10399 EVT IndexVT = Index.getValueType();
10400 MVT XLenVT = Subtarget.getXLenVT();
10401 // RISCV indexed loads only support the "unsigned unscaled" addressing
10402 // mode, so anything else must be manually legalized.
10403 bool NeedsIdxLegalization =
10404 (IsIndexSigned && IndexVT.getVectorElementType().bitsLT(XLenVT));
10405 if (!NeedsIdxLegalization)
10406 break;
10407
10408 SDLoc DL(N);
10409
10410 // Any index legalization should first promote to XLenVT, so we don't lose
10411 // bits when scaling. This may create an illegal index type so we let
10412 // LLVM's legalization take care of the splitting.
10413 // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet.
10414 if (IndexVT.getVectorElementType().bitsLT(XLenVT)) {
10415 IndexVT = IndexVT.changeVectorElementType(XLenVT);
10416 Index = DAG.getNode(IsIndexSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND,
10417 DL, IndexVT, Index);
10418 }
10419
10420 ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_SCALED;
10421 if (const auto *VPGN = dyn_cast<VPGatherSDNode>(N))
10422 return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL,
10423 {VPGN->getChain(), VPGN->getBasePtr(), Index,
10424 ScaleOp, VPGN->getMask(),
10425 VPGN->getVectorLength()},
10426 VPGN->getMemOperand(), NewIndexTy);
10427 if (const auto *VPSN = dyn_cast<VPScatterSDNode>(N))
10428 return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL,
10429 {VPSN->getChain(), VPSN->getValue(),
10430 VPSN->getBasePtr(), Index, ScaleOp,
10431 VPSN->getMask(), VPSN->getVectorLength()},
10432 VPSN->getMemOperand(), NewIndexTy);
10433 if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N))
10434 return DAG.getMaskedGather(
10435 N->getVTList(), MGN->getMemoryVT(), DL,
10436 {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
10437 MGN->getBasePtr(), Index, ScaleOp},
10438 MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
10439 const auto *MSN = cast<MaskedScatterSDNode>(N);
10440 return DAG.getMaskedScatter(
10441 N->getVTList(), MSN->getMemoryVT(), DL,
10442 {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
10443 Index, ScaleOp},
10444 MSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore());
10445 }
10446 case RISCVISD::SRA_VL:
10447 case RISCVISD::SRL_VL:
10448 case RISCVISD::SHL_VL: {
10449 SDValue ShAmt = N->getOperand(1);
10450 if (ShAmt.getOpcode() == RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL) {
10451 // We don't need the upper 32 bits of a 64-bit element for a shift amount.
10452 SDLoc DL(N);
10453 SDValue VL = N->getOperand(3);
10454 EVT VT = N->getValueType(0);
10455 ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
10456 ShAmt.getOperand(1), VL);
10457 return DAG.getNode(N->getOpcode(), DL, VT, N->getOperand(0), ShAmt,
10458 N->getOperand(2), N->getOperand(3), N->getOperand(4));
10459 }
10460 break;
10461 }
10462 case ISD::SRA:
10463 if (SDValue V = performSRACombine(N, DAG, Subtarget))
10464 return V;
10465 [[fallthrough]];
10466 case ISD::SRL:
10467 case ISD::SHL: {
10468 SDValue ShAmt = N->getOperand(1);
10469 if (ShAmt.getOpcode() == RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL) {
10470 // We don't need the upper 32 bits of a 64-bit element for a shift amount.
10471 SDLoc DL(N);
10472 EVT VT = N->getValueType(0);
10473 ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
10474 ShAmt.getOperand(1),
10475 DAG.getRegister(RISCV::X0, Subtarget.getXLenVT()));
10476 return DAG.getNode(N->getOpcode(), DL, VT, N->getOperand(0), ShAmt);
10477 }
10478 break;
10479 }
10480 case RISCVISD::ADD_VL:
10481 case RISCVISD::SUB_VL:
10482 case RISCVISD::VWADD_W_VL:
10483 case RISCVISD::VWADDU_W_VL:
10484 case RISCVISD::VWSUB_W_VL:
10485 case RISCVISD::VWSUBU_W_VL:
10486 case RISCVISD::MUL_VL:
10487 return combineBinOp_VLToVWBinOp_VL(N, DCI);
10488 case RISCVISD::VFMADD_VL:
10489 case RISCVISD::VFNMADD_VL:
10490 case RISCVISD::VFMSUB_VL:
10491 case RISCVISD::VFNMSUB_VL: {
10492 // Fold FNEG_VL into FMA opcodes.
10493 SDValue A = N->getOperand(0);
10494 SDValue B = N->getOperand(1);
10495 SDValue C = N->getOperand(2);
10496 SDValue Mask = N->getOperand(3);
10497 SDValue VL = N->getOperand(4);
10498
10499 auto invertIfNegative = [&Mask, &VL](SDValue &V) {
10500 if (V.getOpcode() == RISCVISD::FNEG_VL && V.getOperand(1) == Mask &&
10501 V.getOperand(2) == VL) {
10502 // Return the negated input.
10503 V = V.getOperand(0);
10504 return true;
10505 }
10506
10507 return false;
10508 };
10509
10510 bool NegA = invertIfNegative(A);
10511 bool NegB = invertIfNegative(B);
10512 bool NegC = invertIfNegative(C);
10513
10514 // If no operands are negated, we're done.
10515 if (!NegA && !NegB && !NegC)
10516 return SDValue();
10517
10518 unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), NegA != NegB, NegC);
10519 return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), A, B, C, Mask,
10520 VL);
10521 }
10522 case ISD::STORE: {
10523 auto *Store = cast<StoreSDNode>(N);
10524 SDValue Val = Store->getValue();
10525 // Combine store of vmv.x.s/vfmv.f.s to vse with VL of 1.
10526 // vfmv.f.s is represented as extract element from 0. Match it late to avoid
10527 // any illegal types.
10528 if (Val.getOpcode() == RISCVISD::VMV_X_S ||
10529 (DCI.isAfterLegalizeDAG() &&
10530 Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
10531 isNullConstant(Val.getOperand(1)))) {
10532 SDValue Src = Val.getOperand(0);
10533 MVT VecVT = Src.getSimpleValueType();
10534 EVT MemVT = Store->getMemoryVT();
10535 // VecVT should be scalable and memory VT should match the element type.
10536 if (VecVT.isScalableVector() &&
10537 MemVT == VecVT.getVectorElementType()) {
10538 SDLoc DL(N);
10539 MVT MaskVT = getMaskTypeFor(VecVT);
10540 return DAG.getStoreVP(
10541 Store->getChain(), DL, Src, Store->getBasePtr(), Store->getOffset(),
10542 DAG.getConstant(1, DL, MaskVT),
10543 DAG.getConstant(1, DL, Subtarget.getXLenVT()), MemVT,
10544 Store->getMemOperand(), Store->getAddressingMode(),
10545 Store->isTruncatingStore(), /*IsCompress*/ false);
10546 }
10547 }
10548
10549 break;
10550 }
10551 case ISD::SPLAT_VECTOR: {
10552 EVT VT = N->getValueType(0);
10553 // Only perform this combine on legal MVT types.
10554 if (!isTypeLegal(VT))
10555 break;
10556 if (auto Gather = matchSplatAsGather(N->getOperand(0), VT.getSimpleVT(), N,
10557 DAG, Subtarget))
10558 return Gather;
10559 break;
10560 }
10561 case RISCVISD::VMV_V_X_VL: {
10562 // Tail agnostic VMV.V.X only demands the vector element bitwidth from the
10563 // scalar input.
10564 unsigned ScalarSize = N->getOperand(1).getValueSizeInBits();
10565 unsigned EltWidth = N->getValueType(0).getScalarSizeInBits();
10566 if (ScalarSize > EltWidth && N->getOperand(0).isUndef())
10567 if (SimplifyDemandedLowBitsHelper(1, EltWidth))
10568 return SDValue(N, 0);
10569
10570 break;
10571 }
10572 case RISCVISD::VFMV_S_F_VL: {
10573 SDValue Src = N->getOperand(1);
10574 // Try to remove vector->scalar->vector if the scalar->vector is inserting
10575 // into an undef vector.
10576 // TODO: Could use a vslide or vmv.v.v for non-undef.
10577 if (N->getOperand(0).isUndef() &&
10578 Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
10579 isNullConstant(Src.getOperand(1)) &&
10580 Src.getOperand(0).getValueType().isScalableVector()) {
10581 EVT VT = N->getValueType(0);
10582 EVT SrcVT = Src.getOperand(0).getValueType();
10583 assert(SrcVT.getVectorElementType() == VT.getVectorElementType());
10584 // Widths match, just return the original vector.
10585 if (SrcVT == VT)
10586 return Src.getOperand(0);
10587 // TODO: Use insert_subvector/extract_subvector to change widen/narrow?
10588 }
10589 break;
10590 }
10591 case ISD::INTRINSIC_WO_CHAIN: {
10592 unsigned IntNo = N->getConstantOperandVal(0);
10593 switch (IntNo) {
10594 // By default we do not combine any intrinsic.
10595 default:
10596 return SDValue();
10597 case Intrinsic::riscv_vcpop:
10598 case Intrinsic::riscv_vcpop_mask:
10599 case Intrinsic::riscv_vfirst:
10600 case Intrinsic::riscv_vfirst_mask: {
10601 SDValue VL = N->getOperand(2);
10602 if (IntNo == Intrinsic::riscv_vcpop_mask ||
10603 IntNo == Intrinsic::riscv_vfirst_mask)
10604 VL = N->getOperand(3);
10605 if (!isNullConstant(VL))
10606 return SDValue();
10607 // If VL is 0, vcpop -> li 0, vfirst -> li -1.
10608 SDLoc DL(N);
10609 EVT VT = N->getValueType(0);
10610 if (IntNo == Intrinsic::riscv_vfirst ||
10611 IntNo == Intrinsic::riscv_vfirst_mask)
10612 return DAG.getConstant(-1, DL, VT);
10613 return DAG.getConstant(0, DL, VT);
10614 }
10615 }
10616 }
10617 case ISD::BITCAST: {
10618 assert(Subtarget.useRVVForFixedLengthVectors());
10619 SDValue N0 = N->getOperand(0);
10620 EVT VT = N->getValueType(0);
10621 EVT SrcVT = N0.getValueType();
10622 // If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer
10623 // type, widen both sides to avoid a trip through memory.
10624 if ((SrcVT == MVT::v1i1 || SrcVT == MVT::v2i1 || SrcVT == MVT::v4i1) &&
10625 VT.isScalarInteger()) {
10626 unsigned NumConcats = 8 / SrcVT.getVectorNumElements();
10627 SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(SrcVT));
10628 Ops[0] = N0;
10629 SDLoc DL(N);
10630 N0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i1, Ops);
10631 N0 = DAG.getBitcast(MVT::i8, N0);
10632 return DAG.getNode(ISD::TRUNCATE, DL, VT, N0);
10633 }
10634
10635 return SDValue();
10636 }
10637 }
10638
10639 return SDValue();
10640 }
10641
isDesirableToCommuteWithShift(const SDNode * N,CombineLevel Level) const10642 bool RISCVTargetLowering::isDesirableToCommuteWithShift(
10643 const SDNode *N, CombineLevel Level) const {
10644 assert((N->getOpcode() == ISD::SHL || N->getOpcode() == ISD::SRA ||
10645 N->getOpcode() == ISD::SRL) &&
10646 "Expected shift op");
10647
10648 // The following folds are only desirable if `(OP _, c1 << c2)` can be
10649 // materialised in fewer instructions than `(OP _, c1)`:
10650 //
10651 // (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10652 // (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10653 SDValue N0 = N->getOperand(0);
10654 EVT Ty = N0.getValueType();
10655 if (Ty.isScalarInteger() &&
10656 (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR)) {
10657 auto *C1 = dyn_cast<ConstantSDNode>(N0->getOperand(1));
10658 auto *C2 = dyn_cast<ConstantSDNode>(N->getOperand(1));
10659 if (C1 && C2) {
10660 const APInt &C1Int = C1->getAPIntValue();
10661 APInt ShiftedC1Int = C1Int << C2->getAPIntValue();
10662
10663 // We can materialise `c1 << c2` into an add immediate, so it's "free",
10664 // and the combine should happen, to potentially allow further combines
10665 // later.
10666 if (ShiftedC1Int.getMinSignedBits() <= 64 &&
10667 isLegalAddImmediate(ShiftedC1Int.getSExtValue()))
10668 return true;
10669
10670 // We can materialise `c1` in an add immediate, so it's "free", and the
10671 // combine should be prevented.
10672 if (C1Int.getMinSignedBits() <= 64 &&
10673 isLegalAddImmediate(C1Int.getSExtValue()))
10674 return false;
10675
10676 // Neither constant will fit into an immediate, so find materialisation
10677 // costs.
10678 int C1Cost = RISCVMatInt::getIntMatCost(C1Int, Ty.getSizeInBits(),
10679 Subtarget.getFeatureBits(),
10680 /*CompressionCost*/true);
10681 int ShiftedC1Cost = RISCVMatInt::getIntMatCost(
10682 ShiftedC1Int, Ty.getSizeInBits(), Subtarget.getFeatureBits(),
10683 /*CompressionCost*/true);
10684
10685 // Materialising `c1` is cheaper than materialising `c1 << c2`, so the
10686 // combine should be prevented.
10687 if (C1Cost < ShiftedC1Cost)
10688 return false;
10689 }
10690 }
10691 return true;
10692 }
10693
targetShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,TargetLoweringOpt & TLO) const10694 bool RISCVTargetLowering::targetShrinkDemandedConstant(
10695 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
10696 TargetLoweringOpt &TLO) const {
10697 // Delay this optimization as late as possible.
10698 if (!TLO.LegalOps)
10699 return false;
10700
10701 EVT VT = Op.getValueType();
10702 if (VT.isVector())
10703 return false;
10704
10705 unsigned Opcode = Op.getOpcode();
10706 if (Opcode != ISD::AND && Opcode != ISD::OR && Opcode != ISD::XOR)
10707 return false;
10708
10709 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
10710 if (!C)
10711 return false;
10712
10713 const APInt &Mask = C->getAPIntValue();
10714
10715 // Clear all non-demanded bits initially.
10716 APInt ShrunkMask = Mask & DemandedBits;
10717
10718 // Try to make a smaller immediate by setting undemanded bits.
10719
10720 APInt ExpandedMask = Mask | ~DemandedBits;
10721
10722 auto IsLegalMask = [ShrunkMask, ExpandedMask](const APInt &Mask) -> bool {
10723 return ShrunkMask.isSubsetOf(Mask) && Mask.isSubsetOf(ExpandedMask);
10724 };
10725 auto UseMask = [Mask, Op, &TLO](const APInt &NewMask) -> bool {
10726 if (NewMask == Mask)
10727 return true;
10728 SDLoc DL(Op);
10729 SDValue NewC = TLO.DAG.getConstant(NewMask, DL, Op.getValueType());
10730 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), DL, Op.getValueType(),
10731 Op.getOperand(0), NewC);
10732 return TLO.CombineTo(Op, NewOp);
10733 };
10734
10735 // If the shrunk mask fits in sign extended 12 bits, let the target
10736 // independent code apply it.
10737 if (ShrunkMask.isSignedIntN(12))
10738 return false;
10739
10740 // And has a few special cases for zext.
10741 if (Opcode == ISD::AND) {
10742 // Preserve (and X, 0xffff), if zext.h exists use zext.h,
10743 // otherwise use SLLI + SRLI.
10744 APInt NewMask = APInt(Mask.getBitWidth(), 0xffff);
10745 if (IsLegalMask(NewMask))
10746 return UseMask(NewMask);
10747
10748 // Try to preserve (and X, 0xffffffff), the (zext_inreg X, i32) pattern.
10749 if (VT == MVT::i64) {
10750 APInt NewMask = APInt(64, 0xffffffff);
10751 if (IsLegalMask(NewMask))
10752 return UseMask(NewMask);
10753 }
10754 }
10755
10756 // For the remaining optimizations, we need to be able to make a negative
10757 // number through a combination of mask and undemanded bits.
10758 if (!ExpandedMask.isNegative())
10759 return false;
10760
10761 // What is the fewest number of bits we need to represent the negative number.
10762 unsigned MinSignedBits = ExpandedMask.getMinSignedBits();
10763
10764 // Try to make a 12 bit negative immediate. If that fails try to make a 32
10765 // bit negative immediate unless the shrunk immediate already fits in 32 bits.
10766 // If we can't create a simm12, we shouldn't change opaque constants.
10767 APInt NewMask = ShrunkMask;
10768 if (MinSignedBits <= 12)
10769 NewMask.setBitsFrom(11);
10770 else if (!C->isOpaque() && MinSignedBits <= 32 && !ShrunkMask.isSignedIntN(32))
10771 NewMask.setBitsFrom(31);
10772 else
10773 return false;
10774
10775 // Check that our new mask is a subset of the demanded mask.
10776 assert(IsLegalMask(NewMask));
10777 return UseMask(NewMask);
10778 }
10779
computeGREVOrGORC(uint64_t x,unsigned ShAmt,bool IsGORC)10780 static uint64_t computeGREVOrGORC(uint64_t x, unsigned ShAmt, bool IsGORC) {
10781 static const uint64_t GREVMasks[] = {
10782 0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL,
10783 0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL};
10784
10785 for (unsigned Stage = 0; Stage != 6; ++Stage) {
10786 unsigned Shift = 1 << Stage;
10787 if (ShAmt & Shift) {
10788 uint64_t Mask = GREVMasks[Stage];
10789 uint64_t Res = ((x & Mask) << Shift) | ((x >> Shift) & Mask);
10790 if (IsGORC)
10791 Res |= x;
10792 x = Res;
10793 }
10794 }
10795
10796 return x;
10797 }
10798
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const10799 void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
10800 KnownBits &Known,
10801 const APInt &DemandedElts,
10802 const SelectionDAG &DAG,
10803 unsigned Depth) const {
10804 unsigned BitWidth = Known.getBitWidth();
10805 unsigned Opc = Op.getOpcode();
10806 assert((Opc >= ISD::BUILTIN_OP_END ||
10807 Opc == ISD::INTRINSIC_WO_CHAIN ||
10808 Opc == ISD::INTRINSIC_W_CHAIN ||
10809 Opc == ISD::INTRINSIC_VOID) &&
10810 "Should use MaskedValueIsZero if you don't know whether Op"
10811 " is a target node!");
10812
10813 Known.resetAll();
10814 switch (Opc) {
10815 default: break;
10816 case RISCVISD::SELECT_CC: {
10817 Known = DAG.computeKnownBits(Op.getOperand(4), Depth + 1);
10818 // If we don't know any bits, early out.
10819 if (Known.isUnknown())
10820 break;
10821 KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(3), Depth + 1);
10822
10823 // Only known if known in both the LHS and RHS.
10824 Known = KnownBits::commonBits(Known, Known2);
10825 break;
10826 }
10827 case RISCVISD::REMUW: {
10828 KnownBits Known2;
10829 Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
10830 Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
10831 // We only care about the lower 32 bits.
10832 Known = KnownBits::urem(Known.trunc(32), Known2.trunc(32));
10833 // Restore the original width by sign extending.
10834 Known = Known.sext(BitWidth);
10835 break;
10836 }
10837 case RISCVISD::DIVUW: {
10838 KnownBits Known2;
10839 Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
10840 Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
10841 // We only care about the lower 32 bits.
10842 Known = KnownBits::udiv(Known.trunc(32), Known2.trunc(32));
10843 // Restore the original width by sign extending.
10844 Known = Known.sext(BitWidth);
10845 break;
10846 }
10847 case RISCVISD::CTZW: {
10848 KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
10849 unsigned PossibleTZ = Known2.trunc(32).countMaxTrailingZeros();
10850 unsigned LowBits = llvm::bit_width(PossibleTZ);
10851 Known.Zero.setBitsFrom(LowBits);
10852 break;
10853 }
10854 case RISCVISD::CLZW: {
10855 KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
10856 unsigned PossibleLZ = Known2.trunc(32).countMaxLeadingZeros();
10857 unsigned LowBits = llvm::bit_width(PossibleLZ);
10858 Known.Zero.setBitsFrom(LowBits);
10859 break;
10860 }
10861 case RISCVISD::BREV8:
10862 case RISCVISD::ORC_B: {
10863 // FIXME: This is based on the non-ratified Zbp GREV and GORC where a
10864 // control value of 7 is equivalent to brev8 and orc.b.
10865 Known = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
10866 bool IsGORC = Op.getOpcode() == RISCVISD::ORC_B;
10867 // To compute zeros, we need to invert the value and invert it back after.
10868 Known.Zero =
10869 ~computeGREVOrGORC(~Known.Zero.getZExtValue(), 7, IsGORC);
10870 Known.One = computeGREVOrGORC(Known.One.getZExtValue(), 7, IsGORC);
10871 break;
10872 }
10873 case RISCVISD::READ_VLENB: {
10874 // We can use the minimum and maximum VLEN values to bound VLENB. We
10875 // know VLEN must be a power of two.
10876 const unsigned MinVLenB = Subtarget.getRealMinVLen() / 8;
10877 const unsigned MaxVLenB = Subtarget.getRealMaxVLen() / 8;
10878 assert(MinVLenB > 0 && "READ_VLENB without vector extension enabled?");
10879 Known.Zero.setLowBits(Log2_32(MinVLenB));
10880 Known.Zero.setBitsFrom(Log2_32(MaxVLenB)+1);
10881 if (MaxVLenB == MinVLenB)
10882 Known.One.setBit(Log2_32(MinVLenB));
10883 break;
10884 }
10885 case ISD::INTRINSIC_W_CHAIN:
10886 case ISD::INTRINSIC_WO_CHAIN: {
10887 unsigned IntNo =
10888 Op.getConstantOperandVal(Opc == ISD::INTRINSIC_WO_CHAIN ? 0 : 1);
10889 switch (IntNo) {
10890 default:
10891 // We can't do anything for most intrinsics.
10892 break;
10893 case Intrinsic::riscv_vsetvli:
10894 case Intrinsic::riscv_vsetvlimax:
10895 case Intrinsic::riscv_vsetvli_opt:
10896 case Intrinsic::riscv_vsetvlimax_opt:
10897 // Assume that VL output is positive and would fit in an int32_t.
10898 // TODO: VLEN might be capped at 16 bits in a future V spec update.
10899 if (BitWidth >= 32)
10900 Known.Zero.setBitsFrom(31);
10901 break;
10902 }
10903 break;
10904 }
10905 }
10906 }
10907
ComputeNumSignBitsForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const10908 unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
10909 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
10910 unsigned Depth) const {
10911 switch (Op.getOpcode()) {
10912 default:
10913 break;
10914 case RISCVISD::SELECT_CC: {
10915 unsigned Tmp =
10916 DAG.ComputeNumSignBits(Op.getOperand(3), DemandedElts, Depth + 1);
10917 if (Tmp == 1) return 1; // Early out.
10918 unsigned Tmp2 =
10919 DAG.ComputeNumSignBits(Op.getOperand(4), DemandedElts, Depth + 1);
10920 return std::min(Tmp, Tmp2);
10921 }
10922 case RISCVISD::ABSW: {
10923 // We expand this at isel to negw+max. The result will have 33 sign bits
10924 // if the input has at least 33 sign bits.
10925 unsigned Tmp =
10926 DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
10927 if (Tmp < 33) return 1;
10928 return 33;
10929 }
10930 case RISCVISD::SLLW:
10931 case RISCVISD::SRAW:
10932 case RISCVISD::SRLW:
10933 case RISCVISD::DIVW:
10934 case RISCVISD::DIVUW:
10935 case RISCVISD::REMUW:
10936 case RISCVISD::ROLW:
10937 case RISCVISD::RORW:
10938 case RISCVISD::FCVT_W_RV64:
10939 case RISCVISD::FCVT_WU_RV64:
10940 case RISCVISD::STRICT_FCVT_W_RV64:
10941 case RISCVISD::STRICT_FCVT_WU_RV64:
10942 // TODO: As the result is sign-extended, this is conservatively correct. A
10943 // more precise answer could be calculated for SRAW depending on known
10944 // bits in the shift amount.
10945 return 33;
10946 case RISCVISD::VMV_X_S: {
10947 // The number of sign bits of the scalar result is computed by obtaining the
10948 // element type of the input vector operand, subtracting its width from the
10949 // XLEN, and then adding one (sign bit within the element type). If the
10950 // element type is wider than XLen, the least-significant XLEN bits are
10951 // taken.
10952 unsigned XLen = Subtarget.getXLen();
10953 unsigned EltBits = Op.getOperand(0).getScalarValueSizeInBits();
10954 if (EltBits <= XLen)
10955 return XLen - EltBits + 1;
10956 break;
10957 }
10958 case ISD::INTRINSIC_W_CHAIN: {
10959 unsigned IntNo = Op.getConstantOperandVal(1);
10960 switch (IntNo) {
10961 default:
10962 break;
10963 case Intrinsic::riscv_masked_atomicrmw_xchg_i64:
10964 case Intrinsic::riscv_masked_atomicrmw_add_i64:
10965 case Intrinsic::riscv_masked_atomicrmw_sub_i64:
10966 case Intrinsic::riscv_masked_atomicrmw_nand_i64:
10967 case Intrinsic::riscv_masked_atomicrmw_max_i64:
10968 case Intrinsic::riscv_masked_atomicrmw_min_i64:
10969 case Intrinsic::riscv_masked_atomicrmw_umax_i64:
10970 case Intrinsic::riscv_masked_atomicrmw_umin_i64:
10971 case Intrinsic::riscv_masked_cmpxchg_i64:
10972 // riscv_masked_{atomicrmw_*,cmpxchg} intrinsics represent an emulated
10973 // narrow atomic operation. These are implemented using atomic
10974 // operations at the minimum supported atomicrmw/cmpxchg width whose
10975 // result is then sign extended to XLEN. With +A, the minimum width is
10976 // 32 for both 64 and 32.
10977 assert(Subtarget.getXLen() == 64);
10978 assert(getMinCmpXchgSizeInBits() == 32);
10979 assert(Subtarget.hasStdExtA());
10980 return 33;
10981 }
10982 }
10983 }
10984
10985 return 1;
10986 }
10987
10988 const Constant *
getTargetConstantFromLoad(LoadSDNode * Ld) const10989 RISCVTargetLowering::getTargetConstantFromLoad(LoadSDNode *Ld) const {
10990 assert(Ld && "Unexpected null LoadSDNode");
10991 if (!ISD::isNormalLoad(Ld))
10992 return nullptr;
10993
10994 SDValue Ptr = Ld->getBasePtr();
10995
10996 // Only constant pools with no offset are supported.
10997 auto GetSupportedConstantPool = [](SDValue Ptr) -> ConstantPoolSDNode * {
10998 auto *CNode = dyn_cast<ConstantPoolSDNode>(Ptr);
10999 if (!CNode || CNode->isMachineConstantPoolEntry() ||
11000 CNode->getOffset() != 0)
11001 return nullptr;
11002
11003 return CNode;
11004 };
11005
11006 // Simple case, LLA.
11007 if (Ptr.getOpcode() == RISCVISD::LLA) {
11008 auto *CNode = GetSupportedConstantPool(Ptr);
11009 if (!CNode || CNode->getTargetFlags() != 0)
11010 return nullptr;
11011
11012 return CNode->getConstVal();
11013 }
11014
11015 // Look for a HI and ADD_LO pair.
11016 if (Ptr.getOpcode() != RISCVISD::ADD_LO ||
11017 Ptr.getOperand(0).getOpcode() != RISCVISD::HI)
11018 return nullptr;
11019
11020 auto *CNodeLo = GetSupportedConstantPool(Ptr.getOperand(1));
11021 auto *CNodeHi = GetSupportedConstantPool(Ptr.getOperand(0).getOperand(0));
11022
11023 if (!CNodeLo || CNodeLo->getTargetFlags() != RISCVII::MO_LO ||
11024 !CNodeHi || CNodeHi->getTargetFlags() != RISCVII::MO_HI)
11025 return nullptr;
11026
11027 if (CNodeLo->getConstVal() != CNodeHi->getConstVal())
11028 return nullptr;
11029
11030 return CNodeLo->getConstVal();
11031 }
11032
emitReadCycleWidePseudo(MachineInstr & MI,MachineBasicBlock * BB)11033 static MachineBasicBlock *emitReadCycleWidePseudo(MachineInstr &MI,
11034 MachineBasicBlock *BB) {
11035 assert(MI.getOpcode() == RISCV::ReadCycleWide && "Unexpected instruction");
11036
11037 // To read the 64-bit cycle CSR on a 32-bit target, we read the two halves.
11038 // Should the count have wrapped while it was being read, we need to try
11039 // again.
11040 // ...
11041 // read:
11042 // rdcycleh x3 # load high word of cycle
11043 // rdcycle x2 # load low word of cycle
11044 // rdcycleh x4 # load high word of cycle
11045 // bne x3, x4, read # check if high word reads match, otherwise try again
11046 // ...
11047
11048 MachineFunction &MF = *BB->getParent();
11049 const BasicBlock *LLVM_BB = BB->getBasicBlock();
11050 MachineFunction::iterator It = ++BB->getIterator();
11051
11052 MachineBasicBlock *LoopMBB = MF.CreateMachineBasicBlock(LLVM_BB);
11053 MF.insert(It, LoopMBB);
11054
11055 MachineBasicBlock *DoneMBB = MF.CreateMachineBasicBlock(LLVM_BB);
11056 MF.insert(It, DoneMBB);
11057
11058 // Transfer the remainder of BB and its successor edges to DoneMBB.
11059 DoneMBB->splice(DoneMBB->begin(), BB,
11060 std::next(MachineBasicBlock::iterator(MI)), BB->end());
11061 DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
11062
11063 BB->addSuccessor(LoopMBB);
11064
11065 MachineRegisterInfo &RegInfo = MF.getRegInfo();
11066 Register ReadAgainReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass);
11067 Register LoReg = MI.getOperand(0).getReg();
11068 Register HiReg = MI.getOperand(1).getReg();
11069 DebugLoc DL = MI.getDebugLoc();
11070
11071 const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
11072 BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), HiReg)
11073 .addImm(RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding)
11074 .addReg(RISCV::X0);
11075 BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), LoReg)
11076 .addImm(RISCVSysReg::lookupSysRegByName("CYCLE")->Encoding)
11077 .addReg(RISCV::X0);
11078 BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), ReadAgainReg)
11079 .addImm(RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding)
11080 .addReg(RISCV::X0);
11081
11082 BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
11083 .addReg(HiReg)
11084 .addReg(ReadAgainReg)
11085 .addMBB(LoopMBB);
11086
11087 LoopMBB->addSuccessor(LoopMBB);
11088 LoopMBB->addSuccessor(DoneMBB);
11089
11090 MI.eraseFromParent();
11091
11092 return DoneMBB;
11093 }
11094
emitSplitF64Pseudo(MachineInstr & MI,MachineBasicBlock * BB)11095 static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
11096 MachineBasicBlock *BB) {
11097 assert(MI.getOpcode() == RISCV::SplitF64Pseudo && "Unexpected instruction");
11098
11099 MachineFunction &MF = *BB->getParent();
11100 DebugLoc DL = MI.getDebugLoc();
11101 const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
11102 const TargetRegisterInfo *RI = MF.getSubtarget().getRegisterInfo();
11103 Register LoReg = MI.getOperand(0).getReg();
11104 Register HiReg = MI.getOperand(1).getReg();
11105 Register SrcReg = MI.getOperand(2).getReg();
11106 const TargetRegisterClass *SrcRC = &RISCV::FPR64RegClass;
11107 int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);
11108
11109 TII.storeRegToStackSlot(*BB, MI, SrcReg, MI.getOperand(2).isKill(), FI, SrcRC,
11110 RI, Register());
11111 MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, FI);
11112 MachineMemOperand *MMOLo =
11113 MF.getMachineMemOperand(MPI, MachineMemOperand::MOLoad, 4, Align(8));
11114 MachineMemOperand *MMOHi = MF.getMachineMemOperand(
11115 MPI.getWithOffset(4), MachineMemOperand::MOLoad, 4, Align(8));
11116 BuildMI(*BB, MI, DL, TII.get(RISCV::LW), LoReg)
11117 .addFrameIndex(FI)
11118 .addImm(0)
11119 .addMemOperand(MMOLo);
11120 BuildMI(*BB, MI, DL, TII.get(RISCV::LW), HiReg)
11121 .addFrameIndex(FI)
11122 .addImm(4)
11123 .addMemOperand(MMOHi);
11124 MI.eraseFromParent(); // The pseudo instruction is gone now.
11125 return BB;
11126 }
11127
emitBuildPairF64Pseudo(MachineInstr & MI,MachineBasicBlock * BB)11128 static MachineBasicBlock *emitBuildPairF64Pseudo(MachineInstr &MI,
11129 MachineBasicBlock *BB) {
11130 assert(MI.getOpcode() == RISCV::BuildPairF64Pseudo &&
11131 "Unexpected instruction");
11132
11133 MachineFunction &MF = *BB->getParent();
11134 DebugLoc DL = MI.getDebugLoc();
11135 const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
11136 const TargetRegisterInfo *RI = MF.getSubtarget().getRegisterInfo();
11137 Register DstReg = MI.getOperand(0).getReg();
11138 Register LoReg = MI.getOperand(1).getReg();
11139 Register HiReg = MI.getOperand(2).getReg();
11140 const TargetRegisterClass *DstRC = &RISCV::FPR64RegClass;
11141 int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);
11142
11143 MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, FI);
11144 MachineMemOperand *MMOLo =
11145 MF.getMachineMemOperand(MPI, MachineMemOperand::MOStore, 4, Align(8));
11146 MachineMemOperand *MMOHi = MF.getMachineMemOperand(
11147 MPI.getWithOffset(4), MachineMemOperand::MOStore, 4, Align(8));
11148 BuildMI(*BB, MI, DL, TII.get(RISCV::SW))
11149 .addReg(LoReg, getKillRegState(MI.getOperand(1).isKill()))
11150 .addFrameIndex(FI)
11151 .addImm(0)
11152 .addMemOperand(MMOLo);
11153 BuildMI(*BB, MI, DL, TII.get(RISCV::SW))
11154 .addReg(HiReg, getKillRegState(MI.getOperand(2).isKill()))
11155 .addFrameIndex(FI)
11156 .addImm(4)
11157 .addMemOperand(MMOHi);
11158 TII.loadRegFromStackSlot(*BB, MI, DstReg, FI, DstRC, RI, Register());
11159 MI.eraseFromParent(); // The pseudo instruction is gone now.
11160 return BB;
11161 }
11162
isSelectPseudo(MachineInstr & MI)11163 static bool isSelectPseudo(MachineInstr &MI) {
11164 switch (MI.getOpcode()) {
11165 default:
11166 return false;
11167 case RISCV::Select_GPR_Using_CC_GPR:
11168 case RISCV::Select_FPR16_Using_CC_GPR:
11169 case RISCV::Select_FPR32_Using_CC_GPR:
11170 case RISCV::Select_FPR64_Using_CC_GPR:
11171 return true;
11172 }
11173 }
11174
emitQuietFCMP(MachineInstr & MI,MachineBasicBlock * BB,unsigned RelOpcode,unsigned EqOpcode,const RISCVSubtarget & Subtarget)11175 static MachineBasicBlock *emitQuietFCMP(MachineInstr &MI, MachineBasicBlock *BB,
11176 unsigned RelOpcode, unsigned EqOpcode,
11177 const RISCVSubtarget &Subtarget) {
11178 DebugLoc DL = MI.getDebugLoc();
11179 Register DstReg = MI.getOperand(0).getReg();
11180 Register Src1Reg = MI.getOperand(1).getReg();
11181 Register Src2Reg = MI.getOperand(2).getReg();
11182 MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
11183 Register SavedFFlags = MRI.createVirtualRegister(&RISCV::GPRRegClass);
11184 const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo();
11185
11186 // Save the current FFLAGS.
11187 BuildMI(*BB, MI, DL, TII.get(RISCV::ReadFFLAGS), SavedFFlags);
11188
11189 auto MIB = BuildMI(*BB, MI, DL, TII.get(RelOpcode), DstReg)
11190 .addReg(Src1Reg)
11191 .addReg(Src2Reg);
11192 if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
11193 MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
11194
11195 // Restore the FFLAGS.
11196 BuildMI(*BB, MI, DL, TII.get(RISCV::WriteFFLAGS))
11197 .addReg(SavedFFlags, RegState::Kill);
11198
11199 // Issue a dummy FEQ opcode to raise exception for signaling NaNs.
11200 auto MIB2 = BuildMI(*BB, MI, DL, TII.get(EqOpcode), RISCV::X0)
11201 .addReg(Src1Reg, getKillRegState(MI.getOperand(1).isKill()))
11202 .addReg(Src2Reg, getKillRegState(MI.getOperand(2).isKill()));
11203 if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
11204 MIB2->setFlag(MachineInstr::MIFlag::NoFPExcept);
11205
11206 // Erase the pseudoinstruction.
11207 MI.eraseFromParent();
11208 return BB;
11209 }
11210
11211 static MachineBasicBlock *
EmitLoweredCascadedSelect(MachineInstr & First,MachineInstr & Second,MachineBasicBlock * ThisMBB,const RISCVSubtarget & Subtarget)11212 EmitLoweredCascadedSelect(MachineInstr &First, MachineInstr &Second,
11213 MachineBasicBlock *ThisMBB,
11214 const RISCVSubtarget &Subtarget) {
11215 // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5)
11216 // Without this, custom-inserter would have generated:
11217 //
11218 // A
11219 // | \
11220 // | B
11221 // | /
11222 // C
11223 // | \
11224 // | D
11225 // | /
11226 // E
11227 //
11228 // A: X = ...; Y = ...
11229 // B: empty
11230 // C: Z = PHI [X, A], [Y, B]
11231 // D: empty
11232 // E: PHI [X, C], [Z, D]
11233 //
11234 // If we lower both Select_FPRX_ in a single step, we can instead generate:
11235 //
11236 // A
11237 // | \
11238 // | C
11239 // | /|
11240 // |/ |
11241 // | |
11242 // | D
11243 // | /
11244 // E
11245 //
11246 // A: X = ...; Y = ...
11247 // D: empty
11248 // E: PHI [X, A], [X, C], [Y, D]
11249
11250 const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
11251 const DebugLoc &DL = First.getDebugLoc();
11252 const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock();
11253 MachineFunction *F = ThisMBB->getParent();
11254 MachineBasicBlock *FirstMBB = F->CreateMachineBasicBlock(LLVM_BB);
11255 MachineBasicBlock *SecondMBB = F->CreateMachineBasicBlock(LLVM_BB);
11256 MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB);
11257 MachineFunction::iterator It = ++ThisMBB->getIterator();
11258 F->insert(It, FirstMBB);
11259 F->insert(It, SecondMBB);
11260 F->insert(It, SinkMBB);
11261
11262 // Transfer the remainder of ThisMBB and its successor edges to SinkMBB.
11263 SinkMBB->splice(SinkMBB->begin(), ThisMBB,
11264 std::next(MachineBasicBlock::iterator(First)),
11265 ThisMBB->end());
11266 SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB);
11267
11268 // Fallthrough block for ThisMBB.
11269 ThisMBB->addSuccessor(FirstMBB);
11270 // Fallthrough block for FirstMBB.
11271 FirstMBB->addSuccessor(SecondMBB);
11272 ThisMBB->addSuccessor(SinkMBB);
11273 FirstMBB->addSuccessor(SinkMBB);
11274 // This is fallthrough.
11275 SecondMBB->addSuccessor(SinkMBB);
11276
11277 auto FirstCC = static_cast<RISCVCC::CondCode>(First.getOperand(3).getImm());
11278 Register FLHS = First.getOperand(1).getReg();
11279 Register FRHS = First.getOperand(2).getReg();
11280 // Insert appropriate branch.
11281 BuildMI(FirstMBB, DL, TII.getBrCond(FirstCC))
11282 .addReg(FLHS)
11283 .addReg(FRHS)
11284 .addMBB(SinkMBB);
11285
11286 Register SLHS = Second.getOperand(1).getReg();
11287 Register SRHS = Second.getOperand(2).getReg();
11288 Register Op1Reg4 = First.getOperand(4).getReg();
11289 Register Op1Reg5 = First.getOperand(5).getReg();
11290
11291 auto SecondCC = static_cast<RISCVCC::CondCode>(Second.getOperand(3).getImm());
11292 // Insert appropriate branch.
11293 BuildMI(ThisMBB, DL, TII.getBrCond(SecondCC))
11294 .addReg(SLHS)
11295 .addReg(SRHS)
11296 .addMBB(SinkMBB);
11297
11298 Register DestReg = Second.getOperand(0).getReg();
11299 Register Op2Reg4 = Second.getOperand(4).getReg();
11300 BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg)
11301 .addReg(Op2Reg4)
11302 .addMBB(ThisMBB)
11303 .addReg(Op1Reg4)
11304 .addMBB(FirstMBB)
11305 .addReg(Op1Reg5)
11306 .addMBB(SecondMBB);
11307
11308 // Now remove the Select_FPRX_s.
11309 First.eraseFromParent();
11310 Second.eraseFromParent();
11311 return SinkMBB;
11312 }
11313
emitSelectPseudo(MachineInstr & MI,MachineBasicBlock * BB,const RISCVSubtarget & Subtarget)11314 static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
11315 MachineBasicBlock *BB,
11316 const RISCVSubtarget &Subtarget) {
11317 // To "insert" Select_* instructions, we actually have to insert the triangle
11318 // control-flow pattern. The incoming instructions know the destination vreg
11319 // to set, the condition code register to branch on, the true/false values to
11320 // select between, and the condcode to use to select the appropriate branch.
11321 //
11322 // We produce the following control flow:
11323 // HeadMBB
11324 // | \
11325 // | IfFalseMBB
11326 // | /
11327 // TailMBB
11328 //
11329 // When we find a sequence of selects we attempt to optimize their emission
11330 // by sharing the control flow. Currently we only handle cases where we have
11331 // multiple selects with the exact same condition (same LHS, RHS and CC).
11332 // The selects may be interleaved with other instructions if the other
11333 // instructions meet some requirements we deem safe:
11334 // - They are not pseudo instructions.
11335 // - They are debug instructions. Otherwise,
11336 // - They do not have side-effects, do not access memory and their inputs do
11337 // not depend on the results of the select pseudo-instructions.
11338 // The TrueV/FalseV operands of the selects cannot depend on the result of
11339 // previous selects in the sequence.
11340 // These conditions could be further relaxed. See the X86 target for a
11341 // related approach and more information.
11342 //
11343 // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5))
11344 // is checked here and handled by a separate function -
11345 // EmitLoweredCascadedSelect.
11346 Register LHS = MI.getOperand(1).getReg();
11347 Register RHS = MI.getOperand(2).getReg();
11348 auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());
11349
11350 SmallVector<MachineInstr *, 4> SelectDebugValues;
11351 SmallSet<Register, 4> SelectDests;
11352 SelectDests.insert(MI.getOperand(0).getReg());
11353
11354 MachineInstr *LastSelectPseudo = &MI;
11355 auto Next = next_nodbg(MI.getIterator(), BB->instr_end());
11356 if (MI.getOpcode() != RISCV::Select_GPR_Using_CC_GPR && Next != BB->end() &&
11357 Next->getOpcode() == MI.getOpcode() &&
11358 Next->getOperand(5).getReg() == MI.getOperand(0).getReg() &&
11359 Next->getOperand(5).isKill()) {
11360 return EmitLoweredCascadedSelect(MI, *Next, BB, Subtarget);
11361 }
11362
11363 for (auto E = BB->end(), SequenceMBBI = MachineBasicBlock::iterator(MI);
11364 SequenceMBBI != E; ++SequenceMBBI) {
11365 if (SequenceMBBI->isDebugInstr())
11366 continue;
11367 if (isSelectPseudo(*SequenceMBBI)) {
11368 if (SequenceMBBI->getOperand(1).getReg() != LHS ||
11369 SequenceMBBI->getOperand(2).getReg() != RHS ||
11370 SequenceMBBI->getOperand(3).getImm() != CC ||
11371 SelectDests.count(SequenceMBBI->getOperand(4).getReg()) ||
11372 SelectDests.count(SequenceMBBI->getOperand(5).getReg()))
11373 break;
11374 LastSelectPseudo = &*SequenceMBBI;
11375 SequenceMBBI->collectDebugValues(SelectDebugValues);
11376 SelectDests.insert(SequenceMBBI->getOperand(0).getReg());
11377 continue;
11378 }
11379 if (SequenceMBBI->hasUnmodeledSideEffects() ||
11380 SequenceMBBI->mayLoadOrStore() ||
11381 SequenceMBBI->usesCustomInsertionHook())
11382 break;
11383 if (llvm::any_of(SequenceMBBI->operands(), [&](MachineOperand &MO) {
11384 return MO.isReg() && MO.isUse() && SelectDests.count(MO.getReg());
11385 }))
11386 break;
11387 }
11388
11389 const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
11390 const BasicBlock *LLVM_BB = BB->getBasicBlock();
11391 DebugLoc DL = MI.getDebugLoc();
11392 MachineFunction::iterator I = ++BB->getIterator();
11393
11394 MachineBasicBlock *HeadMBB = BB;
11395 MachineFunction *F = BB->getParent();
11396 MachineBasicBlock *TailMBB = F->CreateMachineBasicBlock(LLVM_BB);
11397 MachineBasicBlock *IfFalseMBB = F->CreateMachineBasicBlock(LLVM_BB);
11398
11399 F->insert(I, IfFalseMBB);
11400 F->insert(I, TailMBB);
11401
11402 // Transfer debug instructions associated with the selects to TailMBB.
11403 for (MachineInstr *DebugInstr : SelectDebugValues) {
11404 TailMBB->push_back(DebugInstr->removeFromParent());
11405 }
11406
11407 // Move all instructions after the sequence to TailMBB.
11408 TailMBB->splice(TailMBB->end(), HeadMBB,
11409 std::next(LastSelectPseudo->getIterator()), HeadMBB->end());
11410 // Update machine-CFG edges by transferring all successors of the current
11411 // block to the new block which will contain the Phi nodes for the selects.
11412 TailMBB->transferSuccessorsAndUpdatePHIs(HeadMBB);
11413 // Set the successors for HeadMBB.
11414 HeadMBB->addSuccessor(IfFalseMBB);
11415 HeadMBB->addSuccessor(TailMBB);
11416
11417 // Insert appropriate branch.
11418 BuildMI(HeadMBB, DL, TII.getBrCond(CC))
11419 .addReg(LHS)
11420 .addReg(RHS)
11421 .addMBB(TailMBB);
11422
11423 // IfFalseMBB just falls through to TailMBB.
11424 IfFalseMBB->addSuccessor(TailMBB);
11425
11426 // Create PHIs for all of the select pseudo-instructions.
11427 auto SelectMBBI = MI.getIterator();
11428 auto SelectEnd = std::next(LastSelectPseudo->getIterator());
11429 auto InsertionPoint = TailMBB->begin();
11430 while (SelectMBBI != SelectEnd) {
11431 auto Next = std::next(SelectMBBI);
11432 if (isSelectPseudo(*SelectMBBI)) {
11433 // %Result = phi [ %TrueValue, HeadMBB ], [ %FalseValue, IfFalseMBB ]
11434 BuildMI(*TailMBB, InsertionPoint, SelectMBBI->getDebugLoc(),
11435 TII.get(RISCV::PHI), SelectMBBI->getOperand(0).getReg())
11436 .addReg(SelectMBBI->getOperand(4).getReg())
11437 .addMBB(HeadMBB)
11438 .addReg(SelectMBBI->getOperand(5).getReg())
11439 .addMBB(IfFalseMBB);
11440 SelectMBBI->eraseFromParent();
11441 }
11442 SelectMBBI = Next;
11443 }
11444
11445 F->getProperties().reset(MachineFunctionProperties::Property::NoPHIs);
11446 return TailMBB;
11447 }
11448
11449 static MachineBasicBlock *
emitVFCVT_RM_MASK(MachineInstr & MI,MachineBasicBlock * BB,unsigned Opcode)11450 emitVFCVT_RM_MASK(MachineInstr &MI, MachineBasicBlock *BB, unsigned Opcode) {
11451 DebugLoc DL = MI.getDebugLoc();
11452
11453 const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo();
11454
11455 MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
11456 Register SavedFRM = MRI.createVirtualRegister(&RISCV::GPRRegClass);
11457
11458 // Update FRM and save the old value.
11459 BuildMI(*BB, MI, DL, TII.get(RISCV::SwapFRMImm), SavedFRM)
11460 .addImm(MI.getOperand(4).getImm());
11461
11462 // Emit an VFCVT without the FRM operand.
11463 assert(MI.getNumOperands() == 8);
11464 auto MIB = BuildMI(*BB, MI, DL, TII.get(Opcode))
11465 .add(MI.getOperand(0))
11466 .add(MI.getOperand(1))
11467 .add(MI.getOperand(2))
11468 .add(MI.getOperand(3))
11469 .add(MI.getOperand(5))
11470 .add(MI.getOperand(6))
11471 .add(MI.getOperand(7));
11472 if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
11473 MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
11474
11475 // Restore FRM.
11476 BuildMI(*BB, MI, DL, TII.get(RISCV::WriteFRM))
11477 .addReg(SavedFRM, RegState::Kill);
11478
11479 // Erase the pseudoinstruction.
11480 MI.eraseFromParent();
11481 return BB;
11482 }
11483
emitVFROUND_NOEXCEPT_MASK(MachineInstr & MI,MachineBasicBlock * BB,unsigned CVTXOpc,unsigned CVTFOpc)11484 static MachineBasicBlock *emitVFROUND_NOEXCEPT_MASK(MachineInstr &MI,
11485 MachineBasicBlock *BB,
11486 unsigned CVTXOpc,
11487 unsigned CVTFOpc) {
11488 DebugLoc DL = MI.getDebugLoc();
11489
11490 const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo();
11491
11492 MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
11493 Register SavedFFLAGS = MRI.createVirtualRegister(&RISCV::GPRRegClass);
11494
11495 // Save the old value of FFLAGS.
11496 BuildMI(*BB, MI, DL, TII.get(RISCV::ReadFFLAGS), SavedFFLAGS);
11497
11498 assert(MI.getNumOperands() == 7);
11499
11500 // Emit a VFCVT_X_F
11501 const TargetRegisterInfo *TRI =
11502 BB->getParent()->getSubtarget().getRegisterInfo();
11503 const TargetRegisterClass *RC = MI.getRegClassConstraint(0, &TII, TRI);
11504 Register Tmp = MRI.createVirtualRegister(RC);
11505 BuildMI(*BB, MI, DL, TII.get(CVTXOpc), Tmp)
11506 .add(MI.getOperand(1))
11507 .add(MI.getOperand(2))
11508 .add(MI.getOperand(3))
11509 .add(MI.getOperand(4))
11510 .add(MI.getOperand(5))
11511 .add(MI.getOperand(6));
11512
11513 // Emit a VFCVT_F_X
11514 BuildMI(*BB, MI, DL, TII.get(CVTFOpc))
11515 .add(MI.getOperand(0))
11516 .add(MI.getOperand(1))
11517 .addReg(Tmp)
11518 .add(MI.getOperand(3))
11519 .add(MI.getOperand(4))
11520 .add(MI.getOperand(5))
11521 .add(MI.getOperand(6));
11522
11523 // Restore FFLAGS.
11524 BuildMI(*BB, MI, DL, TII.get(RISCV::WriteFFLAGS))
11525 .addReg(SavedFFLAGS, RegState::Kill);
11526
11527 // Erase the pseudoinstruction.
11528 MI.eraseFromParent();
11529 return BB;
11530 }
11531
emitFROUND(MachineInstr & MI,MachineBasicBlock * MBB,const RISCVSubtarget & Subtarget)11532 static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB,
11533 const RISCVSubtarget &Subtarget) {
11534 unsigned CmpOpc, F2IOpc, I2FOpc, FSGNJOpc, FSGNJXOpc;
11535 const TargetRegisterClass *RC;
11536 switch (MI.getOpcode()) {
11537 default:
11538 llvm_unreachable("Unexpected opcode");
11539 case RISCV::PseudoFROUND_H:
11540 CmpOpc = RISCV::FLT_H;
11541 F2IOpc = RISCV::FCVT_W_H;
11542 I2FOpc = RISCV::FCVT_H_W;
11543 FSGNJOpc = RISCV::FSGNJ_H;
11544 FSGNJXOpc = RISCV::FSGNJX_H;
11545 RC = &RISCV::FPR16RegClass;
11546 break;
11547 case RISCV::PseudoFROUND_S:
11548 CmpOpc = RISCV::FLT_S;
11549 F2IOpc = RISCV::FCVT_W_S;
11550 I2FOpc = RISCV::FCVT_S_W;
11551 FSGNJOpc = RISCV::FSGNJ_S;
11552 FSGNJXOpc = RISCV::FSGNJX_S;
11553 RC = &RISCV::FPR32RegClass;
11554 break;
11555 case RISCV::PseudoFROUND_D:
11556 assert(Subtarget.is64Bit() && "Expected 64-bit GPR.");
11557 CmpOpc = RISCV::FLT_D;
11558 F2IOpc = RISCV::FCVT_L_D;
11559 I2FOpc = RISCV::FCVT_D_L;
11560 FSGNJOpc = RISCV::FSGNJ_D;
11561 FSGNJXOpc = RISCV::FSGNJX_D;
11562 RC = &RISCV::FPR64RegClass;
11563 break;
11564 }
11565
11566 const BasicBlock *BB = MBB->getBasicBlock();
11567 DebugLoc DL = MI.getDebugLoc();
11568 MachineFunction::iterator I = ++MBB->getIterator();
11569
11570 MachineFunction *F = MBB->getParent();
11571 MachineBasicBlock *CvtMBB = F->CreateMachineBasicBlock(BB);
11572 MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(BB);
11573
11574 F->insert(I, CvtMBB);
11575 F->insert(I, DoneMBB);
11576 // Move all instructions after the sequence to DoneMBB.
11577 DoneMBB->splice(DoneMBB->end(), MBB, MachineBasicBlock::iterator(MI),
11578 MBB->end());
11579 // Update machine-CFG edges by transferring all successors of the current
11580 // block to the new block which will contain the Phi nodes for the selects.
11581 DoneMBB->transferSuccessorsAndUpdatePHIs(MBB);
11582 // Set the successors for MBB.
11583 MBB->addSuccessor(CvtMBB);
11584 MBB->addSuccessor(DoneMBB);
11585
11586 Register DstReg = MI.getOperand(0).getReg();
11587 Register SrcReg = MI.getOperand(1).getReg();
11588 Register MaxReg = MI.getOperand(2).getReg();
11589 int64_t FRM = MI.getOperand(3).getImm();
11590
11591 const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
11592 MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
11593
11594 Register FabsReg = MRI.createVirtualRegister(RC);
11595 BuildMI(MBB, DL, TII.get(FSGNJXOpc), FabsReg).addReg(SrcReg).addReg(SrcReg);
11596
11597 // Compare the FP value to the max value.
11598 Register CmpReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
11599 auto MIB =
11600 BuildMI(MBB, DL, TII.get(CmpOpc), CmpReg).addReg(FabsReg).addReg(MaxReg);
11601 if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
11602 MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
11603
11604 // Insert branch.
11605 BuildMI(MBB, DL, TII.get(RISCV::BEQ))
11606 .addReg(CmpReg)
11607 .addReg(RISCV::X0)
11608 .addMBB(DoneMBB);
11609
11610 CvtMBB->addSuccessor(DoneMBB);
11611
11612 // Convert to integer.
11613 Register F2IReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
11614 MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg).addImm(FRM);
11615 if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
11616 MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
11617
11618 // Convert back to FP.
11619 Register I2FReg = MRI.createVirtualRegister(RC);
11620 MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg).addImm(FRM);
11621 if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
11622 MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
11623
11624 // Restore the sign bit.
11625 Register CvtReg = MRI.createVirtualRegister(RC);
11626 BuildMI(CvtMBB, DL, TII.get(FSGNJOpc), CvtReg).addReg(I2FReg).addReg(SrcReg);
11627
11628 // Merge the results.
11629 BuildMI(*DoneMBB, DoneMBB->begin(), DL, TII.get(RISCV::PHI), DstReg)
11630 .addReg(SrcReg)
11631 .addMBB(MBB)
11632 .addReg(CvtReg)
11633 .addMBB(CvtMBB);
11634
11635 MI.eraseFromParent();
11636 return DoneMBB;
11637 }
11638
11639 MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr & MI,MachineBasicBlock * BB) const11640 RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
11641 MachineBasicBlock *BB) const {
11642 switch (MI.getOpcode()) {
11643 default:
11644 llvm_unreachable("Unexpected instr type to insert");
11645 case RISCV::ReadCycleWide:
11646 assert(!Subtarget.is64Bit() &&
11647 "ReadCycleWrite is only to be used on riscv32");
11648 return emitReadCycleWidePseudo(MI, BB);
11649 case RISCV::Select_GPR_Using_CC_GPR:
11650 case RISCV::Select_FPR16_Using_CC_GPR:
11651 case RISCV::Select_FPR32_Using_CC_GPR:
11652 case RISCV::Select_FPR64_Using_CC_GPR:
11653 return emitSelectPseudo(MI, BB, Subtarget);
11654 case RISCV::BuildPairF64Pseudo:
11655 return emitBuildPairF64Pseudo(MI, BB);
11656 case RISCV::SplitF64Pseudo:
11657 return emitSplitF64Pseudo(MI, BB);
11658 case RISCV::PseudoQuietFLE_H:
11659 return emitQuietFCMP(MI, BB, RISCV::FLE_H, RISCV::FEQ_H, Subtarget);
11660 case RISCV::PseudoQuietFLT_H:
11661 return emitQuietFCMP(MI, BB, RISCV::FLT_H, RISCV::FEQ_H, Subtarget);
11662 case RISCV::PseudoQuietFLE_S:
11663 return emitQuietFCMP(MI, BB, RISCV::FLE_S, RISCV::FEQ_S, Subtarget);
11664 case RISCV::PseudoQuietFLT_S:
11665 return emitQuietFCMP(MI, BB, RISCV::FLT_S, RISCV::FEQ_S, Subtarget);
11666 case RISCV::PseudoQuietFLE_D:
11667 return emitQuietFCMP(MI, BB, RISCV::FLE_D, RISCV::FEQ_D, Subtarget);
11668 case RISCV::PseudoQuietFLT_D:
11669 return emitQuietFCMP(MI, BB, RISCV::FLT_D, RISCV::FEQ_D, Subtarget);
11670
11671 // =========================================================================
11672 // VFCVT
11673 // =========================================================================
11674
11675 case RISCV::PseudoVFCVT_RM_X_F_V_M1_MASK:
11676 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M1_MASK);
11677 case RISCV::PseudoVFCVT_RM_X_F_V_M2_MASK:
11678 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M2_MASK);
11679 case RISCV::PseudoVFCVT_RM_X_F_V_M4_MASK:
11680 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M4_MASK);
11681 case RISCV::PseudoVFCVT_RM_X_F_V_M8_MASK:
11682 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M8_MASK);
11683 case RISCV::PseudoVFCVT_RM_X_F_V_MF2_MASK:
11684 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_MF2_MASK);
11685 case RISCV::PseudoVFCVT_RM_X_F_V_MF4_MASK:
11686 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_MF4_MASK);
11687
11688 case RISCV::PseudoVFCVT_RM_XU_F_V_M1_MASK:
11689 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_XU_F_V_M1_MASK);
11690 case RISCV::PseudoVFCVT_RM_XU_F_V_M2_MASK:
11691 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_XU_F_V_M2_MASK);
11692 case RISCV::PseudoVFCVT_RM_XU_F_V_M4_MASK:
11693 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_XU_F_V_M4_MASK);
11694 case RISCV::PseudoVFCVT_RM_XU_F_V_M8_MASK:
11695 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_XU_F_V_M8_MASK);
11696 case RISCV::PseudoVFCVT_RM_XU_F_V_MF2_MASK:
11697 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_XU_F_V_MF2_MASK);
11698 case RISCV::PseudoVFCVT_RM_XU_F_V_MF4_MASK:
11699 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_XU_F_V_MF4_MASK);
11700
11701 case RISCV::PseudoVFCVT_RM_F_XU_V_M1_MASK:
11702 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_XU_V_M1_MASK);
11703 case RISCV::PseudoVFCVT_RM_F_XU_V_M2_MASK:
11704 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_XU_V_M2_MASK);
11705 case RISCV::PseudoVFCVT_RM_F_XU_V_M4_MASK:
11706 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_XU_V_M4_MASK);
11707 case RISCV::PseudoVFCVT_RM_F_XU_V_M8_MASK:
11708 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_XU_V_M8_MASK);
11709 case RISCV::PseudoVFCVT_RM_F_XU_V_MF2_MASK:
11710 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_XU_V_MF2_MASK);
11711 case RISCV::PseudoVFCVT_RM_F_XU_V_MF4_MASK:
11712 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_XU_V_MF4_MASK);
11713
11714 case RISCV::PseudoVFCVT_RM_F_X_V_M1_MASK:
11715 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_X_V_M1_MASK);
11716 case RISCV::PseudoVFCVT_RM_F_X_V_M2_MASK:
11717 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_X_V_M2_MASK);
11718 case RISCV::PseudoVFCVT_RM_F_X_V_M4_MASK:
11719 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_X_V_M4_MASK);
11720 case RISCV::PseudoVFCVT_RM_F_X_V_M8_MASK:
11721 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_X_V_M8_MASK);
11722 case RISCV::PseudoVFCVT_RM_F_X_V_MF2_MASK:
11723 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_X_V_MF2_MASK);
11724 case RISCV::PseudoVFCVT_RM_F_X_V_MF4_MASK:
11725 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFCVT_F_X_V_MF4_MASK);
11726
11727 // =========================================================================
11728 // VFWCVT
11729 // =========================================================================
11730
11731 case RISCV::PseudoVFWCVT_RM_XU_F_V_M1_MASK:
11732 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_M1_MASK);
11733 case RISCV::PseudoVFWCVT_RM_XU_F_V_M2_MASK:
11734 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_M2_MASK);
11735 case RISCV::PseudoVFWCVT_RM_XU_F_V_M4_MASK:
11736 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_M4_MASK);
11737 case RISCV::PseudoVFWCVT_RM_XU_F_V_MF2_MASK:
11738 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_MF2_MASK);
11739 case RISCV::PseudoVFWCVT_RM_XU_F_V_MF4_MASK:
11740 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_MF4_MASK);
11741
11742 case RISCV::PseudoVFWCVT_RM_X_F_V_M1_MASK:
11743 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_M1_MASK);
11744 case RISCV::PseudoVFWCVT_RM_X_F_V_M2_MASK:
11745 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_M2_MASK);
11746 case RISCV::PseudoVFWCVT_RM_X_F_V_M4_MASK:
11747 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_M4_MASK);
11748 case RISCV::PseudoVFWCVT_RM_X_F_V_MF2_MASK:
11749 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_MF2_MASK);
11750 case RISCV::PseudoVFWCVT_RM_X_F_V_MF4_MASK:
11751 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_X_F_V_MF4_MASK);
11752
11753 case RISCV::PseudoVFWCVT_RM_F_XU_V_M1_MASK:
11754 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_M1_MASK);
11755 case RISCV::PseudoVFWCVT_RM_F_XU_V_M2_MASK:
11756 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_M2_MASK);
11757 case RISCV::PseudoVFWCVT_RM_F_XU_V_M4_MASK:
11758 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_M4_MASK);
11759 case RISCV::PseudoVFWCVT_RM_F_XU_V_MF2_MASK:
11760 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_MF2_MASK);
11761 case RISCV::PseudoVFWCVT_RM_F_XU_V_MF4_MASK:
11762 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_MF4_MASK);
11763 case RISCV::PseudoVFWCVT_RM_F_XU_V_MF8_MASK:
11764 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_MF8_MASK);
11765
11766 case RISCV::PseudoVFWCVT_RM_F_X_V_M1_MASK:
11767 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_M1_MASK);
11768 case RISCV::PseudoVFWCVT_RM_F_X_V_M2_MASK:
11769 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_M2_MASK);
11770 case RISCV::PseudoVFWCVT_RM_F_X_V_M4_MASK:
11771 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_M4_MASK);
11772 case RISCV::PseudoVFWCVT_RM_F_X_V_MF2_MASK:
11773 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_MF2_MASK);
11774 case RISCV::PseudoVFWCVT_RM_F_X_V_MF4_MASK:
11775 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_MF4_MASK);
11776 case RISCV::PseudoVFWCVT_RM_F_X_V_MF8_MASK:
11777 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFWCVT_F_XU_V_MF8_MASK);
11778
11779 // =========================================================================
11780 // VFNCVT
11781 // =========================================================================
11782
11783 case RISCV::PseudoVFNCVT_RM_XU_F_W_M1_MASK:
11784 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_M1_MASK);
11785 case RISCV::PseudoVFNCVT_RM_XU_F_W_M2_MASK:
11786 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_M2_MASK);
11787 case RISCV::PseudoVFNCVT_RM_XU_F_W_M4_MASK:
11788 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_M4_MASK);
11789 case RISCV::PseudoVFNCVT_RM_XU_F_W_MF2_MASK:
11790 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_MF2_MASK);
11791 case RISCV::PseudoVFNCVT_RM_XU_F_W_MF4_MASK:
11792 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_MF4_MASK);
11793 case RISCV::PseudoVFNCVT_RM_XU_F_W_MF8_MASK:
11794 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_XU_F_W_MF8_MASK);
11795
11796 case RISCV::PseudoVFNCVT_RM_X_F_W_M1_MASK:
11797 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_M1_MASK);
11798 case RISCV::PseudoVFNCVT_RM_X_F_W_M2_MASK:
11799 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_M2_MASK);
11800 case RISCV::PseudoVFNCVT_RM_X_F_W_M4_MASK:
11801 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_M4_MASK);
11802 case RISCV::PseudoVFNCVT_RM_X_F_W_MF2_MASK:
11803 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_MF2_MASK);
11804 case RISCV::PseudoVFNCVT_RM_X_F_W_MF4_MASK:
11805 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_MF4_MASK);
11806 case RISCV::PseudoVFNCVT_RM_X_F_W_MF8_MASK:
11807 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_X_F_W_MF8_MASK);
11808
11809 case RISCV::PseudoVFNCVT_RM_F_XU_W_M1_MASK:
11810 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_M1_MASK);
11811 case RISCV::PseudoVFNCVT_RM_F_XU_W_M2_MASK:
11812 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_M2_MASK);
11813 case RISCV::PseudoVFNCVT_RM_F_XU_W_M4_MASK:
11814 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_M4_MASK);
11815 case RISCV::PseudoVFNCVT_RM_F_XU_W_MF2_MASK:
11816 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_MF2_MASK);
11817 case RISCV::PseudoVFNCVT_RM_F_XU_W_MF4_MASK:
11818 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_MF4_MASK);
11819
11820 case RISCV::PseudoVFNCVT_RM_F_X_W_M1_MASK:
11821 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_M1_MASK);
11822 case RISCV::PseudoVFNCVT_RM_F_X_W_M2_MASK:
11823 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_M2_MASK);
11824 case RISCV::PseudoVFNCVT_RM_F_X_W_M4_MASK:
11825 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_M4_MASK);
11826 case RISCV::PseudoVFNCVT_RM_F_X_W_MF2_MASK:
11827 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_MF2_MASK);
11828 case RISCV::PseudoVFNCVT_RM_F_X_W_MF4_MASK:
11829 return emitVFCVT_RM_MASK(MI, BB, RISCV::PseudoVFNCVT_F_XU_W_MF4_MASK);
11830
11831 case RISCV::PseudoVFROUND_NOEXCEPT_V_M1_MASK:
11832 return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M1_MASK,
11833 RISCV::PseudoVFCVT_F_X_V_M1_MASK);
11834 case RISCV::PseudoVFROUND_NOEXCEPT_V_M2_MASK:
11835 return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M2_MASK,
11836 RISCV::PseudoVFCVT_F_X_V_M2_MASK);
11837 case RISCV::PseudoVFROUND_NOEXCEPT_V_M4_MASK:
11838 return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M4_MASK,
11839 RISCV::PseudoVFCVT_F_X_V_M4_MASK);
11840 case RISCV::PseudoVFROUND_NOEXCEPT_V_M8_MASK:
11841 return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M8_MASK,
11842 RISCV::PseudoVFCVT_F_X_V_M8_MASK);
11843 case RISCV::PseudoVFROUND_NOEXCEPT_V_MF2_MASK:
11844 return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_MF2_MASK,
11845 RISCV::PseudoVFCVT_F_X_V_MF2_MASK);
11846 case RISCV::PseudoVFROUND_NOEXCEPT_V_MF4_MASK:
11847 return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_MF4_MASK,
11848 RISCV::PseudoVFCVT_F_X_V_MF4_MASK);
11849 case RISCV::PseudoFROUND_H:
11850 case RISCV::PseudoFROUND_S:
11851 case RISCV::PseudoFROUND_D:
11852 return emitFROUND(MI, BB, Subtarget);
11853 }
11854 }
11855
AdjustInstrPostInstrSelection(MachineInstr & MI,SDNode * Node) const11856 void RISCVTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
11857 SDNode *Node) const {
11858 // Add FRM dependency to any instructions with dynamic rounding mode.
11859 unsigned Opc = MI.getOpcode();
11860 auto Idx = RISCV::getNamedOperandIdx(Opc, RISCV::OpName::frm);
11861 if (Idx < 0)
11862 return;
11863 if (MI.getOperand(Idx).getImm() != RISCVFPRndMode::DYN)
11864 return;
11865 // If the instruction already reads FRM, don't add another read.
11866 if (MI.readsRegister(RISCV::FRM))
11867 return;
11868 MI.addOperand(
11869 MachineOperand::CreateReg(RISCV::FRM, /*isDef*/ false, /*isImp*/ true));
11870 }
11871
11872 // Calling Convention Implementation.
11873 // The expectations for frontend ABI lowering vary from target to target.
11874 // Ideally, an LLVM frontend would be able to avoid worrying about many ABI
11875 // details, but this is a longer term goal. For now, we simply try to keep the
11876 // role of the frontend as simple and well-defined as possible. The rules can
11877 // be summarised as:
11878 // * Never split up large scalar arguments. We handle them here.
11879 // * If a hardfloat calling convention is being used, and the struct may be
11880 // passed in a pair of registers (fp+fp, int+fp), and both registers are
11881 // available, then pass as two separate arguments. If either the GPRs or FPRs
11882 // are exhausted, then pass according to the rule below.
11883 // * If a struct could never be passed in registers or directly in a stack
11884 // slot (as it is larger than 2*XLEN and the floating point rules don't
11885 // apply), then pass it using a pointer with the byval attribute.
11886 // * If a struct is less than 2*XLEN, then coerce to either a two-element
11887 // word-sized array or a 2*XLEN scalar (depending on alignment).
11888 // * The frontend can determine whether a struct is returned by reference or
11889 // not based on its size and fields. If it will be returned by reference, the
11890 // frontend must modify the prototype so a pointer with the sret annotation is
11891 // passed as the first argument. This is not necessary for large scalar
11892 // returns.
11893 // * Struct return values and varargs should be coerced to structs containing
11894 // register-size fields in the same situations they would be for fixed
11895 // arguments.
11896
11897 static const MCPhysReg ArgGPRs[] = {
11898 RISCV::X10, RISCV::X11, RISCV::X12, RISCV::X13,
11899 RISCV::X14, RISCV::X15, RISCV::X16, RISCV::X17
11900 };
11901 static const MCPhysReg ArgFPR16s[] = {
11902 RISCV::F10_H, RISCV::F11_H, RISCV::F12_H, RISCV::F13_H,
11903 RISCV::F14_H, RISCV::F15_H, RISCV::F16_H, RISCV::F17_H
11904 };
11905 static const MCPhysReg ArgFPR32s[] = {
11906 RISCV::F10_F, RISCV::F11_F, RISCV::F12_F, RISCV::F13_F,
11907 RISCV::F14_F, RISCV::F15_F, RISCV::F16_F, RISCV::F17_F
11908 };
11909 static const MCPhysReg ArgFPR64s[] = {
11910 RISCV::F10_D, RISCV::F11_D, RISCV::F12_D, RISCV::F13_D,
11911 RISCV::F14_D, RISCV::F15_D, RISCV::F16_D, RISCV::F17_D
11912 };
11913 // This is an interim calling convention and it may be changed in the future.
11914 static const MCPhysReg ArgVRs[] = {
11915 RISCV::V8, RISCV::V9, RISCV::V10, RISCV::V11, RISCV::V12, RISCV::V13,
11916 RISCV::V14, RISCV::V15, RISCV::V16, RISCV::V17, RISCV::V18, RISCV::V19,
11917 RISCV::V20, RISCV::V21, RISCV::V22, RISCV::V23};
11918 static const MCPhysReg ArgVRM2s[] = {RISCV::V8M2, RISCV::V10M2, RISCV::V12M2,
11919 RISCV::V14M2, RISCV::V16M2, RISCV::V18M2,
11920 RISCV::V20M2, RISCV::V22M2};
11921 static const MCPhysReg ArgVRM4s[] = {RISCV::V8M4, RISCV::V12M4, RISCV::V16M4,
11922 RISCV::V20M4};
11923 static const MCPhysReg ArgVRM8s[] = {RISCV::V8M8, RISCV::V16M8};
11924
11925 // Pass a 2*XLEN argument that has been split into two XLEN values through
11926 // registers or the stack as necessary.
CC_RISCVAssign2XLen(unsigned XLen,CCState & State,CCValAssign VA1,ISD::ArgFlagsTy ArgFlags1,unsigned ValNo2,MVT ValVT2,MVT LocVT2,ISD::ArgFlagsTy ArgFlags2)11927 static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
11928 ISD::ArgFlagsTy ArgFlags1, unsigned ValNo2,
11929 MVT ValVT2, MVT LocVT2,
11930 ISD::ArgFlagsTy ArgFlags2) {
11931 unsigned XLenInBytes = XLen / 8;
11932 if (Register Reg = State.AllocateReg(ArgGPRs)) {
11933 // At least one half can be passed via register.
11934 State.addLoc(CCValAssign::getReg(VA1.getValNo(), VA1.getValVT(), Reg,
11935 VA1.getLocVT(), CCValAssign::Full));
11936 } else {
11937 // Both halves must be passed on the stack, with proper alignment.
11938 Align StackAlign =
11939 std::max(Align(XLenInBytes), ArgFlags1.getNonZeroOrigAlign());
11940 State.addLoc(
11941 CCValAssign::getMem(VA1.getValNo(), VA1.getValVT(),
11942 State.AllocateStack(XLenInBytes, StackAlign),
11943 VA1.getLocVT(), CCValAssign::Full));
11944 State.addLoc(CCValAssign::getMem(
11945 ValNo2, ValVT2, State.AllocateStack(XLenInBytes, Align(XLenInBytes)),
11946 LocVT2, CCValAssign::Full));
11947 return false;
11948 }
11949
11950 if (Register Reg = State.AllocateReg(ArgGPRs)) {
11951 // The second half can also be passed via register.
11952 State.addLoc(
11953 CCValAssign::getReg(ValNo2, ValVT2, Reg, LocVT2, CCValAssign::Full));
11954 } else {
11955 // The second half is passed via the stack, without additional alignment.
11956 State.addLoc(CCValAssign::getMem(
11957 ValNo2, ValVT2, State.AllocateStack(XLenInBytes, Align(XLenInBytes)),
11958 LocVT2, CCValAssign::Full));
11959 }
11960
11961 return false;
11962 }
11963
allocateRVVReg(MVT ValVT,unsigned ValNo,std::optional<unsigned> FirstMaskArgument,CCState & State,const RISCVTargetLowering & TLI)11964 static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo,
11965 std::optional<unsigned> FirstMaskArgument,
11966 CCState &State, const RISCVTargetLowering &TLI) {
11967 const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
11968 if (RC == &RISCV::VRRegClass) {
11969 // Assign the first mask argument to V0.
11970 // This is an interim calling convention and it may be changed in the
11971 // future.
11972 if (FirstMaskArgument && ValNo == *FirstMaskArgument)
11973 return State.AllocateReg(RISCV::V0);
11974 return State.AllocateReg(ArgVRs);
11975 }
11976 if (RC == &RISCV::VRM2RegClass)
11977 return State.AllocateReg(ArgVRM2s);
11978 if (RC == &RISCV::VRM4RegClass)
11979 return State.AllocateReg(ArgVRM4s);
11980 if (RC == &RISCV::VRM8RegClass)
11981 return State.AllocateReg(ArgVRM8s);
11982 llvm_unreachable("Unhandled register class for ValueType");
11983 }
11984
11985 // Implements the RISC-V calling convention. Returns true upon failure.
CC_RISCV(const DataLayout & DL,RISCVABI::ABI ABI,unsigned ValNo,MVT ValVT,MVT LocVT,CCValAssign::LocInfo LocInfo,ISD::ArgFlagsTy ArgFlags,CCState & State,bool IsFixed,bool IsRet,Type * OrigTy,const RISCVTargetLowering & TLI,std::optional<unsigned> FirstMaskArgument)11986 static bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
11987 MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
11988 ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
11989 bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
11990 std::optional<unsigned> FirstMaskArgument) {
11991 unsigned XLen = DL.getLargestLegalIntTypeSizeInBits();
11992 assert(XLen == 32 || XLen == 64);
11993 MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64;
11994
11995 // Static chain parameter must not be passed in normal argument registers,
11996 // so we assign t2 for it as done in GCC's __builtin_call_with_static_chain
11997 if (ArgFlags.isNest()) {
11998 if (unsigned Reg = State.AllocateReg(RISCV::X7)) {
11999 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12000 return false;
12001 }
12002 }
12003
12004 // Any return value split in to more than two values can't be returned
12005 // directly. Vectors are returned via the available vector registers.
12006 if (!LocVT.isVector() && IsRet && ValNo > 1)
12007 return true;
12008
12009 // UseGPRForF16_F32 if targeting one of the soft-float ABIs, if passing a
12010 // variadic argument, or if no F16/F32 argument registers are available.
12011 bool UseGPRForF16_F32 = true;
12012 // UseGPRForF64 if targeting soft-float ABIs or an FLEN=32 ABI, if passing a
12013 // variadic argument, or if no F64 argument registers are available.
12014 bool UseGPRForF64 = true;
12015
12016 switch (ABI) {
12017 default:
12018 llvm_unreachable("Unexpected ABI");
12019 case RISCVABI::ABI_ILP32:
12020 case RISCVABI::ABI_LP64:
12021 break;
12022 case RISCVABI::ABI_ILP32F:
12023 case RISCVABI::ABI_LP64F:
12024 UseGPRForF16_F32 = !IsFixed;
12025 break;
12026 case RISCVABI::ABI_ILP32D:
12027 case RISCVABI::ABI_LP64D:
12028 UseGPRForF16_F32 = !IsFixed;
12029 UseGPRForF64 = !IsFixed;
12030 break;
12031 }
12032
12033 // FPR16, FPR32, and FPR64 alias each other.
12034 if (State.getFirstUnallocated(ArgFPR32s) == std::size(ArgFPR32s)) {
12035 UseGPRForF16_F32 = true;
12036 UseGPRForF64 = true;
12037 }
12038
12039 // From this point on, rely on UseGPRForF16_F32, UseGPRForF64 and
12040 // similar local variables rather than directly checking against the target
12041 // ABI.
12042
12043 if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::f32)) {
12044 LocVT = XLenVT;
12045 LocInfo = CCValAssign::BCvt;
12046 } else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) {
12047 LocVT = MVT::i64;
12048 LocInfo = CCValAssign::BCvt;
12049 }
12050
12051 // If this is a variadic argument, the RISC-V calling convention requires
12052 // that it is assigned an 'even' or 'aligned' register if it has 8-byte
12053 // alignment (RV32) or 16-byte alignment (RV64). An aligned register should
12054 // be used regardless of whether the original argument was split during
12055 // legalisation or not. The argument will not be passed by registers if the
12056 // original type is larger than 2*XLEN, so the register alignment rule does
12057 // not apply.
12058 unsigned TwoXLenInBytes = (2 * XLen) / 8;
12059 if (!IsFixed && ArgFlags.getNonZeroOrigAlign() == TwoXLenInBytes &&
12060 DL.getTypeAllocSize(OrigTy) == TwoXLenInBytes) {
12061 unsigned RegIdx = State.getFirstUnallocated(ArgGPRs);
12062 // Skip 'odd' register if necessary.
12063 if (RegIdx != std::size(ArgGPRs) && RegIdx % 2 == 1)
12064 State.AllocateReg(ArgGPRs);
12065 }
12066
12067 SmallVectorImpl<CCValAssign> &PendingLocs = State.getPendingLocs();
12068 SmallVectorImpl<ISD::ArgFlagsTy> &PendingArgFlags =
12069 State.getPendingArgFlags();
12070
12071 assert(PendingLocs.size() == PendingArgFlags.size() &&
12072 "PendingLocs and PendingArgFlags out of sync");
12073
12074 // Handle passing f64 on RV32D with a soft float ABI or when floating point
12075 // registers are exhausted.
12076 if (UseGPRForF64 && XLen == 32 && ValVT == MVT::f64) {
12077 assert(!ArgFlags.isSplit() && PendingLocs.empty() &&
12078 "Can't lower f64 if it is split");
12079 // Depending on available argument GPRS, f64 may be passed in a pair of
12080 // GPRs, split between a GPR and the stack, or passed completely on the
12081 // stack. LowerCall/LowerFormalArguments/LowerReturn must recognise these
12082 // cases.
12083 Register Reg = State.AllocateReg(ArgGPRs);
12084 LocVT = MVT::i32;
12085 if (!Reg) {
12086 unsigned StackOffset = State.AllocateStack(8, Align(8));
12087 State.addLoc(
12088 CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
12089 return false;
12090 }
12091 if (!State.AllocateReg(ArgGPRs))
12092 State.AllocateStack(4, Align(4));
12093 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12094 return false;
12095 }
12096
12097 // Fixed-length vectors are located in the corresponding scalable-vector
12098 // container types.
12099 if (ValVT.isFixedLengthVector())
12100 LocVT = TLI.getContainerForFixedLengthVector(LocVT);
12101
12102 // Split arguments might be passed indirectly, so keep track of the pending
12103 // values. Split vectors are passed via a mix of registers and indirectly, so
12104 // treat them as we would any other argument.
12105 if (ValVT.isScalarInteger() && (ArgFlags.isSplit() || !PendingLocs.empty())) {
12106 LocVT = XLenVT;
12107 LocInfo = CCValAssign::Indirect;
12108 PendingLocs.push_back(
12109 CCValAssign::getPending(ValNo, ValVT, LocVT, LocInfo));
12110 PendingArgFlags.push_back(ArgFlags);
12111 if (!ArgFlags.isSplitEnd()) {
12112 return false;
12113 }
12114 }
12115
12116 // If the split argument only had two elements, it should be passed directly
12117 // in registers or on the stack.
12118 if (ValVT.isScalarInteger() && ArgFlags.isSplitEnd() &&
12119 PendingLocs.size() <= 2) {
12120 assert(PendingLocs.size() == 2 && "Unexpected PendingLocs.size()");
12121 // Apply the normal calling convention rules to the first half of the
12122 // split argument.
12123 CCValAssign VA = PendingLocs[0];
12124 ISD::ArgFlagsTy AF = PendingArgFlags[0];
12125 PendingLocs.clear();
12126 PendingArgFlags.clear();
12127 return CC_RISCVAssign2XLen(XLen, State, VA, AF, ValNo, ValVT, LocVT,
12128 ArgFlags);
12129 }
12130
12131 // Allocate to a register if possible, or else a stack slot.
12132 Register Reg;
12133 unsigned StoreSizeBytes = XLen / 8;
12134 Align StackAlign = Align(XLen / 8);
12135
12136 if (ValVT == MVT::f16 && !UseGPRForF16_F32)
12137 Reg = State.AllocateReg(ArgFPR16s);
12138 else if (ValVT == MVT::f32 && !UseGPRForF16_F32)
12139 Reg = State.AllocateReg(ArgFPR32s);
12140 else if (ValVT == MVT::f64 && !UseGPRForF64)
12141 Reg = State.AllocateReg(ArgFPR64s);
12142 else if (ValVT.isVector()) {
12143 Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI);
12144 if (!Reg) {
12145 // For return values, the vector must be passed fully via registers or
12146 // via the stack.
12147 // FIXME: The proposed vector ABI only mandates v8-v15 for return values,
12148 // but we're using all of them.
12149 if (IsRet)
12150 return true;
12151 // Try using a GPR to pass the address
12152 if ((Reg = State.AllocateReg(ArgGPRs))) {
12153 LocVT = XLenVT;
12154 LocInfo = CCValAssign::Indirect;
12155 } else if (ValVT.isScalableVector()) {
12156 LocVT = XLenVT;
12157 LocInfo = CCValAssign::Indirect;
12158 } else {
12159 // Pass fixed-length vectors on the stack.
12160 LocVT = ValVT;
12161 StoreSizeBytes = ValVT.getStoreSize();
12162 // Align vectors to their element sizes, being careful for vXi1
12163 // vectors.
12164 StackAlign = MaybeAlign(ValVT.getScalarSizeInBits() / 8).valueOrOne();
12165 }
12166 }
12167 } else {
12168 Reg = State.AllocateReg(ArgGPRs);
12169 }
12170
12171 unsigned StackOffset =
12172 Reg ? 0 : State.AllocateStack(StoreSizeBytes, StackAlign);
12173
12174 // If we reach this point and PendingLocs is non-empty, we must be at the
12175 // end of a split argument that must be passed indirectly.
12176 if (!PendingLocs.empty()) {
12177 assert(ArgFlags.isSplitEnd() && "Expected ArgFlags.isSplitEnd()");
12178 assert(PendingLocs.size() > 2 && "Unexpected PendingLocs.size()");
12179
12180 for (auto &It : PendingLocs) {
12181 if (Reg)
12182 It.convertToReg(Reg);
12183 else
12184 It.convertToMem(StackOffset);
12185 State.addLoc(It);
12186 }
12187 PendingLocs.clear();
12188 PendingArgFlags.clear();
12189 return false;
12190 }
12191
12192 assert((!UseGPRForF16_F32 || !UseGPRForF64 || LocVT == XLenVT ||
12193 (TLI.getSubtarget().hasVInstructions() && ValVT.isVector())) &&
12194 "Expected an XLenVT or vector types at this stage");
12195
12196 if (Reg) {
12197 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12198 return false;
12199 }
12200
12201 // When a floating-point value is passed on the stack, no bit-conversion is
12202 // needed.
12203 if (ValVT.isFloatingPoint()) {
12204 LocVT = ValVT;
12205 LocInfo = CCValAssign::Full;
12206 }
12207 State.addLoc(CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
12208 return false;
12209 }
12210
12211 template <typename ArgTy>
preAssignMask(const ArgTy & Args)12212 static std::optional<unsigned> preAssignMask(const ArgTy &Args) {
12213 for (const auto &ArgIdx : enumerate(Args)) {
12214 MVT ArgVT = ArgIdx.value().VT;
12215 if (ArgVT.isVector() && ArgVT.getVectorElementType() == MVT::i1)
12216 return ArgIdx.index();
12217 }
12218 return std::nullopt;
12219 }
12220
analyzeInputArgs(MachineFunction & MF,CCState & CCInfo,const SmallVectorImpl<ISD::InputArg> & Ins,bool IsRet,RISCVCCAssignFn Fn) const12221 void RISCVTargetLowering::analyzeInputArgs(
12222 MachineFunction &MF, CCState &CCInfo,
12223 const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
12224 RISCVCCAssignFn Fn) const {
12225 unsigned NumArgs = Ins.size();
12226 FunctionType *FType = MF.getFunction().getFunctionType();
12227
12228 std::optional<unsigned> FirstMaskArgument;
12229 if (Subtarget.hasVInstructions())
12230 FirstMaskArgument = preAssignMask(Ins);
12231
12232 for (unsigned i = 0; i != NumArgs; ++i) {
12233 MVT ArgVT = Ins[i].VT;
12234 ISD::ArgFlagsTy ArgFlags = Ins[i].Flags;
12235
12236 Type *ArgTy = nullptr;
12237 if (IsRet)
12238 ArgTy = FType->getReturnType();
12239 else if (Ins[i].isOrigArg())
12240 ArgTy = FType->getParamType(Ins[i].getOrigArgIndex());
12241
12242 RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
12243 if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
12244 ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
12245 FirstMaskArgument)) {
12246 LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type "
12247 << EVT(ArgVT).getEVTString() << '\n');
12248 llvm_unreachable(nullptr);
12249 }
12250 }
12251 }
12252
analyzeOutputArgs(MachineFunction & MF,CCState & CCInfo,const SmallVectorImpl<ISD::OutputArg> & Outs,bool IsRet,CallLoweringInfo * CLI,RISCVCCAssignFn Fn) const12253 void RISCVTargetLowering::analyzeOutputArgs(
12254 MachineFunction &MF, CCState &CCInfo,
12255 const SmallVectorImpl<ISD::OutputArg> &Outs, bool IsRet,
12256 CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
12257 unsigned NumArgs = Outs.size();
12258
12259 std::optional<unsigned> FirstMaskArgument;
12260 if (Subtarget.hasVInstructions())
12261 FirstMaskArgument = preAssignMask(Outs);
12262
12263 for (unsigned i = 0; i != NumArgs; i++) {
12264 MVT ArgVT = Outs[i].VT;
12265 ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
12266 Type *OrigTy = CLI ? CLI->getArgs()[Outs[i].OrigArgIndex].Ty : nullptr;
12267
12268 RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
12269 if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
12270 ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
12271 FirstMaskArgument)) {
12272 LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type "
12273 << EVT(ArgVT).getEVTString() << "\n");
12274 llvm_unreachable(nullptr);
12275 }
12276 }
12277 }
12278
12279 // Convert Val to a ValVT. Should not be called for CCValAssign::Indirect
12280 // values.
convertLocVTToValVT(SelectionDAG & DAG,SDValue Val,const CCValAssign & VA,const SDLoc & DL,const RISCVSubtarget & Subtarget)12281 static SDValue convertLocVTToValVT(SelectionDAG &DAG, SDValue Val,
12282 const CCValAssign &VA, const SDLoc &DL,
12283 const RISCVSubtarget &Subtarget) {
12284 switch (VA.getLocInfo()) {
12285 default:
12286 llvm_unreachable("Unexpected CCValAssign::LocInfo");
12287 case CCValAssign::Full:
12288 if (VA.getValVT().isFixedLengthVector() && VA.getLocVT().isScalableVector())
12289 Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget);
12290 break;
12291 case CCValAssign::BCvt:
12292 if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16)
12293 Val = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, Val);
12294 else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32)
12295 Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val);
12296 else
12297 Val = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), Val);
12298 break;
12299 }
12300 return Val;
12301 }
12302
12303 // The caller is responsible for loading the full value if the argument is
12304 // passed with CCValAssign::Indirect.
unpackFromRegLoc(SelectionDAG & DAG,SDValue Chain,const CCValAssign & VA,const SDLoc & DL,const ISD::InputArg & In,const RISCVTargetLowering & TLI)12305 static SDValue unpackFromRegLoc(SelectionDAG &DAG, SDValue Chain,
12306 const CCValAssign &VA, const SDLoc &DL,
12307 const ISD::InputArg &In,
12308 const RISCVTargetLowering &TLI) {
12309 MachineFunction &MF = DAG.getMachineFunction();
12310 MachineRegisterInfo &RegInfo = MF.getRegInfo();
12311 EVT LocVT = VA.getLocVT();
12312 SDValue Val;
12313 const TargetRegisterClass *RC = TLI.getRegClassFor(LocVT.getSimpleVT());
12314 Register VReg = RegInfo.createVirtualRegister(RC);
12315 RegInfo.addLiveIn(VA.getLocReg(), VReg);
12316 Val = DAG.getCopyFromReg(Chain, DL, VReg, LocVT);
12317
12318 // If input is sign extended from 32 bits, note it for the SExtWRemoval pass.
12319 if (In.isOrigArg()) {
12320 Argument *OrigArg = MF.getFunction().getArg(In.getOrigArgIndex());
12321 if (OrigArg->getType()->isIntegerTy()) {
12322 unsigned BitWidth = OrigArg->getType()->getIntegerBitWidth();
12323 // An input zero extended from i31 can also be considered sign extended.
12324 if ((BitWidth <= 32 && In.Flags.isSExt()) ||
12325 (BitWidth < 32 && In.Flags.isZExt())) {
12326 RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
12327 RVFI->addSExt32Register(VReg);
12328 }
12329 }
12330 }
12331
12332 if (VA.getLocInfo() == CCValAssign::Indirect)
12333 return Val;
12334
12335 return convertLocVTToValVT(DAG, Val, VA, DL, TLI.getSubtarget());
12336 }
12337
convertValVTToLocVT(SelectionDAG & DAG,SDValue Val,const CCValAssign & VA,const SDLoc & DL,const RISCVSubtarget & Subtarget)12338 static SDValue convertValVTToLocVT(SelectionDAG &DAG, SDValue Val,
12339 const CCValAssign &VA, const SDLoc &DL,
12340 const RISCVSubtarget &Subtarget) {
12341 EVT LocVT = VA.getLocVT();
12342
12343 switch (VA.getLocInfo()) {
12344 default:
12345 llvm_unreachable("Unexpected CCValAssign::LocInfo");
12346 case CCValAssign::Full:
12347 if (VA.getValVT().isFixedLengthVector() && LocVT.isScalableVector())
12348 Val = convertToScalableVector(LocVT, Val, DAG, Subtarget);
12349 break;
12350 case CCValAssign::BCvt:
12351 if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16)
12352 Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, VA.getLocVT(), Val);
12353 else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32)
12354 Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val);
12355 else
12356 Val = DAG.getNode(ISD::BITCAST, DL, LocVT, Val);
12357 break;
12358 }
12359 return Val;
12360 }
12361
12362 // The caller is responsible for loading the full value if the argument is
12363 // passed with CCValAssign::Indirect.
unpackFromMemLoc(SelectionDAG & DAG,SDValue Chain,const CCValAssign & VA,const SDLoc & DL)12364 static SDValue unpackFromMemLoc(SelectionDAG &DAG, SDValue Chain,
12365 const CCValAssign &VA, const SDLoc &DL) {
12366 MachineFunction &MF = DAG.getMachineFunction();
12367 MachineFrameInfo &MFI = MF.getFrameInfo();
12368 EVT LocVT = VA.getLocVT();
12369 EVT ValVT = VA.getValVT();
12370 EVT PtrVT = MVT::getIntegerVT(DAG.getDataLayout().getPointerSizeInBits(0));
12371 if (ValVT.isScalableVector()) {
12372 // When the value is a scalable vector, we save the pointer which points to
12373 // the scalable vector value in the stack. The ValVT will be the pointer
12374 // type, instead of the scalable vector type.
12375 ValVT = LocVT;
12376 }
12377 int FI = MFI.CreateFixedObject(ValVT.getStoreSize(), VA.getLocMemOffset(),
12378 /*IsImmutable=*/true);
12379 SDValue FIN = DAG.getFrameIndex(FI, PtrVT);
12380 SDValue Val;
12381
12382 ISD::LoadExtType ExtType;
12383 switch (VA.getLocInfo()) {
12384 default:
12385 llvm_unreachable("Unexpected CCValAssign::LocInfo");
12386 case CCValAssign::Full:
12387 case CCValAssign::Indirect:
12388 case CCValAssign::BCvt:
12389 ExtType = ISD::NON_EXTLOAD;
12390 break;
12391 }
12392 Val = DAG.getExtLoad(
12393 ExtType, DL, LocVT, Chain, FIN,
12394 MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI), ValVT);
12395 return Val;
12396 }
12397
unpackF64OnRV32DSoftABI(SelectionDAG & DAG,SDValue Chain,const CCValAssign & VA,const SDLoc & DL)12398 static SDValue unpackF64OnRV32DSoftABI(SelectionDAG &DAG, SDValue Chain,
12399 const CCValAssign &VA, const SDLoc &DL) {
12400 assert(VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64 &&
12401 "Unexpected VA");
12402 MachineFunction &MF = DAG.getMachineFunction();
12403 MachineFrameInfo &MFI = MF.getFrameInfo();
12404 MachineRegisterInfo &RegInfo = MF.getRegInfo();
12405
12406 if (VA.isMemLoc()) {
12407 // f64 is passed on the stack.
12408 int FI =
12409 MFI.CreateFixedObject(8, VA.getLocMemOffset(), /*IsImmutable=*/true);
12410 SDValue FIN = DAG.getFrameIndex(FI, MVT::i32);
12411 return DAG.getLoad(MVT::f64, DL, Chain, FIN,
12412 MachinePointerInfo::getFixedStack(MF, FI));
12413 }
12414
12415 assert(VA.isRegLoc() && "Expected register VA assignment");
12416
12417 Register LoVReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass);
12418 RegInfo.addLiveIn(VA.getLocReg(), LoVReg);
12419 SDValue Lo = DAG.getCopyFromReg(Chain, DL, LoVReg, MVT::i32);
12420 SDValue Hi;
12421 if (VA.getLocReg() == RISCV::X17) {
12422 // Second half of f64 is passed on the stack.
12423 int FI = MFI.CreateFixedObject(4, 0, /*IsImmutable=*/true);
12424 SDValue FIN = DAG.getFrameIndex(FI, MVT::i32);
12425 Hi = DAG.getLoad(MVT::i32, DL, Chain, FIN,
12426 MachinePointerInfo::getFixedStack(MF, FI));
12427 } else {
12428 // Second half of f64 is passed in another GPR.
12429 Register HiVReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass);
12430 RegInfo.addLiveIn(VA.getLocReg() + 1, HiVReg);
12431 Hi = DAG.getCopyFromReg(Chain, DL, HiVReg, MVT::i32);
12432 }
12433 return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi);
12434 }
12435
12436 // FastCC has less than 1% performance improvement for some particular
12437 // benchmark. But theoretically, it may has benenfit for some cases.
CC_RISCV_FastCC(const DataLayout & DL,RISCVABI::ABI ABI,unsigned ValNo,MVT ValVT,MVT LocVT,CCValAssign::LocInfo LocInfo,ISD::ArgFlagsTy ArgFlags,CCState & State,bool IsFixed,bool IsRet,Type * OrigTy,const RISCVTargetLowering & TLI,std::optional<unsigned> FirstMaskArgument)12438 static bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
12439 unsigned ValNo, MVT ValVT, MVT LocVT,
12440 CCValAssign::LocInfo LocInfo,
12441 ISD::ArgFlagsTy ArgFlags, CCState &State,
12442 bool IsFixed, bool IsRet, Type *OrigTy,
12443 const RISCVTargetLowering &TLI,
12444 std::optional<unsigned> FirstMaskArgument) {
12445
12446 // X5 and X6 might be used for save-restore libcall.
12447 static const MCPhysReg GPRList[] = {
12448 RISCV::X10, RISCV::X11, RISCV::X12, RISCV::X13, RISCV::X14,
12449 RISCV::X15, RISCV::X16, RISCV::X17, RISCV::X7, RISCV::X28,
12450 RISCV::X29, RISCV::X30, RISCV::X31};
12451
12452 if (LocVT == MVT::i32 || LocVT == MVT::i64) {
12453 if (unsigned Reg = State.AllocateReg(GPRList)) {
12454 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12455 return false;
12456 }
12457 }
12458
12459 if (LocVT == MVT::f16) {
12460 static const MCPhysReg FPR16List[] = {
12461 RISCV::F10_H, RISCV::F11_H, RISCV::F12_H, RISCV::F13_H, RISCV::F14_H,
12462 RISCV::F15_H, RISCV::F16_H, RISCV::F17_H, RISCV::F0_H, RISCV::F1_H,
12463 RISCV::F2_H, RISCV::F3_H, RISCV::F4_H, RISCV::F5_H, RISCV::F6_H,
12464 RISCV::F7_H, RISCV::F28_H, RISCV::F29_H, RISCV::F30_H, RISCV::F31_H};
12465 if (unsigned Reg = State.AllocateReg(FPR16List)) {
12466 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12467 return false;
12468 }
12469 }
12470
12471 if (LocVT == MVT::f32) {
12472 static const MCPhysReg FPR32List[] = {
12473 RISCV::F10_F, RISCV::F11_F, RISCV::F12_F, RISCV::F13_F, RISCV::F14_F,
12474 RISCV::F15_F, RISCV::F16_F, RISCV::F17_F, RISCV::F0_F, RISCV::F1_F,
12475 RISCV::F2_F, RISCV::F3_F, RISCV::F4_F, RISCV::F5_F, RISCV::F6_F,
12476 RISCV::F7_F, RISCV::F28_F, RISCV::F29_F, RISCV::F30_F, RISCV::F31_F};
12477 if (unsigned Reg = State.AllocateReg(FPR32List)) {
12478 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12479 return false;
12480 }
12481 }
12482
12483 if (LocVT == MVT::f64) {
12484 static const MCPhysReg FPR64List[] = {
12485 RISCV::F10_D, RISCV::F11_D, RISCV::F12_D, RISCV::F13_D, RISCV::F14_D,
12486 RISCV::F15_D, RISCV::F16_D, RISCV::F17_D, RISCV::F0_D, RISCV::F1_D,
12487 RISCV::F2_D, RISCV::F3_D, RISCV::F4_D, RISCV::F5_D, RISCV::F6_D,
12488 RISCV::F7_D, RISCV::F28_D, RISCV::F29_D, RISCV::F30_D, RISCV::F31_D};
12489 if (unsigned Reg = State.AllocateReg(FPR64List)) {
12490 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12491 return false;
12492 }
12493 }
12494
12495 if (LocVT == MVT::i32 || LocVT == MVT::f32) {
12496 unsigned Offset4 = State.AllocateStack(4, Align(4));
12497 State.addLoc(CCValAssign::getMem(ValNo, ValVT, Offset4, LocVT, LocInfo));
12498 return false;
12499 }
12500
12501 if (LocVT == MVT::i64 || LocVT == MVT::f64) {
12502 unsigned Offset5 = State.AllocateStack(8, Align(8));
12503 State.addLoc(CCValAssign::getMem(ValNo, ValVT, Offset5, LocVT, LocInfo));
12504 return false;
12505 }
12506
12507 if (LocVT.isVector()) {
12508 if (unsigned Reg =
12509 allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) {
12510 // Fixed-length vectors are located in the corresponding scalable-vector
12511 // container types.
12512 if (ValVT.isFixedLengthVector())
12513 LocVT = TLI.getContainerForFixedLengthVector(LocVT);
12514 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12515 } else {
12516 // Try and pass the address via a "fast" GPR.
12517 if (unsigned GPRReg = State.AllocateReg(GPRList)) {
12518 LocInfo = CCValAssign::Indirect;
12519 LocVT = TLI.getSubtarget().getXLenVT();
12520 State.addLoc(CCValAssign::getReg(ValNo, ValVT, GPRReg, LocVT, LocInfo));
12521 } else if (ValVT.isFixedLengthVector()) {
12522 auto StackAlign =
12523 MaybeAlign(ValVT.getScalarSizeInBits() / 8).valueOrOne();
12524 unsigned StackOffset =
12525 State.AllocateStack(ValVT.getStoreSize(), StackAlign);
12526 State.addLoc(
12527 CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
12528 } else {
12529 // Can't pass scalable vectors on the stack.
12530 return true;
12531 }
12532 }
12533
12534 return false;
12535 }
12536
12537 return true; // CC didn't match.
12538 }
12539
CC_RISCV_GHC(unsigned ValNo,MVT ValVT,MVT LocVT,CCValAssign::LocInfo LocInfo,ISD::ArgFlagsTy ArgFlags,CCState & State)12540 static bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
12541 CCValAssign::LocInfo LocInfo,
12542 ISD::ArgFlagsTy ArgFlags, CCState &State) {
12543
12544 if (ArgFlags.isNest()) {
12545 report_fatal_error(
12546 "Attribute 'nest' is not supported in GHC calling convention");
12547 }
12548
12549 if (LocVT == MVT::i32 || LocVT == MVT::i64) {
12550 // Pass in STG registers: Base, Sp, Hp, R1, R2, R3, R4, R5, R6, R7, SpLim
12551 // s1 s2 s3 s4 s5 s6 s7 s8 s9 s10 s11
12552 static const MCPhysReg GPRList[] = {
12553 RISCV::X9, RISCV::X18, RISCV::X19, RISCV::X20, RISCV::X21, RISCV::X22,
12554 RISCV::X23, RISCV::X24, RISCV::X25, RISCV::X26, RISCV::X27};
12555 if (unsigned Reg = State.AllocateReg(GPRList)) {
12556 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12557 return false;
12558 }
12559 }
12560
12561 if (LocVT == MVT::f32) {
12562 // Pass in STG registers: F1, ..., F6
12563 // fs0 ... fs5
12564 static const MCPhysReg FPR32List[] = {RISCV::F8_F, RISCV::F9_F,
12565 RISCV::F18_F, RISCV::F19_F,
12566 RISCV::F20_F, RISCV::F21_F};
12567 if (unsigned Reg = State.AllocateReg(FPR32List)) {
12568 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12569 return false;
12570 }
12571 }
12572
12573 if (LocVT == MVT::f64) {
12574 // Pass in STG registers: D1, ..., D6
12575 // fs6 ... fs11
12576 static const MCPhysReg FPR64List[] = {RISCV::F22_D, RISCV::F23_D,
12577 RISCV::F24_D, RISCV::F25_D,
12578 RISCV::F26_D, RISCV::F27_D};
12579 if (unsigned Reg = State.AllocateReg(FPR64List)) {
12580 State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
12581 return false;
12582 }
12583 }
12584
12585 report_fatal_error("No registers left in GHC calling convention");
12586 return true;
12587 }
12588
12589 // Transform physical registers into virtual registers.
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool IsVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const12590 SDValue RISCVTargetLowering::LowerFormalArguments(
12591 SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
12592 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
12593 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
12594
12595 MachineFunction &MF = DAG.getMachineFunction();
12596
12597 switch (CallConv) {
12598 default:
12599 report_fatal_error("Unsupported calling convention");
12600 case CallingConv::C:
12601 case CallingConv::Fast:
12602 break;
12603 case CallingConv::GHC:
12604 if (!MF.getSubtarget().getFeatureBits()[RISCV::FeatureStdExtF] ||
12605 !MF.getSubtarget().getFeatureBits()[RISCV::FeatureStdExtD])
12606 report_fatal_error(
12607 "GHC calling convention requires the F and D instruction set extensions");
12608 }
12609
12610 const Function &Func = MF.getFunction();
12611 if (Func.hasFnAttribute("interrupt")) {
12612 if (!Func.arg_empty())
12613 report_fatal_error(
12614 "Functions with the interrupt attribute cannot have arguments!");
12615
12616 StringRef Kind =
12617 MF.getFunction().getFnAttribute("interrupt").getValueAsString();
12618
12619 if (!(Kind == "user" || Kind == "supervisor" || Kind == "machine"))
12620 report_fatal_error(
12621 "Function interrupt attribute argument not supported!");
12622 }
12623
12624 EVT PtrVT = getPointerTy(DAG.getDataLayout());
12625 MVT XLenVT = Subtarget.getXLenVT();
12626 unsigned XLenInBytes = Subtarget.getXLen() / 8;
12627 // Used with vargs to acumulate store chains.
12628 std::vector<SDValue> OutChains;
12629
12630 // Assign locations to all of the incoming arguments.
12631 SmallVector<CCValAssign, 16> ArgLocs;
12632 CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
12633
12634 if (CallConv == CallingConv::GHC)
12635 CCInfo.AnalyzeFormalArguments(Ins, CC_RISCV_GHC);
12636 else
12637 analyzeInputArgs(MF, CCInfo, Ins, /*IsRet=*/false,
12638 CallConv == CallingConv::Fast ? CC_RISCV_FastCC
12639 : CC_RISCV);
12640
12641 for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
12642 CCValAssign &VA = ArgLocs[i];
12643 SDValue ArgValue;
12644 // Passing f64 on RV32D with a soft float ABI must be handled as a special
12645 // case.
12646 if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64)
12647 ArgValue = unpackF64OnRV32DSoftABI(DAG, Chain, VA, DL);
12648 else if (VA.isRegLoc())
12649 ArgValue = unpackFromRegLoc(DAG, Chain, VA, DL, Ins[i], *this);
12650 else
12651 ArgValue = unpackFromMemLoc(DAG, Chain, VA, DL);
12652
12653 if (VA.getLocInfo() == CCValAssign::Indirect) {
12654 // If the original argument was split and passed by reference (e.g. i128
12655 // on RV32), we need to load all parts of it here (using the same
12656 // address). Vectors may be partly split to registers and partly to the
12657 // stack, in which case the base address is partly offset and subsequent
12658 // stores are relative to that.
12659 InVals.push_back(DAG.getLoad(VA.getValVT(), DL, Chain, ArgValue,
12660 MachinePointerInfo()));
12661 unsigned ArgIndex = Ins[i].OrigArgIndex;
12662 unsigned ArgPartOffset = Ins[i].PartOffset;
12663 assert(VA.getValVT().isVector() || ArgPartOffset == 0);
12664 while (i + 1 != e && Ins[i + 1].OrigArgIndex == ArgIndex) {
12665 CCValAssign &PartVA = ArgLocs[i + 1];
12666 unsigned PartOffset = Ins[i + 1].PartOffset - ArgPartOffset;
12667 SDValue Offset = DAG.getIntPtrConstant(PartOffset, DL);
12668 if (PartVA.getValVT().isScalableVector())
12669 Offset = DAG.getNode(ISD::VSCALE, DL, XLenVT, Offset);
12670 SDValue Address = DAG.getNode(ISD::ADD, DL, PtrVT, ArgValue, Offset);
12671 InVals.push_back(DAG.getLoad(PartVA.getValVT(), DL, Chain, Address,
12672 MachinePointerInfo()));
12673 ++i;
12674 }
12675 continue;
12676 }
12677 InVals.push_back(ArgValue);
12678 }
12679
12680 if (any_of(ArgLocs,
12681 [](CCValAssign &VA) { return VA.getLocVT().isScalableVector(); }))
12682 MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
12683
12684 if (IsVarArg) {
12685 ArrayRef<MCPhysReg> ArgRegs = ArrayRef(ArgGPRs);
12686 unsigned Idx = CCInfo.getFirstUnallocated(ArgRegs);
12687 const TargetRegisterClass *RC = &RISCV::GPRRegClass;
12688 MachineFrameInfo &MFI = MF.getFrameInfo();
12689 MachineRegisterInfo &RegInfo = MF.getRegInfo();
12690 RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
12691
12692 // Offset of the first variable argument from stack pointer, and size of
12693 // the vararg save area. For now, the varargs save area is either zero or
12694 // large enough to hold a0-a7.
12695 int VaArgOffset, VarArgsSaveSize;
12696
12697 // If all registers are allocated, then all varargs must be passed on the
12698 // stack and we don't need to save any argregs.
12699 if (ArgRegs.size() == Idx) {
12700 VaArgOffset = CCInfo.getNextStackOffset();
12701 VarArgsSaveSize = 0;
12702 } else {
12703 VarArgsSaveSize = XLenInBytes * (ArgRegs.size() - Idx);
12704 VaArgOffset = -VarArgsSaveSize;
12705 }
12706
12707 // Record the frame index of the first variable argument
12708 // which is a value necessary to VASTART.
12709 int FI = MFI.CreateFixedObject(XLenInBytes, VaArgOffset, true);
12710 RVFI->setVarArgsFrameIndex(FI);
12711
12712 // If saving an odd number of registers then create an extra stack slot to
12713 // ensure that the frame pointer is 2*XLEN-aligned, which in turn ensures
12714 // offsets to even-numbered registered remain 2*XLEN-aligned.
12715 if (Idx % 2) {
12716 MFI.CreateFixedObject(XLenInBytes, VaArgOffset - (int)XLenInBytes, true);
12717 VarArgsSaveSize += XLenInBytes;
12718 }
12719
12720 // Copy the integer registers that may have been used for passing varargs
12721 // to the vararg save area.
12722 for (unsigned I = Idx; I < ArgRegs.size();
12723 ++I, VaArgOffset += XLenInBytes) {
12724 const Register Reg = RegInfo.createVirtualRegister(RC);
12725 RegInfo.addLiveIn(ArgRegs[I], Reg);
12726 SDValue ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, XLenVT);
12727 FI = MFI.CreateFixedObject(XLenInBytes, VaArgOffset, true);
12728 SDValue PtrOff = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
12729 SDValue Store = DAG.getStore(Chain, DL, ArgValue, PtrOff,
12730 MachinePointerInfo::getFixedStack(MF, FI));
12731 cast<StoreSDNode>(Store.getNode())
12732 ->getMemOperand()
12733 ->setValue((Value *)nullptr);
12734 OutChains.push_back(Store);
12735 }
12736 RVFI->setVarArgsSaveSize(VarArgsSaveSize);
12737 }
12738
12739 // All stores are grouped in one node to allow the matching between
12740 // the size of Ins and InVals. This only happens for vararg functions.
12741 if (!OutChains.empty()) {
12742 OutChains.push_back(Chain);
12743 Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, OutChains);
12744 }
12745
12746 return Chain;
12747 }
12748
12749 /// isEligibleForTailCallOptimization - Check whether the call is eligible
12750 /// for tail call optimization.
12751 /// Note: This is modelled after ARM's IsEligibleForTailCallOptimization.
isEligibleForTailCallOptimization(CCState & CCInfo,CallLoweringInfo & CLI,MachineFunction & MF,const SmallVector<CCValAssign,16> & ArgLocs) const12752 bool RISCVTargetLowering::isEligibleForTailCallOptimization(
12753 CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
12754 const SmallVector<CCValAssign, 16> &ArgLocs) const {
12755
12756 auto &Callee = CLI.Callee;
12757 auto CalleeCC = CLI.CallConv;
12758 auto &Outs = CLI.Outs;
12759 auto &Caller = MF.getFunction();
12760 auto CallerCC = Caller.getCallingConv();
12761
12762 // Exception-handling functions need a special set of instructions to
12763 // indicate a return to the hardware. Tail-calling another function would
12764 // probably break this.
12765 // TODO: The "interrupt" attribute isn't currently defined by RISC-V. This
12766 // should be expanded as new function attributes are introduced.
12767 if (Caller.hasFnAttribute("interrupt"))
12768 return false;
12769
12770 // Do not tail call opt if the stack is used to pass parameters.
12771 if (CCInfo.getNextStackOffset() != 0)
12772 return false;
12773
12774 // Do not tail call opt if any parameters need to be passed indirectly.
12775 // Since long doubles (fp128) and i128 are larger than 2*XLEN, they are
12776 // passed indirectly. So the address of the value will be passed in a
12777 // register, or if not available, then the address is put on the stack. In
12778 // order to pass indirectly, space on the stack often needs to be allocated
12779 // in order to store the value. In this case the CCInfo.getNextStackOffset()
12780 // != 0 check is not enough and we need to check if any CCValAssign ArgsLocs
12781 // are passed CCValAssign::Indirect.
12782 for (auto &VA : ArgLocs)
12783 if (VA.getLocInfo() == CCValAssign::Indirect)
12784 return false;
12785
12786 // Do not tail call opt if either caller or callee uses struct return
12787 // semantics.
12788 auto IsCallerStructRet = Caller.hasStructRetAttr();
12789 auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
12790 if (IsCallerStructRet || IsCalleeStructRet)
12791 return false;
12792
12793 // Externally-defined functions with weak linkage should not be
12794 // tail-called. The behaviour of branch instructions in this situation (as
12795 // used for tail calls) is implementation-defined, so we cannot rely on the
12796 // linker replacing the tail call with a return.
12797 if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
12798 const GlobalValue *GV = G->getGlobal();
12799 if (GV->hasExternalWeakLinkage())
12800 return false;
12801 }
12802
12803 // The callee has to preserve all registers the caller needs to preserve.
12804 const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
12805 const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
12806 if (CalleeCC != CallerCC) {
12807 const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
12808 if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
12809 return false;
12810 }
12811
12812 // Byval parameters hand the function a pointer directly into the stack area
12813 // we want to reuse during a tail call. Working around this *is* possible
12814 // but less efficient and uglier in LowerCall.
12815 for (auto &Arg : Outs)
12816 if (Arg.Flags.isByVal())
12817 return false;
12818
12819 return true;
12820 }
12821
getPrefTypeAlign(EVT VT,SelectionDAG & DAG)12822 static Align getPrefTypeAlign(EVT VT, SelectionDAG &DAG) {
12823 return DAG.getDataLayout().getPrefTypeAlign(
12824 VT.getTypeForEVT(*DAG.getContext()));
12825 }
12826
12827 // Lower a call to a callseq_start + CALL + callseq_end chain, and add input
12828 // and output parameter nodes.
LowerCall(CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const12829 SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
12830 SmallVectorImpl<SDValue> &InVals) const {
12831 SelectionDAG &DAG = CLI.DAG;
12832 SDLoc &DL = CLI.DL;
12833 SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
12834 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
12835 SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
12836 SDValue Chain = CLI.Chain;
12837 SDValue Callee = CLI.Callee;
12838 bool &IsTailCall = CLI.IsTailCall;
12839 CallingConv::ID CallConv = CLI.CallConv;
12840 bool IsVarArg = CLI.IsVarArg;
12841 EVT PtrVT = getPointerTy(DAG.getDataLayout());
12842 MVT XLenVT = Subtarget.getXLenVT();
12843
12844 MachineFunction &MF = DAG.getMachineFunction();
12845
12846 // Analyze the operands of the call, assigning locations to each operand.
12847 SmallVector<CCValAssign, 16> ArgLocs;
12848 CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
12849
12850 if (CallConv == CallingConv::GHC)
12851 ArgCCInfo.AnalyzeCallOperands(Outs, CC_RISCV_GHC);
12852 else
12853 analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI,
12854 CallConv == CallingConv::Fast ? CC_RISCV_FastCC
12855 : CC_RISCV);
12856
12857 // Check if it's really possible to do a tail call.
12858 if (IsTailCall)
12859 IsTailCall = isEligibleForTailCallOptimization(ArgCCInfo, CLI, MF, ArgLocs);
12860
12861 if (IsTailCall)
12862 ++NumTailCalls;
12863 else if (CLI.CB && CLI.CB->isMustTailCall())
12864 report_fatal_error("failed to perform tail call elimination on a call "
12865 "site marked musttail");
12866
12867 // Get a count of how many bytes are to be pushed on the stack.
12868 unsigned NumBytes = ArgCCInfo.getNextStackOffset();
12869
12870 // Create local copies for byval args
12871 SmallVector<SDValue, 8> ByValArgs;
12872 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
12873 ISD::ArgFlagsTy Flags = Outs[i].Flags;
12874 if (!Flags.isByVal())
12875 continue;
12876
12877 SDValue Arg = OutVals[i];
12878 unsigned Size = Flags.getByValSize();
12879 Align Alignment = Flags.getNonZeroByValAlign();
12880
12881 int FI =
12882 MF.getFrameInfo().CreateStackObject(Size, Alignment, /*isSS=*/false);
12883 SDValue FIPtr = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
12884 SDValue SizeNode = DAG.getConstant(Size, DL, XLenVT);
12885
12886 Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Alignment,
12887 /*IsVolatile=*/false,
12888 /*AlwaysInline=*/false, IsTailCall,
12889 MachinePointerInfo(), MachinePointerInfo());
12890 ByValArgs.push_back(FIPtr);
12891 }
12892
12893 if (!IsTailCall)
12894 Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);
12895
12896 // Copy argument values to their designated locations.
12897 SmallVector<std::pair<Register, SDValue>, 8> RegsToPass;
12898 SmallVector<SDValue, 8> MemOpChains;
12899 SDValue StackPtr;
12900 for (unsigned i = 0, j = 0, e = ArgLocs.size(); i != e; ++i) {
12901 CCValAssign &VA = ArgLocs[i];
12902 SDValue ArgValue = OutVals[i];
12903 ISD::ArgFlagsTy Flags = Outs[i].Flags;
12904
12905 // Handle passing f64 on RV32D with a soft float ABI as a special case.
12906 bool IsF64OnRV32DSoftABI =
12907 VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64;
12908 if (IsF64OnRV32DSoftABI && VA.isRegLoc()) {
12909 SDValue SplitF64 = DAG.getNode(
12910 RISCVISD::SplitF64, DL, DAG.getVTList(MVT::i32, MVT::i32), ArgValue);
12911 SDValue Lo = SplitF64.getValue(0);
12912 SDValue Hi = SplitF64.getValue(1);
12913
12914 Register RegLo = VA.getLocReg();
12915 RegsToPass.push_back(std::make_pair(RegLo, Lo));
12916
12917 if (RegLo == RISCV::X17) {
12918 // Second half of f64 is passed on the stack.
12919 // Work out the address of the stack slot.
12920 if (!StackPtr.getNode())
12921 StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, PtrVT);
12922 // Emit the store.
12923 MemOpChains.push_back(
12924 DAG.getStore(Chain, DL, Hi, StackPtr, MachinePointerInfo()));
12925 } else {
12926 // Second half of f64 is passed in another GPR.
12927 assert(RegLo < RISCV::X31 && "Invalid register pair");
12928 Register RegHigh = RegLo + 1;
12929 RegsToPass.push_back(std::make_pair(RegHigh, Hi));
12930 }
12931 continue;
12932 }
12933
12934 // IsF64OnRV32DSoftABI && VA.isMemLoc() is handled below in the same way
12935 // as any other MemLoc.
12936
12937 // Promote the value if needed.
12938 // For now, only handle fully promoted and indirect arguments.
12939 if (VA.getLocInfo() == CCValAssign::Indirect) {
12940 // Store the argument in a stack slot and pass its address.
12941 Align StackAlign =
12942 std::max(getPrefTypeAlign(Outs[i].ArgVT, DAG),
12943 getPrefTypeAlign(ArgValue.getValueType(), DAG));
12944 TypeSize StoredSize = ArgValue.getValueType().getStoreSize();
12945 // If the original argument was split (e.g. i128), we need
12946 // to store the required parts of it here (and pass just one address).
12947 // Vectors may be partly split to registers and partly to the stack, in
12948 // which case the base address is partly offset and subsequent stores are
12949 // relative to that.
12950 unsigned ArgIndex = Outs[i].OrigArgIndex;
12951 unsigned ArgPartOffset = Outs[i].PartOffset;
12952 assert(VA.getValVT().isVector() || ArgPartOffset == 0);
12953 // Calculate the total size to store. We don't have access to what we're
12954 // actually storing other than performing the loop and collecting the
12955 // info.
12956 SmallVector<std::pair<SDValue, SDValue>> Parts;
12957 while (i + 1 != e && Outs[i + 1].OrigArgIndex == ArgIndex) {
12958 SDValue PartValue = OutVals[i + 1];
12959 unsigned PartOffset = Outs[i + 1].PartOffset - ArgPartOffset;
12960 SDValue Offset = DAG.getIntPtrConstant(PartOffset, DL);
12961 EVT PartVT = PartValue.getValueType();
12962 if (PartVT.isScalableVector())
12963 Offset = DAG.getNode(ISD::VSCALE, DL, XLenVT, Offset);
12964 StoredSize += PartVT.getStoreSize();
12965 StackAlign = std::max(StackAlign, getPrefTypeAlign(PartVT, DAG));
12966 Parts.push_back(std::make_pair(PartValue, Offset));
12967 ++i;
12968 }
12969 SDValue SpillSlot = DAG.CreateStackTemporary(StoredSize, StackAlign);
12970 int FI = cast<FrameIndexSDNode>(SpillSlot)->getIndex();
12971 MemOpChains.push_back(
12972 DAG.getStore(Chain, DL, ArgValue, SpillSlot,
12973 MachinePointerInfo::getFixedStack(MF, FI)));
12974 for (const auto &Part : Parts) {
12975 SDValue PartValue = Part.first;
12976 SDValue PartOffset = Part.second;
12977 SDValue Address =
12978 DAG.getNode(ISD::ADD, DL, PtrVT, SpillSlot, PartOffset);
12979 MemOpChains.push_back(
12980 DAG.getStore(Chain, DL, PartValue, Address,
12981 MachinePointerInfo::getFixedStack(MF, FI)));
12982 }
12983 ArgValue = SpillSlot;
12984 } else {
12985 ArgValue = convertValVTToLocVT(DAG, ArgValue, VA, DL, Subtarget);
12986 }
12987
12988 // Use local copy if it is a byval arg.
12989 if (Flags.isByVal())
12990 ArgValue = ByValArgs[j++];
12991
12992 if (VA.isRegLoc()) {
12993 // Queue up the argument copies and emit them at the end.
12994 RegsToPass.push_back(std::make_pair(VA.getLocReg(), ArgValue));
12995 } else {
12996 assert(VA.isMemLoc() && "Argument not register or memory");
12997 assert(!IsTailCall && "Tail call not allowed if stack is used "
12998 "for passing parameters");
12999
13000 // Work out the address of the stack slot.
13001 if (!StackPtr.getNode())
13002 StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, PtrVT);
13003 SDValue Address =
13004 DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr,
13005 DAG.getIntPtrConstant(VA.getLocMemOffset(), DL));
13006
13007 // Emit the store.
13008 MemOpChains.push_back(
13009 DAG.getStore(Chain, DL, ArgValue, Address, MachinePointerInfo()));
13010 }
13011 }
13012
13013 // Join the stores, which are independent of one another.
13014 if (!MemOpChains.empty())
13015 Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
13016
13017 SDValue Glue;
13018
13019 // Build a sequence of copy-to-reg nodes, chained and glued together.
13020 for (auto &Reg : RegsToPass) {
13021 Chain = DAG.getCopyToReg(Chain, DL, Reg.first, Reg.second, Glue);
13022 Glue = Chain.getValue(1);
13023 }
13024
13025 // Validate that none of the argument registers have been marked as
13026 // reserved, if so report an error. Do the same for the return address if this
13027 // is not a tailcall.
13028 validateCCReservedRegs(RegsToPass, MF);
13029 if (!IsTailCall &&
13030 MF.getSubtarget<RISCVSubtarget>().isRegisterReservedByUser(RISCV::X1))
13031 MF.getFunction().getContext().diagnose(DiagnosticInfoUnsupported{
13032 MF.getFunction(),
13033 "Return address register required, but has been reserved."});
13034
13035 // If the callee is a GlobalAddress/ExternalSymbol node, turn it into a
13036 // TargetGlobalAddress/TargetExternalSymbol node so that legalize won't
13037 // split it and then direct call can be matched by PseudoCALL.
13038 if (GlobalAddressSDNode *S = dyn_cast<GlobalAddressSDNode>(Callee)) {
13039 const GlobalValue *GV = S->getGlobal();
13040
13041 unsigned OpFlags = RISCVII::MO_CALL;
13042 if (!getTargetMachine().shouldAssumeDSOLocal(*GV->getParent(), GV))
13043 OpFlags = RISCVII::MO_PLT;
13044
13045 Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, OpFlags);
13046 } else if (ExternalSymbolSDNode *S = dyn_cast<ExternalSymbolSDNode>(Callee)) {
13047 unsigned OpFlags = RISCVII::MO_CALL;
13048
13049 if (!getTargetMachine().shouldAssumeDSOLocal(*MF.getFunction().getParent(),
13050 nullptr))
13051 OpFlags = RISCVII::MO_PLT;
13052
13053 Callee = DAG.getTargetExternalSymbol(S->getSymbol(), PtrVT, OpFlags);
13054 }
13055
13056 // The first call operand is the chain and the second is the target address.
13057 SmallVector<SDValue, 8> Ops;
13058 Ops.push_back(Chain);
13059 Ops.push_back(Callee);
13060
13061 // Add argument registers to the end of the list so that they are
13062 // known live into the call.
13063 for (auto &Reg : RegsToPass)
13064 Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType()));
13065
13066 if (!IsTailCall) {
13067 // Add a register mask operand representing the call-preserved registers.
13068 const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
13069 const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
13070 assert(Mask && "Missing call preserved mask for calling convention");
13071 Ops.push_back(DAG.getRegisterMask(Mask));
13072 }
13073
13074 // Glue the call to the argument copies, if any.
13075 if (Glue.getNode())
13076 Ops.push_back(Glue);
13077
13078 // Emit the call.
13079 SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
13080
13081 if (IsTailCall) {
13082 MF.getFrameInfo().setHasTailCall();
13083 return DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
13084 }
13085
13086 Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
13087 DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
13088 Glue = Chain.getValue(1);
13089
13090 // Mark the end of the call, which is glued to the call itself.
13091 Chain = DAG.getCALLSEQ_END(Chain, NumBytes, 0, Glue, DL);
13092 Glue = Chain.getValue(1);
13093
13094 // Assign locations to each value returned by this call.
13095 SmallVector<CCValAssign, 16> RVLocs;
13096 CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext());
13097 analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, CC_RISCV);
13098
13099 // Copy all of the result registers out of their specified physreg.
13100 for (auto &VA : RVLocs) {
13101 // Copy the value out
13102 SDValue RetValue =
13103 DAG.getCopyFromReg(Chain, DL, VA.getLocReg(), VA.getLocVT(), Glue);
13104 // Glue the RetValue to the end of the call sequence
13105 Chain = RetValue.getValue(1);
13106 Glue = RetValue.getValue(2);
13107
13108 if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) {
13109 assert(VA.getLocReg() == ArgGPRs[0] && "Unexpected reg assignment");
13110 SDValue RetValue2 =
13111 DAG.getCopyFromReg(Chain, DL, ArgGPRs[1], MVT::i32, Glue);
13112 Chain = RetValue2.getValue(1);
13113 Glue = RetValue2.getValue(2);
13114 RetValue = DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, RetValue,
13115 RetValue2);
13116 }
13117
13118 RetValue = convertLocVTToValVT(DAG, RetValue, VA, DL, Subtarget);
13119
13120 InVals.push_back(RetValue);
13121 }
13122
13123 return Chain;
13124 }
13125
CanLowerReturn(CallingConv::ID CallConv,MachineFunction & MF,bool IsVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,LLVMContext & Context) const13126 bool RISCVTargetLowering::CanLowerReturn(
13127 CallingConv::ID CallConv, MachineFunction &MF, bool IsVarArg,
13128 const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext &Context) const {
13129 SmallVector<CCValAssign, 16> RVLocs;
13130 CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
13131
13132 std::optional<unsigned> FirstMaskArgument;
13133 if (Subtarget.hasVInstructions())
13134 FirstMaskArgument = preAssignMask(Outs);
13135
13136 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
13137 MVT VT = Outs[i].VT;
13138 ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
13139 RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
13140 if (CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
13141 ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr,
13142 *this, FirstMaskArgument))
13143 return false;
13144 }
13145 return true;
13146 }
13147
13148 SDValue
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool IsVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,const SDLoc & DL,SelectionDAG & DAG) const13149 RISCVTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
13150 bool IsVarArg,
13151 const SmallVectorImpl<ISD::OutputArg> &Outs,
13152 const SmallVectorImpl<SDValue> &OutVals,
13153 const SDLoc &DL, SelectionDAG &DAG) const {
13154 MachineFunction &MF = DAG.getMachineFunction();
13155 const RISCVSubtarget &STI = MF.getSubtarget<RISCVSubtarget>();
13156
13157 // Stores the assignment of the return value to a location.
13158 SmallVector<CCValAssign, 16> RVLocs;
13159
13160 // Info about the registers and stack slot.
13161 CCState CCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), RVLocs,
13162 *DAG.getContext());
13163
13164 analyzeOutputArgs(DAG.getMachineFunction(), CCInfo, Outs, /*IsRet=*/true,
13165 nullptr, CC_RISCV);
13166
13167 if (CallConv == CallingConv::GHC && !RVLocs.empty())
13168 report_fatal_error("GHC functions return void only");
13169
13170 SDValue Glue;
13171 SmallVector<SDValue, 4> RetOps(1, Chain);
13172
13173 // Copy the result values into the output registers.
13174 for (unsigned i = 0, e = RVLocs.size(); i < e; ++i) {
13175 SDValue Val = OutVals[i];
13176 CCValAssign &VA = RVLocs[i];
13177 assert(VA.isRegLoc() && "Can only return in registers!");
13178
13179 if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) {
13180 // Handle returning f64 on RV32D with a soft float ABI.
13181 assert(VA.isRegLoc() && "Expected return via registers");
13182 SDValue SplitF64 = DAG.getNode(RISCVISD::SplitF64, DL,
13183 DAG.getVTList(MVT::i32, MVT::i32), Val);
13184 SDValue Lo = SplitF64.getValue(0);
13185 SDValue Hi = SplitF64.getValue(1);
13186 Register RegLo = VA.getLocReg();
13187 assert(RegLo < RISCV::X31 && "Invalid register pair");
13188 Register RegHi = RegLo + 1;
13189
13190 if (STI.isRegisterReservedByUser(RegLo) ||
13191 STI.isRegisterReservedByUser(RegHi))
13192 MF.getFunction().getContext().diagnose(DiagnosticInfoUnsupported{
13193 MF.getFunction(),
13194 "Return value register required, but has been reserved."});
13195
13196 Chain = DAG.getCopyToReg(Chain, DL, RegLo, Lo, Glue);
13197 Glue = Chain.getValue(1);
13198 RetOps.push_back(DAG.getRegister(RegLo, MVT::i32));
13199 Chain = DAG.getCopyToReg(Chain, DL, RegHi, Hi, Glue);
13200 Glue = Chain.getValue(1);
13201 RetOps.push_back(DAG.getRegister(RegHi, MVT::i32));
13202 } else {
13203 // Handle a 'normal' return.
13204 Val = convertValVTToLocVT(DAG, Val, VA, DL, Subtarget);
13205 Chain = DAG.getCopyToReg(Chain, DL, VA.getLocReg(), Val, Glue);
13206
13207 if (STI.isRegisterReservedByUser(VA.getLocReg()))
13208 MF.getFunction().getContext().diagnose(DiagnosticInfoUnsupported{
13209 MF.getFunction(),
13210 "Return value register required, but has been reserved."});
13211
13212 // Guarantee that all emitted copies are stuck together.
13213 Glue = Chain.getValue(1);
13214 RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT()));
13215 }
13216 }
13217
13218 RetOps[0] = Chain; // Update chain.
13219
13220 // Add the glue node if we have it.
13221 if (Glue.getNode()) {
13222 RetOps.push_back(Glue);
13223 }
13224
13225 if (any_of(RVLocs,
13226 [](CCValAssign &VA) { return VA.getLocVT().isScalableVector(); }))
13227 MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
13228
13229 unsigned RetOpc = RISCVISD::RET_FLAG;
13230 // Interrupt service routines use different return instructions.
13231 const Function &Func = DAG.getMachineFunction().getFunction();
13232 if (Func.hasFnAttribute("interrupt")) {
13233 if (!Func.getReturnType()->isVoidTy())
13234 report_fatal_error(
13235 "Functions with the interrupt attribute must have void return type!");
13236
13237 MachineFunction &MF = DAG.getMachineFunction();
13238 StringRef Kind =
13239 MF.getFunction().getFnAttribute("interrupt").getValueAsString();
13240
13241 if (Kind == "user")
13242 RetOpc = RISCVISD::URET_FLAG;
13243 else if (Kind == "supervisor")
13244 RetOpc = RISCVISD::SRET_FLAG;
13245 else
13246 RetOpc = RISCVISD::MRET_FLAG;
13247 }
13248
13249 return DAG.getNode(RetOpc, DL, MVT::Other, RetOps);
13250 }
13251
validateCCReservedRegs(const SmallVectorImpl<std::pair<llvm::Register,llvm::SDValue>> & Regs,MachineFunction & MF) const13252 void RISCVTargetLowering::validateCCReservedRegs(
13253 const SmallVectorImpl<std::pair<llvm::Register, llvm::SDValue>> &Regs,
13254 MachineFunction &MF) const {
13255 const Function &F = MF.getFunction();
13256 const RISCVSubtarget &STI = MF.getSubtarget<RISCVSubtarget>();
13257
13258 if (llvm::any_of(Regs, [&STI](auto Reg) {
13259 return STI.isRegisterReservedByUser(Reg.first);
13260 }))
13261 F.getContext().diagnose(DiagnosticInfoUnsupported{
13262 F, "Argument register required, but has been reserved."});
13263 }
13264
13265 // Check if the result of the node is only used as a return value, as
13266 // otherwise we can't perform a tail-call.
isUsedByReturnOnly(SDNode * N,SDValue & Chain) const13267 bool RISCVTargetLowering::isUsedByReturnOnly(SDNode *N, SDValue &Chain) const {
13268 if (N->getNumValues() != 1)
13269 return false;
13270 if (!N->hasNUsesOfValue(1, 0))
13271 return false;
13272
13273 SDNode *Copy = *N->use_begin();
13274 // TODO: Handle additional opcodes in order to support tail-calling libcalls
13275 // with soft float ABIs.
13276 if (Copy->getOpcode() != ISD::CopyToReg) {
13277 return false;
13278 }
13279
13280 // If the ISD::CopyToReg has a glue operand, we conservatively assume it
13281 // isn't safe to perform a tail call.
13282 if (Copy->getOperand(Copy->getNumOperands() - 1).getValueType() == MVT::Glue)
13283 return false;
13284
13285 // The copy must be used by a RISCVISD::RET_FLAG, and nothing else.
13286 bool HasRet = false;
13287 for (SDNode *Node : Copy->uses()) {
13288 if (Node->getOpcode() != RISCVISD::RET_FLAG)
13289 return false;
13290 HasRet = true;
13291 }
13292 if (!HasRet)
13293 return false;
13294
13295 Chain = Copy->getOperand(0);
13296 return true;
13297 }
13298
mayBeEmittedAsTailCall(const CallInst * CI) const13299 bool RISCVTargetLowering::mayBeEmittedAsTailCall(const CallInst *CI) const {
13300 return CI->isTailCall();
13301 }
13302
getTargetNodeName(unsigned Opcode) const13303 const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
13304 #define NODE_NAME_CASE(NODE) \
13305 case RISCVISD::NODE: \
13306 return "RISCVISD::" #NODE;
13307 // clang-format off
13308 switch ((RISCVISD::NodeType)Opcode) {
13309 case RISCVISD::FIRST_NUMBER:
13310 break;
13311 NODE_NAME_CASE(RET_FLAG)
13312 NODE_NAME_CASE(URET_FLAG)
13313 NODE_NAME_CASE(SRET_FLAG)
13314 NODE_NAME_CASE(MRET_FLAG)
13315 NODE_NAME_CASE(CALL)
13316 NODE_NAME_CASE(SELECT_CC)
13317 NODE_NAME_CASE(BR_CC)
13318 NODE_NAME_CASE(BuildPairF64)
13319 NODE_NAME_CASE(SplitF64)
13320 NODE_NAME_CASE(TAIL)
13321 NODE_NAME_CASE(ADD_LO)
13322 NODE_NAME_CASE(HI)
13323 NODE_NAME_CASE(LLA)
13324 NODE_NAME_CASE(ADD_TPREL)
13325 NODE_NAME_CASE(LA)
13326 NODE_NAME_CASE(LA_TLS_IE)
13327 NODE_NAME_CASE(LA_TLS_GD)
13328 NODE_NAME_CASE(MULHSU)
13329 NODE_NAME_CASE(SLLW)
13330 NODE_NAME_CASE(SRAW)
13331 NODE_NAME_CASE(SRLW)
13332 NODE_NAME_CASE(DIVW)
13333 NODE_NAME_CASE(DIVUW)
13334 NODE_NAME_CASE(REMUW)
13335 NODE_NAME_CASE(ROLW)
13336 NODE_NAME_CASE(RORW)
13337 NODE_NAME_CASE(CLZW)
13338 NODE_NAME_CASE(CTZW)
13339 NODE_NAME_CASE(ABSW)
13340 NODE_NAME_CASE(FMV_H_X)
13341 NODE_NAME_CASE(FMV_X_ANYEXTH)
13342 NODE_NAME_CASE(FMV_X_SIGNEXTH)
13343 NODE_NAME_CASE(FMV_W_X_RV64)
13344 NODE_NAME_CASE(FMV_X_ANYEXTW_RV64)
13345 NODE_NAME_CASE(FCVT_X)
13346 NODE_NAME_CASE(FCVT_XU)
13347 NODE_NAME_CASE(FCVT_W_RV64)
13348 NODE_NAME_CASE(FCVT_WU_RV64)
13349 NODE_NAME_CASE(STRICT_FCVT_W_RV64)
13350 NODE_NAME_CASE(STRICT_FCVT_WU_RV64)
13351 NODE_NAME_CASE(FROUND)
13352 NODE_NAME_CASE(READ_CYCLE_WIDE)
13353 NODE_NAME_CASE(BREV8)
13354 NODE_NAME_CASE(ORC_B)
13355 NODE_NAME_CASE(ZIP)
13356 NODE_NAME_CASE(UNZIP)
13357 NODE_NAME_CASE(VMV_V_X_VL)
13358 NODE_NAME_CASE(VFMV_V_F_VL)
13359 NODE_NAME_CASE(VMV_X_S)
13360 NODE_NAME_CASE(VMV_S_X_VL)
13361 NODE_NAME_CASE(VFMV_S_F_VL)
13362 NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
13363 NODE_NAME_CASE(READ_VLENB)
13364 NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
13365 NODE_NAME_CASE(VSLIDEUP_VL)
13366 NODE_NAME_CASE(VSLIDE1UP_VL)
13367 NODE_NAME_CASE(VSLIDEDOWN_VL)
13368 NODE_NAME_CASE(VSLIDE1DOWN_VL)
13369 NODE_NAME_CASE(VID_VL)
13370 NODE_NAME_CASE(VFNCVT_ROD_VL)
13371 NODE_NAME_CASE(VECREDUCE_ADD_VL)
13372 NODE_NAME_CASE(VECREDUCE_UMAX_VL)
13373 NODE_NAME_CASE(VECREDUCE_SMAX_VL)
13374 NODE_NAME_CASE(VECREDUCE_UMIN_VL)
13375 NODE_NAME_CASE(VECREDUCE_SMIN_VL)
13376 NODE_NAME_CASE(VECREDUCE_AND_VL)
13377 NODE_NAME_CASE(VECREDUCE_OR_VL)
13378 NODE_NAME_CASE(VECREDUCE_XOR_VL)
13379 NODE_NAME_CASE(VECREDUCE_FADD_VL)
13380 NODE_NAME_CASE(VECREDUCE_SEQ_FADD_VL)
13381 NODE_NAME_CASE(VECREDUCE_FMIN_VL)
13382 NODE_NAME_CASE(VECREDUCE_FMAX_VL)
13383 NODE_NAME_CASE(ADD_VL)
13384 NODE_NAME_CASE(AND_VL)
13385 NODE_NAME_CASE(MUL_VL)
13386 NODE_NAME_CASE(OR_VL)
13387 NODE_NAME_CASE(SDIV_VL)
13388 NODE_NAME_CASE(SHL_VL)
13389 NODE_NAME_CASE(SREM_VL)
13390 NODE_NAME_CASE(SRA_VL)
13391 NODE_NAME_CASE(SRL_VL)
13392 NODE_NAME_CASE(SUB_VL)
13393 NODE_NAME_CASE(UDIV_VL)
13394 NODE_NAME_CASE(UREM_VL)
13395 NODE_NAME_CASE(XOR_VL)
13396 NODE_NAME_CASE(SADDSAT_VL)
13397 NODE_NAME_CASE(UADDSAT_VL)
13398 NODE_NAME_CASE(SSUBSAT_VL)
13399 NODE_NAME_CASE(USUBSAT_VL)
13400 NODE_NAME_CASE(FADD_VL)
13401 NODE_NAME_CASE(FSUB_VL)
13402 NODE_NAME_CASE(FMUL_VL)
13403 NODE_NAME_CASE(FDIV_VL)
13404 NODE_NAME_CASE(FNEG_VL)
13405 NODE_NAME_CASE(FABS_VL)
13406 NODE_NAME_CASE(FSQRT_VL)
13407 NODE_NAME_CASE(VFMADD_VL)
13408 NODE_NAME_CASE(VFNMADD_VL)
13409 NODE_NAME_CASE(VFMSUB_VL)
13410 NODE_NAME_CASE(VFNMSUB_VL)
13411 NODE_NAME_CASE(FCOPYSIGN_VL)
13412 NODE_NAME_CASE(SMIN_VL)
13413 NODE_NAME_CASE(SMAX_VL)
13414 NODE_NAME_CASE(UMIN_VL)
13415 NODE_NAME_CASE(UMAX_VL)
13416 NODE_NAME_CASE(FMINNUM_VL)
13417 NODE_NAME_CASE(FMAXNUM_VL)
13418 NODE_NAME_CASE(MULHS_VL)
13419 NODE_NAME_CASE(MULHU_VL)
13420 NODE_NAME_CASE(VFCVT_RTZ_X_F_VL)
13421 NODE_NAME_CASE(VFCVT_RTZ_XU_F_VL)
13422 NODE_NAME_CASE(VFCVT_RM_X_F_VL)
13423 NODE_NAME_CASE(VFCVT_RM_XU_F_VL)
13424 NODE_NAME_CASE(VFCVT_X_F_VL)
13425 NODE_NAME_CASE(VFCVT_XU_F_VL)
13426 NODE_NAME_CASE(VFROUND_NOEXCEPT_VL)
13427 NODE_NAME_CASE(SINT_TO_FP_VL)
13428 NODE_NAME_CASE(UINT_TO_FP_VL)
13429 NODE_NAME_CASE(VFCVT_RM_F_XU_VL)
13430 NODE_NAME_CASE(VFCVT_RM_F_X_VL)
13431 NODE_NAME_CASE(FP_EXTEND_VL)
13432 NODE_NAME_CASE(FP_ROUND_VL)
13433 NODE_NAME_CASE(VWMUL_VL)
13434 NODE_NAME_CASE(VWMULU_VL)
13435 NODE_NAME_CASE(VWMULSU_VL)
13436 NODE_NAME_CASE(VWADD_VL)
13437 NODE_NAME_CASE(VWADDU_VL)
13438 NODE_NAME_CASE(VWSUB_VL)
13439 NODE_NAME_CASE(VWSUBU_VL)
13440 NODE_NAME_CASE(VWADD_W_VL)
13441 NODE_NAME_CASE(VWADDU_W_VL)
13442 NODE_NAME_CASE(VWSUB_W_VL)
13443 NODE_NAME_CASE(VWSUBU_W_VL)
13444 NODE_NAME_CASE(VNSRL_VL)
13445 NODE_NAME_CASE(SETCC_VL)
13446 NODE_NAME_CASE(VSELECT_VL)
13447 NODE_NAME_CASE(VP_MERGE_VL)
13448 NODE_NAME_CASE(VMAND_VL)
13449 NODE_NAME_CASE(VMOR_VL)
13450 NODE_NAME_CASE(VMXOR_VL)
13451 NODE_NAME_CASE(VMCLR_VL)
13452 NODE_NAME_CASE(VMSET_VL)
13453 NODE_NAME_CASE(VRGATHER_VX_VL)
13454 NODE_NAME_CASE(VRGATHER_VV_VL)
13455 NODE_NAME_CASE(VRGATHEREI16_VV_VL)
13456 NODE_NAME_CASE(VSEXT_VL)
13457 NODE_NAME_CASE(VZEXT_VL)
13458 NODE_NAME_CASE(VCPOP_VL)
13459 NODE_NAME_CASE(VFIRST_VL)
13460 NODE_NAME_CASE(READ_CSR)
13461 NODE_NAME_CASE(WRITE_CSR)
13462 NODE_NAME_CASE(SWAP_CSR)
13463 }
13464 // clang-format on
13465 return nullptr;
13466 #undef NODE_NAME_CASE
13467 }
13468
13469 /// getConstraintType - Given a constraint letter, return the type of
13470 /// constraint it is for this target.
13471 RISCVTargetLowering::ConstraintType
getConstraintType(StringRef Constraint) const13472 RISCVTargetLowering::getConstraintType(StringRef Constraint) const {
13473 if (Constraint.size() == 1) {
13474 switch (Constraint[0]) {
13475 default:
13476 break;
13477 case 'f':
13478 return C_RegisterClass;
13479 case 'I':
13480 case 'J':
13481 case 'K':
13482 return C_Immediate;
13483 case 'A':
13484 return C_Memory;
13485 case 'S': // A symbolic address
13486 return C_Other;
13487 }
13488 } else {
13489 if (Constraint == "vr" || Constraint == "vm")
13490 return C_RegisterClass;
13491 }
13492 return TargetLowering::getConstraintType(Constraint);
13493 }
13494
13495 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * TRI,StringRef Constraint,MVT VT) const13496 RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
13497 StringRef Constraint,
13498 MVT VT) const {
13499 // First, see if this is a constraint that directly corresponds to a
13500 // RISCV register class.
13501 if (Constraint.size() == 1) {
13502 switch (Constraint[0]) {
13503 case 'r':
13504 // TODO: Support fixed vectors up to XLen for P extension?
13505 if (VT.isVector())
13506 break;
13507 return std::make_pair(0U, &RISCV::GPRRegClass);
13508 case 'f':
13509 if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16)
13510 return std::make_pair(0U, &RISCV::FPR16RegClass);
13511 if (Subtarget.hasStdExtF() && VT == MVT::f32)
13512 return std::make_pair(0U, &RISCV::FPR32RegClass);
13513 if (Subtarget.hasStdExtD() && VT == MVT::f64)
13514 return std::make_pair(0U, &RISCV::FPR64RegClass);
13515 break;
13516 default:
13517 break;
13518 }
13519 } else if (Constraint == "vr") {
13520 for (const auto *RC : {&RISCV::VRRegClass, &RISCV::VRM2RegClass,
13521 &RISCV::VRM4RegClass, &RISCV::VRM8RegClass}) {
13522 if (TRI->isTypeLegalForClass(*RC, VT.SimpleTy))
13523 return std::make_pair(0U, RC);
13524 }
13525 } else if (Constraint == "vm") {
13526 if (TRI->isTypeLegalForClass(RISCV::VMV0RegClass, VT.SimpleTy))
13527 return std::make_pair(0U, &RISCV::VMV0RegClass);
13528 }
13529
13530 // Clang will correctly decode the usage of register name aliases into their
13531 // official names. However, other frontends like `rustc` do not. This allows
13532 // users of these frontends to use the ABI names for registers in LLVM-style
13533 // register constraints.
13534 unsigned XRegFromAlias = StringSwitch<unsigned>(Constraint.lower())
13535 .Case("{zero}", RISCV::X0)
13536 .Case("{ra}", RISCV::X1)
13537 .Case("{sp}", RISCV::X2)
13538 .Case("{gp}", RISCV::X3)
13539 .Case("{tp}", RISCV::X4)
13540 .Case("{t0}", RISCV::X5)
13541 .Case("{t1}", RISCV::X6)
13542 .Case("{t2}", RISCV::X7)
13543 .Cases("{s0}", "{fp}", RISCV::X8)
13544 .Case("{s1}", RISCV::X9)
13545 .Case("{a0}", RISCV::X10)
13546 .Case("{a1}", RISCV::X11)
13547 .Case("{a2}", RISCV::X12)
13548 .Case("{a3}", RISCV::X13)
13549 .Case("{a4}", RISCV::X14)
13550 .Case("{a5}", RISCV::X15)
13551 .Case("{a6}", RISCV::X16)
13552 .Case("{a7}", RISCV::X17)
13553 .Case("{s2}", RISCV::X18)
13554 .Case("{s3}", RISCV::X19)
13555 .Case("{s4}", RISCV::X20)
13556 .Case("{s5}", RISCV::X21)
13557 .Case("{s6}", RISCV::X22)
13558 .Case("{s7}", RISCV::X23)
13559 .Case("{s8}", RISCV::X24)
13560 .Case("{s9}", RISCV::X25)
13561 .Case("{s10}", RISCV::X26)
13562 .Case("{s11}", RISCV::X27)
13563 .Case("{t3}", RISCV::X28)
13564 .Case("{t4}", RISCV::X29)
13565 .Case("{t5}", RISCV::X30)
13566 .Case("{t6}", RISCV::X31)
13567 .Default(RISCV::NoRegister);
13568 if (XRegFromAlias != RISCV::NoRegister)
13569 return std::make_pair(XRegFromAlias, &RISCV::GPRRegClass);
13570
13571 // Since TargetLowering::getRegForInlineAsmConstraint uses the name of the
13572 // TableGen record rather than the AsmName to choose registers for InlineAsm
13573 // constraints, plus we want to match those names to the widest floating point
13574 // register type available, manually select floating point registers here.
13575 //
13576 // The second case is the ABI name of the register, so that frontends can also
13577 // use the ABI names in register constraint lists.
13578 if (Subtarget.hasStdExtF()) {
13579 unsigned FReg = StringSwitch<unsigned>(Constraint.lower())
13580 .Cases("{f0}", "{ft0}", RISCV::F0_F)
13581 .Cases("{f1}", "{ft1}", RISCV::F1_F)
13582 .Cases("{f2}", "{ft2}", RISCV::F2_F)
13583 .Cases("{f3}", "{ft3}", RISCV::F3_F)
13584 .Cases("{f4}", "{ft4}", RISCV::F4_F)
13585 .Cases("{f5}", "{ft5}", RISCV::F5_F)
13586 .Cases("{f6}", "{ft6}", RISCV::F6_F)
13587 .Cases("{f7}", "{ft7}", RISCV::F7_F)
13588 .Cases("{f8}", "{fs0}", RISCV::F8_F)
13589 .Cases("{f9}", "{fs1}", RISCV::F9_F)
13590 .Cases("{f10}", "{fa0}", RISCV::F10_F)
13591 .Cases("{f11}", "{fa1}", RISCV::F11_F)
13592 .Cases("{f12}", "{fa2}", RISCV::F12_F)
13593 .Cases("{f13}", "{fa3}", RISCV::F13_F)
13594 .Cases("{f14}", "{fa4}", RISCV::F14_F)
13595 .Cases("{f15}", "{fa5}", RISCV::F15_F)
13596 .Cases("{f16}", "{fa6}", RISCV::F16_F)
13597 .Cases("{f17}", "{fa7}", RISCV::F17_F)
13598 .Cases("{f18}", "{fs2}", RISCV::F18_F)
13599 .Cases("{f19}", "{fs3}", RISCV::F19_F)
13600 .Cases("{f20}", "{fs4}", RISCV::F20_F)
13601 .Cases("{f21}", "{fs5}", RISCV::F21_F)
13602 .Cases("{f22}", "{fs6}", RISCV::F22_F)
13603 .Cases("{f23}", "{fs7}", RISCV::F23_F)
13604 .Cases("{f24}", "{fs8}", RISCV::F24_F)
13605 .Cases("{f25}", "{fs9}", RISCV::F25_F)
13606 .Cases("{f26}", "{fs10}", RISCV::F26_F)
13607 .Cases("{f27}", "{fs11}", RISCV::F27_F)
13608 .Cases("{f28}", "{ft8}", RISCV::F28_F)
13609 .Cases("{f29}", "{ft9}", RISCV::F29_F)
13610 .Cases("{f30}", "{ft10}", RISCV::F30_F)
13611 .Cases("{f31}", "{ft11}", RISCV::F31_F)
13612 .Default(RISCV::NoRegister);
13613 if (FReg != RISCV::NoRegister) {
13614 assert(RISCV::F0_F <= FReg && FReg <= RISCV::F31_F && "Unknown fp-reg");
13615 if (Subtarget.hasStdExtD() && (VT == MVT::f64 || VT == MVT::Other)) {
13616 unsigned RegNo = FReg - RISCV::F0_F;
13617 unsigned DReg = RISCV::F0_D + RegNo;
13618 return std::make_pair(DReg, &RISCV::FPR64RegClass);
13619 }
13620 if (VT == MVT::f32 || VT == MVT::Other)
13621 return std::make_pair(FReg, &RISCV::FPR32RegClass);
13622 if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16) {
13623 unsigned RegNo = FReg - RISCV::F0_F;
13624 unsigned HReg = RISCV::F0_H + RegNo;
13625 return std::make_pair(HReg, &RISCV::FPR16RegClass);
13626 }
13627 }
13628 }
13629
13630 if (Subtarget.hasVInstructions()) {
13631 Register VReg = StringSwitch<Register>(Constraint.lower())
13632 .Case("{v0}", RISCV::V0)
13633 .Case("{v1}", RISCV::V1)
13634 .Case("{v2}", RISCV::V2)
13635 .Case("{v3}", RISCV::V3)
13636 .Case("{v4}", RISCV::V4)
13637 .Case("{v5}", RISCV::V5)
13638 .Case("{v6}", RISCV::V6)
13639 .Case("{v7}", RISCV::V7)
13640 .Case("{v8}", RISCV::V8)
13641 .Case("{v9}", RISCV::V9)
13642 .Case("{v10}", RISCV::V10)
13643 .Case("{v11}", RISCV::V11)
13644 .Case("{v12}", RISCV::V12)
13645 .Case("{v13}", RISCV::V13)
13646 .Case("{v14}", RISCV::V14)
13647 .Case("{v15}", RISCV::V15)
13648 .Case("{v16}", RISCV::V16)
13649 .Case("{v17}", RISCV::V17)
13650 .Case("{v18}", RISCV::V18)
13651 .Case("{v19}", RISCV::V19)
13652 .Case("{v20}", RISCV::V20)
13653 .Case("{v21}", RISCV::V21)
13654 .Case("{v22}", RISCV::V22)
13655 .Case("{v23}", RISCV::V23)
13656 .Case("{v24}", RISCV::V24)
13657 .Case("{v25}", RISCV::V25)
13658 .Case("{v26}", RISCV::V26)
13659 .Case("{v27}", RISCV::V27)
13660 .Case("{v28}", RISCV::V28)
13661 .Case("{v29}", RISCV::V29)
13662 .Case("{v30}", RISCV::V30)
13663 .Case("{v31}", RISCV::V31)
13664 .Default(RISCV::NoRegister);
13665 if (VReg != RISCV::NoRegister) {
13666 if (TRI->isTypeLegalForClass(RISCV::VMRegClass, VT.SimpleTy))
13667 return std::make_pair(VReg, &RISCV::VMRegClass);
13668 if (TRI->isTypeLegalForClass(RISCV::VRRegClass, VT.SimpleTy))
13669 return std::make_pair(VReg, &RISCV::VRRegClass);
13670 for (const auto *RC :
13671 {&RISCV::VRM2RegClass, &RISCV::VRM4RegClass, &RISCV::VRM8RegClass}) {
13672 if (TRI->isTypeLegalForClass(*RC, VT.SimpleTy)) {
13673 VReg = TRI->getMatchingSuperReg(VReg, RISCV::sub_vrm1_0, RC);
13674 return std::make_pair(VReg, RC);
13675 }
13676 }
13677 }
13678 }
13679
13680 std::pair<Register, const TargetRegisterClass *> Res =
13681 TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
13682
13683 // If we picked one of the Zfinx register classes, remap it to the GPR class.
13684 // FIXME: When Zfinx is supported in CodeGen this will need to take the
13685 // Subtarget into account.
13686 if (Res.second == &RISCV::GPRF16RegClass ||
13687 Res.second == &RISCV::GPRF32RegClass ||
13688 Res.second == &RISCV::GPRF64RegClass)
13689 return std::make_pair(Res.first, &RISCV::GPRRegClass);
13690
13691 return Res;
13692 }
13693
13694 unsigned
getInlineAsmMemConstraint(StringRef ConstraintCode) const13695 RISCVTargetLowering::getInlineAsmMemConstraint(StringRef ConstraintCode) const {
13696 // Currently only support length 1 constraints.
13697 if (ConstraintCode.size() == 1) {
13698 switch (ConstraintCode[0]) {
13699 case 'A':
13700 return InlineAsm::Constraint_A;
13701 default:
13702 break;
13703 }
13704 }
13705
13706 return TargetLowering::getInlineAsmMemConstraint(ConstraintCode);
13707 }
13708
LowerAsmOperandForConstraint(SDValue Op,std::string & Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const13709 void RISCVTargetLowering::LowerAsmOperandForConstraint(
13710 SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
13711 SelectionDAG &DAG) const {
13712 // Currently only support length 1 constraints.
13713 if (Constraint.length() == 1) {
13714 switch (Constraint[0]) {
13715 case 'I':
13716 // Validate & create a 12-bit signed immediate operand.
13717 if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
13718 uint64_t CVal = C->getSExtValue();
13719 if (isInt<12>(CVal))
13720 Ops.push_back(
13721 DAG.getTargetConstant(CVal, SDLoc(Op), Subtarget.getXLenVT()));
13722 }
13723 return;
13724 case 'J':
13725 // Validate & create an integer zero operand.
13726 if (auto *C = dyn_cast<ConstantSDNode>(Op))
13727 if (C->getZExtValue() == 0)
13728 Ops.push_back(
13729 DAG.getTargetConstant(0, SDLoc(Op), Subtarget.getXLenVT()));
13730 return;
13731 case 'K':
13732 // Validate & create a 5-bit unsigned immediate operand.
13733 if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
13734 uint64_t CVal = C->getZExtValue();
13735 if (isUInt<5>(CVal))
13736 Ops.push_back(
13737 DAG.getTargetConstant(CVal, SDLoc(Op), Subtarget.getXLenVT()));
13738 }
13739 return;
13740 case 'S':
13741 if (const auto *GA = dyn_cast<GlobalAddressSDNode>(Op)) {
13742 Ops.push_back(DAG.getTargetGlobalAddress(GA->getGlobal(), SDLoc(Op),
13743 GA->getValueType(0)));
13744 } else if (const auto *BA = dyn_cast<BlockAddressSDNode>(Op)) {
13745 Ops.push_back(DAG.getTargetBlockAddress(BA->getBlockAddress(),
13746 BA->getValueType(0)));
13747 }
13748 return;
13749 default:
13750 break;
13751 }
13752 }
13753 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
13754 }
13755
emitLeadingFence(IRBuilderBase & Builder,Instruction * Inst,AtomicOrdering Ord) const13756 Instruction *RISCVTargetLowering::emitLeadingFence(IRBuilderBase &Builder,
13757 Instruction *Inst,
13758 AtomicOrdering Ord) const {
13759 if (isa<LoadInst>(Inst) && Ord == AtomicOrdering::SequentiallyConsistent)
13760 return Builder.CreateFence(Ord);
13761 if (isa<StoreInst>(Inst) && isReleaseOrStronger(Ord))
13762 return Builder.CreateFence(AtomicOrdering::Release);
13763 return nullptr;
13764 }
13765
emitTrailingFence(IRBuilderBase & Builder,Instruction * Inst,AtomicOrdering Ord) const13766 Instruction *RISCVTargetLowering::emitTrailingFence(IRBuilderBase &Builder,
13767 Instruction *Inst,
13768 AtomicOrdering Ord) const {
13769 if (isa<LoadInst>(Inst) && isAcquireOrStronger(Ord))
13770 return Builder.CreateFence(AtomicOrdering::Acquire);
13771 return nullptr;
13772 }
13773
13774 TargetLowering::AtomicExpansionKind
shouldExpandAtomicRMWInIR(AtomicRMWInst * AI) const13775 RISCVTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
13776 // atomicrmw {fadd,fsub} must be expanded to use compare-exchange, as floating
13777 // point operations can't be used in an lr/sc sequence without breaking the
13778 // forward-progress guarantee.
13779 if (AI->isFloatingPointOperation() ||
13780 AI->getOperation() == AtomicRMWInst::UIncWrap ||
13781 AI->getOperation() == AtomicRMWInst::UDecWrap)
13782 return AtomicExpansionKind::CmpXChg;
13783
13784 // Don't expand forced atomics, we want to have __sync libcalls instead.
13785 if (Subtarget.hasForcedAtomics())
13786 return AtomicExpansionKind::None;
13787
13788 unsigned Size = AI->getType()->getPrimitiveSizeInBits();
13789 if (Size == 8 || Size == 16)
13790 return AtomicExpansionKind::MaskedIntrinsic;
13791 return AtomicExpansionKind::None;
13792 }
13793
13794 static Intrinsic::ID
getIntrinsicForMaskedAtomicRMWBinOp(unsigned XLen,AtomicRMWInst::BinOp BinOp)13795 getIntrinsicForMaskedAtomicRMWBinOp(unsigned XLen, AtomicRMWInst::BinOp BinOp) {
13796 if (XLen == 32) {
13797 switch (BinOp) {
13798 default:
13799 llvm_unreachable("Unexpected AtomicRMW BinOp");
13800 case AtomicRMWInst::Xchg:
13801 return Intrinsic::riscv_masked_atomicrmw_xchg_i32;
13802 case AtomicRMWInst::Add:
13803 return Intrinsic::riscv_masked_atomicrmw_add_i32;
13804 case AtomicRMWInst::Sub:
13805 return Intrinsic::riscv_masked_atomicrmw_sub_i32;
13806 case AtomicRMWInst::Nand:
13807 return Intrinsic::riscv_masked_atomicrmw_nand_i32;
13808 case AtomicRMWInst::Max:
13809 return Intrinsic::riscv_masked_atomicrmw_max_i32;
13810 case AtomicRMWInst::Min:
13811 return Intrinsic::riscv_masked_atomicrmw_min_i32;
13812 case AtomicRMWInst::UMax:
13813 return Intrinsic::riscv_masked_atomicrmw_umax_i32;
13814 case AtomicRMWInst::UMin:
13815 return Intrinsic::riscv_masked_atomicrmw_umin_i32;
13816 }
13817 }
13818
13819 if (XLen == 64) {
13820 switch (BinOp) {
13821 default:
13822 llvm_unreachable("Unexpected AtomicRMW BinOp");
13823 case AtomicRMWInst::Xchg:
13824 return Intrinsic::riscv_masked_atomicrmw_xchg_i64;
13825 case AtomicRMWInst::Add:
13826 return Intrinsic::riscv_masked_atomicrmw_add_i64;
13827 case AtomicRMWInst::Sub:
13828 return Intrinsic::riscv_masked_atomicrmw_sub_i64;
13829 case AtomicRMWInst::Nand:
13830 return Intrinsic::riscv_masked_atomicrmw_nand_i64;
13831 case AtomicRMWInst::Max:
13832 return Intrinsic::riscv_masked_atomicrmw_max_i64;
13833 case AtomicRMWInst::Min:
13834 return Intrinsic::riscv_masked_atomicrmw_min_i64;
13835 case AtomicRMWInst::UMax:
13836 return Intrinsic::riscv_masked_atomicrmw_umax_i64;
13837 case AtomicRMWInst::UMin:
13838 return Intrinsic::riscv_masked_atomicrmw_umin_i64;
13839 }
13840 }
13841
13842 llvm_unreachable("Unexpected XLen\n");
13843 }
13844
emitMaskedAtomicRMWIntrinsic(IRBuilderBase & Builder,AtomicRMWInst * AI,Value * AlignedAddr,Value * Incr,Value * Mask,Value * ShiftAmt,AtomicOrdering Ord) const13845 Value *RISCVTargetLowering::emitMaskedAtomicRMWIntrinsic(
13846 IRBuilderBase &Builder, AtomicRMWInst *AI, Value *AlignedAddr, Value *Incr,
13847 Value *Mask, Value *ShiftAmt, AtomicOrdering Ord) const {
13848 unsigned XLen = Subtarget.getXLen();
13849 Value *Ordering =
13850 Builder.getIntN(XLen, static_cast<uint64_t>(AI->getOrdering()));
13851 Type *Tys[] = {AlignedAddr->getType()};
13852 Function *LrwOpScwLoop = Intrinsic::getDeclaration(
13853 AI->getModule(),
13854 getIntrinsicForMaskedAtomicRMWBinOp(XLen, AI->getOperation()), Tys);
13855
13856 if (XLen == 64) {
13857 Incr = Builder.CreateSExt(Incr, Builder.getInt64Ty());
13858 Mask = Builder.CreateSExt(Mask, Builder.getInt64Ty());
13859 ShiftAmt = Builder.CreateSExt(ShiftAmt, Builder.getInt64Ty());
13860 }
13861
13862 Value *Result;
13863
13864 // Must pass the shift amount needed to sign extend the loaded value prior
13865 // to performing a signed comparison for min/max. ShiftAmt is the number of
13866 // bits to shift the value into position. Pass XLen-ShiftAmt-ValWidth, which
13867 // is the number of bits to left+right shift the value in order to
13868 // sign-extend.
13869 if (AI->getOperation() == AtomicRMWInst::Min ||
13870 AI->getOperation() == AtomicRMWInst::Max) {
13871 const DataLayout &DL = AI->getModule()->getDataLayout();
13872 unsigned ValWidth =
13873 DL.getTypeStoreSizeInBits(AI->getValOperand()->getType());
13874 Value *SextShamt =
13875 Builder.CreateSub(Builder.getIntN(XLen, XLen - ValWidth), ShiftAmt);
13876 Result = Builder.CreateCall(LrwOpScwLoop,
13877 {AlignedAddr, Incr, Mask, SextShamt, Ordering});
13878 } else {
13879 Result =
13880 Builder.CreateCall(LrwOpScwLoop, {AlignedAddr, Incr, Mask, Ordering});
13881 }
13882
13883 if (XLen == 64)
13884 Result = Builder.CreateTrunc(Result, Builder.getInt32Ty());
13885 return Result;
13886 }
13887
13888 TargetLowering::AtomicExpansionKind
shouldExpandAtomicCmpXchgInIR(AtomicCmpXchgInst * CI) const13889 RISCVTargetLowering::shouldExpandAtomicCmpXchgInIR(
13890 AtomicCmpXchgInst *CI) const {
13891 // Don't expand forced atomics, we want to have __sync libcalls instead.
13892 if (Subtarget.hasForcedAtomics())
13893 return AtomicExpansionKind::None;
13894
13895 unsigned Size = CI->getCompareOperand()->getType()->getPrimitiveSizeInBits();
13896 if (Size == 8 || Size == 16)
13897 return AtomicExpansionKind::MaskedIntrinsic;
13898 return AtomicExpansionKind::None;
13899 }
13900
emitMaskedAtomicCmpXchgIntrinsic(IRBuilderBase & Builder,AtomicCmpXchgInst * CI,Value * AlignedAddr,Value * CmpVal,Value * NewVal,Value * Mask,AtomicOrdering Ord) const13901 Value *RISCVTargetLowering::emitMaskedAtomicCmpXchgIntrinsic(
13902 IRBuilderBase &Builder, AtomicCmpXchgInst *CI, Value *AlignedAddr,
13903 Value *CmpVal, Value *NewVal, Value *Mask, AtomicOrdering Ord) const {
13904 unsigned XLen = Subtarget.getXLen();
13905 Value *Ordering = Builder.getIntN(XLen, static_cast<uint64_t>(Ord));
13906 Intrinsic::ID CmpXchgIntrID = Intrinsic::riscv_masked_cmpxchg_i32;
13907 if (XLen == 64) {
13908 CmpVal = Builder.CreateSExt(CmpVal, Builder.getInt64Ty());
13909 NewVal = Builder.CreateSExt(NewVal, Builder.getInt64Ty());
13910 Mask = Builder.CreateSExt(Mask, Builder.getInt64Ty());
13911 CmpXchgIntrID = Intrinsic::riscv_masked_cmpxchg_i64;
13912 }
13913 Type *Tys[] = {AlignedAddr->getType()};
13914 Function *MaskedCmpXchg =
13915 Intrinsic::getDeclaration(CI->getModule(), CmpXchgIntrID, Tys);
13916 Value *Result = Builder.CreateCall(
13917 MaskedCmpXchg, {AlignedAddr, CmpVal, NewVal, Mask, Ordering});
13918 if (XLen == 64)
13919 Result = Builder.CreateTrunc(Result, Builder.getInt32Ty());
13920 return Result;
13921 }
13922
shouldRemoveExtendFromGSIndex(EVT IndexVT,EVT DataVT) const13923 bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(EVT IndexVT,
13924 EVT DataVT) const {
13925 return false;
13926 }
13927
shouldConvertFpToSat(unsigned Op,EVT FPVT,EVT VT) const13928 bool RISCVTargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,
13929 EVT VT) const {
13930 if (!isOperationLegalOrCustom(Op, VT) || !FPVT.isSimple())
13931 return false;
13932
13933 switch (FPVT.getSimpleVT().SimpleTy) {
13934 case MVT::f16:
13935 return Subtarget.hasStdExtZfhOrZfhmin();
13936 case MVT::f32:
13937 return Subtarget.hasStdExtF();
13938 case MVT::f64:
13939 return Subtarget.hasStdExtD();
13940 default:
13941 return false;
13942 }
13943 }
13944
getJumpTableEncoding() const13945 unsigned RISCVTargetLowering::getJumpTableEncoding() const {
13946 // If we are using the small code model, we can reduce size of jump table
13947 // entry to 4 bytes.
13948 if (Subtarget.is64Bit() && !isPositionIndependent() &&
13949 getTargetMachine().getCodeModel() == CodeModel::Small) {
13950 return MachineJumpTableInfo::EK_Custom32;
13951 }
13952 return TargetLowering::getJumpTableEncoding();
13953 }
13954
LowerCustomJumpTableEntry(const MachineJumpTableInfo * MJTI,const MachineBasicBlock * MBB,unsigned uid,MCContext & Ctx) const13955 const MCExpr *RISCVTargetLowering::LowerCustomJumpTableEntry(
13956 const MachineJumpTableInfo *MJTI, const MachineBasicBlock *MBB,
13957 unsigned uid, MCContext &Ctx) const {
13958 assert(Subtarget.is64Bit() && !isPositionIndependent() &&
13959 getTargetMachine().getCodeModel() == CodeModel::Small);
13960 return MCSymbolRefExpr::create(MBB->getSymbol(), Ctx);
13961 }
13962
isVScaleKnownToBeAPowerOfTwo() const13963 bool RISCVTargetLowering::isVScaleKnownToBeAPowerOfTwo() const {
13964 // We define vscale to be VLEN/RVVBitsPerBlock. VLEN is always a power
13965 // of two >= 64, and RVVBitsPerBlock is 64. Thus, vscale must be
13966 // a power of two as well.
13967 // FIXME: This doesn't work for zve32, but that's already broken
13968 // elsewhere for the same reason.
13969 assert(Subtarget.getRealMinVLen() >= 64 && "zve32* unsupported");
13970 static_assert(RISCV::RVVBitsPerBlock == 64,
13971 "RVVBitsPerBlock changed, audit needed");
13972 return true;
13973 }
13974
isFMAFasterThanFMulAndFAdd(const MachineFunction & MF,EVT VT) const13975 bool RISCVTargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
13976 EVT VT) const {
13977 EVT SVT = VT.getScalarType();
13978
13979 if (!SVT.isSimple())
13980 return false;
13981
13982 switch (SVT.getSimpleVT().SimpleTy) {
13983 case MVT::f16:
13984 return VT.isVector() ? Subtarget.hasVInstructionsF16()
13985 : Subtarget.hasStdExtZfh();
13986 case MVT::f32:
13987 return Subtarget.hasStdExtF();
13988 case MVT::f64:
13989 return Subtarget.hasStdExtD();
13990 default:
13991 break;
13992 }
13993
13994 return false;
13995 }
13996
getExceptionPointerRegister(const Constant * PersonalityFn) const13997 Register RISCVTargetLowering::getExceptionPointerRegister(
13998 const Constant *PersonalityFn) const {
13999 return RISCV::X10;
14000 }
14001
getExceptionSelectorRegister(const Constant * PersonalityFn) const14002 Register RISCVTargetLowering::getExceptionSelectorRegister(
14003 const Constant *PersonalityFn) const {
14004 return RISCV::X11;
14005 }
14006
shouldExtendTypeInLibCall(EVT Type) const14007 bool RISCVTargetLowering::shouldExtendTypeInLibCall(EVT Type) const {
14008 // Return false to suppress the unnecessary extensions if the LibCall
14009 // arguments or return value is f32 type for LP64 ABI.
14010 RISCVABI::ABI ABI = Subtarget.getTargetABI();
14011 if (ABI == RISCVABI::ABI_LP64 && (Type == MVT::f32))
14012 return false;
14013
14014 return true;
14015 }
14016
shouldSignExtendTypeInLibCall(EVT Type,bool IsSigned) const14017 bool RISCVTargetLowering::shouldSignExtendTypeInLibCall(EVT Type, bool IsSigned) const {
14018 if (Subtarget.is64Bit() && Type == MVT::i32)
14019 return true;
14020
14021 return IsSigned;
14022 }
14023
decomposeMulByConstant(LLVMContext & Context,EVT VT,SDValue C) const14024 bool RISCVTargetLowering::decomposeMulByConstant(LLVMContext &Context, EVT VT,
14025 SDValue C) const {
14026 // Check integral scalar types.
14027 const bool HasExtMOrZmmul =
14028 Subtarget.hasStdExtM() || Subtarget.hasStdExtZmmul();
14029 if (VT.isScalarInteger()) {
14030 // Omit the optimization if the sub target has the M extension and the data
14031 // size exceeds XLen.
14032 if (HasExtMOrZmmul && VT.getSizeInBits() > Subtarget.getXLen())
14033 return false;
14034 if (auto *ConstNode = dyn_cast<ConstantSDNode>(C.getNode())) {
14035 // Break the MUL to a SLLI and an ADD/SUB.
14036 const APInt &Imm = ConstNode->getAPIntValue();
14037 if ((Imm + 1).isPowerOf2() || (Imm - 1).isPowerOf2() ||
14038 (1 - Imm).isPowerOf2() || (-1 - Imm).isPowerOf2())
14039 return true;
14040 // Optimize the MUL to (SH*ADD x, (SLLI x, bits)) if Imm is not simm12.
14041 if (Subtarget.hasStdExtZba() && !Imm.isSignedIntN(12) &&
14042 ((Imm - 2).isPowerOf2() || (Imm - 4).isPowerOf2() ||
14043 (Imm - 8).isPowerOf2()))
14044 return true;
14045 // Omit the following optimization if the sub target has the M extension
14046 // and the data size >= XLen.
14047 if (HasExtMOrZmmul && VT.getSizeInBits() >= Subtarget.getXLen())
14048 return false;
14049 // Break the MUL to two SLLI instructions and an ADD/SUB, if Imm needs
14050 // a pair of LUI/ADDI.
14051 if (!Imm.isSignedIntN(12) && Imm.countTrailingZeros() < 12) {
14052 APInt ImmS = Imm.ashr(Imm.countTrailingZeros());
14053 if ((ImmS + 1).isPowerOf2() || (ImmS - 1).isPowerOf2() ||
14054 (1 - ImmS).isPowerOf2())
14055 return true;
14056 }
14057 }
14058 }
14059
14060 return false;
14061 }
14062
isMulAddWithConstProfitable(SDValue AddNode,SDValue ConstNode) const14063 bool RISCVTargetLowering::isMulAddWithConstProfitable(SDValue AddNode,
14064 SDValue ConstNode) const {
14065 // Let the DAGCombiner decide for vectors.
14066 EVT VT = AddNode.getValueType();
14067 if (VT.isVector())
14068 return true;
14069
14070 // Let the DAGCombiner decide for larger types.
14071 if (VT.getScalarSizeInBits() > Subtarget.getXLen())
14072 return true;
14073
14074 // It is worse if c1 is simm12 while c1*c2 is not.
14075 ConstantSDNode *C1Node = cast<ConstantSDNode>(AddNode.getOperand(1));
14076 ConstantSDNode *C2Node = cast<ConstantSDNode>(ConstNode);
14077 const APInt &C1 = C1Node->getAPIntValue();
14078 const APInt &C2 = C2Node->getAPIntValue();
14079 if (C1.isSignedIntN(12) && !(C1 * C2).isSignedIntN(12))
14080 return false;
14081
14082 // Default to true and let the DAGCombiner decide.
14083 return true;
14084 }
14085
allowsMisalignedMemoryAccesses(EVT VT,unsigned AddrSpace,Align Alignment,MachineMemOperand::Flags Flags,unsigned * Fast) const14086 bool RISCVTargetLowering::allowsMisalignedMemoryAccesses(
14087 EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
14088 unsigned *Fast) const {
14089 if (!VT.isVector()) {
14090 if (Fast)
14091 *Fast = 0;
14092 return Subtarget.enableUnalignedScalarMem();
14093 }
14094
14095 // All vector implementations must support element alignment
14096 EVT ElemVT = VT.getVectorElementType();
14097 if (Alignment >= ElemVT.getStoreSize()) {
14098 if (Fast)
14099 *Fast = 1;
14100 return true;
14101 }
14102
14103 return false;
14104 }
14105
splitValueIntoRegisterParts(SelectionDAG & DAG,const SDLoc & DL,SDValue Val,SDValue * Parts,unsigned NumParts,MVT PartVT,std::optional<CallingConv::ID> CC) const14106 bool RISCVTargetLowering::splitValueIntoRegisterParts(
14107 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
14108 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
14109 bool IsABIRegCopy = CC.has_value();
14110 EVT ValueVT = Val.getValueType();
14111 if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
14112 // Cast the f16 to i16, extend to i32, pad with ones to make a float nan,
14113 // and cast to f32.
14114 Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
14115 Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
14116 Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
14117 DAG.getConstant(0xFFFF0000, DL, MVT::i32));
14118 Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
14119 Parts[0] = Val;
14120 return true;
14121 }
14122
14123 if (ValueVT.isScalableVector() && PartVT.isScalableVector()) {
14124 LLVMContext &Context = *DAG.getContext();
14125 EVT ValueEltVT = ValueVT.getVectorElementType();
14126 EVT PartEltVT = PartVT.getVectorElementType();
14127 unsigned ValueVTBitSize = ValueVT.getSizeInBits().getKnownMinValue();
14128 unsigned PartVTBitSize = PartVT.getSizeInBits().getKnownMinValue();
14129 if (PartVTBitSize % ValueVTBitSize == 0) {
14130 assert(PartVTBitSize >= ValueVTBitSize);
14131 // If the element types are different, bitcast to the same element type of
14132 // PartVT first.
14133 // Give an example here, we want copy a <vscale x 1 x i8> value to
14134 // <vscale x 4 x i16>.
14135 // We need to convert <vscale x 1 x i8> to <vscale x 8 x i8> by insert
14136 // subvector, then we can bitcast to <vscale x 4 x i16>.
14137 if (ValueEltVT != PartEltVT) {
14138 if (PartVTBitSize > ValueVTBitSize) {
14139 unsigned Count = PartVTBitSize / ValueEltVT.getFixedSizeInBits();
14140 assert(Count != 0 && "The number of element should not be zero.");
14141 EVT SameEltTypeVT =
14142 EVT::getVectorVT(Context, ValueEltVT, Count, /*IsScalable=*/true);
14143 Val = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, SameEltTypeVT,
14144 DAG.getUNDEF(SameEltTypeVT), Val,
14145 DAG.getVectorIdxConstant(0, DL));
14146 }
14147 Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
14148 } else {
14149 Val =
14150 DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT),
14151 Val, DAG.getVectorIdxConstant(0, DL));
14152 }
14153 Parts[0] = Val;
14154 return true;
14155 }
14156 }
14157 return false;
14158 }
14159
joinRegisterPartsIntoValue(SelectionDAG & DAG,const SDLoc & DL,const SDValue * Parts,unsigned NumParts,MVT PartVT,EVT ValueVT,std::optional<CallingConv::ID> CC) const14160 SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
14161 SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
14162 MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
14163 bool IsABIRegCopy = CC.has_value();
14164 if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
14165 SDValue Val = Parts[0];
14166
14167 // Cast the f32 to i32, truncate to i16, and cast back to f16.
14168 Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
14169 Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
14170 Val = DAG.getNode(ISD::BITCAST, DL, MVT::f16, Val);
14171 return Val;
14172 }
14173
14174 if (ValueVT.isScalableVector() && PartVT.isScalableVector()) {
14175 LLVMContext &Context = *DAG.getContext();
14176 SDValue Val = Parts[0];
14177 EVT ValueEltVT = ValueVT.getVectorElementType();
14178 EVT PartEltVT = PartVT.getVectorElementType();
14179 unsigned ValueVTBitSize = ValueVT.getSizeInBits().getKnownMinValue();
14180 unsigned PartVTBitSize = PartVT.getSizeInBits().getKnownMinValue();
14181 if (PartVTBitSize % ValueVTBitSize == 0) {
14182 assert(PartVTBitSize >= ValueVTBitSize);
14183 EVT SameEltTypeVT = ValueVT;
14184 // If the element types are different, convert it to the same element type
14185 // of PartVT.
14186 // Give an example here, we want copy a <vscale x 1 x i8> value from
14187 // <vscale x 4 x i16>.
14188 // We need to convert <vscale x 4 x i16> to <vscale x 8 x i8> first,
14189 // then we can extract <vscale x 1 x i8>.
14190 if (ValueEltVT != PartEltVT) {
14191 unsigned Count = PartVTBitSize / ValueEltVT.getFixedSizeInBits();
14192 assert(Count != 0 && "The number of element should not be zero.");
14193 SameEltTypeVT =
14194 EVT::getVectorVT(Context, ValueEltVT, Count, /*IsScalable=*/true);
14195 Val = DAG.getNode(ISD::BITCAST, DL, SameEltTypeVT, Val);
14196 }
14197 Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ValueVT, Val,
14198 DAG.getVectorIdxConstant(0, DL));
14199 return Val;
14200 }
14201 }
14202 return SDValue();
14203 }
14204
isIntDivCheap(EVT VT,AttributeList Attr) const14205 bool RISCVTargetLowering::isIntDivCheap(EVT VT, AttributeList Attr) const {
14206 // When aggressively optimizing for code size, we prefer to use a div
14207 // instruction, as it is usually smaller than the alternative sequence.
14208 // TODO: Add vector division?
14209 bool OptSize = Attr.hasFnAttr(Attribute::MinSize);
14210 return OptSize && !VT.isVector();
14211 }
14212
preferScalarizeSplat(unsigned Opc) const14213 bool RISCVTargetLowering::preferScalarizeSplat(unsigned Opc) const {
14214 // Scalarize zero_ext and sign_ext might stop match to widening instruction in
14215 // some situation.
14216 if (Opc == ISD::ZERO_EXTEND || Opc == ISD::SIGN_EXTEND)
14217 return false;
14218 return true;
14219 }
14220
useTpOffset(IRBuilderBase & IRB,unsigned Offset)14221 static Value *useTpOffset(IRBuilderBase &IRB, unsigned Offset) {
14222 Module *M = IRB.GetInsertBlock()->getParent()->getParent();
14223 Function *ThreadPointerFunc =
14224 Intrinsic::getDeclaration(M, Intrinsic::thread_pointer);
14225 return IRB.CreatePointerCast(
14226 IRB.CreateConstGEP1_32(IRB.getInt8Ty(),
14227 IRB.CreateCall(ThreadPointerFunc), Offset),
14228 IRB.getInt8PtrTy()->getPointerTo(0));
14229 }
14230
getIRStackGuard(IRBuilderBase & IRB) const14231 Value *RISCVTargetLowering::getIRStackGuard(IRBuilderBase &IRB) const {
14232 // Fuchsia provides a fixed TLS slot for the stack cookie.
14233 // <zircon/tls.h> defines ZX_TLS_STACK_GUARD_OFFSET with this value.
14234 if (Subtarget.isTargetFuchsia())
14235 return useTpOffset(IRB, -0x10);
14236
14237 return TargetLowering::getIRStackGuard(IRB);
14238 }
14239
14240 #define GET_REGISTER_MATCHER
14241 #include "RISCVGenAsmMatcher.inc"
14242
14243 Register
getRegisterByName(const char * RegName,LLT VT,const MachineFunction & MF) const14244 RISCVTargetLowering::getRegisterByName(const char *RegName, LLT VT,
14245 const MachineFunction &MF) const {
14246 Register Reg = MatchRegisterAltName(RegName);
14247 if (Reg == RISCV::NoRegister)
14248 Reg = MatchRegisterName(RegName);
14249 if (Reg == RISCV::NoRegister)
14250 report_fatal_error(
14251 Twine("Invalid register name \"" + StringRef(RegName) + "\"."));
14252 BitVector ReservedRegs = Subtarget.getRegisterInfo()->getReservedRegs(MF);
14253 if (!ReservedRegs.test(Reg) && !Subtarget.isRegisterReservedByUser(Reg))
14254 report_fatal_error(Twine("Trying to obtain non-reserved register \"" +
14255 StringRef(RegName) + "\"."));
14256 return Reg;
14257 }
14258
14259 namespace llvm::RISCVVIntrinsicsTable {
14260
14261 #define GET_RISCVVIntrinsicsTable_IMPL
14262 #include "RISCVGenSearchableTables.inc"
14263
14264 } // namespace llvm::RISCVVIntrinsicsTable
14265