1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
11
12 #include "utils/CPP/TypeTraits.h"
13 #include "utils/UnitTest/Test.h"
14
15 #include <stdint.h>
16
17 namespace __llvm_libc {
18 namespace testing {
19 namespace mpfr {
20
21 struct Tolerance {
22 // Number of bits used to represent the fractional
23 // part of a value of type 'float'.
24 static constexpr unsigned int floatPrecision = 23;
25
26 // Number of bits used to represent the fractional
27 // part of a value of type 'double'.
28 static constexpr unsigned int doublePrecision = 52;
29
30 // The base precision of the number. For example, for values of
31 // type float, the base precision is the value |floatPrecision|.
32 unsigned int basePrecision;
33
34 unsigned int width; // Number of valid LSB bits in |value|.
35
36 // The bits in the tolerance value. The tolerance value will be
37 // sum(bits[width - i] * 2 ^ (- basePrecision - i)) for |i| in
38 // range [1, width].
39 uint32_t bits;
40 };
41
42 enum class Operation : int {
43 Abs,
44 Ceil,
45 Cos,
46 Exp,
47 Exp2,
48 Floor,
49 Round,
50 Sin,
51 Trunc
52 };
53
54 namespace internal {
55
56 template <typename T>
57 bool compare(Operation op, T input, T libcOutput, const Tolerance &t);
58
59 template <typename T> class MPFRMatcher : public testing::Matcher<T> {
60 static_assert(__llvm_libc::cpp::IsFloatingPointType<T>::Value,
61 "MPFRMatcher can only be used with floating point values.");
62
63 Operation operation;
64 T input;
65 Tolerance tolerance;
66 T matchValue;
67
68 public:
MPFRMatcher(Operation op,T testInput,Tolerance & t)69 MPFRMatcher(Operation op, T testInput, Tolerance &t)
70 : operation(op), input(testInput), tolerance(t) {}
71
match(T libcResult)72 bool match(T libcResult) {
73 matchValue = libcResult;
74 return internal::compare(operation, input, libcResult, tolerance);
75 }
76
77 void explainError(testutils::StreamWrapper &OS) override;
78 };
79
80 } // namespace internal
81
82 template <typename T>
83 __attribute__((no_sanitize("address")))
getMPFRMatcher(Operation op,T input,Tolerance t)84 internal::MPFRMatcher<T> getMPFRMatcher(Operation op, T input, Tolerance t) {
85 static_assert(
86 __llvm_libc::cpp::IsFloatingPointType<T>::Value,
87 "getMPFRMatcher can only be used to match floating point results.");
88 return internal::MPFRMatcher<T>(op, input, t);
89 }
90
91 } // namespace mpfr
92 } // namespace testing
93 } // namespace __llvm_libc
94
95 #define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance) \
96 EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \
97 op, input, tolerance))
98
99 #define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance) \
100 ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \
101 op, input, tolerance))
102
103 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
104