1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file print_graph_ir.cc
22 * \brief Print the graph IR in LLVM style human readable format.
23 */
24 #include <nnvm/graph.h>
25 #include <nnvm/pass.h>
26 #include <nnvm/tuple.h>
27
28 #include <iostream>
29
30 namespace nnvm {
31 namespace pass {
32
33 using AttrPrinter = std::function<void(uint32_t index, std::ostream& os)>; // NOLINT(*)
34
35 template <typename T>
GetVectorPrinter_(const T & vec)36 AttrPrinter GetVectorPrinter_(const T& vec) {
37 return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*)
38 os << vec[index];
39 };
40 }
41
GetVectorPrinter(const Graph & graph,const std::string & key)42 AttrPrinter GetVectorPrinter(const Graph& graph, const std::string& key) {
43 auto it = graph.attrs.find(key);
44 CHECK(it != graph.attrs.end()) << "Cannot find " << key << " in graph attr";
45 const any& value = *(it->second);
46 if (value.type() == typeid(std::vector<TShape>)) {
47 return GetVectorPrinter_(nnvm::get<std::vector<TShape> >(value));
48 } else if (value.type() == typeid(std::vector<int>)) {
49 return GetVectorPrinter_(nnvm::get<std::vector<int> >(value));
50 } else if (value.type() == typeid(std::vector<std::string>)) {
51 return GetVectorPrinter_(nnvm::get<std::vector<std::string> >(value));
52 } else {
53 LOG(FATAL) << "Cannot handle type " << value.type().name();
54 return nullptr;
55 }
56 }
57
58 // print the graph ir in readable format
PrintGraphIR_(Graph src,const std::vector<std::string> & join_entry_attrs,const std::vector<std::string> & join_node_attrs,std::ostream & os)59 void PrintGraphIR_(Graph src, const std::vector<std::string>& join_entry_attrs,
60 const std::vector<std::string>& join_node_attrs,
61 std::ostream& os) { // NOLINT(*)
62 const IndexedGraph& idx = src.indexed_graph();
63 std::vector<std::function<void(uint32_t, std::ostream&)> > trigger; // NOLINT(*)
64
65 for (const std::string& key : join_entry_attrs) {
66 AttrPrinter fp = GetVectorPrinter(src, key);
67 auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*)
68 const IndexedGraph::Node& inode = idx[nid];
69 os << ", " << key << "=";
70 if (inode.source->num_outputs() != 1) {
71 os << '[';
72 for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
73 if (i != 0) os << ", ";
74 fp(idx.entry_id(nid, i), os);
75 }
76 os << ']';
77 } else {
78 fp(idx.entry_id(nid, 0), os);
79 }
80 };
81 trigger.push_back(fprint);
82 }
83 for (const std::string& key : join_node_attrs) {
84 AttrPrinter fp = GetVectorPrinter(src, key);
85 auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*)
86 os << ", " << key << "=";
87 fp(idx.entry_id(nid, 0), os);
88 };
89 trigger.push_back(fprint);
90 }
91
92 os << "Graph(";
93 if (idx.input_nodes().size() < 4) {
94 for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
95 uint32_t nid = idx.input_nodes()[i];
96 if (i != 0) {
97 os << ", ";
98 }
99 os << '%' << idx[nid].source->attrs.name;
100 }
101 } else {
102 for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
103 uint32_t nid = idx.input_nodes()[i];
104 if (i != 0) {
105 os << ",\n ";
106 }
107 os << '%' << idx[nid].source->attrs.name;
108 }
109 }
110 os << ") {\n";
111
112 auto print_entry = [&](const IndexedGraph::NodeEntry& e) {
113 if (idx[e.node_id].source->is_variable()) {
114 os << '%' << idx[e.node_id].source->attrs.name;
115 } else if (idx[e.node_id].source->num_outputs() == 1) {
116 os << '%' << e.node_id;
117 } else {
118 os << '%' << e.node_id << "." << e.index;
119 }
120 };
121
122 if (trigger.size() != 0) {
123 for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
124 uint32_t nid = idx.input_nodes()[i];
125 os << " %" << idx[nid].source->attrs.name;
126 for (const auto& fp : trigger) {
127 fp(nid, os);
128 }
129 os << '\n';
130 }
131 }
132
133 for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
134 const auto& inode = idx[nid];
135 if (inode.source->is_variable()) continue;
136 os << " "
137 << "%" << nid << " = " << inode.source->op()->name << "(";
138 bool first = true;
139 for (const IndexedGraph::NodeEntry& e : inode.inputs) {
140 if (first) {
141 first = false;
142 } else {
143 os << ", ";
144 }
145 print_entry(e);
146 }
147 for (const auto& kv : inode.source->attrs.dict) {
148 if (first) {
149 first = false;
150 } else {
151 os << ", ";
152 }
153 os << kv.first << "=\'" << kv.second << "\'";
154 }
155 os << ")";
156 if (inode.control_deps.size() != 0) {
157 os << ", control_deps=[";
158 for (size_t i = 0; i < inode.control_deps.size(); ++i) {
159 if (i != 0) os << ", ";
160 uint32_t cid = inode.control_deps[i];
161 if (idx[cid].source->is_variable()) {
162 os << '%' << idx[cid].source->attrs.name;
163 } else {
164 os << '%' << cid;
165 }
166 }
167 os << "]";
168 }
169 // additional attribute trigger
170 for (const auto& fp : trigger) {
171 fp(nid, os);
172 }
173 os << "\n";
174 }
175 os << " ret ";
176 {
177 bool first = true;
178 for (const IndexedGraph::NodeEntry& e : idx.outputs()) {
179 if (first) {
180 first = false;
181 } else {
182 os << ", ";
183 }
184 print_entry(e);
185 }
186 }
187 os << "\n}";
188 if (src.attrs.size() != 0) {
189 os << "\ngraph_attr_keys = [";
190 bool first = true;
191 for (const auto& kv : src.attrs) {
192 if (first) {
193 first = false;
194 } else {
195 os << ", ";
196 }
197 os << kv.first;
198 }
199 os << "]\n";
200 }
201 }
202
203 // save a graph to json
PrintGraphIRPass(Graph src)204 Graph PrintGraphIRPass(Graph src) {
205 std::ostringstream os;
206 std::vector<std::string> join_entry_attrs, join_node_attrs;
207 if (src.attrs.count("join_entry_attrs") != 0) {
208 join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >("join_entry_attrs");
209 }
210 if (src.attrs.count("join_node_attrs") != 0) {
211 join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >("join_node_attrs");
212 }
213 PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os);
214 Graph ret;
215 ret.attrs["graphir"] = std::make_shared<any>(os.str());
216 return ret;
217 }
218
219 // register pass
220 NNVM_REGISTER_PASS(PrintGraphIR)
221 .describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]")
222 .set_body(PrintGraphIRPass);
223
224 } // namespace pass
225 } // namespace nnvm
226