1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file quantize-inl.h
22 * \brief implementation of quantize operation
23 */
24 #ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_
25 #define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_
26
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include <limits>
30 #include "../elemwise_op_common.h"
31 #include "../mshadow_op.h"
32 #include "../mxnet_op.h"
33 #include "./quantization_utils.h"
34
35 namespace mxnet {
36 namespace op {
37
38 struct QuantizeParam : public dmlc::Parameter<QuantizeParam> {
39 int out_type;
DMLC_DECLARE_PARAMETERQuantizeParam40 DMLC_DECLARE_PARAMETER(QuantizeParam) {
41 DMLC_DECLARE_FIELD(out_type)
42 .add_enum("int8", mshadow::kInt8)
43 .add_enum("uint8", mshadow::kUint8)
44 .set_default(mshadow::kUint8)
45 .describe("Output data type.");
46 }
47 };
48
49 // quantize float to uint8_t
50 struct quantize_unsigned {
51 template<typename DstDType, typename SrcDType>
Mapquantize_unsigned52 MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range,
53 float *omax_range, const SrcDType *in,
54 const float *imin_range, const float *imax_range,
55 const double min_limit, const double max_limit) {
56 using mshadow::red::limits::MinValue;
57 using mshadow::red::limits::MaxValue;
58 const float scale = (max_limit - min_limit) / (*imax_range - *imin_range);
59 out[i] = static_cast<DstDType>((in[i] - *imin_range) * scale + 0.5);
60 *omin_range = *imin_range;
61 *omax_range = *imax_range;
62 }
63 };
64
65
66 // keep zero-center
67 struct quantize_zero_centered {
68 template<typename DstDType, typename SrcDType>
Mapquantize_zero_centered69 MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range,
70 float *omax_range, const SrcDType *in,
71 const float *imin_range, const float *imax_range,
72 const float quantized_range) {
73 float real_range = MaxAbs(*imin_range, *imax_range);
74 float scale = quantized_range / real_range;
75 SrcDType x = in[i];
76 out[i] = static_cast<DstDType>(
77 Sign(x) * Min(Abs(x) * scale + 0.5f, quantized_range));
78 *omin_range = -real_range;
79 *omax_range = real_range;
80 }
81 };
82
83 template<typename xpu>
QuantizeCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)84 void QuantizeCompute(const nnvm::NodeAttrs& attrs,
85 const OpContext& ctx,
86 const std::vector<TBlob>& inputs,
87 const std::vector<OpReqType>& req,
88 const std::vector<TBlob>& outputs) {
89 using namespace mshadow;
90 using namespace mxnet_op;
91 using mshadow::red::limits::MinValue;
92 using mshadow::red::limits::MaxValue;
93 Stream<xpu> *s = ctx.get_stream<xpu>();
94
95 const QuantizeParam& param = nnvm::get<QuantizeParam>(attrs.parsed);
96 if (param.out_type == mshadow::kUint8) {
97 if (std::is_same<xpu, gpu>::value) {
98 LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, "
99 "please switch to the context of CPU or int8 data type for GPU.";
100 }
101 Kernel<quantize_unsigned, xpu>::Launch(s, outputs[0].Size(),
102 outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
103 inputs[0].dptr<float>(), inputs[1].dptr<float>(), inputs[2].dptr<float>(),
104 MinValue<uint8_t>(), MaxValue<uint8_t>());
105 } else if (param.out_type == mshadow::kInt8) { // zero-centered quantization
106 Kernel<quantize_zero_centered, xpu>::Launch(s, outputs[0].Size(),
107 outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
108 inputs[0].dptr<float>(), inputs[1].dptr<float>(), inputs[2].dptr<float>(),
109 MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
110 } else {
111 LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
112 }
113 }
114
QuantizeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)115 inline bool QuantizeShape(const nnvm::NodeAttrs& attrs,
116 mxnet::ShapeVector *in_attrs,
117 mxnet::ShapeVector *out_attrs) {
118 CHECK_EQ(in_attrs->size(), 3U);
119 CHECK_EQ(out_attrs->size(), 3U);
120
121 mxnet::TShape dshape = (*in_attrs)[0];
122 for (size_t i = 1; i < 3; ++i) {
123 SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1));
124 }
125
126 SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
127 SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, 1));
128 SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape(1, 1));
129
130 if ((*out_attrs)[0].ndim() > 0) {
131 dshape[0] = ((*out_attrs)[0])[0];
132 SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
133 }
134
135 return shape_is_known(out_attrs->at(0));
136 }
137
QuantizeType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)138 inline bool QuantizeType(const nnvm::NodeAttrs& attrs,
139 std::vector<int> *in_attrs,
140 std::vector<int> *out_attrs) {
141 CHECK_EQ(in_attrs->size(), 3U);
142 CHECK_EQ(out_attrs->size(), 3U);
143 const QuantizeParam& param = nnvm::get<QuantizeParam>(attrs.parsed);
144 TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32);
145 TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32);
146 TYPE_ASSIGN_CHECK(*in_attrs, 2, mshadow::kFloat32);
147 if (param.out_type == mshadow::kUint8) {
148 TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
149 } else if (param.out_type == mshadow::kInt8) {
150 TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8);
151 } else {
152 LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
153 }
154 TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32);
155 TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32);
156 return (*in_attrs)[0] != -1;
157 }
158
159 } // namespace op
160 } // namespace mxnet
161 #endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_
162