1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 
20 /*!
21  * \file mkldnn_fc.cc
22  * \brief MKLDNN (Quantized) FullyConnected operator based on subgraph
23  * \author Ciyong Chen
24 */
25 
26 #if MXNET_USE_MKLDNN == 1
27 
28 #include <utility>
29 #include <vector>
30 #include <string>
31 #include "../common.h"
32 #include "../../nn/mkldnn/mkldnn_base-inl.h"
33 #include "../../nn/mkldnn/mkldnn_ops-inl.h"
34 #include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
35 #include "../../nn/mkldnn/mkldnn_act-inl.h"
36 #include "../../tensor/matrix_op-inl.h"
37 #include "../../quantization/quantization_utils.h"
38 #include "mkldnn_fc-inl.h"
39 #include "mkldnn_common.h"
40 
41 namespace mxnet {
42 namespace op {
43 
44 class SgMKLDNNFCOp {
45  public:
SgMKLDNNFCOp(const nnvm::NodeAttrs & attrs)46   explicit SgMKLDNNFCOp(const nnvm::NodeAttrs &attrs)
47     : subgraph_sym_(*attrs.subgraphs[0]),
48       full_param_(nnvm::get<MKLDNNFCFullParam>(attrs.parsed)) {}
49 
50   void Forward(const OpContext &ctx,
51                const std::vector<NDArray> &inputs,
52                const std::vector<OpReqType> &req,
53                const std::vector<NDArray> &outputs);
54 
Backward(const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)55   void Backward(const OpContext &ctx,
56                 const std::vector<NDArray> &inputs,
57                 const std::vector<OpReqType> &req,
58                 const std::vector<NDArray> &outputs) {
59     LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports "
60                   "inference computation.";
61   }
62 
63  private:
64   bool initialized_{false};
65   bool channel_wise_runtime_{false};
66   bool reorder_data_{false};
67   nnvm::Symbol subgraph_sym_;
68   MKLDNNFCFullParam full_param_;
69   mkldnn_args_map_t args_;
70   std::shared_ptr<MKLDNNFullyConnectedForward> fwd_;
71   std::shared_ptr<mkldnn::memory> cached_data_mem_;
72   std::shared_ptr<mkldnn::memory> cached_out_mem_;
73   NDArray cached_weight_;
74   NDArray cached_bias_;
75   float cached_min_data_;
76   float cached_max_data_;
77   float cached_min_weight_;
78   float cached_max_weight_;
79   float cached_min_bias_;
80   float cached_max_bias_;
81   size_t weight_ver_;
82   size_t bias_ver_;
83   float cached_min_output_;
84   float cached_max_output_;
85   float data_scale_{0.0f};
86   std::vector<float> weight_scales_;
87   size_t total_num_inputs_;
88   size_t total_num_outputs_;
89 };
90 
Forward(const OpContext & ctx,const std::vector<NDArray> & in_data,const std::vector<OpReqType> & req,const std::vector<NDArray> & out_data)91 void SgMKLDNNFCOp::Forward(const OpContext &ctx,
92                            const std::vector<NDArray> &in_data,
93                            const std::vector<OpReqType> &req,
94                            const std::vector<NDArray> &out_data) {
95   auto &mkldnn_param = full_param_.mkldnn_param;
96   auto &default_param = full_param_.default_param;
97   bool has_bias = !default_param.no_bias;
98   size_t base_num_inputs = has_bias ? 3 : 2;
99   size_t base_num_outputs = 1;
100 
101   float min_data = 0.0f;
102   float max_data = 0.0f;
103   float min_weight = 0.0f;
104   float max_weight = 0.0f;
105   float min_bias = 0.0f;
106   float max_bias = 0.0f;
107 
108   if (!initialized_) {
109     if (mkldnn_param.channel_wise_quantize.has_value() &&
110         mkldnn_param.channel_wise_quantize) {
111       channel_wise_runtime_ = true;
112     }
113 
114     total_num_inputs_ = base_num_inputs;
115     total_num_outputs_ = base_num_outputs;
116     if (mkldnn_param.quantized) {
117       total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : (base_num_inputs * 3);
118       total_num_outputs_ =
119         mkldnn_param.enable_float_output ? base_num_outputs : (base_num_outputs * 3);
120     }
121   }
122   CHECK_EQ(in_data.size(), total_num_inputs_);
123   CHECK_EQ(out_data.size(), total_num_outputs_);
124 
125   NDArray data = in_data[fullc::kData];
126   const NDArray &weight = in_data[fullc::kWeight];
127   const NDArray &output = out_data[fullc::kOut];
128 
129   if (mkldnn_param.quantized) {
130     if (!channel_wise_runtime_) {
131       min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr<float>()[0];
132       max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr<float>()[0];
133       if (has_bias) {
134         min_bias = in_data[base_num_inputs + quantized_fullc::kBiasMin].data().dptr<float>()[0];
135         max_bias = in_data[base_num_inputs + quantized_fullc::kBiasMax].data().dptr<float>()[0];
136       }
137     }
138     min_data = in_data[base_num_inputs + quantized_fullc::kDataMin].data().dptr<float>()[0];
139     max_data = in_data[base_num_inputs + quantized_fullc::kDataMax].data().dptr<float>()[0];
140   }
141 
142   if (initialized_ && mkldnn_param.quantized &&
143       dmlc::GetEnv("MXNET_MKLDNN_QFC_DYNAMIC_PARAMS", 0)) {
144     if (channel_wise_runtime_) {
145       if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
146           weight_ver_ != weight.version() ||
147           (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) {
148         initialized_ = false;
149       }
150     } else {
151       if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
152           cached_min_weight_ != min_weight || cached_max_weight_ != max_weight ||
153           (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) {
154         initialized_ = false;
155       }
156     }
157   }
158 
159   if (!initialized_) {
160     const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
161     const auto engine = CpuEngine::Get()->get_engine();
162     cached_min_data_ = min_data;
163     cached_max_data_ = max_data;
164     cached_min_weight_ = min_weight;
165     cached_max_weight_ = max_weight;
166     weight_ver_ = weight.version();
167     cached_weight_ = weight;
168     if (has_bias) {
169       cached_min_bias_ = min_bias;
170       cached_max_bias_ = max_bias;
171       bias_ver_ = in_data[fullc::kBias].version();
172       cached_bias_ = in_data[fullc::kBias];
173     } else {
174       cached_bias_ = NDArray();
175     }
176     const mxnet::TShape ishape = data.shape();
177     const auto data_ndim = ishape.ndim();
178     if (data.IsMKLDNNData()) {
179       reorder_data_ = true;
180       data = data.Reorder2Default();
181     }
182     if (data_ndim != 2) {
183       if (!default_param.flatten) {
184         data = data.MKLDNNDataReshape(
185             Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1]));
186       } else {
187         data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim)));
188       }
189     }
190 
191     // create cached out_md
192     const mxnet::TShape oshape = output.shape();
193     mkldnn::memory::dims out_dims(2);
194     if (oshape.ndim() == 2) {
195       out_dims[0] = static_cast<int>(oshape[0]);
196       out_dims[1] = static_cast<int>(oshape[1]);
197     } else {
198       if (!default_param.flatten) {
199         out_dims[0] = static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1));
200         out_dims[1] = static_cast<int>(oshape[oshape.ndim()-1]);
201       } else {
202         out_dims[0] = static_cast<int>(static_cast<int>(oshape[0]));
203         out_dims[1] = static_cast<int>(oshape.ProdShape(1, oshape.ndim()));
204       }
205     }
206     mkldnn::memory::desc out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(output.dtype()),
207       static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(2)));
208     cached_out_mem_ = std::make_shared<mkldnn::memory>(out_md, engine);
209 
210     bool support_channelwise_scale = false;
211     if (mkldnn_param.quantized) {
212       CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
213       data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_);
214 
215       bool fuse_requantize = false;
216       // Channelwise scaling is only supported when fusion is enabled (requantize or dequantize).
217       if (mkldnn_param.min_calib_range.has_value() &&
218           mkldnn_param.max_calib_range.has_value()) {
219         cached_min_output_ = mkldnn_param.min_calib_range.value();
220         cached_max_output_ = mkldnn_param.max_calib_range.value();
221         support_channelwise_scale = true;
222         fuse_requantize = true;
223       }
224       if (mkldnn_param.enable_float_output) {
225         support_channelwise_scale = true;
226       }
227       // channel_wise  support_channelwise_scale  result
228       // True          True                       True
229       // True          False                      Error
230       // False         True/False                 False
231       if (channel_wise_runtime_ && !support_channelwise_scale) {
232         LOG(FATAL)
233           << "Currently, channel-wise quantization requires fuse requantize or dequantize."
234           << " Please make sure the `min_calib_range` and `max_calib_range` are set when only"
235           << " fuse requantize (outputs of FullyConnected are collected during calibration phase),"
236           << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
237           << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)";
238       }
239       support_channelwise_scale = support_channelwise_scale && channel_wise_runtime_;
240 
241       if (support_channelwise_scale) {
242         MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
243           weight_scales_ =
244             GetWeightScales<DType>(cached_weight_, has_bias ? &cached_bias_ : nullptr,
245                                    data_scale_, support_channelwise_scale);
246         });
247       } else {
248         weight_scales_.resize(1);
249         weight_scales_[0] =
250           GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_, cached_max_weight_);
251         if (has_bias) {
252           float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_);
253           float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale;
254           // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value
255           // of bias to INT_MAX / 2.
256           float bias_max_rescale =
257               MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale;
258           if (bias_int32_rescale > bias_max_rescale) {
259             // avoid overflow on bias
260             bias_int32_rescale = bias_max_rescale;
261             float weight_rescale =
262               bias_int32_rescale * bias_scale / data_scale_ / weight_scales_[0];
263             int8_t *weight_ptr = weight.data().dptr<int8_t>();
264             size_t weight_size = weight.shape().Size();
265             #pragma omp parallel for num_threads(nthreads)
266             for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
267               weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
268             }
269             weight_scales_[0] *= weight_rescale;
270           }
271           NDArray bias = in_data[fullc::kBias];
272           cached_bias_ =
273               NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32);
274           int8_t *bias_ptr = bias.data().dptr<int8_t>();
275           int32_t *quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
276           size_t bias_size = bias.shape().Size();
277           #pragma omp parallel for num_threads(nthreads)
278           for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
279             quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale);
280           }
281         }
282       }
283 
284       size_t num_channel = cached_weight_.shape()[0];
285       if (fuse_requantize || mkldnn_param.enable_float_output) {
286         float tmp_scale_ = 1.0f;
287         if (fuse_requantize) {
288           if (mkldnn_param.with_eltwise) {
289             tmp_scale_ = 1.0 / data_scale_;
290             full_param_.eltwise_param.scale =
291               GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_);
292           } else {
293             tmp_scale_ =
294               GetQuantizeScale(output.dtype(),
295                                cached_min_output_,
296                                cached_max_output_) / data_scale_;
297           }
298         } else {
299           tmp_scale_ = 1.0 / data_scale_;
300         }
301 
302         if (support_channelwise_scale) {
303           full_param_.output_scales.resize(num_channel);
304           #pragma omp parallel for num_threads(nthreads)
305           for (index_t i = 0; i < static_cast<index_t>(num_channel); ++i) {
306             full_param_.output_scales[i] = tmp_scale_ / weight_scales_[i];
307           }
308         } else {
309           full_param_.output_scales.resize(1);
310           full_param_.output_scales[0] = tmp_scale_ / weight_scales_[0];
311         }
312       } else {
313         Stream<cpu> *s = ctx.get_stream<cpu>();
314         if (data.dtype() == mshadow::kInt8) {
315           mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
316               s, 1, &cached_min_output_, &cached_max_output_, &min_data, &max_data, &min_weight,
317               &max_weight);
318         } else {
319           mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
320               s, 1, &cached_min_output_, &cached_max_output_, &min_data, &max_data, &min_weight,
321               &max_weight);
322         }
323         full_param_.output_scales.resize(0);
324       }
325     }
326 
327     fwd_.reset(new MKLDNNFullyConnectedForward(full_param_, ctx.is_train, data, cached_weight_,
328       (has_bias ? &cached_bias_ : nullptr), out_md));
329 
330     // convert weight and bias to the format that MKL-DNN requires
331     if (!mkldnn_param.quantized || support_channelwise_scale) {
332       mkldnn::memory::desc bias_md;
333       if (has_bias) bias_md = fwd_->fwd_pd.bias_desc();
334       ConvertWeightBias2MKLDNN(&cached_weight_, &cached_bias_, has_bias,
335                               fwd_->fwd_pd.weights_desc(),
336                               has_bias ? &bias_md : nullptr,
337                               1, data_scale_, weight_scales_, false);
338     } else {
339       const auto def_weight_mem = weight.GetMKLDNNData();
340       if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
341         cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc());
342         auto cached_weight_mem = cached_weight_.GetMKLDNNData();
343         std::unordered_map<int, mkldnn::memory> args(
344           {{MKLDNN_ARG_FROM, *def_weight_mem},
345           {MKLDNN_ARG_TO, *cached_weight_mem}});
346         MKLDNNStream::Get()->RegisterPrimArgs(
347           mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args);
348       }
349     }
350 
351     const auto data_mem = data.GetMKLDNNData();
352     cached_data_mem_ = std::make_shared<mkldnn::memory>(data_mem->get_desc(), engine);
353 
354     args_[MKLDNN_ARG_SRC] = *cached_data_mem_;
355     args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData();
356     if (has_bias)
357       args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData();
358     args_[MKLDNN_ARG_DST] = *cached_out_mem_;
359     initialized_ = true;
360   }
361 
362   if (reorder_data_) {
363     data = data.Reorder2Default();
364   }
365   MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
366     cached_data_mem_->set_data_handle(reinterpret_cast<void *>(data.data().dptr<DType>()));
367   });
368   MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
369     cached_out_mem_->set_data_handle(reinterpret_cast<void *>(output.data().dptr<DType>()));
370   });
371   MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
372   MKLDNNStream::Get()->Submit();
373 
374   if (mkldnn_param.quantized && !mkldnn_param.enable_float_output) {
375     float *min_output_ptr = out_data[quantized_fullc::kOutMin].data().dptr<float>();
376     float *max_output_ptr = out_data[quantized_fullc::kOutMax].data().dptr<float>();
377     *min_output_ptr = cached_min_output_;
378     *max_output_ptr = cached_max_output_;
379   }
380 }
381 
SgMKLDNNFCParamParser(nnvm::NodeAttrs * attrs)382 static void SgMKLDNNFCParamParser(nnvm::NodeAttrs *attrs) {
383   // For backward compatible, with_relu->with_eltwise
384   auto legacy = attrs->dict.find("with_relu");
385   if (legacy != attrs->dict.end()) {
386     attrs->dict["with_eltwise"] = attrs->dict["with_relu"];
387     attrs->dict.erase(legacy);
388   }
389 
390   MKLDNNFCFullParam full_param;
391   try {
392     full_param.mkldnn_param.Init(attrs->dict);
393   } catch (const dmlc::ParamError &e) {
394     std::ostringstream os;
395     os << e.what();
396     os << ", in operator " << attrs->op->name << "("
397        << "name=\"" << attrs->name << "\"";
398     for (const auto &k : attrs->dict) {
399       os << ", " << k.first << "=\"" << k.second << "\"";
400     }
401     os << ")";
402     throw dmlc::ParamError(os.str());
403   }
404   auto subgraph_sym = attrs->subgraphs[0];
405   DFSVisit(subgraph_sym->outputs, [&](const nnvm::ObjectPtr &node) {
406     if (node->is_variable()) return;
407     auto &op_name = node->op()->name;
408     if (op_name == "FullyConnected") {
409       full_param.default_param =
410           nnvm::get<FullyConnectedParam>(node->attrs.parsed);
411     } else if (SupportMKLDNNFCEltwiseFusion(op_name)) {
412       if (op_name == "Activation") {
413         const ActivationParam act_param = nnvm::get<ActivationParam>(node->attrs.parsed);
414         full_param.eltwise_param.alg = GetMKLDNNActAlgo(act_param);
415       } else if (op_name == "LeakyReLU") {
416         const auto act_param = nnvm::get<LeakyReLUParam>(node->attrs.parsed);
417         full_param.eltwise_param.alpha = act_param.slope;
418         full_param.eltwise_param.alg = GetMKLDNNActAlgo(act_param);
419       } else if (op_name == "clip") {
420         const ClipParam clip_param = nnvm::get<ClipParam>(node->attrs.parsed);
421         full_param.eltwise_param.alg = mkldnn::algorithm::eltwise_bounded_relu;
422         full_param.eltwise_param.alpha = clip_param.a_max;
423       } else {
424         full_param.eltwise_param.alg = GetMKLDNNEltwiseAlgo(op_name);
425       }
426     }
427   });
428   attrs->parsed = std::move(full_param);
429 }
430 
SgMKLDNNFCListInputNames(const NodeAttrs & attrs)431 static std::vector<std::string> SgMKLDNNFCListInputNames(const NodeAttrs &attrs) {
432   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
433   std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
434   if (full_param.mkldnn_param.quantized) {
435     bool channel_wise = false;
436     if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
437         full_param.mkldnn_param.channel_wise_quantize) {
438       channel_wise = true;
439     }
440     input_names.emplace_back("min_data");
441     input_names.emplace_back("max_data");
442     if (!channel_wise) {
443       input_names.emplace_back("min_weight");
444       input_names.emplace_back("max_weight");
445       if (!full_param.default_param.no_bias) {
446         input_names.emplace_back("min_bias");
447         input_names.emplace_back("max_bias");
448       }
449     }
450   }
451   return input_names;
452 }
453 
SgMKLDNNFCListOutputNames(const NodeAttrs & attrs)454 static std::vector<std::string> SgMKLDNNFCListOutputNames(const NodeAttrs &attrs) {
455   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
456   if (full_param.mkldnn_param.quantized) {
457     if (full_param.mkldnn_param.enable_float_output)
458       return std::vector<std::string>{"output"};
459     else
460       return std::vector<std::string>{"output", "min_output", "max_output"};
461   } else {
462     return std::vector<std::string>{"output"};
463   }
464 }
465 
466 template <typename T>
FillBaseInputOutputInfo(const FullyConnectedParam & param,std::vector<T> * base_in_attrs,std::vector<T> * base_out_attrs,std::vector<T> * in_attrs,std::vector<T> * out_attrs)467 static inline void FillBaseInputOutputInfo(const FullyConnectedParam &param,
468                                            std::vector<T> *base_in_attrs,
469                                            std::vector<T> *base_out_attrs,
470                                            std::vector<T> *in_attrs,
471                                            std::vector<T> *out_attrs) {
472   auto base_num_inputs = param.no_bias ? 2 : 3;
473 
474   base_out_attrs->push_back(out_attrs->at(0));
475   for (int i = 0; i < base_num_inputs; ++i) {
476     base_in_attrs->push_back(in_attrs->at(i));
477   }
478 }
479 
SgMKLDNNFCInferShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shapes,mxnet::ShapeVector * out_shapes)480 static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs,
481                                  mxnet::ShapeVector *in_shapes,
482                                  mxnet::ShapeVector *out_shapes) {
483   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
484   if (full_param.mkldnn_param.quantized) {
485     mxnet::ShapeVector base_in_shapes;
486     mxnet::ShapeVector base_out_shapes;
487     FillBaseInputOutputInfo(full_param.default_param, &base_in_shapes, &base_out_shapes,
488                             in_shapes, out_shapes);
489     bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes);
490 
491     for (size_t i = 0; i < in_shapes->size(); ++i) {
492       if (i < base_in_shapes.size())
493         in_shapes->at(i) = base_in_shapes[i];
494       else
495         SHAPE_ASSIGN_CHECK(*in_shapes, i, Shape1(1));
496     }
497 
498     out_shapes->at(0) = base_out_shapes[0];
499     if (!full_param.mkldnn_param.enable_float_output) {
500       SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
501       SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
502     }
503     return ret;
504   } else {
505     return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes);
506   }
507 }
508 
SgMKLDNNFCInferType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_types,std::vector<int> * out_types)509 static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs,
510                                 std::vector<int> *in_types,
511                                 std::vector<int> *out_types) {
512   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
513   if (full_param.mkldnn_param.quantized) {
514     bool channel_wise = false;
515     if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
516         full_param.mkldnn_param.channel_wise_quantize) {
517       channel_wise = true;
518     }
519     size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3;
520     CHECK(in_types->at(0) == mshadow::kInt8 ||
521           in_types->at(0) == mshadow::kUint8)
522         << "QuantizedFullyConnected only supports int8/uint8 input, while "
523         << in_types->at(0) << " is given.";
524     for (size_t i = 1; i < in_types->size(); ++i) {
525       if (channel_wise) {
526         TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
527       } else {
528         if (i < base_num_inputs) {
529           TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8);
530         } else {
531           TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
532         }
533       }
534     }
535 
536     if (full_param.mkldnn_param.enable_float_output) {
537       TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
538     } else {
539       if (full_param.mkldnn_param.min_calib_range.has_value() &&
540           full_param.mkldnn_param.max_calib_range.has_value()) {
541         if (IsOutputUint8(full_param)) {
542           TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
543         } else {
544           TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
545         }
546       } else {
547         TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);
548       }
549       TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);
550       TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
551     }
552     return true;
553   } else {
554     return DefaultSubgraphOpType(attrs, in_types, out_types);
555   }
556 }
557 
SgMKLDNNFCStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)558 static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs,
559                                   const int dev_mask,
560                                   DispatchMode *dispatch_mode,
561                                   std::vector<int> *in_attrs,
562                                   std::vector<int> *out_attrs) {
563   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
564   if (full_param.mkldnn_param.quantized) {
565     std::vector<int> base_in_attrs;
566     std::vector<int> base_out_attrs;
567     FillBaseInputOutputInfo(full_param.default_param, &base_in_attrs, &base_out_attrs,
568                             in_attrs, out_attrs);
569     bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
570                                             &base_in_attrs, &base_out_attrs);
571 
572     for (size_t i = 0; i < in_attrs->size(); ++i) {
573       if (i < base_in_attrs.size())
574         in_attrs->at(i) = base_in_attrs[i];
575       else
576         type_assign(&in_attrs->at(i), mxnet::kDefaultStorage);
577     }
578 
579     out_attrs->at(0) = base_out_attrs[0];
580     if (!full_param.mkldnn_param.enable_float_output) {
581       type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
582       type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
583     }
584     return ret;
585   } else {
586     return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
587                                         in_attrs, out_attrs);
588   }
589 }
590 
CreateSgMKLDNNFCState(const nnvm::NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & in_shapes,const std::vector<int> & in_types)591 static OpStatePtr CreateSgMKLDNNFCState(const nnvm::NodeAttrs &attrs,
592                                         Context ctx,
593                                         const mxnet::ShapeVector &in_shapes,
594                                         const std::vector<int> &in_types) {
595   return OpStatePtr::Create<SgMKLDNNFCOp>(attrs);
596 }
597 
SgMKLDNNFCForward(const OpStatePtr & state_pointer,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)598 static void SgMKLDNNFCForward(const OpStatePtr &state_pointer,
599                               const OpContext &ctx,
600                               const std::vector<NDArray> &inputs,
601                               const std::vector<OpReqType> &req,
602                               const std::vector<NDArray> &outputs) {
603   SgMKLDNNFCOp &op = state_pointer.get_state<SgMKLDNNFCOp>();
604   op.Forward(ctx, inputs, req, outputs);
605 }
606 
SgMKLDNNFCQuantizedOp(const NodeAttrs & attrs)607 nnvm::ObjectPtr SgMKLDNNFCQuantizedOp(const NodeAttrs& attrs) {
608   nnvm::ObjectPtr node = nnvm::Node::Create();
609   node->attrs.op = Op::Get("_sg_mkldnn_fully_connected");
610   node->attrs.name = "quantized_" + attrs.name;
611   node->attrs.dict = attrs.dict;
612   node->attrs.dict["quantized"] = "True";
613   node->attrs.subgraphs.reserve(attrs.subgraphs.size());
614   for (auto sub : attrs.subgraphs) {
615     node->attrs.subgraphs.push_back(sub);
616   }
617   node->op()->attr_parser(&(node->attrs));
618   return node;
619 }
620 
SgMKLDNNAvoidFCQuantizeInput(const NodeAttrs & attrs,const size_t index_to_check,const std::string quantize_granularity)621 static bool SgMKLDNNAvoidFCQuantizeInput(const NodeAttrs& attrs, const size_t index_to_check,
622                                          const std::string quantize_granularity) {
623   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
624   std::unordered_set<size_t> avoid_indexes;
625   if (quantize_granularity == "channel-wise") {
626     avoid_indexes.insert(fullc::kWeight);   // weight
627     if (!full_param.default_param.no_bias) {
628       avoid_indexes.insert(fullc::kBias);   // bias
629     }
630   }
631 
632   return avoid_indexes.count(index_to_check);
633 }
634 
635 NNVM_REGISTER_OP(_sg_mkldnn_fully_connected)
636 .describe(R"code(_sg_mkldnn_fully_connected)code" ADD_FILELINE)
__anon94d98ca20202(const NodeAttrs& attrs) 637 .set_num_inputs([](const NodeAttrs& attrs) {
638   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
639   auto num_inputs = full_param.default_param.no_bias ? 2 : 3;
640   if (full_param.mkldnn_param.quantized) {
641     if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
642         full_param.mkldnn_param.channel_wise_quantize) {
643       return num_inputs + 2;  // min_data, max_data
644     } else {
645       return num_inputs * 3;
646     }
647   } else {
648     return num_inputs;
649   }
650 })
__anon94d98ca20302(const NodeAttrs& attrs) 651 .set_num_outputs([](const NodeAttrs& attrs) {
652   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
653   return (full_param.mkldnn_param.quantized &&
654           !full_param.mkldnn_param.enable_float_output) ? 3 : 1;
655 })
656 .set_attr_parser(SgMKLDNNFCParamParser)
657 .set_attr<nnvm::FListInputNames>("FListInputNames", SgMKLDNNFCListInputNames)
658 .set_attr<nnvm::FListOutputNames>("FListOutputNames", SgMKLDNNFCListOutputNames)
659 .set_attr<mxnet::FInferShape>("FInferShape", SgMKLDNNFCInferShape)
660 .set_attr<nnvm::FInferType>("FInferType", SgMKLDNNFCInferType)
661 .set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNFCStorageType)
662 .set_attr<FCreateOpState>("FCreateOpState", CreateSgMKLDNNFCState)
663 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNFCForward)
664 .set_attr<bool>("TIsMKLDNN", true)
665 // TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
666 // will be reverted after the improvement of CachedOP is done.
667 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
__anon94d98ca20402(const NodeAttrs& n) 668 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
669   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
670 })
671 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
672                                DefaultSubgraphOpMutableInputs)
673 .set_attr<std::string>("key_var_num_args", "num_args")
__anon94d98ca20502(const NodeAttrs& attrs) 674 .set_attr<FQuantizable>("FQuantizable", [](const NodeAttrs& attrs) {
675     return QuantizeType::kMust;
676 })
677 .set_attr<FQuantizedOp>("FQuantizedOp", SgMKLDNNFCQuantizedOp)
__anon94d98ca20602(const NodeAttrs& attrs) 678 .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
679 .set_attr<FAvoidQuantizeInput>("FAvoidQuantizeInput", SgMKLDNNAvoidFCQuantizeInput);
680 
681 }  // namespace op
682 }  // namespace mxnet
683 
684 #endif  // if MXNET_USE_MKLDNN == 1
685