1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
11 
12 #include "utils/CPP/TypeTraits.h"
13 #include "utils/UnitTest/Test.h"
14 
15 #include <stdint.h>
16 
17 namespace __llvm_libc {
18 namespace testing {
19 namespace mpfr {
20 
21 struct Tolerance {
22   // Number of bits used to represent the fractional
23   // part of a value of type 'float'.
24   static constexpr unsigned int floatPrecision = 23;
25 
26   // Number of bits used to represent the fractional
27   // part of a value of type 'double'.
28   static constexpr unsigned int doublePrecision = 52;
29 
30   // The base precision of the number. For example, for values of
31   // type float, the base precision is the value |floatPrecision|.
32   unsigned int basePrecision;
33 
34   unsigned int width; // Number of valid LSB bits in |value|.
35 
36   // The bits in the tolerance value. The tolerance value will be
37   // sum(bits[width - i] * 2 ^ (- basePrecision - i)) for |i| in
38   // range [1, width].
39   uint32_t bits;
40 };
41 
42 enum class Operation : int {
43   Abs,
44   Ceil,
45   Cos,
46   Exp,
47   Exp2,
48   Floor,
49   Round,
50   Sin,
51   Trunc
52 };
53 
54 namespace internal {
55 
56 template <typename T>
57 bool compare(Operation op, T input, T libcOutput, const Tolerance &t);
58 
59 template <typename T> class MPFRMatcher : public testing::Matcher<T> {
60   static_assert(__llvm_libc::cpp::IsFloatingPointType<T>::Value,
61                 "MPFRMatcher can only be used with floating point values.");
62 
63   Operation operation;
64   T input;
65   Tolerance tolerance;
66   T matchValue;
67 
68 public:
MPFRMatcher(Operation op,T testInput,Tolerance & t)69   MPFRMatcher(Operation op, T testInput, Tolerance &t)
70       : operation(op), input(testInput), tolerance(t) {}
71 
match(T libcResult)72   bool match(T libcResult) {
73     matchValue = libcResult;
74     return internal::compare(operation, input, libcResult, tolerance);
75   }
76 
77   void explainError(testutils::StreamWrapper &OS) override;
78 };
79 
80 } // namespace internal
81 
82 template <typename T>
83 __attribute__((no_sanitize("address")))
getMPFRMatcher(Operation op,T input,Tolerance t)84 internal::MPFRMatcher<T> getMPFRMatcher(Operation op, T input, Tolerance t) {
85   static_assert(
86       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
87       "getMPFRMatcher can only be used to match floating point results.");
88   return internal::MPFRMatcher<T>(op, input, t);
89 }
90 
91 } // namespace mpfr
92 } // namespace testing
93 } // namespace __llvm_libc
94 
95 #define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance)                    \
96   EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher(          \
97                               op, input, tolerance))
98 
99 #define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance)                    \
100   ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher(          \
101                               op, input, tolerance))
102 
103 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
104