1 /*========================== begin_copyright_notice ============================
2
3 Copyright (C) 2017-2021 Intel Corporation
4
5 SPDX-License-Identifier: MIT
6
7 ============================= end_copyright_notice ===========================*/
8
9 //===----------------------------------------------------------------------===//
10 //
11 // The purpose of this pass is replace instructions using halfs with
12 // corresponding float counterparts.
13 //
14 // All unnecessary conversions get cleaned up before code gen.
15 //
16 //===----------------------------------------------------------------------===//
17
18
19 #include "HalfPromotion.h"
20 #include "Compiler/IGCPassSupport.h"
21 #include "GenISAIntrinsics/GenIntrinsics.h"
22 #include "IGCIRBuilder.h"
23
24 #include "common/LLVMWarningsPush.hpp"
25 #include <llvm/IR/Function.h>
26 #include "common/LLVMWarningsPop.hpp"
27
28
29 using namespace llvm;
30 using namespace IGC;
31
32 #define PASS_FLAG "half-promotion"
33 #define PASS_DESCRIPTION "Promotion of halfs to floats"
34 #define PASS_CFG_ONLY false
35 #define PASS_ANALYSIS false
36 IGC_INITIALIZE_PASS_BEGIN(HalfPromotion, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
37 IGC_INITIALIZE_PASS_END(HalfPromotion, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
38
39 char HalfPromotion::ID = 0;
40
HalfPromotion()41 HalfPromotion::HalfPromotion() : FunctionPass(ID)
42 {
43 initializeHalfPromotionPass(*PassRegistry::getPassRegistry());
44 }
45
runOnFunction(Function & F)46 bool HalfPromotion::runOnFunction(Function& F)
47 {
48 visit(F);
49 return m_changed;
50 }
51
visitCallInst(llvm::CallInst & I)52 void HalfPromotion::visitCallInst(llvm::CallInst& I)
53 {
54 if (llvm::isa<GenIntrinsicInst>(I) && I.getType()->isHalfTy())
55 {
56 handleGenIntrinsic(llvm::cast<GenIntrinsicInst>(I));
57 }
58 else if (llvm::isa<llvm::IntrinsicInst>(I) && I.getType()->isHalfTy())
59 {
60 handleLLVMIntrinsic(llvm::cast<IntrinsicInst>(I));
61 }
62 }
63
handleLLVMIntrinsic(llvm::IntrinsicInst & I)64 void IGC::HalfPromotion::handleLLVMIntrinsic(llvm::IntrinsicInst& I)
65 {
66 Intrinsic::ID id = I.getIntrinsicID();
67 if (id == Intrinsic::cos ||
68 id == Intrinsic::sin ||
69 id == Intrinsic::log2 ||
70 id == Intrinsic::exp2 ||
71 id == Intrinsic::sqrt ||
72 id == Intrinsic::floor ||
73 id == Intrinsic::ceil ||
74 id == Intrinsic::fabs ||
75 id == Intrinsic::pow ||
76 id == Intrinsic::fma ||
77 id == Intrinsic::maxnum ||
78 id == Intrinsic::minnum)
79 {
80 Module* M = I.getParent()->getParent()->getParent();
81 llvm::IGCIRBuilder<> builder(&I);
82 std::vector<llvm::Value*> arguments;
83
84 Function* pNewFunc = Intrinsic::getDeclaration(
85 M,
86 I.getIntrinsicID(),
87 builder.getFloatTy());
88
89 for (unsigned i = 0; i < I.getNumArgOperands(); ++i)
90 {
91 if (I.getOperand(i)->getType()->isHalfTy())
92 {
93 Value* op = builder.CreateFPExt(I.getOperand(i), builder.getFloatTy());
94 arguments.push_back(op);
95 }
96 else
97 {
98 arguments.push_back(I.getOperand(i));
99 }
100 }
101
102 Value* f32Val = builder.CreateCall(
103 pNewFunc,
104 arguments);
105 Value* f16Val = builder.CreateFPTrunc(f32Val, builder.getHalfTy());
106 I.replaceAllUsesWith(f16Val);
107 m_changed = true;
108 }
109 }
110
handleGenIntrinsic(llvm::GenIntrinsicInst & I)111 void IGC::HalfPromotion::handleGenIntrinsic(llvm::GenIntrinsicInst& I)
112 {
113 GenISAIntrinsic::ID id = I.getIntrinsicID();
114 if (id == GenISAIntrinsic::GenISA_WaveAll ||
115 id == GenISAIntrinsic::GenISA_WavePrefix ||
116 id == GenISAIntrinsic::GenISA_WaveClustered)
117 {
118 Module* M = I.getParent()->getParent()->getParent();
119 llvm::IGCIRBuilder<> builder(&I);
120 std::vector<llvm::Value*> arguments;
121
122 Function* pNewFunc = GenISAIntrinsic::getDeclaration(
123 M,
124 I.getIntrinsicID(),
125 builder.getFloatTy());
126
127 for (unsigned i = 0; i < I.getNumArgOperands(); ++i)
128 {
129 if (I.getOperand(i)->getType()->isHalfTy())
130 {
131 Value* op = builder.CreateFPExt(I.getOperand(i), builder.getFloatTy());
132 arguments.push_back(op);
133 }
134 else
135 {
136 arguments.push_back(I.getOperand(i));
137 }
138 }
139
140 Value* f32Val = builder.CreateCall(
141 pNewFunc,
142 arguments);
143 Value* f16Val = builder.CreateFPTrunc(f32Val, builder.getHalfTy());
144 I.replaceAllUsesWith(f16Val);
145 I.eraseFromParent();
146 m_changed = true;
147 }
148 }
149
visitFCmp(llvm::FCmpInst & CmpI)150 void HalfPromotion::visitFCmp(llvm::FCmpInst& CmpI)
151 {
152 if (CmpI.getOperand(0)->getType()->isHalfTy())
153 {
154 llvm::IGCIRBuilder<> builder(&CmpI);
155 Value* op1 = builder.CreateFPExt(CmpI.getOperand(0), builder.getFloatTy());
156 Value* op2 = builder.CreateFPExt(CmpI.getOperand(1), builder.getFloatTy());
157 Value* newOp = builder.CreateFCmp(CmpI.getPredicate(), op1, op2);
158 CmpI.replaceAllUsesWith(newOp);
159 m_changed = true;
160 }
161 }
162
visitBinaryOperator(llvm::BinaryOperator & BI)163 void HalfPromotion::visitBinaryOperator(llvm::BinaryOperator& BI)
164 {
165 if (BI.getType()->isHalfTy() &&
166 (BI.getOpcode() == BinaryOperator::FAdd ||
167 BI.getOpcode() == BinaryOperator::FSub ||
168 BI.getOpcode() == BinaryOperator::FMul ||
169 BI.getOpcode() == BinaryOperator::FDiv))
170 {
171 llvm::IGCIRBuilder<> builder(&BI);
172 Value* op1 = builder.CreateFPExt(BI.getOperand(0), builder.getFloatTy());
173 Value* op2 = builder.CreateFPExt(BI.getOperand(1), builder.getFloatTy());
174 Value* newOp = builder.CreateBinOp(BI.getOpcode(), op1, op2);
175 Value* f16Val = builder.CreateFPTrunc(newOp, builder.getHalfTy());
176 BI.replaceAllUsesWith(f16Val);
177 m_changed = true;
178 }
179 }
180
181 /*
182
183 What about casts like these?
184 %162 = uitofp i32 %160 to half
185 %163 = fpext half %162 to float
186 %164 = fmul float %163, 1.600000e+01
187
188 Is it safe to do this?
189 %162 = uitofp i32 %160 to float
190 %164 = fmul float %162, 1.600000e+01
191
192 */
193
visitCastInst(llvm::CastInst & CI)194 void HalfPromotion::visitCastInst(llvm::CastInst& CI)
195 {
196 if (CI.getType()->isHalfTy() &&
197 (CI.getOpcode() == CastInst::UIToFP ||
198 CI.getOpcode() == CastInst::SIToFP))
199 {
200 llvm::IGCIRBuilder<> builder(&CI);
201 Value* newOp = nullptr;
202 if (CI.getOpcode() == CastInst::UIToFP)
203 {
204 newOp = builder.CreateUIToFP(CI.getOperand(0), builder.getFloatTy());
205 }
206 else
207 {
208 newOp = builder.CreateSIToFP(CI.getOperand(0), builder.getFloatTy());
209 }
210 Value* f16Val = builder.CreateFPTrunc(newOp, builder.getHalfTy());
211 CI.replaceAllUsesWith(f16Val);
212 m_changed = true;
213 }
214 else if (CI.getOperand(0)->getType()->isHalfTy() &&
215 (CI.getOpcode() == CastInst::FPToUI ||
216 CI.getOpcode() == CastInst::FPToSI))
217 {
218 llvm::IGCIRBuilder<> builder(&CI);
219 Value* newOp = nullptr;
220 Value* f32Val = builder.CreateFPExt(CI.getOperand(0), builder.getFloatTy());
221 if (CI.getOpcode() == CastInst::FPToUI)
222 {
223 newOp = builder.CreateFPToUI(f32Val, CI.getType());
224 }
225 else
226 {
227 newOp = builder.CreateFPToSI(f32Val, CI.getType());
228 }
229 CI.replaceAllUsesWith(newOp);
230 m_changed = true;
231 }
232 }
233
visitSelectInst(llvm::SelectInst & SI)234 void HalfPromotion::visitSelectInst(llvm::SelectInst& SI)
235 {
236 if (SI.getTrueValue()->getType()->isHalfTy())
237 {
238 llvm::IGCIRBuilder<> builder(&SI);
239 Value* opTrue = builder.CreateFPExt(SI.getTrueValue(), builder.getFloatTy());
240 Value* opFalse = builder.CreateFPExt(SI.getFalseValue(), builder.getFloatTy());
241 Value* newOp = builder.CreateSelect(SI.getCondition(), opTrue, opFalse);
242 Value* f16Val = builder.CreateFPTrunc(newOp, builder.getHalfTy());
243 SI.replaceAllUsesWith(f16Val);
244 m_changed = true;
245 }
246 }
247
visitPHINode(llvm::PHINode & PHI)248 void HalfPromotion::visitPHINode(llvm::PHINode& PHI)
249 {
250 if (!PHI.getType()->isHalfTy())
251 {
252 return;
253 }
254
255 llvm::IGCIRBuilder<> builder(&PHI);
256 llvm::PHINode* pNewPhi = llvm::PHINode::Create(builder.getFloatTy(), PHI.getNumIncomingValues(), "", &PHI);
257
258 for (unsigned int i = 0; i < PHI.getNumIncomingValues(); ++i)
259 {
260 builder.SetInsertPoint(PHI.getIncomingBlock(i)->getTerminator());
261 Value* phiFloatValue = builder.CreateFPExt(PHI.getIncomingValue(i), builder.getFloatTy());
262 pNewPhi->addIncoming(phiFloatValue, PHI.getIncomingBlock(i));
263 }
264
265 builder.SetInsertPoint(PHI.getParent()->getFirstNonPHI());
266 Value* f16Val = builder.CreateFPTrunc(pNewPhi, builder.getHalfTy());
267 PHI.replaceAllUsesWith(f16Val);
268 PHI.eraseFromParent();
269 m_changed = true;
270 }
271