1 //
2 //  TmpGraph.hpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef TMPGRAPH_HPP
10 #define TMPGRAPH_HPP
11 
12 #include <iostream>
13 #include <map>
14 #include <vector>
15 
16 #include "graph.pb.h"
17 
18 class TmpNode {
19 public:
20     TmpNode();
21     ~TmpNode();
22 
23 public:
24     std::string opName;
25     std::string opType;
26 
27     const tensorflow::NodeDef *tfNode;
28 
29     std::vector<std::string> inEdges;  // node
30     std::vector<std::string> outEdges; // node
31 
32     std::vector<std::string> inTensors;  // tensor names
33     std::vector<std::string> outTensors; // tensor names
34 
35     std::string future;
36 
37     bool isCovered;
38     bool isDelete;
39     int leftInEdges;
40     std::string DebugString() const;
41 };
42 
43 class TmpGraph {
44 public:
45     TmpGraph(const tensorflow::GraphDef &tfGraph);
46     ~TmpGraph();
47 
48 public:
49     tensorflow::GraphDef _tfGraph;
50 
51     std::vector<TmpNode *> tmpNodes;
52     std::map<std::string, TmpNode *> tmpNodeMap; // nodeName, TmpNode*
53 
54     // constant nodes which have no input
55     std::vector<std::string> inputNodes;
56     std::vector<std::string> outputNodes;
57 
58     std::vector<std::string> opsInOrder;
59 
60 public:
61     int buildGraph(); // build the min Graph
62     TmpNode *_getTmpNode(const std::string &nodeName);
63 
64 private:
65     TmpGraph();
66     bool _allOpSupported();
67     int _setInOutTensorsName(TmpNode *parentNode, TmpNode *curNode, std::string inputName);
68 
69     int _setOuputTensorsName(std::vector<std::string> &tensorVector, std::string inputName, int index);
70 
71     int _makeConnection(TmpNode *srcNode, TmpNode *dstNode, const std::string srcName, const std::string dstName);
72 
73     void _genMinGraph();
74     void _changInOutName(std::vector<std::string> &inOutEdges, std::string name, std::string deleteName);
75 
76     int _getOpsInorder(const std::vector<std::string> inputNodes);
77 
78     void _getInputNodes();
79 
80     int _pushNoReaptedItem(std::vector<std::string> &tensorNames, const std::string item);
81     void _getTmpNodeMapAndConnection();
82     int _optimizeTfModel();
83     bool _hasContinuousConstantNode();
84 };
85 
86 #endif // TMPGRAPH_HPP
87