1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2017-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 //===----------------------------------------------------------------------===//
10 //
11 // The purpose of this pass is replace instructions using halfs with
12 // corresponding float counterparts.
13 //
14 // All unnecessary conversions get cleaned up before code gen.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 
19 #include "HalfPromotion.h"
20 #include "Compiler/IGCPassSupport.h"
21 #include "GenISAIntrinsics/GenIntrinsics.h"
22 #include "IGCIRBuilder.h"
23 
24 #include "common/LLVMWarningsPush.hpp"
25 #include <llvm/IR/Function.h>
26 #include "common/LLVMWarningsPop.hpp"
27 
28 
29 using namespace llvm;
30 using namespace IGC;
31 
32 #define PASS_FLAG "half-promotion"
33 #define PASS_DESCRIPTION "Promotion of halfs to floats"
34 #define PASS_CFG_ONLY false
35 #define PASS_ANALYSIS false
36 IGC_INITIALIZE_PASS_BEGIN(HalfPromotion, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
37 IGC_INITIALIZE_PASS_END(HalfPromotion, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
38 
39 char HalfPromotion::ID = 0;
40 
HalfPromotion()41 HalfPromotion::HalfPromotion() : FunctionPass(ID)
42 {
43     initializeHalfPromotionPass(*PassRegistry::getPassRegistry());
44 }
45 
runOnFunction(Function & F)46 bool HalfPromotion::runOnFunction(Function& F)
47 {
48     visit(F);
49     return m_changed;
50 }
51 
visitCallInst(llvm::CallInst & I)52 void HalfPromotion::visitCallInst(llvm::CallInst& I)
53 {
54     if (llvm::isa<GenIntrinsicInst>(I) && I.getType()->isHalfTy())
55     {
56         handleGenIntrinsic(llvm::cast<GenIntrinsicInst>(I));
57     }
58     else if (llvm::isa<llvm::IntrinsicInst>(I) && I.getType()->isHalfTy())
59     {
60         handleLLVMIntrinsic(llvm::cast<IntrinsicInst>(I));
61     }
62 }
63 
handleLLVMIntrinsic(llvm::IntrinsicInst & I)64 void IGC::HalfPromotion::handleLLVMIntrinsic(llvm::IntrinsicInst& I)
65 {
66     Intrinsic::ID id = I.getIntrinsicID();
67     if (id == Intrinsic::cos ||
68         id == Intrinsic::sin ||
69         id == Intrinsic::log2 ||
70         id == Intrinsic::exp2 ||
71         id == Intrinsic::sqrt ||
72         id == Intrinsic::floor ||
73         id == Intrinsic::ceil ||
74         id == Intrinsic::fabs ||
75         id == Intrinsic::pow ||
76         id == Intrinsic::fma ||
77         id == Intrinsic::maxnum ||
78         id == Intrinsic::minnum)
79     {
80         Module* M = I.getParent()->getParent()->getParent();
81         llvm::IGCIRBuilder<> builder(&I);
82         std::vector<llvm::Value*> arguments;
83 
84         Function* pNewFunc = Intrinsic::getDeclaration(
85             M,
86             I.getIntrinsicID(),
87             builder.getFloatTy());
88 
89         for (unsigned i = 0; i < I.getNumArgOperands(); ++i)
90         {
91             if (I.getOperand(i)->getType()->isHalfTy())
92             {
93                 Value* op = builder.CreateFPExt(I.getOperand(i), builder.getFloatTy());
94                 arguments.push_back(op);
95             }
96             else
97             {
98                 arguments.push_back(I.getOperand(i));
99             }
100         }
101 
102         Value* f32Val = builder.CreateCall(
103             pNewFunc,
104             arguments);
105         Value* f16Val = builder.CreateFPTrunc(f32Val, builder.getHalfTy());
106         I.replaceAllUsesWith(f16Val);
107         m_changed = true;
108     }
109 }
110 
handleGenIntrinsic(llvm::GenIntrinsicInst & I)111 void IGC::HalfPromotion::handleGenIntrinsic(llvm::GenIntrinsicInst& I)
112 {
113     GenISAIntrinsic::ID id = I.getIntrinsicID();
114     if (id == GenISAIntrinsic::GenISA_WaveAll ||
115         id == GenISAIntrinsic::GenISA_WavePrefix ||
116         id == GenISAIntrinsic::GenISA_WaveClustered)
117     {
118         Module* M = I.getParent()->getParent()->getParent();
119         llvm::IGCIRBuilder<> builder(&I);
120         std::vector<llvm::Value*> arguments;
121 
122         Function* pNewFunc = GenISAIntrinsic::getDeclaration(
123             M,
124             I.getIntrinsicID(),
125             builder.getFloatTy());
126 
127         for (unsigned i = 0; i < I.getNumArgOperands(); ++i)
128         {
129             if (I.getOperand(i)->getType()->isHalfTy())
130             {
131                 Value* op = builder.CreateFPExt(I.getOperand(i), builder.getFloatTy());
132                 arguments.push_back(op);
133             }
134             else
135             {
136                 arguments.push_back(I.getOperand(i));
137             }
138         }
139 
140         Value* f32Val = builder.CreateCall(
141             pNewFunc,
142             arguments);
143         Value* f16Val = builder.CreateFPTrunc(f32Val, builder.getHalfTy());
144         I.replaceAllUsesWith(f16Val);
145         I.eraseFromParent();
146         m_changed = true;
147     }
148 }
149 
visitFCmp(llvm::FCmpInst & CmpI)150 void HalfPromotion::visitFCmp(llvm::FCmpInst& CmpI)
151 {
152     if (CmpI.getOperand(0)->getType()->isHalfTy())
153     {
154         llvm::IGCIRBuilder<> builder(&CmpI);
155         Value* op1 = builder.CreateFPExt(CmpI.getOperand(0), builder.getFloatTy());
156         Value* op2 = builder.CreateFPExt(CmpI.getOperand(1), builder.getFloatTy());
157         Value* newOp = builder.CreateFCmp(CmpI.getPredicate(), op1, op2);
158         CmpI.replaceAllUsesWith(newOp);
159         m_changed = true;
160     }
161 }
162 
visitBinaryOperator(llvm::BinaryOperator & BI)163 void HalfPromotion::visitBinaryOperator(llvm::BinaryOperator& BI)
164 {
165     if (BI.getType()->isHalfTy() &&
166         (BI.getOpcode() == BinaryOperator::FAdd ||
167             BI.getOpcode() == BinaryOperator::FSub ||
168             BI.getOpcode() == BinaryOperator::FMul ||
169             BI.getOpcode() == BinaryOperator::FDiv))
170     {
171         llvm::IGCIRBuilder<> builder(&BI);
172         Value* op1 = builder.CreateFPExt(BI.getOperand(0), builder.getFloatTy());
173         Value* op2 = builder.CreateFPExt(BI.getOperand(1), builder.getFloatTy());
174         Value* newOp = builder.CreateBinOp(BI.getOpcode(), op1, op2);
175         Value* f16Val = builder.CreateFPTrunc(newOp, builder.getHalfTy());
176         BI.replaceAllUsesWith(f16Val);
177         m_changed = true;
178     }
179 }
180 
181 /*
182 
183   What about casts like these?
184   %162 = uitofp i32 %160 to half
185   %163 = fpext half %162 to float
186   %164 = fmul float %163, 1.600000e+01
187 
188   Is it safe to do this?
189   %162 = uitofp i32 %160 to float
190   %164 = fmul float %162, 1.600000e+01
191 
192 */
193 
visitCastInst(llvm::CastInst & CI)194 void HalfPromotion::visitCastInst(llvm::CastInst& CI)
195 {
196     if (CI.getType()->isHalfTy() &&
197         (CI.getOpcode() == CastInst::UIToFP ||
198             CI.getOpcode() == CastInst::SIToFP))
199     {
200         llvm::IGCIRBuilder<> builder(&CI);
201         Value* newOp = nullptr;
202         if (CI.getOpcode() == CastInst::UIToFP)
203         {
204             newOp = builder.CreateUIToFP(CI.getOperand(0), builder.getFloatTy());
205         }
206         else
207         {
208             newOp = builder.CreateSIToFP(CI.getOperand(0), builder.getFloatTy());
209         }
210         Value* f16Val = builder.CreateFPTrunc(newOp, builder.getHalfTy());
211         CI.replaceAllUsesWith(f16Val);
212         m_changed = true;
213     }
214     else if (CI.getOperand(0)->getType()->isHalfTy() &&
215         (CI.getOpcode() == CastInst::FPToUI ||
216             CI.getOpcode() == CastInst::FPToSI))
217     {
218         llvm::IGCIRBuilder<> builder(&CI);
219         Value* newOp = nullptr;
220         Value* f32Val = builder.CreateFPExt(CI.getOperand(0), builder.getFloatTy());
221         if (CI.getOpcode() == CastInst::FPToUI)
222         {
223             newOp = builder.CreateFPToUI(f32Val, CI.getType());
224         }
225         else
226         {
227             newOp = builder.CreateFPToSI(f32Val, CI.getType());
228         }
229         CI.replaceAllUsesWith(newOp);
230         m_changed = true;
231     }
232 }
233 
visitSelectInst(llvm::SelectInst & SI)234 void HalfPromotion::visitSelectInst(llvm::SelectInst& SI)
235 {
236     if (SI.getTrueValue()->getType()->isHalfTy())
237     {
238         llvm::IGCIRBuilder<> builder(&SI);
239         Value* opTrue = builder.CreateFPExt(SI.getTrueValue(), builder.getFloatTy());
240         Value* opFalse = builder.CreateFPExt(SI.getFalseValue(), builder.getFloatTy());
241         Value* newOp = builder.CreateSelect(SI.getCondition(), opTrue, opFalse);
242         Value* f16Val = builder.CreateFPTrunc(newOp, builder.getHalfTy());
243         SI.replaceAllUsesWith(f16Val);
244         m_changed = true;
245     }
246 }
247 
visitPHINode(llvm::PHINode & PHI)248 void HalfPromotion::visitPHINode(llvm::PHINode& PHI)
249 {
250     if (!PHI.getType()->isHalfTy())
251     {
252         return;
253     }
254 
255     llvm::IGCIRBuilder<> builder(&PHI);
256     llvm::PHINode* pNewPhi = llvm::PHINode::Create(builder.getFloatTy(), PHI.getNumIncomingValues(), "", &PHI);
257 
258     for (unsigned int i = 0; i < PHI.getNumIncomingValues(); ++i)
259     {
260         builder.SetInsertPoint(PHI.getIncomingBlock(i)->getTerminator());
261         Value* phiFloatValue = builder.CreateFPExt(PHI.getIncomingValue(i), builder.getFloatTy());
262         pNewPhi->addIncoming(phiFloatValue, PHI.getIncomingBlock(i));
263     }
264 
265     builder.SetInsertPoint(PHI.getParent()->getFirstNonPHI());
266     Value* f16Val = builder.CreateFPTrunc(pNewPhi, builder.getHalfTy());
267     PHI.replaceAllUsesWith(f16Val);
268     PHI.eraseFromParent();
269     m_changed = true;
270 }
271