1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/c_types_map.hpp"
18 #include "common/dnnl_thread.hpp"
19 #include "common/type_helpers.hpp"
20 #include "common/utils.hpp"
21 
22 #include "cpu/x64/jit_generator.hpp"
23 
24 #include "cpu/x64/jit_avx2_1x1_convolution.hpp"
25 
26 namespace dnnl {
27 namespace impl {
28 namespace cpu {
29 namespace x64 {
30 
31 using namespace dnnl::impl::status;
32 using namespace dnnl::impl::memory_tracking::names;
33 using namespace dnnl::impl::utils;
34 
35 #define data_blk_off(f, n, c, d, h, w) \
36     ((ndims == 3) ? (f).blk_off(n, c, w) \
37                   : ((ndims == 4) ? (f).blk_off(n, c, h, w) \
38                                   : (f).blk_off(n, c, d, h, w)))
39 /* convolution forward */
40 
execute_forward(const exec_ctx_t & ctx) const41 void jit_avx2_1x1_convolution_fwd_t::execute_forward(
42         const exec_ctx_t &ctx) const {
43     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
44     auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
45     auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
46     auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
47     auto weights_dw = CTX_IN_MEM(
48             const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
49     auto bias_dw = CTX_IN_MEM(
50             const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
51     const auto post_ops_binary_rhs_arg_vec
52             = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
53     const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_
54             ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx,
55                     pd()->jcp_.post_ops.entry_.size() + 1)
56             : std::vector<const void *> {};
57 
58     auto scratchpad = ctx.get_scratchpad_grantor();
59 
60     const auto &jcp = kernel_->jcp;
61     // TODO (Roma): remove this restriction
62     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
63 
64     if (pd()->wants_padded_bias()) {
65         auto padded_bias = scratchpad.get<data_t>(key_conv_padded_bias);
66         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
67         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
68                 jcp.oc - jcp.oc_without_padding);
69         bias = padded_bias;
70     }
71 
72     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
73         execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
74                 dst, scratchpad, post_ops_binary_rhs_arg_vec.data(),
75                 post_ops_binary_rhs_arg_vec_dw.data());
76     });
77 
78     if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
79 }
80 
execute_forward_thr(const int ithr,const int nthr,const data_t * src,const data_t * weights,const data_t * bias,const data_t * weights_dw,const data_t * bias_dw,data_t * dst,const memory_tracking::grantor_t & scratchpad,const void * post_ops_binary_rhs_arg_vec,const void * post_ops_binary_rhs_arg_vec_dw) const81 void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr,
82         const int nthr, const data_t *src, const data_t *weights,
83         const data_t *bias, const data_t *weights_dw, const data_t *bias_dw,
84         data_t *dst, const memory_tracking::grantor_t &scratchpad,
85         const void *post_ops_binary_rhs_arg_vec,
86         const void *post_ops_binary_rhs_arg_vec_dw) const {
87 
88     const memory_desc_wrapper src_d(pd()->src_md());
89     const memory_desc_wrapper dst_d(pd()->dst_md());
90     const memory_desc_wrapper weights_d(pd()->weights_md(0));
91     const memory_desc_wrapper dw_weights_d(
92             pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS));
93     const memory_desc_wrapper dw_bias_d(
94             pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS));
95 
96     const auto &jcp = kernel_->jcp;
97     auto rtus_space = pd()->rtus_.reduce_src_
98             ? scratchpad.get<data_t>(key_conv_rtus_space)
99             : nullptr;
100 
101     const int ndims = dst_d.ndims();
102 
103     const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
104     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
105     const int stride_w = pd()->desc()->strides[ndims - 3];
106 
107     const int nb_oc = jcp.nb_load;
108     const int nb_ic = jcp.nb_reduce;
109     const int nb_ic_blocking = jcp.nb_reduce_blocking;
110 
111     auto p = jit_1x1_conv_call_s();
112     auto rp = rtus_driver_t<avx2>::call_params_t();
113 
114     // override some constants for fused dw_conv
115     const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block;
116     const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast;
117     const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking;
118     const int nb_bcast_blocking_max
119             = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max;
120     const int nb_load_blocking = jcp.nb_load_blocking;
121     const int nb_load_blocking_max = jcp.with_dw_conv
122             ? jcp.nb_load_blocking
123             : jcp.nb_load_blocking_max;
124 
125     // Begin: declare Variables needed for dw conv.
126     data_t *pbuf;
127     size_t row_offset;
128     const int nb_buffer = jcp.nb_load_blocking;
129     auto jcp_dw = pd()->jcp_dw_;
130     std::vector<data_t *> addrs;
131     jit_generator *dw_jit_ker = nullptr;
132 
133     auto step = [](int default_step, int remaining, int tail_step) {
134         assert(default_step <= tail_step);
135         return remaining < tail_step ? remaining : default_step;
136     };
137 
138     auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g,
139                               int &bcast_step, int &od, int &oh, int &ow,
140                               int &id, int &ih, int &iw) {
141         int osb {0};
142         nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast);
143 
144         bcast_step = step(
145                 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max);
146         bcast_step = nstl::min(bcast_step, bcast_end - iwork);
147 
148         const int os = osb * os_block;
149         const int os_2d = os % (jcp.oh * jcp.ow);
150         od = os / (jcp.oh * jcp.ow);
151         oh = os_2d / jcp.ow;
152         ow = os_2d % jcp.ow;
153         id = od * stride_d;
154         ih = oh * stride_h;
155         iw = ow * stride_w;
156         rp.iw_start = iw;
157 
158         p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
159         rp.os = p.bcast_dim;
160     };
161 
162     auto init_load = [&](int ocb, int ocb_end, int &load_step) {
163         load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max);
164         // binary postop injector may override zero-padded areas, so proper
165         // output masking needs to be performed base on exact number of channels
166         const auto oc = jcp.with_binary ? jcp.oc_without_padding : jcp.oc;
167         p.load_dim = this_block_size(
168                 ocb * jcp.oc_block, oc, load_step * jcp.oc_block);
169     };
170 
171     auto ker_1x1 = [&](int ocb, int icb, int ocb_start, int n, int g, int od,
172                            int oh, int ow, int id, int ih, int iw) {
173         const bool is_dst_layout_nxc = utils::one_of(jcp.dst_tag,
174                 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
175         const int oc_off_idx = is_dst_layout_nxc
176                 ? g * jcp.oc + ocb * jcp.oc_block
177                 : g * nb_oc + ocb;
178 
179         p.output_data = jcp.with_dw_conv
180                 ? pbuf + (oh % jcp_dw->kh) * row_offset
181                 : &dst[data_blk_off(dst_d, n, oc_off_idx, od, oh, ow)];
182         p.bias_data
183                 = &bias[oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block)];
184 
185         p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
186                 | (icb + nb_ic_blocking >= nb_ic ? FLAG_REDUCE_LAST : 0);
187 
188         p.reduce_dim = this_block_size(
189                 icb * jcp.ic_block, jcp.ic, nb_ic_blocking * jcp.ic_block);
190         rp.icb = p.reduce_dim;
191 
192         p.load_data
193                 = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb)
194                                                : weights_d.blk_off(ocb, icb)];
195 
196         const bool is_src_layout_nxc = utils::one_of(jcp.src_tag,
197                 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
198         const int ic_off_idx = is_src_layout_nxc
199                 ? g * jcp.ic + icb * jcp.ic_block
200                 : g * nb_ic + icb;
201 
202         if (pd()->rtus_.reduce_src_) {
203             rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
204                     + (is_src_layout_nxc ? ic_off_idx
205                                          : jcp.is * ic_off_idx * jcp.ic_block);
206 
207             if (ocb == ocb_start) {
208                 rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw);
209                 (*rtus_driver_)(&rp);
210             }
211 
212             p.bcast_data = rp.ws;
213         } else
214             p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw);
215 
216         p.oc_l_off = ocb * jcp.oc_block;
217         p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
218         p.dst_orig = dst;
219 
220         (*kernel_)(&p);
221     };
222 
223     auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start,
224                             int ocb_end) {
225         if (bcast_start >= bcast_end || ocb_start >= ocb_end) return;
226         int iwork = bcast_start;
227         while (iwork < bcast_end) {
228             int n {0}, g {0}, bcast_step, od, oh, ow, id, ih, iw;
229             init_bcast(
230                     iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, ih, iw);
231             int ocb = ocb_start;
232             while (ocb < ocb_end) {
233                 int load_step;
234                 init_load(ocb, ocb_end, load_step);
235                 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
236                     ker_1x1(ocb, icb, ocb_start, n, g, od, oh, ow, id, ih, iw);
237                 }
238                 ocb += load_step;
239             }
240             iwork += bcast_step;
241         }
242     };
243 
244     auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) {
245         int oh_1x1 = nstl::max(dw_oh * jcp_dw->stride_h - jcp_dw->t_pad, 0);
246 
247         for (int i = 0; i < jcp_dw->kh; ++i)
248             addrs[i] = pbuf + ((oh_1x1++) % jcp_dw->kh) * row_offset;
249 
250         const auto ocb_end = ocb_start + load_step;
251         const auto wch_stride
252                 = jcp_dw->iw * jcp_dw->nb_ch_blocking * jcp_dw->ch_block;
253         const int dil_h = jcp_dw->dilate_h + 1;
254         const int str_h = jcp_dw->stride_h;
255         const int ch_num = jcp_dw->nb_ch_blocking;
256         const int ow = 0;
257         const int kw = 0;
258 
259         for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw->nb_ch_blocking) {
260 
261             const int i_t_overflow
262                     = nstl::max(0, (int)(jcp_dw->t_pad - dw_oh * str_h));
263             const int i_b_overflow
264                     = nstl::max(jcp_dw->ih,
265                               (int)(dw_oh * str_h + (jcp_dw->kh - 1) * dil_h
266                                       - jcp_dw->t_pad + 1))
267                     - jcp_dw->ih;
268 
269             const int kh = div_up(i_t_overflow, dil_h);
270             const int kh_padding = jcp_dw->kh - div_up(i_t_overflow, dil_h)
271                     - div_up(i_b_overflow, dil_h);
272 
273             jit_conv_call_s par_conv_dw;
274 
275             par_conv_dw.src = addrs.data();
276             par_conv_dw.dst = &dst[dst_d.blk_off(n, ch, dw_oh, ow)];
277 
278             par_conv_dw.filt
279                     = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)];
280             if (bias)
281                 par_conv_dw.bias
282                         = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw->ch_block)];
283 
284             par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding);
285 
286             par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw->nb_ch) - ch)
287                     * jcp_dw->ch_block;
288 
289             par_conv_dw.oc_l_off = ch * jcp_dw->ch_block;
290             par_conv_dw.post_ops_binary_rhs_arg_vec
291                     = post_ops_binary_rhs_arg_vec_dw;
292             par_conv_dw.dst_orig = dst;
293 
294             (*dw_jit_ker)(&par_conv_dw);
295 
296             for (int i = 0; i < jcp_dw->kh; ++i)
297                 addrs[i] += wch_stride;
298         }
299     };
300 
301     auto conv_dw = [&]() {
302         // Set variables
303         memory_tracking::grantor_t dw_scratchpad(
304                 scratchpad, memory_tracking::names::prefix_fusion);
305         auto dw_conv_buffer
306                 = dw_scratchpad.get<data_t>(key_fusion_inout_buffer);
307         dw_jit_ker = kernel_dw_avx2 ? kernel_dw_avx2->ker()
308                                     : kernel_dw_sse41->ker();
309 
310         const auto dw_conv_buffer_size_
311                 = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block;
312         pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
313         row_offset = dw_conv_buffer_size_ / jcp_dw->kh;
314         addrs.resize(jcp_dw->kh);
315 
316         int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0};
317         balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start,
318                 bcast_end, nb_oc, ocb_start, ocb_end, 1);
319 
320         while (ocb_start < ocb_end) {
321             int load_step;
322             init_load(ocb_start, ocb_end, load_step);
323 
324             int oh_1x1 = 0;
325             auto bcast_iter = bcast_start;
326             while (bcast_iter < bcast_end) {
327                 int n, g, oh_dw;
328                 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw,
329                         jcp_dw->oh);
330                 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary
331                 const int oh_1x1_range
332                         = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad;
333                 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0);
334                 const int oh_1x1_end
335                         = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh);
336                 oh_1x1 = nstl::max(
337                         oh_1x1_begin, oh_1x1); // Skip rows computed previously
338 
339                 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw->oh
340                 const int bcast_start_1x1
341                         = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1;
342                 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end;
343 
344                 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start,
345                         ocb_start + load_step);
346                 oh_1x1 = oh_1x1_end;
347                 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw);
348 
349                 bcast_iter += nb_bcast_blocking;
350             }
351             ocb_start += load_step;
352         }
353     };
354 
355     if (jcp.with_dw_conv) {
356         conv_dw();
357     } else {
358         int start {0}, end {0};
359         const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
360         balance211(work_amount, nthr, ithr, start, end);
361         conv_1x1(start, end, 0, jcp.nb_load);
362     }
363 }
364 
365 /* convolution backward wtr data */
366 
execute_backward_data(const exec_ctx_t & ctx) const367 void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data(
368         const exec_ctx_t &ctx) const {
369     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
370     auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
371     auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
372 
373     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
374     const memory_desc_wrapper weights_d(pd()->weights_md(0));
375     const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
376 
377     const auto &jcp = kernel_->jcp;
378     auto rtus_space = pd()->rtus_.reduce_src_
379             ? ctx.get_scratchpad_grantor().get<data_t>(key_conv_rtus_space)
380             : nullptr;
381 
382     // TODO (Roma): remove this restriction
383     assert(jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1);
384     const int ndims = diff_dst_d.ndims();
385 
386     const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
387     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
388     const int stride_w = pd()->desc()->strides[ndims - 3];
389 
390     const int nb_ic = jcp.nb_load;
391     const int nb_oc = jcp.nb_reduce;
392     const int os_block = jcp.bcast_block;
393     const int nb_oc_blocking = jcp.nb_reduce_blocking;
394 
395     const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
396 
397     auto step = [](int default_step, int remaining, int tail_step) {
398         assert(default_step <= tail_step);
399         return remaining < tail_step ? remaining : default_step;
400     };
401 
402     auto ker = [&](const int ithr, const int nthr) {
403         auto p = jit_1x1_conv_call_s();
404         auto rp = rtus_driver_t<avx2>::call_params_t();
405 
406         int start {0}, end {0};
407         balance211(work_amount, nthr, ithr, start, end);
408 
409         int load_step = 0;
410         for (int icb = 0; icb < jcp.nb_load; icb += load_step) {
411             load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
412                     jcp.nb_load_blocking_max);
413 
414             p.load_dim = this_block_size(
415                     icb * jcp.ic_block, jcp.ic, load_step * jcp.ic_block);
416             rp.icb = p.load_dim;
417 
418             int bcast_step;
419             for (int iwork = start; iwork < end; iwork += bcast_step) {
420                 int n {0}, g {0}, osb {0};
421                 nd_iterator_init(
422                         iwork, n, jcp.mb, g, jcp.ngroups, osb, jcp.nb_bcast);
423 
424                 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
425                         jcp.nb_bcast_blocking_max);
426                 bcast_step = nstl::min(bcast_step, end - iwork);
427 
428                 const int os = osb * os_block;
429                 p.bcast_dim
430                         = this_block_size(os, jcp.os, bcast_step * os_block);
431                 rp.os = p.bcast_dim;
432 
433                 const int od = os / (jcp.oh * jcp.ow);
434                 const int os_2d = os % (jcp.oh * jcp.ow);
435                 const int oh = os_2d / jcp.ow;
436                 const int ow = os_2d % jcp.ow;
437                 const int id = od * stride_d;
438                 const int ih = oh * stride_h;
439                 const int iw = ow * stride_w;
440                 rp.iw_start = iw;
441 
442                 const bool is_dsrc_layout_nxc = utils::one_of(jcp.src_tag,
443                         format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
444                 const int ic_off_idx = is_dsrc_layout_nxc
445                         ? g * jcp.ic + icb * jcp.ic_block
446                         : g * nb_ic + icb;
447                 rp.src = diff_src
448                         + data_blk_off(diff_src_d, n, ic_off_idx, id, ih, iw);
449                 if (pd()->rtus_.reduce_src_) {
450                     rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_;
451                     p.output_data = rp.ws;
452                 } else
453                     p.output_data = rp.src;
454 
455                 for (int ocb = 0; ocb < jcp.nb_reduce;
456                         ocb += jcp.nb_reduce_blocking) {
457                     const bool is_ddst_layout_nxc
458                             = utils::one_of(jcp.dst_tag, format_tag::nwc,
459                                     format_tag::nhwc, format_tag::ndhwc);
460                     const int oc_off_idx = is_ddst_layout_nxc
461                             ? g * jcp.oc + ocb * jcp.oc_block
462                             : g * nb_oc + ocb;
463                     size_t diff_dst_off = data_blk_off(
464                             diff_dst_d, n, oc_off_idx, od, oh, ow);
465                     p.bcast_data = &diff_dst[diff_dst_off];
466 
467                     p.load_data = &weights[pd()->with_groups()
468                                     ? weights_d.blk_off(g, ocb, icb)
469                                     : weights_d.blk_off(ocb, icb)];
470 
471                     p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
472 
473                     p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
474                             nb_oc_blocking * jcp.oc_block);
475 
476                     (*kernel_)(&p);
477                 }
478 
479                 if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp);
480             }
481         }
482     };
483 
484     parallel(jcp.nthr, ker);
485 }
486 
487 /* convolution backward wtr weights */
488 
init(engine_t * engine)489 status_t jit_avx2_1x1_convolution_bwd_weights_t::init(engine_t *engine) {
490     CHECK(safe_ptr_assign(kernel_,
491             new jit_avx2_1x1_conv_kernel_f32(
492                     pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
493     CHECK(kernel_->create_kernel());
494 
495     CHECK(safe_ptr_assign(reducer_weights_,
496             new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_)));
497     CHECK(reducer_weights_->create_kernel());
498 
499     CHECK(safe_ptr_assign(reducer_bias_,
500             new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_)));
501     if (pd()->with_bias()) {
502         assert(reducer_weights_->balancer().nthr_
503                 == reducer_bias_->balancer().nthr_);
504         CHECK(reducer_bias_->create_kernel());
505     }
506 
507     CHECK(init_rtus_driver<avx2>(this));
508     return status::success;
509 }
510 
execute_backward_weights(const exec_ctx_t & ctx) const511 void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights(
512         const exec_ctx_t &ctx) const {
513     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
514     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
515     auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS);
516     auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS);
517 
518     auto scratchpad = ctx.get_scratchpad_grantor();
519 
520     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
521     const memory_desc_wrapper src_d(pd()->src_md());
522     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
523     const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
524 
525     const auto &jcp = kernel_->jcp;
526     auto rtus_space = pd()->rtus_.reduce_src_
527             ? scratchpad.get<data_t>(key_conv_rtus_space)
528             : nullptr;
529 
530     const bool is_bias_padded
531             = pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0);
532 
533     data_t *diff_bias = is_bias_padded
534             ? scratchpad.get<data_t>(key_conv_padded_bias)
535             : diff_bias_in;
536 
537     auto reducer_bia_scratchpad
538             = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia);
539     auto rb = this->reducer_bias_.get();
540     rb->init(reducer_bia_scratchpad);
541 
542     auto reducer_wei_scratchpad
543             = memory_tracking::grantor_t(scratchpad, prefix_reducer_wei);
544     auto rw = this->reducer_weights_.get();
545     rw->init(reducer_wei_scratchpad);
546 
547     const int ndims = diff_dst_d.ndims();
548     // TODO (Roma): remove this restriction
549     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
550 
551     const int nb_ic = jcp.nb_bcast;
552     const int nb_ic_blocking = jcp.nb_bcast_blocking;
553     const int bcast_work = div_up(nb_ic, nb_ic_blocking);
554 
555     const int nb_oc = jcp.nb_load;
556     const int nb_oc_blocking = jcp.nb_load_blocking;
557     const int load_work = div_up(nb_oc, nb_oc_blocking);
558 
559     const int sp_dim = jcp.reduce_dim;
560     const int mb_sp_work = jcp.mb * sp_dim;
561 
562     const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
563     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
564     const int stride_w = pd()->desc()->strides[ndims - 3];
565 
566     const bool is_src_layout_nxc = utils::one_of(
567             jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
568     const bool is_ddst_layout_nxc = utils::one_of(
569             jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
570 
571     auto step = [](int default_step, int remaining, int tail_step) {
572         assert(default_step <= tail_step);
573         return remaining < tail_step ? remaining : default_step;
574     };
575 
576     auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image,
577                                  data_t *store_to, size_t store_to_ld,
578                                  const data_t *diff_dst, const data_t *src,
579                                  int ithr) {
580         auto p = jit_1x1_conv_call_s();
581         auto rp = rtus_driver_t<avx2>::call_params_t();
582 
583         p.output_stride = store_to_ld * sizeof(float);
584 
585         int oc_b_step = 0;
586         for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) {
587             oc_b_step = step(nb_oc_blocking, nb_oc_blocking - oc_b,
588                     jcp.nb_load_blocking_max);
589             p.load_dim = this_block_size(
590                     oc_b * jcp.oc_block, jcp.oc, oc_b_step * jcp.oc_block);
591 
592             int ic_b_step = 0;
593             for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) {
594                 ic_b_step = step(nb_ic_blocking, nb_ic_blocking - ic_b,
595                         jcp.nb_bcast_blocking_max);
596                 p.bcast_dim = this_block_size(
597                         ic_b * jcp.ic_block, jcp.ic, ic_b_step * jcp.ic_block);
598                 rp.icb = p.bcast_dim;
599 
600                 p.output_data = store_to + oc_b * store_to_ld
601                         + ic_b * jcp.ic_block * jcp.oc_block;
602 
603                 /* spatial reduction */
604                 int sp_step = 0;
605                 for (int sp = sp_start; sp < sp_end; sp += sp_step) {
606                     sp_step = step(jcp.nb_reduce_blocking, sp_end - sp,
607                             jcp.nb_reduce_blocking_max);
608                     p.reduce_dim = sp_step * jcp.reduce_block;
609                     rp.os = p.reduce_dim;
610 
611                     p.first_last_flag = sp == sp_start && first_image
612                             ? FLAG_REDUCE_FIRST
613                             : 0;
614 
615                     p.load_data = diff_dst
616                             + (oc_b * jcp.reduce_dim + sp)
617                                     * (is_ddst_layout_nxc ? jcp.oc
618                                                           : jcp.oc_block);
619 
620                     if (pd()->rtus_.reduce_src_) {
621                         const int od = sp / (jcp.oh * jcp.ow);
622                         const int sp_2d = sp % (jcp.oh * jcp.ow);
623                         const int oh = sp_2d / jcp.ow;
624                         const int ow = sp_2d % jcp.ow;
625 
626                         const int id = od * stride_d;
627                         const int ih = oh * stride_h;
628                         const int iw = ow * stride_w;
629                         rp.iw_start = iw;
630 
631                         rp.ws = rtus_space
632                                 + ithr * pd()->rtus_.space_per_thread_
633                                 + (ic_b * jcp.is + sp) * jcp.ic_block;
634                         size_t src_offset
635                                 = iw * src_d.blocking_desc().strides[ndims - 1];
636                         if (ndims > 3)
637                             src_offset += ih
638                                     * src_d.blocking_desc().strides[ndims - 2];
639                         if (ndims == 5)
640                             src_offset += id
641                                     * src_d.blocking_desc().strides[ndims - 3];
642 
643                         rp.src = src + src_offset;
644                         if (oc_b == 0) (*rtus_driver_)(&rp);
645 
646                         p.bcast_data = rp.ws;
647                     } else
648                         p.bcast_data = src
649                                 + (ic_b * jcp.reduce_dim + sp)
650                                         * (is_src_layout_nxc ? jcp.ic
651                                                              : jcp.ic_block);
652 
653                     (*kernel_)(&p);
654                 }
655             }
656         }
657     };
658 
659     auto maybe_zero_icpad = [&](const int g_start, const int g_end,
660                                     const int ocb_start, const int ocb_end) {
661         // write zeros to IC padded region.
662         const int ic_tail = jcp.ic_without_padding % jcp.ic_block;
663         if (is_ddst_layout_nxc && ic_tail != 0) {
664             for_(int g = g_start; g < g_end; ++g)
665             for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) {
666                 const int z_icb = nb_ic - 1;
667                 const size_t off = pd()->with_groups()
668                         ? diff_weights_d.blk_off(g, z_ocb, z_icb)
669                         : diff_weights_d.blk_off(z_ocb, z_icb);
670                 data_t *z_wei = diff_weights + off + ic_tail * jcp.oc_block;
671                 const int zero_work
672                         = (nb_ic * jcp.ic_block - jcp.ic_without_padding)
673                         * jcp.oc_block;
674                 PRAGMA_OMP_SIMD()
675                 for (int o = 0; o < zero_work; ++o) {
676                     z_wei[o] = 0;
677                 }
678             }
679         }
680     };
681 
682     auto ker = [&](const int ithr, const int nthr) {
683         assert(nthr == rw->balancer().nthr_);
684 
685         const int w_njobs = rw->balancer().ithr_njobs(ithr);
686         if (w_njobs == 0) return;
687 
688         /* setup: independent work (oc, ic) */
689         const int w_job_start = rw->balancer().ithr_job_off(ithr);
690         int g {0}, load_i {0}, bcast_i {0};
691         nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work,
692                 bcast_i, bcast_work);
693 
694         /* setup: reduction work (mb, sp) */
695         int mb_sp_start {0}, mb_sp_end {0};
696         balance211(mb_sp_work, rw->balancer().nthr_per_group_,
697                 rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end);
698         int img_start {0}, sp_start {0};
699         nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim);
700 
701         /* independent work */
702         for (int iwork = 0; iwork < w_njobs; ++iwork) {
703             const int oc_b = nb_oc_blocking * load_i;
704             const int ic_b = nb_ic_blocking * bcast_i;
705 
706             const int oc_off_idx = is_ddst_layout_nxc
707                     ? g * jcp.oc + oc_b * jcp.oc_block
708                     : g * nb_oc + oc_b;
709             const int ic_off_idx = is_src_layout_nxc
710                     ? g * jcp.ic + ic_b * jcp.ic_block
711                     : g * nb_ic + ic_b;
712 
713             data_t *store_to;
714             size_t store_to_ld;
715 
716             if (rw->balancer().nthr_per_group_ == 1) {
717                 const size_t off = pd()->with_groups()
718                         ? diff_weights_d.blk_off(g, oc_b, ic_b)
719                         : diff_weights_d.blk_off(oc_b, ic_b);
720                 store_to = &diff_weights[off];
721                 store_to_ld = rnd_up(jcp.ic, jcp.ic_block) * jcp.oc_block;
722             } else {
723                 const size_t off = (size_t)iwork * rw->balancer().job_size_;
724                 store_to
725                         = rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off;
726                 store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block;
727             }
728 
729             /* reduction work */
730             int img = img_start;
731             int sp = sp_start;
732             int sp_step = 0;
733             for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step) {
734                 sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp);
735 
736                 const bool first_image = img == img_start;
737                 if (is_ddst_layout_nxc && first_image
738                         && rw->balancer().nthr_per_group_ > 1) {
739                     // Zero-pad the scratchpad when nthr > 1 (since most threads
740                     // write to scratchpad) so that zero-padding is maintained
741                     // for the final output after reduction
742                     array_set(rw->get_local_ptr(ithr, reducer_wei_scratchpad)
743                                     + iwork * rw->balancer().job_size_,
744                             0, rw->balancer().job_size_);
745                 }
746                 oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to,
747                         store_to_ld,
748                         &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)],
749                         &src[src_d.blk_off(img, ic_off_idx)], ithr);
750 
751                 sp = 0;
752                 img += 1;
753             }
754 
755             if (rw->balancer().nthr_per_group_ == 1
756                     && bcast_i + 1 >= bcast_work)
757                 maybe_zero_icpad(g, g + 1, oc_b,
758                         nstl::min(nb_oc, oc_b + nb_oc_blocking));
759 
760             nd_iterator_step(
761                     g, jcp.ngroups, load_i, load_work, bcast_i, bcast_work);
762         }
763 
764         if (dnnl_thr_syncable())
765             rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
766     };
767 
768     auto ker_bias = [&](int ithr, int nthr) {
769         assert(nthr == rb->balancer().nthr_);
770 
771         const int b_job_start = rb->balancer().ithr_job_off(ithr);
772         const int b_njobs = rb->balancer().ithr_njobs(ithr);
773 
774         if (b_njobs == 0) return;
775 
776         /* reduction dimension */
777         int img_start {0}, img_end {0};
778         balance211(jcp.mb, rb->balancer().nthr_per_group_,
779                 rb->balancer().id_in_group(ithr), img_start, img_end);
780 
781         /* jobs */
782         int g_start {0}, ocb_start {0};
783         nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc);
784 
785         for (int img = img_start; img < img_end; ++img) {
786             int g = g_start, ocb = ocb_start;
787             for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
788                 const int oc_off_idx = is_ddst_layout_nxc
789                         ? g * jcp.oc + ocb * jcp.oc_block
790                         : g * nb_oc + ocb;
791 
792                 const data_t *d_dst
793                         = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)];
794                 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
795                                          reducer_bia_scratchpad)
796                         + b_job_loc * rb->balancer().job_size_;
797 
798                 if (img == img_start)
799                     for (int o = 0; o < 8; ++o)
800                         d_bias[o] = 0.;
801 
802                 const int spatial_shift
803                         = is_ddst_layout_nxc ? jcp.oc : jcp.oc_block;
804                 const int max_oc = this_block_size(
805                         ocb * jcp.oc_block, jcp.oc, jcp.oc_block);
806                 for (int hw = 0; hw < jcp.os; ++hw) {
807                     PRAGMA_OMP_SIMD()
808                     for (int o = 0; o < max_oc; ++o)
809                         d_bias[o] += d_dst[o];
810                     d_dst += spatial_shift;
811                 }
812 
813                 nd_iterator_step(g, jcp.ngroups, ocb, nb_oc);
814             }
815         }
816 
817         if (dnnl_thr_syncable())
818             rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
819     };
820 
821     if (dnnl_thr_syncable()) {
822         assert(IMPLICATION(pd()->with_bias(),
823                 rw->balancer().nthr_ == rb->balancer().nthr_));
824         parallel(rw->balancer().nthr_, [&](const int ithr, const int nthr) {
825             ker(ithr, nthr);
826             if (pd()->with_bias()) ker_bias(ithr, nthr);
827         });
828     } else {
829         parallel(rw->balancer().nthr_,
830                 [&](int ithr, int nthr) { ker(ithr, nthr); });
831         parallel(rw->balancer().nthr_, [&](int ithr, int nthr) {
832             assert(nthr == rw->balancer().nthr_);
833             MAYBE_UNUSED(nthr);
834             if (rw->balancer().ithr_njobs(ithr) == 0) return;
835             rw->reduce_nolock(ithr, diff_weights, reducer_wei_scratchpad);
836         });
837         if (pd()->with_bias()) {
838             parallel(rb->balancer().nthr_,
839                     [&](int ithr, int nthr) { ker_bias(ithr, nthr); });
840             parallel(rb->balancer().nthr_, [&](int ithr, int nthr) {
841                 assert(nthr == rb->balancer().nthr_);
842                 MAYBE_UNUSED(nthr);
843                 if (rb->balancer().ithr_njobs(ithr) == 0) return;
844                 rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad);
845             });
846         }
847     }
848 
849     /* TODO: put this in ker_bias */
850     if (is_bias_padded) {
851         assert(IMPLICATION(!is_ddst_layout_nxc, jcp.ngroups == 1));
852         const int padded_stride = utils::rnd_up(jcp.oc, jcp.oc_block);
853         const int stride = jcp.oc_without_padding;
854         for (int g = 0; g < jcp.ngroups; ++g) {
855             utils::array_copy(diff_bias_in + g * stride,
856                     diff_bias + g * padded_stride, stride);
857         }
858     }
859 }
860 
861 } // namespace x64
862 } // namespace cpu
863 } // namespace impl
864 } // namespace dnnl
865