1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file infer_graph_attr_pass.cc
22 * \brief infer graph shape, dtype, and storage type
23 */
24
25 #include <mxnet/op_attr_types.h>
26 #include <mxnet/graph_attr_types.h>
27 #include <mxnet/imperative.h>
28 #include "./exec_pass.h"
29 #include "../operator/operator_common.h"
30 #include "../common/exec_utils.h"
31
32 namespace mxnet {
33 namespace exec {
34
35 template<typename AttrType, typename FInfer>
ApplyOpInferAttr(const nnvm::Graph & g,const FInfer & finfer,const NodeAttrs & attrs,const uint32_t nid,std::vector<AttrType> * in_attrs,std::vector<AttrType> * out_attrs,DispatchMode * dispatch_mode)36 bool ApplyOpInferAttr(const nnvm::Graph& g,
37 const FInfer& finfer,
38 const NodeAttrs& attrs,
39 const uint32_t nid,
40 std::vector<AttrType>* in_attrs,
41 std::vector<AttrType>* out_attrs,
42 DispatchMode* dispatch_mode) {
43 return finfer(attrs, in_attrs, out_attrs);
44 }
45
46 template<>
ApplyOpInferAttr(const nnvm::Graph & g,const FInferStorageType & finfer,const NodeAttrs & attrs,const uint32_t nid,std::vector<int> * in_attrs,std::vector<int> * out_attrs,DispatchMode * dispatch_mode)47 bool ApplyOpInferAttr<int, FInferStorageType>(const nnvm::Graph& g,
48 const FInferStorageType& finfer,
49 const NodeAttrs& attrs,
50 const uint32_t nid,
51 std::vector<int>* in_attrs,
52 std::vector<int>* out_attrs,
53 DispatchMode* dispatch_mode) {
54 const DevMaskVector& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
55 const bool success = finfer(attrs, dev_masks[nid], dispatch_mode, in_attrs, out_attrs);
56 if (!success) {
57 LOG(FATAL) << "Operator not implemented: "
58 << common::operator_stype_string(attrs, dev_masks[nid], *in_attrs, *out_attrs);
59 }
60 if (*dispatch_mode == DispatchMode::kFComputeFallback) {
61 common::LogStorageFallback(attrs, dev_masks[nid], in_attrs, out_attrs);
62 }
63 return true;
64 }
65
66 template<typename AttrType, typename IsNone>
GetAttrFromForwardNode(const uint32_t nid,const nnvm::IndexedGraph & idx,std::vector<AttrType> * rshape_ptr,std::vector<bool> * inference_finished,IsNone fis_none)67 inline void GetAttrFromForwardNode(const uint32_t nid,
68 const nnvm::IndexedGraph &idx,
69 std::vector<AttrType>* rshape_ptr,
70 std::vector<bool>* inference_finished,
71 IsNone fis_none) {
72 std::vector<AttrType>& rshape = *rshape_ptr;
73 const nnvm::IndexedGraph::Node& inode = idx[nid];
74 // gradient function, used to get node correspondence.
75 static auto& fgrad =
76 Op::GetAttr<nnvm::FGradient>("FGradient");
77 nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0];
78 const nnvm::IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
79 // use gradient function to find out the correspondence.
80 std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs());
81 for (size_t i = 0; i < ograd.size(); ++i) {
82 ograd[i].index = static_cast<uint32_t>(i);
83 }
84 // input gradient list
85 const std::vector<nnvm::NodeEntry>& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
86 const nnvm::Node* igrad_node = nullptr;
87 bool all_attrs_known = true;
88 // Input gradient assignement
89 for (size_t i = 0; i < igrad.size(); ++i) {
90 if (igrad[i].node->op() == inode.source->op()) {
91 uint32_t eid = idx.entry_id(nid, igrad[i].index);
92 if (fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
93 // Need to skip empty forward shape, because it may not be
94 // available now and it is possible to infer the forward
95 // shape in one of the next a few passes
96 all_attrs_known = false;
97 } else {
98 if (fis_none(rshape[eid])) {
99 rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
100 } else {
101 CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
102 << "Backward shape inconsistent with the forward shape";
103 }
104 }
105 if (igrad_node == nullptr) {
106 igrad_node = igrad[i].node.get();
107 } else {
108 CHECK(igrad_node == igrad[i].node.get());
109 }
110 }
111 }
112 // out grad entries
113 CHECK(igrad_node != nullptr)
114 << "Cannot find matching backward op for " << inode.source->attrs.name;
115 for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
116 const nnvm::NodeEntry& e = igrad_node->inputs[i];
117 if (e.node == nullptr) {
118 uint32_t eid = idx.entry_id(inode.inputs[i]);
119 if (fis_none(rshape[eid])) {
120 rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
121 }
122 if (fis_none(rshape[eid])) {
123 // If the attr is still unknown
124 all_attrs_known = false;
125 }
126 }
127 }
128 (*inference_finished)[nid] = all_attrs_known;
129 }
130
131 template<typename FAccessSubgraphType, typename AttrType, typename IsNone>
GetAttrFromFusedNode(uint32_t nid,const nnvm::IndexedGraph & idx,std::vector<AttrType> * rshape_ptr,std::vector<bool> * inference_finished,IsNone fis_none,const std::string & infer_fusion_name)132 void GetAttrFromFusedNode(uint32_t nid,
133 const nnvm::IndexedGraph& idx,
134 std::vector<AttrType>* rshape_ptr,
135 std::vector<bool>* inference_finished,
136 IsNone fis_none,
137 const std::string& infer_fusion_name) {
138 std::vector<AttrType>& rshape = *rshape_ptr;
139 const auto& inode = idx[nid];
140 // gradient function, used to get node correspondence.
141 static auto& fgrad =
142 Op::GetAttr<nnvm::FGradient>("FGradient");
143 nnvm::ObjectPtr fused_fwd_ptr = inode.source->control_deps[0];
144 static auto& finfer_fused_shape =
145 Op::GetAttr<FAccessSubgraphType>(infer_fusion_name);
146 auto finfer = finfer_fused_shape.get(fused_fwd_ptr->op(), nullptr);
147 CHECK(finfer != nullptr) << "Operator " << fused_fwd_ptr->attrs.name <<
148 " is marked as Fusion but does not allow accessing attributes";
149 const auto& inferred_attrs = finfer(fused_fwd_ptr->attrs);
150 const auto& fwd_ptr = std::get<0>(inferred_attrs);
151 const auto& input_attrs = std::get<1>(inferred_attrs);
152 const auto& output_attrs = std::get<2>(inferred_attrs);
153
154 // use gradient function to find out the correspondence.
155 std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs());
156 for (size_t i = 0; i < ograd.size(); ++i) {
157 ograd[i].index = static_cast<uint32_t>(i);
158 }
159 // input gradient list
160 const std::vector<nnvm::NodeEntry>& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
161 const nnvm::Node* igrad_node = nullptr;
162 bool all_attrs_known = true;
163 // Set the attributes of output gradients
164 // using attributes of forward node inputs
165 for (size_t i = 0; i < igrad.size(); ++i) {
166 if (igrad[i].node->op() == inode.source->op()) {
167 uint32_t eid = idx.entry_id(nid, igrad[i].index);
168 if (fis_none(input_attrs[i])) {
169 // Need to skip empty forward shape, because it may not be
170 // available now and it is possible to infer the forward
171 // shape in one of the next a few passes
172 all_attrs_known = false;
173 } else {
174 if (fis_none(rshape[eid])) {
175 rshape[eid] = input_attrs[i];
176 } else {
177 CHECK_EQ(rshape[eid], input_attrs[i])
178 << "Backward shape inconsistent with the forward shape";
179 }
180 }
181 if (igrad_node == nullptr) {
182 igrad_node = igrad[i].node.get();
183 } else {
184 CHECK(igrad_node == igrad[i].node.get());
185 }
186 }
187 }
188
189 // Set the attributes of input gradients
190 // using attributes of forward node outputs
191 CHECK(igrad_node != nullptr)
192 << "Cannot find matching backward op for " << inode.source->attrs.name;
193 for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
194 const nnvm::NodeEntry& e = igrad_node->inputs[i];
195 if (e.node == nullptr) {
196 uint32_t eid = idx.entry_id(inode.inputs[i]);
197 if (fis_none(rshape[eid])) {
198 rshape[eid] = output_attrs[e.index];
199 }
200 if (fis_none(rshape[eid])) {
201 // If the attr is still unknown
202 all_attrs_known = false;
203 }
204 }
205 }
206 (*inference_finished)[nid] = all_attrs_known;
207 }
208
209 template <typename FProvideSubgraphType, typename AttrType>
ProvideAttrToFusion(const uint32_t nid,const nnvm::IndexedGraph & idx,const std::vector<AttrType> & rshape,const std::string & provide_fusion_name)210 void ProvideAttrToFusion(const uint32_t nid,
211 const nnvm::IndexedGraph& idx,
212 const std::vector<AttrType>& rshape,
213 const std::string& provide_fusion_name) {
214 const auto& inode = idx[nid];
215 std::vector<std::vector<AttrType>> in_attrs;
216 std::vector<std::vector<AttrType>> out_attrs;
217 for (const auto& dep_node : inode.source->control_deps) {
218 in_attrs.push_back({});
219 out_attrs.push_back({});
220 auto ¤t_in_attrs = in_attrs.back();
221 auto ¤t_out_attrs = out_attrs.back();
222 uint32_t dep_node_id = idx.node_id(dep_node.get());
223 for (const auto& e : idx[dep_node_id].inputs) {
224 current_in_attrs.push_back(rshape[idx.entry_id(e)]);
225 }
226 for (size_t i = 0; i < dep_node->num_outputs(); ++i) {
227 current_out_attrs.push_back(rshape[idx.entry_id(dep_node_id, i)]);
228 }
229 }
230 auto provide =
231 Op::GetAttr<FProvideSubgraphType>(provide_fusion_name).get(inode.source->op(), nullptr);
232 CHECK(provide != nullptr) <<
233 "Encountered Fusion operator that does not implement providing subgraph attr " <<
234 provide_fusion_name << ".";
235 provide(inode.source->attrs, inode.source->control_deps, in_attrs, out_attrs);
236 }
237
238 /*!\brief
239 * This is a duplicate of the InferAttr function in nnvm with minor modification
240 * to support inferring storage type whose function signature is different from
241 * shape/type inference functions'. The nnvm InferAttr will be deprecated
242 * in the future. Please use interfaces InferShape, InferType, and InferStorageType
243 * to call this function.
244 *
245 * \param ret graph used for attribute inference
246 * \param emmpty_val empty value of the attribute
247 * \param infer_name name of the function used for attribute inference
248 * \param infer_fusion_name name of the function used for accessing attributes in fused nodes
249 * \param input_name name of the attribute in the graph used to store the
250 * input data for attribute inference
251 * \param attr_key_name name of the attribute used for inference for variable nodes
252 * \param attr_name name of the inferred attribute
253 * \param unknown_name name of the attribute storing number of entries
254 * impossible to infer
255 * \param fis_none function returning true for not fully inferred values
256 * \param fdefault default function used for inference if the node does not
257 * provide its own implementation.
258 * \param bwd_identity_assign whether the attributes of forward NDArray and backward
259 * NDArray have to be the same. False only for storage
260 * type inference
261 * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
262 * storage type inference
263 * \param default_mode_val default value of the dispatch mode attribute on the node. Used
264 * for storage type inference
265 */
266 template<typename AttrType, typename FInferType, typename FAccessSubgraphType,
267 typename FProvideSubgraphType, typename IsNone, typename FDefault>
InferAttr(nnvm::Graph && ret,const AttrType empty_val,const char * infer_name,const char * infer_fusion_name,const char * provide_fusion_name,const char * input_name,const char * attr_key_name,const char * attr_name,const char * unknown_name,IsNone fis_none,FDefault fdefault,bool bwd_identity_assign,const char * dispatch_mode_name,const DispatchMode default_mode_val=DispatchMode::kUndefined)268 nnvm::Graph InferAttr(nnvm::Graph &&ret,
269 const AttrType empty_val,
270 const char* infer_name,
271 const char* infer_fusion_name,
272 const char* provide_fusion_name,
273 const char* input_name,
274 const char* attr_key_name,
275 const char* attr_name,
276 const char* unknown_name,
277 IsNone fis_none,
278 FDefault fdefault,
279 bool bwd_identity_assign,
280 const char* dispatch_mode_name,
281 const DispatchMode default_mode_val = DispatchMode::kUndefined) {
282 using nnvm::IndexedGraph;
283 using nnvm::Op;
284 using AttrVector = std::vector<AttrType>;
285 using NodeAttrVector = std::vector<DispatchMode>;
286 using dmlc::any;
287
288 const IndexedGraph& idx = ret.indexed_graph();
289 static auto& finfer_shape =
290 Op::GetAttr<FInferType>(infer_name);
291 static auto& is_backward =
292 Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
293 // reshape shape vector
294 AttrVector rshape;
295 // vector holding information which operators
296 // finished attribute inference
297 std::vector<bool> inference_finished(idx.num_nodes(), false);
298 // dispatch mode vector
299 DispatchModeVector dispatch_modes;
300 if (ret.attrs.count(attr_name) != 0) {
301 rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
302 } else {
303 rshape.resize(idx.num_node_entries(), empty_val);
304 }
305
306 if (ret.attrs.count(input_name) != 0) {
307 const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
308 CHECK_LE(shape_args.size(), idx.input_nodes().size())
309 << "More provided " << attr_name << "s than number of arguments.";
310 for (size_t i = 0; i < shape_args.size(); ++i) {
311 rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
312 }
313 }
314
315 // get the shape hints
316 std::string shape_hints_key = std::string(attr_name) + "_hints";
317 if (ret.attrs.count(shape_hints_key)) {
318 nnvm::NodeEntryMap<AttrType> shape_hints =
319 ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key);
320 for (const auto& kv : shape_hints) {
321 nnvm::NodeEntry e = kv.first;
322 if (idx.exist(e.node.get())) {
323 rshape[idx.entry_id(kv.first)] = kv.second;
324 }
325 }
326 }
327
328 std::string shape_attr_key;
329 if (ret.attrs.count(attr_key_name) != 0) {
330 shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
331 // erase the provided arguments
332 ret.attrs.erase(attr_key_name);
333 }
334
335 // limit inference to part of the graph
336 uint32_t node_start = 0, node_end = idx.num_nodes();
337 if (ret.attrs.count("node_range")) {
338 const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
339 node_start = range.first;
340 node_end = range.second;
341 CHECK_GE(node_start, 0);
342 CHECK_LE(node_end, idx.num_nodes());
343 ret.attrs.erase("node_range");
344 }
345 uint32_t entry_start = 0, entry_end = idx.num_node_entries();
346 if (ret.attrs.count("entry_range")) {
347 const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range");
348 entry_start = range.first;
349 entry_end = range.second;
350 CHECK_GE(entry_start, 0);
351 CHECK_LE(entry_end, idx.num_node_entries());
352 ret.attrs.erase("entry_range");
353 }
354 // populate the node attribute vector
355 if (dispatch_mode_name != nullptr) {
356 if (ret.attrs.count(dispatch_mode_name) != 0) {
357 dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name);
358 } else {
359 LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph";
360 }
361 }
362
363 // Temp space for shape inference.
364 std::vector<AttrType> ishape, oshape;
365
366 // inference step function for nid
367 auto infer_step = [&](uint32_t nid, bool last_iter) {
368 if (inference_finished[nid]) return;
369 const auto& inode = idx[nid];
370 const uint32_t num_inputs = inode.inputs.size();
371 const uint32_t num_outputs = inode.source->num_outputs();
372 if (inode.source->is_variable()) {
373 // Variable node. No operator. Only one output entry.
374 CHECK(inode.source->op() == nullptr);
375 CHECK_EQ(num_outputs, 1U);
376 const uint32_t out_ent_id = idx.entry_id(nid, 0);
377 if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
378 auto it = inode.source->attrs.dict.find(shape_attr_key);
379 if (it != inode.source->attrs.dict.end()) {
380 std::istringstream is(it->second);
381 CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
382 }
383 }
384 if (!fis_none(rshape[out_ent_id])) {
385 inference_finished[nid] = true;
386 }
387 // assign a default value to node attribute
388 if (dispatch_mode_name != nullptr) {
389 op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val);
390 }
391 } else if (is_backward.get(inode.source->op(), false) &&
392 inode.source->control_deps.size() && bwd_identity_assign) {
393 CHECK(dispatch_mode_name == nullptr)
394 << "Backward inference for node attributes is not available";
395 CHECK_GE(inode.source->control_deps.size(), 1U)
396 << "BackwardOp need to have control_deps to its forward op";
397 nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0];
398 CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
399
400 static auto& is_fusion_helper = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
401 if (!is_fusion_helper.get(fwd_ptr->op(), false)) {
402 GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none);
403 } else {
404 GetAttrFromFusedNode<FAccessSubgraphType>(nid, idx, &rshape, &inference_finished,
405 fis_none, infer_fusion_name);
406 }
407 } else {
408 DispatchMode* dispatch_mode = nullptr;
409 // Forward operator inference.
410 ishape.resize(num_inputs, empty_val);
411 for (uint32_t i = 0; i < ishape.size(); ++i) {
412 ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
413 }
414 oshape.resize(num_outputs, empty_val);
415 for (uint32_t i = 0; i < oshape.size(); ++i) {
416 oshape[i] = rshape[idx.entry_id(nid, i)];
417 }
418 if (dispatch_mode_name != nullptr) {
419 dispatch_mode = &dispatch_modes[nid];
420 }
421 auto finfer = finfer_shape.get(inode.source->op(), fdefault);
422 if (finfer != nullptr) {
423 // Call inference function of the operator.
424 try {
425 static auto& is_fusion = Op::GetAttr<exec::TIsFusion>("TIsFusion");
426 if (is_fusion.get(inode.source->op(), false)) {
427 ProvideAttrToFusion<FProvideSubgraphType>(nid, idx, rshape, provide_fusion_name);
428 }
429 ApplyOpInferAttr(ret, finfer, inode.source->attrs,
430 nid, &ishape, &oshape, dispatch_mode);
431 bool finished = true;
432 for (const auto& attr : ishape) {
433 if (fis_none(attr)) finished = false;
434 }
435 for (const auto& attr : oshape) {
436 if (fis_none(attr)) finished = false;
437 }
438 inference_finished[nid] = finished;
439 } catch (const std::exception& e) {
440 throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
441 }
442 } else {
443 // Operator does not provide sttribute inference function,
444 // so we need to test if everything was inferred by other operators
445 bool all_attrs_known = true;
446 for (const auto& attr : ishape) {
447 if (fis_none(attr)) {
448 all_attrs_known = false;
449 }
450 }
451 for (const auto& attr : oshape) {
452 if (fis_none(attr)) {
453 all_attrs_known = false;
454 }
455 }
456 inference_finished[nid] = all_attrs_known;
457 if (!all_attrs_known) {
458 CHECK(!last_iter)
459 << "Attribute " << infer_name
460 << " is not registered by op " << inode.source->op()->name
461 << ". We are not able to complete the inference because of this";
462 }
463 }
464 // Save to the result map.
465 for (uint32_t i = 0; i < num_inputs; ++i) {
466 rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
467 }
468 for (uint32_t i = 0; i < num_outputs; ++i) {
469 rshape[idx.entry_id(nid, i)] = oshape[i];
470 }
471 }
472 };
473
474 size_t last_num_unknown;
475 size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0;
476 size_t num_unknown_entry_attr = entry_end - entry_start;
477 size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
478 bool last_iter = false;
479 bool do_next_iteration = true;
480 int i = 0;
481 do {
482 if (i % 2 == 0) {
483 for (uint32_t nid = node_start; nid < node_end; ++nid) {
484 infer_step(nid, last_iter);
485 }
486 } else {
487 // backward inference
488 for (uint32_t i = node_end; i != node_start; --i) {
489 infer_step(i - 1, last_iter);
490 }
491 }
492 last_num_unknown = num_unknown;
493 num_unknown = 0;
494 for (size_t j = entry_start; j < entry_end; ++j) {
495 if (fis_none(rshape[j])) {
496 ++num_unknown;
497 }
498 }
499 if (dispatch_mode_name) {
500 for (size_t i = node_start; i < node_end; i++) {
501 if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown;
502 }
503 }
504 do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown;
505 if (!do_next_iteration && !last_iter) {
506 // Check if every op agrees that it should be
507 // the end of attribute inference. If not,
508 // perform one final step
509 for (const bool done : inference_finished) {
510 do_next_iteration = do_next_iteration || !done;
511 }
512 last_iter = true;
513 }
514 ++i;
515 } while (do_next_iteration);
516 // set the shapes
517 ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
518 // set the shapes
519 if (dispatch_mode_name) {
520 ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes));
521 }
522 // number of nodes who knows the shape.
523 ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
524 return ret;
525 }
526
527 /*!\brief
528 * This is a version of the InferAttr function specifically for shape inference.
529 *
530 * \param ret graph used for attribute inference
531 * \param emmpty_val empty value of the attribute
532 * \param infer_name name of the function used for attribute inference
533 * \param input_name name of the attribute in the graph used to store the
534 * input data for attribute inference
535 * \param attr_key_name name of the attribute used for inference for variable nodes
536 * \param attr_name name of the inferred attribute
537 * \param unknown_name name of the attribute storing number of entries
538 * impossible to infer
539 * \param fis_none function returning true for not fully inferred values
540 * \param fnum_unknown function returning how many elements are unknown in
541 * partially inferred value of the attribute
542 * \param fdefault default function used for inference if the node does not
543 * provide its own implementation.
544 * \param bwd_identity_assign whether the attributes of forward NDArray and backward
545 * NDArray have to be the same. False only for storage
546 * type inference
547 * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
548 * storage type inference
549 * \param default_mode_val default value of the dispatch mode attribute on the node. Used
550 * for storage type inference
551 */
552 template<typename IsNone, typename FDefault, typename FNumUnknown>
InferShapeAttr(nnvm::Graph && ret,const mxnet::TShape empty_val,const char * infer_name,const char * input_name,const char * attr_key_name,const char * attr_name,const char * unknown_name,IsNone fis_none,FNumUnknown fnum_unknown,FDefault fdefault,bool bwd_identity_assign,const char * dispatch_mode_name,const DispatchMode default_mode_val=DispatchMode::kUndefined)553 nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
554 const mxnet::TShape empty_val,
555 const char* infer_name,
556 const char* input_name,
557 const char* attr_key_name,
558 const char* attr_name,
559 const char* unknown_name,
560 IsNone fis_none,
561 FNumUnknown fnum_unknown,
562 FDefault fdefault,
563 bool bwd_identity_assign,
564 const char* dispatch_mode_name,
565 const DispatchMode default_mode_val = DispatchMode::kUndefined) {
566 using nnvm::IndexedGraph;
567 using nnvm::Op;
568 using AttrType = mxnet::TShape;
569 using FInferType = mxnet::FInferShape;
570 using AttrVector = std::vector<AttrType>;
571 using NodeAttrVector = std::vector<DispatchMode>;
572 using dmlc::any;
573 const IndexedGraph& idx = ret.indexed_graph();
574 static auto& finfer_shape =
575 Op::GetAttr<FInferType>(infer_name);
576 static auto& is_backward =
577 Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
578 // reshape shape vector
579 AttrVector rshape;
580 // vector holding information which operators
581 // finished attribute inference
582 std::vector<bool> inference_finished(idx.num_nodes(), false);
583 // dispatch mode vector
584 DispatchModeVector dispatch_modes;
585 if (ret.attrs.count(attr_name) != 0) {
586 rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
587 } else {
588 rshape.resize(idx.num_node_entries(), empty_val);
589 }
590
591 if (ret.attrs.count(input_name) != 0) {
592 const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
593 CHECK_LE(shape_args.size(), idx.input_nodes().size())
594 << "More provided " << attr_name << "s than number of arguments.";
595 for (size_t i = 0; i < shape_args.size(); ++i) {
596 rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
597 }
598 }
599
600 // get the shape hints
601 std::string shape_hints_key = std::string(attr_name) + "_hints";
602 if (ret.attrs.count(shape_hints_key)) {
603 nnvm::NodeEntryMap<AttrType> shape_hints =
604 ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key);
605 for (const auto& kv : shape_hints) {
606 nnvm::NodeEntry e = kv.first;
607 if (idx.exist(e.node.get())) {
608 rshape[idx.entry_id(kv.first)] = kv.second;
609 }
610 }
611 }
612
613 std::string shape_attr_key;
614 if (ret.attrs.count(attr_key_name) != 0) {
615 shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
616 // erase the provided arguments
617 ret.attrs.erase(attr_key_name);
618 }
619
620 // limit inference to part of the graph
621 uint32_t node_start = 0, node_end = idx.num_nodes();
622 if (ret.attrs.count("node_range")) {
623 const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
624 node_start = range.first;
625 node_end = range.second;
626 CHECK_GE(node_start, 0);
627 CHECK_LE(node_end, idx.num_nodes());
628 ret.attrs.erase("node_range");
629 }
630 uint32_t entry_start = 0, entry_end = idx.num_node_entries();
631 if (ret.attrs.count("entry_range")) {
632 const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range");
633 entry_start = range.first;
634 entry_end = range.second;
635 CHECK_GE(entry_start, 0);
636 CHECK_LE(entry_end, idx.num_node_entries());
637 ret.attrs.erase("entry_range");
638 }
639 // populate the node attribute vector
640 if (dispatch_mode_name != nullptr) {
641 if (ret.attrs.count(dispatch_mode_name) != 0) {
642 dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name);
643 } else {
644 LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph";
645 }
646 }
647
648 // Temp space for shape inference.
649 std::vector<AttrType> ishape, oshape;
650 // whether a shape is dynamic
651 std::vector<int> is_dynamic(rshape.size(), 0);
652
653 // convert to numpy compatible shape to use operator's infer shape function
654 if (!Imperative::Get()->is_np_shape()) {
655 common::ConvertToNumpyShape(&rshape);
656 }
657
658 // inference step function for nid
659 auto infer_step = [&](uint32_t nid, bool last_iter) {
660 if (inference_finished[nid]) return;
661 const auto& inode = idx[nid];
662 const std::string name = inode.source->attrs.name;
663 const uint32_t num_inputs = inode.inputs.size();
664 const uint32_t num_outputs = inode.source->num_outputs();
665
666 if (inode.source->is_variable()) {
667 // Variable node. No operator. Only one output entry.
668 CHECK(inode.source->op() == nullptr);
669 CHECK_EQ(num_outputs, 1U);
670 const uint32_t out_ent_id = idx.entry_id(nid, 0);
671 if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
672 auto it = inode.source->attrs.dict.find(shape_attr_key);
673 if (it != inode.source->attrs.dict.end()) {
674 std::istringstream is(it->second);
675 CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
676 if (!Imperative::Get()->is_np_shape()) {
677 common::ConvertToNumpyShape(&rshape[out_ent_id]);
678 }
679 }
680 }
681 if (!fis_none(rshape[out_ent_id])) {
682 inference_finished[nid] = true;
683 }
684 // assign a default value to node attribute
685 if (dispatch_mode_name != nullptr) {
686 op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val);
687 }
688 } else if (is_backward.get(inode.source->op(), false) &&
689 inode.source->control_deps.size() && bwd_identity_assign) {
690 CHECK(dispatch_mode_name == nullptr)
691 << "Backward inference for node attributes is not available";
692 CHECK_GE(inode.source->control_deps.size(), 1U)
693 << "BackwardOp need to have control_deps to its forward op";
694 nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0];
695 CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
696
697 static auto& is_fusion_helper = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
698 if (!is_fusion_helper.get(fwd_ptr->op(), false)) {
699 GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none);
700 } else {
701 GetAttrFromFusedNode<exec::FAccessSubgraphShape>(nid, idx, &rshape,
702 &inference_finished,
703 fis_none,
704 "FAccessSubgraphShape");
705 }
706 } else {
707 DispatchMode* dispatch_mode = nullptr;
708 // Forward operator inference.
709 ishape.resize(num_inputs, empty_val);
710 bool is_input_dynamic_shape = false;
711 for (uint32_t i = 0; i < ishape.size(); ++i) {
712 ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
713 if (!mxnet::ndim_is_known(ishape[i]) && is_dynamic[idx.entry_id(inode.inputs[i])]) {
714 is_input_dynamic_shape = true;
715 }
716 }
717 oshape.resize(num_outputs, empty_val);
718 for (uint32_t i = 0; i < oshape.size(); ++i) {
719 oshape[i] = rshape[idx.entry_id(nid, i)];
720 }
721 if (dispatch_mode_name != nullptr) {
722 dispatch_mode = &dispatch_modes[nid];
723 }
724 auto finfer = finfer_shape.get(inode.source->op(), fdefault);
725 if (finfer == nullptr || is_input_dynamic_shape) {
726 for (uint32_t i = 0; i < oshape.size(); ++i) {
727 if (!mxnet::ndim_is_known(oshape[i].ndim())) {
728 is_dynamic[idx.entry_id(nid, i)] = 1;
729 }
730 }
731 inference_finished[nid] = true;
732 } else {
733 // Call inference function of the operator.
734 try {
735 static auto& is_fusion = Op::GetAttr<exec::TIsFusion>("TIsFusion");
736 if (is_fusion.get(inode.source->op(), false)) {
737 ProvideAttrToFusion<exec::FProvideSubgraphShape>(nid, idx, rshape,
738 "FProvideSubgraphShape");
739 }
740 ApplyOpInferAttr(ret, finfer, inode.source->attrs,
741 nid, &ishape, &oshape, dispatch_mode);
742 bool finished = true;
743 for (const auto& attr : ishape) {
744 if (fis_none(attr)) finished = false;
745 }
746 for (const auto& attr : oshape) {
747 if (fis_none(attr)) finished = false;
748 }
749 inference_finished[nid] = finished;
750 } catch (const std::exception& e) {
751 throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
752 }
753 }
754 // Save to the result map.
755 for (uint32_t i = 0; i < num_inputs; ++i) {
756 rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
757 }
758 for (uint32_t i = 0; i < num_outputs; ++i) {
759 rshape[idx.entry_id(nid, i)] = oshape[i];
760 }
761 }
762 };
763
764 size_t last_num_unknown;
765 size_t num_unknown = static_cast<size_t>(-1); // Infinity
766 bool last_iter = false;
767 bool do_next_iteration = true;
768
769 int i = 0;
770 do {
771 if (i % 2 == 0) {
772 // forward inference
773 for (uint32_t nid = node_start; nid < node_end; ++nid) {
774 infer_step(nid, last_iter);
775 }
776 } else {
777 // backward inference
778 for (uint32_t i = node_end; i != node_start; --i) {
779 infer_step(i - 1, last_iter);
780 }
781 }
782 last_num_unknown = num_unknown;
783 num_unknown = 0;
784 for (size_t j = entry_start; j < entry_end; ++j) {
785 if (fis_none(rshape[j])) {
786 num_unknown += fnum_unknown(rshape[j]);
787 }
788 }
789 if (dispatch_mode_name) {
790 for (size_t i = node_start; i < node_end; i++) {
791 if (dispatch_modes[i] == DispatchMode::kUndefined) {
792 ++num_unknown;
793 }
794 }
795 }
796 do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown;
797 if (!do_next_iteration && !last_iter) {
798 // Check if every op agrees that it should be
799 // the end of attribute inference. If not,
800 // perform one final step
801 for (const bool done : inference_finished) {
802 do_next_iteration = do_next_iteration || !done;
803 }
804 last_iter = true;
805 }
806 ++i;
807 } while (do_next_iteration);
808 // set the shapes
809 ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
810 // set the shapes
811 if (dispatch_mode_name) {
812 ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes));
813 }
814 // number of nodes who knows the shape.
815 ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
816 return ret;
817 }
818
InferShape(nnvm::Graph && graph,mxnet::ShapeVector && shape_inputs,const std::string & shape_attr_key)819 nnvm::Graph InferShape(nnvm::Graph&& graph,
820 mxnet::ShapeVector&& shape_inputs,
821 const std::string& shape_attr_key) {
822 using dmlc::any;
823 if (shape_inputs.size() != 0) {
824 graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
825 }
826 if (shape_attr_key.length() != 0) {
827 graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
828 }
829 return InferShapeAttr(
830 std::move(graph), mxnet::TShape(),
831 "FInferShape", "shape_inputs", "shape_attr_key",
832 "shape", "shape_num_unknown_nodes",
833 [](const mxnet::TShape& s) { return !mxnet::shape_is_known(s); },
834 [](const mxnet::TShape& s) {
835 if (!mxnet::ndim_is_known(s)) {
836 return static_cast<size_t>(1);
837 }
838 size_t ret = 0;
839 for (const auto& val : s) {
840 if (!mxnet::dim_size_is_known(val)) {
841 ++ret;
842 }
843 }
844 return ret;
845 },
846 nullptr, true, nullptr);
847 }
848
InferType(nnvm::Graph && graph,nnvm::DTypeVector && dtype_inputs,const std::string & dtype_attr_key)849 nnvm::Graph InferType(nnvm::Graph&& graph,
850 nnvm::DTypeVector&& dtype_inputs,
851 const std::string& dtype_attr_key) {
852 using dmlc::any;
853 if (dtype_inputs.size() != 0) {
854 graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
855 }
856 if (dtype_attr_key.length() != 0) {
857 graph.attrs["dtype_attr_key"] = std::make_shared<any>(dtype_attr_key);
858 }
859 return InferAttr<int, nnvm::FInferType, exec::FAccessSubgraphType,
860 exec::FProvideSubgraphType>(
861 std::move(graph), -1,
862 "FInferType", "FAccessSubgraphType", "FProvideSubgraphType",
863 "dtype_inputs", "dtype_attr_key", "dtype", "dtype_num_unknown_nodes",
864 [](const int t) { return t == -1; },
865 common::SameType, true, nullptr);
866 }
867
InferStorageType(nnvm::Graph && graph,StorageTypeVector && storage_type_inputs,const std::string & storage_type_attr_key)868 nnvm::Graph InferStorageType(nnvm::Graph&& graph,
869 StorageTypeVector&& storage_type_inputs,
870 const std::string& storage_type_attr_key) {
871 using dmlc::any;
872 if (storage_type_inputs.size() != 0) {
873 graph.attrs["storage_type_inputs"] = std::make_shared<any>(std::move(storage_type_inputs));
874 }
875 if (storage_type_attr_key.length() != 0) {
876 graph.attrs["storage_type_attr_key"] = std::make_shared<any>(storage_type_attr_key);
877 }
878 // initialize unknown values for dispatch modes
879 if (graph.attrs.count("dispatch_mode") == 0) {
880 DispatchModeVector dispatch_modes(graph.indexed_graph().num_nodes(), DispatchMode::kUndefined);
881 graph.attrs["dispatch_mode"] = std::make_shared<any>(std::move(dispatch_modes));
882 }
883 // initialize the dev_mask vector from the context vector
884 if (graph.attrs.count("dev_mask") == 0) {
885 CHECK_GT(graph.attrs.count("context"), 0);
886 DevMaskVector dev_masks(graph.indexed_graph().num_nodes());
887 const ContextVector& vctx = graph.GetAttr<ContextVector>("context");
888 for (size_t i = 0; i < vctx.size(); i++) dev_masks[i] = vctx[i].dev_mask();
889 graph.attrs["dev_mask"] = std::make_shared<any>(std::move(dev_masks));
890 }
891
892 // for storage type, the backward attr is not necessarily the same as it's correspondence
893 nnvm::Graph ret = InferAttr<int, FInferStorageType, exec::FAccessSubgraphStorageType,
894 exec::FProvideSubgraphStorageType>(
895 std::move(graph), -1,
896 "FInferStorageType", "FAccessSubgraphStorageType", "FProvideSubgraphStorageType",
897 "storage_type_inputs", "storage_type_attr_key", "storage_type",
898 "storage_type_num_unknown_nodes",
899 [](const int t) { return t == -1; },
900 common::DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable);
901
902 // log the storage types and dispatch modes of the graph
903 static bool log_verbose = dmlc::GetEnv("MXNET_INFER_STORAGE_TYPE_VERBOSE_LOGGING", false);
904 if (log_verbose) {
905 common::LogInferStorage(ret);
906 }
907 return ret;
908 }
909
910 } // namespace exec
911 } // namespace mxnet
912