1 // 2 // TensorUtils.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/01/23. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef TensorUtils_hpp 10 #define TensorUtils_hpp 11 12 #include <MNN/Tensor.hpp> 13 #include "Tensor_generated.h" 14 15 #ifdef CONSTANT 16 #undef CONSTANT 17 #endif // CONSTANT 18 19 namespace MNN { 20 class Backend; 21 struct TensorArrayAttr { 22 // array size is dynamic or not 23 bool isDynamicSize = false; 24 // elemShape is identical or not 25 bool isIdenticalShape = false; 26 // the number of element 27 uint32_t arraySize = 0; 28 // the shape of element 29 std::vector<std::vector<int>> elemShape; 30 }; 31 struct QuantAttr { 32 float scale; 33 float zero = 0.0f; 34 float min = -127.0f; 35 float max = 127.0f; 36 DataType type = DataType_DT_INT8; 37 }; 38 /** extra tensor info container */ 39 struct Tensor::InsideDescribe { 40 public: 41 /** dimension format */ 42 MNN_DATA_FORMAT dimensionFormat = MNN_DATA_FORMAT_NC4HW4; 43 union { 44 /** Serperate memory offset*/ 45 int offset; 46 47 /** function used to free handle */ 48 void (*handleFreeFunction)(void*); 49 } extra; 50 51 enum MemoryType { 52 /** The tensor's memory come from Backend */ 53 MEMORY_BACKEND = 0, 54 55 /** host memory is owned by tensor or not */ 56 MEMORY_HOST, 57 58 /** The tensor don't has memory */ 59 MEMORY_VIRTUAL, 60 61 /** host memory is owned by tensor or not */ 62 MEMORY_OUTSIDE, 63 64 }; 65 MemoryType memoryType = MEMORY_BACKEND; 66 /** for DEVICE tensor only. backend used to manage tensor's device memory. */ 67 Backend* backend = nullptr; 68 /** for DEVICE tensor only. */ 69 int useCount = 0; 70 enum Usage { 71 NORMAL, 72 INPUT, 73 OUTPUT, 74 CONSTANT, 75 /** Whether the tensor is a trainable parameter. Trainable parameter should be stored in a different area. */ 76 TRAINABLE, 77 }; 78 Usage usage = NORMAL; 79 struct View { 80 int32_t offset = 0; 81 int32_t stride[3] = {1, 1, 1}; 82 }; 83 struct Region { 84 View src; 85 View dst; 86 int32_t size[3] = {1, 1, 1}; 87 Tensor* origin; 88 }; 89 std::vector<Region> regions; 90 halide_dimension_t dims[MNN_MAX_TENSOR_DIM]; 91 // TensorArray Attribute 92 std::shared_ptr<TensorArrayAttr> tensorArrayAttr; 93 // Tensor Quant Attribute 94 std::shared_ptr<QuantAttr> quantAttr; 95 }; 96 typedef Tensor::InsideDescribe::Usage TensorUsage; 97 98 /** tensor utils */ 99 class MNN_PUBLIC TensorUtils { 100 public: 101 /** 102 * @brief get extra tensor info. 103 * @param tensor given tensor. 104 * @return extra tensor info. 105 */ 106 static Tensor::InsideDescribe* getDescribe(const Tensor* tensor); 107 108 /** 109 * @brief copy shape from source tensor to dest tensor. 110 * @param source shape prodiver tensor. 111 * @param dest shape consumer tensor. 112 * @param copyFormat copy data format or not. 113 */ 114 static void copyShape(const Tensor* source, Tensor* dest, bool copyFormat = false); 115 116 /** 117 * @brief set shape for dest tensor from a common int vector. 118 * @param dest shape consumer tensor. 119 * @param alldims dims info. 120 */ 121 static void setShape(Tensor* dest, const std::vector<int>& alldims); 122 123 /** 124 * auto update tensor's strides according to extents and reorder flags. 125 * @param tensor given tensor. 126 */ 127 static void setLinearLayout(Tensor* tensor); 128 129 /** 130 * @brief call handle free function to clear handle of tensor. 131 * @param tensor given tensor. 132 */ 133 static void clearHandleData(Tensor* tensor); 134 135 /** 136 * @brief compare tensor to expected with tolerance. 137 * @param compareTensor comparing tensor. 138 * @param toTensor expected tensor. 139 * @param tolerance tolerable error, any error less than this value will be ignored. 140 * for integer types, compare with `abs(v1 - v2) > tolerance`; 141 * for float types, see `overallTolerance`. 142 * @param overall for float types only. compare with `abs(v1 - v2) / max(abs(allExpectValues))` if true, 143 * `abs(v1 - v2) / abs(v2)` otherwise. 144 * @param printsError print error data or not. 145 * @param printsTensors print tensor data or not when meets error. 146 * @return equals within tolerance or not. 147 */ 148 static bool compareTensors(const Tensor* compareTensor, const Tensor* toTensor, float tolerance = 0, 149 bool overall = false, bool printsError = true, bool printsTensors = false); 150 151 static void setupTensorInfo(const Tensor* tensor, Tensor* wrapTensor, MNN_DATA_FORMAT mMidFormat); 152 static Tensor::InsideDescribe::Region makeFullSlice(Tensor* input); 153 static bool regionIsFull(Tensor* input); 154 static bool isCopyRegion(const Tensor::InsideDescribe::Region& region); 155 static bool reshapeSlice(Tensor::InsideDescribe::Region& slice, int outside, int inside, int axis); 156 static bool fuseRegion(Tensor::InsideDescribe::Region& srcReg, Tensor::InsideDescribe::Region& dstReg); 157 static void adjustTensorForCompability(Tensor* t); 158 static Tensor::DimensionType getDimType(const Tensor* t); 159 static halide_type_t DataTypeToHalideType(DataType t); 160 static DataType HaildeTypeToDataType(halide_type_t t); 161 static std::vector<float> getQuantInfo(const Tensor* t); 162 }; 163 } // namespace MNN 164 165 #endif /* TensorDescribe_hpp */ 166