1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
11 
12 #include "utils/CPP/TypeTraits.h"
13 #include "utils/UnitTest/Test.h"
14 
15 #include <stdint.h>
16 
17 namespace __llvm_libc {
18 namespace testing {
19 namespace mpfr {
20 
21 enum class Operation : int {
22   // Operations with take a single floating point number as input
23   // and produce a single floating point number as output. The input
24   // and output floating point numbers are of the same kind.
25   BeginUnaryOperationsSingleOutput,
26   Abs,
27   Ceil,
28   Cos,
29   Exp,
30   Exp2,
31   Expm1,
32   Floor,
33   Mod2PI,
34   ModPIOver2,
35   ModPIOver4,
36   Round,
37   Sin,
38   Sqrt,
39   Tan,
40   Trunc,
41   EndUnaryOperationsSingleOutput,
42 
43   // Operations which take a single floating point nubmer as input
44   // but produce two outputs. The first ouput is a floating point
45   // number of the same type as the input. The second output is of type
46   // 'int'.
47   BeginUnaryOperationsTwoOutputs,
48   Frexp, // Floating point output, the first output, is the fractional part.
49   EndUnaryOperationsTwoOutputs,
50 
51   // Operations wich take two floating point nubmers of the same type as
52   // input and produce a single floating point number of the same type as
53   // output.
54   BeginBinaryOperationsSingleOutput,
55   Hypot,
56   EndBinaryOperationsSingleOutput,
57 
58   // Operations which take two floating point numbers of the same type as
59   // input and produce two outputs. The first output is a floating nubmer of
60   // the same type as the inputs. The second output is af type 'int'.
61   BeginBinaryOperationsTwoOutputs,
62   RemQuo, // The first output, the floating point output, is the remainder.
63   EndBinaryOperationsTwoOutputs,
64 
65   // Operations which take three floating point nubmers of the same type as
66   // input and produce a single floating point number of the same type as
67   // output.
68   BeginTernaryOperationsSingleOuput,
69   Fma,
70   EndTernaryOperationsSingleOutput,
71 };
72 
73 template <typename T> struct BinaryInput {
74   static_assert(
75       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
76       "Template parameter of BinaryInput must be a floating point type.");
77 
78   using Type = T;
79   T x, y;
80 };
81 
82 template <typename T> struct TernaryInput {
83   static_assert(
84       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
85       "Template parameter of TernaryInput must be a floating point type.");
86 
87   using Type = T;
88   T x, y, z;
89 };
90 
91 template <typename T> struct BinaryOutput {
92   T f;
93   int i;
94 };
95 
96 namespace internal {
97 
98 template <typename T1, typename T2>
99 struct AreMatchingBinaryInputAndBinaryOutput {
100   static constexpr bool value = false;
101 };
102 
103 template <typename T>
104 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
105   static constexpr bool value = cpp::IsFloatingPointType<T>::Value;
106 };
107 
108 template <typename T>
109 bool compareUnaryOperationSingleOutput(Operation op, T input, T libcOutput,
110                                        double t);
111 template <typename T>
112 bool compareUnaryOperationTwoOutputs(Operation op, T input,
113                                      const BinaryOutput<T> &libcOutput,
114                                      double t);
115 template <typename T>
116 bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
117                                       const BinaryOutput<T> &libcOutput,
118                                       double t);
119 
120 template <typename T>
121 bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
122                                      T libcOutput, double t);
123 
124 template <typename T>
125 bool compareTernaryOperationOneOutput(Operation op,
126                                       const TernaryInput<T> &input,
127                                       T libcOutput, double t);
128 
129 template <typename T>
130 void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
131                                             testutils::StreamWrapper &OS);
132 template <typename T>
133 void explainUnaryOperationTwoOutputsError(Operation op, T input,
134                                           const BinaryOutput<T> &matchValue,
135                                           testutils::StreamWrapper &OS);
136 template <typename T>
137 void explainBinaryOperationTwoOutputsError(Operation op,
138                                            const BinaryInput<T> &input,
139                                            const BinaryOutput<T> &matchValue,
140                                            testutils::StreamWrapper &OS);
141 
142 template <typename T>
143 void explainBinaryOperationOneOutputError(Operation op,
144                                           const BinaryInput<T> &input,
145                                           T matchValue,
146                                           testutils::StreamWrapper &OS);
147 
148 template <typename T>
149 void explainTernaryOperationOneOutputError(Operation op,
150                                            const TernaryInput<T> &input,
151                                            T matchValue,
152                                            testutils::StreamWrapper &OS);
153 
154 template <Operation op, typename InputType, typename OutputType>
155 class MPFRMatcher : public testing::Matcher<OutputType> {
156   InputType input;
157   OutputType matchValue;
158   double ulpTolerance;
159 
160 public:
161   MPFRMatcher(InputType testInput, double ulpTolerance)
162       : input(testInput), ulpTolerance(ulpTolerance) {}
163 
164   bool match(OutputType libcResult) {
165     matchValue = libcResult;
166     return match(input, matchValue, ulpTolerance);
167   }
168 
169   void explainError(testutils::StreamWrapper &OS) override {
170     explainError(input, matchValue, OS);
171   }
172 
173 private:
174   template <typename T> static bool match(T in, T out, double tolerance) {
175     return compareUnaryOperationSingleOutput(op, in, out, tolerance);
176   }
177 
178   template <typename T>
179   static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
180     return compareUnaryOperationTwoOutputs(op, in, out, tolerance);
181   }
182 
183   template <typename T>
184   static bool match(const BinaryInput<T> &in, T out, double tolerance) {
185     return compareBinaryOperationOneOutput(op, in, out, tolerance);
186   }
187 
188   template <typename T>
189   static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
190                     double tolerance) {
191     return compareBinaryOperationTwoOutputs(op, in, out, tolerance);
192   }
193 
194   template <typename T>
195   static bool match(const TernaryInput<T> &in, T out, double tolerance) {
196     return compareTernaryOperationOneOutput(op, in, out, tolerance);
197   }
198 
199   template <typename T>
200   static void explainError(T in, T out, testutils::StreamWrapper &OS) {
201     explainUnaryOperationSingleOutputError(op, in, out, OS);
202   }
203 
204   template <typename T>
205   static void explainError(T in, const BinaryOutput<T> &out,
206                            testutils::StreamWrapper &OS) {
207     explainUnaryOperationTwoOutputsError(op, in, out, OS);
208   }
209 
210   template <typename T>
211   static void explainError(const BinaryInput<T> &in, const BinaryOutput<T> &out,
212                            testutils::StreamWrapper &OS) {
213     explainBinaryOperationTwoOutputsError(op, in, out, OS);
214   }
215 
216   template <typename T>
217   static void explainError(const BinaryInput<T> &in, T out,
218                            testutils::StreamWrapper &OS) {
219     explainBinaryOperationOneOutputError(op, in, out, OS);
220   }
221 
222   template <typename T>
223   static void explainError(const TernaryInput<T> &in, T out,
224                            testutils::StreamWrapper &OS) {
225     explainTernaryOperationOneOutputError(op, in, out, OS);
226   }
227 };
228 
229 } // namespace internal
230 
231 // Return true if the input and ouput types for the operation op are valid
232 // types.
233 template <Operation op, typename InputType, typename OutputType>
234 constexpr bool isValidOperation() {
235   return (Operation::BeginUnaryOperationsSingleOutput < op &&
236           op < Operation::EndUnaryOperationsSingleOutput &&
237           cpp::IsSame<InputType, OutputType>::Value &&
238           cpp::IsFloatingPointType<InputType>::Value) ||
239          (Operation::BeginUnaryOperationsTwoOutputs < op &&
240           op < Operation::EndUnaryOperationsTwoOutputs &&
241           cpp::IsFloatingPointType<InputType>::Value &&
242           cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) ||
243          (Operation::BeginBinaryOperationsSingleOutput < op &&
244           op < Operation::EndBinaryOperationsSingleOutput &&
245           cpp::IsFloatingPointType<OutputType>::Value &&
246           cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) ||
247          (Operation::BeginBinaryOperationsTwoOutputs < op &&
248           op < Operation::EndBinaryOperationsTwoOutputs &&
249           internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
250                                                           OutputType>::value) ||
251          (Operation::BeginTernaryOperationsSingleOuput < op &&
252           op < Operation::EndTernaryOperationsSingleOutput &&
253           cpp::IsFloatingPointType<OutputType>::Value &&
254           cpp::IsSame<InputType, TernaryInput<OutputType>>::Value);
255 }
256 
257 template <Operation op, typename InputType, typename OutputType>
258 __attribute__((no_sanitize("address")))
259 cpp::EnableIfType<isValidOperation<op, InputType, OutputType>(),
260                   internal::MPFRMatcher<op, InputType, OutputType>>
261 getMPFRMatcher(InputType input, OutputType outputUnused, double t) {
262   return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
263 }
264 
265 enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };
266 
267 template <typename T> T Round(T x, RoundingMode mode);
268 
269 template <typename T> bool RoundToLong(T x, long &result);
270 template <typename T> bool RoundToLong(T x, RoundingMode mode, long &result);
271 
272 } // namespace mpfr
273 } // namespace testing
274 } // namespace __llvm_libc
275 
276 #define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance)                    \
277   EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>(      \
278                               input, matchValue, tolerance))
279 
280 #define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance)                    \
281   ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>(      \
282                               input, matchValue, tolerance))
283 
284 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
285