1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #ifndef PNNX_IR_H
16 #define PNNX_IR_H
17 
18 #include <initializer_list>
19 #include <map>
20 #include <string>
21 #include <vector>
22 
23 namespace torch {
24 namespace jit {
25 struct Value;
26 struct Node;
27 } // namespace jit
28 } // namespace torch
29 namespace at {
30 class Tensor;
31 }
32 
33 namespace pnnx {
34 
35 class Parameter
36 {
37 public:
Parameter()38     Parameter()
39         : type(0)
40     {
41     }
Parameter(bool _b)42     Parameter(bool _b)
43         : type(1), b(_b)
44     {
45     }
Parameter(int _i)46     Parameter(int _i)
47         : type(2), i(_i)
48     {
49     }
Parameter(long _l)50     Parameter(long _l)
51         : type(2), i(_l)
52     {
53     }
Parameter(long long _l)54     Parameter(long long _l)
55         : type(2), i(_l)
56     {
57     }
Parameter(float _f)58     Parameter(float _f)
59         : type(3), f(_f)
60     {
61     }
Parameter(double _d)62     Parameter(double _d)
63         : type(3), f(_d)
64     {
65     }
Parameter(const char * _s)66     Parameter(const char* _s)
67         : type(4), s(_s)
68     {
69     }
Parameter(const std::string & _s)70     Parameter(const std::string& _s)
71         : type(4), s(_s)
72     {
73     }
Parameter(const std::initializer_list<int> & _ai)74     Parameter(const std::initializer_list<int>& _ai)
75         : type(5), ai(_ai)
76     {
77     }
Parameter(const std::initializer_list<int64_t> & _ai)78     Parameter(const std::initializer_list<int64_t>& _ai)
79         : type(5)
80     {
81         for (const auto& x : _ai)
82             ai.push_back((int)x);
83     }
Parameter(const std::vector<int> & _ai)84     Parameter(const std::vector<int>& _ai)
85         : type(5), ai(_ai)
86     {
87     }
Parameter(const std::initializer_list<float> & _af)88     Parameter(const std::initializer_list<float>& _af)
89         : type(6), af(_af)
90     {
91     }
Parameter(const std::initializer_list<double> & _af)92     Parameter(const std::initializer_list<double>& _af)
93         : type(6)
94     {
95         for (const auto& x : _af)
96             af.push_back((float)x);
97     }
Parameter(const std::vector<float> & _af)98     Parameter(const std::vector<float>& _af)
99         : type(6), af(_af)
100     {
101     }
Parameter(const std::initializer_list<const char * > & _as)102     Parameter(const std::initializer_list<const char*>& _as)
103         : type(7)
104     {
105         for (const auto& x : _as)
106             as.push_back(std::string(x));
107     }
Parameter(const std::initializer_list<std::string> & _as)108     Parameter(const std::initializer_list<std::string>& _as)
109         : type(7), as(_as)
110     {
111     }
Parameter(const std::vector<std::string> & _as)112     Parameter(const std::vector<std::string>& _as)
113         : type(7), as(_as)
114     {
115     }
116 
117     Parameter(const torch::jit::Node* value_node);
118     Parameter(const torch::jit::Value* value);
119 
120     static Parameter parse_from_string(const std::string& value);
121 
122     // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others
123     int type;
124 
125     // value
126     bool b;
127     int i;
128     float f;
129     std::string s;
130     std::vector<int> ai;
131     std::vector<float> af;
132     std::vector<std::string> as;
133 };
134 
135 class Attribute
136 {
137 public:
Attribute()138     Attribute()
139         : type(0)
140     {
141     }
142 
143     Attribute(const at::Tensor& t);
144 
145     Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t);
146 
147     // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8
148     int type;
149     std::vector<int> shape;
150 
151     std::vector<char> data;
152 };
153 
154 class Operator;
155 class Operand
156 {
157 public:
158     void remove_consumer(const Operator* c);
159 
160     std::string name;
161 
162     Operator* producer;
163     std::vector<Operator*> consumers;
164 
165     // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8
166     int type;
167     std::vector<int> shape;
168 
169     std::map<std::string, Parameter> params;
170 
171 private:
172     friend class Graph;
Operand()173     Operand()
174     {
175     }
176 };
177 
178 class Operator
179 {
180 public:
181     std::string type;
182     std::string name;
183 
184     std::vector<Operand*> inputs;
185     std::vector<Operand*> outputs;
186 
187     std::vector<std::string> inputnames;
188     std::map<std::string, Parameter> params;
189     std::map<std::string, Attribute> attrs;
190 
191 private:
192     friend class Graph;
Operator()193     Operator()
194     {
195     }
196 };
197 
198 class Graph
199 {
200 public:
201     Graph();
202     ~Graph();
203 
204     int load(const std::string& parampath, const std::string& binpath);
205     int save(const std::string& parampath, const std::string& binpath);
206 
207     int python(const std::string& pypath, const std::string& binpath);
208 
209     int ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath);
210 
211     int parse(const std::string& param);
212 
213     Operator* new_operator(const std::string& type, const std::string& name);
214 
215     Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur);
216 
217     Operand* new_operand(const torch::jit::Value* v);
218 
219     Operand* new_operand(const std::string& name);
220 
221     Operand* get_operand(const std::string& name);
222 
223     std::vector<Operator*> ops;
224     std::vector<Operand*> operands;
225 
226 private:
227     Graph(const Graph& rhs);
228     Graph& operator=(const Graph& rhs);
229 };
230 
231 } // namespace pnnx
232 
233 #endif // PNNX_IR_H
234