1 //
2 //  GeometryCosineSimilarity.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/07/13.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "geometry/GeometryComputer.hpp"
10 #include "core/OpCommonUtils.hpp"
11 #include "geometry/GeometryComputerUtils.hpp"
12 
13 namespace MNN {
14 class GeometryCosineSimilarity : public GeometryComputer {
15 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const16     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
17                                     const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
18         MNN_ASSERT(3 <= inputs.size());
19         MNN_ASSERT(1 == outputs.size());
20 
21         auto input0          = inputs[0];
22         auto input1          = inputs[1];
23         auto dimTensor = inputs[2];
24         const auto dim = dimTensor->host<int32_t>()[0];
25         MNN_ASSERT(dim == 1);
26         auto output          = outputs[0];
27 
28         int dimensions = input0->dimensions();
29         int outside = 1;
30         int channel = 1;
31         int inside = 1;
32         for(int i=0; i<dim; i++) {
33             outside *= input0->length(i);
34         }
35         channel = input0->length(dim);
36         for(int i=dim+1; i<dimensions; i++) {
37             inside *= input0->length(i);
38         }
39         auto dimType = input0->getDimensionType();
40 
41 
42         //input0 transform to NCHW format
43         std::shared_ptr<Tensor> tmpInput0;
44         {
45             tmpInput0.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
46             auto outputDes = TensorUtils::getDescribe(tmpInput0.get());
47             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
48 
49             Tensor::InsideDescribe::Region desReg;
50             desReg.size[0] = outside;
51             desReg.size[1] = channel;
52             desReg.size[2] = inside;
53             desReg.dst.offset = 0;
54             desReg.dst.stride[0] = channel*inside;
55             desReg.dst.stride[1] = inside;
56             desReg.dst.stride[2] = 1;
57             desReg.src.offset = 0;
58             desReg.src.stride[0] = channel*inside;
59             desReg.src.stride[1] = inside;
60             desReg.src.stride[2] = 1;
61             desReg.origin = input0;
62             outputDes->regions.emplace_back(std::move(desReg));
63 
64             res.extras.emplace_back(tmpInput0);
65         }
66 
67         //input1 transform to NCHW format
68         std::shared_ptr<Tensor> tmpInput1;
69         {
70             tmpInput1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
71             auto outputDes = TensorUtils::getDescribe(tmpInput1.get());
72             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
73             outputDes->dimensionFormat = MNN_DATA_FORMAT_NCHW;
74 
75             Tensor::InsideDescribe::Region desReg;
76             desReg.size[0] = outside;
77             desReg.size[1] = channel;
78             desReg.size[2] = inside;
79             desReg.dst.offset = 0;
80             desReg.dst.stride[0] = channel*inside;
81             desReg.dst.stride[1] = inside;
82             desReg.dst.stride[2] = 1;
83             desReg.src.offset = 0;
84             desReg.src.stride[0] = channel*inside;
85             desReg.src.stride[1] = inside;
86             desReg.src.stride[2] = 1;
87             desReg.origin = input1;
88             outputDes->regions.emplace_back(std::move(desReg));
89 
90             res.extras.emplace_back(tmpInput1);
91         }
92 
93         //input0*input0
94         std::shared_ptr<Tensor> tmpInput0x0;
95         {
96             tmpInput0x0.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
97             auto des = TensorUtils::getDescribe(tmpInput0x0.get());
98             des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
99 
100             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput0.get(), tmpInput0x0.get());
101 
102             res.extras.emplace_back(tmpInput0x0);
103             res.command.emplace_back(std::move(cmd));
104         }
105 
106         //input0*input1
107         std::shared_ptr<Tensor> tmpInput0x1;
108         {
109             tmpInput0x1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
110             auto des = TensorUtils::getDescribe(tmpInput0x1.get());
111             des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
112 
113             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput1.get(), tmpInput0x1.get());
114 
115             res.extras.emplace_back(tmpInput0x1);
116             res.command.emplace_back(std::move(cmd));
117         }
118 
119         //input1*input1
120         std::shared_ptr<Tensor> tmpInput1x1;
121         {
122             tmpInput1x1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
123             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput1.get(), tmpInput1.get(), tmpInput1x1.get());
124 
125             res.extras.emplace_back(tmpInput1x1);
126             res.command.emplace_back(std::move(cmd));
127         }
128 
129         //reduction sum, axis=1, only support NCHW
130         std::shared_ptr<Tensor> sumValue0x0;
131         {
132             sumValue0x0.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
133             auto des = TensorUtils::getDescribe(sumValue0x0.get());
134             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x0.get(), sumValue0x0.get());
135             res.extras.emplace_back(sumValue0x0);
136             res.command.emplace_back(std::move(cmd));
137         }
138 
139         //reduction sum, axis=1, only support NCHW
140         std::shared_ptr<Tensor> sumValue0x1;
141         {
142             sumValue0x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
143             auto des = TensorUtils::getDescribe(sumValue0x1.get());
144             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x1.get(), sumValue0x1.get());
145             res.extras.emplace_back(sumValue0x1);
146             res.command.emplace_back(std::move(cmd));
147         }
148 
149         //reduction sum, axis=1, only support NCHW
150         std::shared_ptr<Tensor> sumValue1x1;
151         {
152             sumValue1x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
153             auto des = TensorUtils::getDescribe(sumValue1x1.get());
154 
155             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput1x1.get(), sumValue1x1.get());
156 
157             res.extras.emplace_back(sumValue1x1);
158             res.command.emplace_back(std::move(cmd));
159         }
160 
161         //sumValue0x0 * sumValue1x1
162         std::shared_ptr<Tensor> mulValue0x0_1x1;
163         {
164             mulValue0x0_1x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
165             auto des = TensorUtils::getDescribe(mulValue0x0_1x1.get());
166 
167             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, sumValue0x0.get(), sumValue1x1.get(), mulValue0x0_1x1.get());
168 
169             res.extras.emplace_back(mulValue0x0_1x1);
170             res.command.emplace_back(std::move(cmd));
171         }
172 
173         //add eps
174         std::shared_ptr<Tensor> mulValue0x0_1x1_eps;
175         {
176             mulValue0x0_1x1_eps.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
177             auto des = TensorUtils::getDescribe(mulValue0x0_1x1_eps.get());
178 
179             const float eps         = 1e-8f;
180             auto epsTensor = context.allocConst(op, {1}, halide_type_of<float>());
181             epsTensor.get()->host<float>()[0] = eps;
182 
183             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, mulValue0x0_1x1.get(), epsTensor.get(), mulValue0x0_1x1_eps.get());
184 
185             res.extras.emplace_back(mulValue0x0_1x1_eps);
186             res.command.emplace_back(std::move(cmd));
187         }
188 
189         //sqrt(sumValue0x0 * sumValue1x1 + eps)
190         std::shared_ptr<Tensor> sqrtMulValue;
191         {
192             sqrtMulValue.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
193             auto des = TensorUtils::getDescribe(sqrtMulValue.get());
194 
195             auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_SQRT, mulValue0x0_1x1_eps.get(), sqrtMulValue.get());
196 
197             res.extras.emplace_back(sqrtMulValue);
198             res.command.emplace_back(std::move(cmd));
199         }
200         //div
201         std::shared_ptr<Tensor> tmpOutput;
202         {
203             tmpOutput.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
204             auto des = TensorUtils::getDescribe(tmpOutput.get());
205 
206             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_REALDIV, sumValue0x1.get(), sqrtMulValue.get(), tmpOutput.get());
207 
208             res.extras.emplace_back(tmpOutput);
209             res.command.emplace_back(std::move(cmd));
210         }
211         //transform to output
212         {
213             auto outputDes = TensorUtils::getDescribe(output);
214             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
215             Tensor::InsideDescribe::Region desReg;
216             desReg.size[0] = 1;
217             desReg.size[1] = outside;
218             desReg.size[2] = inside;
219             desReg.dst.offset = 0;
220             desReg.dst.stride[0] = outside*inside;
221             desReg.dst.stride[1] = inside;
222             desReg.dst.stride[2] = 1;
223             desReg.src.offset = 0;
224             desReg.src.stride[0] = outside*inside;
225             desReg.src.stride[1] = inside;
226             desReg.src.stride[2] = 1;
227             desReg.origin = tmpOutput.get();
228             outputDes->regions.emplace_back(std::move(desReg));
229         }
230         return true;
231 
232     }
233 };
234 
_create()235 static void _create() {
236     std::shared_ptr<GeometryComputer> comp(new GeometryCosineSimilarity);
237     GeometryComputer::registerGeometryComputer(comp, {OpType_CosineSimilarity});
238 }
239 
240 REGISTER_GEOMETRY(GeometryCosineSimilarity, _create);
241 
242 } // namespace MNN
243