1 //
2 //  ShapeDetectionOutput.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/10.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "shape/SizeComputer.hpp"
10 #include "core/Macro.h"
11 namespace MNN {
12 
13 // Size Computer
14 class DetectionOutputComputer : public SizeComputer {
onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const15     virtual bool onComputeSize(const MNN::Op *op, const std::vector<Tensor *> &inputs,
16                                const std::vector<Tensor *> &outputs) const override {
17         MNN_ASSERT(3 <= inputs.size());
18         MNN_ASSERT(1 == outputs.size());
19 
20         // set dims
21         auto &output    = outputs[0]->buffer();
22         auto maxNumber = op->main_as_DetectionOutput()->keepTopK();
23 
24         output.dim[0].extent = 1;
25         output.dim[1].extent = 1;
26         output.dim[2].extent = maxNumber;
27         output.dim[3].extent = 6; // maximum width
28         TensorUtils::getDescribe(outputs[0])->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
29         output.type = halide_type_of<float>();
30 
31         return true;
32     }
33 };
34 
35 REGISTER_SHAPE(DetectionOutputComputer, OpType_DetectionOutput);
36 } // namespace MNN
37