1 //===- llvm/FixedPointBuilder.h - Builder for fixed-point ops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the FixedPointBuilder class, which is used as a convenient
10 // way to lower fixed-point arithmetic operations to LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_IR_FIXEDPOINTBUILDER_H
15 #define LLVM_IR_FIXEDPOINTBUILDER_H
16 
17 #include "llvm/ADT/APFixedPoint.h"
18 #include "llvm/IR/Constant.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/InstrTypes.h"
22 #include "llvm/IR/Instruction.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/IR/Value.h"
27 
28 namespace llvm {
29 
30 template <class IRBuilderTy> class FixedPointBuilder {
31   IRBuilderTy &B;
32 
33   Value *Convert(Value *Src, const FixedPointSemantics &SrcSema,
34                  const FixedPointSemantics &DstSema, bool DstIsInteger) {
35     unsigned SrcWidth = SrcSema.getWidth();
36     unsigned DstWidth = DstSema.getWidth();
37     unsigned SrcScale = SrcSema.getScale();
38     unsigned DstScale = DstSema.getScale();
39     bool SrcIsSigned = SrcSema.isSigned();
40     bool DstIsSigned = DstSema.isSigned();
41 
42     Type *DstIntTy = B.getIntNTy(DstWidth);
43 
44     Value *Result = Src;
45     unsigned ResultWidth = SrcWidth;
46 
47     // Downscale.
48     if (DstScale < SrcScale) {
49       // When converting to integers, we round towards zero. For negative
50       // numbers, right shifting rounds towards negative infinity. In this case,
51       // we can just round up before shifting.
52       if (DstIsInteger && SrcIsSigned) {
53         Value *Zero = Constant::getNullValue(Result->getType());
54         Value *IsNegative = B.CreateICmpSLT(Result, Zero);
55         Value *LowBits = ConstantInt::get(
56             B.getContext(), APInt::getLowBitsSet(ResultWidth, SrcScale));
57         Value *Rounded = B.CreateAdd(Result, LowBits);
58         Result = B.CreateSelect(IsNegative, Rounded, Result);
59       }
60 
61       Result = SrcIsSigned
62                    ? B.CreateAShr(Result, SrcScale - DstScale, "downscale")
63                    : B.CreateLShr(Result, SrcScale - DstScale, "downscale");
64     }
65 
66     if (!DstSema.isSaturated()) {
67       // Resize.
68       Result = B.CreateIntCast(Result, DstIntTy, SrcIsSigned, "resize");
69 
70       // Upscale.
71       if (DstScale > SrcScale)
72         Result = B.CreateShl(Result, DstScale - SrcScale, "upscale");
73     } else {
74       // Adjust the number of fractional bits.
75       if (DstScale > SrcScale) {
76         // Compare to DstWidth to prevent resizing twice.
77         ResultWidth = std::max(SrcWidth + DstScale - SrcScale, DstWidth);
78         Type *UpscaledTy = B.getIntNTy(ResultWidth);
79         Result = B.CreateIntCast(Result, UpscaledTy, SrcIsSigned, "resize");
80         Result = B.CreateShl(Result, DstScale - SrcScale, "upscale");
81       }
82 
83       // Handle saturation.
84       bool LessIntBits = DstSema.getIntegralBits() < SrcSema.getIntegralBits();
85       if (LessIntBits) {
86         Value *Max = ConstantInt::get(
87             B.getContext(),
88             APFixedPoint::getMax(DstSema).getValue().extOrTrunc(ResultWidth));
89         Value *TooHigh = SrcIsSigned ? B.CreateICmpSGT(Result, Max)
90                                      : B.CreateICmpUGT(Result, Max);
91         Result = B.CreateSelect(TooHigh, Max, Result, "satmax");
92       }
93       // Cannot overflow min to dest type if src is unsigned since all fixed
94       // point types can cover the unsigned min of 0.
95       if (SrcIsSigned && (LessIntBits || !DstIsSigned)) {
96         Value *Min = ConstantInt::get(
97             B.getContext(),
98             APFixedPoint::getMin(DstSema).getValue().extOrTrunc(ResultWidth));
99         Value *TooLow = B.CreateICmpSLT(Result, Min);
100         Result = B.CreateSelect(TooLow, Min, Result, "satmin");
101       }
102 
103       // Resize the integer part to get the final destination size.
104       if (ResultWidth != DstWidth)
105         Result = B.CreateIntCast(Result, DstIntTy, SrcIsSigned, "resize");
106     }
107     return Result;
108   }
109 
110   /// Get the common semantic for two semantics, with the added imposition that
111   /// saturated padded types retain the padding bit.
112   FixedPointSemantics
113   getCommonBinopSemantic(const FixedPointSemantics &LHSSema,
114                          const FixedPointSemantics &RHSSema) {
115     auto C = LHSSema.getCommonSemantics(RHSSema);
116     bool BothPadded =
117         LHSSema.hasUnsignedPadding() && RHSSema.hasUnsignedPadding();
118     return FixedPointSemantics(
119         C.getWidth() + (unsigned)(BothPadded && C.isSaturated()), C.getScale(),
120         C.isSigned(), C.isSaturated(), BothPadded);
121   }
122 
123   /// Given a floating point type and a fixed-point semantic, return a floating
124   /// point type which can accommodate the fixed-point semantic. This is either
125   /// \p Ty, or a floating point type with a larger exponent than Ty.
126   Type *getAccommodatingFloatType(Type *Ty, const FixedPointSemantics &Sema) {
127     const fltSemantics *FloatSema = &Ty->getFltSemantics();
128     while (!Sema.fitsInFloatSemantics(*FloatSema))
129       FloatSema = APFixedPoint::promoteFloatSemantics(FloatSema);
130     return Type::getFloatingPointTy(Ty->getContext(), *FloatSema);
131   }
132 
133 public:
134   FixedPointBuilder(IRBuilderTy &Builder) : B(Builder) {}
135 
136   /// Convert an integer value representing a fixed-point number from one
137   /// fixed-point semantic to another fixed-point semantic.
138   /// \p Src     - The source value
139   /// \p SrcSema - The fixed-point semantic of the source value
140   /// \p DstSema - The resulting fixed-point semantic
141   Value *CreateFixedToFixed(Value *Src, const FixedPointSemantics &SrcSema,
142                             const FixedPointSemantics &DstSema) {
143     return Convert(Src, SrcSema, DstSema, false);
144   }
145 
146   /// Convert an integer value representing a fixed-point number to an integer
147   /// with the given bit width and signedness.
148   /// \p Src         - The source value
149   /// \p SrcSema     - The fixed-point semantic of the source value
150   /// \p DstWidth    - The bit width of the result value
151   /// \p DstIsSigned - The signedness of the result value
152   Value *CreateFixedToInteger(Value *Src, const FixedPointSemantics &SrcSema,
153                               unsigned DstWidth, bool DstIsSigned) {
154     return Convert(
155         Src, SrcSema,
156         FixedPointSemantics::GetIntegerSemantics(DstWidth, DstIsSigned), true);
157   }
158 
159   /// Convert an integer value with the given signedness to an integer value
160   /// representing the given fixed-point semantic.
161   /// \p Src         - The source value
162   /// \p SrcIsSigned - The signedness of the source value
163   /// \p DstSema     - The resulting fixed-point semantic
164   Value *CreateIntegerToFixed(Value *Src, unsigned SrcIsSigned,
165                               const FixedPointSemantics &DstSema) {
166     return Convert(Src,
167                    FixedPointSemantics::GetIntegerSemantics(
168                        Src->getType()->getScalarSizeInBits(), SrcIsSigned),
169                    DstSema, false);
170   }
171 
172   Value *CreateFixedToFloating(Value *Src, const FixedPointSemantics &SrcSema,
173                                Type *DstTy) {
174     Value *Result;
175     Type *OpTy = getAccommodatingFloatType(DstTy, SrcSema);
176     // Convert the raw fixed-point value directly to floating point. If the
177     // value is too large to fit, it will be rounded, not truncated.
178     Result = SrcSema.isSigned() ? B.CreateSIToFP(Src, OpTy)
179                                 : B.CreateUIToFP(Src, OpTy);
180     // Rescale the integral-in-floating point by the scaling factor. This is
181     // lossless, except for overflow to infinity which is unlikely.
182     Result = B.CreateFMul(Result,
183         ConstantFP::get(OpTy, std::pow(2, -(int)SrcSema.getScale())));
184     if (OpTy != DstTy)
185       Result = B.CreateFPTrunc(Result, DstTy);
186     return Result;
187   }
188 
189   Value *CreateFloatingToFixed(Value *Src, const FixedPointSemantics &DstSema) {
190     bool UseSigned = DstSema.isSigned() || DstSema.hasUnsignedPadding();
191     Value *Result = Src;
192     Type *OpTy = getAccommodatingFloatType(Src->getType(), DstSema);
193     if (OpTy != Src->getType())
194       Result = B.CreateFPExt(Result, OpTy);
195     // Rescale the floating point value so that its significant bits (for the
196     // purposes of the conversion) are in the integral range.
197     Result = B.CreateFMul(Result,
198         ConstantFP::get(OpTy, std::pow(2, DstSema.getScale())));
199 
200     Type *ResultTy = B.getIntNTy(DstSema.getWidth());
201     if (DstSema.isSaturated()) {
202       Intrinsic::ID IID =
203           UseSigned ? Intrinsic::fptosi_sat : Intrinsic::fptoui_sat;
204       Result = B.CreateIntrinsic(IID, {ResultTy, OpTy}, {Result});
205     } else {
206       Result = UseSigned ? B.CreateFPToSI(Result, ResultTy)
207                          : B.CreateFPToUI(Result, ResultTy);
208     }
209 
210     // When saturating unsigned-with-padding using signed operations, we may
211     // get negative values. Emit an extra clamp to zero.
212     if (DstSema.isSaturated() && DstSema.hasUnsignedPadding()) {
213       Constant *Zero = Constant::getNullValue(Result->getType());
214       Result =
215           B.CreateSelect(B.CreateICmpSLT(Result, Zero), Zero, Result, "satmin");
216     }
217 
218     return Result;
219   }
220 
221   /// Add two fixed-point values and return the result in their common semantic.
222   /// \p LHS     - The left hand side
223   /// \p LHSSema - The semantic of the left hand side
224   /// \p RHS     - The right hand side
225   /// \p RHSSema - The semantic of the right hand side
226   Value *CreateAdd(Value *LHS, const FixedPointSemantics &LHSSema,
227                    Value *RHS, const FixedPointSemantics &RHSSema) {
228     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
229     bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
230 
231     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
232     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
233 
234     Value *Result;
235     if (CommonSema.isSaturated()) {
236       Intrinsic::ID IID = UseSigned ? Intrinsic::sadd_sat : Intrinsic::uadd_sat;
237       Result = B.CreateBinaryIntrinsic(IID, WideLHS, WideRHS);
238     } else {
239       Result = B.CreateAdd(WideLHS, WideRHS);
240     }
241 
242     return CreateFixedToFixed(Result, CommonSema,
243                               LHSSema.getCommonSemantics(RHSSema));
244   }
245 
246   /// Subtract two fixed-point values and return the result in their common
247   /// semantic.
248   /// \p LHS     - The left hand side
249   /// \p LHSSema - The semantic of the left hand side
250   /// \p RHS     - The right hand side
251   /// \p RHSSema - The semantic of the right hand side
252   Value *CreateSub(Value *LHS, const FixedPointSemantics &LHSSema,
253                    Value *RHS, const FixedPointSemantics &RHSSema) {
254     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
255     bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
256 
257     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
258     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
259 
260     Value *Result;
261     if (CommonSema.isSaturated()) {
262       Intrinsic::ID IID = UseSigned ? Intrinsic::ssub_sat : Intrinsic::usub_sat;
263       Result = B.CreateBinaryIntrinsic(IID, WideLHS, WideRHS);
264     } else {
265       Result = B.CreateSub(WideLHS, WideRHS);
266     }
267 
268     // Subtraction can end up below 0 for padded unsigned operations, so emit
269     // an extra clamp in that case.
270     if (CommonSema.isSaturated() && CommonSema.hasUnsignedPadding()) {
271       Constant *Zero = Constant::getNullValue(Result->getType());
272       Result =
273           B.CreateSelect(B.CreateICmpSLT(Result, Zero), Zero, Result, "satmin");
274     }
275 
276     return CreateFixedToFixed(Result, CommonSema,
277                               LHSSema.getCommonSemantics(RHSSema));
278   }
279 
280   /// Multiply two fixed-point values and return the result in their common
281   /// semantic.
282   /// \p LHS     - The left hand side
283   /// \p LHSSema - The semantic of the left hand side
284   /// \p RHS     - The right hand side
285   /// \p RHSSema - The semantic of the right hand side
286   Value *CreateMul(Value *LHS, const FixedPointSemantics &LHSSema,
287                    Value *RHS, const FixedPointSemantics &RHSSema) {
288     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
289     bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
290 
291     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
292     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
293 
294     Intrinsic::ID IID;
295     if (CommonSema.isSaturated()) {
296       IID = UseSigned ? Intrinsic::smul_fix_sat : Intrinsic::umul_fix_sat;
297     } else {
298       IID = UseSigned ? Intrinsic::smul_fix : Intrinsic::umul_fix;
299     }
300     Value *Result = B.CreateIntrinsic(
301         IID, {WideLHS->getType()},
302         {WideLHS, WideRHS, B.getInt32(CommonSema.getScale())});
303 
304     return CreateFixedToFixed(Result, CommonSema,
305                               LHSSema.getCommonSemantics(RHSSema));
306   }
307 
308   /// Divide two fixed-point values and return the result in their common
309   /// semantic.
310   /// \p LHS     - The left hand side
311   /// \p LHSSema - The semantic of the left hand side
312   /// \p RHS     - The right hand side
313   /// \p RHSSema - The semantic of the right hand side
314   Value *CreateDiv(Value *LHS, const FixedPointSemantics &LHSSema,
315                    Value *RHS, const FixedPointSemantics &RHSSema) {
316     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
317     bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
318 
319     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
320     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
321 
322     Intrinsic::ID IID;
323     if (CommonSema.isSaturated()) {
324       IID = UseSigned ? Intrinsic::sdiv_fix_sat : Intrinsic::udiv_fix_sat;
325     } else {
326       IID = UseSigned ? Intrinsic::sdiv_fix : Intrinsic::udiv_fix;
327     }
328     Value *Result = B.CreateIntrinsic(
329         IID, {WideLHS->getType()},
330         {WideLHS, WideRHS, B.getInt32(CommonSema.getScale())});
331 
332     return CreateFixedToFixed(Result, CommonSema,
333                               LHSSema.getCommonSemantics(RHSSema));
334   }
335 
336   /// Left shift a fixed-point value by an unsigned integer value. The integer
337   /// value can be any bit width.
338   /// \p LHS     - The left hand side
339   /// \p LHSSema - The semantic of the left hand side
340   /// \p RHS     - The right hand side
341   Value *CreateShl(Value *LHS, const FixedPointSemantics &LHSSema, Value *RHS) {
342     bool UseSigned = LHSSema.isSigned() || LHSSema.hasUnsignedPadding();
343 
344     RHS = B.CreateIntCast(RHS, LHS->getType(), /*IsSigned=*/false);
345 
346     Value *Result;
347     if (LHSSema.isSaturated()) {
348       Intrinsic::ID IID = UseSigned ? Intrinsic::sshl_sat : Intrinsic::ushl_sat;
349       Result = B.CreateBinaryIntrinsic(IID, LHS, RHS);
350     } else {
351       Result = B.CreateShl(LHS, RHS);
352     }
353 
354     return Result;
355   }
356 
357   /// Right shift a fixed-point value by an unsigned integer value. The integer
358   /// value can be any bit width.
359   /// \p LHS     - The left hand side
360   /// \p LHSSema - The semantic of the left hand side
361   /// \p RHS     - The right hand side
362   Value *CreateShr(Value *LHS, const FixedPointSemantics &LHSSema, Value *RHS) {
363     RHS = B.CreateIntCast(RHS, LHS->getType(), false);
364 
365     return LHSSema.isSigned() ? B.CreateAShr(LHS, RHS) : B.CreateLShr(LHS, RHS);
366   }
367 
368   /// Compare two fixed-point values for equality.
369   /// \p LHS     - The left hand side
370   /// \p LHSSema - The semantic of the left hand side
371   /// \p RHS     - The right hand side
372   /// \p RHSSema - The semantic of the right hand side
373   Value *CreateEQ(Value *LHS, const FixedPointSemantics &LHSSema,
374                   Value *RHS, const FixedPointSemantics &RHSSema) {
375     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
376 
377     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
378     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
379 
380     return B.CreateICmpEQ(WideLHS, WideRHS);
381   }
382 
383   /// Compare two fixed-point values for inequality.
384   /// \p LHS     - The left hand side
385   /// \p LHSSema - The semantic of the left hand side
386   /// \p RHS     - The right hand side
387   /// \p RHSSema - The semantic of the right hand side
388   Value *CreateNE(Value *LHS, const FixedPointSemantics &LHSSema,
389                   Value *RHS, const FixedPointSemantics &RHSSema) {
390     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
391 
392     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
393     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
394 
395     return B.CreateICmpNE(WideLHS, WideRHS);
396   }
397 
398   /// Compare two fixed-point values as LHS < RHS.
399   /// \p LHS     - The left hand side
400   /// \p LHSSema - The semantic of the left hand side
401   /// \p RHS     - The right hand side
402   /// \p RHSSema - The semantic of the right hand side
403   Value *CreateLT(Value *LHS, const FixedPointSemantics &LHSSema,
404                   Value *RHS, const FixedPointSemantics &RHSSema) {
405     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
406 
407     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
408     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
409 
410     return CommonSema.isSigned() ? B.CreateICmpSLT(WideLHS, WideRHS)
411                                  : B.CreateICmpULT(WideLHS, WideRHS);
412   }
413 
414   /// Compare two fixed-point values as LHS <= RHS.
415   /// \p LHS     - The left hand side
416   /// \p LHSSema - The semantic of the left hand side
417   /// \p RHS     - The right hand side
418   /// \p RHSSema - The semantic of the right hand side
419   Value *CreateLE(Value *LHS, const FixedPointSemantics &LHSSema,
420                   Value *RHS, const FixedPointSemantics &RHSSema) {
421     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
422 
423     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
424     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
425 
426     return CommonSema.isSigned() ? B.CreateICmpSLE(WideLHS, WideRHS)
427                                  : B.CreateICmpULE(WideLHS, WideRHS);
428   }
429 
430   /// Compare two fixed-point values as LHS > RHS.
431   /// \p LHS     - The left hand side
432   /// \p LHSSema - The semantic of the left hand side
433   /// \p RHS     - The right hand side
434   /// \p RHSSema - The semantic of the right hand side
435   Value *CreateGT(Value *LHS, const FixedPointSemantics &LHSSema,
436                   Value *RHS, const FixedPointSemantics &RHSSema) {
437     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
438 
439     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
440     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
441 
442     return CommonSema.isSigned() ? B.CreateICmpSGT(WideLHS, WideRHS)
443                                  : B.CreateICmpUGT(WideLHS, WideRHS);
444   }
445 
446   /// Compare two fixed-point values as LHS >= RHS.
447   /// \p LHS     - The left hand side
448   /// \p LHSSema - The semantic of the left hand side
449   /// \p RHS     - The right hand side
450   /// \p RHSSema - The semantic of the right hand side
451   Value *CreateGE(Value *LHS, const FixedPointSemantics &LHSSema,
452                   Value *RHS, const FixedPointSemantics &RHSSema) {
453     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
454 
455     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
456     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
457 
458     return CommonSema.isSigned() ? B.CreateICmpSGE(WideLHS, WideRHS)
459                                  : B.CreateICmpUGE(WideLHS, WideRHS);
460   }
461 };
462 
463 } // end namespace llvm
464 
465 #endif // LLVM_IR_FIXEDPOINTBUILDER_H
466