1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file quantized_elemwise_mul.cc
22  * \brief CPU Implementation of basic elementwise binary mul operators
23  */
24 #include <mxnet/op_attr_types.h>
25 #include "../tensor/elemwise_binary_op-inl.h"
26 #include "./quantized_elemwise_mul-inl.h"
27 #include "./quantization_utils.h"
28 
29 namespace mxnet {
30 namespace op {
31 
32 DMLC_REGISTER_PARAMETER(QuantizeElemwiseMulParam);
33 
QuantizedElemwiseMulOutputNames(const NodeAttrs & attrs)34 static std::vector<std::string> QuantizedElemwiseMulOutputNames(const NodeAttrs &attrs) {
35   const QuantizeElemwiseMulParam& params = nnvm::get<QuantizeElemwiseMulParam>(attrs.parsed);
36   if (params.enable_float_output)
37     return std::vector<std::string>{"output"};
38   else
39     return std::vector<std::string>{"output", "min_output", "max_output"};
40 }
41 
QuantizedElemwiseMulOpShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)42 inline bool QuantizedElemwiseMulOpShape(const nnvm::NodeAttrs& attrs,
43                                         mxnet::ShapeVector *in_attrs,
44                                         mxnet::ShapeVector *out_attrs) {
45   using namespace mshadow;
46   const QuantizeElemwiseMulParam& params = nnvm::get<QuantizeElemwiseMulParam>(attrs.parsed);
47   const mxnet::TShape &lshape = (*in_attrs)[quantized_elemwise_mul::kLhs];
48   const mxnet::TShape &rshape = (*in_attrs)[quantized_elemwise_mul::kRhs];
49   if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false;
50   CHECK_EQ(lshape.ndim(), rshape.ndim())
51     << "Currently, quantized elemwise multiply doesn't support broadcast.";
52   for (int i = 0; i < lshape.ndim(); ++i) {
53     CHECK_EQ(lshape[i], rshape[i]);
54   }
55   SHAPE_ASSIGN_CHECK(*in_attrs, quantized_elemwise_mul::kLhsMin, mxnet::TShape(1, 1));
56   SHAPE_ASSIGN_CHECK(*in_attrs, quantized_elemwise_mul::kLhsMax, mxnet::TShape(1, 1));
57   SHAPE_ASSIGN_CHECK(*in_attrs, quantized_elemwise_mul::kRhsMin, mxnet::TShape(1, 1));
58   SHAPE_ASSIGN_CHECK(*in_attrs, quantized_elemwise_mul::kRhsMax, mxnet::TShape(1, 1));
59 
60   out_attrs->clear();
61   SHAPE_ASSIGN_CHECK(*out_attrs, quantized_elemwise_mul::kOut, lshape);
62   if (!params.enable_float_output) {
63     SHAPE_ASSIGN_CHECK(*out_attrs, quantized_elemwise_mul::kOutMin, mxnet::TShape(1, 1));
64     SHAPE_ASSIGN_CHECK(*out_attrs, quantized_elemwise_mul::kOutMax, mxnet::TShape(1, 1));
65   }
66   return true;
67 }
68 
QuantizedElemwiseMulOpType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)69 inline bool QuantizedElemwiseMulOpType(const nnvm::NodeAttrs& attrs,
70                                        std::vector<int> *in_type,
71                                        std::vector<int> *out_type) {
72   const QuantizeElemwiseMulParam& params = nnvm::get<QuantizeElemwiseMulParam>(attrs.parsed);
73   for (int i = 0; i < 2; ++i) {
74     if (in_type->at(i) == mshadow::kInt8) {
75       TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kInt8);
76     } else {
77       LOG(ERROR) << "currently, quantized elemwise mul only support int8 inputs.";
78     }
79   }
80   TYPE_ASSIGN_CHECK(*in_type, quantized_elemwise_mul::kLhsMin, mshadow::kFloat32);
81   TYPE_ASSIGN_CHECK(*in_type, quantized_elemwise_mul::kLhsMax, mshadow::kFloat32);
82   TYPE_ASSIGN_CHECK(*in_type, quantized_elemwise_mul::kRhsMin, mshadow::kFloat32);
83   TYPE_ASSIGN_CHECK(*in_type, quantized_elemwise_mul::kRhsMax, mshadow::kFloat32);
84 
85   int dtype = mshadow::kInt32;
86   if (params.max_calib_range.has_value() && params.min_calib_range.has_value()) {
87     dtype = mshadow::kInt8;
88   }
89   if (!params.enable_float_output) {
90     TYPE_ASSIGN_CHECK(*out_type, quantized_elemwise_mul::kOut, dtype);
91     TYPE_ASSIGN_CHECK(*out_type, quantized_elemwise_mul::kOutMin, mshadow::kFloat32);
92     TYPE_ASSIGN_CHECK(*out_type, quantized_elemwise_mul::kOutMax, mshadow::kFloat32);
93   } else {
94     TYPE_ASSIGN_CHECK(*out_type, quantized_elemwise_mul::kOut, mshadow::kFloat32);
95   }
96   return true;
97 }
98 
QuantizedElemwiseMulOpStorageType(const nnvm::NodeAttrs & attrs,int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)99 inline bool QuantizedElemwiseMulOpStorageType(const nnvm::NodeAttrs& attrs,
100                                               int dev_mask,
101                                               DispatchMode* dispatch_mode,
102                                               std::vector<int> *in_attrs,
103                                               std::vector<int> *out_attrs) {
104   using namespace common;
105   *dispatch_mode = DispatchMode::kFCompute;
106 
107   for (auto &v : *out_attrs) {
108     v = kDefaultStorage;
109     if (common::stype_string(v).compare("unknown") == 0) {
110       return false;
111     }
112   }
113 
114   for (auto &v : *in_attrs) {
115     v = kDefaultStorage;
116     if (common::stype_string(v).compare("unknown") == 0) {
117       return false;
118     }
119   }
120   return true;
121 }
122 
QuantizedElemwiseMulOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)123 void QuantizedElemwiseMulOpForward(const nnvm::NodeAttrs &attrs,
124                                    const OpContext &ctx,
125                                    const std::vector<TBlob> &inputs,
126                                    const std::vector<OpReqType> &req,
127                                    const std::vector<TBlob> &outputs) {
128   const QuantizeElemwiseMulParam& params = nnvm::get<QuantizeElemwiseMulParam>(attrs.parsed);
129   using namespace mxnet_op;
130 
131   float lhs_min = inputs[quantized_elemwise_mul::kLhsMin].dptr<float>()[0];
132   float lhs_max = inputs[quantized_elemwise_mul::kLhsMax].dptr<float>()[0];
133   float rhs_min = inputs[quantized_elemwise_mul::kRhsMin].dptr<float>()[0];
134   float rhs_max = inputs[quantized_elemwise_mul::kRhsMax].dptr<float>()[0];
135 
136   float cached_output_min_ = 0.f;
137   float cached_output_max_ = 0.f;
138   float out_data_scale = 1.f;
139   float out_scale = 1.f;
140   if (!params.enable_float_output) {
141     float output_data_range = kInt32Range;
142     // dataA && dataB are int8
143     if (outputs[quantized_elemwise_mul::kOut].type_flag_ == mshadow::kInt8) {
144       output_data_range = kInt8Range;
145     } else {
146       output_data_range = kInt32Range;
147     }
148     if (params.max_calib_range.has_value() && params.min_calib_range.has_value()) {
149       cached_output_min_ = params.min_calib_range.value();
150       cached_output_max_ = params.max_calib_range.value();
151       out_data_scale = output_data_range / MaxAbs(cached_output_min_, cached_output_max_);
152       auto lhs_scale = kInt8Range / MaxAbs(lhs_min, lhs_max);
153       auto rhs_scale = kInt8Range / MaxAbs(rhs_min, rhs_max);
154       out_scale = out_data_scale / lhs_scale / rhs_scale;
155     } else {
156       Stream<cpu> *s = ctx.get_stream<cpu>();
157       if (inputs[quantized_elemwise_mul::kLhs].type_flag_ == mshadow::kInt8 &&
158           inputs[quantized_elemwise_mul::kRhs].type_flag_ == mshadow::kInt8) {
159         mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
160             s, 1, &cached_output_min_, &cached_output_max_, &lhs_min, &lhs_max, &rhs_min, &rhs_max);
161       } else {
162         LOG(ERROR) << "lhs and rhs only support iny8 dtype.";
163       }
164     }
165   } else {
166     auto lhs_scale = kInt8Range / MaxAbs(lhs_min, lhs_max);
167     auto rhs_scale = kInt8Range / MaxAbs(rhs_min, rhs_max);
168     out_scale = 1.0 / lhs_scale / rhs_scale;
169   }
170 
171   size_t out_size = outputs[quantized_elemwise_mul::kOut].Size();
172   auto *input_l = inputs[quantized_elemwise_mul::kLhs].dptr<int8_t>();
173   auto *input_r = inputs[quantized_elemwise_mul::kRhs].dptr<int8_t>();
174   // TODO(Xinyu): a temp solution to enable Elemwise INT8 computation,
175   // will be refactored after the DNNL primitive is done.
176   if (!params.enable_float_output) {
177     if (params.max_calib_range.has_value() && params.min_calib_range.has_value()) {
178       typedef int8_t out_type;
179       auto *out_data = outputs[quantized_elemwise_mul::kOut].dptr<out_type>();
180 #if !defined(_MSC_VER)
181 #pragma omp simd
182 #endif
183       for (size_t i = 0; i < out_size; ++i) {
184         const int8_t a = input_l[i];
185         const int8_t b = input_r[i];
186         out_data[i] = static_cast<out_type>(a * b * out_scale);
187       }
188     } else {
189       typedef int32_t out_type;
190       auto *out_data = outputs[quantized_elemwise_mul::kOut].dptr<out_type>();
191 #if !defined(_MSC_VER)
192 #pragma omp simd
193 #endif
194       for (size_t i = 0; i < out_size; ++i) {
195         const int8_t a = input_l[i];
196         const int8_t b = input_r[i];
197         out_data[i] = static_cast<out_type>(a * b * out_scale);
198       }
199     }
200   } else {
201     typedef float_t out_type;
202     auto *out_data = outputs[quantized_elemwise_mul::kOut].dptr<out_type>();
203 #if !defined(_MSC_VER)
204 #pragma omp simd
205 #endif
206     for (size_t i = 0; i < out_size; ++i) {
207       const int8_t a = input_l[i];
208       const int8_t b = input_r[i];
209       out_data[i] = static_cast<out_type>(a * b * out_scale);
210     }
211   }
212 
213   if (!params.enable_float_output) {
214     outputs[quantized_elemwise_mul::kOutMin].dptr<float>()[0] = cached_output_min_;
215     outputs[quantized_elemwise_mul::kOutMax].dptr<float>()[0] = cached_output_max_;
216   }
217 }
218 
219 NNVM_REGISTER_OP(_contrib_quantized_elemwise_mul)
220 .describe(R"code(Multiplies arguments int8 element-wise.
221 )code" ADD_FILELINE)
222 .set_num_inputs(6)
__anon15075b600102(const NodeAttrs& attrs) 223 .set_num_outputs([](const NodeAttrs& attrs) {
224   const QuantizeElemwiseMulParam& params = nnvm::get<QuantizeElemwiseMulParam>(attrs.parsed);
225   return (!params.enable_float_output) ? 3 : 1;
226 })
227 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anon15075b600202(const NodeAttrs& attrs) 228   [](const NodeAttrs& attrs) {
229     return std::vector<std::string>{"lhs", "rhs", "lhs_min", "lhs_max", "rhs_min", "rhs_max"};
230   })
231 .set_attr<nnvm::FListOutputNames>("FListOutputNames", QuantizedElemwiseMulOutputNames)
232 .set_attr<mxnet::FInferShape>("FInferShape", QuantizedElemwiseMulOpShape)
233 .set_attr<nnvm::FInferType>("FInferType", QuantizedElemwiseMulOpType)
234 .set_attr<FInferStorageType>("FInferStorageType", QuantizedElemwiseMulOpStorageType)
235 .set_attr<FResourceRequest>("FResourceRequest",
__anon15075b600302(const NodeAttrs& attrs) 236   [](const NodeAttrs& attrs) {
237     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
238   })
239 .set_attr<FCompute>("FCompute<cpu>", QuantizedElemwiseMulOpForward)
240 // TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
241 // will be reverted after the improvement of CachedOP is done.
242 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
__anon15075b600402(const NodeAttrs& attrs) 243 .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
244 .add_argument("lhs", "NDArray-or-Symbol", "first input")
245 .add_argument("rhs", "NDArray-or-Symbol", "second input")
246 .add_argument("lhs_min", "NDArray-or-Symbol", "Minimum value of first input.")
247 .add_argument("lhs_max", "NDArray-or-Symbol", "Maximum value of first input.")
248 .add_argument("rhs_min", "NDArray-or-Symbol", "Minimum value of second input.")
249 .add_argument("rhs_max", "NDArray-or-Symbol", "Maximum value of second input.")
250 .set_attr_parser(ParamParser<QuantizeElemwiseMulParam>)
251 .add_arguments(QuantizeElemwiseMulParam::__FIELDS__());
252 
253 NNVM_REGISTER_OP(elemwise_mul)
__anon15075b600502(const NodeAttrs& attrs) 254 .set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
255   nnvm::ObjectPtr node = nnvm::Node::Create();
256   node->attrs.op = Op::Get("_contrib_quantized_elemwise_mul");
257   node->attrs.name = "quantized_" + attrs.name;
258   node->attrs.dict = attrs.dict;
259   if (node->op()->attr_parser != nullptr) {
260     node->op()->attr_parser(&(node->attrs));
261   }
262   return node;
263 });
264 
265 }  // namespace op
266 }  // namespace mxnet
267