1 //
2 //  GeometryComputer.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/04/01.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef GeometryComputer_hpp
10 #define GeometryComputer_hpp
11 #include <map>
12 #include <vector>
13 #include "MNN_generated.h"
14 #include "core/Command.hpp"
15 #include "core/TensorUtils.hpp"
16 #include "core/Backend.hpp"
17 
18 namespace MNN {
19 class GeometryComputer {
20 public:
~GeometryComputer()21     virtual ~GeometryComputer() {
22         // Do nothing
23     }
24     class MNN_PUBLIC Context {
25     public:
26         Context(std::shared_ptr<Backend> allocBackend, bool permitVirtual = true, MNNForwardType = MNN_FORWARD_CPU);
27         ~Context();
28 
29         void clear();
30         void setBackend(Backend* backend);
supportVirtual() const31         bool supportVirtual() const {
32             return mPermitVirtual;
33         }
34         void getRasterCacheCreateRecurrse(Tensor* src, CommandBuffer& cmd);
35 
36         // If has cache, return. Otherwise create cache
37         const std::vector<std::shared_ptr<Tensor>>& searchConst(const Op* op);
38         std::shared_ptr<Tensor> allocConst(const Op* key, const std::vector<int>& shape, halide_type_t type,
39                                            Tensor::DimensionType dimType = Tensor::TENSORFLOW);
40         bool allocTensor(Tensor* tenosr);
41         std::vector<Tensor*> pOutputs;
forwardType() const42         inline MNNForwardType forwardType() const {
43             return mForwardType;
44         }
45     private:
46         void getRasterCacheCreate(Tensor* src, CommandBuffer& cmd);
47         std::map<const Op*, std::vector<std::shared_ptr<Tensor>>> mConstTensors;
48         std::vector<std::shared_ptr<Tensor>> mEmpty;
49         std::vector<std::shared_ptr<Tensor>> mTempConstTensors;
50         bool mPermitVirtual;
51         std::shared_ptr<Backend> mBackend;
52         std::vector<uint8_t> mRasterOp;
53         MNNForwardType mForwardType;
54     };
55     static void init();
56     MNN_PUBLIC static const GeometryComputer* search(int opType, Runtime::CompilerType compType);
57     static void registerGeometryComputer(std::shared_ptr<GeometryComputer> comp, std::vector<int> type, Runtime::CompilerType compType = Runtime::Compiler_Geometry);
58     MNN_PUBLIC bool compute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
59                             Context& context, CommandBuffer& cmd) const;
60 
61 protected:
62     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
63                            Context& context, CommandBuffer& cmd) const = 0;
64 };
65 
66 class DefaultGeometryComputer : public GeometryComputer {
67 public:
DefaultGeometryComputer()68     DefaultGeometryComputer() {
69         // Do nothing
70     }
71     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
72                            Context& context, CommandBuffer& cmd) const override;
73 };
74 void registerGeometryOps();
75 
76 #define REGISTER_GEOMETRY(f, c)       \
77     extern void ___##f##__##c##__() { \
78         c();                          \
79     }
80 
81 } // namespace MNN
82 
83 #endif
84