1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file mkldnn_quantize_v2-inl.h
22 * \brief
23 */
24
25 #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
26 #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
27 #if MXNET_USE_MKLDNN == 1
28 #include <algorithm>
29 #include <string>
30 #include <vector>
31 #include "../../nn/mkldnn/mkldnn_base-inl.h"
32 #include "../quantize_v2-inl.h"
33
34 namespace mxnet {
35 namespace op {
36
37 class SgMKLDNNQuantizeOperator {
38 public:
SgMKLDNNQuantizeOperator(const nnvm::NodeAttrs & attrs)39 explicit SgMKLDNNQuantizeOperator(const nnvm::NodeAttrs &attrs)
40 : param_(nnvm::get<QuantizeV2Param>(attrs.parsed)) {}
41
42 void Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
43 const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs);
44
45 private:
46 bool initalized_{false};
47 QuantizeV2Param param_;
48 float cached_data_min_{0.f};
49 float cached_data_max_{0.f};
50 mkldnn::memory::desc o_desc_;
51 mkldnn_args_map_t args_;
52 std::shared_ptr<mkldnn::reorder> fwd_pd_;
53 };
54
Forward(const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)55 void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
56 const std::vector<OpReqType> &req,
57 const std::vector<NDArray> &outputs) {
58 float quantized_range = 0.0;
59 NDArray in_buffer = inputs[0];
60 float data_min = mshadow::red::limits::MaxValue<float>();
61 float data_max = mshadow::red::limits::MinValue<float>();
62
63 // Pass through quantized data
64 if (inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() == mshadow::kInt8) {
65 if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) {
66 *outputs[1].data().dptr<float>() = param_.min_calib_range.value();
67 *outputs[2].data().dptr<float>() = param_.max_calib_range.value();
68 } else {
69 if (inputs[0].dtype() == mshadow::kUint8) {
70 *outputs[1].data().dptr<float>() = 0;
71 *outputs[2].data().dptr<float>() = kUint8Range;
72 } else {
73 *outputs[1].data().dptr<float>() = -kInt8Range;
74 *outputs[2].data().dptr<float>() = kInt8Range;
75 }
76 }
77 if (req[0] != kWriteInplace) {
78 const_cast<NDArray &>(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData());
79 MKLDNNStream::Get()->Submit();
80 }
81 } else {
82 if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default();
83 auto i_mem = in_buffer.GetMKLDNNData();
84
85 if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) {
86 data_min = param_.min_calib_range.value();
87 data_max = param_.max_calib_range.value();
88 } else {
89 // no calib info
90 in_buffer = inputs[0].Reorder2Default();
91 auto in_ptr = in_buffer.data().dptr<float>();
92 auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
93 std::vector<float> data_maxs(nthreads, data_max);
94 std::vector<float> data_mins(nthreads, data_min);
95 #pragma omp parallel for num_threads(nthreads)
96 for (index_t i = 0; i < static_cast<index_t>(in_buffer.shape().Size()); i++) {
97 int tid = omp_get_thread_num();
98 if (in_ptr[i] > data_maxs[tid]) data_maxs[tid] = in_ptr[i];
99 if (in_ptr[i] < data_mins[tid]) data_mins[tid] = in_ptr[i];
100 }
101 for (index_t i = 0; i < nthreads; i++) {
102 if (data_maxs[i] > data_max) data_max = data_maxs[i];
103 if (data_mins[i] < data_min) data_min = data_mins[i];
104 }
105
106 if (initalized_ && (cached_data_min_ != data_min || cached_data_max_ != data_max))
107 initalized_ = false;
108 }
109
110 // Write output min/max
111 auto out_type = GetQuantizeOutputType(param_);
112 if (out_type == mshadow::kUint8) {
113 quantized_range = kUint8Range;
114 *outputs[1].data().dptr<float>() = data_min;
115 *outputs[2].data().dptr<float>() = data_max;
116 } else if (out_type == mshadow::kInt8) {
117 float real_range = MaxAbs(data_min, data_max);
118 quantized_range = kInt8Range;
119 *outputs[1].data().dptr<float>() = -real_range;
120 *outputs[2].data().dptr<float>() = real_range;
121 } else {
122 LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type";
123 }
124
125 if (!initalized_) {
126 cached_data_min_ = data_min;
127 cached_data_max_ = data_max;
128 float real_range = MaxAbs(data_min, data_max);
129 float scale = quantized_range / real_range;
130 mkldnn::primitive_attr attr;
131 const int mask = 0;
132 std::vector<float> scales = {scale};
133 attr.set_output_scales(mask, scales);
134 mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
135 auto i_desc = i_mem->get_desc();
136 size_t i_ndim = in_buffer.shape().ndim();
137 if (i_ndim == 4) {
138 mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nhwc;
139 mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims);
140 o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type(out_type), o_fmt);
141 } else {
142 o_desc_ = i_desc;
143 o_desc_.data.data_type = get_mkldnn_type_t(out_type);
144 }
145 auto reorder_pd =
146 mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr);
147 fwd_pd_ = std::make_shared<mkldnn::reorder>(reorder_pd);
148 initalized_ = true;
149 }
150 auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]);
151 args_[MKLDNN_ARG_FROM] = *i_mem;
152 args_[MKLDNN_ARG_TO] = *o_mem.second;
153 MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_);
154 CommitOutput(outputs[0], o_mem);
155 MKLDNNStream::Get()->Submit();
156 }
157 }
158
SgMKLDNNQuantizeForward(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)159 static void SgMKLDNNQuantizeForward(const OpStatePtr &state_ptr, const OpContext &ctx,
160 const std::vector<NDArray> &inputs,
161 const std::vector<OpReqType> &req,
162 const std::vector<NDArray> &outputs) {
163 SgMKLDNNQuantizeOperator &op = state_ptr.get_state<SgMKLDNNQuantizeOperator>();
164 op.Forward(ctx, inputs, req, outputs);
165 }
166
167 } // namespace op
168 } // namespace mxnet
169
170 #endif // MXNET_USE_MKLDNN == 1
171 #endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
172