1 //
2 // GeometryCosineSimilarity.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/07/13.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "geometry/GeometryComputer.hpp"
10 #include "core/OpCommonUtils.hpp"
11 #include "geometry/GeometryComputerUtils.hpp"
12
13 namespace MNN {
14 class GeometryCosineSimilarity : public GeometryComputer {
15 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const16 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
17 const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
18 MNN_ASSERT(3 <= inputs.size());
19 MNN_ASSERT(1 == outputs.size());
20
21 auto input0 = inputs[0];
22 auto input1 = inputs[1];
23 auto dimTensor = inputs[2];
24 const auto dim = dimTensor->host<int32_t>()[0];
25 MNN_ASSERT(dim == 1);
26 auto output = outputs[0];
27
28 int dimensions = input0->dimensions();
29 int outside = 1;
30 int channel = 1;
31 int inside = 1;
32 for(int i=0; i<dim; i++) {
33 outside *= input0->length(i);
34 }
35 channel = input0->length(dim);
36 for(int i=dim+1; i<dimensions; i++) {
37 inside *= input0->length(i);
38 }
39 auto dimType = input0->getDimensionType();
40
41
42 //input0 transform to NCHW format
43 std::shared_ptr<Tensor> tmpInput0;
44 {
45 tmpInput0.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
46 auto outputDes = TensorUtils::getDescribe(tmpInput0.get());
47 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
48
49 Tensor::InsideDescribe::Region desReg;
50 desReg.size[0] = outside;
51 desReg.size[1] = channel;
52 desReg.size[2] = inside;
53 desReg.dst.offset = 0;
54 desReg.dst.stride[0] = channel*inside;
55 desReg.dst.stride[1] = inside;
56 desReg.dst.stride[2] = 1;
57 desReg.src.offset = 0;
58 desReg.src.stride[0] = channel*inside;
59 desReg.src.stride[1] = inside;
60 desReg.src.stride[2] = 1;
61 desReg.origin = input0;
62 outputDes->regions.emplace_back(std::move(desReg));
63
64 res.extras.emplace_back(tmpInput0);
65 }
66
67 //input1 transform to NCHW format
68 std::shared_ptr<Tensor> tmpInput1;
69 {
70 tmpInput1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
71 auto outputDes = TensorUtils::getDescribe(tmpInput1.get());
72 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
73 outputDes->dimensionFormat = MNN_DATA_FORMAT_NCHW;
74
75 Tensor::InsideDescribe::Region desReg;
76 desReg.size[0] = outside;
77 desReg.size[1] = channel;
78 desReg.size[2] = inside;
79 desReg.dst.offset = 0;
80 desReg.dst.stride[0] = channel*inside;
81 desReg.dst.stride[1] = inside;
82 desReg.dst.stride[2] = 1;
83 desReg.src.offset = 0;
84 desReg.src.stride[0] = channel*inside;
85 desReg.src.stride[1] = inside;
86 desReg.src.stride[2] = 1;
87 desReg.origin = input1;
88 outputDes->regions.emplace_back(std::move(desReg));
89
90 res.extras.emplace_back(tmpInput1);
91 }
92
93 //input0*input0
94 std::shared_ptr<Tensor> tmpInput0x0;
95 {
96 tmpInput0x0.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
97 auto des = TensorUtils::getDescribe(tmpInput0x0.get());
98 des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
99
100 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput0.get(), tmpInput0x0.get());
101
102 res.extras.emplace_back(tmpInput0x0);
103 res.command.emplace_back(std::move(cmd));
104 }
105
106 //input0*input1
107 std::shared_ptr<Tensor> tmpInput0x1;
108 {
109 tmpInput0x1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
110 auto des = TensorUtils::getDescribe(tmpInput0x1.get());
111 des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
112
113 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput1.get(), tmpInput0x1.get());
114
115 res.extras.emplace_back(tmpInput0x1);
116 res.command.emplace_back(std::move(cmd));
117 }
118
119 //input1*input1
120 std::shared_ptr<Tensor> tmpInput1x1;
121 {
122 tmpInput1x1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
123 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput1.get(), tmpInput1.get(), tmpInput1x1.get());
124
125 res.extras.emplace_back(tmpInput1x1);
126 res.command.emplace_back(std::move(cmd));
127 }
128
129 //reduction sum, axis=1, only support NCHW
130 std::shared_ptr<Tensor> sumValue0x0;
131 {
132 sumValue0x0.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
133 auto des = TensorUtils::getDescribe(sumValue0x0.get());
134 auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x0.get(), sumValue0x0.get());
135 res.extras.emplace_back(sumValue0x0);
136 res.command.emplace_back(std::move(cmd));
137 }
138
139 //reduction sum, axis=1, only support NCHW
140 std::shared_ptr<Tensor> sumValue0x1;
141 {
142 sumValue0x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
143 auto des = TensorUtils::getDescribe(sumValue0x1.get());
144 auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x1.get(), sumValue0x1.get());
145 res.extras.emplace_back(sumValue0x1);
146 res.command.emplace_back(std::move(cmd));
147 }
148
149 //reduction sum, axis=1, only support NCHW
150 std::shared_ptr<Tensor> sumValue1x1;
151 {
152 sumValue1x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
153 auto des = TensorUtils::getDescribe(sumValue1x1.get());
154
155 auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput1x1.get(), sumValue1x1.get());
156
157 res.extras.emplace_back(sumValue1x1);
158 res.command.emplace_back(std::move(cmd));
159 }
160
161 //sumValue0x0 * sumValue1x1
162 std::shared_ptr<Tensor> mulValue0x0_1x1;
163 {
164 mulValue0x0_1x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
165 auto des = TensorUtils::getDescribe(mulValue0x0_1x1.get());
166
167 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, sumValue0x0.get(), sumValue1x1.get(), mulValue0x0_1x1.get());
168
169 res.extras.emplace_back(mulValue0x0_1x1);
170 res.command.emplace_back(std::move(cmd));
171 }
172
173 //add eps
174 std::shared_ptr<Tensor> mulValue0x0_1x1_eps;
175 {
176 mulValue0x0_1x1_eps.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
177 auto des = TensorUtils::getDescribe(mulValue0x0_1x1_eps.get());
178
179 const float eps = 1e-8f;
180 auto epsTensor = context.allocConst(op, {1}, halide_type_of<float>());
181 epsTensor.get()->host<float>()[0] = eps;
182
183 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, mulValue0x0_1x1.get(), epsTensor.get(), mulValue0x0_1x1_eps.get());
184
185 res.extras.emplace_back(mulValue0x0_1x1_eps);
186 res.command.emplace_back(std::move(cmd));
187 }
188
189 //sqrt(sumValue0x0 * sumValue1x1 + eps)
190 std::shared_ptr<Tensor> sqrtMulValue;
191 {
192 sqrtMulValue.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
193 auto des = TensorUtils::getDescribe(sqrtMulValue.get());
194
195 auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_SQRT, mulValue0x0_1x1_eps.get(), sqrtMulValue.get());
196
197 res.extras.emplace_back(sqrtMulValue);
198 res.command.emplace_back(std::move(cmd));
199 }
200 //div
201 std::shared_ptr<Tensor> tmpOutput;
202 {
203 tmpOutput.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
204 auto des = TensorUtils::getDescribe(tmpOutput.get());
205
206 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_REALDIV, sumValue0x1.get(), sqrtMulValue.get(), tmpOutput.get());
207
208 res.extras.emplace_back(tmpOutput);
209 res.command.emplace_back(std::move(cmd));
210 }
211 //transform to output
212 {
213 auto outputDes = TensorUtils::getDescribe(output);
214 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
215 Tensor::InsideDescribe::Region desReg;
216 desReg.size[0] = 1;
217 desReg.size[1] = outside;
218 desReg.size[2] = inside;
219 desReg.dst.offset = 0;
220 desReg.dst.stride[0] = outside*inside;
221 desReg.dst.stride[1] = inside;
222 desReg.dst.stride[2] = 1;
223 desReg.src.offset = 0;
224 desReg.src.stride[0] = outside*inside;
225 desReg.src.stride[1] = inside;
226 desReg.src.stride[2] = 1;
227 desReg.origin = tmpOutput.get();
228 outputDes->regions.emplace_back(std::move(desReg));
229 }
230 return true;
231
232 }
233 };
234
_create()235 static void _create() {
236 std::shared_ptr<GeometryComputer> comp(new GeometryCosineSimilarity);
237 GeometryComputer::registerGeometryComputer(comp, {OpType_CosineSimilarity});
238 }
239
240 REGISTER_GEOMETRY(GeometryCosineSimilarity, _create);
241
242 } // namespace MNN
243