1 // 2 // ShapeConst.cpp 3 // MNN 4 // 5 // Created by MNN on 2019/01/10. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "shape/SizeComputer.hpp" 10 #include "core/Macro.h" 11 12 namespace MNN { 13 class ConstComputer : public SizeComputer { 14 public: onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const15 virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, 16 const std::vector<Tensor*>& outputs) const override { 17 //MNN_ASSERT(0 == inputs.size()); 18 MNN_ASSERT(1 == outputs.size()); 19 20 // copy dims 21 auto output = outputs[0]; 22 auto parameter = op->main_as_Blob(); 23 24 output->buffer().dimensions = parameter->dims() ? parameter->dims()->size() : 0; 25 for (int i = 0; i < output->buffer().dimensions; i++) { 26 output->buffer().dim[i].extent = parameter->dims()->Get(i); 27 } 28 if (parameter->dataType() == DataType_DT_HALF) { 29 output->setType(DataType_DT_FLOAT); 30 } else { 31 output->setType(parameter->dataType()); 32 } 33 TensorUtils::getDescribe(output)->dimensionFormat = parameter->dataFormat(); 34 35 return true; 36 } 37 }; 38 39 REGISTER_SHAPE(ConstComputer, OpType_Const); 40 REGISTER_SHAPE(ConstComputer, OpType_TrainableParam); 41 42 } // namespace MNN 43