1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file mkldnn_dequantize-inl.h
22 * \author Wenting Jiang, Xinyu Chen
23 * \brief
24 */
25
26 #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
27 #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
28 #if MXNET_USE_MKLDNN == 1
29 #include <algorithm>
30 #include <string>
31 #include <vector>
32 #include "../../nn/mkldnn/mkldnn_base-inl.h"
33
34 namespace mxnet {
35 namespace op {
36
37
38 class SgMKLDNNDequantizeOperator {
39 public:
SgMKLDNNDequantizeOperator(const nnvm::NodeAttrs & attrs)40 explicit SgMKLDNNDequantizeOperator(const nnvm::NodeAttrs &attrs)
41 : param_(nnvm::get<DequantizeParam>(attrs.parsed)) {}
42
43 void Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
44 const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs);
45
46 private:
47 bool initialized_{false};
48 DequantizeParam param_;
49 float cached_data_min_{0.f};
50 float cached_data_max_{0.f};
51 mkldnn::memory::desc o_desc_;
52 mkldnn_args_map_t args_;
53 std::shared_ptr<mkldnn::reorder> fwd_pd_;
54 };
55
Forward(const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)56 void SgMKLDNNDequantizeOperator::Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
57 const std::vector<OpReqType> &req,
58 const std::vector<NDArray> &outputs) {
59 NDArray in_buffer = inputs[0];
60 if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default();
61 auto i_mem = in_buffer.GetMKLDNNData();
62 float data_min = *inputs[1].data().dptr<float>();
63 float data_max = *inputs[2].data().dptr<float>();
64
65 if (initialized_ && (cached_data_min_ != data_min || cached_data_max_ != data_max))
66 initialized_ = false;
67
68 if (!initialized_) {
69 cached_data_min_ = data_min;
70 cached_data_max_ = data_max;
71 float real_range = MaxAbs(cached_data_min_, cached_data_max_);
72 float quantized_range = 0.0;
73 if (inputs[0].dtype() == mshadow::kUint8) {
74 quantized_range = kUint8Range;
75 } else if (inputs[0].dtype() == mshadow::kInt8) {
76 quantized_range = kInt8Range;
77 real_range = MaxAbs(*inputs[1].data().dptr<float>(), *inputs[2].data().dptr<float>());
78 } else {
79 LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type";
80 }
81 float scale = real_range / quantized_range;
82 mkldnn::primitive_attr attr;
83 const int mask = 0;
84 std::vector<float> scales = {scale};
85 attr.set_output_scales(mask, scales);
86 mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
87 auto i_desc = i_mem->get_desc();
88 size_t i_ndim = in_buffer.shape().ndim();
89 if (i_ndim == 4) {
90 mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nchw;
91 mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims);
92 o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type<float>(), o_fmt);
93 } else {
94 o_desc_ = i_desc;
95 o_desc_.data.data_type = get_mkldnn_type_t<float>();
96 }
97 auto reorder_pd =
98 mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr);
99 fwd_pd_ = std::make_shared<mkldnn::reorder>(reorder_pd);
100 initialized_ = true;
101 }
102 auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]);
103 args_[MKLDNN_ARG_FROM] = *i_mem;
104 args_[MKLDNN_ARG_TO] = *o_mem.second;
105 MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_);
106 CommitOutput(outputs[0], o_mem);
107 MKLDNNStream::Get()->Submit();
108 }
109
SgMKLDNNDequantizeForward(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)110 static void SgMKLDNNDequantizeForward(const OpStatePtr &state_ptr, const OpContext &ctx,
111 const std::vector<NDArray> &inputs,
112 const std::vector<OpReqType> &req,
113 const std::vector<NDArray> &outputs) {
114 SgMKLDNNDequantizeOperator &op = state_ptr.get_state<SgMKLDNNDequantizeOperator>();
115 op.Forward(ctx, inputs, req, outputs);
116 }
117
118
119
120 } // namespace op
121 } // namespace mxnet
122
123 #endif // MXNET_USE_MKLDNN == 1
124 #endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
125