1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file quantized_fully_connected.cc
22  * \brief
23  * \author Ziheng Jiang, Jun Wu
24 */
25 #include <vector>
26 #include "quantization_utils.h"
27 #include "../nn/fully_connected-inl.h"
28 #if MXNET_USE_MKLDNN == 1
29 #include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
30 #include "mkldnn/mkldnn_quantized_ops-inl.h"
31 #endif
32 
33 namespace mxnet {
34 namespace op {
35 
36 namespace quantized_fc {
37 enum QuantizedfcOpResource {kTempSpace};
38 }
39 
QuantizedFullyConnectedShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)40 bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
41                                   mxnet::ShapeVector *in_shape,
42                                   mxnet::ShapeVector *out_shape) {
43   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
44   using namespace mshadow;
45   uint32_t num_inputs = param.no_bias ? 2 : 3;
46   CHECK_EQ(in_shape->size(), num_inputs * 3);
47   CHECK_EQ(out_shape->size(), 3U);
48 
49   mxnet::TShape dshape = (*in_shape)[0];
50   // require data ndim to be known
51   if (!mxnet::ndim_is_known(dshape)) return false;
52 
53   index_t num_input;
54   if (!param.flatten) {
55     num_input = dshape[dshape.ndim() - 1];
56   } else {
57     num_input = dshape.ProdShape(1, dshape.ndim());
58   }
59 
60   mxnet::TShape wshape = Shape2(param.num_hidden, num_input);
61   SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape);
62   if (!param.no_bias) {
63     mxnet::TShape bshape = Shape1(param.num_hidden);
64     SHAPE_ASSIGN_CHECK(*in_shape, 2, bshape);
65   }
66 
67   for (size_t i = num_inputs; i < 3 * num_inputs; ++i) {
68     SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(1, 1));
69   }
70 
71   if (!param.flatten) {
72     mxnet::TShape result_shape(dshape);
73     result_shape[dshape.ndim() - 1] = param.num_hidden;
74     SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape);
75   } else {
76     SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden));
77   }
78   SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1));
79   SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1));
80 
81   if ((*out_shape)[0].ndim() > 0) {
82     dshape[0] = ((*out_shape)[0])[0];
83     SHAPE_ASSIGN_CHECK(*in_shape, 0, dshape);
84   }
85   return true;
86 }
87 
QuantizedFullyConnectedType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)88 bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
89                                  std::vector<int> *in_type,
90                                  std::vector<int> *out_type) {
91   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
92   uint32_t num_inputs = param.no_bias ? 2 : 3;
93   CHECK_EQ(in_type->size(), num_inputs * 3);
94   CHECK_EQ(out_type->size(), 3U);
95 
96 #if MXNET_USE_MKLDNN == 1
97   CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8)
98       << "QuantizedFullyConnected only supports int8/uint8 input, while "
99       << in_type->at(0) << " is given.";
100 #else
101   TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8);
102 #endif
103   for (size_t i = 1; i < num_inputs; ++i) {
104     TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kInt8);
105   }
106   for (size_t i = num_inputs; i < 3 * num_inputs; ++i) {
107     TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32);
108   }
109 
110   TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt32);
111   TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32);
112   TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32);
113   return true;
114 }
115 
QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)116 bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
117                                         const int dev_mask,
118                                         DispatchMode* dispatch_mode,
119                                         std::vector<int> *in_attrs,
120                                         std::vector<int> *out_attrs) {
121   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
122   uint32_t num_inputs = param.no_bias ? 2 : 3;
123   CHECK_EQ(in_attrs->size(), num_inputs * 3);
124   CHECK_EQ(out_attrs->size(), 3U);
125 
126 #if MXNET_USE_MKLDNN == 1
127   return MKLDNNStorageType(attrs, dev_mask, true,
128                            dispatch_mode, in_attrs, out_attrs);
129 #else
130   *dispatch_mode = DispatchMode::kFCompute;
131 
132   for (auto &v : *out_attrs) {
133     v = kDefaultStorage;
134     if (common::stype_string(v).compare("unknown") == 0) {
135       return false;
136     }
137   }
138 
139   for (auto &v : *in_attrs) {
140     v = kDefaultStorage;
141     if (common::stype_string(v).compare("unknown") == 0) {
142       return false;
143     }
144   }
145   return true;
146 #endif
147 }
148 
149 struct QuantizedSumInitKernelWithBias {
150   //  init sum data with bias for matrix b (n)
Mapmxnet::op::QuantizedSumInitKernelWithBias151   MSHADOW_XINLINE static void Map(int i, int32_t *out,
152                                   const int8_t *bias, const float *min_out,
153                                   const float *max_out, const float *min_bias,
154                                   const float *max_bias) {
155     typedef int32_t T1;
156     typedef int8_t  T2;
157     using mshadow::red::limits::MinValue;
158     using mshadow::red::limits::MaxValue;
159     float float_for_one_out_quant  =
160         MaxAbs(*min_out, *max_out) / static_cast<double>(MaxValue<T1>());
161     float float_for_one_bias_quant =
162         MaxAbs(*min_bias, *max_bias) / static_cast<double>(MaxValue<T2>());
163     if (float_for_one_out_quant != 0) {
164       out[i] = bias[i] * float_for_one_bias_quant /
165           float_for_one_out_quant;
166     } else {
167       LOG(INFO) << "float_for_one_out_quant is 0,"
168                 << " need to check the why MaxAbs(*min_out, *max_out) of out_data is 0!";
169       out[i] = 0;
170     }
171   }
172 };
173 
174 
QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data)175 void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs,
176                                        const OpContext &ctx,
177                                        const std::vector<TBlob> &in_data,
178                                        const std::vector<OpReqType> &req,
179                                        const std::vector<TBlob> &out_data) {
180 #if MSHADOW_USE_MKL == 1
181   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
182   using namespace mshadow;
183   using namespace mxnet_op;
184   Stream<cpu> *s = ctx.get_stream<cpu>();
185   size_t num_inputs = param.no_bias ? 2 : 3;
186   CHECK_EQ(in_data.size(),  num_inputs * 3);
187   CHECK_EQ(out_data.size(), 3U);
188 
189   const mxnet::TShape &dshape = in_data[fullc::kData].shape_;
190   const mxnet::TShape &wshape = in_data[fullc::kWeight].shape_;
191   const mxnet::TShape &oshape = out_data[fullc::kOut].shape_;
192 
193   CHECK(in_data[fullc::kData].type_flag_ == mshadow::kInt8)
194     << "QuantizedFullyConnectedForwardCPU Op only supports int8 for now, but got "
195     << mxnet::op::type_string(in_data[fullc::kData].type_flag_);
196 
197   if (dshape.ndim() != 2)
198     CHECK(param.flatten)
199         << "QuantizedFullyConnectedForwardCPU only supports flatten=true "
200         << "when dshape.ndim() != 2 for now.";
201 
202   Tensor<cpu, 2, int8_t> weight = in_data[fullc::kWeight].get<cpu, 2, int8_t>(s);
203   Tensor<cpu, 2, int8_t> data = in_data[fullc::kData].get_with_shape<cpu, 2, int8_t>(
204     Shape2(dshape[0], dshape.ProdShape(1, dshape.ndim())), s);
205   Tensor<cpu, 2, int32_t> out = out_data[fullc::kOut].get_with_shape<cpu, 2, int32_t>(
206     Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
207 
208   auto data_temp = data.dptr_;
209   auto weight_temp = weight.dptr_;
210   auto output_temp = out.dptr_;
211   const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
212   const float alpha = 1.0f;
213   const float beta  = 1.0f;
214   const CBLAS_OFFSET offsetc = CblasFixOffset;
215   const MKL_INT8 oa = 0;
216   const MKL_INT8 ob = 0;
217   MKL_INT32 oc = 0;
218   const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim());
219   //  cblas_gemm_s8u8s32 required first matrix must be uint8
220   //  shift data from int8(from -128 to 127) to uint8 (from 0 to 255)
221   int shift = 128;
222   Tensor<cpu, 1, uint8_t> shiftdata =
223     ctx.requested[quantized_fc::kTempSpace].get_space_typed<cpu, 1, uint8_t>(
224       Shape1(m * k), s);
225   #pragma omp parallel for num_threads(omp_threads)
226   for (int i = 0; i < m * k; ++i) {
227     shiftdata.dptr_[i] = data_temp[i] + shift;
228   }
229 
230   Tensor<cpu, 1, float> min_output = out_data[quantized_fullc::kOutMin].get<cpu, 1, float>(s);
231   Tensor<cpu, 1, float> max_output = out_data[quantized_fullc::kOutMax].get<cpu, 1, float>(s);
232   Tensor<cpu, 1, float> min_data =
233     in_data[num_inputs + quantized_fullc::kDataMin].get<cpu, 1, float>(s);
234   Tensor<cpu, 1, float> max_data =
235     in_data[num_inputs + quantized_fullc::kDataMax].get<cpu, 1, float>(s);
236   Tensor<cpu, 1, float> min_weight =
237     in_data[num_inputs + quantized_fullc::kWeightMin].get<cpu, 1, float>(s);
238   Tensor<cpu, 1, float> max_weight =
239     in_data[num_inputs + quantized_fullc::kWeightMax].get<cpu, 1, float>(s);
240 
241   Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(s, 1, min_output.dptr_,
242       max_output.dptr_, min_data.dptr_, max_data.dptr_, min_weight.dptr_, max_weight.dptr_);
243   if (!param.no_bias) {
244     Tensor<cpu, 1, int8_t> bias = in_data[fullc::kBias].get_with_shape<cpu, 1, int8_t>(
245       Shape1(wshape[0]), s);
246     Tensor<cpu, 1, float> min_bias =
247       in_data[num_inputs + quantized_fullc::kBiasMin].get<cpu, 1, float>(s);
248     Tensor<cpu, 1, float> max_bias =
249       in_data[num_inputs + quantized_fullc::kBiasMax].get<cpu, 1, float>(s);
250 
251     Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n, out.dptr_,
252         bias.dptr_, min_output.dptr_, max_output.dptr_, min_bias.dptr_, max_bias.dptr_);
253   } else {
254     #pragma omp parallel for num_threads(omp_threads)
255     for (int i = 0; i < m * n; ++i) {
256       output_temp[i] = 0;
257     }
258   }
259   #pragma omp parallel for num_threads(omp_threads)
260   for (int i = 0; i < n; ++i) {
261     for (int j = 0; j < k; ++j) {
262       output_temp[i] -= shift * weight_temp[i * k + j];
263     }
264   }
265   #pragma omp parallel for num_threads(omp_threads)
266   for (int i = n; i < m * n; ++i) {
267     output_temp[i] = output_temp[i % n];
268   }
269   cblas_gemm_s8u8s32(CblasRowMajor,
270                      CblasNoTrans,
271                      CblasTrans,
272                      offsetc,
273                      m,
274                      n,
275                      k,
276                      alpha,
277                      shiftdata.dptr_,
278                      k,
279                      oa,
280                      weight.dptr_,
281                      k,
282                      ob,
283                      beta,
284                      out.dptr_,
285                      n,
286                      &oc);
287 #else
288   LOG(FATAL) << "Quantized fully connected operator relies on cblas_gemm_s8u8s32"
289              << " which is only supported by MKL BLAS."
290              << " Please build MXNet with USE_BLAS=mkl to leverage this operator.";
291 #endif
292 }
293 
294 #if MXNET_USE_MKLDNN == 1
QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & in_data,const std::vector<OpReqType> & req,const std::vector<NDArray> & out_data)295 void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs,
296                                          const OpContext &ctx,
297                                          const std::vector<NDArray> &in_data,
298                                          const std::vector<OpReqType> &req,
299                                          const std::vector<NDArray> &out_data) {
300   MKLDNNQuantizedFullyConnectedForward(attrs, ctx, in_data, req, out_data);
301 }
302 #endif
303 
304 NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
305 .describe(R"code(Fully Connected operator for input, weight and bias data type of int8,
306 and accumulates in type int32 for the output. For each argument, two more arguments of type
307 float32 must be provided representing the thresholds of quantizing argument from data
308 type float32 to int8. The final outputs contain the convolution result in int32, and min
309 and max thresholds representing the threholds for quantizing the float32 output into int32.
310 
311 .. Note::
312     This operator only supports forward propogation. DO NOT use it in training.)code" ADD_FILELINE)
313 .set_num_inputs(
__anonadf075600102(const NodeAttrs& attrs) 314   [](const NodeAttrs& attrs) {
315     const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
316     return param.no_bias? 6 : 9;
317   })
318 .set_num_outputs(3)
319 .set_attr_parser(ParamParser<FullyConnectedParam>)
320 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonadf075600202(const NodeAttrs& attrs) 321   [](const NodeAttrs& attrs) {
322     const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
323     if (param.no_bias) {
324       return std::vector<std::string>{"data", "weight", "min_data", "max_data",
325                                       "min_weight", "max_weight"};
326     } else {
327       return std::vector<std::string>{"data", "weight", "bias", "min_data", "max_data",
328                                       "min_weight", "max_weight", "min_bias", "max_bias"};
329     }
330   })
331 .set_attr<nnvm::FListOutputNames>("FListOutputNames",
__anonadf075600302(const NodeAttrs& attrs) 332   [](const NodeAttrs& attrs) {
333     return std::vector<std::string>{"output", "min_output", "max_output"};
334   })
335 .set_attr<mxnet::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
336 .set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
337 .set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType)
338 // TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
339 // will be reverted after the improvement of CachedOP is done.
340 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
__anonadf075600402(const NodeAttrs& attrs) 341 .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
342 .set_attr<FCompute>("FCompute<cpu>", QuantizedFullyConnectedForwardCPU)
343 #if MXNET_USE_MKLDNN == 1
344 .set_attr<bool>("TIsMKLDNN", true)
345 .set_attr<FComputeEx>("FComputeEx<cpu>", QuantizedFullyConnectedForwardExCPU)
346 #endif
347 .set_attr<FResourceRequest>("FResourceRequest",
__anonadf075600502(const NodeAttrs& attrs) 348   [](const NodeAttrs& attrs) {
349     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
350   })
351 .add_argument("data", "NDArray-or-Symbol", "Input data.")
352 .add_argument("weight", "NDArray-or-Symbol", "weight.")
353 .add_argument("bias", "NDArray-or-Symbol", "bias.")
354 .add_argument("min_data", "NDArray-or-Symbol", "Minimum value of data.")
355 .add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.")
356 .add_argument("min_weight", "NDArray-or-Symbol", "Minimum value of weight.")
357 .add_argument("max_weight", "NDArray-or-Symbol", "Maximum value of weight.")
358 .add_argument("min_bias", "NDArray-or-Symbol", "Minimum value of bias.")
359 .add_argument("max_bias", "NDArray-or-Symbol", "Maximum value of bias.")
360 .add_arguments(FullyConnectedParam::__FIELDS__());
361 
362 NNVM_REGISTER_OP(FullyConnected)
__anonadf075600602(const NodeAttrs& attrs) 363 .set_attr<FQuantizable>("FQuantizable", [](const NodeAttrs& attrs) {
364     return QuantizeType::kMust;
365 })
__anonadf075600702(const NodeAttrs& attrs) 366 .set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
367     nnvm::ObjectPtr node = nnvm::Node::Create();
368     node->attrs.op = Op::Get("_contrib_quantized_fully_connected");
369     node->attrs.name = "quantized_" + attrs.name;
370     node->attrs.dict = attrs.dict;
371     if (node->op()->attr_parser != nullptr) {
372       node->op()->attr_parser(&(node->attrs));
373     }
374     return node;
375   });
376 }  // namespace op
377 }  // namespace mxnet
378