1 //
2 //  ShapeConst.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/10.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "shape/SizeComputer.hpp"
10 #include "core/Macro.h"
11 
12 namespace MNN {
13 class ConstComputer : public SizeComputer {
14 public:
onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const15     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
16                                const std::vector<Tensor*>& outputs) const override {
17         //MNN_ASSERT(0 == inputs.size());
18         MNN_ASSERT(1 == outputs.size());
19 
20         // copy dims
21         auto output    = outputs[0];
22         auto parameter = op->main_as_Blob();
23 
24         output->buffer().dimensions = parameter->dims() ? parameter->dims()->size() : 0;
25         for (int i = 0; i < output->buffer().dimensions; i++) {
26             output->buffer().dim[i].extent = parameter->dims()->Get(i);
27         }
28         if (parameter->dataType() == DataType_DT_HALF) {
29             output->setType(DataType_DT_FLOAT);
30         } else {
31             output->setType(parameter->dataType());
32         }
33         TensorUtils::getDescribe(output)->dimensionFormat = parameter->dataFormat();
34 
35         return true;
36     }
37 };
38 
39 REGISTER_SHAPE(ConstComputer, OpType_Const);
40 REGISTER_SHAPE(ConstComputer, OpType_TrainableParam);
41 
42 } // namespace MNN
43