1 #define CATCH_CONFIG_RUNNER
2 #include "test.h"
3 
4 #include <cmath>
5 
main(int argc,char ** argv)6 int main(int argc, char ** argv) {
7   return Catch::Session().run(argc, argv);
8 }
9 
10 namespace intgemm {
11 
CompareMSE(const float * float_ref,const float * int_ref,const float * int_test,std::size_t size,std::string test_info,float int_tolerance,float float_tolerance,float MSE_float_tolerance,float MSE_int_tolerance)12 void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
13              float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) {
14   float int_sum = 0.0, float_sum = 0.0;
15   for (std::size_t i = 0; i < size; ++i) {
16     float int_diff = int_ref[i] - int_test[i];
17     float float_diff = float_ref[i] - int_test[i];
18     CHECK_MESSAGE(std::fabs(int_diff) <= int_tolerance, test_info << "Inaccurate compared to int reference at " << i << ' ' << int_ref[i] << ' ' << int_test[i]);
19     CHECK_MESSAGE(std::fabs(float_diff) <= float_tolerance, test_info << "Inaccurate compared to float reference at " << i << ' ' << float_ref[i] << ' ' << int_test[i]);
20     int_sum += int_diff * int_diff;
21     float_sum += float_diff * float_diff;
22   }
23   CHECK_MESSAGE(std::fabs(sqrt(float_sum / size)) <= MSE_float_tolerance, test_info << "Float MSE = " << sqrt(float_sum / size));
24   CHECK_MESSAGE(std::fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size));
25 }
26 
27 } // namespace intgemm
28