1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file mkldnn_dequantize-inl.h
22  * \author Wenting Jiang, Xinyu Chen
23  * \brief
24  */
25 
26 #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
27 #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
28 #if MXNET_USE_MKLDNN == 1
29 #include <algorithm>
30 #include <string>
31 #include <vector>
32 #include "../../nn/mkldnn/mkldnn_base-inl.h"
33 
34 namespace mxnet {
35 namespace op {
36 
37 
38 class SgMKLDNNDequantizeOperator {
39  public:
SgMKLDNNDequantizeOperator(const nnvm::NodeAttrs & attrs)40   explicit SgMKLDNNDequantizeOperator(const nnvm::NodeAttrs &attrs)
41       : param_(nnvm::get<DequantizeParam>(attrs.parsed)) {}
42 
43   void Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
44                const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs);
45 
46  private:
47   bool initialized_{false};
48   DequantizeParam param_;
49   float cached_data_min_{0.f};
50   float cached_data_max_{0.f};
51   mkldnn::memory::desc o_desc_;
52   mkldnn_args_map_t args_;
53   std::shared_ptr<mkldnn::reorder> fwd_pd_;
54 };
55 
Forward(const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)56 void SgMKLDNNDequantizeOperator::Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
57                                          const std::vector<OpReqType> &req,
58                                          const std::vector<NDArray> &outputs) {
59   NDArray in_buffer = inputs[0];
60   if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default();
61   auto i_mem = in_buffer.GetMKLDNNData();
62   float data_min = *inputs[1].data().dptr<float>();
63   float data_max = *inputs[2].data().dptr<float>();
64 
65   if (initialized_ && (cached_data_min_ != data_min || cached_data_max_ != data_max))
66     initialized_ = false;
67 
68   if (!initialized_) {
69     cached_data_min_ = data_min;
70     cached_data_max_ = data_max;
71     float real_range = MaxAbs(cached_data_min_, cached_data_max_);
72     float quantized_range = 0.0;
73     if (inputs[0].dtype() == mshadow::kUint8) {
74       quantized_range = kUint8Range;
75     } else if (inputs[0].dtype() == mshadow::kInt8) {
76       quantized_range = kInt8Range;
77       real_range = MaxAbs(*inputs[1].data().dptr<float>(), *inputs[2].data().dptr<float>());
78     } else {
79       LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type";
80     }
81     float scale = real_range / quantized_range;
82     mkldnn::primitive_attr attr;
83     const int mask = 0;
84     std::vector<float> scales = {scale};
85     attr.set_output_scales(mask, scales);
86     mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
87     auto i_desc = i_mem->get_desc();
88     size_t i_ndim = in_buffer.shape().ndim();
89     if (i_ndim == 4) {
90       mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nchw;
91       mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims);
92       o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type<float>(), o_fmt);
93     } else {
94       o_desc_ = i_desc;
95       o_desc_.data.data_type = get_mkldnn_type_t<float>();
96     }
97     auto reorder_pd =
98         mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr);
99     fwd_pd_ = std::make_shared<mkldnn::reorder>(reorder_pd);
100     initialized_ = true;
101   }
102   auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]);
103   args_[MKLDNN_ARG_FROM] = *i_mem;
104   args_[MKLDNN_ARG_TO] = *o_mem.second;
105   MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_);
106   CommitOutput(outputs[0], o_mem);
107   MKLDNNStream::Get()->Submit();
108 }
109 
SgMKLDNNDequantizeForward(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)110 static void SgMKLDNNDequantizeForward(const OpStatePtr &state_ptr, const OpContext &ctx,
111                                       const std::vector<NDArray> &inputs,
112                                       const std::vector<OpReqType> &req,
113                                       const std::vector<NDArray> &outputs) {
114   SgMKLDNNDequantizeOperator &op = state_ptr.get_state<SgMKLDNNDequantizeOperator>();
115   op.Forward(ctx, inputs, req, outputs);
116 }
117 
118 
119 
120 }  // namespace op
121 }  // namespace mxnet
122 
123 #endif  // MXNET_USE_MKLDNN == 1
124 #endif  // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
125