1 //
2 //  TensorUtils.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/23.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef TensorUtils_hpp
10 #define TensorUtils_hpp
11 
12 #include <MNN/Tensor.hpp>
13 #include "Tensor_generated.h"
14 
15 #ifdef CONSTANT
16 #undef CONSTANT
17 #endif // CONSTANT
18 
19 namespace MNN {
20 class Backend;
21 struct TensorArrayAttr {
22     // array size is dynamic or not
23     bool isDynamicSize = false;
24     // elemShape is identical or not
25     bool isIdenticalShape = false;
26     // the number of element
27     uint32_t arraySize = 0;
28     // the shape of element
29     std::vector<std::vector<int>> elemShape;
30 };
31 struct QuantAttr {
32     float scale;
33     float zero = 0.0f;
34     float min  = -127.0f;
35     float max  = 127.0f;
36     DataType type = DataType_DT_INT8;
37 };
38 /** extra tensor info container */
39 struct Tensor::InsideDescribe {
40 public:
41     /** dimension format */
42     MNN_DATA_FORMAT dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
43     union {
44         /** Serperate memory offset*/
45         int offset;
46 
47         /** function used to free handle */
48         void (*handleFreeFunction)(void*);
49     } extra;
50 
51     enum MemoryType {
52         /** The tensor's memory come from Backend */
53         MEMORY_BACKEND = 0,
54 
55         /** host memory is owned by tensor or not */
56         MEMORY_HOST,
57 
58         /** The tensor don't has memory */
59         MEMORY_VIRTUAL,
60 
61         /** host memory is owned by tensor or not */
62         MEMORY_OUTSIDE,
63 
64     };
65     MemoryType memoryType = MEMORY_BACKEND;
66     /** for DEVICE tensor only. backend used to manage tensor's device memory. */
67     Backend* backend = nullptr;
68     /** for DEVICE tensor only. */
69     int useCount = 0;
70     enum Usage {
71         NORMAL,
72         INPUT,
73         OUTPUT,
74         CONSTANT,
75         /** Whether the tensor is a trainable parameter. Trainable parameter should be stored in a different area. */
76         TRAINABLE,
77     };
78     Usage usage = NORMAL;
79     struct View {
80         int32_t offset = 0;
81         int32_t stride[3] = {1, 1, 1};
82     };
83     struct Region {
84         View src;
85         View dst;
86         int32_t size[3] = {1, 1, 1};
87         Tensor* origin;
88     };
89     std::vector<Region> regions;
90     halide_dimension_t dims[MNN_MAX_TENSOR_DIM];
91     // TensorArray Attribute
92     std::shared_ptr<TensorArrayAttr> tensorArrayAttr;
93     // Tensor Quant Attribute
94     std::shared_ptr<QuantAttr> quantAttr;
95 };
96 typedef Tensor::InsideDescribe::Usage TensorUsage;
97 
98 /** tensor utils */
99 class MNN_PUBLIC TensorUtils {
100 public:
101     /**
102      * @brief get extra tensor info.
103      * @param tensor    given tensor.
104      * @return extra tensor info.
105      */
106     static Tensor::InsideDescribe* getDescribe(const Tensor* tensor);
107 
108     /**
109      * @brief copy shape from source tensor to dest tensor.
110      * @param source        shape prodiver tensor.
111      * @param dest          shape consumer tensor.
112      * @param copyFormat    copy data format or not.
113      */
114     static void copyShape(const Tensor* source, Tensor* dest, bool copyFormat = false);
115 
116     /**
117      * @brief set shape for dest tensor from a common int vector.
118      * @param dest          shape consumer tensor.
119      * @param alldims       dims info.
120      */
121     static void setShape(Tensor* dest, const std::vector<int>& alldims);
122 
123     /**
124      * auto update tensor's strides according to extents and reorder flags.
125      * @param tensor    given tensor.
126      */
127     static void setLinearLayout(Tensor* tensor);
128 
129     /**
130      * @brief call handle free function to clear handle of tensor.
131      * @param tensor    given tensor.
132      */
133     static void clearHandleData(Tensor* tensor);
134 
135     /**
136      * @brief compare tensor to expected with tolerance.
137      * @param compareTensor comparing tensor.
138      * @param toTensor      expected tensor.
139      * @param tolerance     tolerable error, any error less than this value will be ignored.
140      *                      for integer types, compare with `abs(v1 - v2) > tolerance`;
141      *                      for float types, see `overallTolerance`.
142      * @param overall       for float types only. compare with `abs(v1 - v2) / max(abs(allExpectValues))` if true,
143      *                      `abs(v1 - v2) / abs(v2)` otherwise.
144      * @param printsError   print error data or not.
145      * @param printsTensors print tensor data or not when meets error.
146      * @return equals within tolerance or not.
147      */
148     static bool compareTensors(const Tensor* compareTensor, const Tensor* toTensor, float tolerance = 0,
149                                bool overall = false, bool printsError = true, bool printsTensors = false);
150 
151     static void setupTensorInfo(const Tensor* tensor, Tensor* wrapTensor, MNN_DATA_FORMAT mMidFormat);
152     static Tensor::InsideDescribe::Region makeFullSlice(Tensor* input);
153     static bool regionIsFull(Tensor* input);
154     static bool isCopyRegion(const Tensor::InsideDescribe::Region& region);
155     static bool reshapeSlice(Tensor::InsideDescribe::Region& slice, int outside, int inside, int axis);
156     static bool fuseRegion(Tensor::InsideDescribe::Region& srcReg, Tensor::InsideDescribe::Region& dstReg);
157     static void adjustTensorForCompability(Tensor* t);
158     static Tensor::DimensionType getDimType(const Tensor* t);
159     static halide_type_t DataTypeToHalideType(DataType t);
160     static DataType HaildeTypeToDataType(halide_type_t t);
161     static std::vector<float> getQuantInfo(const Tensor* t);
162 };
163 } // namespace MNN
164 
165 #endif /* TensorDescribe_hpp */
166