1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file graph_deep_compare.cc
22 * \brief Deep compare two graph structure
23 */
24 #include <dmlc/common.h>
25 #include <nnvm/graph.h>
26 #include <nnvm/op_attr_types.h>
27 #include <nnvm/compiler/packed_func_ext.h>
28 #include <tvm/ir.h>
29 #include <tvm/runtime/packed_func.h>
30 #include <functional>
31 #include <vector>
32 #include <utility>
33 #include <algorithm>
34 #include "node_attr.h"
35 #include "graph_hash.h"
36
37 namespace nnvm {
38 namespace compiler {
39
40 using namespace tvm;
41 using tvm::ir::IntImm;
42
HashPlaceHolder(const Tensor & t)43 size_t HashPlaceHolder(const Tensor& t) {
44 size_t key = t->shape.size();
45 key = dmlc::HashCombine(key, (t->dtype.code() << 8) | t->dtype.bits());
46 for (Expr s : t->shape) {
47 if (const IntImm* op = s.as<IntImm>()) {
48 key = dmlc::HashCombine(key, op->value);
49 }
50 }
51 return key;
52 }
53
PlaceHolderEqual(const Tensor & a,const Tensor & b)54 bool PlaceHolderEqual(const Tensor& a, const Tensor& b) {
55 if (a->shape.size() != b->shape.size()) return false;
56 if (a->dtype != b->dtype) return false;
57 for (size_t i = 0; i < a->shape.size(); ++i) {
58 const IntImm* a_value = a->shape[i].as<IntImm>();
59 const IntImm* b_value = b->shape[i].as<IntImm>();
60 if (a_value && b_value == nullptr) return false;
61 if (b_value && a_value == nullptr) return false;
62 if (a_value == nullptr && b_value == nullptr) {
63 continue;
64 }
65 if (a_value->value != b_value->value) return false;
66 }
67 return true;
68 }
69
Hash(const GraphKey & gkey)70 size_t GraphKeyHash::Hash(const GraphKey& gkey) {
71 if (gkey->cache_hash_key_ != 0) return gkey->cache_hash_key_;
72 size_t key = dmlc::HashCombine(GraphHash(gkey->graph), gkey->target);
73 key = dmlc::HashCombine(key, gkey->inputs.size());
74 for (size_t i = 0; i < gkey->inputs.size(); ++i) {
75 key = dmlc::HashCombine(key, HashPlaceHolder(gkey->inputs[i]));
76 }
77 if (key == 0) key = 1;
78 gkey->cache_hash_key_ = key;
79 return key;
80 }
81
Equal(const GraphKey & a,const GraphKey & b)82 bool GraphKeyEqual::Equal(const GraphKey& a,
83 const GraphKey& b) {
84 if (a->target != b->target) return false;
85 if (a->inputs.size() != b->inputs.size()) return false;
86 for (size_t i = 0; i < a->inputs.size(); ++i) {
87 if (!PlaceHolderEqual(a->inputs[i], b->inputs[i])) return false;
88 }
89 if (GraphDeepCompare(a->graph, b->graph, false).length() != 0) return false;
90 return true;
91 }
92
make(Graph graph,tvm::Array<Tensor> inputs,std::string target)93 GraphKey GraphKeyNode::make(Graph graph,
94 tvm::Array<Tensor> inputs,
95 std::string target) {
96 auto n = tvm::make_node<GraphKeyNode>();
97 n->graph = std::move(graph);
98 n->inputs = inputs;
99 n->target = std::move(target);
100 return GraphKey(n);
101 }
102
103 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon3b52ee9d0102(const ObjectRef& ref, IRPrinter* p) 104 .set_dispatch<GraphKeyNode>([](const ObjectRef& ref, IRPrinter* p) {
105 auto* op = static_cast<const GraphKeyNode*>(ref.get());
106 p->stream << "GraphKeyNode("<< op << ")";
107 });
108
109
110 // Run graph hash
GraphHash(const Graph & graph)111 size_t GraphHash(const Graph& graph) {
112 const IndexedGraph& idx = graph.indexed_graph();
113 size_t key = 0;
114 // Combine a linearized sequence of ops in subgraph
115 key = dmlc::HashCombine(key, idx.num_nodes());
116 std::hash<std::string> str_hash;
117 std::vector<size_t> hash_temp;
118 for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
119 const IndexedGraph::Node& inode = idx[nid];
120 // Use name instad op address so it is deterministic across runs
121 if (inode.source->is_variable()) continue;
122 key = dmlc::HashCombine(key, inode.source->op()->name);
123 hash_temp.clear();
124 for (const auto& kv : GetAttrDict(inode.source->attrs)) {
125 hash_temp.push_back(dmlc::HashCombine(str_hash(kv.first), kv.second));
126 }
127 // to make sure it is deterministic
128 // since unordered_map is not deterministic
129 std::sort(hash_temp.begin(), hash_temp.end());
130 for (size_t value : hash_temp) {
131 key = dmlc::HashCombine(key, value);
132 }
133 }
134 return key;
135 }
136
137 // deep compare the graph structure
138 // not considering the graph attributes
139 // return non-empty error message if the graph mismatch.
140 // the comparator won't match name of intermediate node.
141 // compare_var_attr
GraphDeepCompare(const Graph & a,const Graph & b,bool compare_variable_attr)142 std::string GraphDeepCompare(const Graph& a,
143 const Graph& b,
144 bool compare_variable_attr) {
145 const IndexedGraph& idxa = a.indexed_graph();
146 const IndexedGraph& idxb = b.indexed_graph();
147 std::ostringstream err;
148 if (idxa.num_nodes() != idxb.num_nodes()) {
149 err << "Number of nodes mismatch (" << idxa.num_nodes() << " v.s " << idxb.num_nodes() << ")";
150 return err.str();
151 }
152 if (idxa.num_node_entries() != idxb.num_node_entries()) {
153 err << "Number of node entry mismatch";
154 return err.str();
155 }
156 if (idxa.outputs().size() != idxb.outputs().size()) {
157 err << "Number of outputs mismatch";
158 return err.str();
159 }
160 for (size_t i = 0; i < idxa.outputs().size(); ++i) {
161 if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id ||
162 idxa.outputs()[i].index != idxb.outputs()[i].index) {
163 err << "Output entry mismatch";
164 return err.str();
165 }
166 }
167 if (idxa.input_nodes().size() != idxb.input_nodes().size()) {
168 err << "Number of inputs mismatch";
169 return err.str();
170 }
171
172 for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) {
173 const IndexedGraph::Node& anode = idxa[nid];
174 const IndexedGraph::Node& bnode = idxb[nid];
175 if (anode.source->op() != bnode.source->op()) {
176 err << "Node mismatch ";
177 return err.str();
178 }
179 if (anode.source->is_variable()) {
180 CHECK(bnode.source->is_variable());
181 if (!compare_variable_attr) continue;
182 }
183 AttrDict adict = GetAttrDict(anode.source->attrs);
184 AttrDict bdict = GetAttrDict(bnode.source->attrs);
185
186 auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) {
187 for (const auto& kv : adict) {
188 auto it = bdict.find(kv.first);
189 if (it != bdict.end()) {
190 if (it->second != kv.second) {
191 err << "Node attr mismatch, op=" << anode.source->attrs.name
192 << " attr_key=" << kv.first << " " << it->second
193 << " v.s. " << kv.second;
194 return false;
195 }
196 } else {
197 err << "One attr_key=" << kv.first << " is missing in another "
198 << "op=" << anode.source->attrs.name;
199 return false;
200 }
201 }
202 return true;
203 };
204 if (!fmatch(adict, bdict)) return err.str();
205 if (adict.size() != bdict.size()) {
206 CHECK(!fmatch(bdict, adict));
207 return err.str();
208 }
209 if (anode.inputs.size() != bnode.inputs.size()) {
210 err << "Node input mismatch, op=" << anode.source->attrs.name;
211 return err.str();
212 }
213 if (anode.control_deps.size() != bnode.control_deps.size()) {
214 err << "Node control_deps mistach, op=" << anode.source->attrs.name;
215 return err.str();
216 }
217 for (size_t i = 0; i < anode.inputs.size(); ++i) {
218 const IndexedGraph::NodeEntry& ae = anode.inputs[i];
219 const IndexedGraph::NodeEntry& be = bnode.inputs[i];
220 if (ae.node_id != be.node_id ||
221 ae.index != be.index ||
222 ae.version != be.version) {
223 err << "Node input mismatch on, op=" << anode.source->attrs.name;
224 return err.str();
225 }
226 }
227 for (size_t i = 0; i < anode.control_deps.size(); ++i) {
228 if (anode.control_deps[i] != bnode.control_deps[i]) {
229 err << "Node control_dep mismatch on, op=" << anode.source->attrs.name;
230 return err.str();
231 }
232 }
233 }
234 return "";
235 }
236
237 TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
238 .set_body_typed(GraphDeepCompare);
239 } // namespace compiler
240 } // namespace nnvm
241