1 //
2 //  TFGraphResolver.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2020/06/13.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "TFGraphResolver.hpp"
10 #include "TFGraphResolverHelpers.hpp"
11 
12 #include <vector>
13 #include <queue>
14 #include <unordered_map>
15 #include <unordered_set>
16 
17 #include "graph.pb.h"
18 #include "TmpGraph.hpp"
19 #include "tfOpConverter.hpp"
20 #include "MNN_generated.h"
21 #include "../compression/quantization.hpp"
22 #include <flatbuffers/util.h>
23 
AddNode(const NodeDef * node)24 void TFGraph::AddNode(const NodeDef* node) {
25     std::unique_ptr<TFNode> tf_node(new TFNode);
26     tf_node->node_def = node;
27     tf_node->name = node->name();
28     tf_node->op = node->op();
29     nodes_.push_back(std::move(tf_node));
30 }
31 
Finalize()32 void TFGraph::Finalize() {
33     std::unordered_map<std::string, TFNode*> nodes;
34     for (auto& node : nodes_) {
35         nodes.emplace(node->name, node.get());
36     }
37     for (auto& node : nodes_) {
38         const NodeDef* node_def = node->node_def;
39         for (int i = 0; i < node_def->input_size(); ++i) {
40             const std::string& input = node_def->input(i);
41             if (IsControlInput(input)) {
42                 continue;
43             }
44             std::string input_op = input;
45             auto splits = RSplitString(input, ":");
46             if (splits.size() == 2) {
47                 input_op = splits.at(0);
48             }
49             TFNode* start = nodes.at(input_op);
50             std::unique_ptr<TFEdge> edge(new TFEdge);
51             *edge = TFEdge{input, start, node.get()};
52             node->inputs.push_back(edge.get());
53             start->outputs.push_back(edge.get());
54             edges_.push_back(std::move(edge));
55         }
56     }
57     for (auto& node : nodes_) {
58         if (node->outputs.empty()) {
59             final_nodes_.push_back(node.get());
60         }
61     }
62 }
63 
ToProto() const64 std::unique_ptr<MNN::SubGraphProtoT> TFGraph::ToProto() const {
65     std::unique_ptr<MNN::SubGraphProtoT> graph_proto(new MNN::SubGraphProtoT);
66     graph_proto->name = name_;
67     std::vector<const TFNode*> entry_nodes;
68 
69     std::unordered_map<std::string, int> tensor_indices;
70     // Add normal nodes.
71     for (int i = 0; i < nodes_.size(); ++i) {
72         TFNode* node = nodes_[i].get();
73         std::shared_ptr<TmpNode> tempNode(new TmpNode());
74         tempNode->opName = node->name;
75         tempNode->opType = node->op;
76         tempNode->tfNode = node->node_def;
77 
78         MNN::OpT *op = new MNN::OpT;
79         auto creator = tfOpConverterSuit::get()->search(tempNode->opType);
80         DCHECK(creator) << "MNN Converter NOT_SUPPORTED_OP: [ "
81                         << tempNode->opType << " ]";
82         op->name = tempNode->opName;
83         op->type = creator->opType();
84         op->main.type = creator->type();
85 
86         // resize the inputIndexes and outputIndexes
87         int input_size = node->inputs.size();
88         op->inputIndexes.resize(input_size);
89 
90         // -1 is placeholder value, and the number of -1 is the number of
91         // output tensors.
92         // defalut: every op output one tensor, if the number of the output
93         // tensors is bigger than 1, set the outputIndexes in the op
94         // converter(void run(MNN::OpT *dstOp, TmpNode *srcNode))
95         op->outputIndexes = {-1};
96         creator->run(op, tempNode.get());
97 
98         for (int j = 0; j < input_size; j++) {
99             std::string input = node->inputs[j]->name;
100             auto it = tensor_indices.find(input);
101             if (it == tensor_indices.end()) {
102                 int index = tensor_indices.size();
103                 it = tensor_indices.emplace(input, index).first;
104                 graph_proto->tensors.push_back(input);
105             }
106             op->inputIndexes[j] = it->second;
107         }
108 
109         int output_size = node->outputs.size();
110         for (int j = 0; j < node->outputs.size(); ++j) {
111             std::string output = node->outputs[j]->name;
112             auto it = tensor_indices.find(output);
113             if (it == tensor_indices.end()) {
114                 int index = tensor_indices.size();
115                 it = tensor_indices.emplace(output, index).first;
116                 graph_proto->tensors.push_back(output);
117             }
118             int index = 0;
119             auto splits = RSplitString(output, ":");
120             if (splits.size() == 2) {
121                 index = atoi(splits[1].c_str());
122             }
123             if (op->outputIndexes.size() <= index) {
124                 int origin_size = op->outputIndexes.size();
125                 op->outputIndexes.resize(index + 1);
126                 for (int p = origin_size; p <= index; ++p) {
127                     op->outputIndexes[p] = -1;
128                 }
129             }
130             op->outputIndexes[index] = it->second;
131         }
132         graph_proto->nodes.emplace_back(op);
133     }
134 
135     for (auto &op : graph_proto->nodes) {
136         for (int i = 0; i < op->outputIndexes.size(); ++i) {
137             if (op->outputIndexes[i] == -1) {
138                 int index = graph_proto->tensors.size();
139                 op->outputIndexes[i] = index;
140                 std::string output = op->name;
141                 if (i != 0) {
142                     output += ":" + flatbuffers::NumToString(i);
143                 }
144                 graph_proto->tensors.emplace_back(output);
145             }
146         }
147     }
148     return std::move(graph_proto);
149 }
150 
BuildEdge(const std::string & name,TFNode * start,TFNode * end)151 std::unique_ptr<TFEdge> TFGraphResolver::BuildEdge(
152         const std::string& name, TFNode* start, TFNode* end) {
153     std::unique_ptr<TFEdge> edge(new TFEdge);
154     *edge = TFEdge{name, start, end};
155     return std::move(edge);
156 }
157 
BuildQuantOrDequantNode(const std::string & name,const std::string & op,const int & nbit,const std::vector<float> & scales,const float & zero_point,const float & clamp_min,const float & clamp_max,const MNN::Compression::LayerQuantizeParams_QuantMethod & method)158 std::unique_ptr<TFNode> TFGraphResolver::BuildQuantOrDequantNode(
159                             const std::string& name,
160                             const std::string& op,
161                             const int& nbit,
162                             const std::vector<float>& scales,
163                             const float& zero_point, const float& clamp_min, const float& clamp_max,
164                             const MNN::Compression::LayerQuantizeParams_QuantMethod& method) {
165     std::unique_ptr<NodeDef> node_def(new NodeDef);
166     *(node_def->mutable_name()) = name;
167     *(node_def->mutable_op()) = op;
168     (*node_def->mutable_attr())["nbit"].set_i(nbit);
169     auto* list = (*node_def->mutable_attr())["scale"].mutable_list();
170     for (int i = 0; i < scales.size(); ++i) {
171         if (op == "CustomQuantize") {
172             list->mutable_f()->Add(1.f / scales[i]);
173         } else {
174             list->mutable_f()->Add(scales[i]);
175         }
176     }
177     (*node_def->mutable_attr())["zero_point"].set_f(zero_point);
178     (*node_def->mutable_attr())["clamp_min"].set_f(clamp_min);
179     (*node_def->mutable_attr())["clamp_max"].set_f(clamp_max);
180     (*node_def->mutable_attr())["method"].set_i(int(method));
181     std::unique_ptr<TFNode> quant_node(new TFNode);
182     quant_node->name = name;
183     quant_node->op = op;
184     quant_node->node_def = node_def.get();
185 
186     main_graph()->allocated_nodes_.push_back(std::move(node_def));
187     return std::move(quant_node);
188 }
189 
ResolveQuantization(TFGraph * graph,const compression::Quantization & int8_calibration)190 void TFGraphResolver::ResolveQuantization(
191         TFGraph* graph,
192         const compression::Quantization& int8_calibration) {
193     std::vector<std::unique_ptr<TFNode>> append_nodes;
194     std::vector<std::unique_ptr<TFEdge>> append_edges;
195 
196     static int64_t uuid = 0;
197     auto AddQuantizeAndDequantizeNodes =
198             [&, this](const std::vector<TFEdge*> edges,
199                       const compression::Quantization::TensorParams& params) {
200         TFNode* start_node = edges.at(0)->start;
201         for (TFEdge* edge : edges) {
202             EraseOutput(start_node, edge);
203         }
204         auto splits = RSplitString(edges.at(0)->name, ":");
205         const std::string& op_name = splits.at(0);
206         // Add quantize node.
207         std::string quant_name = op_name + "_quant_" + flatbuffers::NumToString(uuid);
208         std::unique_ptr<TFNode> quant_node = BuildQuantOrDequantNode(
209             quant_name, "CustomQuantize", params.nbit, params.scale,
210             params.zero_point, params.clamp_min, params.clamp_max, params.method);
211         // Add dequantize node.
212         std::string dequant_name = quant_name + "_dequant_" + flatbuffers::NumToString(uuid);
213         std::unique_ptr<TFNode> dequant_node = BuildQuantOrDequantNode(
214             dequant_name, "CustomDequantize", params.nbit, params.scale,
215             params.zero_point, params.clamp_min, params.clamp_max, params.method);
216 
217         // Update UUID.
218         ++uuid;
219 
220         // Connect quantize and dequantize node.
221         std::unique_ptr<TFEdge> quant_edge =
222             BuildEdge(edges.at(0)->name, start_node, quant_node.get());
223         // Connect dequantize and the next node.
224         std::unique_ptr<TFEdge> dequant_edge =
225             BuildEdge(quant_node->name, quant_node.get(), dequant_node.get());
226 
227         AddOutput(start_node, quant_edge.get());
228 
229         quant_node->inputs = {quant_edge.get()};
230         quant_node->outputs = {dequant_edge.get()};
231         dequant_node->inputs = {dequant_edge.get()};
232         dequant_node->outputs = edges;
233         for (TFEdge* edge : edges) {
234             edge->name = dequant_node->name;
235             edge->start = dequant_node.get();
236         }
237         append_nodes.push_back(std::move(quant_node));
238         append_nodes.push_back(std::move(dequant_node));
239         append_edges.push_back(std::move(quant_edge));
240         append_edges.push_back(std::move(dequant_edge));
241 
242         // Return dequant edge.
243         return append_edges.back().get();
244     };
245 
246     const auto& tensor_params = int8_calibration.tensors;
247     for (auto& node : graph->nodes_) {
248         std::unordered_map<std::string, std::vector<TFEdge*>> quant_edges;
249         for (TFEdge* output : node->outputs) {
250             std::string tensor_name = output->name;
251             if (node->op == "Enter" || node->op == "Switch") {
252                 // The input names of the node maybe replaced by the quantize
253                 // and dequantize op, so here we use the input name from the
254                 // `node_def` since it should not be modified at any time.
255                 // tensor_name = node->inputs.at(0)->name;
256                 tensor_name = node->node_def->input(0);
257             }
258             quant_edges[tensor_name].push_back(output);
259         }
260         for (const auto& it : quant_edges) {
261             auto p = tensor_params.find(it.first);
262             if (p == tensor_params.end()) {
263                 continue;
264             }
265             const auto& params = p->second.at(0);
266             AddQuantizeAndDequantizeNodes(it.second, params);
267         }
268     }
269     for (auto& node : graph->nodes_) {
270         std::unordered_map<std::string, std::vector<TFEdge*>> quant_edges;
271         for (int i = 0; i < node->inputs.size(); ++i) {
272             TFEdge* edge = node->inputs[i];
273             quant_edges[edge->name].push_back(edge);
274         }
275         for (const auto& it : quant_edges) {
276             auto p = tensor_params.find(it.first);
277             if (p == tensor_params.end()) {
278                 continue;
279             }
280             const auto& params = p->second.at(0);
281             AddQuantizeAndDequantizeNodes(it.second, params);
282         }
283     }
284     // Append nodes and edges to root graph.
285     for (auto& node : append_nodes) {
286         main_graph()->nodes_.push_back(std::move(node));
287     }
288     for (auto& edge : append_edges) {
289         main_graph()->edges_.push_back(std::move(edge));
290     }
291 }
292 
TFGraphResolver(const tensorflow::GraphDef & graph_def)293 TFGraphResolver::TFGraphResolver(const tensorflow::GraphDef& graph_def) {
294     std::unique_ptr<TFGraph> tf_graph(new TFGraph);
295     const int count = graph_def.node_size();
296     for (int i = 0; i < count; ++i) {
297         const NodeDef& node_def = graph_def.node(i);
298         tf_graph->AddNode(&node_def);
299     }
300     tf_graph->Finalize();
301     graphs_.push_back(std::move(tf_graph));
302 
303     TFGraph* main_graph = graphs_.back().get();
304 }
305 
graph(const int graph_index) const306 const TFGraph* TFGraphResolver::graph(const int graph_index) const {
307     return graphs_.at(graph_index).get();
308 }
309 
graph(const int graph_index)310 TFGraph* TFGraphResolver::graph(const int graph_index) {
311     return graphs_.at(graph_index).get();
312 }
313 
main_graph()314 TFGraph* TFGraphResolver::main_graph() {
315     return this->graph(0);
316 }
317