1 // 2 // ShapeLinSpace.cpp 3 // MNN 4 // 5 // Created by MNN on 2019/12/11. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "shape/SizeComputer.hpp" 10 #include "core/Macro.h" 11 #include "core/TensorUtils.hpp" 12 13 namespace MNN { 14 15 class LinSpaceSizeComputer : public SizeComputer { onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const16 virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, 17 const std::vector<Tensor*>& outputs) const override { 18 MNN_ASSERT(inputs.size() == 3); 19 MNN_ASSERT(outputs.size() == 1); 20 auto& ib1 = inputs[0]->buffer(); 21 auto& ib2 = inputs[1]->buffer(); 22 auto& ib3 = inputs[2]->buffer(); 23 auto& ob = outputs[0]->buffer(); 24 MNN_ASSERT(ib1.dimensions == 0); 25 MNN_ASSERT(ib2.dimensions == 0); 26 MNN_ASSERT(ib3.dimensions == 0); 27 28 MNN_ASSERT(inputs[0]->getType() == halide_type_of<float>()); 29 MNN_ASSERT(inputs[1]->getType() == halide_type_of<float>()); 30 MNN_ASSERT(inputs[2]->getType() == halide_type_of<int32_t>()); 31 32 int num = inputs[2]->host<int32_t>()[0]; 33 MNN_ASSERT(num > 0); 34 35 ob.dimensions = 1; 36 ob.dim[0].extent = num; 37 outputs[0]->setType(DataType_DT_FLOAT); 38 TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; 39 40 return true; 41 } 42 }; 43 44 REGISTER_SHAPE_INPUTS(LinSpaceSizeComputer, OpType_LinSpace, {2}); 45 } // namespace MNN 46