1 //
2 // TestUtils.h
3 // MNN
4 //
5 // Created by MNN on 2019/01/15.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #ifndef TestUtils_h
10 #define TestUtils_h
11
12 #include <assert.h>
13 #include <stdio.h>
14 #include <functional>
15 #include <string>
16 #include <MNN/MNNForwardType.h>
17 #include <MNN/Tensor.hpp>
18 #include <math.h>
19 #include <iostream>
20 #include "core/Backend.hpp"
21 #include "MNN_generated.h"
22 /**
23 * @brief dispatch payload on all available backends
24 * @param payload test to perform
25 */
26 void dispatch(std::function<void(MNNForwardType)> payload);
27 /**
28 * @brief dispatch payload on given backend
29 * @param payload test to perform
30 * @param backend given backend
31 */
32 void dispatch(std::function<void(MNNForwardType)> payload, MNNForwardType backend);
33
34 /**
35 @brief check the result with the ground truth
36 @param result data
37 @param rightData
38 @param size
39 @param threshold
40 */
41 template <typename T>
checkVector(const T * result,const T * rightData,int size,T threshold)42 bool checkVector(const T* result, const T* rightData, int size, T threshold){
43 MNN_ASSERT(result != nullptr);
44 MNN_ASSERT(rightData != nullptr);
45 MNN_ASSERT(size >= 0);
46 for(int i = 0; i < size; ++i){
47 if(fabs(result[i] - rightData[i]) > threshold){
48 std::cout << "right: " << rightData[i] << ", compute: " << result[i] << std::endl;
49 return false;
50 }
51 }
52 return true;
53 }
54
55 template <typename T>
checkVectorByRelativeError(const T * result,const T * rightData,int size,float rtol)56 bool checkVectorByRelativeError(const T* result, const T* rightData, int size, float rtol) {
57 MNN_ASSERT(result != nullptr);
58 MNN_ASSERT(rightData != nullptr);
59 MNN_ASSERT(size >= 0);
60
61 float maxValue = 0.0f;
62 for(int i = 0; i < size; ++i){
63 maxValue = fmax(fabs(rightData[i]), maxValue);
64 }
65 for(int i = 0; i < size; ++i){
66 if (fabs(result[i] - rightData[i]) > maxValue * rtol) {
67 std::cout << i << ": right: " << rightData[i] << ", compute: " << result[i] << std::endl;
68 return false;
69 }
70 }
71 return true;
72 }
73
74 #ifdef MNN_SUPPORT_BF16
75 // simulate bf16, prune fp32 tailing precision to bf16 precision
convertFP32Precision(float fp32Value)76 inline float convertFP32Precision(float fp32Value) {
77 int32_t* s32Value = (int32_t*)(&fp32Value);
78 *s32Value &= 0xffff0000;
79 return fp32Value;
80 }
81 #else
82 // simulate fp16
convertFP32Precision(float fp32Value)83 inline float convertFP32Precision(float fp32Value) {
84 // todo: convert exp part and fraction part.
85 return fp32Value;
86 }
87
88 #endif
89
keepFP32Precision(float fp32Value)90 inline float keepFP32Precision(float fp32Value) {
91 return fp32Value;
92 }
93
94 using ConvertFP32 = float(*)(float fp32Value);
95 const static ConvertFP32 FP32Converter[MNN::BackendConfig::Precision_Low + 1] = {keepFP32Precision, keepFP32Precision, convertFP32Precision};
96
97 #endif /* TestUtils_h */
98