1 //
2 // GeometrySelect.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/05/07.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "ConvertUtils.hpp"
10 #include "geometry/GeometryComputer.hpp"
11 #include "shape/SizeComputer.hpp"
12 namespace MNN {
13 class GeometrySelect : public GeometryComputer {
14 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const15 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
16 Context& context, CommandBuffer& res) const override {
17 auto input0 = inputs[0];
18 auto input1 = inputs[1];
19 auto input2 = inputs[2];
20 auto output = outputs[0];
21 auto inputL0 = input0->elementSize();
22 auto inputL1 = input1->elementSize();
23 auto inputL2 = input1->elementSize();
24 auto outputSize = output->elementSize();
25 // Need Broadcast or same shape
26 if (outputSize != inputL0) {
27 std::shared_ptr<Tensor> newTensor(new Tensor);
28 TensorUtils::copyShape(output, newTensor.get(), true);
29 newTensor->buffer().type = output->buffer().type;
30 ConvertUtils::broadcastto(input0, newTensor.get());
31 input0 = newTensor.get();
32 res.extras.emplace_back(newTensor);
33 }
34 if (outputSize != inputL1) {
35 std::shared_ptr<Tensor> newTensor(new Tensor);
36 TensorUtils::copyShape(output, newTensor.get(), true);
37 newTensor->buffer().type = output->buffer().type;
38 ConvertUtils::broadcastto(input1, newTensor.get());
39 input1 = newTensor.get();
40 res.extras.emplace_back(newTensor);
41 }
42 if (outputSize != inputL2) {
43 std::shared_ptr<Tensor> newTensor(new Tensor);
44 TensorUtils::copyShape(output, newTensor.get(), true);
45 newTensor->buffer().type = output->buffer().type;
46 ConvertUtils::broadcastto(input2, newTensor.get());
47 input2 = newTensor.get();
48 res.extras.emplace_back(newTensor);
49 }
50 Command cmd;
51 cmd.op = op;
52 cmd.inputs = {input0, input1, input2};
53 cmd.outputs = std::move(outputs);
54 res.command.emplace_back(std::move(cmd));
55 return true;
56 }
57 };
58
_create()59 static void _create() {
60 std::shared_ptr<GeometryComputer> comp(new GeometrySelect);
61 GeometryComputer::registerGeometryComputer(comp, {OpType_Select});
62 }
63
64 REGISTER_GEOMETRY(GeometrySelect, _create);
65
66 } // namespace MNN
67