1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIRVTargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "SPIRVISelLowering.h"
14 #include "SPIRV.h"
15 #include "llvm/IR/IntrinsicsSPIRV.h"
16
17 #define DEBUG_TYPE "spirv-lower"
18
19 using namespace llvm;
20
getNumRegistersForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const21 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
22 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
23 // This code avoids CallLowering fail inside getVectorTypeBreakdown
24 // on v3i1 arguments. Maybe we need to return 1 for all types.
25 // TODO: remove it once this case is supported by the default implementation.
26 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
27 (VT.getVectorElementType() == MVT::i1 ||
28 VT.getVectorElementType() == MVT::i8))
29 return 1;
30 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
31 return 1;
32 return getNumRegisters(Context, VT);
33 }
34
getRegisterTypeForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const35 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
36 CallingConv::ID CC,
37 EVT VT) const {
38 // This code avoids CallLowering fail inside getVectorTypeBreakdown
39 // on v3i1 arguments. Maybe we need to return i32 for all types.
40 // TODO: remove it once this case is supported by the default implementation.
41 if (VT.isVector() && VT.getVectorNumElements() == 3) {
42 if (VT.getVectorElementType() == MVT::i1)
43 return MVT::v4i1;
44 else if (VT.getVectorElementType() == MVT::i8)
45 return MVT::v4i8;
46 }
47 return getRegisterType(Context, VT);
48 }
49
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const50 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
51 const CallInst &I,
52 MachineFunction &MF,
53 unsigned Intrinsic) const {
54 unsigned AlignIdx = 3;
55 switch (Intrinsic) {
56 case Intrinsic::spv_load:
57 AlignIdx = 2;
58 [[fallthrough]];
59 case Intrinsic::spv_store: {
60 if (I.getNumOperands() >= AlignIdx + 1) {
61 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
62 Info.align = Align(AlignOp->getZExtValue());
63 }
64 Info.flags = static_cast<MachineMemOperand::Flags>(
65 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
66 Info.memVT = MVT::i64;
67 // TODO: take into account opaque pointers (don't use getElementType).
68 // MVT::getVT(PtrTy->getElementType());
69 return true;
70 break;
71 }
72 default:
73 break;
74 }
75 return false;
76 }
77