1 //
2 //  GeometryInnerProduct.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/05/07.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "geometry/GeometryComputer.hpp"
10 #include "geometry/GeometryComputerUtils.hpp"
11 #include "core/OpCommonUtils.hpp"
12 #include "core/ConvolutionCommon.hpp"
13 #include "ConvertUtils.hpp"
14 #define MNN_OPEN_TIME_TRACE
15 #include <MNN/AutoTime.hpp>
16 namespace MNN {
17 class GeometryInnerProduct : public GeometryComputer {
18 public:
19 
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const20     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
21                                     const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
22         auto parameter  = op->main_as_InnerProduct();
23         int outputCount = parameter->outputCount();
24         int srcCount    = parameter->weight()->size() / outputCount;
25 
26         MNN_ASSERT(inputs.size() == 1);
27         MNN_ASSERT(outputs.size() == 1);
28         auto input = inputs[0];
29         auto output = outputs[0];
30         int inputDims = input->dimensions();
31         int outputDims = output->dimensions();
32         MNN_ASSERT(inputDims >= 2);
33         MNN_ASSERT(outputDims == 2);
34         MNN_ASSERT(output->length(1) == outputCount);
35 
36         int batch = output->length(0);
37         MNN_ASSERT(input->length(0) == batch);
38         int mulNum = 1;
39         for(int i=1; i < inputDims; i++) {
40             mulNum *= input->length(i);
41         }
42         if (srcCount != mulNum) {
43             return false;
44         }
45 
46         Tensor* A = nullptr;
47         Tensor* B = nullptr;
48         {
49             std::shared_ptr<Tensor> tmpInput(new Tensor);
50             tmpInput->buffer().type = halide_type_of<float>();
51             tmpInput->buffer().dimensions = 2;
52             tmpInput->setLength(0, batch);
53             tmpInput->setLength(1, srcCount);
54             auto des = TensorUtils::getDescribe(tmpInput.get());
55             des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
56             des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
57             des->regions.clear();
58             des->regions.reserve(1);
59 
60             Tensor::InsideDescribe::Region region;
61             region.origin = input;
62             region.size[0] = 1;
63             region.size[1] = batch;
64             region.size[2] = srcCount;
65             region.src.offset = 0;
66             region.dst.offset = 0;
67             region.src.stride[0] = 1;
68             region.dst.stride[0] = 1;
69             region.src.stride[1] = srcCount;
70             region.dst.stride[1] = srcCount;
71             region.src.stride[2] = 1;
72             region.dst.stride[2] = 1;
73             des->regions.emplace_back(std::move(region));
74 
75             A = tmpInput.get();
76             res.extras.emplace_back(tmpInput);
77         }
78 
79         std::shared_ptr<Tensor> tmpOutput(new Tensor);
80         std::shared_ptr<Tensor> C(new Tensor);
81         auto constTensors = context.searchConst(op);
82         Tensor* weight = nullptr;
83         Tensor* bias = nullptr;
84         if (!constTensors.empty()) {
85             MNN_ASSERT(constTensors.size() == 2);
86             weight = constTensors[0].get();
87             bias = constTensors[1].get();
88         } else {
89             auto weightTensor = context.allocConst(op, {outputCount, srcCount}, halide_type_of<float>());
90             ::memcpy(weightTensor.get()->host<float>(), parameter->weight()->data(), parameter->weight()->size()*sizeof(float));
91             weight = weightTensor.get();
92             auto biasTensor = context.allocConst(op, {batch, outputCount}, halide_type_of<float>());
93             ::memcpy(biasTensor.get()->host<float>(), parameter->bias()->data(), parameter->bias()->size()*sizeof(float));
94             bias = biasTensor.get();
95         }
96         {
97             B = weight;
98 
99             C->buffer().type = halide_type_of<float>();
100             C->buffer().dimensions = 2;
101             C->setLength(0, batch);
102             C->setLength(1, outputCount);
103 
104             auto cmd = GeometryComputerUtils::makeMatMul(A, B, C.get(), nullptr, false, true);
105             res.extras.emplace_back(C);
106             res.command.emplace_back(std::move(cmd));
107         }
108 
109         {
110             tmpOutput->buffer().type = halide_type_of<float>();
111             tmpOutput->buffer().dimensions = 2;
112             tmpOutput->setLength(0, batch);
113             tmpOutput->setLength(1, outputCount);
114 
115             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, C.get(), bias, tmpOutput.get());
116             res.extras.emplace_back(tmpOutput);
117             res.command.emplace_back(std::move(cmd));
118         }
119 
120         {
121             auto des = TensorUtils::getDescribe(output);
122             des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
123             des->regions.clear();
124             des->regions.reserve(1);
125 
126             Tensor::InsideDescribe::Region region;
127             region.origin = tmpOutput.get();
128             region.size[0] = 1;
129             region.size[1] = batch;
130             region.size[2] = outputCount;
131             region.src.offset = 0;
132             region.dst.offset = 0;
133             region.src.stride[0] = 1;
134             region.dst.stride[0] = 1;
135             region.src.stride[1] = outputCount;
136             region.dst.stride[1] = outputCount;
137             region.src.stride[2] = 1;
138             region.dst.stride[2] = 1;
139             des->regions.emplace_back(std::move(region));
140         }
141 
142         return true;
143     }
144 };
145 
_create()146 static void _create() {
147     std::shared_ptr<GeometryComputer> comp(new GeometryInnerProduct);
148     GeometryComputer::registerGeometryComputer(comp, {OpType_InnerProduct});
149 }
150 
151 REGISTER_GEOMETRY(GeometryInnerProduct, _create);
152 
153 } // namespace MNN
154