/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file c_api_symbolic.cc * \brief C API of mxnet */ #include "mxnet/base.h" #include "mxnet/c_api.h" #include "mxnet/imperative.h" #include "nnvm/c_api.h" #include "nnvm/pass.h" #include "nnvm/pass_functions.h" #include "nnvm/symbolic.h" #include "./c_api_common.h" #include "../common/exec_utils.h" #include "../operator/operator_common.h" #include "../executor/exec_pass.h" #include "../operator/subgraph/subgraph_property.h" namespace mxnet { namespace op { void RegisterLegacyOpProp(); void RegisterLegacyNDFunc(); } const std::vector kHiddenKeys = { "ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage" }; const std::vector kReplacedHiddenKeys = { "__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__", "__mirror_stage__" }; const char *kNamespaceSeparator = "$"; DMLC_JSON_ENABLE_ANY(int, int); // convert nnvm symbol to a nnvm graph. nnvm::Graph Symbol2Graph(const nnvm::Symbol &s) { nnvm::Graph g; g.outputs = s.outputs; g.attrs["mxnet_version"] = std::make_shared(static_cast(MXNET_VERSION)); if (Imperative::Get()->is_np_shape()) { g.attrs["is_np_shape"] = std::make_shared( static_cast(Imperative::Get()->is_np_shape())); } return g; } std::vector ReadOnlyArgIndices(const nnvm::IndexedGraph& idx) { std::vector ret; auto& arg_nodes = idx.input_nodes(); for (uint32_t i = 0; i < arg_nodes.size(); ++i) { if (idx.mutable_input_nodes().count(arg_nodes[i]) == 0) { ret.push_back(i); } } return ret; } } // namespace mxnet // symbolic configuration generation API. // Redirect to NNVM's C API int MXListAllOpNames(nn_uint *out_size, const char ***out_array) { mxnet::op::RegisterLegacyOpProp(); mxnet::op::RegisterLegacyNDFunc(); return NNListAllOpNames(out_size, out_array); } int MXSymbolListAtomicSymbolCreators(uint32_t *out_size, AtomicSymbolCreator **out_array) { mxnet::op::RegisterLegacyOpProp(); mxnet::op::RegisterLegacyNDFunc(); return NNListUniqueOps(out_size, out_array); } int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, const char **name, const char **description, uint32_t *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **key_var_num_args, const char **return_type) { static auto& map_key_var_args = nnvm::Op::GetAttr("key_var_num_args"); const Op* op = static_cast(creator); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_str.resize(0); if (map_key_var_args.count(op) != 0) { *key_var_num_args = map_key_var_args[op].c_str(); } else { *key_var_num_args = ret->ret_str.c_str(); } return NNGetOpInfo( creator, name, description, num_args, arg_names, arg_type_infos, arg_descriptions, return_type); } int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, uint32_t num_param, const char **keys, const char **vals, SymbolHandle *out) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); const nnvm::Op* op = static_cast(creator); std::unordered_map kwargs; for (nn_uint i = 0; i < num_param; ++i) { bool flag = false; for (const auto &k : kHiddenKeys) { std::string tmp(keys[i]); size_t pos = tmp.rfind(k); if (pos == 0) { kwargs.insert({"__" + tmp + "__", std::string(vals[i])}); flag = true; break; } else if (pos != std::string::npos && pos == tmp.length() - k.length()) { std::ostringstream os; os << "setting variable attributes with " << keys[i] << " is deprecated. " << "please instead use\nw = Variable(" << k << "=" << vals[i] << ")\n" << "sym = YourSymbolName(" << tmp.substr(0, pos-1) << "=w)"; throw dmlc::Error(os.str()); } } if (!flag) kwargs.insert({std::string(keys[i]), std::string(vals[i])}); } *s = nnvm::Symbol::CreateFunctor(op, std::move(kwargs)); *out = s; API_END_HANDLE_ERROR(delete s;); } int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { return NNSymbolCreateVariable(name, out); } int MXSymbolCreateGroup(uint32_t num_symbols, SymbolHandle *symbols, SymbolHandle *out) { return NNSymbolCreateGroup(num_symbols, symbols, out); } int MXSymbolGetOutput(SymbolHandle symbol, uint32_t index, SymbolHandle *out) { return NNSymbolGetOutput(symbol, index, out); } int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetInternals(); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolGetChildren(SymbolHandle symbol, SymbolHandle *out) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetChildren(); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolFree(SymbolHandle symbol) { return NNSymbolFree(symbol); } int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { return NNSymbolCopy(symbol, out); } int MXSymbolPrint(SymbolHandle symbol, const char **out_str) { return NNSymbolPrint(symbol, out_str); } int MXSymbolGetName(SymbolHandle symbol, const char** out, int* success) { return NNSymbolGetAttr(symbol, "name", out, success); } int MXSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success) { nnvm::Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); if (s->GetAttr(key, &(ret->ret_str))) { *out = (ret->ret_str).c_str(); *success = 1; } else { *out = nullptr; *success = 0; if (std::find(kHiddenKeys.begin(), kHiddenKeys.end(), key) != kHiddenKeys.end()) { std::string skey = "__" + std::string(key) + "__"; if (s->GetAttr(skey, &(ret->ret_str))) { *out = (ret->ret_str).c_str(); *success = 1; } } } API_END(); } int MXSymbolSetAttr(SymbolHandle symbol, const char* key, const char* value) { nnvm::Symbol *s = static_cast(symbol); API_BEGIN(); std::vector > kwargs; std::string skey(key), sval(value); for (const auto &k : kHiddenKeys) { size_t pos = skey.rfind(k); if (pos == 0 && k.length() == skey.length()) { skey = "__" + skey + "__"; break; } else if (pos != std::string::npos && pos + k.length() == skey.length()) { std::ostringstream os; os << "setting variable attributes with " << key << " is deprecated. " << "please instead use\nw = Variable(" << k << "=" << value << ")\n" << "sym = YourSymbolName(" << skey.substr(0, pos-1) << "=w)"; throw dmlc::Error(os.str()); } } kwargs.emplace_back(std::make_pair(std::move(skey), std::move(sval))); s->SetAttrs(kwargs); API_END(); } int MXSymbolListAttr(SymbolHandle symbol, uint32_t *out_size, const char*** out) { nnvm::Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::vector > attr = s->ListAttrsRecursive(); std::vector& attr_list = ret->ret_vec_str; attr_list.clear(); for (const auto& tp : attr) { attr_list.emplace_back(std::get<0>(tp) + kNamespaceSeparator + std::get<1>(tp)); attr_list.emplace_back(std::get<2>(tp)); if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), std::get<1>(tp)) != kReplacedHiddenKeys.end()) { attr_list.push_back(std::get<0>(tp) + kNamespaceSeparator + std::get<1>(tp).substr(2, std::get<1>(tp).length() - 4)); attr_list.push_back(std::get<2>(tp)); } } *out_size = attr_list.size()/2; ret->ret_vec_charp.clear(); for (const auto& attr : attr_list) { ret->ret_vec_charp.push_back(attr.c_str()); } *out = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } int MXSymbolListAttrShallow(SymbolHandle symbol, uint32_t *out_size, const char*** out) { nnvm::Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::unordered_map attr = s->ListAttrs(static_cast(1)); // NOLINT(*) std::vector& attr_list = ret->ret_vec_str; attr_list.clear(); for (const auto& kv : attr) { attr_list.push_back(kv.first); attr_list.push_back(kv.second); if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), kv.first) != kReplacedHiddenKeys.end()) { attr_list.push_back(kv.first.substr(2, kv.first.length() - 4)); attr_list.push_back(kv.second); } } *out_size = attr_list.size()/2; ret->ret_vec_charp.clear(); for (auto &attr : attr_list) { ret->ret_vec_charp.push_back(attr.c_str()); } *out = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } int MXSymbolListOutputs(SymbolHandle symbol, uint32_t *out_size, const char ***out_str_array) { return NNSymbolListOutputNames(symbol, out_size, out_str_array); } int MXSymbolGetNumOutputs(SymbolHandle symbol, uint32_t *output_count) { return NNSymbolGetNumOutputs(symbol, output_count); } int MXSymbolCompose(SymbolHandle sym, const char *name, uint32_t num_args, const char** keys, SymbolHandle* args) { return NNSymbolCompose(sym, name, num_args, keys, args); } // adapter functions that re-implements the functions. int MXSymbolListArguments(SymbolHandle symbol, uint32_t *out_size, const char ***out_str_array) { return NNSymbolListInputNames(symbol, 1, out_size, out_str_array); } int MXSymbolListAuxiliaryStates(SymbolHandle symbol, uint32_t *out_size, const char ***out_str_array) { return NNSymbolListInputNames(symbol, 2, out_size, out_str_array); } int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, const char **out) { API_BEGIN(); Op *e = static_cast(creator); *out = e->name.c_str(); API_END(); } namespace mxnet { extern std::vector GetInputSymbols(const nnvm::Symbol &sym); extern bool CutGraphInputs(const std::vector &input_entries, bool skip_var, std::vector *orig_entries); } int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *input_size) { API_BEGIN(); nnvm::Symbol *s = static_cast(sym); std::vector input_syms = mxnet::GetInputSymbols(*s); *input_size = input_syms.size(); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_handles.clear(); ret->ret_handles.reserve(*input_size); for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]); *input_arr = reinterpret_cast(dmlc::BeginPtr(ret->ret_handles)); API_END_HANDLE_ERROR(); } int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, int *input_size) { // Given a graph, we want to fetch the nodes that have been marked as part of // a subgraph. API_BEGIN(); nnvm::Symbol *s = static_cast(sym); const std::string subg_attr = "__subgraph_name__"; auto out_node = s->outputs[0].node; auto it = out_node->attrs.dict.find(subg_attr); if (it != out_node->attrs.dict.end()) { const std::string &subg_name = it->second; std::vector input_entries; DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries] (nnvm::ObjectPtr n) { // If the node itself isn't in the subgraph, we ignore it. auto it = n->attrs.dict.find(subg_attr); if (it == n->attrs.dict.end() || it->second != subg_name) return; // We search for nodes whose node entries aren't in the subgraph. for (size_t j = 0; j < n->inputs.size(); j++) { auto in_node = n->inputs[j].node; auto it = in_node->attrs.dict.find(subg_attr); if (it == in_node->attrs.dict.end() || it->second != subg_name) input_entries.push_back(&n->inputs[j]); } }); std::vector orig_entries; CutGraphInputs(input_entries, false, &orig_entries); std::vector input_syms(orig_entries.size()); for (size_t i = 0; i < input_syms.size(); i++) { input_syms[i] = new nnvm::Symbol(); input_syms[i]->outputs.push_back(orig_entries[i]); } *input_size = input_syms.size(); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_handles.clear(); ret->ret_handles.reserve(*input_size); for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]); *input_symbols = reinterpret_cast(dmlc::BeginPtr(ret->ret_handles)); } else { *input_size = 0; } API_END_HANDLE_ERROR(); } /*! * \brief Convert shape attr in graph nodes to comply with NumPy semantics for * legacy models (before 1.6.0) if global flag is_np_shape has been turned on, * i.e., use -1 to indicate unknown number of dimensions and unknown dimension sizes. */ void ConvertShapeAttrToNumPyCompatible(nnvm::Graph* g) { if (Imperative::Get()->is_np_shape() && (!g->HasAttr("is_np_shape") || !g->GetAttr("is_np_shape"))) { DFSVisit(g->outputs, [](nnvm::ObjectPtr n) { if (n->is_variable()) { auto it = n->attrs.dict.find("__shape__"); if (it != n->attrs.dict.end()) { mxnet::TShape shape; std::istringstream is(it->second); is >> shape; common::ConvertToNumpyShape(&shape); std::ostringstream os; os << shape; it->second = os.str(); } } }); } } int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); dmlc::istream is(fi.get()); nnvm::Graph g; g.attrs["json"] = std::make_shared( std::string(std::istreambuf_iterator(is), std::istreambuf_iterator())); g = nnvm::ApplyPass(g, "LoadLegacyJSON"); ConvertShapeAttrToNumPyCompatible(&g); s->outputs = g.outputs; *out = s; is.set_stream(nullptr); API_END_HANDLE_ERROR(delete s); } int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Graph g; g.attrs["json"] = std::make_shared(std::string(json)); g = nnvm::ApplyPass(g, "LoadLegacyJSON"); ConvertShapeAttrToNumPyCompatible(&g); s->outputs = g.outputs; *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle) { nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *source = static_cast(sym_handle); *s = source->Copy(); s->outputs = nnvm::ApplyPass(Symbol2Graph(*s), "RemoveAmpCast").outputs; *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) { nnvm::Symbol *s = static_cast(symbol); API_BEGIN(); std::unique_ptr fo(dmlc::Stream::Create(fname, "w")); dmlc::ostream os(fo.get()); os << nnvm::pass::SaveJSON(Symbol2Graph(*s)); // reset file pointer, force flush os.set_stream(nullptr); API_END(); } int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) { nnvm::Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); ret->ret_str = nnvm::pass::SaveJSON(Symbol2Graph(*s)); *out_json = ret->ret_str.c_str(); API_END(); } namespace mxnet { template void MatchArguments( const nnvm::IndexedGraph& idx, const std::unordered_map& known_arg_attrs, std::vector* arg_attrs, const char* source) { auto& arg_nodes = idx.input_nodes(); CHECK_EQ(arg_attrs->size(), arg_nodes.size()); size_t nmatched = 0; for (size_t i = 0; i < arg_nodes.size(); ++i) { const std::string& name = idx[arg_nodes[i]].source->attrs.name; auto it = known_arg_attrs.find(name); if (it != known_arg_attrs.end()) { arg_attrs->at(i) = it->second; ++nmatched; } } if (nmatched != known_arg_attrs.size()) { std::unordered_set keys; std::ostringstream head, msg; msg << "\nCandidate arguments:\n"; for (size_t i = 0; i < arg_nodes.size(); ++i) { std::string arg_name = idx[arg_nodes[i]].source->attrs.name; keys.insert(arg_name); msg << "\t[" << i << ']' << arg_name << '\n'; } for (const auto& kv : known_arg_attrs) { const std::string& key = kv.first; if (keys.count(key) == 0) { LOG(FATAL) << source << "Keyword argument name " << key << " not found." << msg.str(); } } } } } // namespace mxnet int MXSymbolInferShape(SymbolHandle sym, uint32_t num_args, const char** keys, const uint32_t *arg_ind_ptr, const uint32_t *arg_shape_data, uint32_t *in_shape_size, const uint32_t **in_shape_ndim, const uint32_t ***in_shape_data, uint32_t *out_shape_size, const uint32_t **out_shape_ndim, const uint32_t ***out_shape_data, uint32_t *aux_shape_size, const uint32_t **aux_shape_ndim, const uint32_t ***aux_shape_data, int *complete) { nnvm::Symbol *s = static_cast(sym); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); nnvm::Graph g = Symbol2Graph(*s); mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); if (keys == nullptr && num_args != 0) { std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); CHECK_LE(num_args, read_only_args.size()); for (uint32_t i = 0; i < num_args; ++i) { arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast( arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); } } else { std::unordered_map kwargs; for (uint32_t i = 0; i < num_args; ++i) { kwargs[keys[i]] = mxnet::ShapeTypeCast( arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); } mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape"); } try { g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); } catch (const mxnet::op::InferShapeError &err) { throw dmlc::Error(err.msg); } // if use legacy shape definition, need to convert numpy shape to legacy shape mxnet::ShapeVector shapes = g.GetAttr("shape"); if (!Imperative::Get()->is_np_shape()) { common::ConvertToLegacyShape(&shapes); } // copy back CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); // copy data back MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->arg_shapes, &(ret->arg_shape_ndim), &(ret->arg_shape_data), &(ret->arg_shape_buffer)); MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->out_shapes, &(ret->out_shape_ndim), &(ret->out_shape_data), &(ret->out_shape_buffer)); MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->aux_shapes, &(ret->aux_shape_ndim), &(ret->aux_shape_data), &(ret->aux_shape_buffer)); *in_shape_size = static_cast(ret->arg_shapes.size()); *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim); *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data); *out_shape_size = static_cast(ret->out_shapes.size()); *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim); *out_shape_data = dmlc::BeginPtr(ret->out_shape_data); *aux_shape_size = static_cast(ret->aux_shapes.size()); *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim); *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data); // mark complete *complete = (g.GetAttr("shape_num_unknown_nodes") == 0); API_END(); } template inline void SymbolInferShape(const char** keys, uint32_t num_args, const dtype* arg_shape_data, const itype* arg_ind_ptr, const int** in_shape_ndim, const dtype*** in_shape_data, const int** out_shape_ndim, const dtype*** out_shape_data, const int** aux_shape_ndim, const dtype*** aux_shape_data, nnvm::Symbol* s, MXAPIThreadLocalEntry* ret, stype* in_shape_size, stype* out_shape_size, stype* aux_shape_size, int* complete) { nnvm::Graph g = Symbol2Graph(*s); mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); if (keys == nullptr && num_args != 0) { std::vector < uint32_t > read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); CHECK_LE(num_args, read_only_args.size()); for (uint32_t i = 0; i < num_args; ++i) { arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i + 1]); } } else { std::unordered_map kwargs; for (uint32_t i = 0; i < num_args; ++i) { kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i + 1]); } mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape"); } try { g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); } catch (const mxnet::op::InferShapeError& err) { throw dmlc::Error(err.msg); } // if use legacy shape definition, need to convert numpy shape to legacy shape mxnet::ShapeVector shapes = g.GetAttr("shape"); if (!Imperative::Get()->is_np_shape()) { common::ConvertToLegacyShape(&shapes); } // copy back CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); // copy data back MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes, &(ret->arg_shape_ndim_ex), &(ret->arg_shape_data_ex), &(ret->arg_shape_buffer_ex)); MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->out_shapes, &(ret->out_shape_ndim_ex), &(ret->out_shape_data_ex), &(ret->out_shape_buffer_ex)); MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes, &(ret->aux_shape_ndim_ex), &(ret->aux_shape_data_ex), &(ret->aux_shape_buffer_ex)); *in_shape_size = static_cast(ret->arg_shapes.size()); *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex); *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex); *out_shape_size = static_cast(ret->out_shapes.size()); *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex); *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex); *aux_shape_size = static_cast(ret->aux_shapes.size()); *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex); *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex); // mark complete *complete = (g.GetAttr("shape_num_unknown_nodes") == 0); } /*! * \brief Executor for Symbol Shape Inference * This api is available when MXNet is built with flag * USE_INT64_TENSOR_SIZE=0 (by default) * \param sym symbol handle * \param num_args number of args * \param keys keys * \param arg_ind_ptr arg index pointer * \param arg_shape_data arg shape data * \param in_shape_size input shape size * \param in_shape_ndim input shape number of dims * \param in_shape_data input shape data * \param out_shape_size ouput shape size * \param out_shape_ndim output shape number of dims * \param out_shape_data output shape data * \param aux_shape_size shape size of auxiliary states * \param aux_shape_ndim number of dims of auxiliary states shape * \param aux_shape_data shape data of auxiliary states * \param complete indicates completion of Shape Inference * \return 0 when success, -1 when failure happens */ int MXSymbolInferShapeEx(SymbolHandle sym, uint32_t num_args, const char** keys, const uint32_t *arg_ind_ptr, const int *arg_shape_data, uint32_t *in_shape_size, const int **in_shape_ndim, const int ***in_shape_data, uint32_t *out_shape_size, const int **out_shape_ndim, const int ***out_shape_data, uint32_t *aux_shape_size, const int **aux_shape_ndim, const int ***aux_shape_data, int *complete) { nnvm::Symbol *s = static_cast(sym); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); SymbolInferShape(keys, num_args, arg_shape_data, arg_ind_ptr, in_shape_ndim, in_shape_data, out_shape_ndim, out_shape_data, aux_shape_ndim, aux_shape_data, s, ret, in_shape_size, out_shape_size, aux_shape_size, complete); API_END(); } /*! * \brief Executor for Symbol Shape Inference * This api is available when MXNet is built with flag * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support * \param sym symbol handle * \param num_args number of args * \param keys keys * \param arg_ind_ptr arg index pointer * \param arg_shape_data arg shape data * \param in_shape_size input shape size * \param in_shape_ndim input shape number of dims * \param in_shape_data input shape data * \param out_shape_size ouput shape size * \param out_shape_ndim output shape number of dims * \param out_shape_data output shape data * \param aux_shape_size shape size of auxiliary states * \param aux_shape_ndim number of dims of auxiliary states shape * \param aux_shape_data shape data of auxiliary states * \param complete indicates completion of Shape Inference * \return 0 when success, -1 when failure happens */ int MXSymbolInferShapeEx64(SymbolHandle sym, uint32_t num_args, const char** keys, const int64_t *arg_ind_ptr, const int64_t *arg_shape_data, size_t *in_shape_size, const int **in_shape_ndim, const int64_t ***in_shape_data, size_t *out_shape_size, const int **out_shape_ndim, const int64_t ***out_shape_data, size_t *aux_shape_size, const int **aux_shape_ndim, const int64_t ***aux_shape_data, int *complete) { nnvm::Symbol *s = static_cast(sym); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); SymbolInferShape(keys, num_args, arg_shape_data, arg_ind_ptr, in_shape_ndim, in_shape_data, out_shape_ndim, out_shape_data, aux_shape_ndim, aux_shape_data, s, ret, in_shape_size, out_shape_size, aux_shape_size, complete); API_END(); } int MXSymbolInferShapePartial(SymbolHandle sym, uint32_t num_args, const char** keys, const uint32_t *arg_ind_ptr, const uint32_t *arg_shape_data, uint32_t *in_shape_size, const uint32_t **in_shape_ndim, const uint32_t ***in_shape_data, uint32_t *out_shape_size, const uint32_t **out_shape_ndim, const uint32_t ***out_shape_data, uint32_t *aux_shape_size, const uint32_t **aux_shape_ndim, const uint32_t ***aux_shape_data, int *complete) { int succ = 0; *complete = 1; return MXSymbolInferShape(sym, num_args, keys, arg_ind_ptr, arg_shape_data, in_shape_size, in_shape_ndim, in_shape_data, out_shape_size, out_shape_ndim, out_shape_data, aux_shape_size, aux_shape_ndim, aux_shape_data, &succ); } /*! * \brief Executor for Symbol Partial Shape Inference * This api is available when MXNet is built with flag * USE_INT64_TENSOR_SIZE=0 (by default) * \param sym symbol handle * \param num_args number of args * \param keys keys * \param arg_ind_ptr arg index pointer * \param arg_shape_data arg shape data * \param in_shape_size input shape size * \param in_shape_ndim input shape number of dims * \param in_shape_data input shape data * \param out_shape_size ouput shape size * \param out_shape_ndim output shape number of dims * \param out_shape_data output shape data * \param aux_shape_size shape size of auxiliary states * \param aux_shape_ndim number of dims of auxiliary states shape * \param aux_shape_data shape data of auxiliary states * \param complete indicates completion of Shape Inference * \return 0 when success, -1 when failure happens */ int MXSymbolInferShapePartialEx(SymbolHandle sym, uint32_t num_args, const char** keys, const uint32_t *arg_ind_ptr, const int *arg_shape_data, uint32_t *in_shape_size, const int **in_shape_ndim, const int ***in_shape_data, uint32_t *out_shape_size, const int **out_shape_ndim, const int ***out_shape_data, uint32_t *aux_shape_size, const int **aux_shape_ndim, const int ***aux_shape_data, int *complete) { int succ = 0; *complete = 1; return MXSymbolInferShapeEx(sym, num_args, keys, arg_ind_ptr, arg_shape_data, in_shape_size, in_shape_ndim, in_shape_data, out_shape_size, out_shape_ndim, out_shape_data, aux_shape_size, aux_shape_ndim, aux_shape_data, &succ); } /*! * \brief Executor for Symbol Partial Shape Inference * This api is available when MXNet is built with flag * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support * \param sym symbol handle * \param num_args number of args * \param keys keys * \param arg_ind_ptr arg index pointer * \param arg_shape_data arg shape data * \param in_shape_size input shape size * \param in_shape_ndim input shape number of dims * \param in_shape_data input shape data * \param out_shape_size ouput shape size * \param out_shape_ndim output shape number of dims * \param out_shape_data output shape data * \param aux_shape_size shape size of auxiliary states * \param aux_shape_ndim number of dims of auxiliary states shape * \param aux_shape_data shape data of auxiliary states * \param complete indicates completion of Shape Inference * \return 0 when success, -1 when failure happens */ int MXSymbolInferShapePartialEx64(SymbolHandle sym, uint32_t num_args, const char** keys, const int64_t *arg_ind_ptr, const int64_t *arg_shape_data, size_t *in_shape_size, const int **in_shape_ndim, const int64_t ***in_shape_data, size_t *out_shape_size, const int **out_shape_ndim, const int64_t ***out_shape_data, size_t *aux_shape_size, const int **aux_shape_ndim, const int64_t ***aux_shape_data, int *complete) { int succ = 0; *complete = 1; return MXSymbolInferShapeEx64(sym, num_args, keys, arg_ind_ptr, arg_shape_data, in_shape_size, in_shape_ndim, in_shape_data, out_shape_size, out_shape_ndim, out_shape_data, aux_shape_size, aux_shape_ndim, aux_shape_data, &succ); } int MXSymbolInferType(SymbolHandle sym, uint32_t num_args, const char** keys, const int *arg_type_data, uint32_t *in_type_size, const int **in_type_data, uint32_t *out_type_size, const int **out_type_data, uint32_t *aux_type_size, const int **aux_type_data, int *complete) { nnvm::Symbol *s = static_cast(sym); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); nnvm::Graph g = Symbol2Graph(*s); nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1); if (keys == nullptr && num_args != 0) { std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); CHECK_LE(num_args, read_only_args.size()); for (uint32_t i = 0; i < num_args; ++i) { arg_types[read_only_args[i]] = arg_type_data[i]; } } else { std::unordered_map kwargs; for (uint32_t i = 0; i < num_args; ++i) { kwargs[keys[i]] = arg_type_data[i]; } mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType"); } g = mxnet::exec::InferType(std::move(g), std::move(arg_types), "__dtype__"); // copy back CopyAttr(g.indexed_graph(), g.GetAttr("dtype"), &(ret->arg_types), &(ret->out_types), &(ret->aux_types)); *in_type_size = static_cast(ret->arg_types.size()); *in_type_data = dmlc::BeginPtr(ret->arg_types); *out_type_size = static_cast(ret->out_types.size()); *out_type_data = dmlc::BeginPtr(ret->out_types); *aux_type_size = static_cast(ret->aux_types.size()); *aux_type_data = dmlc::BeginPtr(ret->aux_types); *complete = (g.GetAttr("dtype_num_unknown_nodes") == 0); API_END(); } int MXSymbolInferTypePartial(SymbolHandle sym, uint32_t num_args, const char** keys, const int *arg_type_data, uint32_t *in_type_size, const int **in_type_data, uint32_t *out_type_size, const int **out_type_data, uint32_t *aux_type_size, const int **aux_type_data, int *complete) { int succ = 0; *complete = 1; return MXSymbolInferType(sym, num_args, keys, arg_type_data, in_type_size, in_type_data, out_type_size, out_type_data, aux_type_size, aux_type_data, &succ); } int MXSymbolGrad(SymbolHandle sym, uint32_t num_wrt, const char** wrt, SymbolHandle* out) { API_BEGIN(); LOG(FATAL) << "not implemented"; API_END(); } int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const int* dev_type, const uint32_t num_excluded_sym_names, const char **excluded_sym_names, const uint32_t num_excluded_op_names, const char **excluded_op_names, const uint32_t num_offline, const char **offline_params, const char *quantized_dtype, const bool calib_quantize, const char *quantize_mode, const char *quantize_granularity, mx_uint* out_num_calib_names, const char ***out_calib_names) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); nnvm::Graph g = Symbol2Graph(*sym); int target_dev = *dev_type; std::unordered_set excluded_node_names; for (size_t i = 0; i < num_excluded_sym_names; ++i) { excluded_node_names.emplace(excluded_sym_names[i]); } std::unordered_set excluded_op; for (size_t i = 0; i < num_excluded_op_names; ++i) { excluded_op.emplace(excluded_op_names[i]); } std::unordered_set offline; for (size_t i = 0; i < num_offline; ++i) { offline.emplace(offline_params[i]); } std::string quantized_type(quantized_dtype); std::string quantized_mode(quantize_mode); std::string quantized_granularity(quantize_granularity); g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); g.attrs["excluded_ops"] = std::make_shared(std::move(excluded_op)); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); g.attrs["target_ctx"] = std::make_shared(target_dev); g.attrs["quantize_mode"] = std::make_shared(std::move(quantized_mode)); g.attrs["quantize_granularity"] = std::make_shared(std::move(quantized_granularity)); g = ApplyPass(std::move(g), "QuantizeGraph"); const auto& calib_nodes = g.GetAttr>("calib_nodes"); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_vec_str = std::move(calib_nodes); *out_num_calib_names = ret->ret_vec_str.size(); ret->ret_vec_charp.clear(); ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); for (const auto &str : ret->ret_vec_str) { ret->ret_vec_charp.push_back(str.c_str()); } *out_calib_names = dmlc::BeginPtr(ret->ret_vec_charp); s->outputs = g.outputs; *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } // helper function to add mapping of node_name -> dtype map // for the given indexed graph and inferred_dtypes static void _SetInputDTypes( const nnvm::IndexedGraph& idx, const nnvm::DTypeVector& inferred_dtypes, std::unordered_map* node_name_dtype_map, std::unordered_map* node_without_dtype_map) { const std::string dtype_keyword = "__dtype__"; for (uint32_t nid : idx.input_nodes()) { const auto& node = idx[nid].source; const auto& node_with_dtype = node->attrs.dict.find(dtype_keyword); // input nodes classified into nodes_with_dtype, nodes_without_dtype // This classification required because if param_names not provided // we want to update dtypes of only those nodes which have dtypes set // inferred_dtypes are obtained for the nodes, if unknown // dtype is set to fp32 if (node_with_dtype != node->attrs.dict.end()) { if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) { (*node_name_dtype_map)[node->attrs.name] = 0; } else { (*node_name_dtype_map)[node->attrs.name] = inferred_dtypes[idx.entry_id(nid, 0)]; } } else { if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) { (*node_without_dtype_map)[node->attrs.name] = 0; } else { (*node_without_dtype_map)[node->attrs.name] = inferred_dtypes[idx.entry_id(nid, 0)]; } } } } // helper function update the node dtype attrs for a vector of nodeptrs // given the node name to dtype information and the names of model_params // if model_params is provided the function will dtype of only model params. // if model_params is empty, the function will dtype of all nodes which had // a prior dtype set. // args is a const_reference vector of ObjectPtrs. ObjectPtrs are immutable but // the Nodes they are pointing will be mutated in this function static void _UpdateSymDTypeAttrs( const std::unordered_map& node_name_dtype_map, const std::unordered_map& node_without_dtype_map, const std::unordered_set& model_params, const std::vector& args) { const std::string dtype_keyword = "__dtype__"; // Update args to have the right dtype attrs if (model_params.size() > 0) { // if model params provided, set dtype only for model params for (size_t i = 0; i < args.size(); ++i) { const std::string& node_name = args[i]->attrs.name; auto it_model_params = model_params.find(node_name); auto it_with_dtype = node_name_dtype_map.find(node_name); auto it_without_dtype = node_without_dtype_map.find(node_name); if (it_model_params != model_params.end()) { // need to update __dtype__ attribute if already set, else set it if (it_with_dtype != node_name_dtype_map.end()) { args[i]->attrs.dict[dtype_keyword] = std::to_string(it_with_dtype->second); } else { CHECK(it_without_dtype != node_without_dtype_map.end()) << "make sure all nodes without dtype have properly been added " "in node_without_dtype_map"; args[i]->attrs.dict[dtype_keyword] = std::to_string(it_without_dtype->second); } } } } else { // if model params not provided, update __dtype__ for all inputs, // which already had it set, don't touch the rest for (size_t i = 0; i < args.size(); ++i) { auto it = node_name_dtype_map.find(args[i]->attrs.name); if (it != node_name_dtype_map.end()) { if (args[i]->attrs.dict.find(dtype_keyword) != args[i]->attrs.dict.end()) { args[i]->attrs.dict[dtype_keyword] = std::to_string(it->second); } } } } } int MXReducePrecisionSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, uint32_t num_args, const int *arg_type_data, uint32_t num_ind_ptr, const int* ind_ptr, const int* target_dtype, const int cast_optional_params, const uint32_t num_target_dtype_op_names, const uint32_t num_fp32_op_names, const uint32_t num_widest_dtype_op_names, const uint32_t num_conditional_fp32_op_names, const uint32_t num_excluded_symbols, const uint32_t num_model_params, const char **target_dtype_op_names, const char **fp32_op_names, const char **widest_dtype_op_names, const char **conditional_fp32_op_names, const char **excluded_symbols, const char **param_names, const char **param_vals, const char **model_param_names, const char **arg_names) { nnvm::Symbol *result_sym = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); nnvm::Graph g = Symbol2Graph(*sym); std::unordered_set target_dtype_ops; std::unordered_set fp32_ops; std::unordered_set widest_dtype_ops; std::unordered_set excluded_syms; std::unordered_set model_params; // conditional_fp32_ops contains the mapping of op_name -> (map of param_name -> param_values) // which need to be conditionally selected to be casted to FP32 std::unordered_map>> conditional_fp32_ops; int target_dt = *target_dtype; for (size_t i = 0; i < num_target_dtype_op_names; ++i) { target_dtype_ops.emplace(target_dtype_op_names[i]); } for (size_t i = 0; i < num_fp32_op_names; ++i) { fp32_ops.emplace(fp32_op_names[i]); } for (size_t i = 0; i < num_widest_dtype_op_names; ++i) { widest_dtype_ops.emplace(widest_dtype_op_names[i]); } for (size_t i = 0; i < num_excluded_symbols; ++i) { excluded_syms.emplace(excluded_symbols[i]); } for (size_t i = 0; i < num_model_params; ++i) { model_params.emplace(model_param_names[i]); } for (size_t i = 0; i < num_ind_ptr - 1; ++i) { for (int j = ind_ptr[i]; j < ind_ptr[i + 1]; ++j) { conditional_fp32_ops[conditional_fp32_op_names[i]][param_names[i]] .emplace_back(std::string(param_vals[j])); } } std::unordered_map kwargs; std::unordered_map node_name_dtype_map, node_without_dtype_map; nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1); for (uint32_t i = 0; i < num_args; ++i) { kwargs[arg_names[i]] = arg_type_data[i]; node_name_dtype_map[arg_names[i]] = arg_type_data[i]; } mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType"); g.attrs["target_dtype_ops"] = std::make_shared(std::move(target_dtype_ops)); g.attrs["fp32_ops"] = std::make_shared(std::move(fp32_ops)); g.attrs["widest_dtype_ops"] = std::make_shared(std::move(widest_dtype_ops)); g.attrs["conditional_fp32_ops"] = std::make_shared(std::move(conditional_fp32_ops)); g.attrs["excluded_syms"] = std::make_shared(std::move(excluded_syms)); g.attrs["target_dtype"] = std::make_shared(target_dt); g.attrs["data_name_types"] = std::make_shared(kwargs); g.attrs["cast_optional_params"] = std::make_shared(cast_optional_params); g = ApplyPass(std::move(g), "ReducePrecision"); // Need to run type inference since it is possible that inferred // type of some inputs has changed g = mxnet::exec::InferType(std::move(g), std::move(arg_types), ""); const nnvm::DTypeVector &inferred_dtypes = g.GetAttr("dtype"); g.attrs["inferred_dtypes"] = std::make_shared(std::move(inferred_dtypes)); g.attrs["target_dtype"] = std::make_shared(target_dt); if (cast_optional_params) { g = ApplyPass(std::move(g), "AMPInferUnknown"); const nnvm::DTypeVector &inferred_dtype_result = g.GetAttr("inferred_dtype_result"); const nnvm::IndexedGraph &idx = g.indexed_graph(); // set node name -> input dtype mapping using infer dtype _SetInputDTypes(idx, inferred_dtype_result, &node_name_dtype_map, &node_without_dtype_map); } else { const nnvm::IndexedGraph &idx = g.indexed_graph(); // set node name -> input dtype mapping using infer dtype _SetInputDTypes(idx, inferred_dtypes, &node_name_dtype_map, &node_without_dtype_map); } result_sym->outputs = g.outputs; *ret_sym_handle = result_sym; nnvm::Symbol *ret_sym = static_cast(*ret_sym_handle); const std::vector& args = ret_sym->ListInputs(nnvm::Symbol::kAll); // update symbol dtype attrs using the node name -> dtype mapping, if dtype is already set // in the symbol, else set dtype for the model_params _UpdateSymDTypeAttrs(node_name_dtype_map, node_without_dtype_map, model_params, args); API_END_HANDLE_ERROR(delete result_sym); } int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, const uint32_t num_layers, const char** layer_names, const float* min_ranges, const float* max_ranges, SymbolHandle* ret_qsym_handle) { nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol* sym = static_cast(qsym_handle); nnvm::Graph g = Symbol2Graph(*sym); std::unordered_map> calib_table; for (size_t i = 0; i < num_layers; ++i) { calib_table.emplace(layer_names[i], std::make_pair(min_ranges[i], max_ranges[i])); } g.attrs["calib_table"] = std::make_shared(std::move(calib_table)); g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph"); s->outputs = g.outputs; *ret_qsym_handle = s; API_END_HANDLE_ERROR(delete s); } int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name, SymbolHandle *ret_sym_handle) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); *s = sym->Copy(); auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name); const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto property : subgraph_prop_list) { if (property->HasAttr("disable") && property->GetAttr("disable") == true) { auto full_name = property->HasAttr("property_name") ? property->GetAttr("property_name") : std::string(); LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name << " is disabled."; continue; } nnvm::Graph g = Symbol2Graph(*s); property->SetAttr("graph", g); g.attrs["subgraph_property"] = std::make_shared(property); g = ApplyPass(std::move(g), "BuildSubgraph"); property->RemoveAttr("graph"); g.attrs.erase("subgraph_property"); s->outputs = g.outputs; } *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *source = static_cast(sym_handle); CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs."; const auto &node = source->outputs[0].node; for (const auto &other_node : source->outputs) { if (node.get() != other_node.node.get()) { LOG(FATAL) << "Generating atomic symbol from other symbol only works for nongrouped symbol."; } } const auto *op = node->op(); const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow); *s = nnvm::Symbol::CreateFunctor(op, attrs); *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } int MXShallowCopySymbol(SymbolHandle src, SymbolHandle* out) { nnvm::Symbol* out_sym = new nnvm::Symbol; API_BEGIN(); nnvm::Symbol* src_sym = static_cast(src); *out_sym = *src_sym; *out = out_sym; API_END_HANDLE_ERROR(delete out_sym); } int MXOptimizeForBackend(SymbolHandle sym_handle, const char* backend_name, const int dev_type, SymbolHandle* ret_sym_handle, const mx_uint args_len, NDArrayHandle* in_args_handle, const mx_uint aux_len, NDArrayHandle* in_aux_handle, const mx_uint num_options, const char** keys, const char** vals, const uint32_t num_input_shapes, const char** input_shape_names, const int64_t* input_shape_data, const uint32_t* input_shape_idx, const uint32_t num_input_dtypes, const char** input_dtype_names, const int* input_dtypes, const uint32_t num_input_stypes, const char** input_stype_names, const int* input_stypes, bool skip_infer, int* new_args_cnt, NDArrayHandle** new_args_handle, char*** new_arg_names_handle, int* new_aux_cnt, NDArrayHandle** new_aux_handle, char*** new_aux_names_handle) { // create copy of input symbol nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); *s = sym->Copy(); // create a data structure from pointer array std::unordered_map options_map; for (mx_uint i = 0; i < num_options; ++i) options_map.emplace(keys[i], vals[i]); NDArray ***new_args_ptr = reinterpret_cast(new_args_handle); NDArray ***new_aux_ptr = reinterpret_cast(new_aux_handle); NDArray **in_args_ptr = reinterpret_cast(in_args_handle); NDArray **in_aux_ptr = reinterpret_cast(in_aux_handle); auto init_graph = [&](nnvm::Symbol* s) { nnvm::Graph g = Symbol2Graph(*s); const auto& indexed_graph = g.indexed_graph(); const auto& mutable_nodes = indexed_graph.mutable_input_nodes(); std::vector input_names = s->ListInputNames(nnvm::Symbol::kAll); size_t num_forward_inputs = input_names.size(); if (args_len || aux_len) { if (!skip_infer) { Context default_ctx = Context::Create(static_cast(dev_type), 0); mxnet::ShapeVector arg_shapes(args_len + aux_len); nnvm::DTypeVector arg_dtypes(args_len + aux_len); StorageTypeVector arg_stypes(args_len + aux_len); // create the input shape, dtype and stype maps std::unordered_map input_shape_map(num_input_shapes); for (uint32_t i = 0; i < num_input_shapes; ++i) { input_shape_map.emplace(input_shape_names[i], mxnet::TShape(input_shape_data + input_shape_idx[i], input_shape_data + input_shape_idx[i+1])); } std::unordered_map input_dtype_map(num_input_dtypes); for (uint32_t i = 0; i < num_input_dtypes; ++i) { input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]); } std::unordered_map input_stype_map(num_input_stypes); for (uint32_t i = 0; i < num_input_stypes; ++i) { input_stype_map.emplace(input_stype_names[i], input_stypes[i]); } size_t args_top = 0, aux_top = 0; // loop over inputs to symbol in order and add to args/aux if mutable for (size_t i = 0; i < num_forward_inputs; ++i) { const uint32_t nid = indexed_graph.input_nodes().at(i); if (mutable_nodes.count(nid)) { CHECK_LT(aux_top, aux_len) << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; if (in_aux_ptr[aux_top] != nullptr) { const auto &in_arg = *(in_aux_ptr[aux_top]); arg_shapes[i] = in_arg.shape(); arg_dtypes[i] = in_arg.dtype(); arg_stypes[i] = in_arg.storage_type(); } aux_top++; } else { auto name = input_names[i]; CHECK_LT(args_top, args_len) << "Cannot find arg '" << name << "' in provided args to optimize_for"; if (in_args_ptr[args_top] != nullptr) { const auto &in_arg = *(in_args_ptr[args_top]); arg_shapes[i] = in_arg.shape(); arg_dtypes[i] = in_arg.dtype(); arg_stypes[i] = in_arg.storage_type(); } else { // input_names[i] is not in args but can be in the optional // shape/type/stype attribute dicts. auto it_shape = input_shape_map.find(name); if (it_shape != input_shape_map.end()) { arg_shapes[i] = it_shape->second; } auto it_type = input_dtype_map.find(name); if (it_type != input_dtype_map.end()) { arg_dtypes[i] = it_type->second; } it_type = input_stype_map.find(name); if (it_type != input_stype_map.end()) { arg_stypes[i] = it_type->second; } } args_top++; } } g.attrs["context"] = std::make_shared( exec::ContextVector(indexed_graph.num_nodes(), default_ctx)); // infer shapes g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); // infer dtypes g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); // infer stypes g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); } // set args/aux as attributes on graph so that subgraph property can use them std::vector arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs); g.attrs["in_args"] = std::make_shared(in_args_ptr); g.attrs["in_arg_names"] = std::make_shared(arg_names); std::vector aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates); g.attrs["in_aux"] = std::make_shared(in_aux_ptr); g.attrs["in_aux_names"] = std::make_shared(aux_names); } else { // args/aux were not specified, so set nullptr/empty-lists NDArray **in_args_ptr = static_cast(nullptr); std::vector arg_names; g.attrs["in_args"] = std::make_shared(in_args_ptr); g.attrs["in_arg_names"] = std::make_shared(arg_names); NDArray **in_aux_ptr = static_cast(nullptr); std::vector aux_names; g.attrs["in_aux"] = std::make_shared(in_aux_ptr); g.attrs["in_aux_names"] = std::make_shared(aux_names); } // set dedup option as attribute on graph to enable dedup during partitioning if (options_map.count("dedup_subgraph") > 0 && options_map.at("dedup_subgraph").compare("True") == 0) g.attrs["dedup_subgraph"] = std::make_shared(std::string("True")); return g; }; if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) { // use subgraph backend const auto backend = mxnet::op::SubgraphBackendRegistry ::Get()->GetSubgraphBackend(backend_name); const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto property : subgraph_prop_list) { nnvm::Graph g = init_graph(s); property->PrePartition(g, options_map); g.attrs["subgraph_property"] = std::make_shared(property); g = ApplyPass(std::move(g), "BuildSubgraph"); g.attrs.erase("subgraph_property"); property->PostPartition(g); s->outputs = g.outputs; } } else if (dmlc::Registry::Find(backend_name) != nullptr) { // use graph pass nnvm::Graph g = init_graph(s); g.attrs["options_map"] = std::make_shared(options_map); g.attrs["pass_name"] = std::make_shared(backend_name); g = ApplyPass(std::move(g), backend_name); std::vector new_args = g.GetAttr>("new_args"); std::vector new_aux = g.GetAttr>("new_aux"); std::vector new_arg_names = g.GetAttr>("new_arg_names"); std::vector new_aux_names = g.GetAttr>("new_aux_names"); g.attrs.erase("new_args"); g.attrs.erase("new_aux"); g.attrs.erase("new_arg_names"); g.attrs.erase("new_aux_names"); s->outputs = g.outputs; NDArray** new_arg_arr = new NDArray*[new_arg_names.size()]; NDArray** new_aux_arr = new NDArray*[new_aux_names.size()]; char** new_arg_cstr = new char*[new_arg_names.size()]; char** new_aux_cstr = new char*[new_aux_names.size()]; for (unsigned i = 0; i < new_arg_names.size(); i++) { new_arg_arr[i] = new_args[i]; std::string& s = new_arg_names[i]; char* tmp = new char[s.length()+1]; s.copy(tmp, s.length()); tmp[s.length()] = '\0'; new_arg_cstr[i] = tmp; } for (unsigned i = 0; i < new_aux_names.size(); i++) { new_aux_arr[i] = new_aux[i]; std::string& s = new_aux_names[i]; char* tmp = new char[s.length()+1]; s.copy(tmp, s.length()); tmp[s.length()] = '\0'; new_aux_cstr[i] = tmp; } *new_args_cnt = new_arg_names.size(); *new_aux_cnt = new_aux_names.size(); *new_arg_names_handle = new_arg_cstr; *new_aux_names_handle = new_aux_cstr; *new_args_ptr = new_arg_arr; *new_aux_ptr = new_aux_arr; } else { // cannot find graph pass or subgraph backend registered in this name LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found"; } *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); }