1 //
2 //  TestUtils.h
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/15.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef TestUtils_h
10 #define TestUtils_h
11 
12 #include <assert.h>
13 #include <stdio.h>
14 #include <functional>
15 #include <string>
16 #include <MNN/MNNForwardType.h>
17 #include <MNN/Tensor.hpp>
18 #include <math.h>
19 #include <iostream>
20 #include "core/Backend.hpp"
21 #include "MNN_generated.h"
22 /**
23  * @brief dispatch payload on all available backends
24  * @param payload   test to perform
25  */
26 void dispatch(std::function<void(MNNForwardType)> payload);
27 /**
28  * @brief dispatch payload on given backend
29  * @param payload   test to perform
30  * @param backend   given backend
31  */
32 void dispatch(std::function<void(MNNForwardType)> payload, MNNForwardType backend);
33 
34 /**
35  @brief check the result with the ground truth
36  @param result data
37  @param rightData
38  @param size
39  @param threshold
40  */
41 template <typename T>
checkVector(const T * result,const T * rightData,int size,T threshold)42 bool checkVector(const T* result, const T* rightData, int size, T threshold){
43     MNN_ASSERT(result != nullptr);
44     MNN_ASSERT(rightData != nullptr);
45     MNN_ASSERT(size >= 0);
46     for(int i = 0; i < size; ++i){
47         if(fabs(result[i] - rightData[i]) > threshold){
48             std::cout << "right: " << rightData[i] << ", compute: " << result[i] << std::endl;
49             return false;
50         }
51     }
52     return true;
53 }
54 
55 template <typename T>
checkVectorByRelativeError(const T * result,const T * rightData,int size,float rtol)56 bool checkVectorByRelativeError(const T* result, const T* rightData, int size, float rtol) {
57     MNN_ASSERT(result != nullptr);
58     MNN_ASSERT(rightData != nullptr);
59     MNN_ASSERT(size >= 0);
60 
61     float maxValue = 0.0f;
62     for(int i = 0; i < size; ++i){
63         maxValue = fmax(fabs(rightData[i]), maxValue);
64     }
65     for(int i = 0; i < size; ++i){
66         if (fabs(result[i] - rightData[i]) > maxValue * rtol) {
67             std::cout << i << ": right: " << rightData[i] << ", compute: " << result[i] << std::endl;
68             return false;
69         }
70     }
71     return true;
72 }
73 
74 #ifdef MNN_SUPPORT_BF16
75 // simulate bf16, prune fp32 tailing precision to bf16 precision
convertFP32Precision(float fp32Value)76 inline float convertFP32Precision(float fp32Value) {
77     int32_t* s32Value = (int32_t*)(&fp32Value);
78     *s32Value &= 0xffff0000;
79     return fp32Value;
80 }
81 #else
82 // simulate fp16
convertFP32Precision(float fp32Value)83 inline float convertFP32Precision(float fp32Value) {
84     // todo: convert exp part and fraction part.
85     return fp32Value;
86 }
87 
88 #endif
89 
keepFP32Precision(float fp32Value)90 inline float keepFP32Precision(float fp32Value) {
91     return fp32Value;
92 }
93 
94 using ConvertFP32 = float(*)(float fp32Value);
95 const static ConvertFP32 FP32Converter[MNN::BackendConfig::Precision_Low + 1] = {keepFP32Precision, keepFP32Precision, convertFP32Precision};
96 
97 #endif /* TestUtils_h */
98