1 //
2 //  Session.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/07/30.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef Session_hpp
10 #define Session_hpp
11 
12 #include <MNN/Tensor.hpp>
13 #include <map>
14 #include <memory>
15 #include <vector>
16 #include "Pipeline.hpp"
17 #include "Schedule.hpp"
18 #include "core/Backend.hpp"
19 #include "core/Macro.h"
20 #include "shape/SizeComputer.hpp"
21 
22 namespace MNN {
23 struct Net;
24 /** infer unit. multiple sessions could share one net. */
25 class MNN_PUBLIC Session {
26 public:
27     Session(Schedule::ScheduleInfo&& info, Interpreter::SessionMode callBackMode, Interpreter::SessionMode inputMode,
28             RuntimeInfo&& runtime);
29     ~Session();
30 
31 public:
32     /**
33      * @brief infer.
34      * @return result code.
35      */
36     ErrorCode run() const;
37     /**
38      * @brief infer with callbacks and sync option.
39      * @param enterCallback callback before each op.
40      * @param exitCallback  callback after each op.
41      * @param sync          wait until all ops done before return or not.
42      * @return result code.
43      */
44     ErrorCode runWithCallBack(const TensorCallBackWithInfo& enterCallback, const TensorCallBackWithInfo& exitCallback,
45                               bool sync = false) const;
46 
47     bool getInfo(Interpreter::SessionInfoCode code, void* ptr) const;
48 
49     void cloneExecution(const std::map<const Op*, std::shared_ptr<Execution>>& cache, int pipelineIndex);
50     const std::map<const Op*, std::shared_ptr<Execution>>& getExecution(int pipelineIndex);
51 public:
52     /**
53      * @brief resize tensors and buffers responding to input changes.
54      * @return result code.
55      */
56     ErrorCode resize(bool isStatic = false);
57 
58     /**
59      * @brief set if needs resize.
60      * @param flag  needs resize or not.
61      */
setNeedResize(bool flag=true)62     void setNeedResize(bool flag = true) {
63         mNeedResize = flag;
64     }
65 
setNeedMalloc(bool flag=true)66     void setNeedMalloc(bool flag = true) {
67         mNeedMalloc = flag;
68     }
69 
70 public:
71     /**
72      * @brief get backend that create the tensor.
73      * @param tensor    given tensor.
74      * @return backend that create the tensor, NULL if the tensor is created by default backend (CPU backend).
75      */
76     const Backend* getBackEnd(const Tensor* tensor) const;
77 
78     /**
79      * @brief get input tensor for given op name.
80      * @param name given op name. if NULL, return first input tensor.
81      * @return input tensor if found, NULL otherwise.
82      */
83     Tensor* getInput(const char* name) const;
84 
85     /**
86      * @brief get output tensor for given op name.
87      * @param name given op name. if NULL, return first output tensor.
88      * @return output tensor if found, NULL otherwise.
89      */
90     Tensor* getOutput(const char* name) const;
91 
92     /**
93      * @brief get output tensors map.
94      * @return get output tensors map.
95      */
96     const std::map<std::string, Tensor*>& getOutputAll() const;
97     const std::map<std::string, Tensor*>& getInputAll() const;
98 
99     /**
100      * @brief check session is valid or not.
101      * @return session is valid or not.
102      */
valid() const103     inline bool valid() const {
104         return mValid;
105     }
106 
107     /**
108      * @brief update the session's const value to origin model's const blob.
109      * @return errorcode
110      */
111     ErrorCode updateToModel(Net* net) const;
112 
113     bool loadCache(const void* buffer, size_t size);
114     std::pair<const void*, size_t> getCache();
115 
116 protected:
getPipelines() const117     const std::vector<std::shared_ptr<Pipeline>>& getPipelines() const {
118         return this->mPipelines;
119     }
120 
121 private:
122     void _clearCache();
123     void _setUpTensorInfo(const Schedule::ScheduleInfo& info);
124 
125 private:
126     RuntimeInfo mRuntime;
127     std::vector<std::shared_ptr<Pipeline>> mPipelines;
128     std::vector<std::pair<int, std::shared_ptr<Tensor>>> mTensors;
129     std::map<std::string, Tensor*> mInputs;
130     std::map<std::string, Tensor*> mOutputs;
131     bool mNeedResize = true;
132     bool mValid      = true;
133     bool mNeedMalloc = true;
134     Interpreter::SessionMode mCallBackMode;
135 };
136 } // namespace MNN
137 
138 #endif /* Session_hpp */
139