1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file c_api_symbolic.cc
22 * \brief C API of mxnet
23 */
24 #include "mxnet/base.h"
25 #include "mxnet/c_api.h"
26 #include "mxnet/imperative.h"
27 #include "nnvm/c_api.h"
28 #include "nnvm/pass.h"
29 #include "nnvm/pass_functions.h"
30 #include "nnvm/symbolic.h"
31 #include "./c_api_common.h"
32 #include "../common/exec_utils.h"
33 #include "../operator/operator_common.h"
34 #include "../executor/exec_pass.h"
35 #include "../operator/subgraph/subgraph_property.h"
36
37 namespace mxnet {
38 namespace op {
39 void RegisterLegacyOpProp();
40 void RegisterLegacyNDFunc();
41 }
42 const std::vector<std::string> kHiddenKeys = {
43 "ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage"
44 };
45 const std::vector<std::string> kReplacedHiddenKeys = {
46 "__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__", "__mirror_stage__"
47 };
48 const char *kNamespaceSeparator = "$";
49
50
51 DMLC_JSON_ENABLE_ANY(int, int);
52
53 // convert nnvm symbol to a nnvm graph.
Symbol2Graph(const nnvm::Symbol & s)54 nnvm::Graph Symbol2Graph(const nnvm::Symbol &s) {
55 nnvm::Graph g;
56 g.outputs = s.outputs;
57 g.attrs["mxnet_version"] = std::make_shared<nnvm::any>(static_cast<int>(MXNET_VERSION));
58 if (Imperative::Get()->is_np_shape()) {
59 g.attrs["is_np_shape"] = std::make_shared<nnvm::any>(
60 static_cast<int>(Imperative::Get()->is_np_shape()));
61 }
62 return g;
63 }
64
ReadOnlyArgIndices(const nnvm::IndexedGraph & idx)65 std::vector<uint32_t> ReadOnlyArgIndices(const nnvm::IndexedGraph& idx) {
66 std::vector<uint32_t> ret;
67 auto& arg_nodes = idx.input_nodes();
68 for (uint32_t i = 0; i < arg_nodes.size(); ++i) {
69 if (idx.mutable_input_nodes().count(arg_nodes[i]) == 0) {
70 ret.push_back(i);
71 }
72 }
73 return ret;
74 }
75
76 } // namespace mxnet
77
78 // symbolic configuration generation API.
79 // Redirect to NNVM's C API
MXListAllOpNames(nn_uint * out_size,const char *** out_array)80 int MXListAllOpNames(nn_uint *out_size,
81 const char ***out_array) {
82 mxnet::op::RegisterLegacyOpProp();
83 mxnet::op::RegisterLegacyNDFunc();
84 return NNListAllOpNames(out_size, out_array);
85 }
86
MXSymbolListAtomicSymbolCreators(uint32_t * out_size,AtomicSymbolCreator ** out_array)87 int MXSymbolListAtomicSymbolCreators(uint32_t *out_size,
88 AtomicSymbolCreator **out_array) {
89 mxnet::op::RegisterLegacyOpProp();
90 mxnet::op::RegisterLegacyNDFunc();
91 return NNListUniqueOps(out_size, out_array);
92 }
93
MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,const char ** name,const char ** description,uint32_t * num_args,const char *** arg_names,const char *** arg_type_infos,const char *** arg_descriptions,const char ** key_var_num_args,const char ** return_type)94 int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
95 const char **name,
96 const char **description,
97 uint32_t *num_args,
98 const char ***arg_names,
99 const char ***arg_type_infos,
100 const char ***arg_descriptions,
101 const char **key_var_num_args,
102 const char **return_type) {
103 static auto& map_key_var_args = nnvm::Op::GetAttr<std::string>("key_var_num_args");
104 const Op* op = static_cast<Op*>(creator);
105 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
106 ret->ret_str.resize(0);
107
108 if (map_key_var_args.count(op) != 0) {
109 *key_var_num_args = map_key_var_args[op].c_str();
110 } else {
111 *key_var_num_args = ret->ret_str.c_str();
112 }
113 return NNGetOpInfo(
114 creator, name, description,
115 num_args, arg_names, arg_type_infos,
116 arg_descriptions, return_type);
117 }
118
MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,uint32_t num_param,const char ** keys,const char ** vals,SymbolHandle * out)119 int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
120 uint32_t num_param,
121 const char **keys,
122 const char **vals,
123 SymbolHandle *out) {
124 nnvm::Symbol *s = new nnvm::Symbol();
125 API_BEGIN();
126 const nnvm::Op* op = static_cast<const nnvm::Op*>(creator);
127 std::unordered_map<std::string, std::string> kwargs;
128 for (nn_uint i = 0; i < num_param; ++i) {
129 bool flag = false;
130 for (const auto &k : kHiddenKeys) {
131 std::string tmp(keys[i]);
132 size_t pos = tmp.rfind(k);
133 if (pos == 0) {
134 kwargs.insert({"__" + tmp + "__", std::string(vals[i])});
135 flag = true;
136 break;
137 } else if (pos != std::string::npos && pos == tmp.length() - k.length()) {
138 std::ostringstream os;
139 os << "setting variable attributes with " << keys[i] << " is deprecated. "
140 << "please instead use\nw = Variable(" << k << "=" << vals[i] << ")\n"
141 << "sym = YourSymbolName(" << tmp.substr(0, pos-1) << "=w)";
142 throw dmlc::Error(os.str());
143 }
144 }
145 if (!flag)
146 kwargs.insert({std::string(keys[i]), std::string(vals[i])});
147 }
148 *s = nnvm::Symbol::CreateFunctor(op, std::move(kwargs));
149 *out = s;
150 API_END_HANDLE_ERROR(delete s;);
151 }
152
MXSymbolCreateVariable(const char * name,SymbolHandle * out)153 int MXSymbolCreateVariable(const char *name, SymbolHandle *out) {
154 return NNSymbolCreateVariable(name, out);
155 }
156
MXSymbolCreateGroup(uint32_t num_symbols,SymbolHandle * symbols,SymbolHandle * out)157 int MXSymbolCreateGroup(uint32_t num_symbols,
158 SymbolHandle *symbols,
159 SymbolHandle *out) {
160 return NNSymbolCreateGroup(num_symbols, symbols, out);
161 }
162
MXSymbolGetOutput(SymbolHandle symbol,uint32_t index,SymbolHandle * out)163 int MXSymbolGetOutput(SymbolHandle symbol,
164 uint32_t index,
165 SymbolHandle *out) {
166 return NNSymbolGetOutput(symbol, index, out);
167 }
168
MXSymbolGetInternals(SymbolHandle symbol,SymbolHandle * out)169 int MXSymbolGetInternals(SymbolHandle symbol,
170 SymbolHandle *out) {
171 nnvm::Symbol *s = new nnvm::Symbol();
172 API_BEGIN();
173 *s = static_cast<nnvm::Symbol*>(symbol)->GetInternals();
174 *out = s;
175 API_END_HANDLE_ERROR(delete s);
176 }
177
MXSymbolGetChildren(SymbolHandle symbol,SymbolHandle * out)178 int MXSymbolGetChildren(SymbolHandle symbol,
179 SymbolHandle *out) {
180 nnvm::Symbol *s = new nnvm::Symbol();
181 API_BEGIN();
182 *s = static_cast<nnvm::Symbol*>(symbol)->GetChildren();
183 *out = s;
184 API_END_HANDLE_ERROR(delete s);
185 }
186
MXSymbolFree(SymbolHandle symbol)187 int MXSymbolFree(SymbolHandle symbol) {
188 return NNSymbolFree(symbol);
189 }
190
MXSymbolCopy(SymbolHandle symbol,SymbolHandle * out)191 int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) {
192 return NNSymbolCopy(symbol, out);
193 }
194
MXSymbolPrint(SymbolHandle symbol,const char ** out_str)195 int MXSymbolPrint(SymbolHandle symbol, const char **out_str) {
196 return NNSymbolPrint(symbol, out_str);
197 }
198
MXSymbolGetName(SymbolHandle symbol,const char ** out,int * success)199 int MXSymbolGetName(SymbolHandle symbol,
200 const char** out,
201 int* success) {
202 return NNSymbolGetAttr(symbol, "name", out, success);
203 }
204
MXSymbolGetAttr(SymbolHandle symbol,const char * key,const char ** out,int * success)205 int MXSymbolGetAttr(SymbolHandle symbol,
206 const char* key,
207 const char** out,
208 int* success) {
209 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
210 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
211 API_BEGIN();
212 if (s->GetAttr(key, &(ret->ret_str))) {
213 *out = (ret->ret_str).c_str();
214 *success = 1;
215 } else {
216 *out = nullptr;
217 *success = 0;
218 if (std::find(kHiddenKeys.begin(), kHiddenKeys.end(), key) != kHiddenKeys.end()) {
219 std::string skey = "__" + std::string(key) + "__";
220 if (s->GetAttr(skey, &(ret->ret_str))) {
221 *out = (ret->ret_str).c_str();
222 *success = 1;
223 }
224 }
225 }
226 API_END();
227 }
228
MXSymbolSetAttr(SymbolHandle symbol,const char * key,const char * value)229 int MXSymbolSetAttr(SymbolHandle symbol,
230 const char* key,
231 const char* value) {
232 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
233 API_BEGIN();
234 std::vector<std::pair<std::string, std::string> > kwargs;
235 std::string skey(key), sval(value);
236 for (const auto &k : kHiddenKeys) {
237 size_t pos = skey.rfind(k);
238 if (pos == 0 && k.length() == skey.length()) {
239 skey = "__" + skey + "__";
240 break;
241 } else if (pos != std::string::npos && pos + k.length() == skey.length()) {
242 std::ostringstream os;
243 os << "setting variable attributes with " << key << " is deprecated. "
244 << "please instead use\nw = Variable(" << k << "=" << value << ")\n"
245 << "sym = YourSymbolName(" << skey.substr(0, pos-1) << "=w)";
246 throw dmlc::Error(os.str());
247 }
248 }
249 kwargs.emplace_back(std::make_pair(std::move(skey), std::move(sval)));
250 s->SetAttrs(kwargs);
251 API_END();
252 }
253
MXSymbolListAttr(SymbolHandle symbol,uint32_t * out_size,const char *** out)254 int MXSymbolListAttr(SymbolHandle symbol,
255 uint32_t *out_size,
256 const char*** out) {
257 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
258 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
259 API_BEGIN();
260 std::vector<std::tuple<std::string, std::string, std::string> > attr =
261 s->ListAttrsRecursive();
262
263 std::vector<std::string>& attr_list = ret->ret_vec_str;
264 attr_list.clear();
265 for (const auto& tp : attr) {
266 attr_list.emplace_back(std::get<0>(tp) + kNamespaceSeparator + std::get<1>(tp));
267 attr_list.emplace_back(std::get<2>(tp));
268 if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), std::get<1>(tp))
269 != kReplacedHiddenKeys.end()) {
270 attr_list.push_back(std::get<0>(tp) + kNamespaceSeparator +
271 std::get<1>(tp).substr(2, std::get<1>(tp).length() - 4));
272 attr_list.push_back(std::get<2>(tp));
273 }
274 }
275 *out_size = attr_list.size()/2;
276 ret->ret_vec_charp.clear();
277 for (const auto& attr : attr_list) {
278 ret->ret_vec_charp.push_back(attr.c_str());
279 }
280 *out = dmlc::BeginPtr(ret->ret_vec_charp);
281 API_END();
282 }
283
MXSymbolListAttrShallow(SymbolHandle symbol,uint32_t * out_size,const char *** out)284 int MXSymbolListAttrShallow(SymbolHandle symbol,
285 uint32_t *out_size,
286 const char*** out) {
287 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
288 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
289 API_BEGIN();
290 std::unordered_map<std::string, std::string> attr =
291 s->ListAttrs(static_cast<nnvm::Symbol::ListAttrOption>(1)); // NOLINT(*)
292
293 std::vector<std::string>& attr_list = ret->ret_vec_str;
294 attr_list.clear();
295 for (const auto& kv : attr) {
296 attr_list.push_back(kv.first);
297 attr_list.push_back(kv.second);
298 if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), kv.first)
299 != kReplacedHiddenKeys.end()) {
300 attr_list.push_back(kv.first.substr(2, kv.first.length() - 4));
301 attr_list.push_back(kv.second);
302 }
303 }
304 *out_size = attr_list.size()/2;
305 ret->ret_vec_charp.clear();
306 for (auto &attr : attr_list) {
307 ret->ret_vec_charp.push_back(attr.c_str());
308 }
309 *out = dmlc::BeginPtr(ret->ret_vec_charp);
310 API_END();
311 }
312
MXSymbolListOutputs(SymbolHandle symbol,uint32_t * out_size,const char *** out_str_array)313 int MXSymbolListOutputs(SymbolHandle symbol,
314 uint32_t *out_size,
315 const char ***out_str_array) {
316 return NNSymbolListOutputNames(symbol, out_size, out_str_array);
317 }
318
MXSymbolGetNumOutputs(SymbolHandle symbol,uint32_t * output_count)319 int MXSymbolGetNumOutputs(SymbolHandle symbol,
320 uint32_t *output_count) {
321 return NNSymbolGetNumOutputs(symbol, output_count);
322 }
323
MXSymbolCompose(SymbolHandle sym,const char * name,uint32_t num_args,const char ** keys,SymbolHandle * args)324 int MXSymbolCompose(SymbolHandle sym,
325 const char *name,
326 uint32_t num_args,
327 const char** keys,
328 SymbolHandle* args) {
329 return NNSymbolCompose(sym, name, num_args, keys, args);
330 }
331
332 // adapter functions that re-implements the functions.
MXSymbolListArguments(SymbolHandle symbol,uint32_t * out_size,const char *** out_str_array)333 int MXSymbolListArguments(SymbolHandle symbol,
334 uint32_t *out_size,
335 const char ***out_str_array) {
336 return NNSymbolListInputNames(symbol, 1, out_size, out_str_array);
337 }
338
MXSymbolListAuxiliaryStates(SymbolHandle symbol,uint32_t * out_size,const char *** out_str_array)339 int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
340 uint32_t *out_size,
341 const char ***out_str_array) {
342 return NNSymbolListInputNames(symbol, 2, out_size, out_str_array);
343 }
344
MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,const char ** out)345 int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
346 const char **out) {
347 API_BEGIN();
348 Op *e = static_cast<Op *>(creator);
349 *out = e->name.c_str();
350 API_END();
351 }
352
353 namespace mxnet {
354
355 extern std::vector<nnvm::Symbol *> GetInputSymbols(const nnvm::Symbol &sym);
356 extern bool CutGraphInputs(const std::vector<nnvm::NodeEntry *> &input_entries,
357 bool skip_var, std::vector<nnvm::NodeEntry> *orig_entries);
358
359 }
360
MXSymbolGetInputSymbols(SymbolHandle sym,SymbolHandle ** input_arr,int * input_size)361 int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *input_size) {
362 API_BEGIN();
363 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
364 std::vector<nnvm::Symbol *> input_syms = mxnet::GetInputSymbols(*s);
365 *input_size = input_syms.size();
366
367 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
368 ret->ret_handles.clear();
369 ret->ret_handles.reserve(*input_size);
370 for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]);
371 *input_arr = reinterpret_cast<SymbolHandle*>(dmlc::BeginPtr(ret->ret_handles));
372 API_END_HANDLE_ERROR();
373 }
374
MXSymbolCutSubgraph(SymbolHandle sym,SymbolHandle ** input_symbols,int * input_size)375 int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
376 int *input_size) {
377 // Given a graph, we want to fetch the nodes that have been marked as part of
378 // a subgraph.
379 API_BEGIN();
380 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
381 const std::string subg_attr = "__subgraph_name__";
382 auto out_node = s->outputs[0].node;
383 auto it = out_node->attrs.dict.find(subg_attr);
384 if (it != out_node->attrs.dict.end()) {
385 const std::string &subg_name = it->second;
386 std::vector<nnvm::NodeEntry *> input_entries;
387 DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries]
388 (nnvm::ObjectPtr n) {
389 // If the node itself isn't in the subgraph, we ignore it.
390 auto it = n->attrs.dict.find(subg_attr);
391 if (it == n->attrs.dict.end() || it->second != subg_name)
392 return;
393
394 // We search for nodes whose node entries aren't in the subgraph.
395 for (size_t j = 0; j < n->inputs.size(); j++) {
396 auto in_node = n->inputs[j].node;
397 auto it = in_node->attrs.dict.find(subg_attr);
398 if (it == in_node->attrs.dict.end() || it->second != subg_name)
399 input_entries.push_back(&n->inputs[j]);
400 }
401 });
402
403 std::vector<nnvm::NodeEntry> orig_entries;
404 CutGraphInputs(input_entries, false, &orig_entries);
405 std::vector<nnvm::Symbol *> input_syms(orig_entries.size());
406 for (size_t i = 0; i < input_syms.size(); i++) {
407 input_syms[i] = new nnvm::Symbol();
408 input_syms[i]->outputs.push_back(orig_entries[i]);
409 }
410 *input_size = input_syms.size();
411
412 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
413 ret->ret_handles.clear();
414 ret->ret_handles.reserve(*input_size);
415 for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]);
416 *input_symbols = reinterpret_cast<SymbolHandle*>(dmlc::BeginPtr(ret->ret_handles));
417 } else {
418 *input_size = 0;
419 }
420
421 API_END_HANDLE_ERROR();
422 }
423
424
425 /*!
426 * \brief Convert shape attr in graph nodes to comply with NumPy semantics for
427 * legacy models (before 1.6.0) if global flag is_np_shape has been turned on,
428 * i.e., use -1 to indicate unknown number of dimensions and unknown dimension sizes.
429 */
ConvertShapeAttrToNumPyCompatible(nnvm::Graph * g)430 void ConvertShapeAttrToNumPyCompatible(nnvm::Graph* g) {
431 if (Imperative::Get()->is_np_shape()
432 && (!g->HasAttr("is_np_shape") || !g->GetAttr<int>("is_np_shape"))) {
433 DFSVisit(g->outputs, [](nnvm::ObjectPtr n) {
434 if (n->is_variable()) {
435 auto it = n->attrs.dict.find("__shape__");
436 if (it != n->attrs.dict.end()) {
437 mxnet::TShape shape;
438 std::istringstream is(it->second);
439 is >> shape;
440 common::ConvertToNumpyShape(&shape);
441 std::ostringstream os;
442 os << shape;
443 it->second = os.str();
444 }
445 }
446 });
447 }
448 }
449
MXSymbolCreateFromFile(const char * fname,SymbolHandle * out)450 int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) {
451 nnvm::Symbol *s = new nnvm::Symbol();
452 API_BEGIN();
453 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
454 dmlc::istream is(fi.get());
455 nnvm::Graph g;
456 g.attrs["json"] = std::make_shared<nnvm::any>(
457 std::string(std::istreambuf_iterator<char>(is), std::istreambuf_iterator<char>()));
458 g = nnvm::ApplyPass(g, "LoadLegacyJSON");
459 ConvertShapeAttrToNumPyCompatible(&g);
460 s->outputs = g.outputs;
461 *out = s;
462 is.set_stream(nullptr);
463 API_END_HANDLE_ERROR(delete s);
464 }
465
MXSymbolCreateFromJSON(const char * json,SymbolHandle * out)466 int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out) {
467 nnvm::Symbol *s = new nnvm::Symbol();
468 API_BEGIN();
469 nnvm::Graph g;
470 g.attrs["json"] = std::make_shared<nnvm::any>(std::string(json));
471 g = nnvm::ApplyPass(g, "LoadLegacyJSON");
472 ConvertShapeAttrToNumPyCompatible(&g);
473 s->outputs = g.outputs;
474 *out = s;
475 API_END_HANDLE_ERROR(delete s);
476 }
477
MXSymbolRemoveAmpCast(SymbolHandle sym_handle,SymbolHandle * ret_sym_handle)478 int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle) {
479 nnvm::Symbol* s = new nnvm::Symbol();
480 API_BEGIN();
481 nnvm::Symbol *source = static_cast<nnvm::Symbol*>(sym_handle);
482 *s = source->Copy();
483 s->outputs = nnvm::ApplyPass(Symbol2Graph(*s), "RemoveAmpCast").outputs;
484 *ret_sym_handle = s;
485 API_END_HANDLE_ERROR(delete s);
486 }
487
MXSymbolSaveToFile(SymbolHandle symbol,const char * fname)488 int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) {
489 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
490 API_BEGIN();
491 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
492 dmlc::ostream os(fo.get());
493 os << nnvm::pass::SaveJSON(Symbol2Graph(*s));
494 // reset file pointer, force flush
495 os.set_stream(nullptr);
496 API_END();
497 }
498
MXSymbolSaveToJSON(SymbolHandle symbol,const char ** out_json)499 int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) {
500 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
501 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
502 API_BEGIN();
503 ret->ret_str = nnvm::pass::SaveJSON(Symbol2Graph(*s));
504 *out_json = ret->ret_str.c_str();
505 API_END();
506 }
507
508 namespace mxnet {
509
510 template<typename AttrType>
MatchArguments(const nnvm::IndexedGraph & idx,const std::unordered_map<std::string,AttrType> & known_arg_attrs,std::vector<AttrType> * arg_attrs,const char * source)511 void MatchArguments(
512 const nnvm::IndexedGraph& idx,
513 const std::unordered_map<std::string, AttrType>& known_arg_attrs,
514 std::vector<AttrType>* arg_attrs,
515 const char* source) {
516 auto& arg_nodes = idx.input_nodes();
517 CHECK_EQ(arg_attrs->size(), arg_nodes.size());
518 size_t nmatched = 0;
519 for (size_t i = 0; i < arg_nodes.size(); ++i) {
520 const std::string& name = idx[arg_nodes[i]].source->attrs.name;
521 auto it = known_arg_attrs.find(name);
522 if (it != known_arg_attrs.end()) {
523 arg_attrs->at(i) = it->second;
524 ++nmatched;
525 }
526 }
527 if (nmatched != known_arg_attrs.size()) {
528 std::unordered_set<std::string> keys;
529 std::ostringstream head, msg;
530 msg << "\nCandidate arguments:\n";
531 for (size_t i = 0; i < arg_nodes.size(); ++i) {
532 std::string arg_name = idx[arg_nodes[i]].source->attrs.name;
533 keys.insert(arg_name);
534 msg << "\t[" << i << ']' << arg_name << '\n';
535 }
536 for (const auto& kv : known_arg_attrs) {
537 const std::string& key = kv.first;
538 if (keys.count(key) == 0) {
539 LOG(FATAL) << source
540 << "Keyword argument name " << key << " not found."
541 << msg.str();
542 }
543 }
544 }
545 }
546
547 } // namespace mxnet
548
MXSymbolInferShape(SymbolHandle sym,uint32_t num_args,const char ** keys,const uint32_t * arg_ind_ptr,const uint32_t * arg_shape_data,uint32_t * in_shape_size,const uint32_t ** in_shape_ndim,const uint32_t *** in_shape_data,uint32_t * out_shape_size,const uint32_t ** out_shape_ndim,const uint32_t *** out_shape_data,uint32_t * aux_shape_size,const uint32_t ** aux_shape_ndim,const uint32_t *** aux_shape_data,int * complete)549 int MXSymbolInferShape(SymbolHandle sym,
550 uint32_t num_args,
551 const char** keys,
552 const uint32_t *arg_ind_ptr,
553 const uint32_t *arg_shape_data,
554 uint32_t *in_shape_size,
555 const uint32_t **in_shape_ndim,
556 const uint32_t ***in_shape_data,
557 uint32_t *out_shape_size,
558 const uint32_t **out_shape_ndim,
559 const uint32_t ***out_shape_data,
560 uint32_t *aux_shape_size,
561 const uint32_t **aux_shape_ndim,
562 const uint32_t ***aux_shape_data,
563 int *complete) {
564 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
565 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
566 API_BEGIN();
567 nnvm::Graph g = Symbol2Graph(*s);
568 mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape());
569 if (keys == nullptr && num_args != 0) {
570 std::vector<uint32_t> read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
571 CHECK_LE(num_args, read_only_args.size());
572 for (uint32_t i = 0; i < num_args; ++i) {
573 arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(
574 arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]);
575 }
576 } else {
577 std::unordered_map<std::string, mxnet::TShape> kwargs;
578 for (uint32_t i = 0; i < num_args; ++i) {
579 kwargs[keys[i]] = mxnet::ShapeTypeCast(
580 arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]);
581 }
582 mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape");
583 }
584
585 try {
586 g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
587 } catch (const mxnet::op::InferShapeError &err) {
588 throw dmlc::Error(err.msg);
589 }
590
591 // if use legacy shape definition, need to convert numpy shape to legacy shape
592 mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
593 if (!Imperative::Get()->is_np_shape()) {
594 common::ConvertToLegacyShape(&shapes);
595 }
596
597 // copy back
598 CopyAttr(g.indexed_graph(), shapes,
599 &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
600
601 // copy data back
602 MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->arg_shapes,
603 &(ret->arg_shape_ndim), &(ret->arg_shape_data), &(ret->arg_shape_buffer));
604 MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->out_shapes,
605 &(ret->out_shape_ndim), &(ret->out_shape_data), &(ret->out_shape_buffer));
606 MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->aux_shapes,
607 &(ret->aux_shape_ndim), &(ret->aux_shape_data), &(ret->aux_shape_buffer));
608 *in_shape_size = static_cast<uint32_t>(ret->arg_shapes.size());
609 *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim);
610 *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data);
611 *out_shape_size = static_cast<uint32_t>(ret->out_shapes.size());
612 *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim);
613 *out_shape_data = dmlc::BeginPtr(ret->out_shape_data);
614 *aux_shape_size = static_cast<uint32_t>(ret->aux_shapes.size());
615 *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim);
616 *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data);
617 // mark complete
618 *complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
619 API_END();
620 }
621
622 template<typename dtype, typename stype, typename itype>
SymbolInferShape(const char ** keys,uint32_t num_args,const dtype * arg_shape_data,const itype * arg_ind_ptr,const int ** in_shape_ndim,const dtype *** in_shape_data,const int ** out_shape_ndim,const dtype *** out_shape_data,const int ** aux_shape_ndim,const dtype *** aux_shape_data,nnvm::Symbol * s,MXAPIThreadLocalEntry<dtype> * ret,stype * in_shape_size,stype * out_shape_size,stype * aux_shape_size,int * complete)623 inline void SymbolInferShape(const char** keys,
624 uint32_t num_args,
625 const dtype* arg_shape_data,
626 const itype* arg_ind_ptr,
627 const int** in_shape_ndim,
628 const dtype*** in_shape_data,
629 const int** out_shape_ndim,
630 const dtype*** out_shape_data,
631 const int** aux_shape_ndim,
632 const dtype*** aux_shape_data,
633 nnvm::Symbol* s,
634 MXAPIThreadLocalEntry<dtype>* ret,
635 stype* in_shape_size,
636 stype* out_shape_size,
637 stype* aux_shape_size,
638 int* complete) {
639 nnvm::Graph g = Symbol2Graph(*s);
640 mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape());
641 if (keys == nullptr && num_args != 0) {
642 std::vector < uint32_t > read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
643 CHECK_LE(num_args, read_only_args.size());
644 for (uint32_t i = 0; i < num_args; ++i) {
645 arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i],
646 arg_shape_data + arg_ind_ptr[i + 1]);
647 }
648 } else {
649 std::unordered_map<std::string, mxnet::TShape> kwargs;
650 for (uint32_t i = 0; i < num_args; ++i) {
651 kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i],
652 arg_shape_data + arg_ind_ptr[i + 1]);
653 }
654 mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape");
655 }
656 try {
657 g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
658 } catch (const mxnet::op::InferShapeError& err) {
659 throw dmlc::Error(err.msg);
660 }
661 // if use legacy shape definition, need to convert numpy shape to legacy shape
662 mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
663 if (!Imperative::Get()->is_np_shape()) {
664 common::ConvertToLegacyShape(&shapes);
665 }
666 // copy back
667 CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
668 // copy data back
669 MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes,
670 &(ret->arg_shape_ndim_ex),
671 &(ret->arg_shape_data_ex),
672 &(ret->arg_shape_buffer_ex));
673 MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->out_shapes,
674 &(ret->out_shape_ndim_ex),
675 &(ret->out_shape_data_ex),
676 &(ret->out_shape_buffer_ex));
677 MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes,
678 &(ret->aux_shape_ndim_ex),
679 &(ret->aux_shape_data_ex),
680 &(ret->aux_shape_buffer_ex));
681 *in_shape_size = static_cast<stype>(ret->arg_shapes.size());
682 *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex);
683 *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex);
684 *out_shape_size = static_cast<stype>(ret->out_shapes.size());
685 *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex);
686 *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex);
687 *aux_shape_size = static_cast<stype>(ret->aux_shapes.size());
688 *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex);
689 *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex);
690 // mark complete
691 *complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
692 }
693
694 /*!
695 * \brief Executor for Symbol Shape Inference
696 * This api is available when MXNet is built with flag
697 * USE_INT64_TENSOR_SIZE=0 (by default)
698 * \param sym symbol handle
699 * \param num_args number of args
700 * \param keys keys
701 * \param arg_ind_ptr arg index pointer
702 * \param arg_shape_data arg shape data
703 * \param in_shape_size input shape size
704 * \param in_shape_ndim input shape number of dims
705 * \param in_shape_data input shape data
706 * \param out_shape_size ouput shape size
707 * \param out_shape_ndim output shape number of dims
708 * \param out_shape_data output shape data
709 * \param aux_shape_size shape size of auxiliary states
710 * \param aux_shape_ndim number of dims of auxiliary states shape
711 * \param aux_shape_data shape data of auxiliary states
712 * \param complete indicates completion of Shape Inference
713 * \return 0 when success, -1 when failure happens
714 */
MXSymbolInferShapeEx(SymbolHandle sym,uint32_t num_args,const char ** keys,const uint32_t * arg_ind_ptr,const int * arg_shape_data,uint32_t * in_shape_size,const int ** in_shape_ndim,const int *** in_shape_data,uint32_t * out_shape_size,const int ** out_shape_ndim,const int *** out_shape_data,uint32_t * aux_shape_size,const int ** aux_shape_ndim,const int *** aux_shape_data,int * complete)715 int MXSymbolInferShapeEx(SymbolHandle sym,
716 uint32_t num_args,
717 const char** keys,
718 const uint32_t *arg_ind_ptr,
719 const int *arg_shape_data,
720 uint32_t *in_shape_size,
721 const int **in_shape_ndim,
722 const int ***in_shape_data,
723 uint32_t *out_shape_size,
724 const int **out_shape_ndim,
725 const int ***out_shape_data,
726 uint32_t *aux_shape_size,
727 const int **aux_shape_ndim,
728 const int ***aux_shape_data,
729 int *complete) {
730 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
731 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
732 API_BEGIN();
733 SymbolInferShape<int, uint32_t, uint32_t>(keys,
734 num_args,
735 arg_shape_data,
736 arg_ind_ptr,
737 in_shape_ndim,
738 in_shape_data,
739 out_shape_ndim,
740 out_shape_data,
741 aux_shape_ndim,
742 aux_shape_data,
743 s,
744 ret,
745 in_shape_size,
746 out_shape_size,
747 aux_shape_size,
748 complete);
749 API_END();
750 }
751
752 /*!
753 * \brief Executor for Symbol Shape Inference
754 * This api is available when MXNet is built with flag
755 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
756 * \param sym symbol handle
757 * \param num_args number of args
758 * \param keys keys
759 * \param arg_ind_ptr arg index pointer
760 * \param arg_shape_data arg shape data
761 * \param in_shape_size input shape size
762 * \param in_shape_ndim input shape number of dims
763 * \param in_shape_data input shape data
764 * \param out_shape_size ouput shape size
765 * \param out_shape_ndim output shape number of dims
766 * \param out_shape_data output shape data
767 * \param aux_shape_size shape size of auxiliary states
768 * \param aux_shape_ndim number of dims of auxiliary states shape
769 * \param aux_shape_data shape data of auxiliary states
770 * \param complete indicates completion of Shape Inference
771 * \return 0 when success, -1 when failure happens
772 */
MXSymbolInferShapeEx64(SymbolHandle sym,uint32_t num_args,const char ** keys,const int64_t * arg_ind_ptr,const int64_t * arg_shape_data,size_t * in_shape_size,const int ** in_shape_ndim,const int64_t *** in_shape_data,size_t * out_shape_size,const int ** out_shape_ndim,const int64_t *** out_shape_data,size_t * aux_shape_size,const int ** aux_shape_ndim,const int64_t *** aux_shape_data,int * complete)773 int MXSymbolInferShapeEx64(SymbolHandle sym,
774 uint32_t num_args,
775 const char** keys,
776 const int64_t *arg_ind_ptr,
777 const int64_t *arg_shape_data,
778 size_t *in_shape_size,
779 const int **in_shape_ndim,
780 const int64_t ***in_shape_data,
781 size_t *out_shape_size,
782 const int **out_shape_ndim,
783 const int64_t ***out_shape_data,
784 size_t *aux_shape_size,
785 const int **aux_shape_ndim,
786 const int64_t ***aux_shape_data,
787 int *complete) {
788 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
789 MXAPIThreadLocalEntry<int64_t> *ret = MXAPIThreadLocalStore<int64_t>::Get();
790 API_BEGIN();
791 SymbolInferShape<int64_t, size_t, int64_t>(keys,
792 num_args,
793 arg_shape_data,
794 arg_ind_ptr,
795 in_shape_ndim,
796 in_shape_data,
797 out_shape_ndim,
798 out_shape_data,
799 aux_shape_ndim,
800 aux_shape_data,
801 s,
802 ret,
803 in_shape_size,
804 out_shape_size,
805 aux_shape_size,
806 complete);
807 API_END();
808 }
809
MXSymbolInferShapePartial(SymbolHandle sym,uint32_t num_args,const char ** keys,const uint32_t * arg_ind_ptr,const uint32_t * arg_shape_data,uint32_t * in_shape_size,const uint32_t ** in_shape_ndim,const uint32_t *** in_shape_data,uint32_t * out_shape_size,const uint32_t ** out_shape_ndim,const uint32_t *** out_shape_data,uint32_t * aux_shape_size,const uint32_t ** aux_shape_ndim,const uint32_t *** aux_shape_data,int * complete)810 int MXSymbolInferShapePartial(SymbolHandle sym,
811 uint32_t num_args,
812 const char** keys,
813 const uint32_t *arg_ind_ptr,
814 const uint32_t *arg_shape_data,
815 uint32_t *in_shape_size,
816 const uint32_t **in_shape_ndim,
817 const uint32_t ***in_shape_data,
818 uint32_t *out_shape_size,
819 const uint32_t **out_shape_ndim,
820 const uint32_t ***out_shape_data,
821 uint32_t *aux_shape_size,
822 const uint32_t **aux_shape_ndim,
823 const uint32_t ***aux_shape_data,
824 int *complete) {
825 int succ = 0;
826 *complete = 1;
827 return MXSymbolInferShape(sym, num_args, keys,
828 arg_ind_ptr, arg_shape_data,
829 in_shape_size, in_shape_ndim, in_shape_data,
830 out_shape_size, out_shape_ndim, out_shape_data,
831 aux_shape_size, aux_shape_ndim, aux_shape_data,
832 &succ);
833 }
834
835 /*!
836 * \brief Executor for Symbol Partial Shape Inference
837 * This api is available when MXNet is built with flag
838 * USE_INT64_TENSOR_SIZE=0 (by default)
839 * \param sym symbol handle
840 * \param num_args number of args
841 * \param keys keys
842 * \param arg_ind_ptr arg index pointer
843 * \param arg_shape_data arg shape data
844 * \param in_shape_size input shape size
845 * \param in_shape_ndim input shape number of dims
846 * \param in_shape_data input shape data
847 * \param out_shape_size ouput shape size
848 * \param out_shape_ndim output shape number of dims
849 * \param out_shape_data output shape data
850 * \param aux_shape_size shape size of auxiliary states
851 * \param aux_shape_ndim number of dims of auxiliary states shape
852 * \param aux_shape_data shape data of auxiliary states
853 * \param complete indicates completion of Shape Inference
854 * \return 0 when success, -1 when failure happens
855 */
MXSymbolInferShapePartialEx(SymbolHandle sym,uint32_t num_args,const char ** keys,const uint32_t * arg_ind_ptr,const int * arg_shape_data,uint32_t * in_shape_size,const int ** in_shape_ndim,const int *** in_shape_data,uint32_t * out_shape_size,const int ** out_shape_ndim,const int *** out_shape_data,uint32_t * aux_shape_size,const int ** aux_shape_ndim,const int *** aux_shape_data,int * complete)856 int MXSymbolInferShapePartialEx(SymbolHandle sym,
857 uint32_t num_args,
858 const char** keys,
859 const uint32_t *arg_ind_ptr,
860 const int *arg_shape_data,
861 uint32_t *in_shape_size,
862 const int **in_shape_ndim,
863 const int ***in_shape_data,
864 uint32_t *out_shape_size,
865 const int **out_shape_ndim,
866 const int ***out_shape_data,
867 uint32_t *aux_shape_size,
868 const int **aux_shape_ndim,
869 const int ***aux_shape_data,
870 int *complete) {
871 int succ = 0;
872 *complete = 1;
873 return MXSymbolInferShapeEx(sym, num_args, keys,
874 arg_ind_ptr, arg_shape_data,
875 in_shape_size, in_shape_ndim, in_shape_data,
876 out_shape_size, out_shape_ndim, out_shape_data,
877 aux_shape_size, aux_shape_ndim, aux_shape_data,
878 &succ);
879 }
880
881 /*!
882 * \brief Executor for Symbol Partial Shape Inference
883 * This api is available when MXNet is built with flag
884 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
885 * \param sym symbol handle
886 * \param num_args number of args
887 * \param keys keys
888 * \param arg_ind_ptr arg index pointer
889 * \param arg_shape_data arg shape data
890 * \param in_shape_size input shape size
891 * \param in_shape_ndim input shape number of dims
892 * \param in_shape_data input shape data
893 * \param out_shape_size ouput shape size
894 * \param out_shape_ndim output shape number of dims
895 * \param out_shape_data output shape data
896 * \param aux_shape_size shape size of auxiliary states
897 * \param aux_shape_ndim number of dims of auxiliary states shape
898 * \param aux_shape_data shape data of auxiliary states
899 * \param complete indicates completion of Shape Inference
900 * \return 0 when success, -1 when failure happens
901 */
MXSymbolInferShapePartialEx64(SymbolHandle sym,uint32_t num_args,const char ** keys,const int64_t * arg_ind_ptr,const int64_t * arg_shape_data,size_t * in_shape_size,const int ** in_shape_ndim,const int64_t *** in_shape_data,size_t * out_shape_size,const int ** out_shape_ndim,const int64_t *** out_shape_data,size_t * aux_shape_size,const int ** aux_shape_ndim,const int64_t *** aux_shape_data,int * complete)902 int MXSymbolInferShapePartialEx64(SymbolHandle sym,
903 uint32_t num_args,
904 const char** keys,
905 const int64_t *arg_ind_ptr,
906 const int64_t *arg_shape_data,
907 size_t *in_shape_size,
908 const int **in_shape_ndim,
909 const int64_t ***in_shape_data,
910 size_t *out_shape_size,
911 const int **out_shape_ndim,
912 const int64_t ***out_shape_data,
913 size_t *aux_shape_size,
914 const int **aux_shape_ndim,
915 const int64_t ***aux_shape_data,
916 int *complete) {
917 int succ = 0;
918 *complete = 1;
919 return MXSymbolInferShapeEx64(sym, num_args, keys,
920 arg_ind_ptr, arg_shape_data,
921 in_shape_size, in_shape_ndim, in_shape_data,
922 out_shape_size, out_shape_ndim, out_shape_data,
923 aux_shape_size, aux_shape_ndim, aux_shape_data,
924 &succ);
925 }
926
MXSymbolInferType(SymbolHandle sym,uint32_t num_args,const char ** keys,const int * arg_type_data,uint32_t * in_type_size,const int ** in_type_data,uint32_t * out_type_size,const int ** out_type_data,uint32_t * aux_type_size,const int ** aux_type_data,int * complete)927 int MXSymbolInferType(SymbolHandle sym,
928 uint32_t num_args,
929 const char** keys,
930 const int *arg_type_data,
931 uint32_t *in_type_size,
932 const int **in_type_data,
933 uint32_t *out_type_size,
934 const int **out_type_data,
935 uint32_t *aux_type_size,
936 const int **aux_type_data,
937 int *complete) {
938 nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
939 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
940 API_BEGIN();
941 nnvm::Graph g = Symbol2Graph(*s);
942 nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1);
943 if (keys == nullptr && num_args != 0) {
944 std::vector<uint32_t> read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
945 CHECK_LE(num_args, read_only_args.size());
946 for (uint32_t i = 0; i < num_args; ++i) {
947 arg_types[read_only_args[i]] = arg_type_data[i];
948 }
949 } else {
950 std::unordered_map<std::string, int> kwargs;
951 for (uint32_t i = 0; i < num_args; ++i) {
952 kwargs[keys[i]] = arg_type_data[i];
953 }
954 mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
955 }
956
957 g = mxnet::exec::InferType(std::move(g), std::move(arg_types), "__dtype__");
958 // copy back
959 CopyAttr(g.indexed_graph(), g.GetAttr<nnvm::DTypeVector>("dtype"),
960 &(ret->arg_types), &(ret->out_types), &(ret->aux_types));
961
962 *in_type_size = static_cast<uint32_t>(ret->arg_types.size());
963 *in_type_data = dmlc::BeginPtr(ret->arg_types);
964 *out_type_size = static_cast<uint32_t>(ret->out_types.size());
965 *out_type_data = dmlc::BeginPtr(ret->out_types);
966 *aux_type_size = static_cast<uint32_t>(ret->aux_types.size());
967 *aux_type_data = dmlc::BeginPtr(ret->aux_types);
968 *complete = (g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0);
969 API_END();
970 }
971
MXSymbolInferTypePartial(SymbolHandle sym,uint32_t num_args,const char ** keys,const int * arg_type_data,uint32_t * in_type_size,const int ** in_type_data,uint32_t * out_type_size,const int ** out_type_data,uint32_t * aux_type_size,const int ** aux_type_data,int * complete)972 int MXSymbolInferTypePartial(SymbolHandle sym,
973 uint32_t num_args,
974 const char** keys,
975 const int *arg_type_data,
976 uint32_t *in_type_size,
977 const int **in_type_data,
978 uint32_t *out_type_size,
979 const int **out_type_data,
980 uint32_t *aux_type_size,
981 const int **aux_type_data,
982 int *complete) {
983 int succ = 0;
984 *complete = 1;
985 return MXSymbolInferType(sym, num_args, keys,
986 arg_type_data,
987 in_type_size, in_type_data,
988 out_type_size, out_type_data,
989 aux_type_size, aux_type_data,
990 &succ);
991 }
992
MXSymbolGrad(SymbolHandle sym,uint32_t num_wrt,const char ** wrt,SymbolHandle * out)993 int MXSymbolGrad(SymbolHandle sym, uint32_t num_wrt, const char** wrt, SymbolHandle* out) {
994 API_BEGIN();
995 LOG(FATAL) << "not implemented";
996 API_END();
997 }
998
MXQuantizeSymbol(SymbolHandle sym_handle,SymbolHandle * ret_sym_handle,const int * dev_type,const uint32_t num_excluded_sym_names,const char ** excluded_sym_names,const uint32_t num_excluded_op_names,const char ** excluded_op_names,const uint32_t num_offline,const char ** offline_params,const char * quantized_dtype,const bool calib_quantize,const char * quantize_mode,const char * quantize_granularity,mx_uint * out_num_calib_names,const char *** out_calib_names)999 int MXQuantizeSymbol(SymbolHandle sym_handle,
1000 SymbolHandle *ret_sym_handle,
1001 const int* dev_type,
1002 const uint32_t num_excluded_sym_names,
1003 const char **excluded_sym_names,
1004 const uint32_t num_excluded_op_names,
1005 const char **excluded_op_names,
1006 const uint32_t num_offline,
1007 const char **offline_params,
1008 const char *quantized_dtype,
1009 const bool calib_quantize,
1010 const char *quantize_mode,
1011 const char *quantize_granularity,
1012 mx_uint* out_num_calib_names,
1013 const char ***out_calib_names) {
1014 nnvm::Symbol *s = new nnvm::Symbol();
1015 API_BEGIN();
1016 nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(sym_handle);
1017 nnvm::Graph g = Symbol2Graph(*sym);
1018 int target_dev = *dev_type;
1019 std::unordered_set<std::string> excluded_node_names;
1020 for (size_t i = 0; i < num_excluded_sym_names; ++i) {
1021 excluded_node_names.emplace(excluded_sym_names[i]);
1022 }
1023 std::unordered_set<std::string> excluded_op;
1024 for (size_t i = 0; i < num_excluded_op_names; ++i) {
1025 excluded_op.emplace(excluded_op_names[i]);
1026 }
1027 std::unordered_set<std::string> offline;
1028 for (size_t i = 0; i < num_offline; ++i) {
1029 offline.emplace(offline_params[i]);
1030 }
1031 std::string quantized_type(quantized_dtype);
1032 std::string quantized_mode(quantize_mode);
1033 std::string quantized_granularity(quantize_granularity);
1034 g.attrs["excluded_nodes"] = std::make_shared<nnvm::any>(std::move(excluded_node_names));
1035 g.attrs["excluded_ops"] = std::make_shared<nnvm::any>(std::move(excluded_op));
1036 g.attrs["offline_params"] = std::make_shared<nnvm::any>(std::move(offline));
1037 g.attrs["quantized_dtype"] = std::make_shared<nnvm::any>(std::move(quantized_type));
1038 g.attrs["target_ctx"] = std::make_shared<nnvm::any>(target_dev);
1039 g.attrs["quantize_mode"] = std::make_shared<nnvm::any>(std::move(quantized_mode));
1040 g.attrs["quantize_granularity"] = std::make_shared<nnvm::any>(std::move(quantized_granularity));
1041 g = ApplyPass(std::move(g), "QuantizeGraph");
1042 const auto& calib_nodes = g.GetAttr<std::vector<std::string>>("calib_nodes");
1043 MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
1044 ret->ret_vec_str = std::move(calib_nodes);
1045 *out_num_calib_names = ret->ret_vec_str.size();
1046 ret->ret_vec_charp.clear();
1047 ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
1048 for (const auto &str : ret->ret_vec_str) {
1049 ret->ret_vec_charp.push_back(str.c_str());
1050 }
1051 *out_calib_names = dmlc::BeginPtr(ret->ret_vec_charp);
1052 s->outputs = g.outputs;
1053 *ret_sym_handle = s;
1054 API_END_HANDLE_ERROR(delete s);
1055 }
1056
1057 // helper function to add mapping of node_name -> dtype map
1058 // for the given indexed graph and inferred_dtypes
_SetInputDTypes(const nnvm::IndexedGraph & idx,const nnvm::DTypeVector & inferred_dtypes,std::unordered_map<std::string,int> * node_name_dtype_map,std::unordered_map<std::string,int> * node_without_dtype_map)1059 static void _SetInputDTypes(
1060 const nnvm::IndexedGraph& idx,
1061 const nnvm::DTypeVector& inferred_dtypes,
1062 std::unordered_map<std::string, int>* node_name_dtype_map,
1063 std::unordered_map<std::string, int>* node_without_dtype_map) {
1064 const std::string dtype_keyword = "__dtype__";
1065 for (uint32_t nid : idx.input_nodes()) {
1066 const auto& node = idx[nid].source;
1067 const auto& node_with_dtype = node->attrs.dict.find(dtype_keyword);
1068 // input nodes classified into nodes_with_dtype, nodes_without_dtype
1069 // This classification required because if param_names not provided
1070 // we want to update dtypes of only those nodes which have dtypes set
1071 // inferred_dtypes are obtained for the nodes, if unknown
1072 // dtype is set to fp32
1073 if (node_with_dtype != node->attrs.dict.end()) {
1074 if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) {
1075 (*node_name_dtype_map)[node->attrs.name] = 0;
1076 } else {
1077 (*node_name_dtype_map)[node->attrs.name] =
1078 inferred_dtypes[idx.entry_id(nid, 0)];
1079 }
1080 } else {
1081 if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) {
1082 (*node_without_dtype_map)[node->attrs.name] = 0;
1083 } else {
1084 (*node_without_dtype_map)[node->attrs.name] =
1085 inferred_dtypes[idx.entry_id(nid, 0)];
1086 }
1087 }
1088 }
1089 }
1090
1091 // helper function update the node dtype attrs for a vector of nodeptrs
1092 // given the node name to dtype information and the names of model_params
1093 // if model_params is provided the function will dtype of only model params.
1094 // if model_params is empty, the function will dtype of all nodes which had
1095 // a prior dtype set.
1096 // args is a const_reference vector of ObjectPtrs. ObjectPtrs are immutable but
1097 // the Nodes they are pointing will be mutated in this function
_UpdateSymDTypeAttrs(const std::unordered_map<std::string,int> & node_name_dtype_map,const std::unordered_map<std::string,int> & node_without_dtype_map,const std::unordered_set<std::string> & model_params,const std::vector<nnvm::ObjectPtr> & args)1098 static void _UpdateSymDTypeAttrs(
1099 const std::unordered_map<std::string, int>& node_name_dtype_map,
1100 const std::unordered_map<std::string, int>& node_without_dtype_map,
1101 const std::unordered_set<std::string>& model_params,
1102 const std::vector<nnvm::ObjectPtr>& args) {
1103 const std::string dtype_keyword = "__dtype__";
1104
1105 // Update args to have the right dtype attrs
1106 if (model_params.size() > 0) {
1107 // if model params provided, set dtype only for model params
1108 for (size_t i = 0; i < args.size(); ++i) {
1109 const std::string& node_name = args[i]->attrs.name;
1110 auto it_model_params = model_params.find(node_name);
1111 auto it_with_dtype = node_name_dtype_map.find(node_name);
1112 auto it_without_dtype = node_without_dtype_map.find(node_name);
1113 if (it_model_params != model_params.end()) {
1114 // need to update __dtype__ attribute if already set, else set it
1115 if (it_with_dtype != node_name_dtype_map.end()) {
1116 args[i]->attrs.dict[dtype_keyword] =
1117 std::to_string(it_with_dtype->second);
1118 } else {
1119 CHECK(it_without_dtype != node_without_dtype_map.end())
1120 << "make sure all nodes without dtype have properly been added "
1121 "in node_without_dtype_map";
1122 args[i]->attrs.dict[dtype_keyword] =
1123 std::to_string(it_without_dtype->second);
1124 }
1125 }
1126 }
1127 } else {
1128 // if model params not provided, update __dtype__ for all inputs,
1129 // which already had it set, don't touch the rest
1130 for (size_t i = 0; i < args.size(); ++i) {
1131 auto it = node_name_dtype_map.find(args[i]->attrs.name);
1132 if (it != node_name_dtype_map.end()) {
1133 if (args[i]->attrs.dict.find(dtype_keyword) !=
1134 args[i]->attrs.dict.end()) {
1135 args[i]->attrs.dict[dtype_keyword] = std::to_string(it->second);
1136 }
1137 }
1138 }
1139 }
1140 }
1141
MXReducePrecisionSymbol(SymbolHandle sym_handle,SymbolHandle * ret_sym_handle,uint32_t num_args,const int * arg_type_data,uint32_t num_ind_ptr,const int * ind_ptr,const int * target_dtype,const int cast_optional_params,const uint32_t num_target_dtype_op_names,const uint32_t num_fp32_op_names,const uint32_t num_widest_dtype_op_names,const uint32_t num_conditional_fp32_op_names,const uint32_t num_excluded_symbols,const uint32_t num_model_params,const char ** target_dtype_op_names,const char ** fp32_op_names,const char ** widest_dtype_op_names,const char ** conditional_fp32_op_names,const char ** excluded_symbols,const char ** param_names,const char ** param_vals,const char ** model_param_names,const char ** arg_names)1142 int MXReducePrecisionSymbol(SymbolHandle sym_handle,
1143 SymbolHandle *ret_sym_handle,
1144 uint32_t num_args,
1145 const int *arg_type_data,
1146 uint32_t num_ind_ptr,
1147 const int* ind_ptr,
1148 const int* target_dtype,
1149 const int cast_optional_params,
1150 const uint32_t num_target_dtype_op_names,
1151 const uint32_t num_fp32_op_names,
1152 const uint32_t num_widest_dtype_op_names,
1153 const uint32_t num_conditional_fp32_op_names,
1154 const uint32_t num_excluded_symbols,
1155 const uint32_t num_model_params,
1156 const char **target_dtype_op_names,
1157 const char **fp32_op_names,
1158 const char **widest_dtype_op_names,
1159 const char **conditional_fp32_op_names,
1160 const char **excluded_symbols,
1161 const char **param_names,
1162 const char **param_vals,
1163 const char **model_param_names,
1164 const char **arg_names) {
1165 nnvm::Symbol *result_sym = new nnvm::Symbol();
1166 API_BEGIN();
1167 nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
1168 nnvm::Graph g = Symbol2Graph(*sym);
1169 std::unordered_set<std::string> target_dtype_ops;
1170 std::unordered_set<std::string> fp32_ops;
1171 std::unordered_set<std::string> widest_dtype_ops;
1172 std::unordered_set<std::string> excluded_syms;
1173 std::unordered_set<std::string> model_params;
1174
1175 // conditional_fp32_ops contains the mapping of op_name -> (map of param_name -> param_values)
1176 // which need to be conditionally selected to be casted to FP32
1177 std::unordered_map<std::string,
1178 std::unordered_map<std::string,
1179 std::vector<std::string>>> conditional_fp32_ops;
1180 int target_dt = *target_dtype;
1181
1182 for (size_t i = 0; i < num_target_dtype_op_names; ++i) {
1183 target_dtype_ops.emplace(target_dtype_op_names[i]);
1184 }
1185 for (size_t i = 0; i < num_fp32_op_names; ++i) {
1186 fp32_ops.emplace(fp32_op_names[i]);
1187 }
1188 for (size_t i = 0; i < num_widest_dtype_op_names; ++i) {
1189 widest_dtype_ops.emplace(widest_dtype_op_names[i]);
1190 }
1191 for (size_t i = 0; i < num_excluded_symbols; ++i) {
1192 excluded_syms.emplace(excluded_symbols[i]);
1193 }
1194 for (size_t i = 0; i < num_model_params; ++i) {
1195 model_params.emplace(model_param_names[i]);
1196 }
1197
1198 for (size_t i = 0; i < num_ind_ptr - 1; ++i) {
1199 for (int j = ind_ptr[i]; j < ind_ptr[i + 1]; ++j) {
1200 conditional_fp32_ops[conditional_fp32_op_names[i]][param_names[i]]
1201 .emplace_back(std::string(param_vals[j]));
1202 }
1203 }
1204
1205 std::unordered_map<std::string, int> kwargs;
1206 std::unordered_map<std::string, int> node_name_dtype_map, node_without_dtype_map;
1207 nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1);
1208 for (uint32_t i = 0; i < num_args; ++i) {
1209 kwargs[arg_names[i]] = arg_type_data[i];
1210 node_name_dtype_map[arg_names[i]] = arg_type_data[i];
1211 }
1212 mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
1213
1214 g.attrs["target_dtype_ops"] =
1215 std::make_shared<nnvm::any>(std::move(target_dtype_ops));
1216 g.attrs["fp32_ops"] = std::make_shared<nnvm::any>(std::move(fp32_ops));
1217 g.attrs["widest_dtype_ops"] =
1218 std::make_shared<nnvm::any>(std::move(widest_dtype_ops));
1219 g.attrs["conditional_fp32_ops"] =
1220 std::make_shared<nnvm::any>(std::move(conditional_fp32_ops));
1221 g.attrs["excluded_syms"] =
1222 std::make_shared<nnvm::any>(std::move(excluded_syms));
1223 g.attrs["target_dtype"] = std::make_shared<nnvm::any>(target_dt);
1224 g.attrs["data_name_types"] = std::make_shared<nnvm::any>(kwargs);
1225 g.attrs["cast_optional_params"] = std::make_shared<nnvm::any>(cast_optional_params);
1226
1227 g = ApplyPass(std::move(g), "ReducePrecision");
1228 // Need to run type inference since it is possible that inferred
1229 // type of some inputs has changed
1230 g = mxnet::exec::InferType(std::move(g), std::move(arg_types), "");
1231 const nnvm::DTypeVector &inferred_dtypes =
1232 g.GetAttr<nnvm::DTypeVector>("dtype");
1233
1234 g.attrs["inferred_dtypes"] = std::make_shared<dmlc::any>(std::move(inferred_dtypes));
1235 g.attrs["target_dtype"] = std::make_shared<nnvm::any>(target_dt);
1236
1237 if (cast_optional_params) {
1238 g = ApplyPass(std::move(g), "AMPInferUnknown");
1239 const nnvm::DTypeVector &inferred_dtype_result =
1240 g.GetAttr<nnvm::DTypeVector>("inferred_dtype_result");
1241 const nnvm::IndexedGraph &idx = g.indexed_graph();
1242 // set node name -> input dtype mapping using infer dtype
1243 _SetInputDTypes(idx, inferred_dtype_result, &node_name_dtype_map, &node_without_dtype_map);
1244 } else {
1245 const nnvm::IndexedGraph &idx = g.indexed_graph();
1246 // set node name -> input dtype mapping using infer dtype
1247 _SetInputDTypes(idx, inferred_dtypes, &node_name_dtype_map, &node_without_dtype_map);
1248 }
1249
1250
1251 result_sym->outputs = g.outputs;
1252 *ret_sym_handle = result_sym;
1253 nnvm::Symbol *ret_sym = static_cast<nnvm::Symbol *>(*ret_sym_handle);
1254 const std::vector<nnvm::ObjectPtr>& args = ret_sym->ListInputs(nnvm::Symbol::kAll);
1255
1256 // update symbol dtype attrs using the node name -> dtype mapping, if dtype is already set
1257 // in the symbol, else set dtype for the model_params
1258 _UpdateSymDTypeAttrs(node_name_dtype_map, node_without_dtype_map, model_params, args);
1259
1260 API_END_HANDLE_ERROR(delete result_sym);
1261 }
1262
MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,const uint32_t num_layers,const char ** layer_names,const float * min_ranges,const float * max_ranges,SymbolHandle * ret_qsym_handle)1263 int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
1264 const uint32_t num_layers,
1265 const char** layer_names,
1266 const float* min_ranges,
1267 const float* max_ranges,
1268 SymbolHandle* ret_qsym_handle) {
1269 nnvm::Symbol* s = new nnvm::Symbol();
1270 API_BEGIN();
1271 nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(qsym_handle);
1272 nnvm::Graph g = Symbol2Graph(*sym);
1273 std::unordered_map<std::string, std::pair<float, float>> calib_table;
1274 for (size_t i = 0; i < num_layers; ++i) {
1275 calib_table.emplace(layer_names[i], std::make_pair(min_ranges[i], max_ranges[i]));
1276 }
1277 g.attrs["calib_table"] = std::make_shared<nnvm::any>(std::move(calib_table));
1278 g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph");
1279 s->outputs = g.outputs;
1280 *ret_qsym_handle = s;
1281 API_END_HANDLE_ERROR(delete s);
1282 }
1283
MXGenBackendSubgraph(SymbolHandle sym_handle,const char * backend_name,SymbolHandle * ret_sym_handle)1284 int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name,
1285 SymbolHandle *ret_sym_handle) {
1286 nnvm::Symbol *s = new nnvm::Symbol();
1287 API_BEGIN();
1288 nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
1289 *s = sym->Copy();
1290 auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
1291 const auto& subgraph_prop_list = backend->GetSubgraphProperties();
1292 for (auto property : subgraph_prop_list) {
1293 if (property->HasAttr("disable") && property->GetAttr<bool>("disable") == true) {
1294 auto full_name = property->HasAttr("property_name")
1295 ? property->GetAttr<std::string>("property_name")
1296 : std::string();
1297 LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
1298 << " is disabled.";
1299 continue;
1300 }
1301 nnvm::Graph g = Symbol2Graph(*s);
1302 property->SetAttr("graph", g);
1303 g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
1304 g = ApplyPass(std::move(g), "BuildSubgraph");
1305 property->RemoveAttr("graph");
1306 g.attrs.erase("subgraph_property");
1307 s->outputs = g.outputs;
1308 }
1309 *ret_sym_handle = s;
1310 API_END_HANDLE_ERROR(delete s);
1311 }
1312
MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle,SymbolHandle * ret_sym_handle)1313 int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle) {
1314 nnvm::Symbol *s = new nnvm::Symbol();
1315 API_BEGIN();
1316 nnvm::Symbol *source = static_cast<nnvm::Symbol *>(sym_handle);
1317 CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs.";
1318 const auto &node = source->outputs[0].node;
1319 for (const auto &other_node : source->outputs) {
1320 if (node.get() != other_node.node.get()) {
1321 LOG(FATAL)
1322 << "Generating atomic symbol from other symbol only works for nongrouped symbol.";
1323 }
1324 }
1325 const auto *op = node->op();
1326 const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow);
1327 *s = nnvm::Symbol::CreateFunctor(op, attrs);
1328 *ret_sym_handle = s;
1329 API_END_HANDLE_ERROR(delete s);
1330 }
1331
MXShallowCopySymbol(SymbolHandle src,SymbolHandle * out)1332 int MXShallowCopySymbol(SymbolHandle src, SymbolHandle* out) {
1333 nnvm::Symbol* out_sym = new nnvm::Symbol;
1334 API_BEGIN();
1335 nnvm::Symbol* src_sym = static_cast<nnvm::Symbol*>(src);
1336 *out_sym = *src_sym;
1337 *out = out_sym;
1338 API_END_HANDLE_ERROR(delete out_sym);
1339 }
1340
MXOptimizeForBackend(SymbolHandle sym_handle,const char * backend_name,const int dev_type,SymbolHandle * ret_sym_handle,const mx_uint args_len,NDArrayHandle * in_args_handle,const mx_uint aux_len,NDArrayHandle * in_aux_handle,const mx_uint num_options,const char ** keys,const char ** vals,const uint32_t num_input_shapes,const char ** input_shape_names,const int64_t * input_shape_data,const uint32_t * input_shape_idx,const uint32_t num_input_dtypes,const char ** input_dtype_names,const int * input_dtypes,const uint32_t num_input_stypes,const char ** input_stype_names,const int * input_stypes,bool skip_infer,int * new_args_cnt,NDArrayHandle ** new_args_handle,char *** new_arg_names_handle,int * new_aux_cnt,NDArrayHandle ** new_aux_handle,char *** new_aux_names_handle)1341 int MXOptimizeForBackend(SymbolHandle sym_handle,
1342 const char* backend_name,
1343 const int dev_type,
1344 SymbolHandle* ret_sym_handle,
1345 const mx_uint args_len,
1346 NDArrayHandle* in_args_handle,
1347 const mx_uint aux_len,
1348 NDArrayHandle* in_aux_handle,
1349 const mx_uint num_options,
1350 const char** keys,
1351 const char** vals,
1352 const uint32_t num_input_shapes,
1353 const char** input_shape_names,
1354 const int64_t* input_shape_data,
1355 const uint32_t* input_shape_idx,
1356 const uint32_t num_input_dtypes,
1357 const char** input_dtype_names,
1358 const int* input_dtypes,
1359 const uint32_t num_input_stypes,
1360 const char** input_stype_names,
1361 const int* input_stypes,
1362 bool skip_infer,
1363 int* new_args_cnt,
1364 NDArrayHandle** new_args_handle,
1365 char*** new_arg_names_handle,
1366 int* new_aux_cnt,
1367 NDArrayHandle** new_aux_handle,
1368 char*** new_aux_names_handle) {
1369 // create copy of input symbol
1370 nnvm::Symbol *s = new nnvm::Symbol();
1371 API_BEGIN();
1372 nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
1373 *s = sym->Copy();
1374
1375 // create a data structure from pointer array
1376 std::unordered_map<std::string, std::string> options_map;
1377 for (mx_uint i = 0; i < num_options; ++i)
1378 options_map.emplace(keys[i], vals[i]);
1379
1380 NDArray ***new_args_ptr = reinterpret_cast<NDArray***>(new_args_handle);
1381 NDArray ***new_aux_ptr = reinterpret_cast<NDArray***>(new_aux_handle);
1382 NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
1383 NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
1384
1385 auto init_graph = [&](nnvm::Symbol* s) {
1386 nnvm::Graph g = Symbol2Graph(*s);
1387 const auto& indexed_graph = g.indexed_graph();
1388 const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
1389 std::vector<std::string> input_names = s->ListInputNames(nnvm::Symbol::kAll);
1390 size_t num_forward_inputs = input_names.size();
1391
1392 if (args_len || aux_len) {
1393 if (!skip_infer) {
1394 Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
1395 mxnet::ShapeVector arg_shapes(args_len + aux_len);
1396 nnvm::DTypeVector arg_dtypes(args_len + aux_len);
1397 StorageTypeVector arg_stypes(args_len + aux_len);
1398
1399 // create the input shape, dtype and stype maps
1400 std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
1401 for (uint32_t i = 0; i < num_input_shapes; ++i) {
1402 input_shape_map.emplace(input_shape_names[i],
1403 mxnet::TShape(input_shape_data + input_shape_idx[i],
1404 input_shape_data + input_shape_idx[i+1]));
1405 }
1406 std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
1407 for (uint32_t i = 0; i < num_input_dtypes; ++i) {
1408 input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
1409 }
1410 std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
1411 for (uint32_t i = 0; i < num_input_stypes; ++i) {
1412 input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
1413 }
1414
1415 size_t args_top = 0, aux_top = 0;
1416 // loop over inputs to symbol in order and add to args/aux if mutable
1417 for (size_t i = 0; i < num_forward_inputs; ++i) {
1418 const uint32_t nid = indexed_graph.input_nodes().at(i);
1419 if (mutable_nodes.count(nid)) {
1420 CHECK_LT(aux_top, aux_len)
1421 << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
1422 if (in_aux_ptr[aux_top] != nullptr) {
1423 const auto &in_arg = *(in_aux_ptr[aux_top]);
1424 arg_shapes[i] = in_arg.shape();
1425 arg_dtypes[i] = in_arg.dtype();
1426 arg_stypes[i] = in_arg.storage_type();
1427 }
1428 aux_top++;
1429 } else {
1430 auto name = input_names[i];
1431 CHECK_LT(args_top, args_len)
1432 << "Cannot find arg '" << name << "' in provided args to optimize_for";
1433 if (in_args_ptr[args_top] != nullptr) {
1434 const auto &in_arg = *(in_args_ptr[args_top]);
1435 arg_shapes[i] = in_arg.shape();
1436 arg_dtypes[i] = in_arg.dtype();
1437 arg_stypes[i] = in_arg.storage_type();
1438 } else {
1439 // input_names[i] is not in args but can be in the optional
1440 // shape/type/stype attribute dicts.
1441 auto it_shape = input_shape_map.find(name);
1442 if (it_shape != input_shape_map.end()) {
1443 arg_shapes[i] = it_shape->second;
1444 }
1445 auto it_type = input_dtype_map.find(name);
1446 if (it_type != input_dtype_map.end()) {
1447 arg_dtypes[i] = it_type->second;
1448 }
1449 it_type = input_stype_map.find(name);
1450 if (it_type != input_stype_map.end()) {
1451 arg_stypes[i] = it_type->second;
1452 }
1453 }
1454 args_top++;
1455 }
1456 }
1457
1458 g.attrs["context"] = std::make_shared<nnvm::any>(
1459 exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
1460
1461 // infer shapes
1462 g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
1463 // infer dtypes
1464 g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
1465 // infer stypes
1466 g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
1467 }
1468 // set args/aux as attributes on graph so that subgraph property can use them
1469 std::vector<std::string> arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1470 g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
1471 g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
1472
1473 std::vector<std::string> aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1474 g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
1475 g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
1476 } else {
1477 // args/aux were not specified, so set nullptr/empty-lists
1478 NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
1479 std::vector<std::string> arg_names;
1480 g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
1481 g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
1482
1483 NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
1484 std::vector<std::string> aux_names;
1485 g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
1486 g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
1487 }
1488
1489 // set dedup option as attribute on graph to enable dedup during partitioning
1490 if (options_map.count("dedup_subgraph") > 0 &&
1491 options_map.at("dedup_subgraph").compare("True") == 0)
1492 g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
1493 return g;
1494 };
1495
1496 if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) {
1497 // use subgraph backend
1498 const auto backend = mxnet::op::SubgraphBackendRegistry
1499 ::Get()->GetSubgraphBackend(backend_name);
1500 const auto& subgraph_prop_list = backend->GetSubgraphProperties();
1501 for (auto property : subgraph_prop_list) {
1502 nnvm::Graph g = init_graph(s);
1503 property->PrePartition(g, options_map);
1504 g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
1505 g = ApplyPass(std::move(g), "BuildSubgraph");
1506 g.attrs.erase("subgraph_property");
1507 property->PostPartition(g);
1508 s->outputs = g.outputs;
1509 }
1510 } else if (dmlc::Registry<nnvm::PassFunctionReg>::Find(backend_name) != nullptr) {
1511 // use graph pass
1512 nnvm::Graph g = init_graph(s);
1513 g.attrs["options_map"] = std::make_shared<nnvm::any>(options_map);
1514 g.attrs["pass_name"] = std::make_shared<nnvm::any>(backend_name);
1515 g = ApplyPass(std::move(g), backend_name);
1516
1517 std::vector<NDArray*> new_args = g.GetAttr<std::vector<NDArray*>>("new_args");
1518 std::vector<NDArray*> new_aux = g.GetAttr<std::vector<NDArray*>>("new_aux");
1519 std::vector<std::string> new_arg_names = g.GetAttr<std::vector<std::string>>("new_arg_names");
1520 std::vector<std::string> new_aux_names = g.GetAttr<std::vector<std::string>>("new_aux_names");
1521 g.attrs.erase("new_args");
1522 g.attrs.erase("new_aux");
1523 g.attrs.erase("new_arg_names");
1524 g.attrs.erase("new_aux_names");
1525 s->outputs = g.outputs;
1526
1527 NDArray** new_arg_arr = new NDArray*[new_arg_names.size()];
1528 NDArray** new_aux_arr = new NDArray*[new_aux_names.size()];
1529 char** new_arg_cstr = new char*[new_arg_names.size()];
1530 char** new_aux_cstr = new char*[new_aux_names.size()];
1531 for (unsigned i = 0; i < new_arg_names.size(); i++) {
1532 new_arg_arr[i] = new_args[i];
1533 std::string& s = new_arg_names[i];
1534 char* tmp = new char[s.length()+1];
1535 s.copy(tmp, s.length());
1536 tmp[s.length()] = '\0';
1537 new_arg_cstr[i] = tmp;
1538 }
1539 for (unsigned i = 0; i < new_aux_names.size(); i++) {
1540 new_aux_arr[i] = new_aux[i];
1541 std::string& s = new_aux_names[i];
1542 char* tmp = new char[s.length()+1];
1543 s.copy(tmp, s.length());
1544 tmp[s.length()] = '\0';
1545 new_aux_cstr[i] = tmp;
1546 }
1547 *new_args_cnt = new_arg_names.size();
1548 *new_aux_cnt = new_aux_names.size();
1549 *new_arg_names_handle = new_arg_cstr;
1550 *new_aux_names_handle = new_aux_cstr;
1551 *new_args_ptr = new_arg_arr;
1552 *new_aux_ptr = new_aux_arr;
1553 } else {
1554 // cannot find graph pass or subgraph backend registered in this name
1555 LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found";
1556 }
1557
1558 *ret_sym_handle = s;
1559 API_END_HANDLE_ERROR(delete s);
1560 }
1561