1 // 2 // ShapeStridedSlice.cpp 3 // MNN 4 // 5 // Created by MNN on 2019/01/10. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include <algorithm> 10 #include <array> 11 #include "shape/SizeComputer.hpp" 12 #include "core/Macro.h" 13 #include "core/TensorUtils.hpp" 14 namespace MNN { 15 class StridedSliceComputer : public SizeComputer { 16 public: onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const17 virtual bool onComputeSize(const MNN::Op *op, const std::vector<Tensor *> &inputs, 18 const std::vector<Tensor *> &outputs) const override { 19 MNN_ASSERT(4 == inputs.size()); 20 MNN_ASSERT(1 == outputs.size()); 21 22 Tensor *input = inputs[0]; 23 const int inputDim = input->buffer().dimensions; 24 if (inputDim <= 0 || inputDim > MNN_MAX_TENSOR_DIM) { 25 return false; 26 } 27 auto parameter = op->main_as_StridedSliceParam(); 28 int32_t beginMask = parameter->beginMask(); 29 int32_t endMask = parameter->endMask(); 30 int32_t shrinkAxisMask = parameter->shrinkAxisMask(); 31 int32_t ellipsisMask = parameter->ellipsisMask(); 32 int32_t newAxisMask = parameter->newAxisMask(); 33 if (ellipsisMask && (ellipsisMask & (ellipsisMask - 1))) { 34 MNN_ERROR("only one non-zero bit is allowed in ellipsisMask\n"); 35 return false; 36 } 37 38 Tensor *begin = inputs[1]; 39 Tensor *end = inputs[2]; 40 Tensor *strided = inputs[3]; 41 auto output = outputs[0]; 42 43 MNN_ASSERT(begin->buffer().dimensions == end->buffer().dimensions && 44 begin->buffer().dimensions == strided->buffer().dimensions); 45 46 int32_t inputShape[MNN_MAX_TENSOR_DIM] = { 0 }; 47 int32_t begins[MNN_MAX_TENSOR_DIM] = { 0 }; 48 int32_t ends[MNN_MAX_TENSOR_DIM] = { 0 }; 49 int32_t strides[MNN_MAX_TENSOR_DIM] = { 0 }; 50 int32_t beginMasks[MNN_MAX_TENSOR_DIM] = { 0 }; 51 int32_t endMasks[MNN_MAX_TENSOR_DIM] = { 0 }; 52 int32_t shrinkAxisMasks[MNN_MAX_TENSOR_DIM] = { 0 }; 53 int32_t newAxisMasks[MNN_MAX_TENSOR_DIM] = { 0 }; 54 int strideSize = begin->length(0); 55 for (int i = 0; i < inputDim; i++) { 56 inputShape[i] = input->length(i); 57 } 58 for (int i = 0; i < strideSize; i++) { 59 beginMasks[i] = beginMask & (1 << i); 60 } 61 for (int i = 0; i < strideSize; i++) { 62 endMasks[i] = endMask & (1 << i); 63 } 64 for (int i = 0; i < strideSize; i++) { 65 shrinkAxisMasks[i] = shrinkAxisMask & (1 << i); 66 } 67 for (int i = 0; i < strideSize; i++) { 68 newAxisMasks[i] = newAxisMask & (1 << i); 69 } 70 71 // deal ellipsis, expand strides info 72 if (ellipsisMask > 0) { 73 int32_t beginMasksTmp[MNN_MAX_TENSOR_DIM] = { 0 }; 74 int32_t endMasksTmp[MNN_MAX_TENSOR_DIM] = { 0 }; 75 int32_t shrinkAxisMasksTmp[MNN_MAX_TENSOR_DIM] = { 0 }; 76 int32_t newAxisMasksTmp[MNN_MAX_TENSOR_DIM] = { 0 }; 77 // expand stride info 78 int ellipsisPos = -1; 79 for (int i = 0; i < strideSize; i++) { 80 int temp = ellipsisMask & (1 << i); 81 if (temp != 0) { 82 ellipsisPos = i; 83 break; 84 } 85 } 86 MNN_ASSERT(ellipsisPos >= 0 && ellipsisPos < strideSize); 87 /* 88 Example: foo's dim is [2, 3, 4, 5, 6, 7], foo[0:2, :, 3:5, 3:6]: 89 1. strideSize = 4, inputDim = 6, ellipsis = 2(0010) 90 2. left part: 0:2, right part: 3:5, 3:6 91 3. expand: foo[0:2, 0:3, 0:4, 3:5, 3:6] 92 */ 93 int ellpsisSize = inputDim - strideSize, strideIdx = 0; 94 for (int i = 0; i < inputDim; i++) { 95 if (i == ellipsisPos) { 96 strideIdx++; 97 } 98 if (i >= ellipsisPos && i <= ellipsisPos + ellpsisSize) { 99 begins[i] = 0; 100 ends[i] = inputShape[i]; 101 strides[i] = 1; 102 beginMasksTmp[i] = 0; 103 endMasksTmp[i] = 0; 104 shrinkAxisMasksTmp[i] = 0; 105 } else { 106 begins[i] = begin->host<int32_t>()[strideIdx]; 107 ends[i] = end->host<int32_t>()[strideIdx]; 108 strides[i] = strided->host<int32_t>()[strideIdx]; 109 beginMasksTmp[i] = beginMasks[strideIdx]; 110 endMasksTmp[i] = endMasks[strideIdx]; 111 shrinkAxisMasksTmp[i] = shrinkAxisMasks[strideIdx]; 112 newAxisMasksTmp[i] = newAxisMasks[strideIdx++]; 113 } 114 } 115 for (int i = 0; i < inputDim; i++) { 116 beginMasks[i] = beginMasksTmp[i]; 117 endMasks[i] = endMasksTmp[i]; 118 shrinkAxisMasks[i] = shrinkAxisMasksTmp[i]; 119 newAxisMasks[i] = newAxisMasksTmp[i]; 120 } 121 strideSize = inputDim; 122 } else { 123 for (int i = 0; i < strideSize; i++) { 124 begins[i] = begin->host<int>()[i]; 125 ends[i] = end->host<int>()[i]; 126 strides[i] = strided->host<int>()[i]; 127 } 128 } 129 130 int32_t beginShape[MNN_MAX_TENSOR_DIM]; 131 int32_t endShape[MNN_MAX_TENSOR_DIM]; 132 int32_t stridedShape[MNN_MAX_TENSOR_DIM]; 133 int32_t outputShape[MNN_MAX_TENSOR_DIM]; 134 int32_t outputShapeShrinked[MNN_MAX_TENSOR_DIM]; 135 136 int outputShapeSize = 0; 137 int outputShapeShrinkSize = 0; 138 int strideDealDims = 0; 139 140 auto beginAndEndShapeLimit = [](int shape, int dimSize, bool exclusive) -> int { 141 int maxShape = dimSize - 1, minShape = -dimSize; 142 if (exclusive) { 143 ++maxShape; 144 --minShape; 145 } 146 shape = (shape > maxShape ? maxShape : shape); 147 shape = (shape < minShape ? minShape : shape); 148 if (shape < 0) { 149 shape += dimSize; 150 } 151 return shape; 152 }; 153 154 int inputDimOffset = 0; 155 for (int i = 0; i < strideSize; i++) { 156 if (newAxisMasks[i] > 0) { 157 outputShape[outputShapeSize] = 1; 158 outputShapeSize++; 159 outputShapeShrinked[outputShapeShrinkSize] = 1; 160 outputShapeShrinkSize++; 161 continue; 162 } 163 auto inputDim = inputShape[inputDimOffset++]; 164 strideDealDims++; 165 if (beginMasks[i] > 0) { 166 beginShape[i] = 0; 167 } else { 168 beginShape[i] = std::min(inputDim, begins[i]); 169 } 170 if (beginShape[i] < 0) { 171 beginShape[i] += input->buffer().dim[i].extent; 172 } 173 if (endMasks[i] > 0) { 174 endShape[i] = inputDim; 175 } else { 176 endShape[i] = beginAndEndShapeLimit(ends[i], inputDim, true); 177 } 178 stridedShape[i] = shrinkAxisMasks[i] > 0 ? 1 : strides[i]; 179 180 if (endShape[i] < beginShape[i]) { 181 int t = beginShape[i]; 182 beginShape[i] = endShape[i]; 183 endShape[i] = t; 184 185 MNN_ASSERT(stridedShape[i] != 0); 186 if (stridedShape[i] < 0) { 187 stridedShape[i] = -stridedShape[i]; 188 } else { 189 // MNN_ASSERT(false); // TODO: should be the wrong case, but there is one in linfeng's faster 190 // rcnn face model 191 beginShape[i] = endShape[i]; // TODO: temp solution 192 } 193 } 194 195 if (shrinkAxisMasks[i] == 0) { 196 int size = (endShape[i] - beginShape[i] - 1) / stridedShape[i] + 1; 197 outputShape[outputShapeSize] = size; 198 outputShapeSize++; 199 outputShapeShrinked[outputShapeShrinkSize] = size; 200 outputShapeShrinkSize++; 201 } else { 202 outputShape[outputShapeSize] = std::min(1, inputDim); 203 outputShapeSize++; 204 } 205 } 206 207 int outputDimensionsWithoutRemain = strideDealDims; 208 int dimensionRemained = input->buffer().dimensions - strideDealDims; 209 210 for (int i = 0; i < dimensionRemained; i++) { 211 outputShapeShrinked[outputShapeShrinkSize] = input->buffer().dim[outputDimensionsWithoutRemain + i].extent; 212 outputShapeShrinkSize++; 213 } 214 215 output->buffer().dimensions = outputShapeShrinkSize; 216 output->buffer().type = input->buffer().type; 217 218 for (int i = 0; i < outputShapeShrinkSize; i++) { 219 output->buffer().dim[i].extent = outputShapeShrinked[i]; 220 } 221 TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; 222 return true; 223 } 224 }; 225 226 REGISTER_SHAPE_INPUTS(StridedSliceComputer, OpType_StridedSlice, (std::vector<int>{1,2,3})); 227 } // namespace MNN 228