1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_ 21 #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_ 22 #if MXNET_USE_MKLDNN == 1 23 24 #include <string> 25 #include <vector> 26 #include "../../nn/activation-inl.h" 27 #include "../../leaky_relu-inl.h" 28 #include "../../nn/convolution-inl.h" 29 #include "../../nn/mkldnn/mkldnn_ops-inl.h" 30 #include "../../tensor/matrix_op-inl.h" 31 #include "../common.h" 32 #include "mkldnn_subgraph_base-inl.h" 33 34 namespace mxnet { 35 namespace op { 36 class SgMKLDNNConvSelector : public SubgraphSelector { 37 public: 38 /*! \brief pattern match status_ */ 39 enum SelectStatus { 40 kFail = 0, 41 kStart, 42 kBN, 43 kSum, 44 kSuccess, 45 }; 46 47 private: 48 bool disable_all_; 49 bool disable_conv_bn_; 50 bool disable_conv_act_; 51 bool disable_conv_sum_; 52 bool quantize_; 53 SelectStatus status_; 54 std::vector<const nnvm::Node *> matched_list_; 55 56 public: SgMKLDNNConvSelector(int dis_all,int dis_conv_bn,int dis_conv_act,int dis_conv_sum,int quantize)57 SgMKLDNNConvSelector(int dis_all, int dis_conv_bn, int dis_conv_act, int dis_conv_sum, 58 int quantize) 59 : disable_all_(dis_all), 60 disable_conv_bn_(dis_conv_bn), 61 disable_conv_act_(dis_conv_act), 62 disable_conv_sum_(dis_conv_sum), 63 quantize_(quantize) {} 64 Select(const nnvm::Node & n,const std::shared_ptr<NodeAttr> & node_attr)65 bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override { 66 if (n.op() && n.op()->name == "Convolution") { 67 const auto ¶m = nnvm::get<ConvolutionParam>(n.attrs.parsed); 68 if ((param.kernel.ndim() == 2 || param.kernel.ndim() == 3) && 69 SupportMKLDNNAttr(node_attr)) { 70 status_ = disable_all_ ? kSuccess : kStart; 71 matched_list_.clear(); 72 matched_list_.push_back(&n); 73 return true; 74 } 75 } 76 return false; 77 } 78 SelectInput(const nnvm::Node & n,const nnvm::Node & new_node)79 bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { 80 return false; 81 } 82 SelectOutput(const nnvm::Node & n,const nnvm::Node & new_node)83 bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { 84 // If n isn't the last matched node, then we encoutered a internal 85 // branch, we should pop out the node behind n and stop fusion. 86 if (matched_list_.back() != &n) { 87 if (std::find(matched_list_.begin(), matched_list_.end(), &n) != 88 matched_list_.end()) { 89 while (matched_list_.back() != &n) { 90 matched_list_.pop_back(); 91 } 92 } 93 status_ = kSuccess; 94 return false; 95 } 96 if (status_ == kFail || status_ == kSuccess || new_node.is_variable()) 97 return false; 98 99 // Use status_ machine to do selection. The status_ change is 100 // kStart -> kBN -> kSum -> kSuccess 101 switch (status_) { 102 case kStart: 103 if ((!disable_conv_bn_) && new_node.op()->name == "BatchNorm") { 104 matched_list_.push_back(&new_node); 105 status_ = kBN; 106 return true; 107 } 108 case kBN: 109 if ((!disable_conv_sum_) && new_node.op()->name == "elemwise_add") { 110 matched_list_.push_back(&new_node); 111 status_ = kSum; 112 return true; 113 } 114 case kSum: 115 default: 116 if ((!disable_conv_act_) && new_node.op()->name == "Activation") { 117 const ActivationParam ¶m = 118 nnvm::get<ActivationParam>(new_node.attrs.parsed); 119 if ((quantize_ && SupportQuantizedMKLDNNAct(param)) || 120 (!quantize_ && SupportMKLDNNAct(param))) { 121 matched_list_.push_back(&new_node); 122 // not support conv+relu+sum yet. 123 status_ = kSuccess; 124 return true; 125 } 126 } else if ((!disable_conv_act_) && new_node.op()->name == "LeakyReLU") { 127 const LeakyReLUParam ¶m = 128 nnvm::get<LeakyReLUParam>(new_node.attrs.parsed); 129 if (param.act_type == leakyrelu::kLeakyReLU || 130 param.act_type == leakyrelu::kGELU) { 131 matched_list_.push_back(&new_node); 132 // not support conv+relu+sum yet. 133 status_ = kSuccess; 134 return true; 135 } 136 } else if ((!disable_conv_act_) && new_node.op()->name == "clip") { 137 if (!(quantize_ && (status_ == kSum))) { 138 // TODO(zhennan): doesn't support int8 conv+sum+relu6 at moment. To support this, we 139 // need to fuse conv+sum first, and calibrate with it. Then fuse int8 relu6 into fused 140 // conv. 141 const ClipParam ¶m = nnvm::get<ClipParam>(new_node.attrs.parsed); 142 if (param.a_min == 0.f) { 143 matched_list_.push_back(&new_node); 144 // not support conv+relu+sum yet. 145 status_ = kSuccess; 146 return true; 147 } 148 } 149 } 150 status_ = kSuccess; 151 return false; 152 } 153 } 154 Filter(const std::vector<nnvm::Node * > & candidates)155 std::vector<nnvm::Node *> Filter( 156 const std::vector<nnvm::Node *> &candidates) override { 157 if (status_ == kFail) { 158 return std::vector<nnvm::Node *>(0); 159 } else { 160 std::vector<nnvm::Node *> ret; 161 for (auto i : matched_list_) { 162 auto non_const_i = const_cast<nnvm::Node *>(i); 163 if (std::find(candidates.begin(), candidates.end(), non_const_i) != 164 candidates.end()) { 165 ret.push_back(non_const_i); 166 } 167 } 168 return ret; 169 } 170 } 171 Reset()172 void Reset() override { 173 CHECK_GE(matched_list_.size(), 1); 174 auto new_selector = SgMKLDNNConvSelector(disable_all_, disable_conv_bn_, disable_conv_act_, 175 disable_conv_sum_, quantize_); 176 new_selector.Select(*matched_list_[0], nullptr); 177 *this = new_selector; 178 } 179 }; 180 181 class SgMKLDNNConvProperty : public SubgraphProperty { 182 public: SgMKLDNNConvProperty()183 SgMKLDNNConvProperty() { 184 disable_conv_bn_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_BN", 0); 185 disable_conv_act_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_RELU", 0); 186 disable_conv_sum_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_SUM", 0); 187 188 disable_all_ = disable_conv_bn_ && disable_conv_act_ && disable_conv_sum_; 189 } Create()190 static SubgraphPropertyPtr Create() { 191 static const std::string &name = "MKLDNN convolution optimization pass"; 192 auto property = std::make_shared<SgMKLDNNConvProperty>(); 193 property->SetAttr<std::string>("property_name", name); 194 property->SetAttr<bool>("inference_only", true); 195 if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) { 196 property->SetAttr<bool>("disable", true); 197 } 198 return property; 199 } 200 nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, 201 const int subgraph_id = 0) const override { 202 nnvm::ObjectPtr n = nnvm::Node::Create(); 203 // This op has single output, remove duplicated. 204 auto last_node = sym.outputs[0].node; 205 nnvm::Symbol new_sym; 206 new_sym.outputs.emplace_back(last_node); 207 std::ostringstream node_name; 208 node_name << "sg_mkldnn_"; 209 bool _with_sum = false; 210 DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { 211 if (node->is_variable()) return; 212 auto &sub_name = node->op()->name; 213 if (sub_name == "Convolution") { 214 node_name << "conv_"; 215 } else if (sub_name == "BatchNorm") { 216 node_name << "bn_"; 217 n->attrs.dict["with_bn"] = "true"; 218 } else if (sub_name == "elemwise_add") { 219 node_name << "add_"; 220 n->attrs.dict["with_sum"] = "true"; 221 _with_sum = true; 222 223 } else if (sub_name == "Activation" || sub_name == "LeakyReLU" || sub_name == "clip") { 224 node_name << "act_"; 225 if (!_with_sum) { 226 n->attrs.dict["with_act"] = "true"; 227 } else { 228 n->attrs.dict["with_postsum_act"] = "true"; 229 } 230 } 231 }); 232 node_name << std::to_string(subgraph_id); 233 n->attrs.name = node_name.str(); 234 n->attrs.op = Op::Get("_sg_mkldnn_conv"); 235 CHECK(n->attrs.op); 236 n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym)); 237 n->op()->attr_parser(&(n->attrs)); 238 return n; 239 } 240 CreateSubgraphSelector()241 SubgraphSelectorPtr CreateSubgraphSelector() const override { 242 bool quantize = HasAttr("quantize") ? GetAttr<bool>("quantize") : false; 243 auto selector = std::make_shared<SgMKLDNNConvSelector>( 244 disable_all_, disable_conv_bn_, disable_conv_act_, disable_conv_sum_, quantize); 245 return selector; 246 } 247 ConnectSubgraphOutputs(const nnvm::ObjectPtr n,std::vector<nnvm::NodeEntry * > * output_entries)248 void ConnectSubgraphOutputs( 249 const nnvm::ObjectPtr n, 250 std::vector<nnvm::NodeEntry *> *output_entries) const override { 251 // Connect all extern output entries to output[0] 252 for (size_t i = 0; i < output_entries->size(); ++i) { 253 *output_entries->at(i) = nnvm::NodeEntry{n, 0, 0}; 254 } 255 } 256 ConnectSubgraphInputs(const nnvm::ObjectPtr n,std::vector<nnvm::NodeEntry * > * input_entries,std::vector<nnvm::NodeEntry> * orig_input_entries)257 void ConnectSubgraphInputs( 258 const nnvm::ObjectPtr n, std::vector<nnvm::NodeEntry *> *input_entries, 259 std::vector<nnvm::NodeEntry> *orig_input_entries) const override { 260 auto sym = n->attrs.subgraphs[0]; 261 std::unordered_set<const nnvm::Node *> node_sets; 262 DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr &node) { 263 if (node->is_variable()) return; 264 node_sets.insert(node.get()); 265 if (node->op()->name == "elemwise_add") { 266 // Make sure n is the left operand of sum, if not, 267 // switch sum operands sequence to ensure that 268 // the extra sum operand stays in the last of inputs. 269 if (node_sets.count(node->inputs[1].node.get())) { 270 auto tmp = node->inputs[1]; 271 node->inputs[1] = node->inputs[0]; 272 node->inputs[0] = tmp; 273 std::rotate(input_entries->begin(), input_entries->begin() + 1, 274 input_entries->end()); 275 std::rotate(orig_input_entries->begin(), 276 orig_input_entries->begin() + 1, 277 orig_input_entries->end()); 278 } 279 } 280 }); 281 n->inputs = *orig_input_entries; 282 } 283 284 private: 285 int disable_all_; 286 int disable_conv_bn_; 287 int disable_conv_act_; 288 int disable_conv_sum_; 289 }; 290 291 } // namespace op 292 } // namespace mxnet 293 294 #endif // if MXNET_USE_MKLDNN == 1 295 #endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_ 296