1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIRVTargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVISelLowering.h"
14 #include "SPIRV.h"
15 
16 #define DEBUG_TYPE "spirv-lower"
17 
18 using namespace llvm;
19 
20 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
21     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
22   // This code avoids CallLowering fail inside getVectorTypeBreakdown
23   // on v3i1 arguments. Maybe we need to return 1 for all types.
24   // TODO: remove it once this case is supported by the default implementation.
25   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
26       (VT.getVectorElementType() == MVT::i1 ||
27        VT.getVectorElementType() == MVT::i8))
28     return 1;
29   return getNumRegisters(Context, VT);
30 }
31 
32 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
33                                                        CallingConv::ID CC,
34                                                        EVT VT) const {
35   // This code avoids CallLowering fail inside getVectorTypeBreakdown
36   // on v3i1 arguments. Maybe we need to return i32 for all types.
37   // TODO: remove it once this case is supported by the default implementation.
38   if (VT.isVector() && VT.getVectorNumElements() == 3) {
39     if (VT.getVectorElementType() == MVT::i1)
40       return MVT::v4i1;
41     else if (VT.getVectorElementType() == MVT::i8)
42       return MVT::v4i8;
43   }
44   return getRegisterType(Context, VT);
45 }
46