1 //
2 //  TensorUtils.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/08/11.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "core/TensorUtils.hpp"
10 #include <float.h>
11 #include <math.h>
12 #include <stdio.h>
13 #include <cmath>
14 #include <cstring>
15 #include "core/Backend.hpp"
16 #include "core/Macro.h"
17 
18 namespace MNN {
getDescribe(const Tensor * tensor)19 Tensor::InsideDescribe* TensorUtils::getDescribe(const Tensor* tensor) {
20     return tensor->mDescribe;
21 }
regionIsFull(Tensor * input)22 bool TensorUtils::regionIsFull(Tensor* input) {
23     auto des = TensorUtils::getDescribe(input);
24     if (des->memoryType != Tensor::InsideDescribe::MEMORY_VIRTUAL) {
25         return true;
26     }
27     int size = 1;
28     for (int i = 0; i < input->dimensions(); ++i) {
29         size *= input->length(i);
30     }
31     int regionSize = 0;
32     for (auto& region : des->regions) {
33         regionSize += region.size[1] * region.size[0] * region.size[2];
34     }
35     return regionSize == size;
36 }
37 
makeFullSlice(Tensor * input)38 Tensor::InsideDescribe::Region TensorUtils::makeFullSlice(Tensor* input) {
39     Tensor::InsideDescribe::Region totalSlice;
40     totalSlice.src.offset = 0;
41     totalSlice.dst.offset = 0;
42     totalSlice.origin     = input;
43     for (int i = 0; i < input->dimensions(); ++i) {
44         totalSlice.size[2] *= input->length(i);
45     }
46     totalSlice.dst.stride[1] = totalSlice.size[2];
47     totalSlice.dst.stride[0] = totalSlice.size[2];
48     totalSlice.src.stride[1] = totalSlice.size[2];
49     totalSlice.src.stride[0] = totalSlice.size[2];
50     return totalSlice;
51 }
reshapeSlice(Tensor::InsideDescribe::Region & slice,int outside,int inside,int axis)52 bool TensorUtils::reshapeSlice(Tensor::InsideDescribe::Region& slice, int outside, int inside, int axis) {
53     if (slice.size[1] == 1 && slice.size[0] == 1 && slice.size[2] == outside * inside * axis) {
54         slice.size[0]       = outside;
55         slice.size[2]       = inside;
56         slice.size[1]       = axis;
57         slice.dst.stride[0] = inside * axis;
58         slice.dst.stride[1] = inside;
59 
60         auto originStride   = slice.src.stride[2];
61         slice.src.stride[0] = originStride * inside * axis;
62         slice.src.stride[1] = originStride * inside;
63         return true;
64     }
65     if (slice.size[0] == outside && slice.size[1] == axis && slice.size[2] == inside) {
66         return true;
67     }
68     return false;
69 }
70 
setupTensorInfo(const Tensor * tensor,Tensor * wrapTensor,MNN_DATA_FORMAT mMidFormat)71 void TensorUtils::setupTensorInfo(const Tensor* tensor, Tensor* wrapTensor, MNN_DATA_FORMAT mMidFormat) {
72     TensorUtils::getDescribe(wrapTensor)->dimensionFormat = mMidFormat;
73     auto tensorFormat                                     = TensorUtils::getDescribe(tensor)->dimensionFormat;
74     bool originCaffeFormat = (tensorFormat == MNN_DATA_FORMAT_NCHW || tensorFormat == MNN_DATA_FORMAT_NC4HW4);
75     bool wrapCaffeFormat   = (mMidFormat == MNN_DATA_FORMAT_NCHW || mMidFormat == MNN_DATA_FORMAT_NC4HW4);
76     bool originTfFormat    = (tensorFormat == MNN_DATA_FORMAT_NHWC || tensorFormat == MNN_DATA_FORMAT_NHWC4);
77     bool wrapTfFormat      = (mMidFormat == MNN_DATA_FORMAT_NHWC || mMidFormat == MNN_DATA_FORMAT_NHWC4);
78     if ((originCaffeFormat && wrapCaffeFormat) || (originTfFormat && wrapTfFormat)) {
79         TensorUtils::copyShape(tensor, wrapTensor);
80     } else if (originCaffeFormat && wrapTfFormat) {
81         for (int i = 1; i < wrapTensor->dimensions() - 1; ++i) {
82             wrapTensor->setLength(i, tensor->length(i + 1));
83         }
84         wrapTensor->setLength(0, tensor->length(0));
85         wrapTensor->setLength(wrapTensor->dimensions() - 1, tensor->length(1));
86     } else if (originTfFormat && wrapCaffeFormat) {
87         for (int i = 2; i < wrapTensor->dimensions(); ++i) {
88             wrapTensor->setLength(i, tensor->length(i - 1));
89         }
90         wrapTensor->setLength(0, tensor->length(0));
91         wrapTensor->setLength(1, tensor->length(tensor->dimensions() - 1));
92     } else {
93         // will not reach here
94         MNN_ASSERT(false);
95     }
96     TensorUtils::setLinearLayout(wrapTensor);
97     wrapTensor->buffer().type = tensor->getType();
98 }
99 
copyShape(const Tensor * source,Tensor * dest,bool copyFormat)100 void TensorUtils::copyShape(const Tensor* source, Tensor* dest, bool copyFormat) {
101     auto& ob      = dest->buffer();
102     auto& ib      = source->buffer();
103     ob.dimensions = ib.dimensions;
104     ::memcpy(ob.dim, ib.dim, ib.dimensions * sizeof(halide_dimension_t));
105     if (copyFormat) {
106         getDescribe(dest)->dimensionFormat = getDescribe(source)->dimensionFormat;
107     }
108 }
109 
setShape(Tensor * dest,const std::vector<int> & alldims)110 void TensorUtils::setShape(Tensor* dest, const std::vector<int>& alldims) {
111     auto& ob      = dest->buffer();
112     ob.dimensions = alldims.size();
113     int stride = 1;
114     for (int i = alldims.size() - 1; i >= 0; --i) {
115         ob.dim[i].stride = stride;
116         ob.dim[i].extent = alldims[i];
117         stride *= alldims[i];
118     }
119     return;
120 }
121 
setLinearLayout(Tensor * tensor)122 void TensorUtils::setLinearLayout(Tensor* tensor) {
123     auto& buffer = tensor->buffer();
124     int size     = 1;
125     for (int i = 0; i < buffer.dimensions; ++i) {
126         auto index  = buffer.dimensions - i - 1;
127         auto extent = buffer.dim[index].extent;
128         if (1 == index && tensor->mDescribe->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) {
129             extent = ROUND_UP(extent, 4);
130         }
131         buffer.dim[index].stride = size;
132         size *= extent;
133     }
134 }
135 
clearHandleData(Tensor * tensor)136 void TensorUtils::clearHandleData(Tensor* tensor) {
137     if (tensor->buffer().type.code != halide_type_handle) {
138         return;
139     }
140     auto handle = tensor->host<void*>();
141     if (nullptr == handle) {
142         return;
143     }
144 
145     MNN_ASSERT(tensor->mDescribe->extra.handleFreeFunction != nullptr);
146     for (int i = 0; i < tensor->elementSize(); ++i) {
147         if (nullptr != handle[i]) {
148             tensor->mDescribe->extra.handleFreeFunction(handle[i]);
149             handle[i] = nullptr;
150         }
151     }
152 }
153 
createHostPlanar(const Tensor * source)154 static const Tensor* createHostPlanar(const Tensor* source) {
155     // check
156     auto bnType        = MNN_FORWARD_CPU;
157     auto tensorBackend = TensorUtils::getDescribe(source)->backend;
158     if (tensorBackend) {
159         bnType = tensorBackend->type();
160     }
161     bool device = bnType != MNN_FORWARD_CPU;
162     bool chunky = TensorUtils::getDescribe(source)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4;
163 
164     // no convert needed
165     if (!device && !chunky) {
166         return source;
167     }
168 
169     // convert
170     if (chunky) {
171         Tensor* result = source->createHostTensorFromDevice(source, false);
172         if (result->getDimensionType() == MNN::Tensor::TENSORFLOW) {
173             TensorUtils::getDescribe(result)->dimensionFormat = MNN_DATA_FORMAT_NHWC;
174         } else {
175             TensorUtils::getDescribe(result)->dimensionFormat = MNN_DATA_FORMAT_NCHW;
176         }
177         TensorUtils::setLinearLayout(result);
178 
179         if (device) {
180             source->copyToHostTensor(result);
181         } else {
182             Backend::Info info;
183             info.type = MNN_FORWARD_CPU;
184             std::shared_ptr<Runtime> runtime(MNNGetExtraRuntimeCreator(MNN_FORWARD_CPU)->onCreate(info));
185             auto backend = runtime->onCreate();
186             backend->onCopyBuffer(source, result);
187             delete backend;
188         }
189         return result;
190     } else {
191         return source->createHostTensorFromDevice(source, true);
192     }
193 }
194 
195 template <typename T>
copyTensorToFloat(const Tensor * source,double * dest)196 static void copyTensorToFloat(const Tensor* source, double* dest) {
197     auto srcData = source->host<T>();
198     auto size    = source->elementSize();
199     for (int i = 0; i < size; ++i) {
200         dest[i] = srcData[i];
201     }
202 }
203 
equals(const double * pa,const double * pb,size_t size,double tolerance,double epsilon,bool overall,bool prints)204 static bool equals(const double* pa, const double* pb, size_t size, double tolerance, double epsilon, bool overall,
205                    bool prints) {
206     // get max if using overall torelance
207     double max = fabs(pb[0]);
208     if (overall) {
209         for (int i = 1; i < size; i++) {
210             max = std::max(max, fabs(pb[i]));
211         }
212     }
213 
214     // compare
215     for (int i = 0; i < size; i++) {
216         float va = pa[i], vb = pb[i];
217         if (std::isinf(va) && std::isinf(vb)) {
218             continue;
219         }
220         if (fabs(va) < epsilon && fabs(vb) < epsilon) {
221             continue;
222         }
223         float div = overall ? max : fabsf(vb);
224         if (fabsf(va - vb) / div > tolerance) {
225             if (prints) {
226                 MNN_PRINT("%d: %f != %f\n", i, va, vb);
227             }
228             return false;
229         }
230     }
231     return true;
232 }
233 
compareTensors(const Tensor * compare,const Tensor * expect,float tolerance,bool overall,bool printsErrors,bool printsTensors)234 bool TensorUtils::compareTensors(const Tensor* compare, const Tensor* expect, float tolerance, bool overall,
235                                  bool printsErrors, bool printsTensors) {
236     // type
237     if (compare->getType().code != expect->getType().code || compare->getType().bits != expect->getType().bits) {
238         if (printsErrors) {
239             MNN_PRINT("NOT equal in type: %d/%d - %d/%d.\n", compare->getType().code, compare->getType().bits,
240                       expect->getType().code, expect->getType().bits);
241         }
242         return false;
243     }
244 
245     // dimensions
246     if (compare->dimensions() != expect->dimensions()) {
247         if (printsErrors) {
248             MNN_PRINT("NOT equal in dimensions: %d - %d.\n", compare->dimensions(), expect->dimensions());
249         }
250         return false;
251     }
252     for (int i = 0; i < compare->dimensions(); i++) {
253         if (compare->length(i) == expect->length(i)) {
254             continue;
255         }
256         if (printsErrors) {
257             MNN_PRINT("NOT equal in dimensions[%d]: %d - %d.\n", i, compare->length(i), expect->length(i));
258         }
259         return false;
260     }
261 
262     // convert to host if needed
263     auto a = createHostPlanar(compare), b = createHostPlanar(expect);
264 
265     // get value as double
266     auto size = expect->elementSize();
267     std::vector<double> expectValue(expect->elementSize(), 0.0f);
268     std::vector<double> compareValue(compare->elementSize(), 0.0f);
269 
270     auto result = false;
271     if (b->buffer().type.code == halide_type_uint) {
272         switch (b->buffer().type.bits) {
273             case 8:
274                 copyTensorToFloat<uint8_t>(a, compareValue.data());
275                 copyTensorToFloat<uint8_t>(b, expectValue.data());
276                 break;
277             case 16:
278                 copyTensorToFloat<uint16_t>(a, compareValue.data());
279                 copyTensorToFloat<uint16_t>(b, expectValue.data());
280                 break;
281             case 32:
282                 copyTensorToFloat<uint32_t>(a, compareValue.data());
283                 copyTensorToFloat<uint32_t>(b, expectValue.data());
284                 break;
285             case 64:
286                 copyTensorToFloat<uint64_t>(a, compareValue.data());
287                 copyTensorToFloat<uint64_t>(b, expectValue.data());
288                 break;
289             default:
290                 break;
291         }
292     } else if (b->buffer().type.code == halide_type_int) {
293         switch (b->buffer().type.bits) {
294             case 8:
295                 copyTensorToFloat<int8_t>(a, compareValue.data());
296                 copyTensorToFloat<int8_t>(b, expectValue.data());
297                 break;
298             case 16:
299                 copyTensorToFloat<int16_t>(a, compareValue.data());
300                 copyTensorToFloat<int16_t>(b, expectValue.data());
301                 break;
302             case 32:
303                 copyTensorToFloat<int32_t>(a, compareValue.data());
304                 copyTensorToFloat<int32_t>(b, expectValue.data());
305                 break;
306             case 64:
307                 copyTensorToFloat<int64_t>(a, compareValue.data());
308                 copyTensorToFloat<int64_t>(b, expectValue.data());
309                 break;
310             default:
311                 break;
312         }
313     } else if (b->buffer().type.code == halide_type_float) {
314         switch (b->buffer().type.bits) {
315             case 32:
316                 copyTensorToFloat<float>(a, compareValue.data());
317                 copyTensorToFloat<float>(b, expectValue.data());
318                 break;
319             default:
320                 break;
321         }
322     } else {
323         if (printsErrors) {
324             MNN_PRINT("unsupported data type.");
325         }
326     }
327     auto epsilon = FLT_EPSILON;
328     if ((NULL != compareValue.data()) && (NULL != expectValue.data())) {
329         result = equals(compareValue.data(), expectValue.data(), size, tolerance, epsilon, overall, printsErrors);
330     }
331 
332     // clean up
333     if (a != compare) {
334         delete a;
335     }
336     if (b != expect) {
337         delete b;
338     }
339     return result;
340 }
341 
342 // is copy only region
isCopyRegion(const Tensor::InsideDescribe::Region & region)343 bool TensorUtils::isCopyRegion(const Tensor::InsideDescribe::Region& region) {
344     bool eq = true;
345     for (int i = 0; i < 3; i++) {
346         eq &= ((region.src.stride[i] == region.dst.stride[i]) || (region.size[i] <= 1));
347     }
348     return eq;
349 }
350 
351 // compute offset through region
offsetCompute(Tensor::InsideDescribe::Region reg,int offset,bool backward)352 static inline int offsetCompute(Tensor::InsideDescribe::Region reg, int offset, bool backward) {
353     if (backward) {
354         auto tmp = reg.src;
355         reg.src = reg.dst;
356         reg.dst = tmp;
357     }
358     int res = 0;
359     for (int i = 0; i < 3; i++) {
360         if (reg.size[i] > 1) {
361             res += offset / reg.src.stride[i] * reg.dst.stride[i];
362             offset %= reg.src.stride[i];
363         }
364     }
365     return res;
366 }
367 
368 // expand src stride with expand value
expandSrc(std::vector<int> & src,std::vector<int> & dst,std::vector<int> & size,int expandValue)369 static inline bool expandSrc(std::vector<int>& src, std::vector<int>& dst, std::vector<int>& size, int expandValue) {
370     if (expandValue <= 0) {
371         return false;
372     }
373     for (int i = size.size()-1; i >= 0; i--) {
374         int splitSize = expandValue / src[i];
375         if (!(expandValue % src[i] || size[i] % splitSize)) {
376             src.insert(src.begin()+i, expandValue);
377             dst.insert(dst.begin()+i, splitSize * dst[i]);
378             size[i] /= splitSize;
379             size.insert(size.begin()+i+1, splitSize);
380             return true;
381         }
382     }
383     return false;
384 }
385 // expand stride and size with expand value
expandStrideSize(int * src,int * dst,int * size,int & num,int expandValue)386 static inline bool expandStrideSize(int* src, int* dst, int* size, int& num, int expandValue) {
387 #define MNN_3_INT_INSERT(x, i, y) if (i == 2) { x[2] = y; } else if (i == 1) { x[2] = x[1]; x[1] = y; } else if (i == 0) { x[2] = x[1]; x[1] = x[0]; x[0] = y; } else { return false; }
388     for (int i = num-1; i >= 0; i--) {
389         int splitSize = expandValue / src[i];
390         if (!(expandValue % src[i] || size[i] % splitSize)) {
391             MNN_3_INT_INSERT(src, i, expandValue)
392             MNN_3_INT_INSERT(dst, i, (splitSize * dst[i]))
393             size[i] /= splitSize;
394             MNN_3_INT_INSERT(size, (i+1), splitSize)
395             if (++num > 3) return false;
396             return true;
397         }
398     }
399     return false;
400 #undef MNN_3_INT_INSERT
401 }
402 
403 // fuse srcRegion and dstRegion to dstRegion if return true
fuseRegion(Tensor::InsideDescribe::Region & srcReg,Tensor::InsideDescribe::Region & dstReg)404 bool TensorUtils::fuseRegion(Tensor::InsideDescribe::Region& srcReg, Tensor::InsideDescribe::Region& dstReg) {
405     // src data isnot full data of dst
406     if (srcReg.dst.offset > dstReg.src.offset ||
407         srcReg.dst.stride[1] > srcReg.size[2] ||
408         srcReg.dst.stride[2] > srcReg.size[1] * srcReg.size[2]) {
409         return false;
410     }
411     int dstTotalSize = 1, srcTotalSize = 1;
412     for (int i = 0; i < 3; i++) {
413         if (dstReg.size[i] > 1) {
414             dstTotalSize *= dstReg.size[i];
415         }
416         if (srcReg.size[i] > 1) {
417             srcTotalSize *= srcReg.size[i];
418         }
419     }
420     // src data is not full data of dst
421     if (dstTotalSize > srcTotalSize) {
422         return false;
423     }
424     // dont deal size > 1 && stride <= 0
425     for (int i = 0; i < 3; i++) {
426         if (srcReg.size[i] > 1 && (srcReg.src.stride[i] <= 0 || srcReg.dst.stride[i] <= 0)) {
427             return false;
428         }
429         if (dstReg.size[i] > 1 && (dstReg.src.stride[i] <= 0 || dstReg.dst.stride[i] <= 0)) {
430             return false;
431         }
432     }
433     // src copy fuse
434     if (isCopyRegion(srcReg)) {
435         dstReg.origin = srcReg.origin;
436         dstReg.src.offset += srcReg.src.offset - srcReg.dst.offset;
437         return true;
438     }
439     // dst copy fuse
440     if (isCopyRegion(dstReg) && dstTotalSize == srcTotalSize) {
441         int srcOff = dstReg.src.offset - srcReg.dst.offset;
442         int dstOff = dstReg.dst.offset;
443         srcOff = offsetCompute(srcReg, srcOff, true) + srcReg.src.offset;
444         if (srcReg.src.stride[2] > 0 && srcOff % srcReg.src.stride[2] != 0) {
445             // when transpose + slice, offset is not align can't fuse
446             return false;
447         }
448         dstReg.origin = srcReg.origin;
449         dstReg.dst = srcReg.dst;
450         dstReg.src = srcReg.src;
451         dstReg.src.offset = srcOff;
452         dstReg.dst.offset = dstOff;
453         dstReg.size[0] = srcReg.size[0];
454         dstReg.size[1] = srcReg.size[1];
455         dstReg.size[2] = srcReg.size[2];
456         return true;
457     }
458 #define MNN_FAST_FUSE_WITHOUT_STL
459 #ifdef MNN_FAST_FUSE_WITHOUT_STL
460     // general fuse
461     int srcDst[3], srcSrc[3], dstSrc[3], dstDst[3], srcSize[3], dstSize[3], newSrc[3], dstStride[3], srcStride[3];
462 #define MNN_3_INT_INIT(x, y) { x[0] = y; x[1] = y; x[2] = y; }
463     MNN_3_INT_INIT(dstStride, -1)
464     MNN_3_INT_INIT(srcStride, -1)
465 #undef MNN_3_INT_INIT
466     int srcNum = 0, dstNum = 0, sizeNum = 0;
467     for (int i = 0; i < 3; i++) {
468         if (srcReg.size[i] > 1) {
469             srcStride[srcNum] = srcReg.dst.stride[i];
470             srcDst[srcNum]    = srcReg.dst.stride[i];
471             srcSrc[srcNum]    = srcReg.src.stride[i];
472             srcSize[srcNum]   = srcReg.size[i];
473             srcNum++;
474         }
475         if (dstReg.size[i] > 1) {
476             dstStride[dstNum] = dstReg.src.stride[i];
477             dstDst[dstNum]    = dstReg.dst.stride[i];
478             dstSrc[dstNum]    = dstReg.src.stride[i];
479             dstSize[dstNum]   = dstReg.size[i];
480             dstNum++;
481         }
482     }
483     sizeNum = dstNum;
484 #define MNN_3_INT_DIFF(r, x, y, i) if ((x[i] != y[0]) && (x[i] != y[1]) && (x[i] != y[2])) { if (r > 0) { return false; } else { r = x[i]; } }
485     int srcExtra = -1, dstExtra = -1;
486     MNN_3_INT_DIFF(srcExtra, srcStride, dstStride, 0)
487     MNN_3_INT_DIFF(srcExtra, srcStride, dstStride, 1)
488     MNN_3_INT_DIFF(srcExtra, srcStride, dstStride, 2)
489     MNN_3_INT_DIFF(dstExtra, dstStride, srcStride, 0)
490     MNN_3_INT_DIFF(dstExtra, dstStride, srcStride, 1)
491     MNN_3_INT_DIFF(dstExtra, dstStride, srcStride, 2)
492 #undef MNN_3_INT_DIFF
493     if (dstExtra > 0) {
494         if (!expandStrideSize(srcDst, srcSrc, srcSize, srcNum, dstExtra)) {
495             return false;
496         }
497     }
498     if (srcExtra > 0) {
499         if (!expandStrideSize(dstSrc, dstDst, dstSize, dstNum, srcExtra)) {
500             return false;
501         }
502     }
503     // reorder srcSrc to newSrc by align srcDst and dstSrc
504     for (int i = 0; i < dstNum; i++) {
505         int index = 0;
506         for (int j = 0; j < srcNum; j++) {
507             if (dstSrc[j] == srcDst[i]) {
508                 index = j;
509             }
510         }
511         newSrc[index] = srcSrc[i];
512     }
513     // set final size and set expandIdx if expand val is 1
514     int expandIdx = -1;
515     if (dstNum > sizeNum) {
516         for (int i = 2; i >= 0; i--) {
517             if (i < dstNum) {
518                 if (dstSize[i] == 1) {
519                     expandIdx = i;
520                 }
521                 dstReg.size[i] = dstSize[i];
522             } else {
523                 dstReg.size[i] = 1;
524             }
525         }
526     }
527 #else
528     // general fuse
529     std::set<int> dstStride, srcStride, dstDiff, srcDiff;
530     std::vector<int> dstDst, dstSrc, srcDst, srcSrc, newSrc, dstSize, srcSize;
531     for (int i = 0; i < 3; i++) {
532         if (srcReg.size[i] > 1) {
533             srcStride.insert(srcReg.dst.stride[i]);
534             srcDst.push_back(srcReg.dst.stride[i]);
535             srcSrc.push_back(srcReg.src.stride[i]);
536             srcSize.push_back(srcReg.size[i]);
537         }
538         if (dstReg.size[i] > 1) {
539             dstStride.insert(dstReg.src.stride[i]);
540             dstDst.push_back(dstReg.dst.stride[i]);
541             dstSrc.push_back(dstReg.src.stride[i]);
542             dstSize.push_back(dstReg.size[i]);
543         }
544     }
545     int sizeNum = dstSize.size();
546     std::set_difference(dstStride.begin(), dstStride.end(), srcStride.begin(), srcStride.end(), std::inserter(dstDiff, dstDiff.begin()));
547     std::set_difference(srcStride.begin(), srcStride.end(), dstStride.begin(), dstStride.end(), std::inserter(srcDiff, srcDiff.begin()));
548     if (dstDiff.size() > 1 || srcDiff.size() > 1) {
549         // many diff stride, now dont deal
550         return false;
551     }
552     // expand stride when middle tensor's stride diff
553     if (!dstDiff.empty()) {
554         if (!expandSrc(srcDst, srcSrc, srcSize, *dstDiff.begin())) {
555             return false;
556         }
557     }
558     if (!srcDiff.empty()) {
559         if (!expandSrc(dstSrc, dstDst, dstSize, *srcDiff.begin())) {
560             return false;
561         }
562     }
563     if (dstSize.size() > 3) {
564         // need splite region, dont deal
565         return false;
566     }
567     // reorder srcSrc to newSrc by align srcDst and dstSrc
568     newSrc.resize(srcSrc.size());
569     for (int i = 0; i < dstSrc.size(); i++) {
570         int index = std::distance(dstSrc.begin(), std::find(dstSrc.begin(), dstSrc.end(), srcDst[i]));
571         newSrc[index] = srcSrc[i];
572     }
573     // set final size and set expandIdx if expand val is 1
574     int expandIdx = -1;
575     if (dstSize.size() > sizeNum) {
576         for (int i = 2; i >= 0; i--) {
577             if (i < dstSize.size()) {
578                 if (dstSize[i] == 1) {
579                     expandIdx = i;
580                 }
581                 dstReg.size[i] = dstSize[i];
582             } else {
583                 dstReg.size[i] = 1;
584             }
585         }
586     }
587 #endif
588     int idx = 0;
589     for (int i = 0; i < 3; i++) {
590         if (dstReg.size[i] > 1 || i == expandIdx) {
591             dstReg.src.stride[i] = newSrc[idx];
592             dstReg.dst.stride[i] = dstDst[idx++];
593         }
594     }
595     dstReg.origin = srcReg.origin;
596     dstReg.src.offset = offsetCompute(srcReg, dstReg.src.offset - srcReg.dst.offset, true) + srcReg.src.offset;
597     return true;
598 }
adjustTensorForCompability(Tensor * newTensor)599 void TensorUtils::adjustTensorForCompability(Tensor* newTensor) {
600     if (newTensor->dimensions() < 4) {
601         for (int n = newTensor->dimensions(); n < 4; ++n) {
602             newTensor->setLength(n, 1);
603         }
604     }
605 }
606 
getDimType(const Tensor * t)607 Tensor::DimensionType TensorUtils::getDimType(const Tensor* t) {
608     auto format = TensorUtils::getDescribe(t)->dimensionFormat;
609     switch (format) {
610         case MNN_DATA_FORMAT_NCHW:
611             return Tensor::CAFFE;
612         case MNN_DATA_FORMAT_NC4HW4:
613             return Tensor::CAFFE_C4;
614         case MNN_DATA_FORMAT_NHWC:
615             return Tensor::TENSORFLOW;
616         default:
617             break;
618     }
619     return Tensor::TENSORFLOW;
620 }
621 
DataTypeToHalideType(DataType t)622 halide_type_t TensorUtils::DataTypeToHalideType(DataType t) {
623     switch (t) {
624         case DataType_DT_DOUBLE:
625         case DataType_DT_FLOAT:
626             return halide_type_of<float>();
627         case DataType_DT_BFLOAT16:
628             return halide_type_t(halide_type_float, 16);
629         case DataType_DT_QINT32:
630         case DataType_DT_INT32:
631         case DataType_DT_BOOL:
632         case DataType_DT_INT64:
633             return halide_type_of<int32_t>();
634         case DataType_DT_QINT8:
635         case DataType_DT_INT8:
636             return halide_type_of<int8_t>();
637         case DataType_DT_QUINT8:
638         case DataType_DT_UINT8:
639             return halide_type_of<uint8_t>();
640         case DataType_DT_QUINT16:
641         case DataType_DT_UINT16:
642             return halide_type_of<uint16_t>();
643         case DataType_DT_QINT16:
644         case DataType_DT_INT16:
645             return halide_type_of<int16_t>();
646         case DataType_DT_STRING:
647         default:
648             MNN_PRINT("Unsupported data type!");
649             MNN_ASSERT(false);
650             return halide_type_of<float>();
651     }
652 }
653 
HaildeTypeToDataType(halide_type_t t)654 DataType TensorUtils::HaildeTypeToDataType(halide_type_t t) {
655     if (t == halide_type_of<int8_t>()) {
656         return DataType_DT_INT8;
657     }
658     if (t == halide_type_of<int16_t>()) {
659         return DataType_DT_INT16;
660     }
661     if (t == halide_type_of<int32_t>()) {
662         return DataType_DT_INT32;
663     }
664     if (t == halide_type_of<int64_t>()) {
665         return DataType_DT_INT64;
666     }
667     if (t == halide_type_of<uint8_t>()) {
668         return DataType_DT_UINT8;
669     }
670     if (t == halide_type_of<uint16_t>()) {
671         return DataType_DT_UINT16;
672     }
673     if (t == halide_type_t(halide_type_float, 16)) {
674         return DataType_DT_BFLOAT16;
675     }
676     if (t == halide_type_of<float>()) {
677         return DataType_DT_FLOAT;
678     }
679     if (t == halide_type_of<double>()) {
680         return DataType_DT_DOUBLE;
681     }
682     MNN_PRINT("Unsupported data type!");
683     MNN_ASSERT(false);
684     return DataType_DT_INVALID;
685 }
getQuantInfo(const Tensor * t)686 std::vector<float> TensorUtils::getQuantInfo(const Tensor* t) {
687     float scale = getDescribe(t)->quantAttr ? getDescribe(t)->quantAttr->scale : 0.0f;
688     float zero = getDescribe(t)->quantAttr ? getDescribe(t)->quantAttr->zero : 0.0f;
689     float min = getDescribe(t)->quantAttr ? getDescribe(t)->quantAttr->min : -127.0f;
690     float max = getDescribe(t)->quantAttr ? getDescribe(t)->quantAttr->max : 127.0f;
691     return {scale, zero, min, max};
692 }
693 } // namespace MNN
694