1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 
5 // Copyright (C) 2020, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 
8 #ifndef __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
9 #define __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
10 
11 #include <string>
12 
13 #include <opencv2/core.hpp>
14 
15 namespace cv { namespace dnn {
16 
17 class ImportNodeWrapper
18 {
19 public:
~ImportNodeWrapper()20     virtual ~ImportNodeWrapper() {};
21 
22     virtual int getNumInputs() const = 0;
23 
24     virtual std::string getInputName(int idx) const = 0;
25 
26     virtual std::string getType() const = 0;
27 
28     virtual void setType(const std::string& type) = 0;
29 
30     virtual void setInputNames(const std::vector<std::string>& inputs) = 0;
31 };
32 
33 class ImportGraphWrapper
34 {
35 public:
~ImportGraphWrapper()36     virtual ~ImportGraphWrapper() {};
37 
38     virtual Ptr<ImportNodeWrapper> getNode(int idx) const = 0;
39 
40     virtual int getNumNodes() const = 0;
41 
42     virtual int getNumOutputs(int nodeId) const = 0;
43 
44     virtual std::string getOutputName(int nodeId, int outId) const = 0;
45 
46     virtual void removeNode(int idx) = 0;
47 };
48 
49 class Subgraph  // Interface to match and replace subgraphs.
50 {
51 public:
52     virtual ~Subgraph();
53 
54     // Add a node to be matched in the origin graph. Specify ids of nodes that
55     // are expected to be inputs. Returns id of a newly added node.
56     // TODO: Replace inputs to std::vector<int> in C++11
57     int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1,
58                        int input_2 = -1, int input_3 = -1);
59 
60     int addNodeToMatch(const std::string& op, const std::vector<int>& inputs_);
61 
62     // Specify resulting node. All the matched nodes in subgraph excluding
63     // input nodes will be fused into this single node.
64     // TODO: Replace inputs to std::vector<int> in C++11
65     void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1,
66                       int input_2 = -1, int input_3 = -1, int input_4 = -1,
67                       int input_5 = -1);
68 
69     void setFusedNode(const std::string& op, const std::vector<int>& inputs_);
70 
71     static int getInputNodeId(const Ptr<ImportGraphWrapper>& net,
72                               const Ptr<ImportNodeWrapper>& node,
73                               int inpId);
74 
75     // Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
76     // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
77     virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
78                        std::vector<int>& matchedNodesIds,
79                        std::vector<int>& targetNodesIds);
80 
81     // Fuse matched subgraph.
82     void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
83                  const std::vector<int>& targetNodesIds);
84 
85     virtual void finalize(const Ptr<ImportGraphWrapper>& net,
86                           const Ptr<ImportNodeWrapper>& fusedNode,
87                           std::vector<Ptr<ImportNodeWrapper> >& inputs);
88 
89 private:
90     std::vector<std::string> nodes;         // Nodes to be matched in the origin graph.
91     std::vector<std::vector<int> > inputs;  // Connections of an every node to it's inputs.
92 
93     std::string fusedNodeOp;           // Operation name of resulting fused node.
94     std::vector<int> fusedNodeInputs;  // Inputs of fused node.
95 };
96 
97 void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
98                        const std::vector<Ptr<Subgraph> >& patterns);
99 
100 }}  // namespace dnn, namespace cv
101 
102 #endif  // __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
103