1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file c_api_symbolic.cc
22  * \brief C API of mxnet
23  */
24 #include "mxnet/base.h"
25 #include "mxnet/c_api.h"
26 #include "mxnet/imperative.h"
27 #include "nnvm/c_api.h"
28 #include "nnvm/pass.h"
29 #include "nnvm/pass_functions.h"
30 #include "nnvm/symbolic.h"
31 #include "./c_api_common.h"
32 #include "../common/exec_utils.h"
33 #include "../operator/operator_common.h"
34 #include "../executor/exec_pass.h"
35 #include "../operator/subgraph/subgraph_property.h"
36 
37 namespace mxnet {
38 namespace op {
39 void RegisterLegacyOpProp();
40 void RegisterLegacyNDFunc();
41 }
42 const std::vector<std::string> kHiddenKeys = {
43   "ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage"
44 };
45 const std::vector<std::string> kReplacedHiddenKeys = {
46   "__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__", "__mirror_stage__"
47 };
48 const char *kNamespaceSeparator = "$";
49 
50 
51 DMLC_JSON_ENABLE_ANY(int, int);
52 
53 // convert nnvm symbol to a nnvm graph.
Symbol2Graph(const nnvm::Symbol & s)54 nnvm::Graph Symbol2Graph(const nnvm::Symbol &s) {
55   nnvm::Graph g;
56   g.outputs = s.outputs;
57   g.attrs["mxnet_version"] = std::make_shared<nnvm::any>(static_cast<int>(MXNET_VERSION));
58   if (Imperative::Get()->is_np_shape()) {
59     g.attrs["is_np_shape"] = std::make_shared<nnvm::any>(
60         static_cast<int>(Imperative::Get()->is_np_shape()));
61   }
62   return g;
63 }
64 
ReadOnlyArgIndices(const nnvm::IndexedGraph & idx)65 std::vector<uint32_t> ReadOnlyArgIndices(const nnvm::IndexedGraph& idx) {
66   std::vector<uint32_t> ret;
67   auto& arg_nodes = idx.input_nodes();
68   for (uint32_t i = 0; i < arg_nodes.size(); ++i) {
69     if (idx.mutable_input_nodes().count(arg_nodes[i]) == 0) {
70       ret.push_back(i);
71     }
72   }
73   return ret;
74 }
75 
76 }  // namespace mxnet
77 
78 // symbolic configuration generation API.
79 // Redirect to NNVM's C API
MXListAllOpNames(nn_uint * out_size,const char *** out_array)80 int MXListAllOpNames(nn_uint *out_size,
81                      const char ***out_array) {
82   mxnet::op::RegisterLegacyOpProp();
83   mxnet::op::RegisterLegacyNDFunc();
84   return NNListAllOpNames(out_size, out_array);
85 }
86 
MXSymbolListAtomicSymbolCreators(uint32_t * out_size,AtomicSymbolCreator ** out_array)87 int MXSymbolListAtomicSymbolCreators(uint32_t *out_size,
88                                      AtomicSymbolCreator **out_array) {
89   mxnet::op::RegisterLegacyOpProp();
90   mxnet::op::RegisterLegacyNDFunc();
91   return NNListUniqueOps(out_size, out_array);
92 }
93 
MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,const char ** name,const char ** description,uint32_t * num_args,const char *** arg_names,const char *** arg_type_infos,const char *** arg_descriptions,const char ** key_var_num_args,const char ** return_type)94 int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
95                                 const char **name,
96                                 const char **description,
97                                 uint32_t *num_args,
98                                 const char ***arg_names,
99                                 const char ***arg_type_infos,
100                                 const char ***arg_descriptions,
101                                 const char **key_var_num_args,
102                                 const char **return_type) {
103   static auto& map_key_var_args = nnvm::Op::GetAttr<std::string>("key_var_num_args");
104   const Op* op = static_cast<Op*>(creator);
105   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
106   ret->ret_str.resize(0);
107 
108   if (map_key_var_args.count(op) != 0) {
109     *key_var_num_args = map_key_var_args[op].c_str();
110   } else {
111     *key_var_num_args = ret->ret_str.c_str();
112   }
113   return NNGetOpInfo(
114       creator, name, description,
115       num_args, arg_names, arg_type_infos,
116       arg_descriptions, return_type);
117 }
118 
MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,uint32_t num_param,const char ** keys,const char ** vals,SymbolHandle * out)119 int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
120                                uint32_t num_param,
121                                const char **keys,
122                                const char **vals,
123                                SymbolHandle *out) {
124   nnvm::Symbol *s = new nnvm::Symbol();
125   API_BEGIN();
126   const nnvm::Op* op = static_cast<const nnvm::Op*>(creator);
127   std::unordered_map<std::string, std::string> kwargs;
128   for (nn_uint i = 0; i < num_param; ++i) {
129     bool flag = false;
130     for (const auto &k : kHiddenKeys) {
131       std::string tmp(keys[i]);
132       size_t pos = tmp.rfind(k);
133       if (pos == 0) {
134         kwargs.insert({"__" + tmp + "__", std::string(vals[i])});
135         flag = true;
136         break;
137       } else if (pos != std::string::npos && pos == tmp.length() - k.length()) {
138         std::ostringstream os;
139         os << "setting variable attributes with " << keys[i] << " is deprecated. "
140            << "please instead use\nw = Variable(" << k << "=" << vals[i] << ")\n"
141            << "sym = YourSymbolName(" << tmp.substr(0, pos-1) << "=w)";
142         throw dmlc::Error(os.str());
143       }
144     }
145     if (!flag)
146       kwargs.insert({std::string(keys[i]), std::string(vals[i])});
147   }
148   *s = nnvm::Symbol::CreateFunctor(op, std::move(kwargs));
149   *out = s;
150   API_END_HANDLE_ERROR(delete s;);
151 }
152 
MXSymbolCreateVariable(const char * name,SymbolHandle * out)153 int MXSymbolCreateVariable(const char *name, SymbolHandle *out) {
154   return NNSymbolCreateVariable(name, out);
155 }
156 
MXSymbolCreateGroup(uint32_t num_symbols,SymbolHandle * symbols,SymbolHandle * out)157 int MXSymbolCreateGroup(uint32_t num_symbols,
158                         SymbolHandle *symbols,
159                         SymbolHandle *out) {
160   return NNSymbolCreateGroup(num_symbols, symbols, out);
161 }
162 
MXSymbolGetOutput(SymbolHandle symbol,uint32_t index,SymbolHandle * out)163 int MXSymbolGetOutput(SymbolHandle symbol,
164                       uint32_t index,
165                       SymbolHandle *out) {
166   return NNSymbolGetOutput(symbol, index, out);
167 }
168 
MXSymbolGetInternals(SymbolHandle symbol,SymbolHandle * out)169 int MXSymbolGetInternals(SymbolHandle symbol,
170                          SymbolHandle *out) {
171   nnvm::Symbol *s = new nnvm::Symbol();
172   API_BEGIN();
173   *s = static_cast<nnvm::Symbol*>(symbol)->GetInternals();
174   *out = s;
175   API_END_HANDLE_ERROR(delete s);
176 }
177 
MXSymbolGetChildren(SymbolHandle symbol,SymbolHandle * out)178 int MXSymbolGetChildren(SymbolHandle symbol,
179                         SymbolHandle *out) {
180   nnvm::Symbol *s = new nnvm::Symbol();
181   API_BEGIN();
182   *s = static_cast<nnvm::Symbol*>(symbol)->GetChildren();
183   *out = s;
184   API_END_HANDLE_ERROR(delete s);
185 }
186 
MXSymbolFree(SymbolHandle symbol)187 int MXSymbolFree(SymbolHandle symbol) {
188   return NNSymbolFree(symbol);
189 }
190 
MXSymbolCopy(SymbolHandle symbol,SymbolHandle * out)191 int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) {
192   return NNSymbolCopy(symbol, out);
193 }
194 
MXSymbolPrint(SymbolHandle symbol,const char ** out_str)195 int MXSymbolPrint(SymbolHandle symbol, const char **out_str) {
196   return NNSymbolPrint(symbol, out_str);
197 }
198 
MXSymbolGetName(SymbolHandle symbol,const char ** out,int * success)199 int MXSymbolGetName(SymbolHandle symbol,
200                     const char** out,
201                     int* success) {
202   return NNSymbolGetAttr(symbol, "name", out, success);
203 }
204 
MXSymbolGetAttr(SymbolHandle symbol,const char * key,const char ** out,int * success)205 int MXSymbolGetAttr(SymbolHandle symbol,
206                     const char* key,
207                     const char** out,
208                     int* success) {
209   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
210   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
211   API_BEGIN();
212   if (s->GetAttr(key, &(ret->ret_str))) {
213     *out = (ret->ret_str).c_str();
214     *success = 1;
215   } else {
216     *out = nullptr;
217     *success = 0;
218     if (std::find(kHiddenKeys.begin(), kHiddenKeys.end(), key) != kHiddenKeys.end()) {
219       std::string skey = "__" + std::string(key) + "__";
220       if (s->GetAttr(skey, &(ret->ret_str))) {
221         *out = (ret->ret_str).c_str();
222         *success = 1;
223       }
224     }
225   }
226   API_END();
227 }
228 
MXSymbolSetAttr(SymbolHandle symbol,const char * key,const char * value)229 int MXSymbolSetAttr(SymbolHandle symbol,
230                     const char* key,
231                     const char* value) {
232   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
233   API_BEGIN();
234   std::vector<std::pair<std::string, std::string> > kwargs;
235   std::string skey(key), sval(value);
236   for (const auto &k : kHiddenKeys) {
237     size_t pos = skey.rfind(k);
238     if (pos == 0 && k.length() == skey.length()) {
239       skey = "__" + skey + "__";
240       break;
241     } else if (pos != std::string::npos && pos + k.length() == skey.length()) {
242       std::ostringstream os;
243       os << "setting variable attributes with " << key << " is deprecated. "
244          << "please instead use\nw = Variable(" << k << "=" << value << ")\n"
245          << "sym = YourSymbolName(" << skey.substr(0, pos-1) << "=w)";
246       throw dmlc::Error(os.str());
247     }
248   }
249   kwargs.emplace_back(std::make_pair(std::move(skey), std::move(sval)));
250   s->SetAttrs(kwargs);
251   API_END();
252 }
253 
MXSymbolListAttr(SymbolHandle symbol,uint32_t * out_size,const char *** out)254 int MXSymbolListAttr(SymbolHandle symbol,
255                      uint32_t *out_size,
256                      const char*** out) {
257   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
258   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
259   API_BEGIN();
260   std::vector<std::tuple<std::string, std::string, std::string> > attr =
261       s->ListAttrsRecursive();
262 
263   std::vector<std::string>& attr_list = ret->ret_vec_str;
264   attr_list.clear();
265   for (const auto& tp : attr) {
266     attr_list.emplace_back(std::get<0>(tp) + kNamespaceSeparator + std::get<1>(tp));
267     attr_list.emplace_back(std::get<2>(tp));
268     if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), std::get<1>(tp))
269           != kReplacedHiddenKeys.end()) {
270       attr_list.push_back(std::get<0>(tp) + kNamespaceSeparator +
271                           std::get<1>(tp).substr(2, std::get<1>(tp).length() - 4));
272       attr_list.push_back(std::get<2>(tp));
273     }
274   }
275   *out_size = attr_list.size()/2;
276   ret->ret_vec_charp.clear();
277   for (const auto& attr : attr_list) {
278     ret->ret_vec_charp.push_back(attr.c_str());
279   }
280   *out = dmlc::BeginPtr(ret->ret_vec_charp);
281   API_END();
282 }
283 
MXSymbolListAttrShallow(SymbolHandle symbol,uint32_t * out_size,const char *** out)284 int MXSymbolListAttrShallow(SymbolHandle symbol,
285                             uint32_t *out_size,
286                             const char*** out) {
287   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
288   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
289   API_BEGIN();
290   std::unordered_map<std::string, std::string> attr =
291       s->ListAttrs(static_cast<nnvm::Symbol::ListAttrOption>(1));  // NOLINT(*)
292 
293   std::vector<std::string>& attr_list = ret->ret_vec_str;
294   attr_list.clear();
295   for (const auto& kv : attr) {
296     attr_list.push_back(kv.first);
297     attr_list.push_back(kv.second);
298     if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), kv.first)
299           != kReplacedHiddenKeys.end()) {
300       attr_list.push_back(kv.first.substr(2, kv.first.length() - 4));
301       attr_list.push_back(kv.second);
302     }
303   }
304   *out_size = attr_list.size()/2;
305   ret->ret_vec_charp.clear();
306   for (auto &attr : attr_list) {
307     ret->ret_vec_charp.push_back(attr.c_str());
308   }
309   *out = dmlc::BeginPtr(ret->ret_vec_charp);
310   API_END();
311 }
312 
MXSymbolListOutputs(SymbolHandle symbol,uint32_t * out_size,const char *** out_str_array)313 int MXSymbolListOutputs(SymbolHandle symbol,
314                         uint32_t *out_size,
315                         const char ***out_str_array) {
316   return NNSymbolListOutputNames(symbol, out_size, out_str_array);
317 }
318 
MXSymbolGetNumOutputs(SymbolHandle symbol,uint32_t * output_count)319 int MXSymbolGetNumOutputs(SymbolHandle symbol,
320                            uint32_t *output_count) {
321   return NNSymbolGetNumOutputs(symbol, output_count);
322 }
323 
MXSymbolCompose(SymbolHandle sym,const char * name,uint32_t num_args,const char ** keys,SymbolHandle * args)324 int MXSymbolCompose(SymbolHandle sym,
325                     const char *name,
326                     uint32_t num_args,
327                     const char** keys,
328                     SymbolHandle* args) {
329   return NNSymbolCompose(sym, name, num_args, keys, args);
330 }
331 
332 // adapter functions that re-implements the functions.
MXSymbolListArguments(SymbolHandle symbol,uint32_t * out_size,const char *** out_str_array)333 int MXSymbolListArguments(SymbolHandle symbol,
334                           uint32_t *out_size,
335                           const char ***out_str_array) {
336   return NNSymbolListInputNames(symbol, 1, out_size, out_str_array);
337 }
338 
MXSymbolListAuxiliaryStates(SymbolHandle symbol,uint32_t * out_size,const char *** out_str_array)339 int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
340                                 uint32_t *out_size,
341                                 const char ***out_str_array) {
342   return NNSymbolListInputNames(symbol, 2, out_size, out_str_array);
343 }
344 
MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,const char ** out)345 int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
346                                 const char **out) {
347   API_BEGIN();
348   Op *e = static_cast<Op *>(creator);
349   *out = e->name.c_str();
350   API_END();
351 }
352 
353 namespace mxnet {
354 
355 extern std::vector<nnvm::Symbol *> GetInputSymbols(const nnvm::Symbol &sym);
356 extern bool CutGraphInputs(const std::vector<nnvm::NodeEntry *> &input_entries,
357                            bool skip_var, std::vector<nnvm::NodeEntry> *orig_entries);
358 
359 }
360 
MXSymbolGetInputSymbols(SymbolHandle sym,SymbolHandle ** input_arr,int * input_size)361 int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *input_size) {
362   API_BEGIN();
363   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
364   std::vector<nnvm::Symbol *> input_syms = mxnet::GetInputSymbols(*s);
365   *input_size = input_syms.size();
366 
367   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
368   ret->ret_handles.clear();
369   ret->ret_handles.reserve(*input_size);
370   for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]);
371   *input_arr = reinterpret_cast<SymbolHandle*>(dmlc::BeginPtr(ret->ret_handles));
372   API_END_HANDLE_ERROR();
373 }
374 
MXSymbolCutSubgraph(SymbolHandle sym,SymbolHandle ** input_symbols,int * input_size)375 int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
376                         int *input_size) {
377   // Given a graph, we want to fetch the nodes that have been marked as part of
378   // a subgraph.
379   API_BEGIN();
380   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
381   const std::string subg_attr = "__subgraph_name__";
382   auto out_node = s->outputs[0].node;
383   auto it = out_node->attrs.dict.find(subg_attr);
384   if (it != out_node->attrs.dict.end()) {
385     const std::string &subg_name = it->second;
386     std::vector<nnvm::NodeEntry *> input_entries;
387     DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries]
388              (nnvm::ObjectPtr n) {
389       // If the node itself isn't in the subgraph, we ignore it.
390       auto it = n->attrs.dict.find(subg_attr);
391       if (it == n->attrs.dict.end() || it->second != subg_name)
392         return;
393 
394       // We search for nodes whose node entries aren't in the subgraph.
395       for (size_t j = 0; j < n->inputs.size(); j++) {
396         auto in_node = n->inputs[j].node;
397         auto it = in_node->attrs.dict.find(subg_attr);
398         if (it == in_node->attrs.dict.end() || it->second != subg_name)
399           input_entries.push_back(&n->inputs[j]);
400       }
401     });
402 
403     std::vector<nnvm::NodeEntry> orig_entries;
404     CutGraphInputs(input_entries, false, &orig_entries);
405     std::vector<nnvm::Symbol *> input_syms(orig_entries.size());
406     for (size_t i = 0; i < input_syms.size(); i++) {
407       input_syms[i] = new nnvm::Symbol();
408       input_syms[i]->outputs.push_back(orig_entries[i]);
409     }
410     *input_size = input_syms.size();
411 
412     MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
413     ret->ret_handles.clear();
414     ret->ret_handles.reserve(*input_size);
415     for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]);
416     *input_symbols = reinterpret_cast<SymbolHandle*>(dmlc::BeginPtr(ret->ret_handles));
417   } else {
418     *input_size = 0;
419   }
420 
421   API_END_HANDLE_ERROR();
422 }
423 
424 
425 /*!
426  * \brief Convert shape attr in graph nodes to comply with NumPy semantics for
427  * legacy models (before 1.6.0) if global flag is_np_shape has been turned on,
428  * i.e., use -1 to indicate unknown number of dimensions and unknown dimension sizes.
429  */
ConvertShapeAttrToNumPyCompatible(nnvm::Graph * g)430 void ConvertShapeAttrToNumPyCompatible(nnvm::Graph* g) {
431   if (Imperative::Get()->is_np_shape()
432     && (!g->HasAttr("is_np_shape") || !g->GetAttr<int>("is_np_shape"))) {
433     DFSVisit(g->outputs, [](nnvm::ObjectPtr n) {
434       if (n->is_variable()) {
435         auto it = n->attrs.dict.find("__shape__");
436         if (it != n->attrs.dict.end()) {
437           mxnet::TShape shape;
438           std::istringstream is(it->second);
439           is >> shape;
440           common::ConvertToNumpyShape(&shape);
441           std::ostringstream os;
442           os << shape;
443           it->second = os.str();
444         }
445       }
446     });
447   }
448 }
449 
MXSymbolCreateFromFile(const char * fname,SymbolHandle * out)450 int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) {
451   nnvm::Symbol *s = new nnvm::Symbol();
452   API_BEGIN();
453   std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
454   dmlc::istream is(fi.get());
455   nnvm::Graph g;
456   g.attrs["json"] = std::make_shared<nnvm::any>(
457     std::string(std::istreambuf_iterator<char>(is), std::istreambuf_iterator<char>()));
458   g = nnvm::ApplyPass(g, "LoadLegacyJSON");
459   ConvertShapeAttrToNumPyCompatible(&g);
460   s->outputs = g.outputs;
461   *out = s;
462   is.set_stream(nullptr);
463   API_END_HANDLE_ERROR(delete s);
464 }
465 
MXSymbolCreateFromJSON(const char * json,SymbolHandle * out)466 int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out) {
467   nnvm::Symbol *s = new nnvm::Symbol();
468   API_BEGIN();
469   nnvm::Graph g;
470   g.attrs["json"] = std::make_shared<nnvm::any>(std::string(json));
471   g = nnvm::ApplyPass(g, "LoadLegacyJSON");
472   ConvertShapeAttrToNumPyCompatible(&g);
473   s->outputs = g.outputs;
474   *out = s;
475   API_END_HANDLE_ERROR(delete s);
476 }
477 
MXSymbolRemoveAmpCast(SymbolHandle sym_handle,SymbolHandle * ret_sym_handle)478 int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle) {
479   nnvm::Symbol* s = new nnvm::Symbol();
480   API_BEGIN();
481   nnvm::Symbol *source = static_cast<nnvm::Symbol*>(sym_handle);
482   *s = source->Copy();
483   s->outputs = nnvm::ApplyPass(Symbol2Graph(*s), "RemoveAmpCast").outputs;
484   *ret_sym_handle = s;
485   API_END_HANDLE_ERROR(delete s);
486 }
487 
MXSymbolSaveToFile(SymbolHandle symbol,const char * fname)488 int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) {
489   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
490   API_BEGIN();
491   std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
492   dmlc::ostream os(fo.get());
493   os << nnvm::pass::SaveJSON(Symbol2Graph(*s));
494   // reset file pointer, force flush
495   os.set_stream(nullptr);
496   API_END();
497 }
498 
MXSymbolSaveToJSON(SymbolHandle symbol,const char ** out_json)499 int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) {
500   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
501   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
502   API_BEGIN();
503   ret->ret_str = nnvm::pass::SaveJSON(Symbol2Graph(*s));
504   *out_json = ret->ret_str.c_str();
505   API_END();
506 }
507 
508 namespace mxnet {
509 
510 template<typename AttrType>
MatchArguments(const nnvm::IndexedGraph & idx,const std::unordered_map<std::string,AttrType> & known_arg_attrs,std::vector<AttrType> * arg_attrs,const char * source)511 void MatchArguments(
512     const nnvm::IndexedGraph& idx,
513     const std::unordered_map<std::string, AttrType>& known_arg_attrs,
514     std::vector<AttrType>* arg_attrs,
515     const char* source) {
516   auto& arg_nodes = idx.input_nodes();
517   CHECK_EQ(arg_attrs->size(), arg_nodes.size());
518   size_t nmatched = 0;
519   for (size_t i = 0; i < arg_nodes.size(); ++i) {
520     const std::string& name = idx[arg_nodes[i]].source->attrs.name;
521     auto it = known_arg_attrs.find(name);
522     if (it != known_arg_attrs.end()) {
523       arg_attrs->at(i) = it->second;
524       ++nmatched;
525     }
526   }
527   if (nmatched != known_arg_attrs.size()) {
528     std::unordered_set<std::string> keys;
529     std::ostringstream head, msg;
530     msg << "\nCandidate arguments:\n";
531     for (size_t i = 0; i < arg_nodes.size(); ++i) {
532       std::string arg_name = idx[arg_nodes[i]].source->attrs.name;
533       keys.insert(arg_name);
534       msg << "\t[" << i << ']' << arg_name << '\n';
535     }
536     for (const auto& kv : known_arg_attrs) {
537       const std::string& key = kv.first;
538       if (keys.count(key) == 0) {
539         LOG(FATAL) << source
540                    << "Keyword argument name " << key << " not found."
541                    << msg.str();
542       }
543     }
544   }
545 }
546 
547 }  // namespace mxnet
548 
MXSymbolInferShape(SymbolHandle sym,uint32_t num_args,const char ** keys,const uint32_t * arg_ind_ptr,const uint32_t * arg_shape_data,uint32_t * in_shape_size,const uint32_t ** in_shape_ndim,const uint32_t *** in_shape_data,uint32_t * out_shape_size,const uint32_t ** out_shape_ndim,const uint32_t *** out_shape_data,uint32_t * aux_shape_size,const uint32_t ** aux_shape_ndim,const uint32_t *** aux_shape_data,int * complete)549 int MXSymbolInferShape(SymbolHandle sym,
550                        uint32_t num_args,
551                        const char** keys,
552                        const uint32_t *arg_ind_ptr,
553                        const uint32_t *arg_shape_data,
554                        uint32_t *in_shape_size,
555                        const uint32_t **in_shape_ndim,
556                        const uint32_t ***in_shape_data,
557                        uint32_t *out_shape_size,
558                        const uint32_t **out_shape_ndim,
559                        const uint32_t ***out_shape_data,
560                        uint32_t *aux_shape_size,
561                        const uint32_t **aux_shape_ndim,
562                        const uint32_t ***aux_shape_data,
563                        int *complete) {
564   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
565   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
566   API_BEGIN();
567   nnvm::Graph g = Symbol2Graph(*s);
568   mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape());
569   if (keys == nullptr && num_args != 0) {
570     std::vector<uint32_t> read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
571     CHECK_LE(num_args, read_only_args.size());
572     for (uint32_t i = 0; i < num_args; ++i) {
573       arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(
574           arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]);
575     }
576   } else {
577     std::unordered_map<std::string, mxnet::TShape> kwargs;
578     for (uint32_t i = 0; i < num_args; ++i) {
579       kwargs[keys[i]] = mxnet::ShapeTypeCast(
580           arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]);
581     }
582     mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape");
583   }
584 
585   try {
586     g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
587   } catch (const mxnet::op::InferShapeError &err) {
588     throw dmlc::Error(err.msg);
589   }
590 
591   // if use legacy shape definition, need to convert numpy shape to legacy shape
592   mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
593   if (!Imperative::Get()->is_np_shape()) {
594     common::ConvertToLegacyShape(&shapes);
595   }
596 
597   // copy back
598   CopyAttr(g.indexed_graph(), shapes,
599            &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
600 
601   // copy data back
602   MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->arg_shapes,
603       &(ret->arg_shape_ndim), &(ret->arg_shape_data), &(ret->arg_shape_buffer));
604   MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->out_shapes,
605       &(ret->out_shape_ndim), &(ret->out_shape_data), &(ret->out_shape_buffer));
606   MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->aux_shapes,
607       &(ret->aux_shape_ndim), &(ret->aux_shape_data), &(ret->aux_shape_buffer));
608   *in_shape_size = static_cast<uint32_t>(ret->arg_shapes.size());
609   *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim);
610   *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data);
611   *out_shape_size = static_cast<uint32_t>(ret->out_shapes.size());
612   *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim);
613   *out_shape_data = dmlc::BeginPtr(ret->out_shape_data);
614   *aux_shape_size = static_cast<uint32_t>(ret->aux_shapes.size());
615   *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim);
616   *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data);
617   // mark complete
618   *complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
619   API_END();
620 }
621 
622 template<typename dtype, typename stype, typename itype>
SymbolInferShape(const char ** keys,uint32_t num_args,const dtype * arg_shape_data,const itype * arg_ind_ptr,const int ** in_shape_ndim,const dtype *** in_shape_data,const int ** out_shape_ndim,const dtype *** out_shape_data,const int ** aux_shape_ndim,const dtype *** aux_shape_data,nnvm::Symbol * s,MXAPIThreadLocalEntry<dtype> * ret,stype * in_shape_size,stype * out_shape_size,stype * aux_shape_size,int * complete)623 inline void SymbolInferShape(const char** keys,
624                              uint32_t num_args,
625                              const dtype* arg_shape_data,
626                              const itype* arg_ind_ptr,
627                              const int** in_shape_ndim,
628                              const dtype*** in_shape_data,
629                              const int** out_shape_ndim,
630                              const dtype*** out_shape_data,
631                              const int** aux_shape_ndim,
632                              const dtype*** aux_shape_data,
633                              nnvm::Symbol* s,
634                              MXAPIThreadLocalEntry<dtype>* ret,
635                              stype* in_shape_size,
636                              stype* out_shape_size,
637                              stype* aux_shape_size,
638                              int* complete) {
639   nnvm::Graph g = Symbol2Graph(*s);
640   mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape());
641   if (keys == nullptr && num_args != 0) {
642     std::vector < uint32_t > read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
643     CHECK_LE(num_args, read_only_args.size());
644     for (uint32_t i = 0; i < num_args; ++i) {
645       arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i],
646                                                            arg_shape_data + arg_ind_ptr[i + 1]);
647     }
648   } else {
649     std::unordered_map<std::string, mxnet::TShape> kwargs;
650     for (uint32_t i = 0; i < num_args; ++i) {
651       kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i],
652                                              arg_shape_data + arg_ind_ptr[i + 1]);
653     }
654     mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape");
655   }
656   try {
657     g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
658   } catch (const mxnet::op::InferShapeError& err) {
659     throw dmlc::Error(err.msg);
660   }
661   // if use legacy shape definition, need to convert numpy shape to legacy shape
662   mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
663   if (!Imperative::Get()->is_np_shape()) {
664     common::ConvertToLegacyShape(&shapes);
665   }
666   // copy back
667   CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
668   // copy data back
669   MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes,
670                                                                   &(ret->arg_shape_ndim_ex),
671                                                                   &(ret->arg_shape_data_ex),
672                                                                   &(ret->arg_shape_buffer_ex));
673   MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->out_shapes,
674                                                                   &(ret->out_shape_ndim_ex),
675                                                                   &(ret->out_shape_data_ex),
676                                                                   &(ret->out_shape_buffer_ex));
677   MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes,
678                                                                   &(ret->aux_shape_ndim_ex),
679                                                                   &(ret->aux_shape_data_ex),
680                                                                   &(ret->aux_shape_buffer_ex));
681   *in_shape_size = static_cast<stype>(ret->arg_shapes.size());
682   *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex);
683   *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex);
684   *out_shape_size = static_cast<stype>(ret->out_shapes.size());
685   *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex);
686   *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex);
687   *aux_shape_size = static_cast<stype>(ret->aux_shapes.size());
688   *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex);
689   *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex);
690   // mark complete
691   *complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
692 }
693 
694 /*!
695  * \brief Executor for Symbol Shape Inference
696  *  This api is available when MXNet is built with flag
697  *  USE_INT64_TENSOR_SIZE=0 (by default)
698  * \param sym symbol handle
699  * \param num_args number of args
700  * \param keys keys
701  * \param arg_ind_ptr arg index pointer
702  * \param arg_shape_data arg shape data
703  * \param in_shape_size input shape size
704  * \param in_shape_ndim input shape number of dims
705  * \param in_shape_data input shape data
706  * \param out_shape_size ouput shape size
707  * \param out_shape_ndim output shape number of dims
708  * \param out_shape_data output shape data
709  * \param aux_shape_size shape size of auxiliary states
710  * \param aux_shape_ndim number of dims of auxiliary states shape
711  * \param aux_shape_data shape data of auxiliary states
712  * \param complete indicates completion of Shape Inference
713  * \return 0 when success, -1 when failure happens
714  */
MXSymbolInferShapeEx(SymbolHandle sym,uint32_t num_args,const char ** keys,const uint32_t * arg_ind_ptr,const int * arg_shape_data,uint32_t * in_shape_size,const int ** in_shape_ndim,const int *** in_shape_data,uint32_t * out_shape_size,const int ** out_shape_ndim,const int *** out_shape_data,uint32_t * aux_shape_size,const int ** aux_shape_ndim,const int *** aux_shape_data,int * complete)715 int MXSymbolInferShapeEx(SymbolHandle sym,
716                          uint32_t num_args,
717                          const char** keys,
718                          const uint32_t *arg_ind_ptr,
719                          const int *arg_shape_data,
720                          uint32_t *in_shape_size,
721                          const int **in_shape_ndim,
722                          const int ***in_shape_data,
723                          uint32_t *out_shape_size,
724                          const int **out_shape_ndim,
725                          const int ***out_shape_data,
726                          uint32_t *aux_shape_size,
727                          const int **aux_shape_ndim,
728                          const int ***aux_shape_data,
729                          int *complete) {
730   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
731   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
732   API_BEGIN();
733   SymbolInferShape<int, uint32_t, uint32_t>(keys,
734                                               num_args,
735                                               arg_shape_data,
736                                               arg_ind_ptr,
737                                               in_shape_ndim,
738                                               in_shape_data,
739                                               out_shape_ndim,
740                                               out_shape_data,
741                                               aux_shape_ndim,
742                                               aux_shape_data,
743                                               s,
744                                               ret,
745                                               in_shape_size,
746                                               out_shape_size,
747                                               aux_shape_size,
748                                               complete);
749   API_END();
750 }
751 
752 /*!
753  * \brief Executor for Symbol Shape Inference
754  *  This api is available when MXNet is built with flag
755  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
756  * \param sym symbol handle
757  * \param num_args number of args
758  * \param keys keys
759  * \param arg_ind_ptr arg index pointer
760  * \param arg_shape_data arg shape data
761  * \param in_shape_size input shape size
762  * \param in_shape_ndim input shape number of dims
763  * \param in_shape_data input shape data
764  * \param out_shape_size ouput shape size
765  * \param out_shape_ndim output shape number of dims
766  * \param out_shape_data output shape data
767  * \param aux_shape_size shape size of auxiliary states
768  * \param aux_shape_ndim number of dims of auxiliary states shape
769  * \param aux_shape_data shape data of auxiliary states
770  * \param complete indicates completion of Shape Inference
771  * \return 0 when success, -1 when failure happens
772  */
MXSymbolInferShapeEx64(SymbolHandle sym,uint32_t num_args,const char ** keys,const int64_t * arg_ind_ptr,const int64_t * arg_shape_data,size_t * in_shape_size,const int ** in_shape_ndim,const int64_t *** in_shape_data,size_t * out_shape_size,const int ** out_shape_ndim,const int64_t *** out_shape_data,size_t * aux_shape_size,const int ** aux_shape_ndim,const int64_t *** aux_shape_data,int * complete)773 int MXSymbolInferShapeEx64(SymbolHandle sym,
774                            uint32_t num_args,
775                            const char** keys,
776                            const int64_t *arg_ind_ptr,
777                            const int64_t *arg_shape_data,
778                            size_t *in_shape_size,
779                            const int **in_shape_ndim,
780                            const int64_t ***in_shape_data,
781                            size_t *out_shape_size,
782                            const int **out_shape_ndim,
783                            const int64_t ***out_shape_data,
784                            size_t *aux_shape_size,
785                            const int **aux_shape_ndim,
786                            const int64_t ***aux_shape_data,
787                            int *complete) {
788   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
789   MXAPIThreadLocalEntry<int64_t> *ret = MXAPIThreadLocalStore<int64_t>::Get();
790   API_BEGIN();
791   SymbolInferShape<int64_t, size_t, int64_t>(keys,
792                                                  num_args,
793                                                  arg_shape_data,
794                                                  arg_ind_ptr,
795                                                  in_shape_ndim,
796                                                  in_shape_data,
797                                                  out_shape_ndim,
798                                                  out_shape_data,
799                                                  aux_shape_ndim,
800                                                  aux_shape_data,
801                                                  s,
802                                                  ret,
803                                                  in_shape_size,
804                                                  out_shape_size,
805                                                  aux_shape_size,
806                                                  complete);
807   API_END();
808 }
809 
MXSymbolInferShapePartial(SymbolHandle sym,uint32_t num_args,const char ** keys,const uint32_t * arg_ind_ptr,const uint32_t * arg_shape_data,uint32_t * in_shape_size,const uint32_t ** in_shape_ndim,const uint32_t *** in_shape_data,uint32_t * out_shape_size,const uint32_t ** out_shape_ndim,const uint32_t *** out_shape_data,uint32_t * aux_shape_size,const uint32_t ** aux_shape_ndim,const uint32_t *** aux_shape_data,int * complete)810 int MXSymbolInferShapePartial(SymbolHandle sym,
811                               uint32_t num_args,
812                               const char** keys,
813                               const uint32_t *arg_ind_ptr,
814                               const uint32_t *arg_shape_data,
815                               uint32_t *in_shape_size,
816                               const uint32_t **in_shape_ndim,
817                               const uint32_t ***in_shape_data,
818                               uint32_t *out_shape_size,
819                               const uint32_t **out_shape_ndim,
820                               const uint32_t ***out_shape_data,
821                               uint32_t *aux_shape_size,
822                               const uint32_t **aux_shape_ndim,
823                               const uint32_t ***aux_shape_data,
824                               int *complete) {
825   int succ = 0;
826   *complete = 1;
827   return MXSymbolInferShape(sym, num_args, keys,
828                             arg_ind_ptr, arg_shape_data,
829                             in_shape_size, in_shape_ndim, in_shape_data,
830                             out_shape_size, out_shape_ndim, out_shape_data,
831                             aux_shape_size, aux_shape_ndim, aux_shape_data,
832                             &succ);
833 }
834 
835 /*!
836  * \brief Executor for Symbol Partial Shape Inference
837  *  This api is available when MXNet is built with flag
838  *  USE_INT64_TENSOR_SIZE=0 (by default)
839  * \param sym symbol handle
840  * \param num_args number of args
841  * \param keys keys
842  * \param arg_ind_ptr arg index pointer
843  * \param arg_shape_data arg shape data
844  * \param in_shape_size input shape size
845  * \param in_shape_ndim input shape number of dims
846  * \param in_shape_data input shape data
847  * \param out_shape_size ouput shape size
848  * \param out_shape_ndim output shape number of dims
849  * \param out_shape_data output shape data
850  * \param aux_shape_size shape size of auxiliary states
851  * \param aux_shape_ndim number of dims of auxiliary states shape
852  * \param aux_shape_data shape data of auxiliary states
853  * \param complete indicates completion of Shape Inference
854  * \return 0 when success, -1 when failure happens
855  */
MXSymbolInferShapePartialEx(SymbolHandle sym,uint32_t num_args,const char ** keys,const uint32_t * arg_ind_ptr,const int * arg_shape_data,uint32_t * in_shape_size,const int ** in_shape_ndim,const int *** in_shape_data,uint32_t * out_shape_size,const int ** out_shape_ndim,const int *** out_shape_data,uint32_t * aux_shape_size,const int ** aux_shape_ndim,const int *** aux_shape_data,int * complete)856 int MXSymbolInferShapePartialEx(SymbolHandle sym,
857                                 uint32_t num_args,
858                                 const char** keys,
859                                 const uint32_t *arg_ind_ptr,
860                                 const int *arg_shape_data,
861                                 uint32_t *in_shape_size,
862                                 const int **in_shape_ndim,
863                                 const int ***in_shape_data,
864                                 uint32_t *out_shape_size,
865                                 const int **out_shape_ndim,
866                                 const int ***out_shape_data,
867                                 uint32_t *aux_shape_size,
868                                 const int **aux_shape_ndim,
869                                 const int ***aux_shape_data,
870                                 int *complete) {
871   int succ = 0;
872   *complete = 1;
873   return MXSymbolInferShapeEx(sym, num_args, keys,
874                               arg_ind_ptr, arg_shape_data,
875                               in_shape_size, in_shape_ndim, in_shape_data,
876                               out_shape_size, out_shape_ndim, out_shape_data,
877                               aux_shape_size, aux_shape_ndim, aux_shape_data,
878                               &succ);
879 }
880 
881 /*!
882  * \brief Executor for Symbol Partial Shape Inference
883  *  This api is available when MXNet is built with flag
884  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
885  * \param sym symbol handle
886  * \param num_args number of args
887  * \param keys keys
888  * \param arg_ind_ptr arg index pointer
889  * \param arg_shape_data arg shape data
890  * \param in_shape_size input shape size
891  * \param in_shape_ndim input shape number of dims
892  * \param in_shape_data input shape data
893  * \param out_shape_size ouput shape size
894  * \param out_shape_ndim output shape number of dims
895  * \param out_shape_data output shape data
896  * \param aux_shape_size shape size of auxiliary states
897  * \param aux_shape_ndim number of dims of auxiliary states shape
898  * \param aux_shape_data shape data of auxiliary states
899  * \param complete indicates completion of Shape Inference
900  * \return 0 when success, -1 when failure happens
901  */
MXSymbolInferShapePartialEx64(SymbolHandle sym,uint32_t num_args,const char ** keys,const int64_t * arg_ind_ptr,const int64_t * arg_shape_data,size_t * in_shape_size,const int ** in_shape_ndim,const int64_t *** in_shape_data,size_t * out_shape_size,const int ** out_shape_ndim,const int64_t *** out_shape_data,size_t * aux_shape_size,const int ** aux_shape_ndim,const int64_t *** aux_shape_data,int * complete)902 int MXSymbolInferShapePartialEx64(SymbolHandle sym,
903                                   uint32_t num_args,
904                                   const char** keys,
905                                   const int64_t *arg_ind_ptr,
906                                   const int64_t *arg_shape_data,
907                                   size_t *in_shape_size,
908                                   const int **in_shape_ndim,
909                                   const int64_t ***in_shape_data,
910                                   size_t *out_shape_size,
911                                   const int **out_shape_ndim,
912                                   const int64_t ***out_shape_data,
913                                   size_t *aux_shape_size,
914                                   const int **aux_shape_ndim,
915                                   const int64_t ***aux_shape_data,
916                                   int *complete) {
917   int succ = 0;
918   *complete = 1;
919   return MXSymbolInferShapeEx64(sym, num_args, keys,
920                                 arg_ind_ptr, arg_shape_data,
921                                 in_shape_size, in_shape_ndim, in_shape_data,
922                                 out_shape_size, out_shape_ndim, out_shape_data,
923                                 aux_shape_size, aux_shape_ndim, aux_shape_data,
924                                 &succ);
925 }
926 
MXSymbolInferType(SymbolHandle sym,uint32_t num_args,const char ** keys,const int * arg_type_data,uint32_t * in_type_size,const int ** in_type_data,uint32_t * out_type_size,const int ** out_type_data,uint32_t * aux_type_size,const int ** aux_type_data,int * complete)927 int MXSymbolInferType(SymbolHandle sym,
928                       uint32_t num_args,
929                       const char** keys,
930                       const int *arg_type_data,
931                       uint32_t *in_type_size,
932                       const int **in_type_data,
933                       uint32_t *out_type_size,
934                       const int **out_type_data,
935                       uint32_t *aux_type_size,
936                       const int **aux_type_data,
937                       int *complete) {
938   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
939   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
940   API_BEGIN();
941   nnvm::Graph g = Symbol2Graph(*s);
942   nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1);
943   if (keys == nullptr && num_args != 0) {
944     std::vector<uint32_t> read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
945     CHECK_LE(num_args, read_only_args.size());
946     for (uint32_t i = 0; i < num_args; ++i) {
947       arg_types[read_only_args[i]] = arg_type_data[i];
948     }
949   } else {
950     std::unordered_map<std::string, int> kwargs;
951     for (uint32_t i = 0; i < num_args; ++i) {
952       kwargs[keys[i]] = arg_type_data[i];
953     }
954     mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
955   }
956 
957   g = mxnet::exec::InferType(std::move(g), std::move(arg_types), "__dtype__");
958   // copy back
959   CopyAttr(g.indexed_graph(), g.GetAttr<nnvm::DTypeVector>("dtype"),
960            &(ret->arg_types), &(ret->out_types), &(ret->aux_types));
961 
962   *in_type_size = static_cast<uint32_t>(ret->arg_types.size());
963   *in_type_data = dmlc::BeginPtr(ret->arg_types);
964   *out_type_size = static_cast<uint32_t>(ret->out_types.size());
965   *out_type_data = dmlc::BeginPtr(ret->out_types);
966   *aux_type_size = static_cast<uint32_t>(ret->aux_types.size());
967   *aux_type_data = dmlc::BeginPtr(ret->aux_types);
968   *complete = (g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0);
969   API_END();
970 }
971 
MXSymbolInferTypePartial(SymbolHandle sym,uint32_t num_args,const char ** keys,const int * arg_type_data,uint32_t * in_type_size,const int ** in_type_data,uint32_t * out_type_size,const int ** out_type_data,uint32_t * aux_type_size,const int ** aux_type_data,int * complete)972 int MXSymbolInferTypePartial(SymbolHandle sym,
973                              uint32_t num_args,
974                              const char** keys,
975                              const int *arg_type_data,
976                              uint32_t *in_type_size,
977                              const int **in_type_data,
978                              uint32_t *out_type_size,
979                              const int **out_type_data,
980                              uint32_t *aux_type_size,
981                              const int **aux_type_data,
982                              int *complete) {
983   int succ = 0;
984   *complete = 1;
985   return MXSymbolInferType(sym, num_args, keys,
986                             arg_type_data,
987                             in_type_size, in_type_data,
988                             out_type_size, out_type_data,
989                             aux_type_size, aux_type_data,
990                             &succ);
991 }
992 
MXSymbolGrad(SymbolHandle sym,uint32_t num_wrt,const char ** wrt,SymbolHandle * out)993 int MXSymbolGrad(SymbolHandle sym, uint32_t num_wrt, const char** wrt, SymbolHandle* out) {
994   API_BEGIN();
995   LOG(FATAL) << "not implemented";
996   API_END();
997 }
998 
MXQuantizeSymbol(SymbolHandle sym_handle,SymbolHandle * ret_sym_handle,const int * dev_type,const uint32_t num_excluded_sym_names,const char ** excluded_sym_names,const uint32_t num_excluded_op_names,const char ** excluded_op_names,const uint32_t num_offline,const char ** offline_params,const char * quantized_dtype,const bool calib_quantize,const char * quantize_mode,const char * quantize_granularity,mx_uint * out_num_calib_names,const char *** out_calib_names)999 int MXQuantizeSymbol(SymbolHandle sym_handle,
1000                      SymbolHandle *ret_sym_handle,
1001                      const int* dev_type,
1002                      const uint32_t num_excluded_sym_names,
1003                      const char **excluded_sym_names,
1004                      const uint32_t num_excluded_op_names,
1005                      const char **excluded_op_names,
1006                      const uint32_t num_offline,
1007                      const char **offline_params,
1008                      const char *quantized_dtype,
1009                      const bool calib_quantize,
1010                      const char *quantize_mode,
1011                      const char *quantize_granularity,
1012                      mx_uint* out_num_calib_names,
1013                      const char ***out_calib_names) {
1014   nnvm::Symbol *s = new nnvm::Symbol();
1015   API_BEGIN();
1016   nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(sym_handle);
1017   nnvm::Graph g = Symbol2Graph(*sym);
1018   int target_dev = *dev_type;
1019   std::unordered_set<std::string> excluded_node_names;
1020   for (size_t i = 0; i < num_excluded_sym_names; ++i) {
1021     excluded_node_names.emplace(excluded_sym_names[i]);
1022   }
1023   std::unordered_set<std::string> excluded_op;
1024   for (size_t i = 0; i < num_excluded_op_names; ++i) {
1025     excluded_op.emplace(excluded_op_names[i]);
1026   }
1027   std::unordered_set<std::string> offline;
1028   for (size_t i = 0; i < num_offline; ++i) {
1029     offline.emplace(offline_params[i]);
1030   }
1031   std::string quantized_type(quantized_dtype);
1032   std::string quantized_mode(quantize_mode);
1033   std::string quantized_granularity(quantize_granularity);
1034   g.attrs["excluded_nodes"] = std::make_shared<nnvm::any>(std::move(excluded_node_names));
1035   g.attrs["excluded_ops"] = std::make_shared<nnvm::any>(std::move(excluded_op));
1036   g.attrs["offline_params"] = std::make_shared<nnvm::any>(std::move(offline));
1037   g.attrs["quantized_dtype"] = std::make_shared<nnvm::any>(std::move(quantized_type));
1038   g.attrs["target_ctx"] = std::make_shared<nnvm::any>(target_dev);
1039   g.attrs["quantize_mode"] = std::make_shared<nnvm::any>(std::move(quantized_mode));
1040   g.attrs["quantize_granularity"] = std::make_shared<nnvm::any>(std::move(quantized_granularity));
1041   g = ApplyPass(std::move(g), "QuantizeGraph");
1042   const auto& calib_nodes = g.GetAttr<std::vector<std::string>>("calib_nodes");
1043   MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
1044   ret->ret_vec_str = std::move(calib_nodes);
1045   *out_num_calib_names = ret->ret_vec_str.size();
1046   ret->ret_vec_charp.clear();
1047   ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
1048   for (const auto &str : ret->ret_vec_str) {
1049     ret->ret_vec_charp.push_back(str.c_str());
1050   }
1051   *out_calib_names = dmlc::BeginPtr(ret->ret_vec_charp);
1052   s->outputs = g.outputs;
1053   *ret_sym_handle = s;
1054   API_END_HANDLE_ERROR(delete s);
1055 }
1056 
1057 // helper function to add mapping of node_name -> dtype map
1058 // for the given indexed graph and inferred_dtypes
_SetInputDTypes(const nnvm::IndexedGraph & idx,const nnvm::DTypeVector & inferred_dtypes,std::unordered_map<std::string,int> * node_name_dtype_map,std::unordered_map<std::string,int> * node_without_dtype_map)1059 static void _SetInputDTypes(
1060     const nnvm::IndexedGraph& idx,
1061     const nnvm::DTypeVector& inferred_dtypes,
1062     std::unordered_map<std::string, int>* node_name_dtype_map,
1063     std::unordered_map<std::string, int>* node_without_dtype_map) {
1064   const std::string dtype_keyword = "__dtype__";
1065   for (uint32_t nid : idx.input_nodes()) {
1066     const auto& node = idx[nid].source;
1067     const auto& node_with_dtype = node->attrs.dict.find(dtype_keyword);
1068     // input nodes classified into nodes_with_dtype, nodes_without_dtype
1069     // This classification required because if param_names not provided
1070     // we want to update dtypes of only those nodes which have dtypes set
1071     // inferred_dtypes are obtained for the nodes, if unknown
1072     // dtype is set to fp32
1073     if (node_with_dtype != node->attrs.dict.end()) {
1074       if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) {
1075         (*node_name_dtype_map)[node->attrs.name] = 0;
1076       } else {
1077         (*node_name_dtype_map)[node->attrs.name] =
1078             inferred_dtypes[idx.entry_id(nid, 0)];
1079       }
1080     } else {
1081       if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) {
1082         (*node_without_dtype_map)[node->attrs.name] = 0;
1083       } else {
1084         (*node_without_dtype_map)[node->attrs.name] =
1085             inferred_dtypes[idx.entry_id(nid, 0)];
1086       }
1087     }
1088   }
1089 }
1090 
1091 // helper function update the node dtype attrs for a vector of nodeptrs
1092 // given the node name to dtype information and the names of model_params
1093 // if model_params is provided the function will dtype of only model params.
1094 // if model_params is empty, the function will dtype of all nodes which had
1095 // a prior dtype set.
1096 // args is a const_reference vector of ObjectPtrs. ObjectPtrs are immutable but
1097 // the Nodes they are pointing will be mutated in this function
_UpdateSymDTypeAttrs(const std::unordered_map<std::string,int> & node_name_dtype_map,const std::unordered_map<std::string,int> & node_without_dtype_map,const std::unordered_set<std::string> & model_params,const std::vector<nnvm::ObjectPtr> & args)1098 static void _UpdateSymDTypeAttrs(
1099     const std::unordered_map<std::string, int>& node_name_dtype_map,
1100     const std::unordered_map<std::string, int>& node_without_dtype_map,
1101     const std::unordered_set<std::string>& model_params,
1102     const std::vector<nnvm::ObjectPtr>& args) {
1103   const std::string dtype_keyword = "__dtype__";
1104 
1105   // Update args to have the right dtype attrs
1106   if (model_params.size() > 0) {
1107     // if model params provided, set dtype only for model params
1108     for (size_t i = 0; i < args.size(); ++i) {
1109       const std::string& node_name = args[i]->attrs.name;
1110       auto it_model_params = model_params.find(node_name);
1111       auto it_with_dtype = node_name_dtype_map.find(node_name);
1112       auto it_without_dtype = node_without_dtype_map.find(node_name);
1113       if (it_model_params != model_params.end()) {
1114         // need to update __dtype__ attribute if already set, else set it
1115         if (it_with_dtype != node_name_dtype_map.end()) {
1116           args[i]->attrs.dict[dtype_keyword] =
1117               std::to_string(it_with_dtype->second);
1118         } else {
1119           CHECK(it_without_dtype != node_without_dtype_map.end())
1120               << "make sure all nodes without dtype have properly been added "
1121                  "in node_without_dtype_map";
1122           args[i]->attrs.dict[dtype_keyword] =
1123               std::to_string(it_without_dtype->second);
1124         }
1125       }
1126     }
1127   } else {
1128     // if model params not provided, update __dtype__ for all inputs,
1129     // which already had it set, don't touch the rest
1130     for (size_t i = 0; i < args.size(); ++i) {
1131       auto it = node_name_dtype_map.find(args[i]->attrs.name);
1132       if (it != node_name_dtype_map.end()) {
1133         if (args[i]->attrs.dict.find(dtype_keyword) !=
1134             args[i]->attrs.dict.end()) {
1135           args[i]->attrs.dict[dtype_keyword] = std::to_string(it->second);
1136         }
1137       }
1138     }
1139   }
1140 }
1141 
MXReducePrecisionSymbol(SymbolHandle sym_handle,SymbolHandle * ret_sym_handle,uint32_t num_args,const int * arg_type_data,uint32_t num_ind_ptr,const int * ind_ptr,const int * target_dtype,const int cast_optional_params,const uint32_t num_target_dtype_op_names,const uint32_t num_fp32_op_names,const uint32_t num_widest_dtype_op_names,const uint32_t num_conditional_fp32_op_names,const uint32_t num_excluded_symbols,const uint32_t num_model_params,const char ** target_dtype_op_names,const char ** fp32_op_names,const char ** widest_dtype_op_names,const char ** conditional_fp32_op_names,const char ** excluded_symbols,const char ** param_names,const char ** param_vals,const char ** model_param_names,const char ** arg_names)1142 int MXReducePrecisionSymbol(SymbolHandle sym_handle,
1143                             SymbolHandle *ret_sym_handle,
1144                             uint32_t num_args,
1145                             const int *arg_type_data,
1146                             uint32_t num_ind_ptr,
1147                             const int* ind_ptr,
1148                             const int* target_dtype,
1149                             const int cast_optional_params,
1150                             const uint32_t num_target_dtype_op_names,
1151                             const uint32_t num_fp32_op_names,
1152                             const uint32_t num_widest_dtype_op_names,
1153                             const uint32_t num_conditional_fp32_op_names,
1154                             const uint32_t num_excluded_symbols,
1155                             const uint32_t num_model_params,
1156                             const char **target_dtype_op_names,
1157                             const char **fp32_op_names,
1158                             const char **widest_dtype_op_names,
1159                             const char **conditional_fp32_op_names,
1160                             const char **excluded_symbols,
1161                             const char **param_names,
1162                             const char **param_vals,
1163                             const char **model_param_names,
1164                             const char **arg_names) {
1165   nnvm::Symbol *result_sym = new nnvm::Symbol();
1166   API_BEGIN();
1167   nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
1168   nnvm::Graph g = Symbol2Graph(*sym);
1169   std::unordered_set<std::string> target_dtype_ops;
1170   std::unordered_set<std::string> fp32_ops;
1171   std::unordered_set<std::string> widest_dtype_ops;
1172   std::unordered_set<std::string> excluded_syms;
1173   std::unordered_set<std::string> model_params;
1174 
1175   // conditional_fp32_ops contains the mapping of op_name -> (map of param_name -> param_values)
1176   // which need to be conditionally selected to be casted to FP32
1177   std::unordered_map<std::string,
1178                      std::unordered_map<std::string,
1179                                         std::vector<std::string>>> conditional_fp32_ops;
1180   int target_dt = *target_dtype;
1181 
1182   for (size_t i = 0; i < num_target_dtype_op_names; ++i) {
1183     target_dtype_ops.emplace(target_dtype_op_names[i]);
1184   }
1185   for (size_t i = 0; i < num_fp32_op_names; ++i) {
1186     fp32_ops.emplace(fp32_op_names[i]);
1187   }
1188   for (size_t i = 0; i < num_widest_dtype_op_names; ++i) {
1189     widest_dtype_ops.emplace(widest_dtype_op_names[i]);
1190   }
1191   for (size_t i = 0; i < num_excluded_symbols; ++i) {
1192     excluded_syms.emplace(excluded_symbols[i]);
1193   }
1194   for (size_t i = 0; i < num_model_params; ++i) {
1195     model_params.emplace(model_param_names[i]);
1196   }
1197 
1198   for (size_t i = 0; i < num_ind_ptr - 1; ++i) {
1199     for (int j = ind_ptr[i]; j < ind_ptr[i + 1]; ++j) {
1200       conditional_fp32_ops[conditional_fp32_op_names[i]][param_names[i]]
1201           .emplace_back(std::string(param_vals[j]));
1202     }
1203   }
1204 
1205   std::unordered_map<std::string, int> kwargs;
1206   std::unordered_map<std::string, int> node_name_dtype_map, node_without_dtype_map;
1207   nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1);
1208   for (uint32_t i = 0; i < num_args; ++i) {
1209     kwargs[arg_names[i]] = arg_type_data[i];
1210     node_name_dtype_map[arg_names[i]] = arg_type_data[i];
1211   }
1212   mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
1213 
1214   g.attrs["target_dtype_ops"] =
1215       std::make_shared<nnvm::any>(std::move(target_dtype_ops));
1216   g.attrs["fp32_ops"] = std::make_shared<nnvm::any>(std::move(fp32_ops));
1217   g.attrs["widest_dtype_ops"] =
1218       std::make_shared<nnvm::any>(std::move(widest_dtype_ops));
1219   g.attrs["conditional_fp32_ops"] =
1220       std::make_shared<nnvm::any>(std::move(conditional_fp32_ops));
1221   g.attrs["excluded_syms"] =
1222       std::make_shared<nnvm::any>(std::move(excluded_syms));
1223   g.attrs["target_dtype"] = std::make_shared<nnvm::any>(target_dt);
1224   g.attrs["data_name_types"] = std::make_shared<nnvm::any>(kwargs);
1225   g.attrs["cast_optional_params"] = std::make_shared<nnvm::any>(cast_optional_params);
1226 
1227   g = ApplyPass(std::move(g), "ReducePrecision");
1228   // Need to run type inference since it is possible that inferred
1229   // type of some inputs has changed
1230   g = mxnet::exec::InferType(std::move(g), std::move(arg_types), "");
1231   const nnvm::DTypeVector &inferred_dtypes =
1232       g.GetAttr<nnvm::DTypeVector>("dtype");
1233 
1234   g.attrs["inferred_dtypes"] = std::make_shared<dmlc::any>(std::move(inferred_dtypes));
1235   g.attrs["target_dtype"] = std::make_shared<nnvm::any>(target_dt);
1236 
1237   if (cast_optional_params) {
1238     g = ApplyPass(std::move(g), "AMPInferUnknown");
1239     const nnvm::DTypeVector &inferred_dtype_result =
1240         g.GetAttr<nnvm::DTypeVector>("inferred_dtype_result");
1241     const nnvm::IndexedGraph &idx = g.indexed_graph();
1242     // set node name -> input dtype mapping using infer dtype
1243     _SetInputDTypes(idx, inferred_dtype_result, &node_name_dtype_map, &node_without_dtype_map);
1244   } else {
1245     const nnvm::IndexedGraph &idx = g.indexed_graph();
1246     // set node name -> input dtype mapping using infer dtype
1247     _SetInputDTypes(idx, inferred_dtypes, &node_name_dtype_map, &node_without_dtype_map);
1248   }
1249 
1250 
1251   result_sym->outputs = g.outputs;
1252   *ret_sym_handle = result_sym;
1253   nnvm::Symbol *ret_sym = static_cast<nnvm::Symbol *>(*ret_sym_handle);
1254   const std::vector<nnvm::ObjectPtr>& args = ret_sym->ListInputs(nnvm::Symbol::kAll);
1255 
1256   // update symbol dtype attrs using the node name -> dtype mapping, if dtype is already set
1257   // in the symbol, else set dtype for the model_params
1258   _UpdateSymDTypeAttrs(node_name_dtype_map, node_without_dtype_map, model_params, args);
1259 
1260   API_END_HANDLE_ERROR(delete result_sym);
1261 }
1262 
MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,const uint32_t num_layers,const char ** layer_names,const float * min_ranges,const float * max_ranges,SymbolHandle * ret_qsym_handle)1263 int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
1264                                      const uint32_t num_layers,
1265                                      const char** layer_names,
1266                                      const float* min_ranges,
1267                                      const float* max_ranges,
1268                                      SymbolHandle* ret_qsym_handle) {
1269   nnvm::Symbol* s = new nnvm::Symbol();
1270   API_BEGIN();
1271   nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(qsym_handle);
1272   nnvm::Graph g = Symbol2Graph(*sym);
1273   std::unordered_map<std::string, std::pair<float, float>> calib_table;
1274   for (size_t i = 0; i < num_layers; ++i) {
1275     calib_table.emplace(layer_names[i], std::make_pair(min_ranges[i], max_ranges[i]));
1276   }
1277   g.attrs["calib_table"] = std::make_shared<nnvm::any>(std::move(calib_table));
1278   g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph");
1279   s->outputs = g.outputs;
1280   *ret_qsym_handle = s;
1281   API_END_HANDLE_ERROR(delete s);
1282 }
1283 
MXGenBackendSubgraph(SymbolHandle sym_handle,const char * backend_name,SymbolHandle * ret_sym_handle)1284 int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name,
1285                          SymbolHandle *ret_sym_handle) {
1286   nnvm::Symbol *s = new nnvm::Symbol();
1287   API_BEGIN();
1288   nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
1289   *s = sym->Copy();
1290   auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
1291   const auto& subgraph_prop_list = backend->GetSubgraphProperties();
1292   for (auto property : subgraph_prop_list) {
1293     if (property->HasAttr("disable") && property->GetAttr<bool>("disable") == true) {
1294       auto full_name = property->HasAttr("property_name")
1295                            ? property->GetAttr<std::string>("property_name")
1296                            : std::string();
1297       LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
1298                 << " is disabled.";
1299       continue;
1300     }
1301     nnvm::Graph g = Symbol2Graph(*s);
1302     property->SetAttr("graph", g);
1303     g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
1304     g = ApplyPass(std::move(g), "BuildSubgraph");
1305     property->RemoveAttr("graph");
1306     g.attrs.erase("subgraph_property");
1307     s->outputs = g.outputs;
1308   }
1309   *ret_sym_handle = s;
1310   API_END_HANDLE_ERROR(delete s);
1311 }
1312 
MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle,SymbolHandle * ret_sym_handle)1313 int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle) {
1314   nnvm::Symbol *s = new nnvm::Symbol();
1315   API_BEGIN();
1316   nnvm::Symbol *source = static_cast<nnvm::Symbol *>(sym_handle);
1317   CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs.";
1318   const auto &node = source->outputs[0].node;
1319   for (const auto &other_node : source->outputs) {
1320     if (node.get() != other_node.node.get()) {
1321       LOG(FATAL)
1322         << "Generating atomic symbol from other symbol only works for nongrouped symbol.";
1323     }
1324   }
1325   const auto *op = node->op();
1326   const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow);
1327   *s = nnvm::Symbol::CreateFunctor(op, attrs);
1328   *ret_sym_handle = s;
1329   API_END_HANDLE_ERROR(delete s);
1330 }
1331 
MXShallowCopySymbol(SymbolHandle src,SymbolHandle * out)1332 int MXShallowCopySymbol(SymbolHandle src, SymbolHandle* out) {
1333   nnvm::Symbol* out_sym = new nnvm::Symbol;
1334   API_BEGIN();
1335   nnvm::Symbol* src_sym = static_cast<nnvm::Symbol*>(src);
1336   *out_sym = *src_sym;
1337   *out = out_sym;
1338   API_END_HANDLE_ERROR(delete out_sym);
1339 }
1340 
MXOptimizeForBackend(SymbolHandle sym_handle,const char * backend_name,const int dev_type,SymbolHandle * ret_sym_handle,const mx_uint args_len,NDArrayHandle * in_args_handle,const mx_uint aux_len,NDArrayHandle * in_aux_handle,const mx_uint num_options,const char ** keys,const char ** vals,const uint32_t num_input_shapes,const char ** input_shape_names,const int64_t * input_shape_data,const uint32_t * input_shape_idx,const uint32_t num_input_dtypes,const char ** input_dtype_names,const int * input_dtypes,const uint32_t num_input_stypes,const char ** input_stype_names,const int * input_stypes,bool skip_infer,int * new_args_cnt,NDArrayHandle ** new_args_handle,char *** new_arg_names_handle,int * new_aux_cnt,NDArrayHandle ** new_aux_handle,char *** new_aux_names_handle)1341 int MXOptimizeForBackend(SymbolHandle sym_handle,
1342                          const char* backend_name,
1343                          const int dev_type,
1344                          SymbolHandle* ret_sym_handle,
1345                          const mx_uint args_len,
1346                          NDArrayHandle* in_args_handle,
1347                          const mx_uint aux_len,
1348                          NDArrayHandle* in_aux_handle,
1349                          const mx_uint num_options,
1350                          const char** keys,
1351                          const char** vals,
1352                          const uint32_t num_input_shapes,
1353                          const char** input_shape_names,
1354                          const int64_t* input_shape_data,
1355                          const uint32_t* input_shape_idx,
1356                          const uint32_t num_input_dtypes,
1357                          const char** input_dtype_names,
1358                          const int* input_dtypes,
1359                          const uint32_t num_input_stypes,
1360                          const char** input_stype_names,
1361                          const int* input_stypes,
1362                          bool skip_infer,
1363                          int* new_args_cnt,
1364                          NDArrayHandle** new_args_handle,
1365                          char*** new_arg_names_handle,
1366                          int* new_aux_cnt,
1367                          NDArrayHandle** new_aux_handle,
1368                          char*** new_aux_names_handle) {
1369   // create copy of input symbol
1370   nnvm::Symbol *s = new nnvm::Symbol();
1371   API_BEGIN();
1372   nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
1373   *s = sym->Copy();
1374 
1375   // create a data structure from pointer array
1376   std::unordered_map<std::string, std::string> options_map;
1377   for (mx_uint i = 0; i < num_options; ++i)
1378     options_map.emplace(keys[i], vals[i]);
1379 
1380   NDArray ***new_args_ptr = reinterpret_cast<NDArray***>(new_args_handle);
1381   NDArray ***new_aux_ptr = reinterpret_cast<NDArray***>(new_aux_handle);
1382   NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
1383   NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
1384 
1385   auto init_graph = [&](nnvm::Symbol* s) {
1386     nnvm::Graph g = Symbol2Graph(*s);
1387     const auto& indexed_graph = g.indexed_graph();
1388     const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
1389     std::vector<std::string> input_names = s->ListInputNames(nnvm::Symbol::kAll);
1390     size_t num_forward_inputs = input_names.size();
1391 
1392     if (args_len || aux_len) {
1393       if (!skip_infer) {
1394         Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
1395         mxnet::ShapeVector arg_shapes(args_len + aux_len);
1396         nnvm::DTypeVector arg_dtypes(args_len + aux_len);
1397         StorageTypeVector arg_stypes(args_len + aux_len);
1398 
1399         // create the input shape, dtype and stype maps
1400         std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
1401         for (uint32_t i = 0; i < num_input_shapes; ++i) {
1402           input_shape_map.emplace(input_shape_names[i],
1403                     mxnet::TShape(input_shape_data + input_shape_idx[i],
1404                     input_shape_data + input_shape_idx[i+1]));
1405         }
1406         std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
1407         for (uint32_t i = 0; i < num_input_dtypes; ++i) {
1408           input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
1409         }
1410         std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
1411         for (uint32_t i = 0; i < num_input_stypes; ++i) {
1412           input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
1413         }
1414 
1415         size_t args_top = 0, aux_top = 0;
1416         // loop over inputs to symbol in order and add to args/aux if mutable
1417         for (size_t i = 0; i < num_forward_inputs; ++i) {
1418           const uint32_t nid = indexed_graph.input_nodes().at(i);
1419           if (mutable_nodes.count(nid)) {
1420             CHECK_LT(aux_top, aux_len)
1421               << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
1422             if (in_aux_ptr[aux_top] != nullptr) {
1423               const auto &in_arg = *(in_aux_ptr[aux_top]);
1424               arg_shapes[i] = in_arg.shape();
1425               arg_dtypes[i] = in_arg.dtype();
1426               arg_stypes[i] = in_arg.storage_type();
1427             }
1428             aux_top++;
1429           } else {
1430             auto name = input_names[i];
1431             CHECK_LT(args_top, args_len)
1432               << "Cannot find arg '" << name << "' in provided args to optimize_for";
1433             if (in_args_ptr[args_top] != nullptr) {
1434               const auto &in_arg = *(in_args_ptr[args_top]);
1435               arg_shapes[i] = in_arg.shape();
1436               arg_dtypes[i] = in_arg.dtype();
1437               arg_stypes[i] = in_arg.storage_type();
1438             } else {
1439               // input_names[i] is not in args but can be in the optional
1440               // shape/type/stype attribute dicts.
1441               auto it_shape = input_shape_map.find(name);
1442               if (it_shape != input_shape_map.end()) {
1443                 arg_shapes[i] = it_shape->second;
1444               }
1445               auto it_type = input_dtype_map.find(name);
1446               if (it_type != input_dtype_map.end()) {
1447                 arg_dtypes[i] = it_type->second;
1448               }
1449               it_type = input_stype_map.find(name);
1450               if (it_type != input_stype_map.end()) {
1451                 arg_stypes[i] = it_type->second;
1452               }
1453             }
1454             args_top++;
1455           }
1456         }
1457 
1458         g.attrs["context"] = std::make_shared<nnvm::any>(
1459           exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
1460 
1461         // infer shapes
1462         g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
1463         // infer dtypes
1464         g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
1465         // infer stypes
1466         g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
1467       }
1468       // set args/aux as attributes on graph so that subgraph property can use them
1469       std::vector<std::string> arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1470       g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
1471       g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
1472 
1473       std::vector<std::string> aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1474       g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
1475       g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
1476     } else {
1477       // args/aux were not specified, so set nullptr/empty-lists
1478       NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
1479       std::vector<std::string> arg_names;
1480       g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
1481       g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
1482 
1483       NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
1484       std::vector<std::string> aux_names;
1485       g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
1486       g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
1487     }
1488 
1489     // set dedup option as attribute on graph to enable dedup during partitioning
1490     if (options_map.count("dedup_subgraph") > 0 &&
1491         options_map.at("dedup_subgraph").compare("True") == 0)
1492       g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
1493     return g;
1494   };
1495 
1496   if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) {
1497     // use subgraph backend
1498     const auto backend = mxnet::op::SubgraphBackendRegistry
1499                                       ::Get()->GetSubgraphBackend(backend_name);
1500     const auto& subgraph_prop_list = backend->GetSubgraphProperties();
1501     for (auto property : subgraph_prop_list) {
1502       nnvm::Graph g = init_graph(s);
1503       property->PrePartition(g, options_map);
1504       g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
1505       g = ApplyPass(std::move(g), "BuildSubgraph");
1506       g.attrs.erase("subgraph_property");
1507       property->PostPartition(g);
1508       s->outputs = g.outputs;
1509     }
1510   } else if (dmlc::Registry<nnvm::PassFunctionReg>::Find(backend_name) != nullptr) {
1511     // use graph pass
1512     nnvm::Graph g = init_graph(s);
1513     g.attrs["options_map"] = std::make_shared<nnvm::any>(options_map);
1514     g.attrs["pass_name"] = std::make_shared<nnvm::any>(backend_name);
1515     g = ApplyPass(std::move(g), backend_name);
1516 
1517     std::vector<NDArray*> new_args = g.GetAttr<std::vector<NDArray*>>("new_args");
1518     std::vector<NDArray*> new_aux = g.GetAttr<std::vector<NDArray*>>("new_aux");
1519     std::vector<std::string> new_arg_names = g.GetAttr<std::vector<std::string>>("new_arg_names");
1520     std::vector<std::string> new_aux_names = g.GetAttr<std::vector<std::string>>("new_aux_names");
1521     g.attrs.erase("new_args");
1522     g.attrs.erase("new_aux");
1523     g.attrs.erase("new_arg_names");
1524     g.attrs.erase("new_aux_names");
1525     s->outputs = g.outputs;
1526 
1527     NDArray** new_arg_arr = new NDArray*[new_arg_names.size()];
1528     NDArray** new_aux_arr = new NDArray*[new_aux_names.size()];
1529     char** new_arg_cstr = new char*[new_arg_names.size()];
1530     char** new_aux_cstr = new char*[new_aux_names.size()];
1531     for (unsigned i = 0; i < new_arg_names.size(); i++) {
1532       new_arg_arr[i] = new_args[i];
1533       std::string& s = new_arg_names[i];
1534       char* tmp = new char[s.length()+1];
1535       s.copy(tmp, s.length());
1536       tmp[s.length()] = '\0';
1537       new_arg_cstr[i] = tmp;
1538     }
1539     for (unsigned i = 0; i < new_aux_names.size(); i++) {
1540       new_aux_arr[i] = new_aux[i];
1541       std::string& s = new_aux_names[i];
1542       char* tmp = new char[s.length()+1];
1543       s.copy(tmp, s.length());
1544       tmp[s.length()] = '\0';
1545       new_aux_cstr[i] = tmp;
1546     }
1547     *new_args_cnt = new_arg_names.size();
1548     *new_aux_cnt = new_aux_names.size();
1549     *new_arg_names_handle = new_arg_cstr;
1550     *new_aux_names_handle = new_aux_cstr;
1551     *new_args_ptr = new_arg_arr;
1552     *new_aux_ptr = new_aux_arr;
1553   } else {
1554     // cannot find graph pass or subgraph backend registered in this name
1555     LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found";
1556   }
1557 
1558   *ret_sym_handle = s;
1559   API_END_HANDLE_ERROR(delete s);
1560 }
1561