1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file mkldnn_common.h
22 * \brief Common header file for MKLDNN backend subgraph
23 * \author Ciyong Chen
24 */
25
26 #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_
27 #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_
28 #if MXNET_USE_MKLDNN == 1
29 #include <vector>
30
31 namespace mxnet {
32 namespace op {
33
34 template <typename DType>
GetWeightScales(const NDArray & weight,const NDArray * bias,const float data_scale,bool weight_channelwise_scale)35 static std::vector<float> GetWeightScales(const NDArray &weight, const NDArray *bias,
36 const float data_scale, bool weight_channelwise_scale) {
37 auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
38 std::vector<float> weight_scales;
39 const DType *weight_ptr = weight.data().dptr<DType>();
40 const DType *bias_ptr = bias? bias->data().dptr<DType>() : nullptr;
41 const auto wshape = weight.shape();
42 size_t channel = wshape[0];
43
44 size_t offset = wshape.ProdShape(1, wshape.ndim());
45 std::vector<DType> weight_c_min(channel, MaxValue<DType>());
46 std::vector<DType> weight_c_max(channel, MinValue<DType>());
47 for (int c = 0; c < static_cast<int>(channel); ++c) {
48 const DType *p1 = weight_ptr + c * offset;
49 for (size_t k = 0; k < offset; ++k) {
50 if (weight_c_min[c] > p1[k])
51 weight_c_min[c] = p1[k];
52 if (weight_c_max[c] < p1[k])
53 weight_c_max[c] = p1[k];
54 }
55 }
56
57 if (weight_channelwise_scale) {
58 weight_scales.resize(channel);
59 #pragma omp parallel for num_threads(nthreads)
60 for (int c = 0; c < static_cast<int>(channel); ++c) {
61 float scale = GetQuantizeScale(mshadow::kInt8, weight_c_min[c], weight_c_max[c]);
62 if (bias_ptr && bias_ptr[c]) {
63 // avoid overflow on bias
64 // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value of bias
65 // to INT_MAX / 2.
66 float scale_max =
67 static_cast<float>(bias_ptr[c] > 0 ? MaxValue<int32_t>() : MinValue<int32_t>()) / 2 /
68 bias_ptr[c] / data_scale;
69 scale = Min(scale, scale_max);
70 }
71 weight_scales[c] = scale;
72 }
73 } else {
74 DType total_min = weight_c_min[0];
75 DType total_max = weight_c_max[0];
76 for (size_t c = 0; c < channel; ++c) {
77 if (total_min > weight_c_min[c]) total_min = weight_c_min[c];
78 if (total_max < weight_c_max[c]) total_max = weight_c_max[c];
79 }
80 weight_scales.resize(3);
81 weight_scales[0] = GetQuantizeScale(mshadow::kInt8, total_min, total_max);
82 weight_scales[1] = total_min;
83 weight_scales[2] = total_max;
84 }
85 return weight_scales;
86 }
87
88 static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias,
89 const mkldnn::memory::desc &weight_md,
90 const mkldnn::memory::desc *bias_md,
91 const int num_group, float data_scale,
92 const std::vector<float> &weight_scales,
93 const bool submit = true) {
94 MKLDNNStream *stream = MKLDNNStream::Get();
95 const auto new_weight = NDArray(weight_md);
96 const auto conv_weights_memory = new_weight.GetMKLDNNData();
97 mkldnn::primitive_attr weight_attr;
98 if (weight_scales.size()) {
99 const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1;
100 weight_attr.set_output_scales(weight_mask, weight_scales);
101 }
102 auto default_weights_memory = GetWeights(*weight, num_group);
103 if (default_weights_memory == nullptr) default_weights_memory = weight->GetMKLDNNData();
104 const auto weight_reorder_pd =
105 mkldnn::reorder::primitive_desc(*default_weights_memory, *conv_weights_memory, weight_attr);
106 MKLDNNStream::Get()->RegisterPrimArgs(
107 mkldnn::reorder(weight_reorder_pd),
108 {{MKLDNN_ARG_FROM, *default_weights_memory}, {MKLDNN_ARG_TO, *conv_weights_memory}});
109 NDArray new_bias;
110 if (has_bias && data_scale) {
111 std::vector<float> bias_scales(weight_scales.size());
112 for (size_t c = 0; c < weight_scales.size(); ++c) {
113 bias_scales[c] = weight_scales[c] * data_scale;
114 }
115 new_bias = NDArray(*bias_md);
116 const auto conv_bias_memory = new_bias.GetMKLDNNData();
117 const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1;
118 mkldnn::primitive_attr bias_attr;
119 bias_attr.set_output_scales(bias_mask, bias_scales);
120 auto bias_weights_memory = bias->GetMKLDNNData();
121 const auto bias_reorder_pd =
122 mkldnn::reorder::primitive_desc(*bias_weights_memory, *conv_bias_memory, bias_attr);
123 MKLDNNStream::Get()->RegisterPrimArgs(
124 mkldnn::reorder(bias_reorder_pd),
125 {{MKLDNN_ARG_FROM, *bias_weights_memory}, {MKLDNN_ARG_TO, *conv_bias_memory}});
126 }
127 if (submit)
128 stream->Submit();
129 *weight = new_weight;
130 if (has_bias && data_scale) *bias = new_bias;
131 }
132
133 } // namespace op
134 } // namespace mxnet
135
136 #endif // if MXNET_USE_MKLDNN == 1
137 #endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_
138