1 // 2 // Module.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/11/25. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef MNN_Train_Module_hpp 10 #define MNN_Train_Module_hpp 11 12 #include <vector> 13 #include <unordered_map> 14 15 #include <MNN/expr/Expr.hpp> 16 #include <MNN/MNNForwardType.h> 17 18 namespace MNN { 19 namespace Express { 20 struct SubGraph; 21 class MNN_PUBLIC Module { 22 public: 23 Module() = default; 24 virtual ~Module() = default; 25 virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) = 0; 26 Express::VARP forward(Express::VARP input); 27 std::vector<Express::VARP> parameters() const; 28 bool loadParameters(const std::vector<Express::VARP>& parameters); 29 void setIsTraining(const bool isTraining); 30 bool getIsTraining(); 31 void clearCache(); 32 name() const33 const std::string& name() const { 34 return mName; 35 }; setName(std::string name)36 void setName(std::string name) { 37 mName = std::move(name); 38 } type() const39 const std::string type() const { 40 return mType; 41 } setType(std::string type)42 void setType(std::string type) { 43 mType = std::move(type); 44 } 45 // Return the parameter index 46 int addParameter(Express::VARP parameter); 47 48 void setParameter(Express::VARP parameter, int index); 49 static Module* createEmpty(const std::vector<Express::VARP>& parameters); 50 51 struct BackendInfo { 52 MNNForwardType type = MNN_FORWARD_CPU; 53 BackendConfig* config = nullptr; 54 }; 55 56 struct Config { 57 // Load module as dynamic, default static 58 bool dynamic = false; 59 60 // for static mode, if the shape is mutable, set true, otherwise set false to avoid resizeSession freqencily 61 bool shapeMutable = true; 62 // Pre-rearrange weights or not. Disabled by default. 63 // The weights will be rearranged in a general way, so the best implementation 64 // may not be adopted if `rearrange` is enabled. 65 bool rearrange = false; 66 67 BackendInfo* backend = nullptr; 68 }; 69 static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Config* config = nullptr); 70 static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Config* config = nullptr); 71 static Module* extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph = {}); 72 73 static Module* clone(const Module* module, const bool shareParams = false); 74 75 class CloneContext { 76 public: 77 CloneContext() = default; CloneContext(const bool shareParams)78 explicit CloneContext(const bool shareParams) 79 : mShareParams(shareParams) {} 80 virtual ~CloneContext() = default; 81 shareParams() const82 const bool shareParams() const { return mShareParams; } 83 84 EXPRP getOrClone(const EXPRP expr); 85 VARP getOrClone(const VARP var); 86 87 private: 88 bool mShareParams = false; 89 std::unordered_map<const Expr*, EXPRP> mExprMap; 90 std::unordered_map<const Variable*, VARP> mVarMap; 91 }; 92 clone(CloneContext * ctx) const93 virtual Module* clone(CloneContext* ctx) const { 94 return nullptr; 95 } 96 97 protected: 98 void registerModel(const std::vector<std::shared_ptr<Module>>& children); onClearCache()99 virtual void onClearCache() { 100 } 101 102 Module* cloneBaseTo(CloneContext* ctx, Module* module) const; 103 104 private: 105 void _collectParameters(std::vector<Express::VARP>& result) const; 106 std::vector<std::shared_ptr<Module>> mChildren; 107 std::vector<Express::VARP> mParameters; 108 bool mIsTraining = true; 109 std::string mName; 110 std::string mType; 111 }; 112 113 struct SubGraph { 114 std::vector<std::string> inputs; 115 std::vector<std::string> outputs; 116 std::shared_ptr<Module> m; 117 }; 118 119 } // namespace Train 120 } // namespace MNN 121 122 #endif 123