1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include "./common.h"
21 #include "./subgraph_property.h"
22 #include "../../imperative/cached_op.h"
23 
24 namespace mxnet {
25 namespace op {
26 
27 /*
28  * This selects nodes for a subgraph that only contains operators
29  * in a given set and it visits nodes via both input and output links.
30  */
31 class ContainOpSelector: public SubgraphSelector {
32  public:
ContainOpSelector(const std::unordered_set<std::string> & op_names)33   explicit ContainOpSelector(const std::unordered_set<std::string>& op_names)
34     : op_names_(op_names) {}
35 
Select(const nnvm::Node & seed_node)36   virtual bool Select(const nnvm::Node &seed_node) {
37     return !seed_node.is_variable() && op_names_.count(seed_node.op()->name);
38   }
39 
SelectInput(const nnvm::Node & cur_node,const nnvm::Node & input_node)40   virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node &input_node) {
41     return !input_node.is_variable() && op_names_.count(input_node.op()->name);
42   }
43 
SelectOutput(const nnvm::Node & cur_node,const nnvm::Node & output_node)44   virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node &output_node) {
45     return !output_node.is_variable() && op_names_.count(output_node.op()->name);
46   }
47  private:
48   const std::unordered_set<std::string>& op_names_;
49 };
50 
51 /*
52  * This subgraph property finds a subgraph whose nodes have only operators
53  * within a set. The operators in the subgraph will be executed by _CachedOp.
54  */
55 class DefaultSubgraphProperty: public SubgraphProperty {
56  public:
Create()57   static SubgraphPropertyPtr Create() { return std::make_shared<DefaultSubgraphProperty>(); }
CreateSubgraphNode(const nnvm::Symbol & sym,const int subgraph_id=0) const58   virtual nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
59                                            const int subgraph_id = 0) const {
60     nnvm::ObjectPtr n = nnvm::Node::Create();
61     n->attrs.op = Op::Get("_CachedOp");
62     n->attrs.name = "_CachedOp" + std::to_string(subgraph_id);
63     n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
64 
65     std::vector<std::pair<std::string, std::string> > flags{{"static_alloc", "true"}};
66     n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags));
67 
68     return n;
69   }
CreateSubgraphSelector() const70   virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
71     return std::make_shared<ContainOpSelector>(
72         this->GetAttr<std::unordered_set<std::string>>("op_names"));
73   }
74 };
75 
76 MXNET_REGISTER_SUBGRAPH_BACKEND(default);
77 MXNET_REGISTER_SUBGRAPH_PROPERTY(default, DefaultSubgraphProperty);
78 
79 }  // namespace op
80 }  // namespace mxnet
81