1 // 2 // TmpGraph.hpp 3 // MNNConverter 4 // 5 // Created by MNN on 2019/01/31. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef TMPGRAPH_HPP 10 #define TMPGRAPH_HPP 11 12 #include <iostream> 13 #include <map> 14 #include <vector> 15 16 #include "graph.pb.h" 17 18 class TmpNode { 19 public: 20 TmpNode(); 21 ~TmpNode(); 22 23 public: 24 std::string opName; 25 std::string opType; 26 27 const tensorflow::NodeDef *tfNode; 28 29 std::vector<std::string> inEdges; // node 30 std::vector<std::string> outEdges; // node 31 32 std::vector<std::string> inTensors; // tensor names 33 std::vector<std::string> outTensors; // tensor names 34 35 std::string future; 36 37 bool isCovered; 38 bool isDelete; 39 int leftInEdges; 40 std::string DebugString() const; 41 }; 42 43 class TmpGraph { 44 public: 45 TmpGraph(const tensorflow::GraphDef &tfGraph); 46 ~TmpGraph(); 47 48 public: 49 tensorflow::GraphDef _tfGraph; 50 51 std::vector<TmpNode *> tmpNodes; 52 std::map<std::string, TmpNode *> tmpNodeMap; // nodeName, TmpNode* 53 54 // constant nodes which have no input 55 std::vector<std::string> inputNodes; 56 std::vector<std::string> outputNodes; 57 58 std::vector<std::string> opsInOrder; 59 60 public: 61 int buildGraph(); // build the min Graph 62 TmpNode *_getTmpNode(const std::string &nodeName); 63 64 private: 65 TmpGraph(); 66 bool _allOpSupported(); 67 int _setInOutTensorsName(TmpNode *parentNode, TmpNode *curNode, std::string inputName); 68 69 int _setOuputTensorsName(std::vector<std::string> &tensorVector, std::string inputName, int index); 70 71 int _makeConnection(TmpNode *srcNode, TmpNode *dstNode, const std::string srcName, const std::string dstName); 72 73 void _genMinGraph(); 74 void _changInOutName(std::vector<std::string> &inOutEdges, std::string name, std::string deleteName); 75 76 int _getOpsInorder(const std::vector<std::string> inputNodes); 77 78 void _getInputNodes(); 79 80 int _pushNoReaptedItem(std::vector<std::string> &tensorNames, const std::string item); 81 void _getTmpNodeMapAndConnection(); 82 int _optimizeTfModel(); 83 bool _hasContinuousConstantNode(); 84 }; 85 86 #endif // TMPGRAPH_HPP 87