1 //
2 // TFGraphResolver.cpp
3 // MNNConverter
4 //
5 // Created by MNN on 2020/06/13.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "TFGraphResolver.hpp"
10 #include "TFGraphResolverHelpers.hpp"
11
12 #include <vector>
13 #include <queue>
14 #include <unordered_map>
15 #include <unordered_set>
16
17 #include "graph.pb.h"
18 #include "TmpGraph.hpp"
19 #include "tfOpConverter.hpp"
20 #include "MNN_generated.h"
21 #include "../compression/quantization.hpp"
22 #include <flatbuffers/util.h>
23
AddNode(const NodeDef * node)24 void TFGraph::AddNode(const NodeDef* node) {
25 std::unique_ptr<TFNode> tf_node(new TFNode);
26 tf_node->node_def = node;
27 tf_node->name = node->name();
28 tf_node->op = node->op();
29 nodes_.push_back(std::move(tf_node));
30 }
31
Finalize()32 void TFGraph::Finalize() {
33 std::unordered_map<std::string, TFNode*> nodes;
34 for (auto& node : nodes_) {
35 nodes.emplace(node->name, node.get());
36 }
37 for (auto& node : nodes_) {
38 const NodeDef* node_def = node->node_def;
39 for (int i = 0; i < node_def->input_size(); ++i) {
40 const std::string& input = node_def->input(i);
41 if (IsControlInput(input)) {
42 continue;
43 }
44 std::string input_op = input;
45 auto splits = RSplitString(input, ":");
46 if (splits.size() == 2) {
47 input_op = splits.at(0);
48 }
49 TFNode* start = nodes.at(input_op);
50 std::unique_ptr<TFEdge> edge(new TFEdge);
51 *edge = TFEdge{input, start, node.get()};
52 node->inputs.push_back(edge.get());
53 start->outputs.push_back(edge.get());
54 edges_.push_back(std::move(edge));
55 }
56 }
57 for (auto& node : nodes_) {
58 if (node->outputs.empty()) {
59 final_nodes_.push_back(node.get());
60 }
61 }
62 }
63
ToProto() const64 std::unique_ptr<MNN::SubGraphProtoT> TFGraph::ToProto() const {
65 std::unique_ptr<MNN::SubGraphProtoT> graph_proto(new MNN::SubGraphProtoT);
66 graph_proto->name = name_;
67 std::vector<const TFNode*> entry_nodes;
68
69 std::unordered_map<std::string, int> tensor_indices;
70 // Add normal nodes.
71 for (int i = 0; i < nodes_.size(); ++i) {
72 TFNode* node = nodes_[i].get();
73 std::shared_ptr<TmpNode> tempNode(new TmpNode());
74 tempNode->opName = node->name;
75 tempNode->opType = node->op;
76 tempNode->tfNode = node->node_def;
77
78 MNN::OpT *op = new MNN::OpT;
79 auto creator = tfOpConverterSuit::get()->search(tempNode->opType);
80 DCHECK(creator) << "MNN Converter NOT_SUPPORTED_OP: [ "
81 << tempNode->opType << " ]";
82 op->name = tempNode->opName;
83 op->type = creator->opType();
84 op->main.type = creator->type();
85
86 // resize the inputIndexes and outputIndexes
87 int input_size = node->inputs.size();
88 op->inputIndexes.resize(input_size);
89
90 // -1 is placeholder value, and the number of -1 is the number of
91 // output tensors.
92 // defalut: every op output one tensor, if the number of the output
93 // tensors is bigger than 1, set the outputIndexes in the op
94 // converter(void run(MNN::OpT *dstOp, TmpNode *srcNode))
95 op->outputIndexes = {-1};
96 creator->run(op, tempNode.get());
97
98 for (int j = 0; j < input_size; j++) {
99 std::string input = node->inputs[j]->name;
100 auto it = tensor_indices.find(input);
101 if (it == tensor_indices.end()) {
102 int index = tensor_indices.size();
103 it = tensor_indices.emplace(input, index).first;
104 graph_proto->tensors.push_back(input);
105 }
106 op->inputIndexes[j] = it->second;
107 }
108
109 int output_size = node->outputs.size();
110 for (int j = 0; j < node->outputs.size(); ++j) {
111 std::string output = node->outputs[j]->name;
112 auto it = tensor_indices.find(output);
113 if (it == tensor_indices.end()) {
114 int index = tensor_indices.size();
115 it = tensor_indices.emplace(output, index).first;
116 graph_proto->tensors.push_back(output);
117 }
118 int index = 0;
119 auto splits = RSplitString(output, ":");
120 if (splits.size() == 2) {
121 index = atoi(splits[1].c_str());
122 }
123 if (op->outputIndexes.size() <= index) {
124 int origin_size = op->outputIndexes.size();
125 op->outputIndexes.resize(index + 1);
126 for (int p = origin_size; p <= index; ++p) {
127 op->outputIndexes[p] = -1;
128 }
129 }
130 op->outputIndexes[index] = it->second;
131 }
132 graph_proto->nodes.emplace_back(op);
133 }
134
135 for (auto &op : graph_proto->nodes) {
136 for (int i = 0; i < op->outputIndexes.size(); ++i) {
137 if (op->outputIndexes[i] == -1) {
138 int index = graph_proto->tensors.size();
139 op->outputIndexes[i] = index;
140 std::string output = op->name;
141 if (i != 0) {
142 output += ":" + flatbuffers::NumToString(i);
143 }
144 graph_proto->tensors.emplace_back(output);
145 }
146 }
147 }
148 return std::move(graph_proto);
149 }
150
BuildEdge(const std::string & name,TFNode * start,TFNode * end)151 std::unique_ptr<TFEdge> TFGraphResolver::BuildEdge(
152 const std::string& name, TFNode* start, TFNode* end) {
153 std::unique_ptr<TFEdge> edge(new TFEdge);
154 *edge = TFEdge{name, start, end};
155 return std::move(edge);
156 }
157
BuildQuantOrDequantNode(const std::string & name,const std::string & op,const int & nbit,const std::vector<float> & scales,const float & zero_point,const float & clamp_min,const float & clamp_max,const MNN::Compression::LayerQuantizeParams_QuantMethod & method)158 std::unique_ptr<TFNode> TFGraphResolver::BuildQuantOrDequantNode(
159 const std::string& name,
160 const std::string& op,
161 const int& nbit,
162 const std::vector<float>& scales,
163 const float& zero_point, const float& clamp_min, const float& clamp_max,
164 const MNN::Compression::LayerQuantizeParams_QuantMethod& method) {
165 std::unique_ptr<NodeDef> node_def(new NodeDef);
166 *(node_def->mutable_name()) = name;
167 *(node_def->mutable_op()) = op;
168 (*node_def->mutable_attr())["nbit"].set_i(nbit);
169 auto* list = (*node_def->mutable_attr())["scale"].mutable_list();
170 for (int i = 0; i < scales.size(); ++i) {
171 if (op == "CustomQuantize") {
172 list->mutable_f()->Add(1.f / scales[i]);
173 } else {
174 list->mutable_f()->Add(scales[i]);
175 }
176 }
177 (*node_def->mutable_attr())["zero_point"].set_f(zero_point);
178 (*node_def->mutable_attr())["clamp_min"].set_f(clamp_min);
179 (*node_def->mutable_attr())["clamp_max"].set_f(clamp_max);
180 (*node_def->mutable_attr())["method"].set_i(int(method));
181 std::unique_ptr<TFNode> quant_node(new TFNode);
182 quant_node->name = name;
183 quant_node->op = op;
184 quant_node->node_def = node_def.get();
185
186 main_graph()->allocated_nodes_.push_back(std::move(node_def));
187 return std::move(quant_node);
188 }
189
ResolveQuantization(TFGraph * graph,const compression::Quantization & int8_calibration)190 void TFGraphResolver::ResolveQuantization(
191 TFGraph* graph,
192 const compression::Quantization& int8_calibration) {
193 std::vector<std::unique_ptr<TFNode>> append_nodes;
194 std::vector<std::unique_ptr<TFEdge>> append_edges;
195
196 static int64_t uuid = 0;
197 auto AddQuantizeAndDequantizeNodes =
198 [&, this](const std::vector<TFEdge*> edges,
199 const compression::Quantization::TensorParams& params) {
200 TFNode* start_node = edges.at(0)->start;
201 for (TFEdge* edge : edges) {
202 EraseOutput(start_node, edge);
203 }
204 auto splits = RSplitString(edges.at(0)->name, ":");
205 const std::string& op_name = splits.at(0);
206 // Add quantize node.
207 std::string quant_name = op_name + "_quant_" + flatbuffers::NumToString(uuid);
208 std::unique_ptr<TFNode> quant_node = BuildQuantOrDequantNode(
209 quant_name, "CustomQuantize", params.nbit, params.scale,
210 params.zero_point, params.clamp_min, params.clamp_max, params.method);
211 // Add dequantize node.
212 std::string dequant_name = quant_name + "_dequant_" + flatbuffers::NumToString(uuid);
213 std::unique_ptr<TFNode> dequant_node = BuildQuantOrDequantNode(
214 dequant_name, "CustomDequantize", params.nbit, params.scale,
215 params.zero_point, params.clamp_min, params.clamp_max, params.method);
216
217 // Update UUID.
218 ++uuid;
219
220 // Connect quantize and dequantize node.
221 std::unique_ptr<TFEdge> quant_edge =
222 BuildEdge(edges.at(0)->name, start_node, quant_node.get());
223 // Connect dequantize and the next node.
224 std::unique_ptr<TFEdge> dequant_edge =
225 BuildEdge(quant_node->name, quant_node.get(), dequant_node.get());
226
227 AddOutput(start_node, quant_edge.get());
228
229 quant_node->inputs = {quant_edge.get()};
230 quant_node->outputs = {dequant_edge.get()};
231 dequant_node->inputs = {dequant_edge.get()};
232 dequant_node->outputs = edges;
233 for (TFEdge* edge : edges) {
234 edge->name = dequant_node->name;
235 edge->start = dequant_node.get();
236 }
237 append_nodes.push_back(std::move(quant_node));
238 append_nodes.push_back(std::move(dequant_node));
239 append_edges.push_back(std::move(quant_edge));
240 append_edges.push_back(std::move(dequant_edge));
241
242 // Return dequant edge.
243 return append_edges.back().get();
244 };
245
246 const auto& tensor_params = int8_calibration.tensors;
247 for (auto& node : graph->nodes_) {
248 std::unordered_map<std::string, std::vector<TFEdge*>> quant_edges;
249 for (TFEdge* output : node->outputs) {
250 std::string tensor_name = output->name;
251 if (node->op == "Enter" || node->op == "Switch") {
252 // The input names of the node maybe replaced by the quantize
253 // and dequantize op, so here we use the input name from the
254 // `node_def` since it should not be modified at any time.
255 // tensor_name = node->inputs.at(0)->name;
256 tensor_name = node->node_def->input(0);
257 }
258 quant_edges[tensor_name].push_back(output);
259 }
260 for (const auto& it : quant_edges) {
261 auto p = tensor_params.find(it.first);
262 if (p == tensor_params.end()) {
263 continue;
264 }
265 const auto& params = p->second.at(0);
266 AddQuantizeAndDequantizeNodes(it.second, params);
267 }
268 }
269 for (auto& node : graph->nodes_) {
270 std::unordered_map<std::string, std::vector<TFEdge*>> quant_edges;
271 for (int i = 0; i < node->inputs.size(); ++i) {
272 TFEdge* edge = node->inputs[i];
273 quant_edges[edge->name].push_back(edge);
274 }
275 for (const auto& it : quant_edges) {
276 auto p = tensor_params.find(it.first);
277 if (p == tensor_params.end()) {
278 continue;
279 }
280 const auto& params = p->second.at(0);
281 AddQuantizeAndDequantizeNodes(it.second, params);
282 }
283 }
284 // Append nodes and edges to root graph.
285 for (auto& node : append_nodes) {
286 main_graph()->nodes_.push_back(std::move(node));
287 }
288 for (auto& edge : append_edges) {
289 main_graph()->edges_.push_back(std::move(edge));
290 }
291 }
292
TFGraphResolver(const tensorflow::GraphDef & graph_def)293 TFGraphResolver::TFGraphResolver(const tensorflow::GraphDef& graph_def) {
294 std::unique_ptr<TFGraph> tf_graph(new TFGraph);
295 const int count = graph_def.node_size();
296 for (int i = 0; i < count; ++i) {
297 const NodeDef& node_def = graph_def.node(i);
298 tf_graph->AddNode(&node_def);
299 }
300 tf_graph->Finalize();
301 graphs_.push_back(std::move(tf_graph));
302
303 TFGraph* main_graph = graphs_.back().get();
304 }
305
graph(const int graph_index) const306 const TFGraph* TFGraphResolver::graph(const int graph_index) const {
307 return graphs_.at(graph_index).get();
308 }
309
graph(const int graph_index)310 TFGraph* TFGraphResolver::graph(const int graph_index) {
311 return graphs_.at(graph_index).get();
312 }
313
main_graph()314 TFGraph* TFGraphResolver::main_graph() {
315 return this->graph(0);
316 }
317