1 //
2 //  GeometrySelect.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/05/07.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "ConvertUtils.hpp"
10 #include "geometry/GeometryComputer.hpp"
11 #include "shape/SizeComputer.hpp"
12 namespace MNN {
13 class GeometrySelect : public GeometryComputer {
14 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const15     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
16                            Context& context, CommandBuffer& res) const override {
17         auto input0     = inputs[0];
18         auto input1     = inputs[1];
19         auto input2     = inputs[2];
20         auto output     = outputs[0];
21         auto inputL0    = input0->elementSize();
22         auto inputL1    = input1->elementSize();
23         auto inputL2    = input1->elementSize();
24         auto outputSize = output->elementSize();
25         // Need Broadcast or same shape
26         if (outputSize != inputL0) {
27             std::shared_ptr<Tensor> newTensor(new Tensor);
28             TensorUtils::copyShape(output, newTensor.get(), true);
29             newTensor->buffer().type = output->buffer().type;
30             ConvertUtils::broadcastto(input0, newTensor.get());
31             input0 = newTensor.get();
32             res.extras.emplace_back(newTensor);
33         }
34         if (outputSize != inputL1) {
35             std::shared_ptr<Tensor> newTensor(new Tensor);
36             TensorUtils::copyShape(output, newTensor.get(), true);
37             newTensor->buffer().type = output->buffer().type;
38             ConvertUtils::broadcastto(input1, newTensor.get());
39             input1 = newTensor.get();
40             res.extras.emplace_back(newTensor);
41         }
42         if (outputSize != inputL2) {
43             std::shared_ptr<Tensor> newTensor(new Tensor);
44             TensorUtils::copyShape(output, newTensor.get(), true);
45             newTensor->buffer().type = output->buffer().type;
46             ConvertUtils::broadcastto(input2, newTensor.get());
47             input2 = newTensor.get();
48             res.extras.emplace_back(newTensor);
49         }
50         Command cmd;
51         cmd.op      = op;
52         cmd.inputs  = {input0, input1, input2};
53         cmd.outputs = std::move(outputs);
54         res.command.emplace_back(std::move(cmd));
55         return true;
56     }
57 };
58 
_create()59 static void _create() {
60     std::shared_ptr<GeometryComputer> comp(new GeometrySelect);
61     GeometryComputer::registerGeometryComputer(comp, {OpType_Select});
62 }
63 
64 REGISTER_GEOMETRY(GeometrySelect, _create);
65 
66 } // namespace MNN
67