1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file infer_shape.cc
22  * \brief Inference the shapes given existin information.
23  */
24 #include <nnvm/pass.h>
25 #include <nnvm/op_attr_types.h>
26 #include <nnvm/graph_attr_types.h>
27 
28 namespace nnvm {
29 namespace pass {
30 namespace {
31 
32 template<typename AttrType, typename IsNone, typename FDefault>
InferAttr(Graph && ret,const AttrType empty_val,const char * infer_name,const char * input_name,const char * attr_key_name,const char * attr_name,const char * unknown_name,IsNone fis_none,FDefault fdefault)33 Graph InferAttr(Graph &&ret,
34                 const AttrType empty_val,
35                 const char* infer_name,
36                 const char* input_name,
37                 const char* attr_key_name,
38                 const char* attr_name,
39                 const char* unknown_name,
40                 IsNone fis_none,
41                 FDefault fdefault) {
42   using AttrVector = std::vector<AttrType>;
43   const IndexedGraph& idx = ret.indexed_graph();
44   static auto& finfer_shape =
45       Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
46   static auto& is_backward =
47       Op::GetAttr<TIsBackward>("TIsBackward");
48   // gradient function, used to get node correspondence.
49   static auto& fgrad =
50       Op::GetAttr<FGradient>("FGradient");
51   // reshape shape vector
52   AttrVector rshape;
53   if (ret.attrs.count(attr_name) != 0) {
54     rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
55   } else {
56     rshape.resize(idx.num_node_entries(), empty_val);
57   }
58 
59   if (ret.attrs.count(input_name) != 0) {
60     const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
61     CHECK_LE(shape_args.size(), idx.input_nodes().size())
62         << "More provided shapes than number of arguments.";
63     for (size_t i = 0; i < shape_args.size(); ++i) {
64       rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
65     }
66     // erase the provided arguments
67     ret.attrs.erase(input_name);
68   }
69 
70   // get the shape hints
71   std::string shape_hints_key = std::string(attr_name) + "_hints";
72   if (ret.attrs.count(shape_hints_key)) {
73     NodeEntryMap<AttrType> shape_hints =
74       ret.GetAttr<NodeEntryMap<AttrType>>(shape_hints_key);
75     for (const auto& kv : shape_hints) {
76       NodeEntry e = kv.first;
77       if (idx.exist(e.node.get())) {
78         rshape[idx.entry_id(kv.first)] = kv.second;
79       }
80     }
81   }
82 
83   std::string shape_attr_key;
84   if (ret.attrs.count(attr_key_name) != 0) {
85     shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
86     // erase the provided arguments
87     ret.attrs.erase(attr_key_name);
88   } else {
89     shape_attr_key = attr_name;
90   }
91   // Temp space for shape inference.
92   std::vector<AttrType> ishape, oshape;
93 
94   // inference step function for nid
95   auto infer_step = [&](uint32_t nid, bool last_iter) {
96     const auto& inode = idx[nid];
97     const uint32_t num_inputs = inode.inputs.size();
98     const uint32_t num_outputs = inode.source->num_outputs();
99     if (inode.source->is_variable()) {
100       // Variable node. No operator. Only one output entry.
101       CHECK(inode.source->op() == nullptr);
102       CHECK_EQ(num_outputs, 1U);
103       const uint32_t out_ent_id = idx.entry_id(nid, 0);
104       if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
105         auto it = inode.source->attrs.dict.find(shape_attr_key);
106         if (it != inode.source->attrs.dict.end()) {
107           std::istringstream is(it->second);
108           CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
109         }
110       }
111     } else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
112       CHECK_GE(inode.control_deps.size(), 1U)
113         << "BackwardOp need to have control_deps to its forward op";
114       const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
115       NodePtr fwd_ptr = inode.source->control_deps[0];
116       CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
117       // use gradient function to find out the correspondence.
118       std::vector<NodeEntry> ograd(fwd_ptr->num_outputs());
119       for (size_t i = 0; i < ograd.size(); ++i) {
120         ograd[i].index = static_cast<uint32_t>(i);
121       }
122       // input gradient list
123       auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
124       const Node* igrad_node = nullptr;
125       // Input gradient assignement
126       for (size_t i = 0; i < igrad.size(); ++i) {
127         if (igrad[i].node->op() == inode.source->op()) {
128           uint32_t eid = idx.entry_id(nid, igrad[i].index);
129           if (fis_none(rshape[eid])) {
130             rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
131           } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
132             CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
133                 << "Backward shape inconsistent with the forward shape";
134           }
135           if (igrad_node == nullptr) {
136             igrad_node = igrad[i].node.get();
137           } else {
138             CHECK(igrad_node == igrad[i].node.get());
139           }
140         }
141       }
142       // out grad entries
143       CHECK(igrad_node != nullptr)
144         << "Cannot find matching backward op for " << inode.source->attrs.name;
145       for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
146         const NodeEntry& e = igrad_node->inputs[i];
147         if (e.node == nullptr) {
148           uint32_t eid = idx.entry_id(inode.inputs[i]);
149           if (fis_none(rshape[eid])) {
150             rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
151           }
152         }
153       }
154     } else {
155       bool forward_known = true;
156       // Forward operator inference.
157       ishape.resize(num_inputs, empty_val);
158       for (uint32_t i = 0; i < ishape.size(); ++i) {
159         ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
160         if (fis_none(ishape[i])) forward_known = false;
161       }
162       oshape.resize(num_outputs, empty_val);
163       for (uint32_t i = 0; i < oshape.size(); ++i) {
164         oshape[i] = rshape[idx.entry_id(nid, i)];
165         if (fis_none(oshape[i])) forward_known = false;
166       }
167       auto finfer = finfer_shape.get(inode.source->op(), fdefault);
168       if (!forward_known) {
169         if (finfer != nullptr) {
170           // Call inference function of the operator.
171           try {
172             forward_known = finfer(inode.source->attrs, &ishape, &oshape);
173           } catch (const std::exception& e) {
174             throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
175           }
176         } else {
177           CHECK(!last_iter)
178               << "Attribute " << infer_name
179               << " is not registered by op " << inode.source->op()->name
180               << " we are not able to complete the inference because of this";
181         }
182       }
183       // Save to the result map.
184       for (uint32_t i = 0; i < num_inputs; ++i) {
185         rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
186       }
187       for (uint32_t i = 0; i < num_outputs; ++i) {
188         rshape[idx.entry_id(nid, i)] = oshape[i];
189       }
190     }
191   };
192 
193   size_t last_num_unknown;
194   size_t num_unknown = rshape.size();
195   int i = 0;
196   do {
197     if (i % 2 == 0) {
198       for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
199         infer_step(nid, false);
200       }
201     } else {
202       // backward inference
203       for (uint32_t i = idx.num_nodes(); i != 0; --i) {
204         infer_step(i - 1, false);
205       }
206     }
207     last_num_unknown = num_unknown;
208     num_unknown = 0;
209     for (size_t j = 0; j < idx.num_node_entries(); ++j) {
210       if (fis_none(rshape[j])) {
211         ++num_unknown;
212       }
213     }
214     ++i;
215   } while (num_unknown > 0 && last_num_unknown > num_unknown);
216   // set the shapes
217   ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
218   // number of nodes who knows the shape.
219   ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
220   return std::move(ret);
221 }
222 
223 NNVM_REGISTER_PASS(InferShape)
224 .describe("Infer the shape of each node entries.")
__anon918239df0302(Graph ret) 225 .set_body([](Graph ret) {
226     return InferAttr<TShape>(
227         std::move(ret), TShape(),
228         "FInferShape", "shape_inputs", "shape_attr_key",
229         "shape", "shape_num_unknown_nodes",
230         [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
231         nullptr);
232   })
233 .set_change_graph(false)
234 .provide_graph_attr("shape");
235 
236 // inference function for same type
SameType(const NodeAttrs & attrs,std::vector<int> * iattr,std::vector<int> * oattr)237 inline bool SameType(const NodeAttrs& attrs,
238                      std::vector<int> *iattr,
239                      std::vector<int> *oattr) {
240   int def_v = -1;
241   for (int v : *oattr) {
242     if (v != -1) {
243       def_v = v; break;
244     }
245   }
246   if (def_v == -1) {
247     for (int v : *iattr) {
248       if (v != -1) {
249         def_v = v; break;
250       }
251     }
252   }
253   if (def_v == -1) return false;
254   for (int& v : *oattr) {
255     v = def_v;
256   }
257   for (int& v : *iattr) {
258     v = def_v;
259   }
260   return true;
261 }
262 
263 NNVM_REGISTER_PASS(InferType)
264 .describe("Infer the dtype of each node entries.")
__anon918239df0502(Graph ret) 265 .set_body([](Graph ret) {
266     return InferAttr<int>(
267         std::move(ret), -1,
268         "FInferType", "dtype_inputs", "dtype_attr_key",
269         "dtype", "dtype_num_unknown_nodes",
270         [](const int t) { return t == -1; },
271         SameType);
272   })
273 .set_change_graph(false)
274 .provide_graph_attr("dtype");
275 
276 DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
277 DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
278 DMLC_JSON_ENABLE_ANY(size_t, size_t);
279 
280 }  // namespace
281 }  // namespace pass
282 }  // namespace nnvm
283