1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file infer_shape.cc
22 * \brief Inference the shapes given existin information.
23 */
24 #include <nnvm/pass.h>
25 #include <nnvm/op_attr_types.h>
26 #include <nnvm/graph_attr_types.h>
27
28 namespace nnvm {
29 namespace pass {
30 namespace {
31
32 template<typename AttrType, typename IsNone, typename FDefault>
InferAttr(Graph && ret,const AttrType empty_val,const char * infer_name,const char * input_name,const char * attr_key_name,const char * attr_name,const char * unknown_name,IsNone fis_none,FDefault fdefault)33 Graph InferAttr(Graph &&ret,
34 const AttrType empty_val,
35 const char* infer_name,
36 const char* input_name,
37 const char* attr_key_name,
38 const char* attr_name,
39 const char* unknown_name,
40 IsNone fis_none,
41 FDefault fdefault) {
42 using AttrVector = std::vector<AttrType>;
43 const IndexedGraph& idx = ret.indexed_graph();
44 static auto& finfer_shape =
45 Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
46 static auto& is_backward =
47 Op::GetAttr<TIsBackward>("TIsBackward");
48 // gradient function, used to get node correspondence.
49 static auto& fgrad =
50 Op::GetAttr<FGradient>("FGradient");
51 // reshape shape vector
52 AttrVector rshape;
53 if (ret.attrs.count(attr_name) != 0) {
54 rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
55 } else {
56 rshape.resize(idx.num_node_entries(), empty_val);
57 }
58
59 if (ret.attrs.count(input_name) != 0) {
60 const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
61 CHECK_LE(shape_args.size(), idx.input_nodes().size())
62 << "More provided shapes than number of arguments.";
63 for (size_t i = 0; i < shape_args.size(); ++i) {
64 rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
65 }
66 // erase the provided arguments
67 ret.attrs.erase(input_name);
68 }
69
70 // get the shape hints
71 std::string shape_hints_key = std::string(attr_name) + "_hints";
72 if (ret.attrs.count(shape_hints_key)) {
73 NodeEntryMap<AttrType> shape_hints =
74 ret.GetAttr<NodeEntryMap<AttrType>>(shape_hints_key);
75 for (const auto& kv : shape_hints) {
76 NodeEntry e = kv.first;
77 if (idx.exist(e.node.get())) {
78 rshape[idx.entry_id(kv.first)] = kv.second;
79 }
80 }
81 }
82
83 std::string shape_attr_key;
84 if (ret.attrs.count(attr_key_name) != 0) {
85 shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
86 // erase the provided arguments
87 ret.attrs.erase(attr_key_name);
88 } else {
89 shape_attr_key = attr_name;
90 }
91 // Temp space for shape inference.
92 std::vector<AttrType> ishape, oshape;
93
94 // inference step function for nid
95 auto infer_step = [&](uint32_t nid, bool last_iter) {
96 const auto& inode = idx[nid];
97 const uint32_t num_inputs = inode.inputs.size();
98 const uint32_t num_outputs = inode.source->num_outputs();
99 if (inode.source->is_variable()) {
100 // Variable node. No operator. Only one output entry.
101 CHECK(inode.source->op() == nullptr);
102 CHECK_EQ(num_outputs, 1U);
103 const uint32_t out_ent_id = idx.entry_id(nid, 0);
104 if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
105 auto it = inode.source->attrs.dict.find(shape_attr_key);
106 if (it != inode.source->attrs.dict.end()) {
107 std::istringstream is(it->second);
108 CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
109 }
110 }
111 } else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
112 CHECK_GE(inode.control_deps.size(), 1U)
113 << "BackwardOp need to have control_deps to its forward op";
114 const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
115 NodePtr fwd_ptr = inode.source->control_deps[0];
116 CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
117 // use gradient function to find out the correspondence.
118 std::vector<NodeEntry> ograd(fwd_ptr->num_outputs());
119 for (size_t i = 0; i < ograd.size(); ++i) {
120 ograd[i].index = static_cast<uint32_t>(i);
121 }
122 // input gradient list
123 auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
124 const Node* igrad_node = nullptr;
125 // Input gradient assignement
126 for (size_t i = 0; i < igrad.size(); ++i) {
127 if (igrad[i].node->op() == inode.source->op()) {
128 uint32_t eid = idx.entry_id(nid, igrad[i].index);
129 if (fis_none(rshape[eid])) {
130 rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
131 } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
132 CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
133 << "Backward shape inconsistent with the forward shape";
134 }
135 if (igrad_node == nullptr) {
136 igrad_node = igrad[i].node.get();
137 } else {
138 CHECK(igrad_node == igrad[i].node.get());
139 }
140 }
141 }
142 // out grad entries
143 CHECK(igrad_node != nullptr)
144 << "Cannot find matching backward op for " << inode.source->attrs.name;
145 for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
146 const NodeEntry& e = igrad_node->inputs[i];
147 if (e.node == nullptr) {
148 uint32_t eid = idx.entry_id(inode.inputs[i]);
149 if (fis_none(rshape[eid])) {
150 rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
151 }
152 }
153 }
154 } else {
155 bool forward_known = true;
156 // Forward operator inference.
157 ishape.resize(num_inputs, empty_val);
158 for (uint32_t i = 0; i < ishape.size(); ++i) {
159 ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
160 if (fis_none(ishape[i])) forward_known = false;
161 }
162 oshape.resize(num_outputs, empty_val);
163 for (uint32_t i = 0; i < oshape.size(); ++i) {
164 oshape[i] = rshape[idx.entry_id(nid, i)];
165 if (fis_none(oshape[i])) forward_known = false;
166 }
167 auto finfer = finfer_shape.get(inode.source->op(), fdefault);
168 if (!forward_known) {
169 if (finfer != nullptr) {
170 // Call inference function of the operator.
171 try {
172 forward_known = finfer(inode.source->attrs, &ishape, &oshape);
173 } catch (const std::exception& e) {
174 throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
175 }
176 } else {
177 CHECK(!last_iter)
178 << "Attribute " << infer_name
179 << " is not registered by op " << inode.source->op()->name
180 << " we are not able to complete the inference because of this";
181 }
182 }
183 // Save to the result map.
184 for (uint32_t i = 0; i < num_inputs; ++i) {
185 rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
186 }
187 for (uint32_t i = 0; i < num_outputs; ++i) {
188 rshape[idx.entry_id(nid, i)] = oshape[i];
189 }
190 }
191 };
192
193 size_t last_num_unknown;
194 size_t num_unknown = rshape.size();
195 int i = 0;
196 do {
197 if (i % 2 == 0) {
198 for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
199 infer_step(nid, false);
200 }
201 } else {
202 // backward inference
203 for (uint32_t i = idx.num_nodes(); i != 0; --i) {
204 infer_step(i - 1, false);
205 }
206 }
207 last_num_unknown = num_unknown;
208 num_unknown = 0;
209 for (size_t j = 0; j < idx.num_node_entries(); ++j) {
210 if (fis_none(rshape[j])) {
211 ++num_unknown;
212 }
213 }
214 ++i;
215 } while (num_unknown > 0 && last_num_unknown > num_unknown);
216 // set the shapes
217 ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
218 // number of nodes who knows the shape.
219 ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
220 return std::move(ret);
221 }
222
223 NNVM_REGISTER_PASS(InferShape)
224 .describe("Infer the shape of each node entries.")
__anon918239df0302(Graph ret) 225 .set_body([](Graph ret) {
226 return InferAttr<TShape>(
227 std::move(ret), TShape(),
228 "FInferShape", "shape_inputs", "shape_attr_key",
229 "shape", "shape_num_unknown_nodes",
230 [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
231 nullptr);
232 })
233 .set_change_graph(false)
234 .provide_graph_attr("shape");
235
236 // inference function for same type
SameType(const NodeAttrs & attrs,std::vector<int> * iattr,std::vector<int> * oattr)237 inline bool SameType(const NodeAttrs& attrs,
238 std::vector<int> *iattr,
239 std::vector<int> *oattr) {
240 int def_v = -1;
241 for (int v : *oattr) {
242 if (v != -1) {
243 def_v = v; break;
244 }
245 }
246 if (def_v == -1) {
247 for (int v : *iattr) {
248 if (v != -1) {
249 def_v = v; break;
250 }
251 }
252 }
253 if (def_v == -1) return false;
254 for (int& v : *oattr) {
255 v = def_v;
256 }
257 for (int& v : *iattr) {
258 v = def_v;
259 }
260 return true;
261 }
262
263 NNVM_REGISTER_PASS(InferType)
264 .describe("Infer the dtype of each node entries.")
__anon918239df0502(Graph ret) 265 .set_body([](Graph ret) {
266 return InferAttr<int>(
267 std::move(ret), -1,
268 "FInferType", "dtype_inputs", "dtype_attr_key",
269 "dtype", "dtype_num_unknown_nodes",
270 [](const int t) { return t == -1; },
271 SameType);
272 })
273 .set_change_graph(false)
274 .provide_graph_attr("dtype");
275
276 DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
277 DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
278 DMLC_JSON_ENABLE_ANY(size_t, size_t);
279
280 } // namespace
281 } // namespace pass
282 } // namespace nnvm
283