1 //===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering and legalization of vector instructions to
10 // VVP_*layer SDNodes.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "VECustomDAG.h"
15 #include "VEISelLowering.h"
16
17 using namespace llvm;
18
19 #define DEBUG_TYPE "ve-lower"
20
splitMaskArithmetic(SDValue Op,SelectionDAG & DAG) const21 SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
22 SelectionDAG &DAG) const {
23 VECustomDAG CDAG(DAG, Op);
24 SDValue AVL =
25 CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
26 SDValue A = Op->getOperand(0);
27 SDValue B = Op->getOperand(1);
28 SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
29 SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
30 SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
31 SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
32 unsigned Opc = Op.getOpcode();
33 auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
34 auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
35 return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
36 }
37
lowerToVVP(SDValue Op,SelectionDAG & DAG) const38 SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
39 // Can we represent this as a VVP node.
40 const unsigned Opcode = Op->getOpcode();
41 auto VVPOpcodeOpt = getVVPOpcode(Opcode);
42 if (!VVPOpcodeOpt)
43 return SDValue();
44 unsigned VVPOpcode = *VVPOpcodeOpt;
45 const bool FromVP = ISD::isVPOpcode(Opcode);
46
47 // The representative and legalized vector type of this operation.
48 VECustomDAG CDAG(DAG, Op);
49 // Dispatch to complex lowering functions.
50 switch (VVPOpcode) {
51 case VEISD::VVP_LOAD:
52 case VEISD::VVP_STORE:
53 return lowerVVP_LOAD_STORE(Op, CDAG);
54 case VEISD::VVP_GATHER:
55 case VEISD::VVP_SCATTER:
56 return lowerVVP_GATHER_SCATTER(Op, CDAG);
57 }
58
59 EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
60 EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
61 auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
62
63 SDValue AVL;
64 SDValue Mask;
65
66 if (FromVP) {
67 // All upstream VP SDNodes always have a mask and avl.
68 auto MaskIdx = ISD::getVPMaskIdx(Opcode);
69 auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
70 if (MaskIdx)
71 Mask = Op->getOperand(*MaskIdx);
72 if (AVLIdx)
73 AVL = Op->getOperand(*AVLIdx);
74 }
75
76 // Materialize default mask and avl.
77 if (!AVL)
78 AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
79 if (!Mask)
80 Mask = CDAG.getConstantMask(Packing, true);
81
82 assert(LegalVecVT.isSimple());
83 if (isVVPUnaryOp(VVPOpcode))
84 return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
85 if (isVVPBinaryOp(VVPOpcode))
86 return CDAG.getNode(VVPOpcode, LegalVecVT,
87 {Op->getOperand(0), Op->getOperand(1), Mask, AVL});
88 if (isVVPReductionOp(VVPOpcode)) {
89 auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
90 SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
91 SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
92 return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
93 VectorV, Mask, AVL, Op->getFlags());
94 }
95
96 switch (VVPOpcode) {
97 default:
98 llvm_unreachable("lowerToVVP called for unexpected SDNode.");
99 case VEISD::VVP_FFMA: {
100 // VE has a swizzled operand order in FMA (compared to LLVM IR and
101 // SDNodes).
102 auto X = Op->getOperand(2);
103 auto Y = Op->getOperand(0);
104 auto Z = Op->getOperand(1);
105 return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
106 }
107 case VEISD::VVP_SELECT: {
108 auto Mask = Op->getOperand(0);
109 auto OnTrue = Op->getOperand(1);
110 auto OnFalse = Op->getOperand(2);
111 return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
112 }
113 case VEISD::VVP_SETCC: {
114 EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
115 auto LHS = Op->getOperand(0);
116 auto RHS = Op->getOperand(1);
117 auto Pred = Op->getOperand(2);
118 return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
119 }
120 }
121 }
122
lowerVVP_LOAD_STORE(SDValue Op,VECustomDAG & CDAG) const123 SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
124 VECustomDAG &CDAG) const {
125 auto VVPOpc = *getVVPOpcode(Op->getOpcode());
126 const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
127
128 // Shares.
129 SDValue BasePtr = getMemoryPtr(Op);
130 SDValue Mask = getNodeMask(Op);
131 SDValue Chain = getNodeChain(Op);
132 SDValue AVL = getNodeAVL(Op);
133 // Store specific.
134 SDValue Data = getStoredValue(Op);
135 // Load specific.
136 SDValue PassThru = getNodePassthru(Op);
137
138 SDValue StrideV = getLoadStoreStride(Op, CDAG);
139
140 auto DataVT = *getIdiomaticVectorType(Op.getNode());
141 auto Packing = getTypePacking(DataVT);
142
143 // TODO: Infer lower AVL from mask.
144 if (!AVL)
145 AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
146
147 // Default to the all-true mask.
148 if (!Mask)
149 Mask = CDAG.getConstantMask(Packing, true);
150
151 if (IsLoad) {
152 MVT LegalDataVT = getLegalVectorType(
153 Packing, DataVT.getVectorElementType().getSimpleVT());
154
155 auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
156 {Chain, BasePtr, StrideV, Mask, AVL});
157
158 if (!PassThru || PassThru->isUndef())
159 return NewLoadV;
160
161 // Convert passthru to an explicit select node.
162 SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
163 {NewLoadV, PassThru, Mask, AVL});
164 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
165
166 // Merge them back into one node.
167 return CDAG.getMergeValues({DataV, NewLoadChainV});
168 }
169
170 // VVP_STORE
171 assert(VVPOpc == VEISD::VVP_STORE);
172 return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
173 {Chain, Data, BasePtr, StrideV, Mask, AVL});
174 }
175
splitPackedLoadStore(SDValue Op,VECustomDAG & CDAG) const176 SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
177 VECustomDAG &CDAG) const {
178 auto VVPOC = *getVVPOpcode(Op.getOpcode());
179 assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
180
181 MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
182 assert(getTypePacking(DataVT) == Packing::Dense &&
183 "Can only split packed load/store");
184 MVT SplitDataVT = splitVectorType(DataVT);
185
186 assert(!getNodePassthru(Op) &&
187 "Should have been folded in lowering to VVP layer");
188
189 // Analyze the operation
190 SDValue PackedMask = getNodeMask(Op);
191 SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
192 SDValue PackPtr = getMemoryPtr(Op);
193 SDValue PackData = getStoredValue(Op);
194 SDValue PackStride = getLoadStoreStride(Op, CDAG);
195
196 unsigned ChainResIdx = PackData ? 0 : 1;
197
198 SDValue PartOps[2];
199
200 SDValue UpperPartAVL; // we will use this for packing things back together
201 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
202 // VP ops already have an explicit mask and AVL. When expanding from non-VP
203 // attach those additional inputs here.
204 auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
205
206 // Keep track of the (higher) lvl.
207 if (Part == PackElem::Hi)
208 UpperPartAVL = SplitTM.AVL;
209
210 // Attach non-predicating value operands
211 SmallVector<SDValue, 4> OpVec;
212
213 // Chain
214 OpVec.push_back(getNodeChain(Op));
215
216 // Data
217 if (PackData) {
218 SDValue PartData =
219 CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
220 OpVec.push_back(PartData);
221 }
222
223 // Ptr & Stride
224 // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
225 // Stride info
226 // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
227 OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
228 OpVec.push_back(CDAG.getSplitPtrStride(PackStride));
229
230 // Add predicating args and generate part node
231 OpVec.push_back(SplitTM.Mask);
232 OpVec.push_back(SplitTM.AVL);
233
234 if (PackData) {
235 // Store
236 PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
237 } else {
238 // Load
239 PartOps[(int)Part] =
240 CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
241 }
242 }
243
244 // Merge the chains
245 SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
246 SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
247 SDValue FusedChains =
248 CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});
249
250 // Chain only [store]
251 if (PackData)
252 return FusedChains;
253
254 // Re-pack into full packed vector result
255 MVT PackedVT =
256 getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
257 SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
258 PartOps[(int)PackElem::Hi], UpperPartAVL);
259
260 return CDAG.getMergeValues({PackedVals, FusedChains});
261 }
262
lowerVVP_GATHER_SCATTER(SDValue Op,VECustomDAG & CDAG) const263 SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
264 VECustomDAG &CDAG) const {
265 EVT DataVT = *getIdiomaticVectorType(Op.getNode());
266 auto Packing = getTypePacking(DataVT);
267 MVT LegalDataVT =
268 getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());
269
270 SDValue AVL = getAnnotatedNodeAVL(Op).first;
271 SDValue Index = getGatherScatterIndex(Op);
272 SDValue BasePtr = getMemoryPtr(Op);
273 SDValue Mask = getNodeMask(Op);
274 SDValue Chain = getNodeChain(Op);
275 SDValue Scale = getGatherScatterScale(Op);
276 SDValue PassThru = getNodePassthru(Op);
277 SDValue StoredValue = getStoredValue(Op);
278 if (PassThru && PassThru->isUndef())
279 PassThru = SDValue();
280
281 bool IsScatter = (bool)StoredValue;
282
283 // TODO: Infer lower AVL from mask.
284 if (!AVL)
285 AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
286
287 // Default to the all-true mask.
288 if (!Mask)
289 Mask = CDAG.getConstantMask(Packing, true);
290
291 SDValue AddressVec =
292 CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
293 if (IsScatter)
294 return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
295 {Chain, StoredValue, AddressVec, Mask, AVL});
296
297 // Gather.
298 SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
299 {Chain, AddressVec, Mask, AVL});
300
301 if (!PassThru)
302 return NewLoadV;
303
304 // TODO: Use vvp_select
305 SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
306 {NewLoadV, PassThru, Mask, AVL});
307 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
308 return CDAG.getMergeValues({DataV, NewLoadChainV});
309 }
310
legalizeInternalLoadStoreOp(SDValue Op,VECustomDAG & CDAG) const311 SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
312 VECustomDAG &CDAG) const {
313 LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
314 MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
315
316 // TODO: Recognize packable load,store.
317 if (isPackedVectorType(DataVT))
318 return splitPackedLoadStore(Op, CDAG);
319
320 return legalizePackedAVL(Op, CDAG);
321 }
322
legalizeInternalVectorOp(SDValue Op,SelectionDAG & DAG) const323 SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
324 SelectionDAG &DAG) const {
325 LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
326 VECustomDAG CDAG(DAG, Op);
327
328 // Dispatch to specialized legalization functions.
329 switch (Op->getOpcode()) {
330 case VEISD::VVP_LOAD:
331 case VEISD::VVP_STORE:
332 return legalizeInternalLoadStoreOp(Op, CDAG);
333 }
334
335 EVT IdiomVT = Op.getValueType();
336 if (isPackedVectorType(IdiomVT) &&
337 !supportsPackedMode(Op.getOpcode(), IdiomVT))
338 return splitVectorOp(Op, CDAG);
339
340 // TODO: Implement odd/even splitting.
341 return legalizePackedAVL(Op, CDAG);
342 }
343
splitVectorOp(SDValue Op,VECustomDAG & CDAG) const344 SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
345 MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());
346
347 auto AVLPos = getAVLPos(Op->getOpcode());
348 auto MaskPos = getMaskPos(Op->getOpcode());
349
350 SDValue PackedMask = getNodeMask(Op);
351 auto AVLPair = getAnnotatedNodeAVL(Op);
352 SDValue PackedAVL = AVLPair.first;
353 assert(!AVLPair.second && "Expecting non pack-legalized oepration");
354
355 // request the parts
356 SDValue PartOps[2];
357
358 SDValue UpperPartAVL; // we will use this for packing things back together
359 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
360 // VP ops already have an explicit mask and AVL. When expanding from non-VP
361 // attach those additional inputs here.
362 auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
363
364 if (Part == PackElem::Hi)
365 UpperPartAVL = SplitTM.AVL;
366
367 // Attach non-predicating value operands
368 SmallVector<SDValue, 4> OpVec;
369 for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
370 if (AVLPos && ((int)i) == *AVLPos)
371 continue;
372 if (MaskPos && ((int)i) == *MaskPos)
373 continue;
374
375 // Value operand
376 auto PackedOperand = Op.getOperand(i);
377 auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
378 SDValue PartV =
379 CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
380 OpVec.push_back(PartV);
381 }
382
383 // Add predicating args and generate part node.
384 OpVec.push_back(SplitTM.Mask);
385 OpVec.push_back(SplitTM.AVL);
386 // Emit legal VVP nodes.
387 PartOps[(int)Part] =
388 CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
389 }
390
391 // Re-package vectors.
392 return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
393 PartOps[(int)PackElem::Hi], UpperPartAVL);
394 }
395
legalizePackedAVL(SDValue Op,VECustomDAG & CDAG) const396 SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
397 VECustomDAG &CDAG) const {
398 LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
399 // Only required for VEC and VVP ops.
400 if (!isVVPOrVEC(Op->getOpcode()))
401 return Op;
402
403 // Operation already has a legal AVL.
404 auto AVL = getNodeAVL(Op);
405 if (isLegalAVL(AVL))
406 return Op;
407
408 // Half and round up EVL for 32bit element types.
409 SDValue LegalAVL = AVL;
410 MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
411 if (isPackedVectorType(IdiomVT)) {
412 assert(maySafelyIgnoreMask(Op) &&
413 "TODO Shift predication from EVL into Mask");
414
415 if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
416 LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
417 } else {
418 auto ConstOne = CDAG.getConstant(1, MVT::i32);
419 auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
420 LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
421 }
422 }
423
424 SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);
425
426 // Copy the operand list.
427 int NumOp = Op->getNumOperands();
428 auto AVLPos = getAVLPos(Op->getOpcode());
429 std::vector<SDValue> FixedOperands;
430 for (int i = 0; i < NumOp; ++i) {
431 if (AVLPos && (i == *AVLPos)) {
432 FixedOperands.push_back(AnnotatedLegalAVL);
433 continue;
434 }
435 FixedOperands.push_back(Op->getOperand(i));
436 }
437
438 // Clone the operation with fixed operands.
439 auto Flags = Op->getFlags();
440 SDValue NewN =
441 CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
442 return NewN;
443 }
444