1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file c_api_symbolic.cc
22  * \brief C API related to symbolic graph compsition.
23  */
24 #include <nnvm/c_api.h>
25 #include <nnvm/op.h>
26 #include <nnvm/symbolic.h>
27 #include "c_api_common.h"
28 
29 using namespace nnvm;
30 
NNListAllOpNames(nn_uint * out_size,const char *** out_array)31 int NNListAllOpNames(nn_uint *out_size,
32                      const char*** out_array) {
33   API_BEGIN();
34   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
35   ret->ret_vec_str = dmlc::Registry<Op>::ListAllNames();
36   ret->ret_vec_charp.resize(0);
37   ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
38   for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
39     ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
40   }
41   *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
42   *out_size = static_cast<nn_uint>(ret->ret_vec_str.size());
43   API_END();
44 }
45 
NNGetOpHandle(const char * op_name,OpHandle * op_out)46 int NNGetOpHandle(const char* op_name,
47                   OpHandle* op_out) {
48   API_BEGIN();
49   *op_out = (OpHandle)Op::Get(op_name);  // NOLINT(*)
50   API_END();
51 }
52 
NNListUniqueOps(nn_uint * out_size,OpHandle ** out_array)53 int NNListUniqueOps(nn_uint *out_size,
54                     OpHandle **out_array) {
55   API_BEGIN();
56   auto &vec = dmlc::Registry<Op>::List();
57   *out_size = static_cast<nn_uint>(vec.size());
58   *out_array = (OpHandle*)(dmlc::BeginPtr(vec));  //  NOLINT(*)
59   API_END();
60 }
61 
NNAddControlDeps(SymbolHandle handle,SymbolHandle src_dep)62 int NNAddControlDeps(SymbolHandle handle,
63                      SymbolHandle src_dep) {
64   API_BEGIN();
65   static_cast<Symbol*>(handle)->AddControlDeps(
66       *static_cast<Symbol*>(src_dep));
67   API_END();
68 }
69 
NNGetOpInfo(OpHandle handle,const char ** name,const char ** description,nn_uint * num_doc_args,const char *** arg_names,const char *** arg_type_infos,const char *** arg_descriptions,const char ** return_type)70 int NNGetOpInfo(OpHandle handle,
71                 const char **name,
72                 const char **description,
73                 nn_uint *num_doc_args,
74                 const char ***arg_names,
75                 const char ***arg_type_infos,
76                 const char ***arg_descriptions,
77                 const char **return_type) {
78   const Op *op = static_cast<const Op *>(handle);
79   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
80 
81   API_BEGIN();
82   *name = op->name.c_str();
83   *description = op->description.c_str();
84   *num_doc_args = static_cast<nn_uint>(op->arguments.size());
85   if (return_type) *return_type = nullptr;
86   ret->ret_vec_charp.resize(0);
87   ret->ret_vec_charp.reserve(op->arguments.size() * 3);
88   for (size_t i = 0; i < op->arguments.size(); ++i) {
89     ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
90   }
91   for (size_t i = 0; i < op->arguments.size(); ++i) {
92     ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
93   }
94   for (size_t i = 0; i < op->arguments.size(); ++i) {
95     ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
96   }
97   *arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
98   *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
99   *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
100   API_END();
101 }
102 
NNSymbolCreateAtomicSymbol(OpHandle creator,nn_uint num_param,const char ** keys,const char ** vals,SymbolHandle * out)103 int NNSymbolCreateAtomicSymbol(OpHandle creator,
104                                nn_uint num_param,
105                                const char **keys,
106                                const char **vals,
107                                SymbolHandle *out) {
108   Symbol *s = new Symbol();
109   API_BEGIN();
110   const Op* op = static_cast<const Op*>(creator);
111   std::unordered_map<std::string, std::string> kwargs;
112   for (nn_uint i = 0; i < num_param; ++i) {
113     kwargs.insert({std::string(keys[i]), std::string(vals[i])});
114   }
115   *s = Symbol::CreateFunctor(op, std::move(kwargs));
116   *out = s;
117   API_END_HANDLE_ERROR(delete s;);
118 }
119 
NNSymbolCreateVariable(const char * name,SymbolHandle * out)120 int NNSymbolCreateVariable(const char *name, SymbolHandle *out) {
121   Symbol *s = new Symbol();
122   API_BEGIN();
123   *s = Symbol::CreateVariable(name);
124   *out = s;
125   API_END_HANDLE_ERROR(delete s);
126 }
127 
NNSymbolCreateGroup(nn_uint num_symbols,SymbolHandle * symbols,SymbolHandle * out)128 int NNSymbolCreateGroup(nn_uint num_symbols,
129                         SymbolHandle *symbols,
130                         SymbolHandle *out) {
131   Symbol *s = new Symbol();
132   Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*)
133   API_BEGIN();
134   std::vector<Symbol> syms;
135   for (nn_uint i = 0; i < num_symbols; ++i) {
136     syms.push_back(*sym_arr[i]);
137   }
138   *s = Symbol::CreateGroup(syms);
139   *out = s;
140   API_END_HANDLE_ERROR(delete s);
141 }
142 
NNSymbolGetOutput(SymbolHandle symbol,nn_uint index,SymbolHandle * out)143 int NNSymbolGetOutput(SymbolHandle symbol,
144                       nn_uint index,
145                       SymbolHandle *out) {
146   Symbol *s = new Symbol();
147   API_BEGIN();
148   *s = (*static_cast<Symbol*>(symbol))[index];
149   *out = s;
150   API_END_HANDLE_ERROR(delete s);
151 }
152 
NNSymbolGetInternals(SymbolHandle symbol,SymbolHandle * out)153 int NNSymbolGetInternals(SymbolHandle symbol,
154                          SymbolHandle *out) {
155   Symbol *s = new Symbol();
156   API_BEGIN();
157   *s = static_cast<Symbol*>(symbol)->GetInternals();
158   *out = s;
159   API_END_HANDLE_ERROR(delete s);
160 }
161 
NNSymbolGetChildren(SymbolHandle symbol,SymbolHandle * out)162 int NNSymbolGetChildren(SymbolHandle symbol,
163                         SymbolHandle *out) {
164   Symbol *s = new Symbol();
165   API_BEGIN();
166   *s = static_cast<Symbol*>(symbol)->GetChildren();
167   *out = s;
168   API_END_HANDLE_ERROR(delete s);
169 }
170 
NNSymbolFree(SymbolHandle symbol)171 int NNSymbolFree(SymbolHandle symbol) {
172   API_BEGIN();
173   delete static_cast<Symbol*>(symbol);
174   API_END();
175 }
176 
NNSymbolCopy(SymbolHandle symbol,SymbolHandle * out)177 int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) {
178   Symbol *s = new Symbol();
179   API_BEGIN();
180   *s = static_cast<const Symbol*>(symbol)->Copy();
181   *out = s;
182   API_END_HANDLE_ERROR(delete s);
183 }
184 
NNSymbolPrint(SymbolHandle symbol,const char ** out_str)185 int NNSymbolPrint(SymbolHandle symbol, const char **out_str) {
186   Symbol *s = static_cast<Symbol*>(symbol);
187   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
188   API_BEGIN();
189   std::ostringstream os;
190   s->Print(os);
191   ret->ret_str = os.str();
192   *out_str = (ret->ret_str).c_str();
193   API_END();
194 }
195 
NNSymbolGetAttr(SymbolHandle symbol,const char * key,const char ** out,int * success)196 int NNSymbolGetAttr(SymbolHandle symbol,
197                     const char* key,
198                     const char** out,
199                     int* success) {
200   Symbol *s = static_cast<Symbol*>(symbol);
201   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
202   API_BEGIN();
203   if (s->GetAttr(key, &(ret->ret_str))) {
204     *out = (ret->ret_str).c_str();
205     *success = 1;
206   } else {
207     *out = nullptr;
208     *success = 0;
209   }
210   API_END();
211 }
212 
NNSymbolSetAttrs(SymbolHandle symbol,nn_uint num_param,const char ** keys,const char ** vals)213 int NNSymbolSetAttrs(SymbolHandle symbol,
214                      nn_uint num_param,
215                      const char** keys,
216                      const char** vals) {
217   Symbol *s = static_cast<Symbol*>(symbol);
218   API_BEGIN();
219   std::vector<std::pair<std::string, std::string> > kwargs;
220   for (nn_uint i = 0; i < num_param; ++i) {
221     kwargs.emplace_back(
222         std::make_pair(std::string(keys[i]), std::string(vals[i])));
223   }
224   s->SetAttrs(kwargs);
225   API_END();
226 }
227 
NNSymbolListAttrs(SymbolHandle symbol,int option,nn_uint * out_size,const char *** out)228 int NNSymbolListAttrs(SymbolHandle symbol,
229                       int option,
230                       nn_uint *out_size,
231                       const char*** out) {
232   Symbol *s = static_cast<Symbol*>(symbol);
233   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
234   API_BEGIN();
235   std::unordered_map<std::string, std::string> attr =
236       s->ListAttrs(static_cast<Symbol::ListAttrOption>(option));  // NOLINT(*)
237 
238   std::vector<std::string>& attr_list = ret->ret_vec_str;
239   attr_list.resize(0);
240   attr_list.reserve(attr.size());
241   for (const auto& kv : attr) {
242     attr_list.push_back(kv.first);
243     attr_list.push_back(kv.second);
244   }
245   *out_size = attr.size();
246   ret->ret_vec_charp.clear();
247   ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
248   for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
249     ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
250   }
251   *out = dmlc::BeginPtr(ret->ret_vec_charp);
252   API_END();
253 }
254 
NNSymbolListInputVariables(SymbolHandle symbol,int option,nn_uint * out_size,SymbolHandle ** out_sym_array)255 int NNSymbolListInputVariables(SymbolHandle symbol,
256                                int option,
257                                nn_uint *out_size,
258                                SymbolHandle** out_sym_array) {
259   Symbol *s = static_cast<Symbol*>(symbol);
260   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
261   API_BEGIN();
262   std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option));
263   ret->ret_handles.resize(0);
264   ret->ret_handles.reserve(vs.size());
265   for (size_t i = 0; i < vs.size(); ++i) {
266     nnvm::Symbol* rs = new nnvm::Symbol();
267     rs->outputs.push_back(NodeEntry{vs[i], 0, 0});
268     ret->ret_handles.push_back(rs);
269   }
270   *out_size = static_cast<nn_uint>(vs.size());
271   *out_sym_array = dmlc::BeginPtr(ret->ret_handles);
272   API_END();
273 }
274 
NNSymbolListInputNames(SymbolHandle symbol,int option,nn_uint * out_size,const char *** out_str_array)275 int NNSymbolListInputNames(SymbolHandle symbol,
276                            int option,
277                            nn_uint *out_size,
278                            const char ***out_str_array) {
279   Symbol *s = static_cast<Symbol*>(symbol);
280   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
281   API_BEGIN();
282   ret->ret_vec_str =
283       s->ListInputNames(Symbol::ListInputOption(option));
284   ret->ret_vec_charp.resize(0);
285   ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
286   for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
287     ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
288   }
289   *out_size = static_cast<nn_uint>(ret->ret_vec_charp.size());
290   *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp);
291   API_END();
292 }
293 
NNSymbolListOutputNames(SymbolHandle symbol,nn_uint * out_size,const char *** out_str_array)294 int NNSymbolListOutputNames(SymbolHandle symbol,
295                             nn_uint *out_size,
296                             const char ***out_str_array) {
297   Symbol *s = static_cast<Symbol*>(symbol);
298   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
299   API_BEGIN();
300   ret->ret_vec_str = s->ListOutputNames();
301   ret->ret_vec_charp.resize(0);
302   ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
303   for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
304     ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
305   }
306   *out_size = static_cast<nn_uint>(ret->ret_vec_charp.size());
307   *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp);
308   API_END();
309 }
310 
NNSymbolGetNumOutputs(SymbolHandle symbol,nn_uint * output_count)311 int NNSymbolGetNumOutputs(SymbolHandle symbol,
312                            nn_uint *output_count) {
313   Symbol *s = static_cast<Symbol*>(symbol);
314   API_BEGIN();
315   *output_count = static_cast<nn_uint>(s->outputs.size());
316   API_END();
317 }
318 
NNSymbolCompose(SymbolHandle sym,const char * name,nn_uint num_args,const char ** keys,SymbolHandle * args)319 int NNSymbolCompose(SymbolHandle sym,
320                     const char *name,
321                     nn_uint num_args,
322                     const char** keys,
323                     SymbolHandle* args) {
324   API_BEGIN();
325   NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
326   std::string& s_name = ret->ret_str;
327   std::unordered_map<std::string, const Symbol*>& kwargs
328       = ret->kwarg_symbol;
329   kwargs.clear();
330   if (name != nullptr) {
331     s_name = name;
332   } else {
333     s_name.clear();
334   }
335   Symbol* s = static_cast<Symbol*>(sym);
336   if (keys == nullptr && num_args != 0) {
337     kwargs.clear();
338     array_view<const Symbol*> parg(
339         (Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
340     s->Compose(parg, kwargs, s_name);
341   } else {
342     for (nn_uint i = 0; i < num_args; ++i) {
343       kwargs[keys[i]] = (Symbol*)args[i];  //  NOLINT(*)
344     }
345     s->Compose(array_view<const Symbol*>(), kwargs, s_name);
346   }
347   API_END();
348 }
349