1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file c_api_symbolic.cc
22 * \brief C API related to symbolic graph compsition.
23 */
24 #include <nnvm/c_api.h>
25 #include <nnvm/op.h>
26 #include <nnvm/symbolic.h>
27 #include "c_api_common.h"
28
29 using namespace nnvm;
30
NNListAllOpNames(nn_uint * out_size,const char *** out_array)31 int NNListAllOpNames(nn_uint *out_size,
32 const char*** out_array) {
33 API_BEGIN();
34 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
35 ret->ret_vec_str = dmlc::Registry<Op>::ListAllNames();
36 ret->ret_vec_charp.resize(0);
37 ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
38 for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
39 ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
40 }
41 *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
42 *out_size = static_cast<nn_uint>(ret->ret_vec_str.size());
43 API_END();
44 }
45
NNGetOpHandle(const char * op_name,OpHandle * op_out)46 int NNGetOpHandle(const char* op_name,
47 OpHandle* op_out) {
48 API_BEGIN();
49 *op_out = (OpHandle)Op::Get(op_name); // NOLINT(*)
50 API_END();
51 }
52
NNListUniqueOps(nn_uint * out_size,OpHandle ** out_array)53 int NNListUniqueOps(nn_uint *out_size,
54 OpHandle **out_array) {
55 API_BEGIN();
56 auto &vec = dmlc::Registry<Op>::List();
57 *out_size = static_cast<nn_uint>(vec.size());
58 *out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
59 API_END();
60 }
61
NNAddControlDeps(SymbolHandle handle,SymbolHandle src_dep)62 int NNAddControlDeps(SymbolHandle handle,
63 SymbolHandle src_dep) {
64 API_BEGIN();
65 static_cast<Symbol*>(handle)->AddControlDeps(
66 *static_cast<Symbol*>(src_dep));
67 API_END();
68 }
69
NNGetOpInfo(OpHandle handle,const char ** name,const char ** description,nn_uint * num_doc_args,const char *** arg_names,const char *** arg_type_infos,const char *** arg_descriptions,const char ** return_type)70 int NNGetOpInfo(OpHandle handle,
71 const char **name,
72 const char **description,
73 nn_uint *num_doc_args,
74 const char ***arg_names,
75 const char ***arg_type_infos,
76 const char ***arg_descriptions,
77 const char **return_type) {
78 const Op *op = static_cast<const Op *>(handle);
79 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
80
81 API_BEGIN();
82 *name = op->name.c_str();
83 *description = op->description.c_str();
84 *num_doc_args = static_cast<nn_uint>(op->arguments.size());
85 if (return_type) *return_type = nullptr;
86 ret->ret_vec_charp.resize(0);
87 ret->ret_vec_charp.reserve(op->arguments.size() * 3);
88 for (size_t i = 0; i < op->arguments.size(); ++i) {
89 ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
90 }
91 for (size_t i = 0; i < op->arguments.size(); ++i) {
92 ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
93 }
94 for (size_t i = 0; i < op->arguments.size(); ++i) {
95 ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
96 }
97 *arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
98 *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
99 *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
100 API_END();
101 }
102
NNSymbolCreateAtomicSymbol(OpHandle creator,nn_uint num_param,const char ** keys,const char ** vals,SymbolHandle * out)103 int NNSymbolCreateAtomicSymbol(OpHandle creator,
104 nn_uint num_param,
105 const char **keys,
106 const char **vals,
107 SymbolHandle *out) {
108 Symbol *s = new Symbol();
109 API_BEGIN();
110 const Op* op = static_cast<const Op*>(creator);
111 std::unordered_map<std::string, std::string> kwargs;
112 for (nn_uint i = 0; i < num_param; ++i) {
113 kwargs.insert({std::string(keys[i]), std::string(vals[i])});
114 }
115 *s = Symbol::CreateFunctor(op, std::move(kwargs));
116 *out = s;
117 API_END_HANDLE_ERROR(delete s;);
118 }
119
NNSymbolCreateVariable(const char * name,SymbolHandle * out)120 int NNSymbolCreateVariable(const char *name, SymbolHandle *out) {
121 Symbol *s = new Symbol();
122 API_BEGIN();
123 *s = Symbol::CreateVariable(name);
124 *out = s;
125 API_END_HANDLE_ERROR(delete s);
126 }
127
NNSymbolCreateGroup(nn_uint num_symbols,SymbolHandle * symbols,SymbolHandle * out)128 int NNSymbolCreateGroup(nn_uint num_symbols,
129 SymbolHandle *symbols,
130 SymbolHandle *out) {
131 Symbol *s = new Symbol();
132 Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*)
133 API_BEGIN();
134 std::vector<Symbol> syms;
135 for (nn_uint i = 0; i < num_symbols; ++i) {
136 syms.push_back(*sym_arr[i]);
137 }
138 *s = Symbol::CreateGroup(syms);
139 *out = s;
140 API_END_HANDLE_ERROR(delete s);
141 }
142
NNSymbolGetOutput(SymbolHandle symbol,nn_uint index,SymbolHandle * out)143 int NNSymbolGetOutput(SymbolHandle symbol,
144 nn_uint index,
145 SymbolHandle *out) {
146 Symbol *s = new Symbol();
147 API_BEGIN();
148 *s = (*static_cast<Symbol*>(symbol))[index];
149 *out = s;
150 API_END_HANDLE_ERROR(delete s);
151 }
152
NNSymbolGetInternals(SymbolHandle symbol,SymbolHandle * out)153 int NNSymbolGetInternals(SymbolHandle symbol,
154 SymbolHandle *out) {
155 Symbol *s = new Symbol();
156 API_BEGIN();
157 *s = static_cast<Symbol*>(symbol)->GetInternals();
158 *out = s;
159 API_END_HANDLE_ERROR(delete s);
160 }
161
NNSymbolGetChildren(SymbolHandle symbol,SymbolHandle * out)162 int NNSymbolGetChildren(SymbolHandle symbol,
163 SymbolHandle *out) {
164 Symbol *s = new Symbol();
165 API_BEGIN();
166 *s = static_cast<Symbol*>(symbol)->GetChildren();
167 *out = s;
168 API_END_HANDLE_ERROR(delete s);
169 }
170
NNSymbolFree(SymbolHandle symbol)171 int NNSymbolFree(SymbolHandle symbol) {
172 API_BEGIN();
173 delete static_cast<Symbol*>(symbol);
174 API_END();
175 }
176
NNSymbolCopy(SymbolHandle symbol,SymbolHandle * out)177 int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) {
178 Symbol *s = new Symbol();
179 API_BEGIN();
180 *s = static_cast<const Symbol*>(symbol)->Copy();
181 *out = s;
182 API_END_HANDLE_ERROR(delete s);
183 }
184
NNSymbolPrint(SymbolHandle symbol,const char ** out_str)185 int NNSymbolPrint(SymbolHandle symbol, const char **out_str) {
186 Symbol *s = static_cast<Symbol*>(symbol);
187 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
188 API_BEGIN();
189 std::ostringstream os;
190 s->Print(os);
191 ret->ret_str = os.str();
192 *out_str = (ret->ret_str).c_str();
193 API_END();
194 }
195
NNSymbolGetAttr(SymbolHandle symbol,const char * key,const char ** out,int * success)196 int NNSymbolGetAttr(SymbolHandle symbol,
197 const char* key,
198 const char** out,
199 int* success) {
200 Symbol *s = static_cast<Symbol*>(symbol);
201 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
202 API_BEGIN();
203 if (s->GetAttr(key, &(ret->ret_str))) {
204 *out = (ret->ret_str).c_str();
205 *success = 1;
206 } else {
207 *out = nullptr;
208 *success = 0;
209 }
210 API_END();
211 }
212
NNSymbolSetAttrs(SymbolHandle symbol,nn_uint num_param,const char ** keys,const char ** vals)213 int NNSymbolSetAttrs(SymbolHandle symbol,
214 nn_uint num_param,
215 const char** keys,
216 const char** vals) {
217 Symbol *s = static_cast<Symbol*>(symbol);
218 API_BEGIN();
219 std::vector<std::pair<std::string, std::string> > kwargs;
220 for (nn_uint i = 0; i < num_param; ++i) {
221 kwargs.emplace_back(
222 std::make_pair(std::string(keys[i]), std::string(vals[i])));
223 }
224 s->SetAttrs(kwargs);
225 API_END();
226 }
227
NNSymbolListAttrs(SymbolHandle symbol,int option,nn_uint * out_size,const char *** out)228 int NNSymbolListAttrs(SymbolHandle symbol,
229 int option,
230 nn_uint *out_size,
231 const char*** out) {
232 Symbol *s = static_cast<Symbol*>(symbol);
233 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
234 API_BEGIN();
235 std::unordered_map<std::string, std::string> attr =
236 s->ListAttrs(static_cast<Symbol::ListAttrOption>(option)); // NOLINT(*)
237
238 std::vector<std::string>& attr_list = ret->ret_vec_str;
239 attr_list.resize(0);
240 attr_list.reserve(attr.size());
241 for (const auto& kv : attr) {
242 attr_list.push_back(kv.first);
243 attr_list.push_back(kv.second);
244 }
245 *out_size = attr.size();
246 ret->ret_vec_charp.clear();
247 ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
248 for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
249 ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
250 }
251 *out = dmlc::BeginPtr(ret->ret_vec_charp);
252 API_END();
253 }
254
NNSymbolListInputVariables(SymbolHandle symbol,int option,nn_uint * out_size,SymbolHandle ** out_sym_array)255 int NNSymbolListInputVariables(SymbolHandle symbol,
256 int option,
257 nn_uint *out_size,
258 SymbolHandle** out_sym_array) {
259 Symbol *s = static_cast<Symbol*>(symbol);
260 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
261 API_BEGIN();
262 std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option));
263 ret->ret_handles.resize(0);
264 ret->ret_handles.reserve(vs.size());
265 for (size_t i = 0; i < vs.size(); ++i) {
266 nnvm::Symbol* rs = new nnvm::Symbol();
267 rs->outputs.push_back(NodeEntry{vs[i], 0, 0});
268 ret->ret_handles.push_back(rs);
269 }
270 *out_size = static_cast<nn_uint>(vs.size());
271 *out_sym_array = dmlc::BeginPtr(ret->ret_handles);
272 API_END();
273 }
274
NNSymbolListInputNames(SymbolHandle symbol,int option,nn_uint * out_size,const char *** out_str_array)275 int NNSymbolListInputNames(SymbolHandle symbol,
276 int option,
277 nn_uint *out_size,
278 const char ***out_str_array) {
279 Symbol *s = static_cast<Symbol*>(symbol);
280 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
281 API_BEGIN();
282 ret->ret_vec_str =
283 s->ListInputNames(Symbol::ListInputOption(option));
284 ret->ret_vec_charp.resize(0);
285 ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
286 for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
287 ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
288 }
289 *out_size = static_cast<nn_uint>(ret->ret_vec_charp.size());
290 *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp);
291 API_END();
292 }
293
NNSymbolListOutputNames(SymbolHandle symbol,nn_uint * out_size,const char *** out_str_array)294 int NNSymbolListOutputNames(SymbolHandle symbol,
295 nn_uint *out_size,
296 const char ***out_str_array) {
297 Symbol *s = static_cast<Symbol*>(symbol);
298 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
299 API_BEGIN();
300 ret->ret_vec_str = s->ListOutputNames();
301 ret->ret_vec_charp.resize(0);
302 ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
303 for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
304 ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
305 }
306 *out_size = static_cast<nn_uint>(ret->ret_vec_charp.size());
307 *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp);
308 API_END();
309 }
310
NNSymbolGetNumOutputs(SymbolHandle symbol,nn_uint * output_count)311 int NNSymbolGetNumOutputs(SymbolHandle symbol,
312 nn_uint *output_count) {
313 Symbol *s = static_cast<Symbol*>(symbol);
314 API_BEGIN();
315 *output_count = static_cast<nn_uint>(s->outputs.size());
316 API_END();
317 }
318
NNSymbolCompose(SymbolHandle sym,const char * name,nn_uint num_args,const char ** keys,SymbolHandle * args)319 int NNSymbolCompose(SymbolHandle sym,
320 const char *name,
321 nn_uint num_args,
322 const char** keys,
323 SymbolHandle* args) {
324 API_BEGIN();
325 NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
326 std::string& s_name = ret->ret_str;
327 std::unordered_map<std::string, const Symbol*>& kwargs
328 = ret->kwarg_symbol;
329 kwargs.clear();
330 if (name != nullptr) {
331 s_name = name;
332 } else {
333 s_name.clear();
334 }
335 Symbol* s = static_cast<Symbol*>(sym);
336 if (keys == nullptr && num_args != 0) {
337 kwargs.clear();
338 array_view<const Symbol*> parg(
339 (Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
340 s->Compose(parg, kwargs, s_name);
341 } else {
342 for (nn_uint i = 0; i < num_args; ++i) {
343 kwargs[keys[i]] = (Symbol*)args[i]; // NOLINT(*)
344 }
345 s->Compose(array_view<const Symbol*>(), kwargs, s_name);
346 }
347 API_END();
348 }
349