1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 21 #include "./common.h" 22 #include "./subgraph_property.h" 23 #include "../../imperative/cached_op.h" 24 25 namespace mxnet { 26 namespace op { 27 28 /* 29 * This selects nodes for a subgraph that only contains operators 30 * in a given set and it visits nodes via both input and output links. 31 */ 32 class ContainOpSelectorV2: public SubgraphSelectorV2 { 33 public: ContainOpSelectorV2(const std::unordered_set<std::string> & op_names)34 explicit ContainOpSelectorV2(const std::unordered_set<std::string>& op_names) 35 : op_names_(op_names) {} 36 Select(const BiDirectedNode & sn)37 bool Select(const BiDirectedNode &sn) override { 38 const auto &seed_node = *sn.node; 39 return !seed_node.is_variable() && op_names_.count(seed_node.op()->name); 40 } 41 SelectInput(const BiDirectedNode & sn,const BiDirectedNode & snew_node)42 bool SelectInput(const BiDirectedNode &sn, const BiDirectedNode &snew_node) override { 43 const auto &input_node = *snew_node.node; 44 return !input_node.is_variable() && op_names_.count(input_node.op()->name); 45 } 46 SelectOutput(const BiDirectedNode & sn,const BiDirectedNode & snew_node)47 bool SelectOutput(const BiDirectedNode &sn, const BiDirectedNode &snew_node) override { 48 const auto &output_node = *snew_node.node; 49 return !output_node.is_variable() && op_names_.count(output_node.op()->name); 50 } 51 private: 52 const std::unordered_set<std::string>& op_names_; 53 }; 54 55 /* 56 * This subgraph property finds a subgraph whose nodes have only operators 57 * within a set. The operators in the subgraph will be executed by _CachedOp. 58 */ 59 class DefaultSubgraphProperty: public SubgraphProperty { 60 public: Create()61 static SubgraphPropertyPtr Create() { return std::make_shared<DefaultSubgraphProperty>(); } CreateSubgraphNode(const nnvm::Symbol & sym,const SubgraphSelectorPtr & subgraph_selector,const int subgraph_id=0) const62 nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, 63 const SubgraphSelectorPtr& subgraph_selector, 64 const int subgraph_id = 0) const override { 65 nnvm::ObjectPtr n = nnvm::Node::Create(); 66 n->attrs.op = Op::Get("_CachedOp"); 67 n->attrs.name = "_CachedOp" + std::to_string(subgraph_id); 68 n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym)); 69 70 std::vector<std::pair<std::string, std::string> > flags{{"static_alloc", "true"}}; 71 n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags)); 72 73 return n; 74 } CreateSubgraphSelectorV2() const75 SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override { 76 return std::make_shared<ContainOpSelectorV2>( 77 this->GetAttr<std::unordered_set<std::string>>("op_names")); 78 } 79 }; 80 81 MXNET_REGISTER_SUBGRAPH_BACKEND(default_v2); 82 83 MXNET_REGISTER_SUBGRAPH_PROPERTY(default_v2, DefaultSubgraphProperty); 84 85 } // namespace op 86 } // namespace mxnet 87