1 //
2 //  Module.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/11/25.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef MNN_Train_Module_hpp
10 #define MNN_Train_Module_hpp
11 
12 #include <vector>
13 #include <unordered_map>
14 
15 #include <MNN/expr/Expr.hpp>
16 #include <MNN/MNNForwardType.h>
17 
18 namespace MNN {
19 namespace Express {
20 struct SubGraph;
21 class MNN_PUBLIC Module {
22 public:
23     Module()                                                                               = default;
24     virtual ~Module()                                                                      = default;
25     virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) = 0;
26     Express::VARP forward(Express::VARP input);
27     std::vector<Express::VARP> parameters() const;
28     bool loadParameters(const std::vector<Express::VARP>& parameters);
29     void setIsTraining(const bool isTraining);
30     bool getIsTraining();
31     void clearCache();
32 
name() const33     const std::string& name() const {
34         return mName;
35     };
setName(std::string name)36     void setName(std::string name) {
37         mName = std::move(name);
38     }
type() const39     const std::string type() const {
40         return mType;
41     }
setType(std::string type)42     void setType(std::string type) {
43         mType = std::move(type);
44     }
45     // Return the parameter index
46     int addParameter(Express::VARP parameter);
47 
48     void setParameter(Express::VARP parameter, int index);
49     static Module* createEmpty(const std::vector<Express::VARP>& parameters);
50 
51     struct BackendInfo {
52         MNNForwardType type = MNN_FORWARD_CPU;
53         BackendConfig* config = nullptr;
54     };
55 
56     struct Config {
57         // Load module as dynamic, default static
58         bool dynamic = false;
59 
60         // for static mode, if the shape is mutable, set true, otherwise set false to avoid resizeSession freqencily
61         bool shapeMutable = true;
62         // Pre-rearrange weights or not. Disabled by default.
63         // The weights will be rearranged in a general way, so the best implementation
64         // may not be adopted if `rearrange` is enabled.
65         bool rearrange = false;
66 
67         BackendInfo* backend = nullptr;
68     };
69     static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Config* config = nullptr);
70     static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Config* config = nullptr);
71     static Module* extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph = {});
72 
73     static Module* clone(const Module* module, const bool shareParams = false);
74 
75     class CloneContext {
76     public:
77         CloneContext() = default;
CloneContext(const bool shareParams)78         explicit CloneContext(const bool shareParams)
79             : mShareParams(shareParams) {}
80         virtual ~CloneContext() = default;
81 
shareParams() const82         const bool shareParams() const { return mShareParams; }
83 
84         EXPRP getOrClone(const EXPRP expr);
85         VARP getOrClone(const VARP var);
86 
87     private:
88         bool mShareParams = false;
89         std::unordered_map<const Expr*, EXPRP> mExprMap;
90         std::unordered_map<const Variable*, VARP> mVarMap;
91     };
92 
clone(CloneContext * ctx) const93     virtual Module* clone(CloneContext* ctx) const {
94         return nullptr;
95     }
96 
97 protected:
98     void registerModel(const std::vector<std::shared_ptr<Module>>& children);
onClearCache()99     virtual void onClearCache() {
100     }
101 
102     Module* cloneBaseTo(CloneContext* ctx, Module* module) const;
103 
104 private:
105     void _collectParameters(std::vector<Express::VARP>& result) const;
106     std::vector<std::shared_ptr<Module>> mChildren;
107     std::vector<Express::VARP> mParameters;
108     bool mIsTraining = true;
109     std::string mName;
110     std::string mType;
111 };
112 
113 struct SubGraph {
114     std::vector<std::string> inputs;
115     std::vector<std::string> outputs;
116     std::shared_ptr<Module> m;
117 };
118 
119 } // namespace Train
120 } // namespace MNN
121 
122 #endif
123