1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_
21 #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_
22 #if MXNET_USE_MKLDNN == 1
23 
24 #include <string>
25 #include <vector>
26 #include "../../nn/activation-inl.h"
27 #include "../../leaky_relu-inl.h"
28 #include "../../nn/convolution-inl.h"
29 #include "../../nn/mkldnn/mkldnn_ops-inl.h"
30 #include "../../tensor/matrix_op-inl.h"
31 #include "../common.h"
32 #include "mkldnn_subgraph_base-inl.h"
33 
34 namespace mxnet {
35 namespace op {
36 class SgMKLDNNConvSelector : public SubgraphSelector {
37  public:
38   /*! \brief pattern match status_ */
39   enum SelectStatus {
40     kFail = 0,
41     kStart,
42     kBN,
43     kSum,
44     kSuccess,
45   };
46 
47  private:
48   bool disable_all_;
49   bool disable_conv_bn_;
50   bool disable_conv_act_;
51   bool disable_conv_sum_;
52   bool quantize_;
53   SelectStatus status_;
54   std::vector<const nnvm::Node *> matched_list_;
55 
56  public:
SgMKLDNNConvSelector(int dis_all,int dis_conv_bn,int dis_conv_act,int dis_conv_sum,int quantize)57   SgMKLDNNConvSelector(int dis_all, int dis_conv_bn, int dis_conv_act, int dis_conv_sum,
58                        int quantize)
59       : disable_all_(dis_all),
60         disable_conv_bn_(dis_conv_bn),
61         disable_conv_act_(dis_conv_act),
62         disable_conv_sum_(dis_conv_sum),
63         quantize_(quantize) {}
64 
Select(const nnvm::Node & n,const std::shared_ptr<NodeAttr> & node_attr)65   bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
66     if (n.op() && n.op()->name == "Convolution") {
67       const auto &param = nnvm::get<ConvolutionParam>(n.attrs.parsed);
68       if ((param.kernel.ndim() == 2 || param.kernel.ndim() == 3) &&
69            SupportMKLDNNAttr(node_attr)) {
70         status_ = disable_all_ ? kSuccess : kStart;
71         matched_list_.clear();
72         matched_list_.push_back(&n);
73         return true;
74         }
75     }
76     return false;
77   }
78 
SelectInput(const nnvm::Node & n,const nnvm::Node & new_node)79   bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
80     return false;
81   }
82 
SelectOutput(const nnvm::Node & n,const nnvm::Node & new_node)83   bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
84     // If n isn't the last matched node, then we encoutered a internal
85     // branch, we should pop out the node behind n and stop fusion.
86     if (matched_list_.back() != &n) {
87       if (std::find(matched_list_.begin(), matched_list_.end(), &n) !=
88           matched_list_.end()) {
89         while (matched_list_.back() != &n) {
90           matched_list_.pop_back();
91         }
92       }
93       status_ = kSuccess;
94       return false;
95     }
96     if (status_ == kFail || status_ == kSuccess || new_node.is_variable())
97       return false;
98 
99     // Use status_ machine to do selection. The status_ change is
100     // kStart -> kBN -> kSum -> kSuccess
101     switch (status_) {
102       case kStart:
103         if ((!disable_conv_bn_) && new_node.op()->name == "BatchNorm") {
104           matched_list_.push_back(&new_node);
105           status_ = kBN;
106           return true;
107         }
108       case kBN:
109         if ((!disable_conv_sum_) && new_node.op()->name == "elemwise_add") {
110           matched_list_.push_back(&new_node);
111           status_ = kSum;
112           return true;
113         }
114       case kSum:
115       default:
116         if ((!disable_conv_act_) && new_node.op()->name == "Activation") {
117           const ActivationParam &param =
118               nnvm::get<ActivationParam>(new_node.attrs.parsed);
119           if ((quantize_ && SupportQuantizedMKLDNNAct(param)) ||
120               (!quantize_ && SupportMKLDNNAct(param))) {
121             matched_list_.push_back(&new_node);
122             // not support conv+relu+sum yet.
123             status_ = kSuccess;
124             return true;
125           }
126         } else if ((!disable_conv_act_) && new_node.op()->name == "LeakyReLU") {
127           const LeakyReLUParam &param =
128               nnvm::get<LeakyReLUParam>(new_node.attrs.parsed);
129           if (param.act_type == leakyrelu::kLeakyReLU ||
130               param.act_type == leakyrelu::kGELU) {
131             matched_list_.push_back(&new_node);
132             // not support conv+relu+sum yet.
133             status_ = kSuccess;
134             return true;
135           }
136         } else if ((!disable_conv_act_) && new_node.op()->name == "clip") {
137           if (!(quantize_ && (status_ == kSum))) {
138             // TODO(zhennan): doesn't support int8 conv+sum+relu6 at moment. To support this, we
139             // need to fuse conv+sum first, and calibrate with it. Then fuse int8 relu6 into fused
140             // conv.
141             const ClipParam &param = nnvm::get<ClipParam>(new_node.attrs.parsed);
142             if (param.a_min == 0.f) {
143               matched_list_.push_back(&new_node);
144               // not support conv+relu+sum yet.
145               status_ = kSuccess;
146               return true;
147             }
148           }
149         }
150         status_ = kSuccess;
151         return false;
152     }
153   }
154 
Filter(const std::vector<nnvm::Node * > & candidates)155   std::vector<nnvm::Node *> Filter(
156       const std::vector<nnvm::Node *> &candidates) override {
157     if (status_ == kFail) {
158       return std::vector<nnvm::Node *>(0);
159     } else {
160       std::vector<nnvm::Node *> ret;
161       for (auto i : matched_list_) {
162         auto non_const_i = const_cast<nnvm::Node *>(i);
163         if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
164             candidates.end()) {
165           ret.push_back(non_const_i);
166         }
167       }
168       return ret;
169     }
170   }
171 
Reset()172   void Reset() override {
173     CHECK_GE(matched_list_.size(), 1);
174     auto new_selector = SgMKLDNNConvSelector(disable_all_, disable_conv_bn_, disable_conv_act_,
175                                              disable_conv_sum_, quantize_);
176     new_selector.Select(*matched_list_[0], nullptr);
177     *this = new_selector;
178   }
179 };
180 
181 class SgMKLDNNConvProperty : public SubgraphProperty {
182  public:
SgMKLDNNConvProperty()183   SgMKLDNNConvProperty() {
184     disable_conv_bn_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_BN", 0);
185     disable_conv_act_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_RELU", 0);
186     disable_conv_sum_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_SUM", 0);
187 
188     disable_all_ = disable_conv_bn_ && disable_conv_act_ && disable_conv_sum_;
189   }
Create()190   static SubgraphPropertyPtr Create() {
191     static const std::string &name = "MKLDNN convolution optimization pass";
192     auto property = std::make_shared<SgMKLDNNConvProperty>();
193     property->SetAttr<std::string>("property_name", name);
194     property->SetAttr<bool>("inference_only", true);
195     if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) {
196       property->SetAttr<bool>("disable", true);
197     }
198     return property;
199   }
200   nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
201                                    const int subgraph_id = 0) const override {
202     nnvm::ObjectPtr n = nnvm::Node::Create();
203     // This op has single output, remove duplicated.
204     auto last_node = sym.outputs[0].node;
205     nnvm::Symbol new_sym;
206     new_sym.outputs.emplace_back(last_node);
207     std::ostringstream node_name;
208     node_name << "sg_mkldnn_";
209     bool _with_sum = false;
210     DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) {
211       if (node->is_variable()) return;
212       auto &sub_name = node->op()->name;
213       if (sub_name == "Convolution") {
214         node_name << "conv_";
215       } else if (sub_name == "BatchNorm") {
216         node_name << "bn_";
217         n->attrs.dict["with_bn"] = "true";
218       } else if (sub_name == "elemwise_add") {
219         node_name << "add_";
220         n->attrs.dict["with_sum"] = "true";
221         _with_sum = true;
222 
223       } else if (sub_name == "Activation" || sub_name == "LeakyReLU" || sub_name == "clip") {
224         node_name << "act_";
225         if (!_with_sum) {
226           n->attrs.dict["with_act"] = "true";
227         } else {
228           n->attrs.dict["with_postsum_act"] = "true";
229         }
230       }
231     });
232     node_name << std::to_string(subgraph_id);
233     n->attrs.name = node_name.str();
234     n->attrs.op = Op::Get("_sg_mkldnn_conv");
235     CHECK(n->attrs.op);
236     n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
237     n->op()->attr_parser(&(n->attrs));
238     return n;
239   }
240 
CreateSubgraphSelector()241   SubgraphSelectorPtr CreateSubgraphSelector() const override {
242     bool quantize = HasAttr("quantize") ? GetAttr<bool>("quantize") : false;
243     auto selector = std::make_shared<SgMKLDNNConvSelector>(
244         disable_all_, disable_conv_bn_, disable_conv_act_, disable_conv_sum_, quantize);
245     return selector;
246   }
247 
ConnectSubgraphOutputs(const nnvm::ObjectPtr n,std::vector<nnvm::NodeEntry * > * output_entries)248   void ConnectSubgraphOutputs(
249       const nnvm::ObjectPtr n,
250       std::vector<nnvm::NodeEntry *> *output_entries) const override {
251     // Connect all extern output entries to output[0]
252     for (size_t i = 0; i < output_entries->size(); ++i) {
253       *output_entries->at(i) = nnvm::NodeEntry{n, 0, 0};
254     }
255   }
256 
ConnectSubgraphInputs(const nnvm::ObjectPtr n,std::vector<nnvm::NodeEntry * > * input_entries,std::vector<nnvm::NodeEntry> * orig_input_entries)257   void ConnectSubgraphInputs(
258       const nnvm::ObjectPtr n, std::vector<nnvm::NodeEntry *> *input_entries,
259       std::vector<nnvm::NodeEntry> *orig_input_entries) const override {
260     auto sym = n->attrs.subgraphs[0];
261     std::unordered_set<const nnvm::Node *> node_sets;
262     DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr &node) {
263       if (node->is_variable()) return;
264       node_sets.insert(node.get());
265       if (node->op()->name == "elemwise_add") {
266         // Make sure n is the left operand of sum, if not,
267         // switch sum operands sequence to ensure that
268         // the extra sum operand stays in the last of inputs.
269         if (node_sets.count(node->inputs[1].node.get())) {
270           auto tmp = node->inputs[1];
271           node->inputs[1] = node->inputs[0];
272           node->inputs[0] = tmp;
273           std::rotate(input_entries->begin(), input_entries->begin() + 1,
274                       input_entries->end());
275           std::rotate(orig_input_entries->begin(),
276                       orig_input_entries->begin() + 1,
277                       orig_input_entries->end());
278         }
279       }
280     });
281     n->inputs = *orig_input_entries;
282   }
283 
284  private:
285   int disable_all_;
286   int disable_conv_bn_;
287   int disable_conv_act_;
288   int disable_conv_sum_;
289 };
290 
291 }  // namespace op
292 }  // namespace mxnet
293 
294 #endif  // if MXNET_USE_MKLDNN == 1
295 #endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_
296