1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file mkldnn_fc.cc
22 * \brief MKLDNN (Quantized) FullyConnected operator based on subgraph
23 * \author Ciyong Chen
24 */
25
26 #if MXNET_USE_MKLDNN == 1
27
28 #include <utility>
29 #include <vector>
30 #include <string>
31 #include "../common.h"
32 #include "../../nn/mkldnn/mkldnn_base-inl.h"
33 #include "../../nn/mkldnn/mkldnn_ops-inl.h"
34 #include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
35 #include "../../nn/mkldnn/mkldnn_act-inl.h"
36 #include "../../tensor/matrix_op-inl.h"
37 #include "../../quantization/quantization_utils.h"
38 #include "mkldnn_fc-inl.h"
39 #include "mkldnn_common.h"
40
41 namespace mxnet {
42 namespace op {
43
44 class SgMKLDNNFCOp {
45 public:
SgMKLDNNFCOp(const nnvm::NodeAttrs & attrs)46 explicit SgMKLDNNFCOp(const nnvm::NodeAttrs &attrs)
47 : subgraph_sym_(*attrs.subgraphs[0]),
48 full_param_(nnvm::get<MKLDNNFCFullParam>(attrs.parsed)) {}
49
50 void Forward(const OpContext &ctx,
51 const std::vector<NDArray> &inputs,
52 const std::vector<OpReqType> &req,
53 const std::vector<NDArray> &outputs);
54
Backward(const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)55 void Backward(const OpContext &ctx,
56 const std::vector<NDArray> &inputs,
57 const std::vector<OpReqType> &req,
58 const std::vector<NDArray> &outputs) {
59 LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports "
60 "inference computation.";
61 }
62
63 private:
64 bool initialized_{false};
65 bool channel_wise_runtime_{false};
66 bool reorder_data_{false};
67 nnvm::Symbol subgraph_sym_;
68 MKLDNNFCFullParam full_param_;
69 mkldnn_args_map_t args_;
70 std::shared_ptr<MKLDNNFullyConnectedForward> fwd_;
71 std::shared_ptr<mkldnn::memory> cached_data_mem_;
72 std::shared_ptr<mkldnn::memory> cached_out_mem_;
73 NDArray cached_weight_;
74 NDArray cached_bias_;
75 float cached_min_data_;
76 float cached_max_data_;
77 float cached_min_weight_;
78 float cached_max_weight_;
79 float cached_min_bias_;
80 float cached_max_bias_;
81 size_t weight_ver_;
82 size_t bias_ver_;
83 float cached_min_output_;
84 float cached_max_output_;
85 float data_scale_{0.0f};
86 std::vector<float> weight_scales_;
87 size_t total_num_inputs_;
88 size_t total_num_outputs_;
89 };
90
Forward(const OpContext & ctx,const std::vector<NDArray> & in_data,const std::vector<OpReqType> & req,const std::vector<NDArray> & out_data)91 void SgMKLDNNFCOp::Forward(const OpContext &ctx,
92 const std::vector<NDArray> &in_data,
93 const std::vector<OpReqType> &req,
94 const std::vector<NDArray> &out_data) {
95 auto &mkldnn_param = full_param_.mkldnn_param;
96 auto &default_param = full_param_.default_param;
97 bool has_bias = !default_param.no_bias;
98 size_t base_num_inputs = has_bias ? 3 : 2;
99 size_t base_num_outputs = 1;
100
101 float min_data = 0.0f;
102 float max_data = 0.0f;
103 float min_weight = 0.0f;
104 float max_weight = 0.0f;
105 float min_bias = 0.0f;
106 float max_bias = 0.0f;
107
108 if (!initialized_) {
109 if (mkldnn_param.channel_wise_quantize.has_value() &&
110 mkldnn_param.channel_wise_quantize) {
111 channel_wise_runtime_ = true;
112 }
113
114 total_num_inputs_ = base_num_inputs;
115 total_num_outputs_ = base_num_outputs;
116 if (mkldnn_param.quantized) {
117 total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : (base_num_inputs * 3);
118 total_num_outputs_ =
119 mkldnn_param.enable_float_output ? base_num_outputs : (base_num_outputs * 3);
120 }
121 }
122 CHECK_EQ(in_data.size(), total_num_inputs_);
123 CHECK_EQ(out_data.size(), total_num_outputs_);
124
125 NDArray data = in_data[fullc::kData];
126 const NDArray &weight = in_data[fullc::kWeight];
127 const NDArray &output = out_data[fullc::kOut];
128
129 if (mkldnn_param.quantized) {
130 if (!channel_wise_runtime_) {
131 min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr<float>()[0];
132 max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr<float>()[0];
133 if (has_bias) {
134 min_bias = in_data[base_num_inputs + quantized_fullc::kBiasMin].data().dptr<float>()[0];
135 max_bias = in_data[base_num_inputs + quantized_fullc::kBiasMax].data().dptr<float>()[0];
136 }
137 }
138 min_data = in_data[base_num_inputs + quantized_fullc::kDataMin].data().dptr<float>()[0];
139 max_data = in_data[base_num_inputs + quantized_fullc::kDataMax].data().dptr<float>()[0];
140 }
141
142 if (initialized_ && mkldnn_param.quantized &&
143 dmlc::GetEnv("MXNET_MKLDNN_QFC_DYNAMIC_PARAMS", 0)) {
144 if (channel_wise_runtime_) {
145 if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
146 weight_ver_ != weight.version() ||
147 (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) {
148 initialized_ = false;
149 }
150 } else {
151 if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
152 cached_min_weight_ != min_weight || cached_max_weight_ != max_weight ||
153 (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) {
154 initialized_ = false;
155 }
156 }
157 }
158
159 if (!initialized_) {
160 const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
161 const auto engine = CpuEngine::Get()->get_engine();
162 cached_min_data_ = min_data;
163 cached_max_data_ = max_data;
164 cached_min_weight_ = min_weight;
165 cached_max_weight_ = max_weight;
166 weight_ver_ = weight.version();
167 cached_weight_ = weight;
168 if (has_bias) {
169 cached_min_bias_ = min_bias;
170 cached_max_bias_ = max_bias;
171 bias_ver_ = in_data[fullc::kBias].version();
172 cached_bias_ = in_data[fullc::kBias];
173 } else {
174 cached_bias_ = NDArray();
175 }
176 const mxnet::TShape ishape = data.shape();
177 const auto data_ndim = ishape.ndim();
178 if (data.IsMKLDNNData()) {
179 reorder_data_ = true;
180 data = data.Reorder2Default();
181 }
182 if (data_ndim != 2) {
183 if (!default_param.flatten) {
184 data = data.MKLDNNDataReshape(
185 Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1]));
186 } else {
187 data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim)));
188 }
189 }
190
191 // create cached out_md
192 const mxnet::TShape oshape = output.shape();
193 mkldnn::memory::dims out_dims(2);
194 if (oshape.ndim() == 2) {
195 out_dims[0] = static_cast<int>(oshape[0]);
196 out_dims[1] = static_cast<int>(oshape[1]);
197 } else {
198 if (!default_param.flatten) {
199 out_dims[0] = static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1));
200 out_dims[1] = static_cast<int>(oshape[oshape.ndim()-1]);
201 } else {
202 out_dims[0] = static_cast<int>(static_cast<int>(oshape[0]));
203 out_dims[1] = static_cast<int>(oshape.ProdShape(1, oshape.ndim()));
204 }
205 }
206 mkldnn::memory::desc out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(output.dtype()),
207 static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(2)));
208 cached_out_mem_ = std::make_shared<mkldnn::memory>(out_md, engine);
209
210 bool support_channelwise_scale = false;
211 if (mkldnn_param.quantized) {
212 CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
213 data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_);
214
215 bool fuse_requantize = false;
216 // Channelwise scaling is only supported when fusion is enabled (requantize or dequantize).
217 if (mkldnn_param.min_calib_range.has_value() &&
218 mkldnn_param.max_calib_range.has_value()) {
219 cached_min_output_ = mkldnn_param.min_calib_range.value();
220 cached_max_output_ = mkldnn_param.max_calib_range.value();
221 support_channelwise_scale = true;
222 fuse_requantize = true;
223 }
224 if (mkldnn_param.enable_float_output) {
225 support_channelwise_scale = true;
226 }
227 // channel_wise support_channelwise_scale result
228 // True True True
229 // True False Error
230 // False True/False False
231 if (channel_wise_runtime_ && !support_channelwise_scale) {
232 LOG(FATAL)
233 << "Currently, channel-wise quantization requires fuse requantize or dequantize."
234 << " Please make sure the `min_calib_range` and `max_calib_range` are set when only"
235 << " fuse requantize (outputs of FullyConnected are collected during calibration phase),"
236 << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
237 << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)";
238 }
239 support_channelwise_scale = support_channelwise_scale && channel_wise_runtime_;
240
241 if (support_channelwise_scale) {
242 MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
243 weight_scales_ =
244 GetWeightScales<DType>(cached_weight_, has_bias ? &cached_bias_ : nullptr,
245 data_scale_, support_channelwise_scale);
246 });
247 } else {
248 weight_scales_.resize(1);
249 weight_scales_[0] =
250 GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_, cached_max_weight_);
251 if (has_bias) {
252 float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_);
253 float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale;
254 // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value
255 // of bias to INT_MAX / 2.
256 float bias_max_rescale =
257 MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale;
258 if (bias_int32_rescale > bias_max_rescale) {
259 // avoid overflow on bias
260 bias_int32_rescale = bias_max_rescale;
261 float weight_rescale =
262 bias_int32_rescale * bias_scale / data_scale_ / weight_scales_[0];
263 int8_t *weight_ptr = weight.data().dptr<int8_t>();
264 size_t weight_size = weight.shape().Size();
265 #pragma omp parallel for num_threads(nthreads)
266 for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
267 weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
268 }
269 weight_scales_[0] *= weight_rescale;
270 }
271 NDArray bias = in_data[fullc::kBias];
272 cached_bias_ =
273 NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32);
274 int8_t *bias_ptr = bias.data().dptr<int8_t>();
275 int32_t *quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
276 size_t bias_size = bias.shape().Size();
277 #pragma omp parallel for num_threads(nthreads)
278 for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
279 quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale);
280 }
281 }
282 }
283
284 size_t num_channel = cached_weight_.shape()[0];
285 if (fuse_requantize || mkldnn_param.enable_float_output) {
286 float tmp_scale_ = 1.0f;
287 if (fuse_requantize) {
288 if (mkldnn_param.with_eltwise) {
289 tmp_scale_ = 1.0 / data_scale_;
290 full_param_.eltwise_param.scale =
291 GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_);
292 } else {
293 tmp_scale_ =
294 GetQuantizeScale(output.dtype(),
295 cached_min_output_,
296 cached_max_output_) / data_scale_;
297 }
298 } else {
299 tmp_scale_ = 1.0 / data_scale_;
300 }
301
302 if (support_channelwise_scale) {
303 full_param_.output_scales.resize(num_channel);
304 #pragma omp parallel for num_threads(nthreads)
305 for (index_t i = 0; i < static_cast<index_t>(num_channel); ++i) {
306 full_param_.output_scales[i] = tmp_scale_ / weight_scales_[i];
307 }
308 } else {
309 full_param_.output_scales.resize(1);
310 full_param_.output_scales[0] = tmp_scale_ / weight_scales_[0];
311 }
312 } else {
313 Stream<cpu> *s = ctx.get_stream<cpu>();
314 if (data.dtype() == mshadow::kInt8) {
315 mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
316 s, 1, &cached_min_output_, &cached_max_output_, &min_data, &max_data, &min_weight,
317 &max_weight);
318 } else {
319 mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
320 s, 1, &cached_min_output_, &cached_max_output_, &min_data, &max_data, &min_weight,
321 &max_weight);
322 }
323 full_param_.output_scales.resize(0);
324 }
325 }
326
327 fwd_.reset(new MKLDNNFullyConnectedForward(full_param_, ctx.is_train, data, cached_weight_,
328 (has_bias ? &cached_bias_ : nullptr), out_md));
329
330 // convert weight and bias to the format that MKL-DNN requires
331 if (!mkldnn_param.quantized || support_channelwise_scale) {
332 mkldnn::memory::desc bias_md;
333 if (has_bias) bias_md = fwd_->fwd_pd.bias_desc();
334 ConvertWeightBias2MKLDNN(&cached_weight_, &cached_bias_, has_bias,
335 fwd_->fwd_pd.weights_desc(),
336 has_bias ? &bias_md : nullptr,
337 1, data_scale_, weight_scales_, false);
338 } else {
339 const auto def_weight_mem = weight.GetMKLDNNData();
340 if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
341 cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc());
342 auto cached_weight_mem = cached_weight_.GetMKLDNNData();
343 std::unordered_map<int, mkldnn::memory> args(
344 {{MKLDNN_ARG_FROM, *def_weight_mem},
345 {MKLDNN_ARG_TO, *cached_weight_mem}});
346 MKLDNNStream::Get()->RegisterPrimArgs(
347 mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args);
348 }
349 }
350
351 const auto data_mem = data.GetMKLDNNData();
352 cached_data_mem_ = std::make_shared<mkldnn::memory>(data_mem->get_desc(), engine);
353
354 args_[MKLDNN_ARG_SRC] = *cached_data_mem_;
355 args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData();
356 if (has_bias)
357 args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData();
358 args_[MKLDNN_ARG_DST] = *cached_out_mem_;
359 initialized_ = true;
360 }
361
362 if (reorder_data_) {
363 data = data.Reorder2Default();
364 }
365 MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
366 cached_data_mem_->set_data_handle(reinterpret_cast<void *>(data.data().dptr<DType>()));
367 });
368 MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
369 cached_out_mem_->set_data_handle(reinterpret_cast<void *>(output.data().dptr<DType>()));
370 });
371 MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
372 MKLDNNStream::Get()->Submit();
373
374 if (mkldnn_param.quantized && !mkldnn_param.enable_float_output) {
375 float *min_output_ptr = out_data[quantized_fullc::kOutMin].data().dptr<float>();
376 float *max_output_ptr = out_data[quantized_fullc::kOutMax].data().dptr<float>();
377 *min_output_ptr = cached_min_output_;
378 *max_output_ptr = cached_max_output_;
379 }
380 }
381
SgMKLDNNFCParamParser(nnvm::NodeAttrs * attrs)382 static void SgMKLDNNFCParamParser(nnvm::NodeAttrs *attrs) {
383 // For backward compatible, with_relu->with_eltwise
384 auto legacy = attrs->dict.find("with_relu");
385 if (legacy != attrs->dict.end()) {
386 attrs->dict["with_eltwise"] = attrs->dict["with_relu"];
387 attrs->dict.erase(legacy);
388 }
389
390 MKLDNNFCFullParam full_param;
391 try {
392 full_param.mkldnn_param.Init(attrs->dict);
393 } catch (const dmlc::ParamError &e) {
394 std::ostringstream os;
395 os << e.what();
396 os << ", in operator " << attrs->op->name << "("
397 << "name=\"" << attrs->name << "\"";
398 for (const auto &k : attrs->dict) {
399 os << ", " << k.first << "=\"" << k.second << "\"";
400 }
401 os << ")";
402 throw dmlc::ParamError(os.str());
403 }
404 auto subgraph_sym = attrs->subgraphs[0];
405 DFSVisit(subgraph_sym->outputs, [&](const nnvm::ObjectPtr &node) {
406 if (node->is_variable()) return;
407 auto &op_name = node->op()->name;
408 if (op_name == "FullyConnected") {
409 full_param.default_param =
410 nnvm::get<FullyConnectedParam>(node->attrs.parsed);
411 } else if (SupportMKLDNNFCEltwiseFusion(op_name)) {
412 if (op_name == "Activation") {
413 const ActivationParam act_param = nnvm::get<ActivationParam>(node->attrs.parsed);
414 full_param.eltwise_param.alg = GetMKLDNNActAlgo(act_param);
415 } else if (op_name == "LeakyReLU") {
416 const auto act_param = nnvm::get<LeakyReLUParam>(node->attrs.parsed);
417 full_param.eltwise_param.alpha = act_param.slope;
418 full_param.eltwise_param.alg = GetMKLDNNActAlgo(act_param);
419 } else if (op_name == "clip") {
420 const ClipParam clip_param = nnvm::get<ClipParam>(node->attrs.parsed);
421 full_param.eltwise_param.alg = mkldnn::algorithm::eltwise_bounded_relu;
422 full_param.eltwise_param.alpha = clip_param.a_max;
423 } else {
424 full_param.eltwise_param.alg = GetMKLDNNEltwiseAlgo(op_name);
425 }
426 }
427 });
428 attrs->parsed = std::move(full_param);
429 }
430
SgMKLDNNFCListInputNames(const NodeAttrs & attrs)431 static std::vector<std::string> SgMKLDNNFCListInputNames(const NodeAttrs &attrs) {
432 auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
433 std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
434 if (full_param.mkldnn_param.quantized) {
435 bool channel_wise = false;
436 if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
437 full_param.mkldnn_param.channel_wise_quantize) {
438 channel_wise = true;
439 }
440 input_names.emplace_back("min_data");
441 input_names.emplace_back("max_data");
442 if (!channel_wise) {
443 input_names.emplace_back("min_weight");
444 input_names.emplace_back("max_weight");
445 if (!full_param.default_param.no_bias) {
446 input_names.emplace_back("min_bias");
447 input_names.emplace_back("max_bias");
448 }
449 }
450 }
451 return input_names;
452 }
453
SgMKLDNNFCListOutputNames(const NodeAttrs & attrs)454 static std::vector<std::string> SgMKLDNNFCListOutputNames(const NodeAttrs &attrs) {
455 auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
456 if (full_param.mkldnn_param.quantized) {
457 if (full_param.mkldnn_param.enable_float_output)
458 return std::vector<std::string>{"output"};
459 else
460 return std::vector<std::string>{"output", "min_output", "max_output"};
461 } else {
462 return std::vector<std::string>{"output"};
463 }
464 }
465
466 template <typename T>
FillBaseInputOutputInfo(const FullyConnectedParam & param,std::vector<T> * base_in_attrs,std::vector<T> * base_out_attrs,std::vector<T> * in_attrs,std::vector<T> * out_attrs)467 static inline void FillBaseInputOutputInfo(const FullyConnectedParam ¶m,
468 std::vector<T> *base_in_attrs,
469 std::vector<T> *base_out_attrs,
470 std::vector<T> *in_attrs,
471 std::vector<T> *out_attrs) {
472 auto base_num_inputs = param.no_bias ? 2 : 3;
473
474 base_out_attrs->push_back(out_attrs->at(0));
475 for (int i = 0; i < base_num_inputs; ++i) {
476 base_in_attrs->push_back(in_attrs->at(i));
477 }
478 }
479
SgMKLDNNFCInferShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shapes,mxnet::ShapeVector * out_shapes)480 static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs,
481 mxnet::ShapeVector *in_shapes,
482 mxnet::ShapeVector *out_shapes) {
483 auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
484 if (full_param.mkldnn_param.quantized) {
485 mxnet::ShapeVector base_in_shapes;
486 mxnet::ShapeVector base_out_shapes;
487 FillBaseInputOutputInfo(full_param.default_param, &base_in_shapes, &base_out_shapes,
488 in_shapes, out_shapes);
489 bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes);
490
491 for (size_t i = 0; i < in_shapes->size(); ++i) {
492 if (i < base_in_shapes.size())
493 in_shapes->at(i) = base_in_shapes[i];
494 else
495 SHAPE_ASSIGN_CHECK(*in_shapes, i, Shape1(1));
496 }
497
498 out_shapes->at(0) = base_out_shapes[0];
499 if (!full_param.mkldnn_param.enable_float_output) {
500 SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
501 SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
502 }
503 return ret;
504 } else {
505 return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes);
506 }
507 }
508
SgMKLDNNFCInferType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_types,std::vector<int> * out_types)509 static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs,
510 std::vector<int> *in_types,
511 std::vector<int> *out_types) {
512 auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
513 if (full_param.mkldnn_param.quantized) {
514 bool channel_wise = false;
515 if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
516 full_param.mkldnn_param.channel_wise_quantize) {
517 channel_wise = true;
518 }
519 size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3;
520 CHECK(in_types->at(0) == mshadow::kInt8 ||
521 in_types->at(0) == mshadow::kUint8)
522 << "QuantizedFullyConnected only supports int8/uint8 input, while "
523 << in_types->at(0) << " is given.";
524 for (size_t i = 1; i < in_types->size(); ++i) {
525 if (channel_wise) {
526 TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
527 } else {
528 if (i < base_num_inputs) {
529 TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8);
530 } else {
531 TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
532 }
533 }
534 }
535
536 if (full_param.mkldnn_param.enable_float_output) {
537 TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
538 } else {
539 if (full_param.mkldnn_param.min_calib_range.has_value() &&
540 full_param.mkldnn_param.max_calib_range.has_value()) {
541 if (IsOutputUint8(full_param)) {
542 TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
543 } else {
544 TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
545 }
546 } else {
547 TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);
548 }
549 TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);
550 TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
551 }
552 return true;
553 } else {
554 return DefaultSubgraphOpType(attrs, in_types, out_types);
555 }
556 }
557
SgMKLDNNFCStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)558 static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs,
559 const int dev_mask,
560 DispatchMode *dispatch_mode,
561 std::vector<int> *in_attrs,
562 std::vector<int> *out_attrs) {
563 auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
564 if (full_param.mkldnn_param.quantized) {
565 std::vector<int> base_in_attrs;
566 std::vector<int> base_out_attrs;
567 FillBaseInputOutputInfo(full_param.default_param, &base_in_attrs, &base_out_attrs,
568 in_attrs, out_attrs);
569 bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
570 &base_in_attrs, &base_out_attrs);
571
572 for (size_t i = 0; i < in_attrs->size(); ++i) {
573 if (i < base_in_attrs.size())
574 in_attrs->at(i) = base_in_attrs[i];
575 else
576 type_assign(&in_attrs->at(i), mxnet::kDefaultStorage);
577 }
578
579 out_attrs->at(0) = base_out_attrs[0];
580 if (!full_param.mkldnn_param.enable_float_output) {
581 type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
582 type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
583 }
584 return ret;
585 } else {
586 return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
587 in_attrs, out_attrs);
588 }
589 }
590
CreateSgMKLDNNFCState(const nnvm::NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & in_shapes,const std::vector<int> & in_types)591 static OpStatePtr CreateSgMKLDNNFCState(const nnvm::NodeAttrs &attrs,
592 Context ctx,
593 const mxnet::ShapeVector &in_shapes,
594 const std::vector<int> &in_types) {
595 return OpStatePtr::Create<SgMKLDNNFCOp>(attrs);
596 }
597
SgMKLDNNFCForward(const OpStatePtr & state_pointer,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)598 static void SgMKLDNNFCForward(const OpStatePtr &state_pointer,
599 const OpContext &ctx,
600 const std::vector<NDArray> &inputs,
601 const std::vector<OpReqType> &req,
602 const std::vector<NDArray> &outputs) {
603 SgMKLDNNFCOp &op = state_pointer.get_state<SgMKLDNNFCOp>();
604 op.Forward(ctx, inputs, req, outputs);
605 }
606
SgMKLDNNFCQuantizedOp(const NodeAttrs & attrs)607 nnvm::ObjectPtr SgMKLDNNFCQuantizedOp(const NodeAttrs& attrs) {
608 nnvm::ObjectPtr node = nnvm::Node::Create();
609 node->attrs.op = Op::Get("_sg_mkldnn_fully_connected");
610 node->attrs.name = "quantized_" + attrs.name;
611 node->attrs.dict = attrs.dict;
612 node->attrs.dict["quantized"] = "True";
613 node->attrs.subgraphs.reserve(attrs.subgraphs.size());
614 for (auto sub : attrs.subgraphs) {
615 node->attrs.subgraphs.push_back(sub);
616 }
617 node->op()->attr_parser(&(node->attrs));
618 return node;
619 }
620
SgMKLDNNAvoidFCQuantizeInput(const NodeAttrs & attrs,const size_t index_to_check,const std::string quantize_granularity)621 static bool SgMKLDNNAvoidFCQuantizeInput(const NodeAttrs& attrs, const size_t index_to_check,
622 const std::string quantize_granularity) {
623 auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
624 std::unordered_set<size_t> avoid_indexes;
625 if (quantize_granularity == "channel-wise") {
626 avoid_indexes.insert(fullc::kWeight); // weight
627 if (!full_param.default_param.no_bias) {
628 avoid_indexes.insert(fullc::kBias); // bias
629 }
630 }
631
632 return avoid_indexes.count(index_to_check);
633 }
634
635 NNVM_REGISTER_OP(_sg_mkldnn_fully_connected)
636 .describe(R"code(_sg_mkldnn_fully_connected)code" ADD_FILELINE)
__anon94d98ca20202(const NodeAttrs& attrs) 637 .set_num_inputs([](const NodeAttrs& attrs) {
638 auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
639 auto num_inputs = full_param.default_param.no_bias ? 2 : 3;
640 if (full_param.mkldnn_param.quantized) {
641 if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
642 full_param.mkldnn_param.channel_wise_quantize) {
643 return num_inputs + 2; // min_data, max_data
644 } else {
645 return num_inputs * 3;
646 }
647 } else {
648 return num_inputs;
649 }
650 })
__anon94d98ca20302(const NodeAttrs& attrs) 651 .set_num_outputs([](const NodeAttrs& attrs) {
652 auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
653 return (full_param.mkldnn_param.quantized &&
654 !full_param.mkldnn_param.enable_float_output) ? 3 : 1;
655 })
656 .set_attr_parser(SgMKLDNNFCParamParser)
657 .set_attr<nnvm::FListInputNames>("FListInputNames", SgMKLDNNFCListInputNames)
658 .set_attr<nnvm::FListOutputNames>("FListOutputNames", SgMKLDNNFCListOutputNames)
659 .set_attr<mxnet::FInferShape>("FInferShape", SgMKLDNNFCInferShape)
660 .set_attr<nnvm::FInferType>("FInferType", SgMKLDNNFCInferType)
661 .set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNFCStorageType)
662 .set_attr<FCreateOpState>("FCreateOpState", CreateSgMKLDNNFCState)
663 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNFCForward)
664 .set_attr<bool>("TIsMKLDNN", true)
665 // TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
666 // will be reverted after the improvement of CachedOP is done.
667 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
__anon94d98ca20402(const NodeAttrs& n) 668 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
669 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
670 })
671 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
672 DefaultSubgraphOpMutableInputs)
673 .set_attr<std::string>("key_var_num_args", "num_args")
__anon94d98ca20502(const NodeAttrs& attrs) 674 .set_attr<FQuantizable>("FQuantizable", [](const NodeAttrs& attrs) {
675 return QuantizeType::kMust;
676 })
677 .set_attr<FQuantizedOp>("FQuantizedOp", SgMKLDNNFCQuantizedOp)
__anon94d98ca20602(const NodeAttrs& attrs) 678 .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
679 .set_attr<FAvoidQuantizeInput>("FAvoidQuantizeInput", SgMKLDNNAvoidFCQuantizeInput);
680
681 } // namespace op
682 } // namespace mxnet
683
684 #endif // if MXNET_USE_MKLDNN == 1
685