1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 #include "./common.h" 21 #include "./subgraph_property.h" 22 #include "../../imperative/cached_op.h" 23 24 namespace mxnet { 25 namespace op { 26 27 /* 28 * This selects nodes for a subgraph that only contains operators 29 * in a given set and it visits nodes via both input and output links. 30 */ 31 class ContainOpSelector: public SubgraphSelector { 32 public: ContainOpSelector(const std::unordered_set<std::string> & op_names)33 explicit ContainOpSelector(const std::unordered_set<std::string>& op_names) 34 : op_names_(op_names) {} 35 Select(const nnvm::Node & seed_node)36 virtual bool Select(const nnvm::Node &seed_node) { 37 return !seed_node.is_variable() && op_names_.count(seed_node.op()->name); 38 } 39 SelectInput(const nnvm::Node & cur_node,const nnvm::Node & input_node)40 virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node &input_node) { 41 return !input_node.is_variable() && op_names_.count(input_node.op()->name); 42 } 43 SelectOutput(const nnvm::Node & cur_node,const nnvm::Node & output_node)44 virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node &output_node) { 45 return !output_node.is_variable() && op_names_.count(output_node.op()->name); 46 } 47 private: 48 const std::unordered_set<std::string>& op_names_; 49 }; 50 51 /* 52 * This subgraph property finds a subgraph whose nodes have only operators 53 * within a set. The operators in the subgraph will be executed by _CachedOp. 54 */ 55 class DefaultSubgraphProperty: public SubgraphProperty { 56 public: Create()57 static SubgraphPropertyPtr Create() { return std::make_shared<DefaultSubgraphProperty>(); } CreateSubgraphNode(const nnvm::Symbol & sym,const int subgraph_id=0) const58 virtual nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, 59 const int subgraph_id = 0) const { 60 nnvm::ObjectPtr n = nnvm::Node::Create(); 61 n->attrs.op = Op::Get("_CachedOp"); 62 n->attrs.name = "_CachedOp" + std::to_string(subgraph_id); 63 n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym)); 64 65 std::vector<std::pair<std::string, std::string> > flags{{"static_alloc", "true"}}; 66 n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags)); 67 68 return n; 69 } CreateSubgraphSelector() const70 virtual SubgraphSelectorPtr CreateSubgraphSelector() const { 71 return std::make_shared<ContainOpSelector>( 72 this->GetAttr<std::unordered_set<std::string>>("op_names")); 73 } 74 }; 75 76 MXNET_REGISTER_SUBGRAPH_BACKEND(default); 77 MXNET_REGISTER_SUBGRAPH_PROPERTY(default, DefaultSubgraphProperty); 78 79 } // namespace op 80 } // namespace mxnet 81