1 // This file is part of OpenCV project. 2 // It is subject to the license terms in the LICENSE file found in the top-level directory 3 // of this distribution and at http://opencv.org/license.html. 4 5 // Copyright (C) 2020, Intel Corporation, all rights reserved. 6 // Third party copyrights are property of their respective owners. 7 8 #ifndef __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__ 9 #define __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__ 10 11 #include <string> 12 13 #include <opencv2/core.hpp> 14 15 namespace cv { namespace dnn { 16 17 class ImportNodeWrapper 18 { 19 public: ~ImportNodeWrapper()20 virtual ~ImportNodeWrapper() {}; 21 22 virtual int getNumInputs() const = 0; 23 24 virtual std::string getInputName(int idx) const = 0; 25 26 virtual std::string getType() const = 0; 27 28 virtual void setType(const std::string& type) = 0; 29 30 virtual void setInputNames(const std::vector<std::string>& inputs) = 0; 31 }; 32 33 class ImportGraphWrapper 34 { 35 public: ~ImportGraphWrapper()36 virtual ~ImportGraphWrapper() {}; 37 38 virtual Ptr<ImportNodeWrapper> getNode(int idx) const = 0; 39 40 virtual int getNumNodes() const = 0; 41 42 virtual int getNumOutputs(int nodeId) const = 0; 43 44 virtual std::string getOutputName(int nodeId, int outId) const = 0; 45 46 virtual void removeNode(int idx) = 0; 47 }; 48 49 class Subgraph // Interface to match and replace subgraphs. 50 { 51 public: 52 virtual ~Subgraph(); 53 54 // Add a node to be matched in the origin graph. Specify ids of nodes that 55 // are expected to be inputs. Returns id of a newly added node. 56 // TODO: Replace inputs to std::vector<int> in C++11 57 int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1, 58 int input_2 = -1, int input_3 = -1); 59 60 int addNodeToMatch(const std::string& op, const std::vector<int>& inputs_); 61 62 // Specify resulting node. All the matched nodes in subgraph excluding 63 // input nodes will be fused into this single node. 64 // TODO: Replace inputs to std::vector<int> in C++11 65 void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1, 66 int input_2 = -1, int input_3 = -1, int input_4 = -1, 67 int input_5 = -1); 68 69 void setFusedNode(const std::string& op, const std::vector<int>& inputs_); 70 71 static int getInputNodeId(const Ptr<ImportGraphWrapper>& net, 72 const Ptr<ImportNodeWrapper>& node, 73 int inpId); 74 75 // Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused. 76 // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused. 77 virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, 78 std::vector<int>& matchedNodesIds, 79 std::vector<int>& targetNodesIds); 80 81 // Fuse matched subgraph. 82 void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds, 83 const std::vector<int>& targetNodesIds); 84 85 virtual void finalize(const Ptr<ImportGraphWrapper>& net, 86 const Ptr<ImportNodeWrapper>& fusedNode, 87 std::vector<Ptr<ImportNodeWrapper> >& inputs); 88 89 private: 90 std::vector<std::string> nodes; // Nodes to be matched in the origin graph. 91 std::vector<std::vector<int> > inputs; // Connections of an every node to it's inputs. 92 93 std::string fusedNodeOp; // Operation name of resulting fused node. 94 std::vector<int> fusedNodeInputs; // Inputs of fused node. 95 }; 96 97 void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net, 98 const std::vector<Ptr<Subgraph> >& patterns); 99 100 }} // namespace dnn, namespace cv 101 102 #endif // __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__ 103