1 // 2 // GeometryComputer.hpp 3 // MNN 4 // 5 // Created by MNN on 2020/04/01. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef GeometryComputer_hpp 10 #define GeometryComputer_hpp 11 #include <map> 12 #include <vector> 13 #include "MNN_generated.h" 14 #include "core/Command.hpp" 15 #include "core/TensorUtils.hpp" 16 #include "core/Backend.hpp" 17 18 namespace MNN { 19 class GeometryComputer { 20 public: ~GeometryComputer()21 virtual ~GeometryComputer() { 22 // Do nothing 23 } 24 class MNN_PUBLIC Context { 25 public: 26 Context(std::shared_ptr<Backend> allocBackend, bool permitVirtual = true, MNNForwardType = MNN_FORWARD_CPU); 27 ~Context(); 28 29 void clear(); 30 void setBackend(Backend* backend); supportVirtual() const31 bool supportVirtual() const { 32 return mPermitVirtual; 33 } 34 void getRasterCacheCreateRecurrse(Tensor* src, CommandBuffer& cmd); 35 36 // If has cache, return. Otherwise create cache 37 const std::vector<std::shared_ptr<Tensor>>& searchConst(const Op* op); 38 std::shared_ptr<Tensor> allocConst(const Op* key, const std::vector<int>& shape, halide_type_t type, 39 Tensor::DimensionType dimType = Tensor::TENSORFLOW); 40 bool allocTensor(Tensor* tenosr); 41 std::vector<Tensor*> pOutputs; forwardType() const42 inline MNNForwardType forwardType() const { 43 return mForwardType; 44 } 45 private: 46 void getRasterCacheCreate(Tensor* src, CommandBuffer& cmd); 47 std::map<const Op*, std::vector<std::shared_ptr<Tensor>>> mConstTensors; 48 std::vector<std::shared_ptr<Tensor>> mEmpty; 49 std::vector<std::shared_ptr<Tensor>> mTempConstTensors; 50 bool mPermitVirtual; 51 std::shared_ptr<Backend> mBackend; 52 std::vector<uint8_t> mRasterOp; 53 MNNForwardType mForwardType; 54 }; 55 static void init(); 56 MNN_PUBLIC static const GeometryComputer* search(int opType, Runtime::CompilerType compType); 57 static void registerGeometryComputer(std::shared_ptr<GeometryComputer> comp, std::vector<int> type, Runtime::CompilerType compType = Runtime::Compiler_Geometry); 58 MNN_PUBLIC bool compute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, 59 Context& context, CommandBuffer& cmd) const; 60 61 protected: 62 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, 63 Context& context, CommandBuffer& cmd) const = 0; 64 }; 65 66 class DefaultGeometryComputer : public GeometryComputer { 67 public: DefaultGeometryComputer()68 DefaultGeometryComputer() { 69 // Do nothing 70 } 71 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, 72 Context& context, CommandBuffer& cmd) const override; 73 }; 74 void registerGeometryOps(); 75 76 #define REGISTER_GEOMETRY(f, c) \ 77 extern void ___##f##__##c##__() { \ 78 c(); \ 79 } 80 81 } // namespace MNN 82 83 #endif 84