1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/ocl/xe_lp_x8s8x_convolution.hpp"
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/type_helpers.hpp"
22 
23 namespace dnnl {
24 namespace impl {
25 namespace gpu {
26 namespace ocl {
27 
is_nhwc(const memory_desc_wrapper & src_mdw,const memory_desc_wrapper & dst_mdw)28 bool is_nhwc(const memory_desc_wrapper &src_mdw,
29         const memory_desc_wrapper &dst_mdw) {
30     using namespace format_tag;
31     const bool is_src_nhwc
32             = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
33     const bool is_dst_nhwc
34             = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
35     const bool is_nhwc = is_src_nhwc || is_dst_nhwc;
36     return is_nhwc;
37 }
38 
init_conf()39 status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_conf() {
40     using namespace format_tag;
41 
42     const memory_desc_t *src = src_md();
43     const memory_desc_t *dst = dst_md();
44     const memory_desc_t *wei = weights_md();
45     const memory_desc_t *bia = weights_md(1);
46 
47     memory_desc_t r_src, r_wei, r_dst;
48 
49     int ndims = src_md()->ndims;
50 
51     // XXX: try reduce number of spatial dims when iw/ow/kw=1,
52     // memory tags will be selected based on the number of input dimensions
53     bool use_reshaped_mem = ndims > 3;
54     if (dnnl_memory_desc_reshape(&r_src, src, src->ndims - 1, src->dims)
55             != status::success)
56         use_reshaped_mem = false;
57     if (dnnl_memory_desc_reshape(&r_dst, dst, dst->ndims - 1, dst->dims)
58             != status::success)
59         use_reshaped_mem = false;
60     if (dnnl_memory_desc_reshape(&r_wei, wei, wei->ndims - 1, wei->dims)
61             != status::success)
62         use_reshaped_mem = false;
63 
64     if (use_reshaped_mem) {
65         src = &r_src;
66         dst = &r_dst;
67         wei = &r_wei;
68     }
69 
70     const convolution_desc_t &cd = *desc();
71     const memory_desc_wrapper src_mdw(src);
72     const memory_desc_wrapper weights_mdw(wei);
73     const memory_desc_wrapper dst_mdw(dst);
74 
75     set_default_conf(conf, cd, *src, *wei, *dst, *bia, *attr());
76 
77     const bool is_1stconv = conf.ic_without_padding <= 4 && !conf.is_depthwise;
78 
79     conf.is_nhwc = is_nhwc(src_mdw, dst_mdw);
80     conf.is_dst_nhwc
81             = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
82     // TODO: Add group convolution support in NHWC kernel.
83     if (!conf.is_depthwise && conf.with_groups && conf.ngroups > 1
84             && (conf.oc % 32 != 0 || conf.ic % 32 != 0))
85         return status::unimplemented;
86 
87     conf.dst_data_type = dst_mdw.data_type();
88     conf.src_data_type = src_mdw.data_type();
89 
90     conf.oc_block = 32;
91     conf.ic_block = 32;
92     conf.mb_block = 1;
93     conf.ow_block = 1;
94 
95     if (conf.is_nhwc) {
96         conf.ver = ver_nhwc;
97         if (conf.is_depthwise) {
98             if (!(conf.kw <= 4 && conf.stride_w <= 2 && conf.dilate_w == 0
99                         && conf.l_pad < 4)) {
100                 conf.mb_block = 32;
101             } else {
102                 int off = conf.kw == 4 ? 1 : 0;
103                 if (conf.ow < 15 - off) {
104                     conf.ow_block = conf.ow;
105                 } else {
106                     for (int i = 0; i < 7; ++i) {
107                         conf.ow_block = utils::max_div(conf.ow + i, 14 - off);
108                         if (conf.ow_block > 4) break;
109                     }
110                 }
111             }
112 
113             int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
114 
115             conf.sub_group_size = 16;
116 
117             conf.lws_d[0] = 16;
118             conf.lws_d[1] = 1;
119             conf.lws_d[2] = 1;
120 
121             conf.gws_d[0] = utils::div_up(conf.ngroups, 32) * conf.lws_d[0];
122             conf.gws_d[1] = conf.od * conf.oh * ow_nchunk;
123             conf.gws_d[2] = utils::div_up(conf.mb,
124                     utils::div_up(conf.mb_block, conf.mb_block == 32 ? 4 : 1));
125         } else {
126             if (!is_1stconv) {
127                 conf.ow_block
128                         = (conf.mb * conf.oc * conf.oh * conf.ow < 49 * 1024)
129                         ? 4
130                         : 8;
131             } else { // 1st conv
132                 conf.ic_block = 4;
133                 conf.ow_block = (conf.kw * conf.kh <= 49 && conf.ow % 16 < 8)
134                         ? 16
135                         : 12;
136                 if (conf.mb == 8 || conf.mb % 16 == 0) { conf.mb_block = 32; }
137             }
138 
139             int max_oc = 4;
140             int oc_group = utils::max_div(
141                     utils::div_up(conf.oc, conf.oc_block), max_oc);
142             int max_subgroups = 32;
143             int max_ow_group = max_subgroups / oc_group;
144             int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
145             int ow_group = utils::max_div(ow_nchunk, max_ow_group);
146 
147             conf.sub_group_size = 8;
148             conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block);
149             conf.src_slm_size = conf.ic_block / 4
150                     * (ow_group * conf.stride_w * conf.ow_block
151                             + (conf.kw - 1) * (1 + conf.dilate_w));
152 
153             conf.lws_d[0] = 8 * oc_group;
154             conf.lws_d[1] = ow_group;
155             conf.lws_d[2] = 1;
156 
157             conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
158             conf.gws_d[1]
159                     = conf.od * conf.oh * utils::rnd_up(ow_nchunk, ow_group);
160             conf.gws_d[2] = is_1stconv
161                     ? utils::rnd_up(conf.mb, conf.mb_block)
162                     : utils::div_up(conf.mb,
163                             utils::div_up(conf.mb_block,
164                                     conf.mb_block == 32 ? 2 : 1));
165         }
166 
167     } else if (conf.is_depthwise) {
168         if (conf.mb == 8 || conf.mb % 16 == 0
169                 || !(conf.kw <= 4 && conf.stride_w <= 2 && conf.dilate_w == 0
170                         && conf.l_pad < 4)) {
171             conf.ver = ver_mb_block;
172             conf.mb_block = 32;
173         } else {
174             conf.ver = ver_ow_block;
175             int off = conf.kw == 4 ? 1 : 0;
176             // Try to do not use ow blocks of size > 10 as there is
177             // a lot of GRF memory used what leads to spills
178             if (conf.ow < 10 - off) {
179                 conf.ow_block = conf.ow;
180             } else {
181                 for (int i = 0; i < 7; ++i) {
182                     conf.ow_block = utils::max_div(conf.ow + i, 10 - off);
183                     if (conf.ow_block > 4) break;
184                 }
185             }
186         }
187 
188         conf.sub_group_size = 16;
189         const int spatial_global_size
190                 = conf.od * conf.oh * utils::div_up(conf.ow, conf.ow_block);
191 
192         conf.lws_d[0] = 16;
193         conf.lws_d[1] = 1;
194         if (conf.ver == ver_mb_block) {
195             // Try to increase WG size in order to improve caching
196             for (const int pixels_per_wg : {2, 3, 5}) {
197                 if (spatial_global_size % pixels_per_wg == 0) {
198                     conf.lws_d[1] = pixels_per_wg;
199                     break;
200                 }
201             }
202         }
203         conf.lws_d[2] = 1;
204 
205         conf.gws_d[0] = utils::div_up(conf.ngroups, 32) * conf.lws_d[0];
206         conf.gws_d[1] = spatial_global_size;
207         conf.gws_d[2] = (conf.mb_block == 32 ? 4 : 1)
208                 * utils::div_up(conf.mb, conf.mb_block);
209 
210     } else {
211         if (conf.mb % 16 == 0) {
212             conf.ver = ver_mb_block;
213             conf.mb_block = 32;
214         } else {
215             conf.ver = ver_ow_block;
216         }
217         if (conf.ic <= 4) conf.ver = ver_1stconv;
218 
219         int max_oc = 4;
220         int oc_group
221                 = utils::max_div(utils::div_up(conf.oc, conf.oc_block), max_oc);
222         int max_subgroups = 32;
223         int max_ow_group = max_subgroups / oc_group;
224         int ow_group = 1;
225         int ow_nchunk = 1;
226 
227         conf.sub_group_size = 8;
228         conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block);
229 
230         switch (conf.ver) {
231             case ver_mb_block:
232                 oc_group = 1;
233                 conf.ow_block = 1;
234                 ow_group = 1;
235                 break;
236             case ver_ow_block:
237                 conf.ow_block
238                         = (conf.mb * conf.oc * conf.oh * conf.ow < 49 * 1024)
239                         ? 4
240                         : 8;
241                 ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
242                 ow_group = utils::max_div(ow_nchunk, max_ow_group);
243                 break;
244             case ver_1stconv:
245                 conf.ic_block = 4;
246                 conf.ow_block = (conf.kw * conf.kh <= 49 && conf.ow % 16 < 8)
247                         ? 16
248                         : 12;
249                 ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
250                 ow_group = utils::max_div(ow_nchunk, max_ow_group);
251                 if (ow_group == 1)
252                     ow_group = utils::max_div(ow_nchunk + 1, max_ow_group);
253                 break;
254         }
255 
256         conf.src_slm_size = conf.ic_block / 4
257                 * (ow_group * conf.stride_w * conf.ow_block
258                         + (conf.kw - 1) * (1 + conf.dilate_w));
259 
260         conf.lws_d[0] = 8 * oc_group;
261         conf.lws_d[1] = ow_group;
262         conf.lws_d[2] = 1;
263 
264         conf.src_slm_size = conf.ic_block / 4
265                 * (conf.lws_d[1] * conf.stride_w * conf.ow_block
266                         + (conf.kw - 1) * (1 + conf.dilate_w) + conf.l_pad);
267 
268         conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
269         conf.gws_d[1] = conf.od * conf.oh
270                 * utils::rnd_up(
271                         utils::div_up(conf.ow, conf.ow_block), ow_group);
272         conf.gws_d[2] = (conf.mb_block == 32 ? 2 : 1)
273                 * utils::div_up(conf.mb, conf.mb_block);
274 
275         if (conf.ver == ver_1stconv) {
276             conf.gws_d[2] = utils::rnd_up(conf.mb, conf.mb_block);
277             // Save opportunity to use this implementation with nchw formats,
278             // which will result in worse performance, but prevent us using reorder.
279             // That can be efficient in some cases.
280             conf.is_nchw = src_mdw.matches_one_of_tag(ncw, nchw, ncdhw)
281                     || src_mdw.format_kind() == format_kind::any;
282             // decrease src ic_block in case of input nchw
283             if (conf.is_nchw) conf.ic_block = 1;
284         }
285     }
286 
287     // TODO: add support for nhwc and dw ow_block
288     const bool has_compensation = conf.attr_info.with_src_zpoints
289             || conf.attr_info.with_dst_zpoints;
290     if (has_compensation)
291         if (conf.is_nhwc || (conf.is_depthwise && conf.mb_block != 32))
292             return status::unimplemented;
293 
294     conf.with_bias = cd.bias_desc.format_kind != format_kind::undef;
295 
296     format_tag_t src_tag, dst_tag, wei_tag;
297 
298     if (conf.is_nhwc) {
299         src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
300         dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
301 
302         if (is_1stconv) {
303             wei_tag = conf.with_groups
304                     ? utils::pick(ndims - 3, gOIw8o4i, gOIhw8o4i, gOIdhw8o4i)
305                     : utils::pick(ndims - 3, OIw8o4i, OIhw8o4i, OIdhw8o4i);
306             if (!conf.is_dst_nhwc) {
307                 if (conf.mb_block == 32) {
308                     dst_tag = utils::pick(
309                             ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
310                 } else {
311                     dst_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
312                 }
313             }
314         } else if (conf.is_depthwise) {
315             wei_tag = utils::pick(ndims - 3, Goiw32g, Goihw32g, Goidhw32g);
316         } else {
317             wei_tag = conf.with_groups ? utils::pick(ndims - 3, gOIw4o8i8o4i,
318                               gOIhw4o8i8o4i, gOIdhw4o8i8o4i)
319                                        : utils::pick(ndims - 3, OIw4o8i8o4i,
320                                                OIhw4o8i8o4i, OIdhw4o8i8o4i);
321         }
322 
323     } else {
324         if (conf.mb_block == 32) {
325             src_tag = utils::pick(
326                     ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
327             dst_tag = utils::pick(
328                     ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
329         } else {
330             src_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
331             dst_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
332         }
333 
334         if (!conf.is_depthwise && conf.ver == ver_1stconv) {
335             src_tag = (conf.is_nchw)
336                     ? utils::pick(ndims - 3, ncw, nchw, ncdhw)
337                     : utils::pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
338         }
339 
340         if (conf.is_depthwise) {
341             wei_tag = utils::pick(ndims - 3, Goiw32g, Goihw32g, Goidhw32g);
342         } else {
343             if (conf.ver == ver_1stconv) {
344                 wei_tag = conf.with_groups
345                         ? utils::pick(
346                                 ndims - 3, gOIw8o4i, gOIhw8o4i, gOIdhw8o4i)
347                         : utils::pick(ndims - 3, OIw8o4i, OIhw8o4i, OIdhw8o4i);
348             } else {
349                 wei_tag = conf.with_groups ? utils::pick(ndims - 3,
350                                   gOIw4o8i8o4i, gOIhw4o8i8o4i, gOIdhw4o8i8o4i)
351                                            : utils::pick(ndims - 3, OIw4o8i8o4i,
352                                                    OIhw4o8i8o4i, OIdhw4o8i8o4i);
353             }
354         }
355     }
356 
357     conf.src_tag = src_mdw.format_kind() == format_kind::any
358             ? src_tag
359             : src_mdw.matches_one_of_tag(src_tag);
360     conf.wei_tag = weights_mdw.format_kind() == format_kind::any
361             ? wei_tag
362             : weights_mdw.matches_one_of_tag(wei_tag);
363     conf.dst_tag = dst_mdw.format_kind() == format_kind::any
364             ? dst_tag
365             : dst_mdw.matches_one_of_tag(dst_tag);
366 
367     if (conf.src_tag != src_tag || conf.wei_tag != wei_tag
368             || conf.dst_tag != dst_tag)
369         return status::unimplemented;
370 
371     return status::success;
372 }
373 
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const374 status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_kernel_ctx(
375         compute::kernel_ctx_t &kernel_ctx) const {
376     int owx = nstl::max(
377             1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w));
378     int ow_block_with_stride = conf.stride_w * conf.ow_block;
379     int iw_with_l_pad = conf.iw + conf.l_pad;
380     int iw_len = iw_with_l_pad < ow_block_with_stride + conf.kw - 1
381             ? iw_with_l_pad - ow_block_with_stride
382             : iw_with_l_pad % ow_block_with_stride;
383     int iw_tail
384             = iw_len < (conf.kw - 1) ? ow_block_with_stride + iw_len : iw_len;
385     int ow_tail = conf.ow % conf.ow_block;
386     int iw_nchunk = utils::div_up(conf.iw, ow_block_with_stride);
387     int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
388     int min_w_nchunk = nstl::min(ow_nchunk, iw_nchunk);
389     int slm_tail
390             = conf.iw - (conf.stride_w * conf.ow_block * (min_w_nchunk - 1));
391     int zero_tail = utils::rnd_up(conf.ow, conf.ow_block) * conf.stride_w
392             - conf.iw + (conf.kw - 1) * (1 + conf.dilate_w) - conf.l_pad;
393 
394     kernel_ctx.define_int("NCHW", conf.is_nchw);
395     kernel_ctx.define_int("DST_NHWC", conf.is_dst_nhwc);
396     kernel_ctx.define_int("G", conf.ngroups);
397     kernel_ctx.define_int("MB", conf.mb);
398     kernel_ctx.define_int("IC", conf.ic);
399     kernel_ctx.define_int("ID", conf.id);
400     kernel_ctx.define_int("IH", conf.ih);
401     kernel_ctx.define_int("IW", conf.iw);
402     kernel_ctx.define_int("OC", conf.oc);
403     kernel_ctx.define_int("OD", conf.od);
404     kernel_ctx.define_int("OH", conf.oh);
405     kernel_ctx.define_int("OW", conf.ow);
406     kernel_ctx.define_int("KD", conf.kd);
407     kernel_ctx.define_int("KH", conf.kh);
408     kernel_ctx.define_int("KW", conf.kw);
409     kernel_ctx.define_int("SD", conf.stride_d);
410     kernel_ctx.define_int("SH", conf.stride_h);
411     kernel_ctx.define_int("SW", conf.stride_w);
412     kernel_ctx.define_int("PD", conf.f_pad);
413     kernel_ctx.define_int("PH", conf.t_pad);
414     kernel_ctx.define_int("PW", conf.l_pad);
415     kernel_ctx.define_int("DD", conf.dilate_d);
416     kernel_ctx.define_int("DH", conf.dilate_h);
417     kernel_ctx.define_int("DW", conf.dilate_w);
418 
419     kernel_ctx.define_int("OW_PADDED",
420             utils::rnd_up(
421                     utils::div_up(conf.ow, conf.ow_block), conf.lws_d[1]));
422     int ow = nstl::max(
423             1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w));
424     kernel_ctx.define_int("OWX", ow);
425     kernel_ctx.define_int("OWB", utils::div_up(conf.ow, conf.ow_block));
426 
427     kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
428     kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
429     kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
430     kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
431     kernel_ctx.define_int("SRC_SLM_SIZE", conf.src_slm_size);
432     kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
433     kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
434     kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
435     kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
436     kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
437 
438     kernel_ctx.define_int("OW_TAIL", ow_tail);
439     kernel_ctx.define_int("IW_TAIL", iw_tail);
440     kernel_ctx.define_int("SLM_TAIL", slm_tail);
441     kernel_ctx.define_int("ZERO_TAIL", zero_tail);
442 
443     kernel_ctx.define_int("OW_PADDED", utils::rnd_up(ow_nchunk, conf.lws_d[1]));
444     kernel_ctx.define_int("G_PADDED",
445             utils::div_up(conf.ngroups, conf.oc_block) * conf.oc_block);
446 
447     kernel_ctx.define_int("MB_GROUP", 1);
448     kernel_ctx.define_int("SP_GROUP", conf.lws_d[1]);
449     kernel_ctx.define_int("OC_GROUP", utils::div_up(conf.lws_d[0], 8));
450 
451     kernel_ctx.define_int("OC_NCHUNK", utils::div_up(conf.oc, conf.oc_block));
452     kernel_ctx.define_int("IC_NCHUNK", utils::div_up(conf.ic, conf.ic_block));
453     kernel_ctx.define_int("OW_NCHUNK", ow_nchunk);
454     kernel_ctx.define_int("SLM_NCHUNK", min_w_nchunk);
455     kernel_ctx.define_int("OWB", ow_nchunk);
456     kernel_ctx.define_int("OWX", owx);
457 
458     kernel_ctx.define_int("DISABLE_DPAS", disable_dpas);
459 
460     if (conf.is_depthwise)
461         kernel_ctx.define_int("WEI_32G", 1);
462     else
463         kernel_ctx.define_int("WEI_4O8I8O4I", 1);
464 
465     kernel_ctx.set_data_type(conf.dst_data_type);
466 
467     def_data_type(kernel_ctx, conf.src_data_type, "SRC");
468     def_data_type(kernel_ctx, conf.dst_data_type, "DST");
469     def_data_type(kernel_ctx,
470             conf.attr_info.sum_data_type == dnnl_data_type_undef
471                     ? conf.dst_data_type
472                     : conf.attr_info.sum_data_type,
473             "SUM");
474 
475     def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_);
476 
477     kernel_ctx.add_option("-Dcl_intel_subgroups_char");
478     kernel_ctx.add_option("-Dcl_intel_subgroups_long");
479 
480     return status::success;
481 }
482 
init_scratchpad()483 void xe_lp_x8s8x_convolution_fwd_t::pd_t::init_scratchpad() {
484     if (conf.attr_info.with_src_zpoints) {
485         size_t size = conf.is_depthwise
486                 ? utils::rnd_up(conf.ngroups, 32)
487                 : conf.ngroups * utils::rnd_up(conf.oc, 32);
488 
489         auto scratchpad = scratchpad_registry().registrar();
490         scratchpad.book(memory_tracking::names::key_conv_wei_reduction, size,
491                 types::data_type_size(data_type::s32), OCL_BUFFER_ALIGNMENT);
492     }
493 }
494 
execute_forward(const exec_ctx_t & ctx) const495 status_t xe_lp_x8s8x_convolution_fwd_t::execute_forward(
496         const exec_ctx_t &ctx) const {
497     auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
498     auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
499     auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
500     auto &oscales = CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES);
501     auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
502     auto &src_zpoints
503             = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
504     auto &dst_zpoints
505             = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
506 
507     const auto &conf = pd()->conf;
508 
509     // XXX: first convolution calculates compensation in-place
510     const bool precompute_compensation = conf.is_depthwise || conf.ic > 4;
511 
512     std::unique_ptr<memory_storage_t> temp_src_compensation;
513     if (conf.attr_info.with_src_zpoints && precompute_compensation) {
514         temp_src_compensation = ctx.get_scratchpad_grantor().get_memory_storage(
515                 memory_tracking::names::key_conv_wei_reduction);
516 
517         compute::kernel_arg_list_t arg_list;
518         arg_list.set(0, src_zpoints);
519         arg_list.set(1, weights);
520         arg_list.set(2, *temp_src_compensation);
521 
522         auto nd_range = conf.is_depthwise
523                 ? compute::nd_range_t(
524                         {16, utils::div_up(conf.ngroups, 32), 1}, {16, 1, 1})
525                 : compute::nd_range_t(
526                         {8, utils::div_up(conf.oc, 32), conf.ngroups},
527                         {8, 1, 1});
528         status_t status = parallel_for(
529                 ctx, nd_range, src_compensation_kernel_, arg_list);
530         if (status != status::success) return status::runtime_error;
531     }
532 
533     compute::kernel_arg_list_t arg_list;
534     arg_list.set(0, src);
535     arg_list.set(1, weights);
536     arg_list.set(2, bias);
537     arg_list.set(3, dst);
538 
539     unsigned arg_idx = append_post_ops_to_arg_list(
540             ctx, arg_list, 4, pd()->attr()->post_ops_);
541 
542     if (conf.attr_info.common_oscales) {
543         float scales = pd()->attr()->output_scales_.scales_[0];
544         arg_list.set(arg_idx++, scales);
545     } else {
546         arg_list.set(arg_idx++, 1.0f);
547     }
548 
549     if (conf.attr_info.with_per_oc_oscales) {
550         if (conf.attr_info.with_runtime_oscales)
551             arg_list.set(arg_idx++, oscales);
552         else
553             arg_list.set(arg_idx++, CTX_GPU_RES_STORAGE(SCALES_));
554     } else {
555         arg_list.set(arg_idx++, memory_storage_t::empty_storage());
556     }
557 
558     if (conf.attr_info.with_src_zpoints) {
559         if (precompute_compensation)
560             arg_list.set(arg_idx++, *temp_src_compensation);
561         else
562             arg_list.set(arg_idx++, memory_storage_t::empty_storage());
563         arg_list.set(arg_idx++, src_zpoints);
564     } else {
565         arg_list.set(arg_idx++, memory_storage_t::empty_storage());
566         arg_list.set(arg_idx++, memory_storage_t::empty_storage());
567     }
568 
569     if (conf.attr_info.with_dst_zpoints)
570         arg_list.set(arg_idx++, dst_zpoints);
571     else
572         arg_list.set(arg_idx++, memory_storage_t::empty_storage());
573 
574     auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
575     status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
576 
577     if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) {
578         ctx.zero_pad_output(DNNL_ARG_DST);
579     }
580     return status;
581 }
582 
init_conf()583 status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_conf() {
584     using namespace format_tag;
585 
586     const convolution_desc_t &cd = *desc();
587     const memory_desc_wrapper src_mdw(diff_src_md());
588     const memory_desc_wrapper weights_mdw(weights_md());
589     const memory_desc_wrapper dst_mdw(diff_dst_md());
590     const memory_desc_wrapper bias_mdw(weights_md(1));
591 
592     set_default_conf(conf, cd, *diff_src_md(), *weights_md(), *diff_dst_md(),
593             *weights_md(1), *attr());
594 
595     conf.is_nhwc = is_nhwc(src_mdw, dst_mdw);
596 
597     if (conf.with_groups && conf.ngroups > 1
598             && (conf.oc % 32 != 0 || conf.ic % 32 != 0))
599         return status::unimplemented;
600 
601     if (!conf.is_nhwc) {
602         if (conf.mb % 16 == 0) {
603             conf.ver = ver_mb_block;
604         } else {
605             conf.ver = ver_ow_block;
606         }
607     }
608 
609     conf.oc_block = 32;
610     conf.ic_block = 32;
611     conf.iw_block = 1;
612 
613     conf.sub_group_size = 8;
614     conf.nchunk = utils::div_up(conf.ic * conf.ngroups, conf.ic_block);
615     int ic_group = nstl::min(conf.nchunk, 2);
616 
617     if (conf.ver == ver_ow_block || conf.is_nhwc) {
618         conf.mb_block = 1;
619         int max_ic = 4;
620         ic_group
621                 = utils::max_div(utils::div_up(conf.ic, conf.ic_block), max_ic);
622         int max_subgroups = 32;
623         int max_iw_group = max_subgroups / ic_group;
624         conf.iw_block
625                 = (conf.mb * conf.ic * conf.ih * conf.iw < 49 * 1024) ? 4 : 8;
626         int iw_nchunk = utils::div_up(conf.iw, conf.iw_block);
627         int iw_group = utils::max_div(iw_nchunk, max_iw_group);
628 
629         //an upper bound on the number of elems per subgroup
630         conf.dst_slm_size = (conf.oc_block / 4)
631                 * ((iw_group * conf.iw_block)
632                         + (conf.kw - 1) * (1 + conf.dilate_w));
633         conf.iw_tail = conf.iw % conf.iw_block;
634 
635         conf.lws_d[0] = 8 * ic_group;
636         conf.lws_d[1] = iw_group;
637         conf.lws_d[2] = 1;
638 
639         conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
640         conf.gws_d[1] = conf.id * conf.ih * iw_nchunk;
641         conf.gws_d[2] = utils::div_up(conf.mb, utils::div_up(conf.mb_block, 2));
642     } else { //ver_mb_block
643         conf.mb_block = 32;
644         conf.lws_d[0] = 8 * ic_group;
645         conf.lws_d[1] = 8;
646         conf.lws_d[2] = 1;
647 
648         conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
649         conf.gws_d[1]
650                 = conf.id * conf.ih * utils::rnd_up(conf.iw, conf.lws_d[1]);
651         conf.gws_d[2] = 2 * utils::div_up(conf.mb, conf.mb_block);
652     }
653     conf.with_bias = cd.bias_desc.format_kind != format_kind::undef;
654 
655     format_tag_t src_tag, dst_tag, wei_tag;
656 
657     if (conf.is_nhwc) {
658         src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
659         dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
660     } else {
661         src_tag = (conf.ver == ver_ow_block)
662                 ? utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c)
663                 : utils::pick(
664                         conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
665         dst_tag = (conf.ver == ver_ow_block)
666                 ? utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c)
667                 : utils::pick(
668                         conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
669     }
670 
671     wei_tag = conf.with_groups ? utils::pick(conf.ndims - 3, gIOw4i8o8i4o,
672                       gIOhw4i8o8i4o, gIOdhw4i8o8i4o)
673                                : utils::pick(conf.ndims - 3, IOw4i8o8i4o,
674                                        IOhw4i8o8i4o, IOdhw4i8o8i4o);
675 
676     conf.dst_data_type = dst_mdw.data_type();
677     conf.src_data_type = src_mdw.data_type();
678 
679     conf.src_tag = src_mdw.format_kind() == format_kind::any
680             ? src_tag
681             : src_mdw.matches_one_of_tag(src_tag);
682     conf.wei_tag = weights_mdw.format_kind() == format_kind::any
683             ? wei_tag
684             : weights_mdw.matches_one_of_tag(wei_tag);
685     conf.dst_tag = dst_mdw.format_kind() == format_kind::any
686             ? dst_tag
687             : dst_mdw.matches_one_of_tag(dst_tag);
688 
689     if (conf.src_tag != src_tag || conf.wei_tag != wei_tag
690             || conf.dst_tag != dst_tag)
691         return status::unimplemented;
692 
693     return status::success;
694 }
695 
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const696 status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_kernel_ctx(
697         compute::kernel_ctx_t &kernel_ctx) const {
698     kernel_ctx.define_int("G", conf.ngroups);
699     kernel_ctx.define_int("MB", conf.mb);
700     kernel_ctx.define_int("IC", conf.ic);
701     kernel_ctx.define_int("ID", conf.id);
702     kernel_ctx.define_int("IH", conf.ih);
703     kernel_ctx.define_int("IW", conf.iw);
704     kernel_ctx.define_int("OC", conf.oc);
705     kernel_ctx.define_int("OD", conf.od);
706     kernel_ctx.define_int("OH", conf.oh);
707     kernel_ctx.define_int("OW", conf.ow);
708     kernel_ctx.define_int("KD", conf.kd);
709     kernel_ctx.define_int("KH", conf.kh);
710     kernel_ctx.define_int("KW", conf.kw);
711     kernel_ctx.define_int("SD", conf.stride_d);
712     kernel_ctx.define_int("SH", conf.stride_h);
713     kernel_ctx.define_int("SW", conf.stride_w);
714     kernel_ctx.define_int("PD", conf.f_pad);
715     kernel_ctx.define_int("PH", conf.t_pad);
716     kernel_ctx.define_int("PW", conf.l_pad);
717     kernel_ctx.define_int("DD", conf.dilate_d);
718     kernel_ctx.define_int("DH", conf.dilate_h);
719     kernel_ctx.define_int("DW", conf.dilate_w);
720 
721     kernel_ctx.define_int("IW_PADDED", utils::rnd_up(conf.iw, conf.lws_d[1]));
722     kernel_ctx.define_int("IW_TAIL", conf.iw_tail);
723 
724     kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
725     kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
726     kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
727     kernel_ctx.define_int("IW_BLOCK", conf.iw_block);
728 
729     kernel_ctx.define_int("MB_GROUP", 1);
730     kernel_ctx.define_int("IC_GROUP", utils::div_up(conf.lws_d[0], 8));
731     kernel_ctx.define_int("SP_GROUP", conf.lws_d[1]);
732 
733     kernel_ctx.define_int("IW_NCHUNK", utils::div_up(conf.iw, conf.iw_block));
734     kernel_ctx.define_int("OC_NCHUNK", utils::div_up(conf.oc, conf.oc_block));
735     kernel_ctx.define_int("IC_NCHUNK", utils::div_up(conf.ic, conf.ic_block));
736 
737     kernel_ctx.define_int("DST_SLM_SIZE", conf.dst_slm_size);
738     kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
739 
740     kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
741 
742     kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
743     kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
744     kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
745 
746     kernel_ctx.define_int("IS_NHWC", conf.is_nhwc);
747 
748     kernel_ctx.define_int("DISABLE_DPAS", disable_dpas);
749 
750     kernel_ctx.set_data_type(conf.dst_data_type);
751     def_data_type(kernel_ctx, conf.src_data_type, "SRC");
752     def_data_type(kernel_ctx, conf.dst_data_type, "DST");
753     kernel_ctx.add_option("-Dcl_intel_subgroups_char");
754 
755     return status::success;
756 }
757 
execute_backward_data(const exec_ctx_t & ctx) const758 status_t xe_lp_x8s8x_convolution_bwd_data_t::execute_backward_data(
759         const exec_ctx_t &ctx) const {
760 
761     auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
762     auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
763     auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
764     auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
765 
766     const auto &conf = pd()->conf;
767 
768     compute::kernel_arg_list_t arg_list;
769     arg_list.set(0, diff_src);
770     arg_list.set(1, weights);
771     arg_list.set(2, bias);
772     arg_list.set(3, diff_dst);
773 
774     auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
775     status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
776 
777     return status;
778 }
779 
780 } // namespace ocl
781 } // namespace gpu
782 } // namespace impl
783 } // namespace dnnl
784