1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file graph_deep_compare.cc
22  * \brief Deep compare two graph structure
23  */
24 #include <dmlc/common.h>
25 #include <nnvm/graph.h>
26 #include <nnvm/op_attr_types.h>
27 #include <nnvm/compiler/packed_func_ext.h>
28 #include <tvm/ir.h>
29 #include <tvm/runtime/packed_func.h>
30 #include <functional>
31 #include <vector>
32 #include <utility>
33 #include <algorithm>
34 #include "node_attr.h"
35 #include "graph_hash.h"
36 
37 namespace nnvm {
38 namespace compiler {
39 
40 using namespace tvm;
41 using tvm::ir::IntImm;
42 
HashPlaceHolder(const Tensor & t)43 size_t HashPlaceHolder(const Tensor& t) {
44   size_t key = t->shape.size();
45   key = dmlc::HashCombine(key, (t->dtype.code() << 8) | t->dtype.bits());
46   for (Expr s : t->shape) {
47     if (const IntImm* op = s.as<IntImm>()) {
48       key = dmlc::HashCombine(key, op->value);
49     }
50   }
51   return key;
52 }
53 
PlaceHolderEqual(const Tensor & a,const Tensor & b)54 bool PlaceHolderEqual(const Tensor& a, const Tensor& b) {
55   if (a->shape.size() != b->shape.size()) return false;
56   if (a->dtype != b->dtype) return false;
57   for (size_t i = 0; i < a->shape.size(); ++i) {
58     const IntImm* a_value = a->shape[i].as<IntImm>();
59     const IntImm* b_value = b->shape[i].as<IntImm>();
60     if (a_value && b_value == nullptr) return false;
61     if (b_value && a_value == nullptr) return false;
62     if (a_value == nullptr && b_value == nullptr) {
63       continue;
64     }
65     if (a_value->value != b_value->value) return false;
66   }
67   return true;
68 }
69 
Hash(const GraphKey & gkey)70 size_t GraphKeyHash::Hash(const GraphKey& gkey)  {
71   if (gkey->cache_hash_key_ != 0) return gkey->cache_hash_key_;
72   size_t key = dmlc::HashCombine(GraphHash(gkey->graph), gkey->target);
73   key = dmlc::HashCombine(key, gkey->inputs.size());
74   for (size_t i = 0; i < gkey->inputs.size(); ++i) {
75     key = dmlc::HashCombine(key, HashPlaceHolder(gkey->inputs[i]));
76   }
77   if (key == 0) key = 1;
78   gkey->cache_hash_key_ = key;
79   return key;
80 }
81 
Equal(const GraphKey & a,const GraphKey & b)82 bool GraphKeyEqual::Equal(const GraphKey& a,
83                           const GraphKey& b) {
84   if (a->target != b->target) return false;
85   if (a->inputs.size() != b->inputs.size()) return false;
86   for (size_t i = 0; i < a->inputs.size(); ++i) {
87     if (!PlaceHolderEqual(a->inputs[i], b->inputs[i])) return false;
88   }
89   if (GraphDeepCompare(a->graph, b->graph, false).length() != 0) return false;
90   return true;
91 }
92 
make(Graph graph,tvm::Array<Tensor> inputs,std::string target)93 GraphKey GraphKeyNode::make(Graph graph,
94                             tvm::Array<Tensor> inputs,
95                             std::string target) {
96   auto n = tvm::make_node<GraphKeyNode>();
97   n->graph = std::move(graph);
98   n->inputs = inputs;
99   n->target = std::move(target);
100   return GraphKey(n);
101 }
102 
103 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon3b52ee9d0102(const ObjectRef& ref, IRPrinter* p) 104 .set_dispatch<GraphKeyNode>([](const ObjectRef& ref, IRPrinter* p) {
105     auto* op = static_cast<const GraphKeyNode*>(ref.get());
106     p->stream << "GraphKeyNode("<< op << ")";
107 });
108 
109 
110 // Run graph hash
GraphHash(const Graph & graph)111 size_t GraphHash(const Graph& graph) {
112   const IndexedGraph& idx = graph.indexed_graph();
113   size_t key = 0;
114   // Combine a linearized sequence of ops in subgraph
115   key = dmlc::HashCombine(key, idx.num_nodes());
116   std::hash<std::string> str_hash;
117   std::vector<size_t> hash_temp;
118   for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
119     const IndexedGraph::Node& inode = idx[nid];
120     // Use name instad op address so it is deterministic across runs
121     if (inode.source->is_variable()) continue;
122     key = dmlc::HashCombine(key, inode.source->op()->name);
123     hash_temp.clear();
124     for (const auto& kv : GetAttrDict(inode.source->attrs)) {
125       hash_temp.push_back(dmlc::HashCombine(str_hash(kv.first), kv.second));
126     }
127     // to make sure it is deterministic
128     // since unordered_map is not deterministic
129     std::sort(hash_temp.begin(), hash_temp.end());
130     for (size_t value : hash_temp) {
131       key = dmlc::HashCombine(key, value);
132     }
133   }
134   return key;
135 }
136 
137 // deep compare the graph structure
138 // not considering the graph attributes
139 // return non-empty error message if the graph mismatch.
140 // the comparator won't match name of intermediate node.
141 // compare_var_attr
GraphDeepCompare(const Graph & a,const Graph & b,bool compare_variable_attr)142 std::string GraphDeepCompare(const Graph& a,
143                              const Graph& b,
144                              bool compare_variable_attr) {
145   const IndexedGraph& idxa = a.indexed_graph();
146   const IndexedGraph& idxb = b.indexed_graph();
147   std::ostringstream err;
148   if (idxa.num_nodes() != idxb.num_nodes()) {
149     err << "Number of nodes mismatch (" <<  idxa.num_nodes() << " v.s " << idxb.num_nodes() << ")";
150     return err.str();
151   }
152   if (idxa.num_node_entries() != idxb.num_node_entries()) {
153     err << "Number of node entry mismatch";
154     return err.str();
155   }
156   if (idxa.outputs().size() != idxb.outputs().size()) {
157     err << "Number of outputs mismatch";
158     return err.str();
159   }
160   for (size_t i = 0; i < idxa.outputs().size(); ++i) {
161     if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id ||
162         idxa.outputs()[i].index != idxb.outputs()[i].index) {
163       err << "Output entry mismatch";
164       return err.str();
165     }
166   }
167   if (idxa.input_nodes().size() != idxb.input_nodes().size()) {
168     err << "Number of inputs mismatch";
169     return err.str();
170   }
171 
172   for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) {
173     const IndexedGraph::Node& anode = idxa[nid];
174     const IndexedGraph::Node& bnode = idxb[nid];
175     if (anode.source->op() != bnode.source->op()) {
176       err << "Node mismatch ";
177       return err.str();
178     }
179     if (anode.source->is_variable()) {
180       CHECK(bnode.source->is_variable());
181       if (!compare_variable_attr) continue;
182     }
183     AttrDict adict = GetAttrDict(anode.source->attrs);
184     AttrDict bdict = GetAttrDict(bnode.source->attrs);
185 
186     auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) {
187       for (const auto& kv : adict) {
188         auto it = bdict.find(kv.first);
189         if (it != bdict.end()) {
190           if (it->second != kv.second) {
191             err << "Node attr mismatch, op=" << anode.source->attrs.name
192                 << " attr_key=" << kv.first << " " << it->second
193                 << " v.s. " << kv.second;
194             return false;
195           }
196         } else {
197           err << "One attr_key=" << kv.first << " is missing in another "
198                << "op=" << anode.source->attrs.name;
199           return false;
200         }
201       }
202       return true;
203     };
204     if (!fmatch(adict, bdict)) return err.str();
205     if (adict.size() != bdict.size()) {
206       CHECK(!fmatch(bdict, adict));
207       return err.str();
208     }
209     if (anode.inputs.size() != bnode.inputs.size()) {
210       err << "Node input mismatch, op=" << anode.source->attrs.name;
211       return err.str();
212     }
213     if (anode.control_deps.size() != bnode.control_deps.size()) {
214       err << "Node control_deps mistach, op=" << anode.source->attrs.name;
215       return err.str();
216     }
217     for (size_t i = 0; i < anode.inputs.size(); ++i) {
218       const IndexedGraph::NodeEntry& ae = anode.inputs[i];
219       const IndexedGraph::NodeEntry& be = bnode.inputs[i];
220       if (ae.node_id != be.node_id ||
221           ae.index != be.index ||
222           ae.version != be.version) {
223         err << "Node input mismatch on, op=" << anode.source->attrs.name;
224         return err.str();
225       }
226     }
227     for (size_t i = 0; i < anode.control_deps.size(); ++i) {
228       if (anode.control_deps[i] != bnode.control_deps[i]) {
229         err << "Node control_dep mismatch on, op=" << anode.source->attrs.name;
230         return err.str();
231       }
232     }
233   }
234   return "";
235 }
236 
237 TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
238 .set_body_typed(GraphDeepCompare);
239 }  // namespace compiler
240 }  // namespace nnvm
241