1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file quantize-inl.h
22  * \brief implementation of quantize operation
23  */
24 #ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_
25 #define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_
26 
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include <limits>
30 #include "../elemwise_op_common.h"
31 #include "../mshadow_op.h"
32 #include "../mxnet_op.h"
33 #include "./quantization_utils.h"
34 
35 namespace mxnet {
36 namespace op {
37 
38 struct QuantizeParam : public dmlc::Parameter<QuantizeParam> {
39   int   out_type;
DMLC_DECLARE_PARAMETERQuantizeParam40   DMLC_DECLARE_PARAMETER(QuantizeParam) {
41     DMLC_DECLARE_FIELD(out_type)
42     .add_enum("int8", mshadow::kInt8)
43     .add_enum("uint8", mshadow::kUint8)
44     .set_default(mshadow::kUint8)
45     .describe("Output data type.");
46   }
47 };
48 
49 // quantize float to uint8_t
50 struct quantize_unsigned {
51   template<typename DstDType, typename SrcDType>
Mapquantize_unsigned52   MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range,
53                                   float *omax_range, const SrcDType *in,
54                                   const float *imin_range, const float *imax_range,
55                                   const double min_limit, const double max_limit) {
56     using mshadow::red::limits::MinValue;
57     using mshadow::red::limits::MaxValue;
58     const float scale = (max_limit - min_limit) / (*imax_range - *imin_range);
59     out[i] = static_cast<DstDType>((in[i] - *imin_range) * scale + 0.5);
60     *omin_range = *imin_range;
61     *omax_range = *imax_range;
62   }
63 };
64 
65 
66 // keep zero-center
67 struct quantize_zero_centered {
68   template<typename DstDType, typename SrcDType>
Mapquantize_zero_centered69   MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range,
70                                   float *omax_range, const SrcDType *in,
71                                   const float *imin_range, const float *imax_range,
72                                   const float quantized_range) {
73     float real_range = MaxAbs(*imin_range, *imax_range);
74     float scale = quantized_range / real_range;
75     SrcDType x = in[i];
76     out[i] = static_cast<DstDType>(
77         Sign(x) * Min(Abs(x) * scale + 0.5f, quantized_range));
78     *omin_range = -real_range;
79     *omax_range =  real_range;
80   }
81 };
82 
83 template<typename xpu>
QuantizeCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)84 void QuantizeCompute(const nnvm::NodeAttrs& attrs,
85                      const OpContext& ctx,
86                      const std::vector<TBlob>& inputs,
87                      const std::vector<OpReqType>& req,
88                      const std::vector<TBlob>& outputs) {
89   using namespace mshadow;
90   using namespace mxnet_op;
91   using mshadow::red::limits::MinValue;
92   using mshadow::red::limits::MaxValue;
93   Stream<xpu> *s = ctx.get_stream<xpu>();
94 
95   const QuantizeParam& param = nnvm::get<QuantizeParam>(attrs.parsed);
96   if (param.out_type == mshadow::kUint8) {
97     if (std::is_same<xpu, gpu>::value) {
98       LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, "
99                     "please switch to the context of CPU or int8 data type for GPU.";
100     }
101     Kernel<quantize_unsigned, xpu>::Launch(s, outputs[0].Size(),
102       outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
103       inputs[0].dptr<float>(), inputs[1].dptr<float>(), inputs[2].dptr<float>(),
104       MinValue<uint8_t>(), MaxValue<uint8_t>());
105   } else if (param.out_type == mshadow::kInt8) {  // zero-centered quantization
106     Kernel<quantize_zero_centered, xpu>::Launch(s, outputs[0].Size(),
107       outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
108       inputs[0].dptr<float>(), inputs[1].dptr<float>(), inputs[2].dptr<float>(),
109       MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
110   } else {
111     LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
112   }
113 }
114 
QuantizeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)115 inline bool QuantizeShape(const nnvm::NodeAttrs& attrs,
116                           mxnet::ShapeVector *in_attrs,
117                           mxnet::ShapeVector *out_attrs) {
118   CHECK_EQ(in_attrs->size(), 3U);
119   CHECK_EQ(out_attrs->size(), 3U);
120 
121   mxnet::TShape dshape = (*in_attrs)[0];
122   for (size_t i = 1; i < 3; ++i) {
123     SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1));
124   }
125 
126   SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
127   SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, 1));
128   SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape(1, 1));
129 
130   if ((*out_attrs)[0].ndim() > 0) {
131     dshape[0] = ((*out_attrs)[0])[0];
132     SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
133   }
134 
135   return shape_is_known(out_attrs->at(0));
136 }
137 
QuantizeType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)138 inline bool QuantizeType(const nnvm::NodeAttrs& attrs,
139                          std::vector<int> *in_attrs,
140                          std::vector<int> *out_attrs) {
141   CHECK_EQ(in_attrs->size(), 3U);
142   CHECK_EQ(out_attrs->size(), 3U);
143   const QuantizeParam& param = nnvm::get<QuantizeParam>(attrs.parsed);
144   TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32);
145   TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32);
146   TYPE_ASSIGN_CHECK(*in_attrs, 2, mshadow::kFloat32);
147   if (param.out_type == mshadow::kUint8) {
148     TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
149   } else if (param.out_type == mshadow::kInt8) {
150     TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8);
151   } else {
152     LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
153   }
154   TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32);
155   TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32);
156   return (*in_attrs)[0] != -1;
157 }
158 
159 }  // namespace op
160 }  // namespace mxnet
161 #endif  // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_
162