1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file mkldnn_common.h
22  * \brief Common header file for MKLDNN backend subgraph
23  * \author Ciyong Chen
24 */
25 
26 #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_
27 #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_
28 #if MXNET_USE_MKLDNN == 1
29 #include <vector>
30 
31 namespace mxnet {
32 namespace op {
33 
34 template <typename DType>
GetWeightScales(const NDArray & weight,const NDArray * bias,const float data_scale,bool weight_channelwise_scale)35 static std::vector<float> GetWeightScales(const NDArray &weight, const NDArray *bias,
36                                           const float data_scale, bool weight_channelwise_scale) {
37   auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
38   std::vector<float> weight_scales;
39   const DType *weight_ptr = weight.data().dptr<DType>();
40   const DType *bias_ptr = bias? bias->data().dptr<DType>() : nullptr;
41   const auto wshape = weight.shape();
42   size_t channel = wshape[0];
43 
44   size_t offset = wshape.ProdShape(1, wshape.ndim());
45   std::vector<DType> weight_c_min(channel, MaxValue<DType>());
46   std::vector<DType> weight_c_max(channel, MinValue<DType>());
47   for (int c = 0; c < static_cast<int>(channel); ++c) {
48     const DType *p1 = weight_ptr + c * offset;
49     for (size_t k = 0; k < offset; ++k) {
50       if (weight_c_min[c] > p1[k])
51         weight_c_min[c] = p1[k];
52       if (weight_c_max[c] < p1[k])
53         weight_c_max[c] = p1[k];
54     }
55   }
56 
57   if (weight_channelwise_scale) {
58     weight_scales.resize(channel);
59     #pragma omp parallel for num_threads(nthreads)
60     for (int c = 0; c < static_cast<int>(channel); ++c) {
61       float scale = GetQuantizeScale(mshadow::kInt8, weight_c_min[c], weight_c_max[c]);
62       if (bias_ptr && bias_ptr[c]) {
63         // avoid overflow on bias
64         // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value of bias
65         // to INT_MAX / 2.
66         float scale_max =
67             static_cast<float>(bias_ptr[c] > 0 ? MaxValue<int32_t>() : MinValue<int32_t>()) / 2 /
68             bias_ptr[c] / data_scale;
69         scale = Min(scale, scale_max);
70       }
71       weight_scales[c] = scale;
72     }
73   } else {
74     DType total_min = weight_c_min[0];
75     DType total_max = weight_c_max[0];
76     for (size_t c = 0; c < channel; ++c) {
77       if (total_min > weight_c_min[c]) total_min = weight_c_min[c];
78       if (total_max < weight_c_max[c]) total_max = weight_c_max[c];
79     }
80     weight_scales.resize(3);
81     weight_scales[0] = GetQuantizeScale(mshadow::kInt8, total_min, total_max);
82     weight_scales[1] = total_min;
83     weight_scales[2] = total_max;
84   }
85   return weight_scales;
86 }
87 
88 static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias,
89                                      const mkldnn::memory::desc &weight_md,
90                                      const mkldnn::memory::desc *bias_md,
91                                      const int num_group, float data_scale,
92                                      const std::vector<float> &weight_scales,
93                                      const bool submit = true) {
94   MKLDNNStream *stream = MKLDNNStream::Get();
95   const auto new_weight = NDArray(weight_md);
96   const auto conv_weights_memory = new_weight.GetMKLDNNData();
97   mkldnn::primitive_attr weight_attr;
98   if (weight_scales.size()) {
99     const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1;
100     weight_attr.set_output_scales(weight_mask, weight_scales);
101   }
102   auto default_weights_memory = GetWeights(*weight, num_group);
103   if (default_weights_memory == nullptr) default_weights_memory = weight->GetMKLDNNData();
104   const auto weight_reorder_pd =
105       mkldnn::reorder::primitive_desc(*default_weights_memory, *conv_weights_memory, weight_attr);
106   MKLDNNStream::Get()->RegisterPrimArgs(
107       mkldnn::reorder(weight_reorder_pd),
108       {{MKLDNN_ARG_FROM, *default_weights_memory}, {MKLDNN_ARG_TO, *conv_weights_memory}});
109   NDArray new_bias;
110   if (has_bias && data_scale) {
111     std::vector<float> bias_scales(weight_scales.size());
112     for (size_t c = 0; c < weight_scales.size(); ++c) {
113       bias_scales[c] = weight_scales[c] * data_scale;
114     }
115     new_bias = NDArray(*bias_md);
116     const auto conv_bias_memory = new_bias.GetMKLDNNData();
117     const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1;
118     mkldnn::primitive_attr bias_attr;
119     bias_attr.set_output_scales(bias_mask, bias_scales);
120     auto bias_weights_memory = bias->GetMKLDNNData();
121     const auto bias_reorder_pd =
122         mkldnn::reorder::primitive_desc(*bias_weights_memory, *conv_bias_memory, bias_attr);
123     MKLDNNStream::Get()->RegisterPrimArgs(
124         mkldnn::reorder(bias_reorder_pd),
125         {{MKLDNN_ARG_FROM, *bias_weights_memory}, {MKLDNN_ARG_TO, *conv_bias_memory}});
126   }
127   if (submit)
128     stream->Submit();
129   *weight = new_weight;
130   if (has_bias && data_scale) *bias = new_bias;
131 }
132 
133 }  // namespace op
134 }  // namespace mxnet
135 
136 #endif  // if MXNET_USE_MKLDNN == 1
137 #endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_
138