1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <atomic>
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23 #include "common/utils.hpp"
24 
25 #include "cpu/binary_injector_utils.hpp"
26 #include "cpu/cpu_primitive.hpp"
27 #include "cpu/gemm/gemm.hpp"
28 #include "cpu/gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
29 #include "cpu/gemm_x8s8s32x_convolution.hpp"
30 #include "cpu/simple_q10n.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 
36 using namespace dnnl::impl::utils;
37 using namespace dnnl::impl::math;
38 using namespace dnnl::impl::memory_tracking::names;
39 
mul_zp_src_comp_from_wei_by_zp_src(const int zp_comp_size,int32_t * zp_src_comp_scratch_dst,const int32_t * const zp_src_comp_from_wei,const int32_t zp_src)40 const int32_t *mul_zp_src_comp_from_wei_by_zp_src(const int zp_comp_size,
41         int32_t *zp_src_comp_scratch_dst,
42         const int32_t *const zp_src_comp_from_wei, const int32_t zp_src) {
43     static constexpr int cache_line_size
44             = platform::get_cache_line_size() / sizeof(int);
45     const auto res = std::div(zp_comp_size, cache_line_size);
46 
47     if (res.quot) {
48         parallel_nd(res.quot, [&](const int shift_factor) {
49             const auto shift = shift_factor * cache_line_size;
50             const int32_t *__restrict const src = zp_src_comp_from_wei + shift;
51             int32_t *__restrict dst = zp_src_comp_scratch_dst + shift;
52 
53             PRAGMA_OMP_SIMD()
54             for (int i = 0; i < cache_line_size; ++i) {
55                 dst[i] = src[i] * zp_src;
56             }
57         });
58     }
59 
60     if (res.rem) {
61         const auto shift = res.quot * cache_line_size;
62         const int32_t *__restrict const src = zp_src_comp_from_wei + shift;
63         int32_t *__restrict dst = zp_src_comp_scratch_dst + shift;
64 
65         PRAGMA_OMP_SIMD()
66         for (int i = 0; i < res.rem; ++i) {
67             dst[i] = src[i] * zp_src;
68         }
69     }
70 
71     return zp_src_comp_scratch_dst;
72 }
73 
prepare_zp_params(const conv_gemm_conf_t & jcp,const memory_tracking::grantor_t & scratchpad,const int8_t * weights,const memory_desc_wrapper & weights_md,bool with_groups,const int32_t * zp_src,const int32_t * zp_dst)74 static zero_point_call_params_t prepare_zp_params(const conv_gemm_conf_t &jcp,
75         const memory_tracking::grantor_t &scratchpad, const int8_t *weights,
76         const memory_desc_wrapper &weights_md, bool with_groups,
77         const int32_t *zp_src, const int32_t *zp_dst) {
78 
79     int32_t *zp_src_comp_pad = nullptr;
80     const int32_t *zp_src_comp = nullptr;
81 
82     if (jcp.zp.src_exists) {
83         const int32_t *zp_src_comp_from_wei = get_src_zp_comp_from_wei(
84                 weights, weights_md, jcp.signed_input, jcp.ngroups, jcp.oc);
85         int32_t *zp_src_comp_scratch
86                 = scratchpad.get<int32_t>(key_conv_gemm_zp_src_comp);
87         static constexpr auto cache_line_size
88                 = platform::get_cache_line_size() / sizeof(int);
89         const auto zp_comp_size = jcp.oc * jcp.ngroups;
90 
91         if (jcp.zp.src_is_common) {
92             zp_src_comp = mul_zp_src_comp_from_wei_by_zp_src(zp_comp_size,
93                     zp_src_comp_scratch, zp_src_comp_from_wei, *zp_src);
94         } else
95             zp_src_comp = zp_src_comp_from_wei;
96 
97         if (jit_gemm_convolution_utils::padding_exists(jcp)) {
98             const auto shift = jcp.zp.src_is_common
99                     ? utils::rnd_up(zp_comp_size, cache_line_size)
100                     : 0;
101             zp_src_comp_pad = zp_src_comp_scratch + shift;
102             compute_zp_src_comp_pad(jcp, zp_src_comp_pad, zp_src, weights,
103                     weights_md, with_groups);
104         }
105     }
106 
107     return {zp_src, zp_dst, zp_src_comp, zp_src_comp_pad};
108 }
109 
110 template <data_type_t src_type, data_type_t dst_type>
execute_forward(const exec_ctx_t & ctx) const111 status_t _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::execute_forward(
112         const exec_ctx_t &ctx) const {
113     const conv_gemm_conf_t &jcp = this->pd()->jcp_;
114     auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
115     auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
116     auto bia_base = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
117     auto dst_base = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
118     DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
119     DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
120     const auto post_ops_binary_rhs_arg_vec
121             = binary_injector_utils::prepare_binary_args(
122                     this->pd()->attr()->post_ops_, ctx);
123 
124     auto scratchpad = ctx.get_scratchpad_grantor();
125 
126     assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
127 
128     const zero_point_call_params_t zp = prepare_zp_params(jcp, scratchpad,
129             wei_base, memory_desc_wrapper(pd()->weights_md(0)),
130             this->pd()->with_groups(), zp_src, zp_dst);
131 
132     std::atomic<status_t> st(status::success);
133 
134     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
135         status_t st_thr = execute_forward_thr(ithr, nthr, src_base, wei_base,
136                 bia_base, dst_base, zp, scratchpad,
137                 post_ops_binary_rhs_arg_vec.data(), ctx);
138 
139         if (st_thr != status::success) st = st_thr;
140     });
141 
142     return st;
143 }
144 
get_wei_comp(const int8_t * weights,const memory_desc_wrapper & weights_md)145 static const int32_t *get_wei_comp(
146         const int8_t *weights, const memory_desc_wrapper &weights_md) {
147     const size_t comp_off
148             = weights_md.size() - weights_md.additional_buffer_size();
149     return reinterpret_cast<const int32_t *>(&weights[comp_off]);
150 }
151 
152 template <data_type_t src_type, data_type_t dst_type>
153 status_t
execute_forward_thr(const int ithr,const int nthr,const src_data_t * src_base,const wei_data_t * wei_base,const char * bia_base,dst_data_t * dst_base,const zero_point_call_params_t & zp,const memory_tracking::grantor_t & scratchpad,const void * post_ops_binary_rhs_arg_vec,const exec_ctx_t & ctx) const154 _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::execute_forward_thr(
155         const int ithr, const int nthr, const src_data_t *src_base,
156         const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base,
157         const zero_point_call_params_t &zp,
158         const memory_tracking::grantor_t &scratchpad,
159         const void *post_ops_binary_rhs_arg_vec, const exec_ctx_t &ctx) const {
160 
161     const conv_gemm_conf_t &jcp = this->pd()->jcp_;
162 
163     const auto src_md = memory_desc_wrapper(pd()->src_md());
164     const size_t src_mb_stride = src_md.blk_off(1);
165     const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
166 
167     const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
168     const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
169 
170     const auto dst_md = memory_desc_wrapper(pd()->dst_md());
171     const size_t dst_mb_stride = dst_md.blk_off(1);
172     const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
173 
174     const float *scales = pd()->attr()->output_scales_.scales_;
175 
176     const auto &post_ops = pd()->attr()->post_ops_;
177     const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
178     const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
179 
180     uint8_t *__restrict col = scratchpad.get<uint8_t>(key_conv_gemm_col)
181             + (ptrdiff_t)ithr * jcp.im2col_sz;
182     src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr)
183             + (ptrdiff_t)ithr * jcp.is * jcp.ic;
184     acc_data_t *__restrict acc
185             = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
186             + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc;
187 
188     const int32_t *_wei_comp
189             = jcp.signed_input ? get_wei_comp(wei_base, wei_md) : nullptr;
190 
191     const bool should_apply_zp_src_comp_pad = jcp.zp.src_exists
192             && jit_gemm_convolution_utils::padding_exists(jcp);
193     const bool should_apply_zp_src_comp_pad_jit_pp
194             = should_apply_zp_src_comp_pad
195             && gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
196     const bool should_apply_zp_src_comp_outside_pp
197             = should_apply_zp_src_comp_pad
198             && !gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
199 
200     int g {0}, n {0}, ohb {0}, owb {0};
201     size_t start = 0, end = 0;
202 
203     const bool is_problem_3d = pd()->ndims() == 5;
204     assert(IMPLICATION(is_problem_3d,
205             jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow
206                     && jcp.ic_block == jcp.ic));
207 
208     const int nb_oh = div_up(jcp.oh, jcp.oh_block);
209     const int nb_ow = div_up(jcp.ow, jcp.ow_block);
210     const size_t work_amount = (size_t)jcp.ngroups * jcp.mb * nb_oh * nb_ow;
211     balance211(work_amount, nthr, ithr, start, end);
212     nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
213     const uint8_t shift = jcp.signed_input ? 128 : 0;
214     parallel_nd(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; });
215 
216     status_t st = status::success;
217 
218     for (size_t iwork = start; iwork < end; ++iwork) {
219         const int oh = ohb * jcp.oh_block;
220         const int ow = owb * jcp.ow_block;
221         const src_data_t *__restrict src
222                 = src_base + n * src_mb_stride + g * src_g_stride;
223         const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
224         const int32_t *__restrict wei_comp
225                 = _wei_comp ? _wei_comp + g * jcp.oc : nullptr;
226         const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
227         const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
228         if (jcp.im2col_sz && is_problem_3d)
229             jit_gemm_convolution_utils::transpose_dt<src_data_t>(
230                     jcp, src, imtr);
231 
232         for (int od = 0; od < jcp.od; od++) {
233             dst_data_t *__restrict dst = dst_base + n * dst_mb_stride
234                     + g * dst_g_stride
235                     + ((od * jcp.oh + oh) * jcp.ow + ow) * jcp.dst_os_stride;
236             if (jcp.im2col_sz) {
237                 if (is_problem_3d)
238                     jit_gemm_convolution_utils::im2col_dt_3d<src_data_t,
239                             uint8_t>(jcp, imtr, col, od);
240                 else
241                     jit_gemm_convolution_utils::im2col_dt<src_data_t, uint8_t>(
242                             jcp, src, imtr, col, oh, h_step, ow, w_step);
243             }
244 
245             const dim_t M = jcp.oc;
246             const dim_t K = jcp.ks * jcp.ic;
247             const dim_t N = h_step * w_step;
248             const dim_t LDA = M * jcp.ngroups;
249             const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups;
250             const char *BT = jcp.im2col_sz ? "T" : "N";
251             const int8_t off_a = 0;
252             const uint8_t off_b = 0;
253             const int32_t off_c = 0;
254             const float onef = 1.f, zerof = 0.f;
255             const src_data_t *__restrict src_od
256                     = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
257             st = gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", &M, &N, &K,
258                     &onef, wei, &LDA, &off_a,
259                     jcp.im2col_sz ? col : (uint8_t *)src_od, &LDB, &off_b,
260                     &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c);
261 
262             if (st != status::success) return st;
263 
264             const auto wei_adj_scale
265                     = (wei_md.extra().flags & memory_extra_flags::scale_adjust)
266                     ? wei_md.extra().scale_adjust
267                     : 1.f;
268 
269             if (should_apply_zp_src_comp_outside_pp)
270                 apply_zp_src_comp_pad(jcp, g, od, oh, ow, h_step, w_step, acc,
271                         zp.src_pad_comp);
272 
273             const single_gemm_conv_chunk_desc_t chunk_desc
274                     = should_apply_zp_src_comp_pad_jit_pp
275                     ? single_gemm_conv_chunk_desc_t {od, 1, oh, h_step, ow,
276                             w_step}
277                     : single_gemm_conv_chunk_desc_t {};
278 
279             parallel(0, [&](int ithr, int nthr) {
280                 size_t _start, _end;
281                 balance211((size_t)N * jcp.oc, nthr, ithr, _start, _end);
282 
283                 (*pp_ker_)(dst, acc, bia_base, scales, sum_scale,
284                         1.f / wei_adj_scale, g, _start, _end, zp,
285                         post_ops_binary_rhs_arg_vec, dst_base, ctx,
286                         *pd()->dst_md(), chunk_desc);
287             });
288         }
289         nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
290     }
291 
292     return st;
293 }
294 
295 template <data_type_t dst_type>
execute_backward_data(const exec_ctx_t & ctx) const296 status_t _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::execute_backward_data(
297         const exec_ctx_t &ctx) const {
298     auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
299     auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
300     auto bia_base = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
301     auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
302 
303     auto scratchpad = ctx.get_scratchpad_grantor();
304 
305     const conv_gemm_conf_t &jcp = this->pd()->jcp_;
306 
307     std::atomic<status_t> st(status::success);
308 
309     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
310         status_t st_thr = execute_backward_data_thr(ithr, nthr, diff_dst_base,
311                 wei_base, bia_base, diff_src_base, scratchpad);
312 
313         if (st_thr != status::success) st = st_thr;
314     });
315 
316     return st;
317 }
318 
319 template <data_type_t dst_type>
320 status_t
execute_backward_data_thr(const int ithr,const int nthr,const diff_dst_data_t * diff_dst_base,const wei_data_t * wei_base,const char * bia_base,diff_src_data_t * diff_src_base,const memory_tracking::grantor_t & scratchpad) const321 _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::execute_backward_data_thr(
322         const int ithr, const int nthr, const diff_dst_data_t *diff_dst_base,
323         const wei_data_t *wei_base, const char *bia_base,
324         diff_src_data_t *diff_src_base,
325         const memory_tracking::grantor_t &scratchpad) const {
326     const conv_gemm_conf_t &jcp = this->pd()->jcp_;
327 
328     const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md());
329     const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
330     const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
331 
332     const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
333     const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
334 
335     const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md());
336     const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
337     const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
338     const size_t diff_src_os_stride
339             = diff_src_md.blocking_desc().strides[pd()->ndims() - 1];
340 
341     /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
342     const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1);
343     const float *__restrict scales = pd()->attr()->output_scales_.scales_;
344     const size_t work_amount = jcp.ngroups * jcp.mb;
345 
346     acc_data_t *__restrict col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
347             + (ptrdiff_t)ithr * jcp.im2col_sz;
348     acc_data_t *__restrict acc
349             = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
350             + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic;
351 
352     int n {0}, g {0};
353     size_t start = 0, end = 0;
354 
355     balance211(work_amount, nthr, ithr, start, end);
356     nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
357 
358     for (size_t iwork = start; iwork < end; ++iwork) {
359         const diff_dst_data_t *__restrict diff_dst = diff_dst_base
360                 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
361         const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
362         diff_src_data_t *__restrict diff_src = diff_src_base
363                 + n * diff_src_mb_stride + g * diff_src_g_stride;
364 
365         const dim_t M = jcp.ks * jcp.ic;
366         const dim_t N = jcp.os * jcp.od;
367         const dim_t K = jcp.oc;
368         const int8_t off_a = 0;
369         const diff_dst_data_t off_b = 0;
370         const int32_t off_c = 0;
371         const float onef = 1.0, zerof = 0.0;
372         const dim_t LD = K * jcp.ngroups;
373 
374         status_t st = gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, wei, &LD,
375                 &off_a, diff_dst, &LD, &off_b, &zerof,
376                 jcp.im2col_sz ? col : acc, &M, &off_c);
377 
378         if (st != status::success) return st;
379 
380         if (jcp.im2col_sz)
381             jit_gemm_convolution_utils::col2im_dt<int32_t>(jcp, col, acc);
382 
383         // TODO: the code below is not tested and broken anyway.
384         parallel_nd(jcp.is * jcp.id, [&](int is) {
385             diff_src_data_t *__restrict diff_src_loc
386                     = diff_src + is * diff_src_os_stride;
387             const acc_data_t *__restrict acc_loc = acc + is * jcp.ic;
388             const float *__restrict scales_loc
389                     = scales + g * jcp.ic * scale_idx_mult;
390             for (int ic = 0; ic < jcp.ic; ic++) {
391                 acc_data_t d = acc_loc[ic];
392                 if (jcp.with_bias)
393                     d += get_bias(bia_base, g * jcp.ic + ic,
394                             pd()->desc()->bias_desc.data_type);
395                 d *= scales_loc[ic * scale_idx_mult];
396                 diff_src_loc[ic] = qz_a1b0<acc_data_t, diff_src_data_t>()(d);
397             }
398         });
399         nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
400     }
401 
402     return status::success;
403 }
404 
405 using namespace data_type;
406 
407 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, f32>;
408 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s32>;
409 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s8>;
410 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, u8>;
411 
412 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, f32>;
413 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s32>;
414 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s8>;
415 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, u8>;
416 
417 template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
418 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;
419 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>;
420 template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>;
421 } // namespace cpu
422 } // namespace impl
423 } // namespace dnnl
424