1 // 2 // Session.hpp 3 // MNN 4 // 5 // Created by MNN on 2018/07/30. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef Session_hpp 10 #define Session_hpp 11 12 #include <MNN/Tensor.hpp> 13 #include <map> 14 #include <memory> 15 #include <vector> 16 #include "Pipeline.hpp" 17 #include "Schedule.hpp" 18 #include "core/Backend.hpp" 19 #include "core/Macro.h" 20 #include "shape/SizeComputer.hpp" 21 22 namespace MNN { 23 struct Net; 24 /** infer unit. multiple sessions could share one net. */ 25 class MNN_PUBLIC Session { 26 public: 27 Session(Schedule::ScheduleInfo&& info, Interpreter::SessionMode callBackMode, Interpreter::SessionMode inputMode, 28 RuntimeInfo&& runtime); 29 ~Session(); 30 31 public: 32 /** 33 * @brief infer. 34 * @return result code. 35 */ 36 ErrorCode run() const; 37 /** 38 * @brief infer with callbacks and sync option. 39 * @param enterCallback callback before each op. 40 * @param exitCallback callback after each op. 41 * @param sync wait until all ops done before return or not. 42 * @return result code. 43 */ 44 ErrorCode runWithCallBack(const TensorCallBackWithInfo& enterCallback, const TensorCallBackWithInfo& exitCallback, 45 bool sync = false) const; 46 47 bool getInfo(Interpreter::SessionInfoCode code, void* ptr) const; 48 49 void cloneExecution(const std::map<const Op*, std::shared_ptr<Execution>>& cache, int pipelineIndex); 50 const std::map<const Op*, std::shared_ptr<Execution>>& getExecution(int pipelineIndex); 51 public: 52 /** 53 * @brief resize tensors and buffers responding to input changes. 54 * @return result code. 55 */ 56 ErrorCode resize(bool isStatic = false); 57 58 /** 59 * @brief set if needs resize. 60 * @param flag needs resize or not. 61 */ setNeedResize(bool flag=true)62 void setNeedResize(bool flag = true) { 63 mNeedResize = flag; 64 } 65 setNeedMalloc(bool flag=true)66 void setNeedMalloc(bool flag = true) { 67 mNeedMalloc = flag; 68 } 69 70 public: 71 /** 72 * @brief get backend that create the tensor. 73 * @param tensor given tensor. 74 * @return backend that create the tensor, NULL if the tensor is created by default backend (CPU backend). 75 */ 76 const Backend* getBackEnd(const Tensor* tensor) const; 77 78 /** 79 * @brief get input tensor for given op name. 80 * @param name given op name. if NULL, return first input tensor. 81 * @return input tensor if found, NULL otherwise. 82 */ 83 Tensor* getInput(const char* name) const; 84 85 /** 86 * @brief get output tensor for given op name. 87 * @param name given op name. if NULL, return first output tensor. 88 * @return output tensor if found, NULL otherwise. 89 */ 90 Tensor* getOutput(const char* name) const; 91 92 /** 93 * @brief get output tensors map. 94 * @return get output tensors map. 95 */ 96 const std::map<std::string, Tensor*>& getOutputAll() const; 97 const std::map<std::string, Tensor*>& getInputAll() const; 98 99 /** 100 * @brief check session is valid or not. 101 * @return session is valid or not. 102 */ valid() const103 inline bool valid() const { 104 return mValid; 105 } 106 107 /** 108 * @brief update the session's const value to origin model's const blob. 109 * @return errorcode 110 */ 111 ErrorCode updateToModel(Net* net) const; 112 113 bool loadCache(const void* buffer, size_t size); 114 std::pair<const void*, size_t> getCache(); 115 116 protected: getPipelines() const117 const std::vector<std::shared_ptr<Pipeline>>& getPipelines() const { 118 return this->mPipelines; 119 } 120 121 private: 122 void _clearCache(); 123 void _setUpTensorInfo(const Schedule::ScheduleInfo& info); 124 125 private: 126 RuntimeInfo mRuntime; 127 std::vector<std::shared_ptr<Pipeline>> mPipelines; 128 std::vector<std::pair<int, std::shared_ptr<Tensor>>> mTensors; 129 std::map<std::string, Tensor*> mInputs; 130 std::map<std::string, Tensor*> mOutputs; 131 bool mNeedResize = true; 132 bool mValid = true; 133 bool mNeedMalloc = true; 134 Interpreter::SessionMode mCallBackMode; 135 }; 136 } // namespace MNN 137 138 #endif /* Session_hpp */ 139