1 //
2 // GeometryInnerProduct.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/05/07.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "geometry/GeometryComputer.hpp"
10 #include "geometry/GeometryComputerUtils.hpp"
11 #include "core/OpCommonUtils.hpp"
12 #include "core/ConvolutionCommon.hpp"
13 #include "ConvertUtils.hpp"
14 #define MNN_OPEN_TIME_TRACE
15 #include <MNN/AutoTime.hpp>
16 namespace MNN {
17 class GeometryInnerProduct : public GeometryComputer {
18 public:
19
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const20 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
21 const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
22 auto parameter = op->main_as_InnerProduct();
23 int outputCount = parameter->outputCount();
24 int srcCount = parameter->weight()->size() / outputCount;
25
26 MNN_ASSERT(inputs.size() == 1);
27 MNN_ASSERT(outputs.size() == 1);
28 auto input = inputs[0];
29 auto output = outputs[0];
30 int inputDims = input->dimensions();
31 int outputDims = output->dimensions();
32 MNN_ASSERT(inputDims >= 2);
33 MNN_ASSERT(outputDims == 2);
34 MNN_ASSERT(output->length(1) == outputCount);
35
36 int batch = output->length(0);
37 MNN_ASSERT(input->length(0) == batch);
38 int mulNum = 1;
39 for(int i=1; i < inputDims; i++) {
40 mulNum *= input->length(i);
41 }
42 if (srcCount != mulNum) {
43 return false;
44 }
45
46 Tensor* A = nullptr;
47 Tensor* B = nullptr;
48 {
49 std::shared_ptr<Tensor> tmpInput(new Tensor);
50 tmpInput->buffer().type = halide_type_of<float>();
51 tmpInput->buffer().dimensions = 2;
52 tmpInput->setLength(0, batch);
53 tmpInput->setLength(1, srcCount);
54 auto des = TensorUtils::getDescribe(tmpInput.get());
55 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
56 des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
57 des->regions.clear();
58 des->regions.reserve(1);
59
60 Tensor::InsideDescribe::Region region;
61 region.origin = input;
62 region.size[0] = 1;
63 region.size[1] = batch;
64 region.size[2] = srcCount;
65 region.src.offset = 0;
66 region.dst.offset = 0;
67 region.src.stride[0] = 1;
68 region.dst.stride[0] = 1;
69 region.src.stride[1] = srcCount;
70 region.dst.stride[1] = srcCount;
71 region.src.stride[2] = 1;
72 region.dst.stride[2] = 1;
73 des->regions.emplace_back(std::move(region));
74
75 A = tmpInput.get();
76 res.extras.emplace_back(tmpInput);
77 }
78
79 std::shared_ptr<Tensor> tmpOutput(new Tensor);
80 std::shared_ptr<Tensor> C(new Tensor);
81 auto constTensors = context.searchConst(op);
82 Tensor* weight = nullptr;
83 Tensor* bias = nullptr;
84 if (!constTensors.empty()) {
85 MNN_ASSERT(constTensors.size() == 2);
86 weight = constTensors[0].get();
87 bias = constTensors[1].get();
88 } else {
89 auto weightTensor = context.allocConst(op, {outputCount, srcCount}, halide_type_of<float>());
90 ::memcpy(weightTensor.get()->host<float>(), parameter->weight()->data(), parameter->weight()->size()*sizeof(float));
91 weight = weightTensor.get();
92 auto biasTensor = context.allocConst(op, {batch, outputCount}, halide_type_of<float>());
93 ::memcpy(biasTensor.get()->host<float>(), parameter->bias()->data(), parameter->bias()->size()*sizeof(float));
94 bias = biasTensor.get();
95 }
96 {
97 B = weight;
98
99 C->buffer().type = halide_type_of<float>();
100 C->buffer().dimensions = 2;
101 C->setLength(0, batch);
102 C->setLength(1, outputCount);
103
104 auto cmd = GeometryComputerUtils::makeMatMul(A, B, C.get(), nullptr, false, true);
105 res.extras.emplace_back(C);
106 res.command.emplace_back(std::move(cmd));
107 }
108
109 {
110 tmpOutput->buffer().type = halide_type_of<float>();
111 tmpOutput->buffer().dimensions = 2;
112 tmpOutput->setLength(0, batch);
113 tmpOutput->setLength(1, outputCount);
114
115 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, C.get(), bias, tmpOutput.get());
116 res.extras.emplace_back(tmpOutput);
117 res.command.emplace_back(std::move(cmd));
118 }
119
120 {
121 auto des = TensorUtils::getDescribe(output);
122 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
123 des->regions.clear();
124 des->regions.reserve(1);
125
126 Tensor::InsideDescribe::Region region;
127 region.origin = tmpOutput.get();
128 region.size[0] = 1;
129 region.size[1] = batch;
130 region.size[2] = outputCount;
131 region.src.offset = 0;
132 region.dst.offset = 0;
133 region.src.stride[0] = 1;
134 region.dst.stride[0] = 1;
135 region.src.stride[1] = outputCount;
136 region.dst.stride[1] = outputCount;
137 region.src.stride[2] = 1;
138 region.dst.stride[2] = 1;
139 des->regions.emplace_back(std::move(region));
140 }
141
142 return true;
143 }
144 };
145
_create()146 static void _create() {
147 std::shared_ptr<GeometryComputer> comp(new GeometryInnerProduct);
148 GeometryComputer::registerGeometryComputer(comp, {OpType_InnerProduct});
149 }
150
151 REGISTER_GEOMETRY(GeometryInnerProduct, _create);
152
153 } // namespace MNN
154