1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 
21 #include "./common.h"
22 #include "./subgraph_property.h"
23 #include "../../imperative/cached_op.h"
24 
25 namespace mxnet {
26 namespace op {
27 
28 /*
29  * This selects nodes for a subgraph that only contains operators
30  * in a given set and it visits nodes via both input and output links.
31  */
32 class ContainOpSelectorV2: public SubgraphSelectorV2 {
33  public:
ContainOpSelectorV2(const std::unordered_set<std::string> & op_names)34   explicit ContainOpSelectorV2(const std::unordered_set<std::string>& op_names)
35     : op_names_(op_names) {}
36 
Select(const BiDirectedNode & sn)37   bool Select(const BiDirectedNode &sn) override {
38     const auto &seed_node = *sn.node;
39     return !seed_node.is_variable() && op_names_.count(seed_node.op()->name);
40   }
41 
SelectInput(const BiDirectedNode & sn,const BiDirectedNode & snew_node)42   bool SelectInput(const BiDirectedNode &sn, const BiDirectedNode &snew_node) override {
43     const auto &input_node = *snew_node.node;
44     return !input_node.is_variable() && op_names_.count(input_node.op()->name);
45   }
46 
SelectOutput(const BiDirectedNode & sn,const BiDirectedNode & snew_node)47   bool SelectOutput(const BiDirectedNode &sn, const BiDirectedNode &snew_node) override {
48     const auto &output_node = *snew_node.node;
49     return !output_node.is_variable() && op_names_.count(output_node.op()->name);
50   }
51  private:
52   const std::unordered_set<std::string>& op_names_;
53 };
54 
55 /*
56  * This subgraph property finds a subgraph whose nodes have only operators
57  * within a set. The operators in the subgraph will be executed by _CachedOp.
58  */
59 class DefaultSubgraphProperty: public SubgraphProperty {
60  public:
Create()61   static SubgraphPropertyPtr Create() { return std::make_shared<DefaultSubgraphProperty>(); }
CreateSubgraphNode(const nnvm::Symbol & sym,const SubgraphSelectorPtr & subgraph_selector,const int subgraph_id=0) const62   nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
63                                            const SubgraphSelectorPtr& subgraph_selector,
64                                            const int subgraph_id = 0) const override {
65     nnvm::ObjectPtr n = nnvm::Node::Create();
66     n->attrs.op = Op::Get("_CachedOp");
67     n->attrs.name = "_CachedOp" + std::to_string(subgraph_id);
68     n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
69 
70     std::vector<std::pair<std::string, std::string> > flags{{"static_alloc", "true"}};
71     n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags));
72 
73     return n;
74   }
CreateSubgraphSelectorV2() const75   SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
76     return std::make_shared<ContainOpSelectorV2>(
77         this->GetAttr<std::unordered_set<std::string>>("op_names"));
78   }
79 };
80 
81 MXNET_REGISTER_SUBGRAPH_BACKEND(default_v2);
82 
83 MXNET_REGISTER_SUBGRAPH_PROPERTY(default_v2, DefaultSubgraphProperty);
84 
85 }  // namespace op
86 }  // namespace mxnet
87