1 // Tencent is pleased to support the open source community by making ncnn available. 2 // 3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 4 // 5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 // in compliance with the License. You may obtain a copy of the License at 7 // 8 // https://opensource.org/licenses/BSD-3-Clause 9 // 10 // Unless required by applicable law or agreed to in writing, software distributed 11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 // specific language governing permissions and limitations under the License. 14 15 #ifndef PNNX_IR_H 16 #define PNNX_IR_H 17 18 #include <initializer_list> 19 #include <map> 20 #include <string> 21 #include <vector> 22 23 namespace torch { 24 namespace jit { 25 struct Value; 26 struct Node; 27 } // namespace jit 28 } // namespace torch 29 namespace at { 30 class Tensor; 31 } 32 33 namespace pnnx { 34 35 class Parameter 36 { 37 public: Parameter()38 Parameter() 39 : type(0) 40 { 41 } Parameter(bool _b)42 Parameter(bool _b) 43 : type(1), b(_b) 44 { 45 } Parameter(int _i)46 Parameter(int _i) 47 : type(2), i(_i) 48 { 49 } Parameter(long _l)50 Parameter(long _l) 51 : type(2), i(_l) 52 { 53 } Parameter(long long _l)54 Parameter(long long _l) 55 : type(2), i(_l) 56 { 57 } Parameter(float _f)58 Parameter(float _f) 59 : type(3), f(_f) 60 { 61 } Parameter(double _d)62 Parameter(double _d) 63 : type(3), f(_d) 64 { 65 } Parameter(const char * _s)66 Parameter(const char* _s) 67 : type(4), s(_s) 68 { 69 } Parameter(const std::string & _s)70 Parameter(const std::string& _s) 71 : type(4), s(_s) 72 { 73 } Parameter(const std::initializer_list<int> & _ai)74 Parameter(const std::initializer_list<int>& _ai) 75 : type(5), ai(_ai) 76 { 77 } Parameter(const std::initializer_list<int64_t> & _ai)78 Parameter(const std::initializer_list<int64_t>& _ai) 79 : type(5) 80 { 81 for (const auto& x : _ai) 82 ai.push_back((int)x); 83 } Parameter(const std::vector<int> & _ai)84 Parameter(const std::vector<int>& _ai) 85 : type(5), ai(_ai) 86 { 87 } Parameter(const std::initializer_list<float> & _af)88 Parameter(const std::initializer_list<float>& _af) 89 : type(6), af(_af) 90 { 91 } Parameter(const std::initializer_list<double> & _af)92 Parameter(const std::initializer_list<double>& _af) 93 : type(6) 94 { 95 for (const auto& x : _af) 96 af.push_back((float)x); 97 } Parameter(const std::vector<float> & _af)98 Parameter(const std::vector<float>& _af) 99 : type(6), af(_af) 100 { 101 } Parameter(const std::initializer_list<const char * > & _as)102 Parameter(const std::initializer_list<const char*>& _as) 103 : type(7) 104 { 105 for (const auto& x : _as) 106 as.push_back(std::string(x)); 107 } Parameter(const std::initializer_list<std::string> & _as)108 Parameter(const std::initializer_list<std::string>& _as) 109 : type(7), as(_as) 110 { 111 } Parameter(const std::vector<std::string> & _as)112 Parameter(const std::vector<std::string>& _as) 113 : type(7), as(_as) 114 { 115 } 116 117 Parameter(const torch::jit::Node* value_node); 118 Parameter(const torch::jit::Value* value); 119 120 static Parameter parse_from_string(const std::string& value); 121 122 // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others 123 int type; 124 125 // value 126 bool b; 127 int i; 128 float f; 129 std::string s; 130 std::vector<int> ai; 131 std::vector<float> af; 132 std::vector<std::string> as; 133 }; 134 135 class Attribute 136 { 137 public: Attribute()138 Attribute() 139 : type(0) 140 { 141 } 142 143 Attribute(const at::Tensor& t); 144 145 Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t); 146 147 // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 148 int type; 149 std::vector<int> shape; 150 151 std::vector<char> data; 152 }; 153 154 class Operator; 155 class Operand 156 { 157 public: 158 void remove_consumer(const Operator* c); 159 160 std::string name; 161 162 Operator* producer; 163 std::vector<Operator*> consumers; 164 165 // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 166 int type; 167 std::vector<int> shape; 168 169 std::map<std::string, Parameter> params; 170 171 private: 172 friend class Graph; Operand()173 Operand() 174 { 175 } 176 }; 177 178 class Operator 179 { 180 public: 181 std::string type; 182 std::string name; 183 184 std::vector<Operand*> inputs; 185 std::vector<Operand*> outputs; 186 187 std::vector<std::string> inputnames; 188 std::map<std::string, Parameter> params; 189 std::map<std::string, Attribute> attrs; 190 191 private: 192 friend class Graph; Operator()193 Operator() 194 { 195 } 196 }; 197 198 class Graph 199 { 200 public: 201 Graph(); 202 ~Graph(); 203 204 int load(const std::string& parampath, const std::string& binpath); 205 int save(const std::string& parampath, const std::string& binpath); 206 207 int python(const std::string& pypath, const std::string& binpath); 208 209 int ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath); 210 211 int parse(const std::string& param); 212 213 Operator* new_operator(const std::string& type, const std::string& name); 214 215 Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur); 216 217 Operand* new_operand(const torch::jit::Value* v); 218 219 Operand* new_operand(const std::string& name); 220 221 Operand* get_operand(const std::string& name); 222 223 std::vector<Operator*> ops; 224 std::vector<Operand*> operands; 225 226 private: 227 Graph(const Graph& rhs); 228 Graph& operator=(const Graph& rhs); 229 }; 230 231 } // namespace pnnx 232 233 #endif // PNNX_IR_H 234