1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file print_graph_ir.cc
22  * \brief Print the graph IR in LLVM style human readable format.
23  */
24 #include <nnvm/graph.h>
25 #include <nnvm/pass.h>
26 #include <nnvm/tuple.h>
27 
28 #include <iostream>
29 
30 namespace nnvm {
31 namespace pass {
32 
33 using AttrPrinter = std::function<void(uint32_t index, std::ostream& os)>;  // NOLINT(*)
34 
35 template <typename T>
GetVectorPrinter_(const T & vec)36 AttrPrinter GetVectorPrinter_(const T& vec) {
37   return [&vec](uint32_t index, std::ostream& os) {  // NOLINT(*)
38     os << vec[index];
39   };
40 }
41 
GetVectorPrinter(const Graph & graph,const std::string & key)42 AttrPrinter GetVectorPrinter(const Graph& graph, const std::string& key) {
43   auto it = graph.attrs.find(key);
44   CHECK(it != graph.attrs.end()) << "Cannot find " << key << " in graph attr";
45   const any& value = *(it->second);
46   if (value.type() == typeid(std::vector<TShape>)) {
47     return GetVectorPrinter_(nnvm::get<std::vector<TShape> >(value));
48   } else if (value.type() == typeid(std::vector<int>)) {
49     return GetVectorPrinter_(nnvm::get<std::vector<int> >(value));
50   } else if (value.type() == typeid(std::vector<std::string>)) {
51     return GetVectorPrinter_(nnvm::get<std::vector<std::string> >(value));
52   } else {
53     LOG(FATAL) << "Cannot handle type " << value.type().name();
54     return nullptr;
55   }
56 }
57 
58 // print the graph ir in readable format
PrintGraphIR_(Graph src,const std::vector<std::string> & join_entry_attrs,const std::vector<std::string> & join_node_attrs,std::ostream & os)59 void PrintGraphIR_(Graph src, const std::vector<std::string>& join_entry_attrs,
60                    const std::vector<std::string>& join_node_attrs,
61                    std::ostream& os) {  // NOLINT(*)
62   const IndexedGraph& idx = src.indexed_graph();
63   std::vector<std::function<void(uint32_t, std::ostream&)> > trigger;  // NOLINT(*)
64 
65   for (const std::string& key : join_entry_attrs) {
66     AttrPrinter fp = GetVectorPrinter(src, key);
67     auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) {  // NOLINT(*)
68       const IndexedGraph::Node& inode = idx[nid];
69       os << ", " << key << "=";
70       if (inode.source->num_outputs() != 1) {
71         os << '[';
72         for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
73           if (i != 0) os << ", ";
74           fp(idx.entry_id(nid, i), os);
75         }
76         os << ']';
77       } else {
78         fp(idx.entry_id(nid, 0), os);
79       }
80     };
81     trigger.push_back(fprint);
82   }
83   for (const std::string& key : join_node_attrs) {
84     AttrPrinter fp = GetVectorPrinter(src, key);
85     auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) {  // NOLINT(*)
86       os << ", " << key << "=";
87       fp(idx.entry_id(nid, 0), os);
88     };
89     trigger.push_back(fprint);
90   }
91 
92   os << "Graph(";
93   if (idx.input_nodes().size() < 4) {
94     for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
95       uint32_t nid = idx.input_nodes()[i];
96       if (i != 0) {
97         os << ", ";
98       }
99       os << '%' << idx[nid].source->attrs.name;
100     }
101   } else {
102     for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
103       uint32_t nid = idx.input_nodes()[i];
104       if (i != 0) {
105         os << ",\n      ";
106       }
107       os << '%' << idx[nid].source->attrs.name;
108     }
109   }
110   os << ") {\n";
111 
112   auto print_entry = [&](const IndexedGraph::NodeEntry& e) {
113     if (idx[e.node_id].source->is_variable()) {
114       os << '%' << idx[e.node_id].source->attrs.name;
115     } else if (idx[e.node_id].source->num_outputs() == 1) {
116       os << '%' << e.node_id;
117     } else {
118       os << '%' << e.node_id << "." << e.index;
119     }
120   };
121 
122   if (trigger.size() != 0) {
123     for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
124       uint32_t nid = idx.input_nodes()[i];
125       os << "  %" << idx[nid].source->attrs.name;
126       for (const auto& fp : trigger) {
127         fp(nid, os);
128       }
129       os << '\n';
130     }
131   }
132 
133   for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
134     const auto& inode = idx[nid];
135     if (inode.source->is_variable()) continue;
136     os << "  "
137        << "%" << nid << " = " << inode.source->op()->name << "(";
138     bool first = true;
139     for (const IndexedGraph::NodeEntry& e : inode.inputs) {
140       if (first) {
141         first = false;
142       } else {
143         os << ", ";
144       }
145       print_entry(e);
146     }
147     for (const auto& kv : inode.source->attrs.dict) {
148       if (first) {
149         first = false;
150       } else {
151         os << ", ";
152       }
153       os << kv.first << "=\'" << kv.second << "\'";
154     }
155     os << ")";
156     if (inode.control_deps.size() != 0) {
157       os << ", control_deps=[";
158       for (size_t i = 0; i < inode.control_deps.size(); ++i) {
159         if (i != 0) os << ", ";
160         uint32_t cid = inode.control_deps[i];
161         if (idx[cid].source->is_variable()) {
162           os << '%' << idx[cid].source->attrs.name;
163         } else {
164           os << '%' << cid;
165         }
166       }
167       os << "]";
168     }
169     // additional attribute trigger
170     for (const auto& fp : trigger) {
171       fp(nid, os);
172     }
173     os << "\n";
174   }
175   os << "  ret ";
176   {
177     bool first = true;
178     for (const IndexedGraph::NodeEntry& e : idx.outputs()) {
179       if (first) {
180         first = false;
181       } else {
182         os << ", ";
183       }
184       print_entry(e);
185     }
186   }
187   os << "\n}";
188   if (src.attrs.size() != 0) {
189     os << "\ngraph_attr_keys = [";
190     bool first = true;
191     for (const auto& kv : src.attrs) {
192       if (first) {
193         first = false;
194       } else {
195         os << ", ";
196       }
197       os << kv.first;
198     }
199     os << "]\n";
200   }
201 }
202 
203 // save a graph to json
PrintGraphIRPass(Graph src)204 Graph PrintGraphIRPass(Graph src) {
205   std::ostringstream os;
206   std::vector<std::string> join_entry_attrs, join_node_attrs;
207   if (src.attrs.count("join_entry_attrs") != 0) {
208     join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >("join_entry_attrs");
209   }
210   if (src.attrs.count("join_node_attrs") != 0) {
211     join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >("join_node_attrs");
212   }
213   PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os);
214   Graph ret;
215   ret.attrs["graphir"] = std::make_shared<any>(os.str());
216   return ret;
217 }
218 
219 // register pass
220 NNVM_REGISTER_PASS(PrintGraphIR)
221     .describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]")
222     .set_body(PrintGraphIRPass);
223 
224 }  // namespace pass
225 }  // namespace nnvm
226