1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/c_types_map.hpp"
18 #include "common/dnnl_thread.hpp"
19 #include "common/type_helpers.hpp"
20 #include "common/utils.hpp"
21 
22 #include "cpu/x64/jit_avx2_convolution.hpp"
23 
24 namespace dnnl {
25 namespace impl {
26 namespace cpu {
27 namespace x64 {
28 
29 using namespace dnnl::impl::status;
30 using namespace dnnl::impl::memory_tracking::names;
31 using namespace dnnl::impl::utils;
32 using namespace nstl;
33 
34 #define src_blk_off(f, n, c, d, h, w) \
35     (pd()->ndims() == 3) ? (f).blk_off(n, c, w) \
36                          : (pd()->ndims() == 4) ? (f).blk_off(n, c, h, w) \
37                                                 : (f).blk_off(n, c, d, h, w)
38 
39 #define wht_blk_off_(f, g, ...) \
40     pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
41 #define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
42     (pd()->ndims() == 3) \
43             ? wht_blk_off_(f, g, oc, ic, kw) \
44             : (pd()->ndims() == 4) ? wht_blk_off_(f, g, oc, ic, kh, kw) \
45                                    : wht_blk_off_(f, g, oc, ic, kd, kh, kw)
46 
execute_forward(const exec_ctx_t & ctx) const47 void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
48     const auto &jcp = kernel_->jcp;
49     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
50     auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
51     auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
52     auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
53     const auto post_ops_binary_rhs_arg_vec
54             = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
55 
56     const memory_desc_wrapper src_d(pd()->src_md());
57     const memory_desc_wrapper dst_d(pd()->dst_md());
58     const memory_desc_wrapper weights_d(pd()->weights_md(0));
59     const memory_desc_wrapper bias_d(pd()->weights_md(1));
60 
61     const size_t ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
62     const size_t work_amount
63             = jcp.mb * jcp.ngroups * ocb_work * jcp.od * jcp.oh;
64 
65     auto ker = [&](const int ithr, const int nthr) {
66         size_t start {0}, end {0};
67         balance211(work_amount, nthr, ithr, start, end);
68 
69         bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c,
70                 format_tag::nChw8c, format_tag::nCdhw8c);
71         int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic;
72         int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block;
73 
74         bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c,
75                 format_tag::nChw8c, format_tag::nCdhw8c);
76         int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc;
77         int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block;
78         int oc_bias_scale = is_oc_physically_blocked ? jcp.oc_block : 1;
79 
80         int icbb = 0;
81         while (icbb < jcp.nb_ic) {
82             int icb_step = jcp.nb_ic_blocking;
83             int icb_step_rem = jcp.nb_ic - icbb;
84             if (icb_step_rem < jcp.nb_ic_blocking_max) icb_step = icb_step_rem;
85 
86             size_t n {0}, g {0}, ocbb {0}, oh {0}, od {0};
87             nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
88                     od, jcp.od, oh, jcp.oh);
89             for (size_t iwork = start; iwork < end; ++iwork) {
90                 int ocb = ocbb * jcp.nb_oc_blocking;
91                 int ocb_num = jcp.nb_oc_blocking;
92 
93                 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
94                     auto par_conv = jit_conv_call_s();
95 
96                     const int ij = oh * jcp.stride_h;
97                     const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
98                     const int i_b_overflow
99                             = nstl::max(jcp.ih,
100                                       ij + (jcp.kh - 1) * (jcp.dilate_h + 1)
101                                               - jcp.t_pad + 1)
102                             - jcp.ih;
103 
104                     const int dj = od * jcp.stride_d;
105                     const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
106                     const int d_b_overflow
107                             = nstl::max(jcp.id,
108                                       dj + (jcp.kd - 1) * (jcp.dilate_d + 1)
109                                               - jcp.f_pad + 1)
110                             - jcp.id;
111 
112                     const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale;
113                     const size_t _ic = g * g_ic_offset + icb * icb_ic_scale;
114 
115                     const int ih = nstl::max(ij - jcp.t_pad
116                                     + div_up(i_t_overflow, (jcp.dilate_h + 1))
117                                             * (jcp.dilate_h + 1),
118                             0);
119 
120                     const int id = nstl::max(dj - jcp.f_pad
121                                     + div_up(d_t_overflow, (jcp.dilate_d + 1))
122                                             * (jcp.dilate_d + 1),
123                             0);
124 
125                     par_conv.src = &src[src_blk_off(src_d, n, _ic, id, ih, 0)];
126 
127                     par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
128 
129                     const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
130                     const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
131                     par_conv.filt = &weights[wht_blk_off(
132                             weights_d, g, ocb, icb, wd, wh, 0)];
133 
134                     if (icb == 0) {
135                         if (bias)
136                             par_conv.bias = &bias[bias_d.blk_off(
137                                     _oc * oc_bias_scale)];
138 
139                         par_conv.flags |= FLAG_IC_FIRST;
140                     }
141 
142                     if ((jcp.with_eltwise || jcp.with_binary)
143                             && icb + 1 == jcp.nb_ic)
144                         par_conv.flags |= FLAG_IC_LAST;
145 
146                     par_conv.reduce_work = this_block_size(
147                             icb * jcp.ic_block, jcp.ic, jcp.ic_block);
148 
149                     par_conv.oc_blocks
150                             = nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
151 
152                     if (ocbb == ocb_work - 1) par_conv.oc_flag |= FLAG_OC_LAST;
153 
154                     par_conv.kw_padding = 0;
155                     const int kh_padding = jcp.kh
156                             - div_up(i_t_overflow, (jcp.dilate_h + 1))
157                             - div_up(i_b_overflow, (jcp.dilate_h + 1));
158                     par_conv.kh_padding = nstl::max(0, kh_padding);
159 
160                     const int kd_padding = jcp.kd
161                             - div_up(d_t_overflow, (jcp.dilate_d + 1))
162                             - div_up(d_b_overflow, (jcp.dilate_d + 1));
163                     par_conv.kd_padding = nstl::max(0, kd_padding);
164 
165                     par_conv.oc_l_off = _oc * oc_bias_scale;
166                     par_conv.post_ops_binary_rhs_arg_vec
167                             = post_ops_binary_rhs_arg_vec.data();
168                     par_conv.dst_orig = dst;
169 
170                     (*kernel_)(&par_conv);
171                 }
172                 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, od,
173                         jcp.od, oh, jcp.oh);
174             }
175             icbb += icb_step;
176         }
177     };
178 
179     if (pd()->wants_padded_bias()) {
180         auto padded_bias = ctx.get_scratchpad_grantor().get<data_t>(
181                 key_conv_padded_bias);
182         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
183         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
184                 jcp.oc - jcp.oc_without_padding);
185         bias = padded_bias;
186     }
187 
188     parallel(jcp.nthr, ker);
189 
190     if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
191 }
192 
execute_backward_data(const exec_ctx_t & ctx) const193 void jit_avx2_convolution_bwd_data_t::execute_backward_data(
194         const exec_ctx_t &ctx) const {
195     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
196     auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
197     auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
198 
199     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
200     const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
201     const memory_desc_wrapper weights_d(pd()->weights_md(0));
202 
203     const auto &jcp = kernel_->jcp;
204 
205     int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
206     int ih_block_size = jcp.ih;
207     int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
208     size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks;
209 
210     const auto data_size = sizeof(data_t);
211     const auto L2 = platform::get_per_core_cache_size(2) / data_size;
212     // input + output + weights per iteration by nb_oc_blocking
213     auto ic_chunk = jcp.nb_ic_blocking * jcp.ic_block;
214     auto oc_chunk = jcp.nb_oc_blocking * jcp.oc_block;
215     auto iter_data_amount = (size_t)jcp.id * jcp.ih * jcp.iw * ic_chunk
216             + (size_t)jcp.od * jcp.oh * jcp.ow * oc_chunk
217             + (size_t)jcp.kd * jcp.kh * jcp.kw * ic_chunk * oc_chunk;
218 
219     if (work_amount < (size_t)2 * jcp.nthr || iter_data_amount > L2) {
220         ih_block_size = 1;
221         num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
222         work_amount *= num_ih_blocks;
223     }
224 
225     const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
226     const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
227 
228     bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c,
229             format_tag::nChw8c, format_tag::nCdhw8c);
230     int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic;
231     int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block;
232 
233     bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c,
234             format_tag::nChw8c, format_tag::nCdhw8c);
235     int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc;
236     int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block;
237 
238     const bool is_ddst_layout_nxc = one_of(
239             jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
240     const int oc_step = is_ddst_layout_nxc ? jcp.nb_oc_blocking : 1;
241 
242     auto ker = [&](const int ithr, const int nthr) {
243         size_t start {0}, end {0};
244         balance211(work_amount, nthr, ithr, start, end);
245 
246         size_t n {0}, g {0}, icbb {0}, ihb {0};
247         nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb,
248                 num_ih_blocks);
249         for (size_t iwork = start; iwork < end; ++iwork) {
250             for_(int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
251             for (int id = 0; id < jcp.id; ++id) {
252                 int cur_nb_oc = nstl::min(jcp.nb_oc - oc, jcp.nb_oc_blocking);
253 
254                 auto par_conv = jit_conv_call_s();
255 
256                 int d_t_overflow, d_b_overflow, od;
257                 if (jcp.dilate_d != 0) { // stride == 1
258                     const int dilate_d = jcp.dilate_d + 1;
259                     d_t_overflow
260                             = div_up(nstl::max(0, ext_kd - 1 - id - jcp.f_pad),
261                                     dilate_d);
262                     d_b_overflow = div_up(
263                             nstl::max(0, ext_kd - jcp.id + id - jcp.back_pad),
264                             dilate_d);
265                     od = id + jcp.f_pad - d_b_overflow * dilate_d;
266                 } else {
267                     d_t_overflow = nstl::max(0, jcp.kd - 1 - id - jcp.f_pad);
268                     d_b_overflow = nstl::max(
269                             0, jcp.kd - 1 - (jcp.id - 1 - id) - jcp.back_pad);
270                     od = id + jcp.f_pad - d_b_overflow;
271                 }
272                 par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
273 
274                 int ih_start = ihb * ih_block_size;
275                 int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
276                 for (int ih = ih_start; ih < ih_end; ++ih) {
277 
278                     int k_lo, oh;
279                     if (jcp.dilate_h != 0) { // stride == 1
280                         const int dilate_h = jcp.dilate_h + 1;
281                         int i_t_overflow = div_up(
282                                 nstl::max(0, ext_kh - 1 - ih - jcp.t_pad),
283                                 dilate_h);
284                         int i_b_overflow = div_up(
285                                 nstl::max(0, ext_kh - jcp.ih + ih - jcp.b_pad),
286                                 dilate_h);
287                         par_conv.kh_padding
288                                 = jcp.kh - i_t_overflow - i_b_overflow;
289                         k_lo = i_b_overflow;
290                         oh = ih + jcp.t_pad - k_lo * dilate_h;
291                     } else {
292                         int i_t_overflow = nstl::max(0,
293                                 (jcp.kh - 1 - ih - jcp.t_pad) / jcp.stride_h);
294                         int i_b_overflow = nstl::max(0,
295                                 (jcp.kh - jcp.ih + ih - jcp.b_pad)
296                                         / jcp.stride_h);
297                         int overflow_kh_hi = jcp.kh - 1
298                                 - modulo(jcp.ih - 1 + jcp.b_pad - ih,
299                                         jcp.stride_h);
300                         int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
301 
302                         par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
303                                         / jcp.stride_h
304                                 + 1 - i_t_overflow - i_b_overflow;
305 
306                         k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
307                         oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
308                     }
309                     par_conv.kw_padding = 0;
310 
311                     par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
312                             g * g_ic_offset
313                                     + jcp.nb_ic_blocking * icbb * icb_ic_scale,
314                             id, ih, 0)];
315                     par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, n,
316                             g * g_oc_offset + ocb_oc_scale * oc, od, oh, 0)];
317                     par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
318                             jcp.nb_ic_blocking * icbb, d_b_overflow, k_lo, 0)];
319 
320                     par_conv.src_prf = nullptr;
321                     par_conv.dst_prf = nullptr;
322                     par_conv.filt_prf = nullptr;
323                     par_conv.channel = oc;
324                     par_conv.ch_blocks = cur_nb_oc;
325 
326                     if (is_ddst_layout_nxc) {
327                         par_conv.load_work = this_block_size(
328                                 icbb * jcp.nb_ic_blocking * jcp.ic_block,
329                                 (size_t)jcp.ic,
330                                 jcp.nb_ic_blocking * jcp.ic_block);
331                         par_conv.reduce_work
332                                 = this_block_size(oc * jcp.oc_block, jcp.oc,
333                                         oc_step * jcp.oc_block);
334 
335                         if (par_conv.load_work % jcp.ic_block > 0)
336                             par_conv.flags |= FLAG_IC_LAST;
337                     }
338 
339                     (*kernel_)(&par_conv);
340                 }
341             }
342             nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb,
343                     num_ih_blocks);
344         }
345     };
346 
347     parallel(jcp.nthr, ker);
348 }
349 
execute_backward_weights(const exec_ctx_t & ctx) const350 void jit_avx2_convolution_bwd_weights_t::execute_backward_weights(
351         const exec_ctx_t &ctx) const {
352     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
353     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
354     auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS);
355     auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS);
356 
357     auto scratchpad = ctx.get_scratchpad_grantor();
358 
359     const auto &jcp = kernel_->jcp;
360 
361     const bool is_bias_padded
362             = pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0);
363 
364     data_t *diff_bias = is_bias_padded
365             ? scratchpad.get<data_t>(key_conv_padded_bias)
366             : diff_bias_in;
367 
368     const memory_desc_wrapper src_d(pd()->src_md());
369     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
370     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
371 
372     auto reducer_bia_scratchpad
373             = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia);
374     auto rb = this->reducer_bias_.get();
375     rb->init(reducer_bia_scratchpad);
376 
377     auto reducer_wei_scratchpad
378             = memory_tracking::grantor_t(scratchpad, prefix_reducer_wei);
379     auto rw = this->reducer_weights_.get();
380     rw->init(reducer_wei_scratchpad);
381 
382     bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c,
383             format_tag::nChw8c, format_tag::nCdhw8c);
384     int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic;
385     int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block;
386 
387     bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c,
388             format_tag::nChw8c, format_tag::nCdhw8c);
389     bool is_ddst_layout_nxc = !is_oc_physically_blocked;
390     int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc;
391     int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block;
392 
393     auto ker = [&](int ithr, int nthr) {
394         assert(nthr == rw->balancer().nthr_);
395 
396         const int w_job_start = rw->balancer().ithr_job_off(ithr);
397         const int w_njobs = rw->balancer().ithr_njobs(ithr);
398 
399         if (w_njobs == 0) return;
400 
401         /* reduction dimension */
402         int img_od_start {0}, img_od_end {0}, img {0}, od_s {0};
403         balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
404                 rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
405 
406         int img_start = img_od_start, img_end = img_od_end;
407         nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
408         const int img_first = img;
409 
410         /* jobs */
411         int g_start {0}, ocb_start {0}, icb_start {0};
412         nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start,
413                 jcp.nb_oc, icb_start, jcp.nb_ic);
414 
415         while (img_start < img_end) {
416             int g = g_start, ocb = ocb_start, icb = icb_start;
417 
418             const int work_rem = img_end - img_start;
419             const int od_e
420                     = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
421             const int id_s = od_s * jcp.stride_d;
422             const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
423 
424             if (id_s < idp - jcp.back_pad - jcp.kd + 1)
425                 for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) {
426                     const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale;
427                     const size_t _ic = g * g_ic_offset + icb * icb_ic_scale;
428 
429                     /* TODO: put dw <-- 0 in kernel */
430                     if (img == img_first)
431                         array_set(rw->get_local_ptr(ithr, diff_weights,
432                                           reducer_wei_scratchpad)
433                                         + w_job_loc * rw->balancer().job_size_,
434                                 0, rw->balancer().job_size_);
435 
436                     for (int od = od_s; od < od_e; ++od) {
437                         const int id = od * jcp.stride_d;
438                         if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break;
439 
440                         auto par_conv = jit_conv_call_s();
441                         par_conv.src
442                                 = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
443                         par_conv.dst = &diff_dst[src_blk_off(
444                                 diff_dst_d, img, _oc, od, 0, 0)];
445                         par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
446                                                 reducer_wei_scratchpad)
447                                 + w_job_loc * rw->balancer().job_size_;
448 
449                         if (ocb == jcp.nb_oc - 1)
450                             par_conv.flags |= FLAG_OC_LAST;
451 
452                         par_conv.channel = this_block_size(
453                                 icb * jcp.ic_block, jcp.ic, jcp.ic_block);
454 
455                         (*kernel_)(&par_conv);
456                     }
457                     nd_iterator_step(
458                             g, jcp.ngroups, ocb, jcp.nb_oc, icb, jcp.nb_ic);
459                 }
460             nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
461         }
462 
463         if (dnnl_thr_syncable())
464             rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
465     };
466 
467     auto ker_bias = [&](int ithr, int nthr) {
468         assert(nthr == rb->balancer().nthr_);
469 
470         const int b_job_start = rb->balancer().ithr_job_off(ithr);
471         const int b_njobs = rb->balancer().ithr_njobs(ithr);
472 
473         if (b_njobs == 0) return;
474 
475         /* reduction dimension */
476         int img_start {0}, img_end {0};
477         balance211(jcp.mb, rb->balancer().nthr_per_group_,
478                 rb->balancer().id_in_group(ithr), img_start, img_end);
479 
480         /* jobs */
481         int g_start {0}, ocb_start {0};
482         nd_iterator_init(
483                 b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc);
484 
485         for (int img = img_start; img < img_end; ++img) {
486             int g = g_start, ocb = ocb_start;
487             for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
488                 const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale;
489 
490                 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
491                 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
492                                          reducer_bia_scratchpad)
493                         + b_job_loc * rb->balancer().job_size_;
494 
495                 if (img == img_start)
496                     for (int o = 0; o < jcp.oc_block; ++o)
497                         d_bias[o] = 0.;
498 
499                 const int max_oc = this_block_size(
500                         ocb * jcp.oc_block, jcp.oc, jcp.oc_block);
501 
502                 for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) {
503                     PRAGMA_OMP_SIMD()
504                     for (int o = 0; o < max_oc; ++o)
505                         d_bias[o] += d_dst[o];
506                     d_dst += is_ddst_layout_nxc ? jcp.ngroups * jcp.oc
507                                                 : jcp.oc_block;
508                 }
509 
510                 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
511             }
512         }
513 
514         if (dnnl_thr_syncable())
515             rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
516     };
517 
518     if (dnnl_thr_syncable()) {
519         assert(IMPLICATION(pd()->with_bias(),
520                 rw->balancer().nthr_ == rb->balancer().nthr_));
521         parallel(rw->balancer().nthr_, [&](const int ithr, const int nthr) {
522             ker(ithr, nthr);
523             if (pd()->with_bias()) ker_bias(ithr, nthr);
524         });
525     } else {
526         parallel(rw->balancer().nthr_,
527                 [&](int ithr, int nthr) { ker(ithr, nthr); });
528         parallel(rw->balancer().nthr_, [&](int ithr, int nthr) {
529             assert(nthr == rw->balancer().nthr_);
530             MAYBE_UNUSED(nthr);
531             if (rw->balancer().ithr_njobs(ithr) == 0) return;
532             rw->reduce_nolock(ithr, diff_weights, reducer_wei_scratchpad);
533         });
534         if (pd()->with_bias()) {
535             parallel(rb->balancer().nthr_,
536                     [&](int ithr, int nthr) { ker_bias(ithr, nthr); });
537             parallel(rb->balancer().nthr_, [&](int ithr, int nthr) {
538                 assert(nthr == rb->balancer().nthr_);
539                 MAYBE_UNUSED(nthr);
540                 if (rb->balancer().ithr_njobs(ithr) == 0) return;
541                 rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad);
542             });
543         }
544     }
545 
546     /* TODO: put this in ker_bias */
547     if (pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0)) {
548         const int padded_stride = rnd_up(jcp.oc, jcp.oc_block);
549         const int stride = jcp.oc_without_padding;
550         for (int g = 0; g < jcp.ngroups; ++g)
551             utils::array_copy(diff_bias_in + g * stride,
552                     diff_bias + g * padded_stride, stride);
553     }
554 }
555 
556 } // namespace x64
557 } // namespace cpu
558 } // namespace impl
559 } // namespace dnnl
560 
561 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
562