1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <atomic>
18
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23 #include "common/utils.hpp"
24
25 #include "cpu/binary_injector_utils.hpp"
26 #include "cpu/cpu_primitive.hpp"
27 #include "cpu/gemm/gemm.hpp"
28 #include "cpu/gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
29 #include "cpu/gemm_x8s8s32x_convolution.hpp"
30 #include "cpu/simple_q10n.hpp"
31
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35
36 using namespace dnnl::impl::utils;
37 using namespace dnnl::impl::math;
38 using namespace dnnl::impl::memory_tracking::names;
39
mul_zp_src_comp_from_wei_by_zp_src(const int zp_comp_size,int32_t * zp_src_comp_scratch_dst,const int32_t * const zp_src_comp_from_wei,const int32_t zp_src)40 const int32_t *mul_zp_src_comp_from_wei_by_zp_src(const int zp_comp_size,
41 int32_t *zp_src_comp_scratch_dst,
42 const int32_t *const zp_src_comp_from_wei, const int32_t zp_src) {
43 static constexpr int cache_line_size
44 = platform::get_cache_line_size() / sizeof(int);
45 const auto res = std::div(zp_comp_size, cache_line_size);
46
47 if (res.quot) {
48 parallel_nd(res.quot, [&](const int shift_factor) {
49 const auto shift = shift_factor * cache_line_size;
50 const int32_t *__restrict const src = zp_src_comp_from_wei + shift;
51 int32_t *__restrict dst = zp_src_comp_scratch_dst + shift;
52
53 PRAGMA_OMP_SIMD()
54 for (int i = 0; i < cache_line_size; ++i) {
55 dst[i] = src[i] * zp_src;
56 }
57 });
58 }
59
60 if (res.rem) {
61 const auto shift = res.quot * cache_line_size;
62 const int32_t *__restrict const src = zp_src_comp_from_wei + shift;
63 int32_t *__restrict dst = zp_src_comp_scratch_dst + shift;
64
65 PRAGMA_OMP_SIMD()
66 for (int i = 0; i < res.rem; ++i) {
67 dst[i] = src[i] * zp_src;
68 }
69 }
70
71 return zp_src_comp_scratch_dst;
72 }
73
prepare_zp_params(const conv_gemm_conf_t & jcp,const memory_tracking::grantor_t & scratchpad,const int8_t * weights,const memory_desc_wrapper & weights_md,bool with_groups,const int32_t * zp_src,const int32_t * zp_dst)74 static zero_point_call_params_t prepare_zp_params(const conv_gemm_conf_t &jcp,
75 const memory_tracking::grantor_t &scratchpad, const int8_t *weights,
76 const memory_desc_wrapper &weights_md, bool with_groups,
77 const int32_t *zp_src, const int32_t *zp_dst) {
78
79 int32_t *zp_src_comp_pad = nullptr;
80 const int32_t *zp_src_comp = nullptr;
81
82 if (jcp.zp.src_exists) {
83 const int32_t *zp_src_comp_from_wei = get_src_zp_comp_from_wei(
84 weights, weights_md, jcp.signed_input, jcp.ngroups, jcp.oc);
85 int32_t *zp_src_comp_scratch
86 = scratchpad.get<int32_t>(key_conv_gemm_zp_src_comp);
87 static constexpr auto cache_line_size
88 = platform::get_cache_line_size() / sizeof(int);
89 const auto zp_comp_size = jcp.oc * jcp.ngroups;
90
91 if (jcp.zp.src_is_common) {
92 zp_src_comp = mul_zp_src_comp_from_wei_by_zp_src(zp_comp_size,
93 zp_src_comp_scratch, zp_src_comp_from_wei, *zp_src);
94 } else
95 zp_src_comp = zp_src_comp_from_wei;
96
97 if (jit_gemm_convolution_utils::padding_exists(jcp)) {
98 const auto shift = jcp.zp.src_is_common
99 ? utils::rnd_up(zp_comp_size, cache_line_size)
100 : 0;
101 zp_src_comp_pad = zp_src_comp_scratch + shift;
102 compute_zp_src_comp_pad(jcp, zp_src_comp_pad, zp_src, weights,
103 weights_md, with_groups);
104 }
105 }
106
107 return {zp_src, zp_dst, zp_src_comp, zp_src_comp_pad};
108 }
109
110 template <data_type_t src_type, data_type_t dst_type>
execute_forward(const exec_ctx_t & ctx) const111 status_t _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::execute_forward(
112 const exec_ctx_t &ctx) const {
113 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
114 auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
115 auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
116 auto bia_base = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
117 auto dst_base = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
118 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
119 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
120 const auto post_ops_binary_rhs_arg_vec
121 = binary_injector_utils::prepare_binary_args(
122 this->pd()->attr()->post_ops_, ctx);
123
124 auto scratchpad = ctx.get_scratchpad_grantor();
125
126 assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
127
128 const zero_point_call_params_t zp = prepare_zp_params(jcp, scratchpad,
129 wei_base, memory_desc_wrapper(pd()->weights_md(0)),
130 this->pd()->with_groups(), zp_src, zp_dst);
131
132 std::atomic<status_t> st(status::success);
133
134 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
135 status_t st_thr = execute_forward_thr(ithr, nthr, src_base, wei_base,
136 bia_base, dst_base, zp, scratchpad,
137 post_ops_binary_rhs_arg_vec.data(), ctx);
138
139 if (st_thr != status::success) st = st_thr;
140 });
141
142 return st;
143 }
144
get_wei_comp(const int8_t * weights,const memory_desc_wrapper & weights_md)145 static const int32_t *get_wei_comp(
146 const int8_t *weights, const memory_desc_wrapper &weights_md) {
147 const size_t comp_off
148 = weights_md.size() - weights_md.additional_buffer_size();
149 return reinterpret_cast<const int32_t *>(&weights[comp_off]);
150 }
151
152 template <data_type_t src_type, data_type_t dst_type>
153 status_t
execute_forward_thr(const int ithr,const int nthr,const src_data_t * src_base,const wei_data_t * wei_base,const char * bia_base,dst_data_t * dst_base,const zero_point_call_params_t & zp,const memory_tracking::grantor_t & scratchpad,const void * post_ops_binary_rhs_arg_vec,const exec_ctx_t & ctx) const154 _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::execute_forward_thr(
155 const int ithr, const int nthr, const src_data_t *src_base,
156 const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base,
157 const zero_point_call_params_t &zp,
158 const memory_tracking::grantor_t &scratchpad,
159 const void *post_ops_binary_rhs_arg_vec, const exec_ctx_t &ctx) const {
160
161 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
162
163 const auto src_md = memory_desc_wrapper(pd()->src_md());
164 const size_t src_mb_stride = src_md.blk_off(1);
165 const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
166
167 const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
168 const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
169
170 const auto dst_md = memory_desc_wrapper(pd()->dst_md());
171 const size_t dst_mb_stride = dst_md.blk_off(1);
172 const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
173
174 const float *scales = pd()->attr()->output_scales_.scales_;
175
176 const auto &post_ops = pd()->attr()->post_ops_;
177 const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
178 const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
179
180 uint8_t *__restrict col = scratchpad.get<uint8_t>(key_conv_gemm_col)
181 + (ptrdiff_t)ithr * jcp.im2col_sz;
182 src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr)
183 + (ptrdiff_t)ithr * jcp.is * jcp.ic;
184 acc_data_t *__restrict acc
185 = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
186 + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc;
187
188 const int32_t *_wei_comp
189 = jcp.signed_input ? get_wei_comp(wei_base, wei_md) : nullptr;
190
191 const bool should_apply_zp_src_comp_pad = jcp.zp.src_exists
192 && jit_gemm_convolution_utils::padding_exists(jcp);
193 const bool should_apply_zp_src_comp_pad_jit_pp
194 = should_apply_zp_src_comp_pad
195 && gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
196 const bool should_apply_zp_src_comp_outside_pp
197 = should_apply_zp_src_comp_pad
198 && !gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
199
200 int g {0}, n {0}, ohb {0}, owb {0};
201 size_t start = 0, end = 0;
202
203 const bool is_problem_3d = pd()->ndims() == 5;
204 assert(IMPLICATION(is_problem_3d,
205 jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow
206 && jcp.ic_block == jcp.ic));
207
208 const int nb_oh = div_up(jcp.oh, jcp.oh_block);
209 const int nb_ow = div_up(jcp.ow, jcp.ow_block);
210 const size_t work_amount = (size_t)jcp.ngroups * jcp.mb * nb_oh * nb_ow;
211 balance211(work_amount, nthr, ithr, start, end);
212 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
213 const uint8_t shift = jcp.signed_input ? 128 : 0;
214 parallel_nd(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; });
215
216 status_t st = status::success;
217
218 for (size_t iwork = start; iwork < end; ++iwork) {
219 const int oh = ohb * jcp.oh_block;
220 const int ow = owb * jcp.ow_block;
221 const src_data_t *__restrict src
222 = src_base + n * src_mb_stride + g * src_g_stride;
223 const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
224 const int32_t *__restrict wei_comp
225 = _wei_comp ? _wei_comp + g * jcp.oc : nullptr;
226 const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
227 const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
228 if (jcp.im2col_sz && is_problem_3d)
229 jit_gemm_convolution_utils::transpose_dt<src_data_t>(
230 jcp, src, imtr);
231
232 for (int od = 0; od < jcp.od; od++) {
233 dst_data_t *__restrict dst = dst_base + n * dst_mb_stride
234 + g * dst_g_stride
235 + ((od * jcp.oh + oh) * jcp.ow + ow) * jcp.dst_os_stride;
236 if (jcp.im2col_sz) {
237 if (is_problem_3d)
238 jit_gemm_convolution_utils::im2col_dt_3d<src_data_t,
239 uint8_t>(jcp, imtr, col, od);
240 else
241 jit_gemm_convolution_utils::im2col_dt<src_data_t, uint8_t>(
242 jcp, src, imtr, col, oh, h_step, ow, w_step);
243 }
244
245 const dim_t M = jcp.oc;
246 const dim_t K = jcp.ks * jcp.ic;
247 const dim_t N = h_step * w_step;
248 const dim_t LDA = M * jcp.ngroups;
249 const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups;
250 const char *BT = jcp.im2col_sz ? "T" : "N";
251 const int8_t off_a = 0;
252 const uint8_t off_b = 0;
253 const int32_t off_c = 0;
254 const float onef = 1.f, zerof = 0.f;
255 const src_data_t *__restrict src_od
256 = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
257 st = gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", &M, &N, &K,
258 &onef, wei, &LDA, &off_a,
259 jcp.im2col_sz ? col : (uint8_t *)src_od, &LDB, &off_b,
260 &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c);
261
262 if (st != status::success) return st;
263
264 const auto wei_adj_scale
265 = (wei_md.extra().flags & memory_extra_flags::scale_adjust)
266 ? wei_md.extra().scale_adjust
267 : 1.f;
268
269 if (should_apply_zp_src_comp_outside_pp)
270 apply_zp_src_comp_pad(jcp, g, od, oh, ow, h_step, w_step, acc,
271 zp.src_pad_comp);
272
273 const single_gemm_conv_chunk_desc_t chunk_desc
274 = should_apply_zp_src_comp_pad_jit_pp
275 ? single_gemm_conv_chunk_desc_t {od, 1, oh, h_step, ow,
276 w_step}
277 : single_gemm_conv_chunk_desc_t {};
278
279 parallel(0, [&](int ithr, int nthr) {
280 size_t _start, _end;
281 balance211((size_t)N * jcp.oc, nthr, ithr, _start, _end);
282
283 (*pp_ker_)(dst, acc, bia_base, scales, sum_scale,
284 1.f / wei_adj_scale, g, _start, _end, zp,
285 post_ops_binary_rhs_arg_vec, dst_base, ctx,
286 *pd()->dst_md(), chunk_desc);
287 });
288 }
289 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
290 }
291
292 return st;
293 }
294
295 template <data_type_t dst_type>
execute_backward_data(const exec_ctx_t & ctx) const296 status_t _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::execute_backward_data(
297 const exec_ctx_t &ctx) const {
298 auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
299 auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
300 auto bia_base = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
301 auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
302
303 auto scratchpad = ctx.get_scratchpad_grantor();
304
305 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
306
307 std::atomic<status_t> st(status::success);
308
309 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
310 status_t st_thr = execute_backward_data_thr(ithr, nthr, diff_dst_base,
311 wei_base, bia_base, diff_src_base, scratchpad);
312
313 if (st_thr != status::success) st = st_thr;
314 });
315
316 return st;
317 }
318
319 template <data_type_t dst_type>
320 status_t
execute_backward_data_thr(const int ithr,const int nthr,const diff_dst_data_t * diff_dst_base,const wei_data_t * wei_base,const char * bia_base,diff_src_data_t * diff_src_base,const memory_tracking::grantor_t & scratchpad) const321 _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::execute_backward_data_thr(
322 const int ithr, const int nthr, const diff_dst_data_t *diff_dst_base,
323 const wei_data_t *wei_base, const char *bia_base,
324 diff_src_data_t *diff_src_base,
325 const memory_tracking::grantor_t &scratchpad) const {
326 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
327
328 const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md());
329 const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
330 const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
331
332 const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
333 const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
334
335 const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md());
336 const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
337 const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
338 const size_t diff_src_os_stride
339 = diff_src_md.blocking_desc().strides[pd()->ndims() - 1];
340
341 /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
342 const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1);
343 const float *__restrict scales = pd()->attr()->output_scales_.scales_;
344 const size_t work_amount = jcp.ngroups * jcp.mb;
345
346 acc_data_t *__restrict col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
347 + (ptrdiff_t)ithr * jcp.im2col_sz;
348 acc_data_t *__restrict acc
349 = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
350 + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic;
351
352 int n {0}, g {0};
353 size_t start = 0, end = 0;
354
355 balance211(work_amount, nthr, ithr, start, end);
356 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
357
358 for (size_t iwork = start; iwork < end; ++iwork) {
359 const diff_dst_data_t *__restrict diff_dst = diff_dst_base
360 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
361 const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
362 diff_src_data_t *__restrict diff_src = diff_src_base
363 + n * diff_src_mb_stride + g * diff_src_g_stride;
364
365 const dim_t M = jcp.ks * jcp.ic;
366 const dim_t N = jcp.os * jcp.od;
367 const dim_t K = jcp.oc;
368 const int8_t off_a = 0;
369 const diff_dst_data_t off_b = 0;
370 const int32_t off_c = 0;
371 const float onef = 1.0, zerof = 0.0;
372 const dim_t LD = K * jcp.ngroups;
373
374 status_t st = gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, wei, &LD,
375 &off_a, diff_dst, &LD, &off_b, &zerof,
376 jcp.im2col_sz ? col : acc, &M, &off_c);
377
378 if (st != status::success) return st;
379
380 if (jcp.im2col_sz)
381 jit_gemm_convolution_utils::col2im_dt<int32_t>(jcp, col, acc);
382
383 // TODO: the code below is not tested and broken anyway.
384 parallel_nd(jcp.is * jcp.id, [&](int is) {
385 diff_src_data_t *__restrict diff_src_loc
386 = diff_src + is * diff_src_os_stride;
387 const acc_data_t *__restrict acc_loc = acc + is * jcp.ic;
388 const float *__restrict scales_loc
389 = scales + g * jcp.ic * scale_idx_mult;
390 for (int ic = 0; ic < jcp.ic; ic++) {
391 acc_data_t d = acc_loc[ic];
392 if (jcp.with_bias)
393 d += get_bias(bia_base, g * jcp.ic + ic,
394 pd()->desc()->bias_desc.data_type);
395 d *= scales_loc[ic * scale_idx_mult];
396 diff_src_loc[ic] = qz_a1b0<acc_data_t, diff_src_data_t>()(d);
397 }
398 });
399 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
400 }
401
402 return status::success;
403 }
404
405 using namespace data_type;
406
407 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, f32>;
408 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s32>;
409 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s8>;
410 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, u8>;
411
412 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, f32>;
413 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s32>;
414 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s8>;
415 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, u8>;
416
417 template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
418 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;
419 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>;
420 template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>;
421 } // namespace cpu
422 } // namespace impl
423 } // namespace dnnl
424