1 //
2 //  ShapeLinSpace.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/12/11.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "shape/SizeComputer.hpp"
10 #include "core/Macro.h"
11 #include "core/TensorUtils.hpp"
12 
13 namespace MNN {
14 
15 class LinSpaceSizeComputer : public SizeComputer {
onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const16     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
17                                const std::vector<Tensor*>& outputs) const override {
18         MNN_ASSERT(inputs.size() == 3);
19         MNN_ASSERT(outputs.size() == 1);
20         auto& ib1 = inputs[0]->buffer();
21         auto& ib2 = inputs[1]->buffer();
22         auto& ib3 = inputs[2]->buffer();
23         auto& ob = outputs[0]->buffer();
24         MNN_ASSERT(ib1.dimensions == 0);
25         MNN_ASSERT(ib2.dimensions == 0);
26         MNN_ASSERT(ib3.dimensions == 0);
27 
28         MNN_ASSERT(inputs[0]->getType() == halide_type_of<float>());
29         MNN_ASSERT(inputs[1]->getType() == halide_type_of<float>());
30         MNN_ASSERT(inputs[2]->getType() == halide_type_of<int32_t>());
31 
32         int num = inputs[2]->host<int32_t>()[0];
33         MNN_ASSERT(num > 0);
34 
35         ob.dimensions = 1;
36         ob.dim[0].extent = num;
37         outputs[0]->setType(DataType_DT_FLOAT);
38         TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
39 
40         return true;
41     }
42 };
43 
44 REGISTER_SHAPE_INPUTS(LinSpaceSizeComputer, OpType_LinSpace, {2});
45 } // namespace MNN
46