1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file infer_graph_attr_pass.cc
22  * \brief infer graph shape, dtype, and storage type
23  */
24 
25 #include <mxnet/op_attr_types.h>
26 #include <mxnet/graph_attr_types.h>
27 #include <mxnet/imperative.h>
28 #include "./exec_pass.h"
29 #include "../operator/operator_common.h"
30 #include "../common/exec_utils.h"
31 
32 namespace mxnet {
33 namespace exec {
34 
35 template<typename AttrType, typename FInfer>
ApplyOpInferAttr(const nnvm::Graph & g,const FInfer & finfer,const NodeAttrs & attrs,const uint32_t nid,std::vector<AttrType> * in_attrs,std::vector<AttrType> * out_attrs,DispatchMode * dispatch_mode)36 bool ApplyOpInferAttr(const nnvm::Graph& g,
37                       const FInfer& finfer,
38                       const NodeAttrs& attrs,
39                       const uint32_t nid,
40                       std::vector<AttrType>* in_attrs,
41                       std::vector<AttrType>* out_attrs,
42                       DispatchMode* dispatch_mode) {
43   return finfer(attrs, in_attrs, out_attrs);
44 }
45 
46 template<>
ApplyOpInferAttr(const nnvm::Graph & g,const FInferStorageType & finfer,const NodeAttrs & attrs,const uint32_t nid,std::vector<int> * in_attrs,std::vector<int> * out_attrs,DispatchMode * dispatch_mode)47 bool ApplyOpInferAttr<int, FInferStorageType>(const nnvm::Graph& g,
48                                               const FInferStorageType& finfer,
49                                               const NodeAttrs& attrs,
50                                               const uint32_t nid,
51                                               std::vector<int>* in_attrs,
52                                               std::vector<int>* out_attrs,
53                                               DispatchMode* dispatch_mode) {
54   const DevMaskVector& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
55   const bool success = finfer(attrs, dev_masks[nid], dispatch_mode, in_attrs, out_attrs);
56   if (!success) {
57     LOG(FATAL) << "Operator not implemented: "
58                << common::operator_stype_string(attrs, dev_masks[nid], *in_attrs, *out_attrs);
59   }
60   if (*dispatch_mode == DispatchMode::kFComputeFallback) {
61     common::LogStorageFallback(attrs, dev_masks[nid], in_attrs, out_attrs);
62   }
63   return true;
64 }
65 
66 template<typename AttrType, typename IsNone>
GetAttrFromForwardNode(const uint32_t nid,const nnvm::IndexedGraph & idx,std::vector<AttrType> * rshape_ptr,std::vector<bool> * inference_finished,IsNone fis_none)67 inline void GetAttrFromForwardNode(const uint32_t nid,
68                                    const nnvm::IndexedGraph &idx,
69                                    std::vector<AttrType>* rshape_ptr,
70                                    std::vector<bool>* inference_finished,
71                                    IsNone fis_none) {
72   std::vector<AttrType>& rshape = *rshape_ptr;
73   const nnvm::IndexedGraph::Node& inode = idx[nid];
74   // gradient function, used to get node correspondence.
75   static auto& fgrad =
76       Op::GetAttr<nnvm::FGradient>("FGradient");
77   nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0];
78   const nnvm::IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
79   // use gradient function to find out the correspondence.
80   std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs());
81   for (size_t i = 0; i < ograd.size(); ++i) {
82     ograd[i].index = static_cast<uint32_t>(i);
83   }
84   // input gradient list
85   const std::vector<nnvm::NodeEntry>& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
86   const nnvm::Node* igrad_node = nullptr;
87   bool all_attrs_known = true;
88   // Input gradient assignement
89   for (size_t i = 0; i < igrad.size(); ++i) {
90     if (igrad[i].node->op() == inode.source->op()) {
91       uint32_t eid = idx.entry_id(nid, igrad[i].index);
92       if (fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
93         // Need to skip empty forward shape, because it may not be
94         // available now and it is possible to infer the forward
95         // shape in one of the next a few passes
96         all_attrs_known = false;
97       } else {
98         if (fis_none(rshape[eid])) {
99           rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
100         } else {
101           CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
102               << "Backward shape inconsistent with the forward shape";
103         }
104       }
105       if (igrad_node == nullptr) {
106         igrad_node = igrad[i].node.get();
107       } else {
108         CHECK(igrad_node == igrad[i].node.get());
109       }
110     }
111   }
112   // out grad entries
113   CHECK(igrad_node != nullptr)
114     << "Cannot find matching backward op for " << inode.source->attrs.name;
115   for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
116     const nnvm::NodeEntry& e = igrad_node->inputs[i];
117     if (e.node == nullptr) {
118       uint32_t eid = idx.entry_id(inode.inputs[i]);
119       if (fis_none(rshape[eid])) {
120         rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
121       }
122       if (fis_none(rshape[eid])) {
123         // If the attr is still unknown
124         all_attrs_known = false;
125       }
126     }
127   }
128   (*inference_finished)[nid] = all_attrs_known;
129 }
130 
131 template<typename FAccessSubgraphType, typename AttrType, typename IsNone>
GetAttrFromFusedNode(uint32_t nid,const nnvm::IndexedGraph & idx,std::vector<AttrType> * rshape_ptr,std::vector<bool> * inference_finished,IsNone fis_none,const std::string & infer_fusion_name)132 void GetAttrFromFusedNode(uint32_t nid,
133                           const nnvm::IndexedGraph& idx,
134                           std::vector<AttrType>* rshape_ptr,
135                           std::vector<bool>* inference_finished,
136                           IsNone fis_none,
137                           const std::string& infer_fusion_name) {
138   std::vector<AttrType>& rshape = *rshape_ptr;
139   const auto& inode = idx[nid];
140   // gradient function, used to get node correspondence.
141   static auto& fgrad =
142       Op::GetAttr<nnvm::FGradient>("FGradient");
143   nnvm::ObjectPtr fused_fwd_ptr = inode.source->control_deps[0];
144   static auto& finfer_fused_shape =
145     Op::GetAttr<FAccessSubgraphType>(infer_fusion_name);
146   auto finfer = finfer_fused_shape.get(fused_fwd_ptr->op(), nullptr);
147   CHECK(finfer != nullptr) << "Operator " << fused_fwd_ptr->attrs.name <<
148     " is marked as Fusion but does not allow accessing attributes";
149   const auto& inferred_attrs = finfer(fused_fwd_ptr->attrs);
150   const auto& fwd_ptr = std::get<0>(inferred_attrs);
151   const auto& input_attrs = std::get<1>(inferred_attrs);
152   const auto& output_attrs = std::get<2>(inferred_attrs);
153 
154   // use gradient function to find out the correspondence.
155   std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs());
156   for (size_t i = 0; i < ograd.size(); ++i) {
157     ograd[i].index = static_cast<uint32_t>(i);
158   }
159   // input gradient list
160   const std::vector<nnvm::NodeEntry>& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
161   const nnvm::Node* igrad_node = nullptr;
162   bool all_attrs_known = true;
163   // Set the attributes of output gradients
164   // using attributes of forward node inputs
165   for (size_t i = 0; i < igrad.size(); ++i) {
166     if (igrad[i].node->op() == inode.source->op()) {
167       uint32_t eid = idx.entry_id(nid, igrad[i].index);
168       if (fis_none(input_attrs[i])) {
169         // Need to skip empty forward shape, because it may not be
170         // available now and it is possible to infer the forward
171         // shape in one of the next a few passes
172         all_attrs_known = false;
173       } else {
174         if (fis_none(rshape[eid])) {
175           rshape[eid] = input_attrs[i];
176         } else {
177           CHECK_EQ(rshape[eid], input_attrs[i])
178               << "Backward shape inconsistent with the forward shape";
179         }
180       }
181       if (igrad_node == nullptr) {
182         igrad_node = igrad[i].node.get();
183       } else {
184         CHECK(igrad_node == igrad[i].node.get());
185       }
186     }
187   }
188 
189   // Set the attributes of input gradients
190   // using attributes of forward node outputs
191   CHECK(igrad_node != nullptr)
192     << "Cannot find matching backward op for " << inode.source->attrs.name;
193   for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
194     const nnvm::NodeEntry& e = igrad_node->inputs[i];
195     if (e.node == nullptr) {
196       uint32_t eid = idx.entry_id(inode.inputs[i]);
197       if (fis_none(rshape[eid])) {
198         rshape[eid] = output_attrs[e.index];
199       }
200       if (fis_none(rshape[eid])) {
201         // If the attr is still unknown
202         all_attrs_known = false;
203       }
204     }
205   }
206   (*inference_finished)[nid] = all_attrs_known;
207 }
208 
209 template <typename FProvideSubgraphType, typename AttrType>
ProvideAttrToFusion(const uint32_t nid,const nnvm::IndexedGraph & idx,const std::vector<AttrType> & rshape,const std::string & provide_fusion_name)210 void ProvideAttrToFusion(const uint32_t nid,
211                          const nnvm::IndexedGraph& idx,
212                          const std::vector<AttrType>& rshape,
213                          const std::string& provide_fusion_name) {
214   const auto& inode = idx[nid];
215   std::vector<std::vector<AttrType>> in_attrs;
216   std::vector<std::vector<AttrType>> out_attrs;
217   for (const auto& dep_node : inode.source->control_deps) {
218     in_attrs.push_back({});
219     out_attrs.push_back({});
220     auto &current_in_attrs = in_attrs.back();
221     auto &current_out_attrs = out_attrs.back();
222     uint32_t dep_node_id = idx.node_id(dep_node.get());
223     for (const auto& e : idx[dep_node_id].inputs) {
224       current_in_attrs.push_back(rshape[idx.entry_id(e)]);
225     }
226     for (size_t i = 0; i < dep_node->num_outputs(); ++i) {
227       current_out_attrs.push_back(rshape[idx.entry_id(dep_node_id, i)]);
228     }
229   }
230   auto provide =
231     Op::GetAttr<FProvideSubgraphType>(provide_fusion_name).get(inode.source->op(), nullptr);
232   CHECK(provide != nullptr) <<
233     "Encountered Fusion operator that does not implement providing subgraph attr " <<
234     provide_fusion_name << ".";
235   provide(inode.source->attrs, inode.source->control_deps, in_attrs, out_attrs);
236 }
237 
238 /*!\brief
239  * This is a duplicate of the InferAttr function in nnvm with minor modification
240  * to support inferring storage type whose function signature is different from
241  * shape/type inference functions'. The nnvm InferAttr will be deprecated
242  * in the future. Please use interfaces InferShape, InferType, and InferStorageType
243  * to call this function.
244  *
245  * \param ret graph used for attribute inference
246  * \param emmpty_val empty value of the attribute
247  * \param infer_name name of the function used for attribute inference
248  * \param infer_fusion_name name of the function used for accessing attributes in fused nodes
249  * \param input_name name of the attribute in the graph used to store the
250  *                   input data for attribute inference
251  * \param attr_key_name name of the attribute used for inference for variable nodes
252  * \param attr_name name of the inferred attribute
253  * \param unknown_name name of the attribute storing number of entries
254  *                     impossible to infer
255  * \param fis_none function returning true for not fully inferred values
256  * \param fdefault default function used for inference if the node does not
257  *                 provide its own implementation.
258  * \param bwd_identity_assign whether the attributes of forward NDArray and backward
259  *                            NDArray have to be the same. False only for storage
260  *                            type inference
261  * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
262  *                           storage type inference
263  * \param default_mode_val default value of the dispatch mode attribute on the node. Used
264  *                         for storage type inference
265  */
266 template<typename AttrType, typename FInferType, typename FAccessSubgraphType,
267          typename FProvideSubgraphType, typename IsNone, typename FDefault>
InferAttr(nnvm::Graph && ret,const AttrType empty_val,const char * infer_name,const char * infer_fusion_name,const char * provide_fusion_name,const char * input_name,const char * attr_key_name,const char * attr_name,const char * unknown_name,IsNone fis_none,FDefault fdefault,bool bwd_identity_assign,const char * dispatch_mode_name,const DispatchMode default_mode_val=DispatchMode::kUndefined)268 nnvm::Graph InferAttr(nnvm::Graph &&ret,
269                       const AttrType empty_val,
270                       const char* infer_name,
271                       const char* infer_fusion_name,
272                       const char* provide_fusion_name,
273                       const char* input_name,
274                       const char* attr_key_name,
275                       const char* attr_name,
276                       const char* unknown_name,
277                       IsNone fis_none,
278                       FDefault fdefault,
279                       bool bwd_identity_assign,
280                       const char* dispatch_mode_name,
281                       const DispatchMode default_mode_val = DispatchMode::kUndefined) {
282   using nnvm::IndexedGraph;
283   using nnvm::Op;
284   using AttrVector = std::vector<AttrType>;
285   using NodeAttrVector = std::vector<DispatchMode>;
286   using dmlc::any;
287 
288   const IndexedGraph& idx = ret.indexed_graph();
289   static auto& finfer_shape =
290       Op::GetAttr<FInferType>(infer_name);
291   static auto& is_backward =
292       Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
293   // reshape shape vector
294   AttrVector rshape;
295   // vector holding information which operators
296   // finished attribute inference
297   std::vector<bool> inference_finished(idx.num_nodes(), false);
298   // dispatch mode vector
299   DispatchModeVector dispatch_modes;
300   if (ret.attrs.count(attr_name) != 0) {
301     rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
302   } else {
303     rshape.resize(idx.num_node_entries(), empty_val);
304   }
305 
306   if (ret.attrs.count(input_name) != 0) {
307     const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
308     CHECK_LE(shape_args.size(), idx.input_nodes().size())
309         << "More provided " << attr_name << "s than number of arguments.";
310     for (size_t i = 0; i < shape_args.size(); ++i) {
311       rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
312     }
313   }
314 
315   // get the shape hints
316   std::string shape_hints_key = std::string(attr_name) + "_hints";
317   if (ret.attrs.count(shape_hints_key)) {
318     nnvm::NodeEntryMap<AttrType> shape_hints =
319       ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key);
320     for (const auto& kv : shape_hints) {
321       nnvm::NodeEntry e = kv.first;
322       if (idx.exist(e.node.get())) {
323         rshape[idx.entry_id(kv.first)] = kv.second;
324       }
325     }
326   }
327 
328   std::string shape_attr_key;
329   if (ret.attrs.count(attr_key_name) != 0) {
330     shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
331     // erase the provided arguments
332     ret.attrs.erase(attr_key_name);
333   }
334 
335   // limit inference to part of the graph
336   uint32_t node_start = 0, node_end = idx.num_nodes();
337   if (ret.attrs.count("node_range")) {
338     const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
339     node_start = range.first;
340     node_end = range.second;
341     CHECK_GE(node_start, 0);
342     CHECK_LE(node_end, idx.num_nodes());
343     ret.attrs.erase("node_range");
344   }
345   uint32_t entry_start = 0, entry_end = idx.num_node_entries();
346   if (ret.attrs.count("entry_range")) {
347     const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range");
348     entry_start = range.first;
349     entry_end = range.second;
350     CHECK_GE(entry_start, 0);
351     CHECK_LE(entry_end, idx.num_node_entries());
352     ret.attrs.erase("entry_range");
353   }
354   // populate the node attribute vector
355   if (dispatch_mode_name != nullptr) {
356     if (ret.attrs.count(dispatch_mode_name) != 0) {
357       dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name);
358     } else {
359       LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph";
360     }
361   }
362 
363   // Temp space for shape inference.
364   std::vector<AttrType> ishape, oshape;
365 
366   // inference step function for nid
367   auto infer_step = [&](uint32_t nid, bool last_iter) {
368     if (inference_finished[nid]) return;
369     const auto& inode = idx[nid];
370     const uint32_t num_inputs = inode.inputs.size();
371     const uint32_t num_outputs = inode.source->num_outputs();
372     if (inode.source->is_variable()) {
373       // Variable node. No operator. Only one output entry.
374       CHECK(inode.source->op() == nullptr);
375       CHECK_EQ(num_outputs, 1U);
376       const uint32_t out_ent_id = idx.entry_id(nid, 0);
377       if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
378         auto it = inode.source->attrs.dict.find(shape_attr_key);
379         if (it != inode.source->attrs.dict.end()) {
380           std::istringstream is(it->second);
381           CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
382         }
383       }
384       if (!fis_none(rshape[out_ent_id])) {
385         inference_finished[nid] = true;
386       }
387       // assign a default value to node attribute
388       if (dispatch_mode_name != nullptr) {
389         op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val);
390       }
391     } else if (is_backward.get(inode.source->op(), false) &&
392                inode.source->control_deps.size() && bwd_identity_assign) {
393       CHECK(dispatch_mode_name == nullptr)
394         << "Backward inference for node attributes is not available";
395       CHECK_GE(inode.source->control_deps.size(), 1U)
396         << "BackwardOp need to have control_deps to its forward op";
397       nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0];
398       CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
399 
400       static auto& is_fusion_helper = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
401       if (!is_fusion_helper.get(fwd_ptr->op(), false)) {
402         GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none);
403       } else {
404         GetAttrFromFusedNode<FAccessSubgraphType>(nid, idx, &rshape, &inference_finished,
405                                                   fis_none, infer_fusion_name);
406       }
407     } else {
408       DispatchMode* dispatch_mode = nullptr;
409       // Forward operator inference.
410       ishape.resize(num_inputs, empty_val);
411       for (uint32_t i = 0; i < ishape.size(); ++i) {
412         ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
413       }
414       oshape.resize(num_outputs, empty_val);
415       for (uint32_t i = 0; i < oshape.size(); ++i) {
416         oshape[i] = rshape[idx.entry_id(nid, i)];
417       }
418       if (dispatch_mode_name != nullptr) {
419         dispatch_mode = &dispatch_modes[nid];
420       }
421       auto finfer = finfer_shape.get(inode.source->op(), fdefault);
422       if (finfer != nullptr) {
423         // Call inference function of the operator.
424         try {
425           static auto& is_fusion = Op::GetAttr<exec::TIsFusion>("TIsFusion");
426           if (is_fusion.get(inode.source->op(), false)) {
427             ProvideAttrToFusion<FProvideSubgraphType>(nid, idx, rshape, provide_fusion_name);
428           }
429           ApplyOpInferAttr(ret, finfer, inode.source->attrs,
430                            nid, &ishape, &oshape, dispatch_mode);
431           bool finished = true;
432           for (const auto& attr : ishape) {
433             if (fis_none(attr)) finished = false;
434           }
435           for (const auto& attr : oshape) {
436             if (fis_none(attr)) finished = false;
437           }
438           inference_finished[nid] = finished;
439         } catch (const std::exception& e) {
440           throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
441         }
442       } else {
443         // Operator does not provide sttribute inference function,
444         // so we need to test if everything was inferred by other operators
445         bool all_attrs_known = true;
446         for (const auto& attr : ishape) {
447           if (fis_none(attr)) {
448             all_attrs_known = false;
449           }
450         }
451         for (const auto& attr : oshape) {
452           if (fis_none(attr)) {
453             all_attrs_known = false;
454           }
455         }
456         inference_finished[nid] = all_attrs_known;
457         if (!all_attrs_known) {
458           CHECK(!last_iter)
459               << "Attribute " << infer_name
460               << " is not registered by op " << inode.source->op()->name
461               << ". We are not able to complete the inference because of this";
462         }
463       }
464       // Save to the result map.
465       for (uint32_t i = 0; i < num_inputs; ++i) {
466         rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
467       }
468       for (uint32_t i = 0; i < num_outputs; ++i) {
469         rshape[idx.entry_id(nid, i)] = oshape[i];
470       }
471     }
472   };
473 
474   size_t last_num_unknown;
475   size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0;
476   size_t num_unknown_entry_attr = entry_end - entry_start;
477   size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
478   bool last_iter = false;
479   bool do_next_iteration = true;
480   int i = 0;
481   do {
482     if (i % 2 == 0) {
483       for (uint32_t nid = node_start; nid < node_end; ++nid) {
484         infer_step(nid, last_iter);
485       }
486     } else {
487       // backward inference
488       for (uint32_t i = node_end; i != node_start; --i) {
489         infer_step(i - 1, last_iter);
490       }
491     }
492     last_num_unknown = num_unknown;
493     num_unknown = 0;
494     for (size_t j = entry_start; j < entry_end; ++j) {
495       if (fis_none(rshape[j])) {
496         ++num_unknown;
497       }
498     }
499     if (dispatch_mode_name) {
500       for (size_t i = node_start; i < node_end; i++) {
501         if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown;
502       }
503     }
504     do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown;
505     if (!do_next_iteration && !last_iter) {
506       // Check if every op agrees that it should be
507       // the end of attribute inference. If not,
508       // perform one final step
509       for (const bool done : inference_finished) {
510         do_next_iteration = do_next_iteration || !done;
511       }
512       last_iter = true;
513     }
514     ++i;
515   } while (do_next_iteration);
516   // set the shapes
517   ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
518   // set the shapes
519   if (dispatch_mode_name) {
520     ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes));
521   }
522   // number of nodes who knows the shape.
523   ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
524   return ret;
525 }
526 
527 /*!\brief
528  * This is a version of the InferAttr function specifically for shape inference.
529  *
530  * \param ret graph used for attribute inference
531  * \param emmpty_val empty value of the attribute
532  * \param infer_name name of the function used for attribute inference
533  * \param input_name name of the attribute in the graph used to store the
534  *                   input data for attribute inference
535  * \param attr_key_name name of the attribute used for inference for variable nodes
536  * \param attr_name name of the inferred attribute
537  * \param unknown_name name of the attribute storing number of entries
538  *                     impossible to infer
539  * \param fis_none function returning true for not fully inferred values
540  * \param fnum_unknown function returning how many elements are unknown in
541  *                     partially inferred value of the attribute
542  * \param fdefault default function used for inference if the node does not
543  *                 provide its own implementation.
544  * \param bwd_identity_assign whether the attributes of forward NDArray and backward
545  *                            NDArray have to be the same. False only for storage
546  *                            type inference
547  * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
548  *                           storage type inference
549  * \param default_mode_val default value of the dispatch mode attribute on the node. Used
550  *                         for storage type inference
551  */
552 template<typename IsNone, typename FDefault, typename FNumUnknown>
InferShapeAttr(nnvm::Graph && ret,const mxnet::TShape empty_val,const char * infer_name,const char * input_name,const char * attr_key_name,const char * attr_name,const char * unknown_name,IsNone fis_none,FNumUnknown fnum_unknown,FDefault fdefault,bool bwd_identity_assign,const char * dispatch_mode_name,const DispatchMode default_mode_val=DispatchMode::kUndefined)553 nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
554                            const mxnet::TShape empty_val,
555                            const char* infer_name,
556                            const char* input_name,
557                            const char* attr_key_name,
558                            const char* attr_name,
559                            const char* unknown_name,
560                            IsNone fis_none,
561                            FNumUnknown fnum_unknown,
562                            FDefault fdefault,
563                            bool bwd_identity_assign,
564                            const char* dispatch_mode_name,
565                            const DispatchMode default_mode_val = DispatchMode::kUndefined) {
566   using nnvm::IndexedGraph;
567   using nnvm::Op;
568   using AttrType = mxnet::TShape;
569   using FInferType = mxnet::FInferShape;
570   using AttrVector = std::vector<AttrType>;
571   using NodeAttrVector = std::vector<DispatchMode>;
572   using dmlc::any;
573   const IndexedGraph& idx = ret.indexed_graph();
574   static auto& finfer_shape =
575       Op::GetAttr<FInferType>(infer_name);
576   static auto& is_backward =
577       Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
578   // reshape shape vector
579   AttrVector rshape;
580   // vector holding information which operators
581   // finished attribute inference
582   std::vector<bool> inference_finished(idx.num_nodes(), false);
583   // dispatch mode vector
584   DispatchModeVector dispatch_modes;
585   if (ret.attrs.count(attr_name) != 0) {
586     rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
587   } else {
588     rshape.resize(idx.num_node_entries(), empty_val);
589   }
590 
591   if (ret.attrs.count(input_name) != 0) {
592     const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
593     CHECK_LE(shape_args.size(), idx.input_nodes().size())
594         << "More provided " << attr_name << "s than number of arguments.";
595     for (size_t i = 0; i < shape_args.size(); ++i) {
596       rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
597     }
598   }
599 
600   // get the shape hints
601   std::string shape_hints_key = std::string(attr_name) + "_hints";
602   if (ret.attrs.count(shape_hints_key)) {
603     nnvm::NodeEntryMap<AttrType> shape_hints =
604       ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key);
605     for (const auto& kv : shape_hints) {
606       nnvm::NodeEntry e = kv.first;
607       if (idx.exist(e.node.get())) {
608         rshape[idx.entry_id(kv.first)] = kv.second;
609       }
610     }
611   }
612 
613   std::string shape_attr_key;
614   if (ret.attrs.count(attr_key_name) != 0) {
615     shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
616     // erase the provided arguments
617     ret.attrs.erase(attr_key_name);
618   }
619 
620   // limit inference to part of the graph
621   uint32_t node_start = 0, node_end = idx.num_nodes();
622   if (ret.attrs.count("node_range")) {
623     const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
624     node_start = range.first;
625     node_end = range.second;
626     CHECK_GE(node_start, 0);
627     CHECK_LE(node_end, idx.num_nodes());
628     ret.attrs.erase("node_range");
629   }
630   uint32_t entry_start = 0, entry_end = idx.num_node_entries();
631   if (ret.attrs.count("entry_range")) {
632     const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range");
633     entry_start = range.first;
634     entry_end = range.second;
635     CHECK_GE(entry_start, 0);
636     CHECK_LE(entry_end, idx.num_node_entries());
637     ret.attrs.erase("entry_range");
638   }
639   // populate the node attribute vector
640   if (dispatch_mode_name != nullptr) {
641     if (ret.attrs.count(dispatch_mode_name) != 0) {
642       dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name);
643     } else {
644       LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph";
645     }
646   }
647 
648   // Temp space for shape inference.
649   std::vector<AttrType> ishape, oshape;
650   // whether a shape is dynamic
651   std::vector<int> is_dynamic(rshape.size(), 0);
652 
653   // convert to numpy compatible shape to use operator's infer shape function
654   if (!Imperative::Get()->is_np_shape()) {
655     common::ConvertToNumpyShape(&rshape);
656   }
657 
658   // inference step function for nid
659   auto infer_step = [&](uint32_t nid, bool last_iter) {
660     if (inference_finished[nid]) return;
661     const auto& inode = idx[nid];
662     const std::string name = inode.source->attrs.name;
663     const uint32_t num_inputs = inode.inputs.size();
664     const uint32_t num_outputs = inode.source->num_outputs();
665 
666     if (inode.source->is_variable()) {
667       // Variable node. No operator. Only one output entry.
668       CHECK(inode.source->op() == nullptr);
669       CHECK_EQ(num_outputs, 1U);
670       const uint32_t out_ent_id = idx.entry_id(nid, 0);
671       if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
672         auto it = inode.source->attrs.dict.find(shape_attr_key);
673         if (it != inode.source->attrs.dict.end()) {
674           std::istringstream is(it->second);
675           CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
676           if (!Imperative::Get()->is_np_shape()) {
677             common::ConvertToNumpyShape(&rshape[out_ent_id]);
678           }
679         }
680       }
681       if (!fis_none(rshape[out_ent_id])) {
682         inference_finished[nid] = true;
683       }
684       // assign a default value to node attribute
685       if (dispatch_mode_name != nullptr) {
686         op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val);
687       }
688     } else if (is_backward.get(inode.source->op(), false) &&
689                inode.source->control_deps.size() && bwd_identity_assign) {
690       CHECK(dispatch_mode_name == nullptr)
691         << "Backward inference for node attributes is not available";
692       CHECK_GE(inode.source->control_deps.size(), 1U)
693         << "BackwardOp need to have control_deps to its forward op";
694       nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0];
695       CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
696 
697       static auto& is_fusion_helper = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
698       if (!is_fusion_helper.get(fwd_ptr->op(), false)) {
699         GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none);
700       } else {
701         GetAttrFromFusedNode<exec::FAccessSubgraphShape>(nid, idx, &rshape,
702                                                          &inference_finished,
703                                                          fis_none,
704                                                          "FAccessSubgraphShape");
705       }
706     } else {
707       DispatchMode* dispatch_mode = nullptr;
708       // Forward operator inference.
709       ishape.resize(num_inputs, empty_val);
710       bool is_input_dynamic_shape = false;
711       for (uint32_t i = 0; i < ishape.size(); ++i) {
712         ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
713         if (!mxnet::ndim_is_known(ishape[i]) && is_dynamic[idx.entry_id(inode.inputs[i])]) {
714           is_input_dynamic_shape = true;
715         }
716       }
717       oshape.resize(num_outputs, empty_val);
718       for (uint32_t i = 0; i < oshape.size(); ++i) {
719         oshape[i] = rshape[idx.entry_id(nid, i)];
720       }
721       if (dispatch_mode_name != nullptr) {
722         dispatch_mode = &dispatch_modes[nid];
723       }
724       auto finfer = finfer_shape.get(inode.source->op(), fdefault);
725       if (finfer == nullptr || is_input_dynamic_shape) {
726         for (uint32_t i = 0; i < oshape.size(); ++i) {
727           if (!mxnet::ndim_is_known(oshape[i].ndim())) {
728             is_dynamic[idx.entry_id(nid, i)] = 1;
729           }
730         }
731         inference_finished[nid] = true;
732       } else {
733         // Call inference function of the operator.
734         try {
735           static auto& is_fusion = Op::GetAttr<exec::TIsFusion>("TIsFusion");
736           if (is_fusion.get(inode.source->op(), false)) {
737             ProvideAttrToFusion<exec::FProvideSubgraphShape>(nid, idx, rshape,
738                                                              "FProvideSubgraphShape");
739           }
740           ApplyOpInferAttr(ret, finfer, inode.source->attrs,
741                            nid, &ishape, &oshape, dispatch_mode);
742           bool finished = true;
743           for (const auto& attr : ishape) {
744             if (fis_none(attr)) finished = false;
745           }
746           for (const auto& attr : oshape) {
747             if (fis_none(attr)) finished = false;
748           }
749           inference_finished[nid] = finished;
750         } catch (const std::exception& e) {
751           throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
752         }
753       }
754       // Save to the result map.
755       for (uint32_t i = 0; i < num_inputs; ++i) {
756         rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
757       }
758       for (uint32_t i = 0; i < num_outputs; ++i) {
759         rshape[idx.entry_id(nid, i)] = oshape[i];
760       }
761     }
762   };
763 
764   size_t last_num_unknown;
765   size_t num_unknown = static_cast<size_t>(-1);  // Infinity
766   bool last_iter = false;
767   bool do_next_iteration = true;
768 
769   int i = 0;
770   do {
771     if (i % 2 == 0) {
772       // forward inference
773       for (uint32_t nid = node_start; nid < node_end; ++nid) {
774         infer_step(nid, last_iter);
775       }
776     } else {
777       // backward inference
778       for (uint32_t i = node_end; i != node_start; --i) {
779         infer_step(i - 1, last_iter);
780       }
781     }
782     last_num_unknown = num_unknown;
783     num_unknown = 0;
784     for (size_t j = entry_start; j < entry_end; ++j) {
785       if (fis_none(rshape[j])) {
786         num_unknown += fnum_unknown(rshape[j]);
787       }
788     }
789     if (dispatch_mode_name) {
790       for (size_t i = node_start; i < node_end; i++) {
791         if (dispatch_modes[i] == DispatchMode::kUndefined) {
792           ++num_unknown;
793         }
794       }
795     }
796     do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown;
797     if (!do_next_iteration && !last_iter) {
798       // Check if every op agrees that it should be
799       // the end of attribute inference. If not,
800       // perform one final step
801       for (const bool done : inference_finished) {
802         do_next_iteration = do_next_iteration || !done;
803       }
804       last_iter = true;
805     }
806     ++i;
807   } while (do_next_iteration);
808   // set the shapes
809   ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
810   // set the shapes
811   if (dispatch_mode_name) {
812     ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes));
813   }
814   // number of nodes who knows the shape.
815   ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
816   return ret;
817 }
818 
InferShape(nnvm::Graph && graph,mxnet::ShapeVector && shape_inputs,const std::string & shape_attr_key)819 nnvm::Graph InferShape(nnvm::Graph&& graph,
820                        mxnet::ShapeVector&& shape_inputs,
821                        const std::string& shape_attr_key) {
822   using dmlc::any;
823   if (shape_inputs.size() != 0) {
824     graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
825   }
826   if (shape_attr_key.length() != 0) {
827     graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
828   }
829   return InferShapeAttr(
830       std::move(graph), mxnet::TShape(),
831       "FInferShape", "shape_inputs", "shape_attr_key",
832       "shape", "shape_num_unknown_nodes",
833       [](const mxnet::TShape& s) { return !mxnet::shape_is_known(s); },
834       [](const mxnet::TShape& s) {
835         if (!mxnet::ndim_is_known(s)) {
836           return static_cast<size_t>(1);
837         }
838         size_t ret = 0;
839         for (const auto& val : s) {
840           if (!mxnet::dim_size_is_known(val)) {
841             ++ret;
842           }
843         }
844         return ret;
845       },
846       nullptr, true, nullptr);
847 }
848 
InferType(nnvm::Graph && graph,nnvm::DTypeVector && dtype_inputs,const std::string & dtype_attr_key)849 nnvm::Graph InferType(nnvm::Graph&& graph,
850                       nnvm::DTypeVector&& dtype_inputs,
851                       const std::string& dtype_attr_key) {
852   using dmlc::any;
853   if (dtype_inputs.size() != 0) {
854     graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
855   }
856   if (dtype_attr_key.length() != 0) {
857     graph.attrs["dtype_attr_key"] = std::make_shared<any>(dtype_attr_key);
858   }
859   return InferAttr<int, nnvm::FInferType, exec::FAccessSubgraphType,
860                    exec::FProvideSubgraphType>(
861       std::move(graph), -1,
862       "FInferType", "FAccessSubgraphType", "FProvideSubgraphType",
863       "dtype_inputs", "dtype_attr_key", "dtype", "dtype_num_unknown_nodes",
864       [](const int t) { return t == -1; },
865       common::SameType, true, nullptr);
866 }
867 
InferStorageType(nnvm::Graph && graph,StorageTypeVector && storage_type_inputs,const std::string & storage_type_attr_key)868 nnvm::Graph InferStorageType(nnvm::Graph&& graph,
869                              StorageTypeVector&& storage_type_inputs,
870                              const std::string& storage_type_attr_key) {
871   using dmlc::any;
872   if (storage_type_inputs.size() != 0) {
873     graph.attrs["storage_type_inputs"] = std::make_shared<any>(std::move(storage_type_inputs));
874   }
875   if (storage_type_attr_key.length() != 0) {
876     graph.attrs["storage_type_attr_key"] = std::make_shared<any>(storage_type_attr_key);
877   }
878   // initialize unknown values for dispatch modes
879   if (graph.attrs.count("dispatch_mode") == 0) {
880     DispatchModeVector dispatch_modes(graph.indexed_graph().num_nodes(), DispatchMode::kUndefined);
881     graph.attrs["dispatch_mode"] = std::make_shared<any>(std::move(dispatch_modes));
882   }
883   // initialize the dev_mask vector from the context vector
884   if (graph.attrs.count("dev_mask") == 0) {
885     CHECK_GT(graph.attrs.count("context"), 0);
886     DevMaskVector dev_masks(graph.indexed_graph().num_nodes());
887     const ContextVector& vctx = graph.GetAttr<ContextVector>("context");
888     for (size_t i = 0; i < vctx.size(); i++) dev_masks[i] = vctx[i].dev_mask();
889     graph.attrs["dev_mask"] = std::make_shared<any>(std::move(dev_masks));
890   }
891 
892   // for storage type, the backward attr is not necessarily the same as it's correspondence
893   nnvm::Graph ret = InferAttr<int, FInferStorageType, exec::FAccessSubgraphStorageType,
894                               exec::FProvideSubgraphStorageType>(
895       std::move(graph), -1,
896       "FInferStorageType", "FAccessSubgraphStorageType", "FProvideSubgraphStorageType",
897       "storage_type_inputs", "storage_type_attr_key", "storage_type",
898       "storage_type_num_unknown_nodes",
899       [](const int t) { return t == -1; },
900       common::DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable);
901 
902   // log the storage types and dispatch modes of the graph
903   static bool log_verbose = dmlc::GetEnv("MXNET_INFER_STORAGE_TYPE_VERBOSE_LOGGING", false);
904   if (log_verbose) {
905     common::LogInferStorage(ret);
906   }
907   return ret;
908 }
909 
910 }  // namespace exec
911 }  // namespace mxnet
912