1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gpu/ocl/xe_lp_x8s8x_convolution.hpp"
18
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/type_helpers.hpp"
22
23 namespace dnnl {
24 namespace impl {
25 namespace gpu {
26 namespace ocl {
27
is_nhwc(const memory_desc_wrapper & src_mdw,const memory_desc_wrapper & dst_mdw)28 bool is_nhwc(const memory_desc_wrapper &src_mdw,
29 const memory_desc_wrapper &dst_mdw) {
30 using namespace format_tag;
31 const bool is_src_nhwc
32 = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
33 const bool is_dst_nhwc
34 = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
35 const bool is_nhwc = is_src_nhwc || is_dst_nhwc;
36 return is_nhwc;
37 }
38
init_conf()39 status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_conf() {
40 using namespace format_tag;
41
42 const memory_desc_t *src = src_md();
43 const memory_desc_t *dst = dst_md();
44 const memory_desc_t *wei = weights_md();
45 const memory_desc_t *bia = weights_md(1);
46
47 memory_desc_t r_src, r_wei, r_dst;
48
49 int ndims = src_md()->ndims;
50
51 // XXX: try reduce number of spatial dims when iw/ow/kw=1,
52 // memory tags will be selected based on the number of input dimensions
53 bool use_reshaped_mem = ndims > 3;
54 if (dnnl_memory_desc_reshape(&r_src, src, src->ndims - 1, src->dims)
55 != status::success)
56 use_reshaped_mem = false;
57 if (dnnl_memory_desc_reshape(&r_dst, dst, dst->ndims - 1, dst->dims)
58 != status::success)
59 use_reshaped_mem = false;
60 if (dnnl_memory_desc_reshape(&r_wei, wei, wei->ndims - 1, wei->dims)
61 != status::success)
62 use_reshaped_mem = false;
63
64 if (use_reshaped_mem) {
65 src = &r_src;
66 dst = &r_dst;
67 wei = &r_wei;
68 }
69
70 const convolution_desc_t &cd = *desc();
71 const memory_desc_wrapper src_mdw(src);
72 const memory_desc_wrapper weights_mdw(wei);
73 const memory_desc_wrapper dst_mdw(dst);
74
75 set_default_conf(conf, cd, *src, *wei, *dst, *bia, *attr());
76
77 const bool is_1stconv = conf.ic_without_padding <= 4 && !conf.is_depthwise;
78
79 conf.is_nhwc = is_nhwc(src_mdw, dst_mdw);
80 conf.is_dst_nhwc
81 = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
82 // TODO: Add group convolution support in NHWC kernel.
83 if (!conf.is_depthwise && conf.with_groups && conf.ngroups > 1
84 && (conf.oc % 32 != 0 || conf.ic % 32 != 0))
85 return status::unimplemented;
86
87 conf.dst_data_type = dst_mdw.data_type();
88 conf.src_data_type = src_mdw.data_type();
89
90 conf.oc_block = 32;
91 conf.ic_block = 32;
92 conf.mb_block = 1;
93 conf.ow_block = 1;
94
95 if (conf.is_nhwc) {
96 conf.ver = ver_nhwc;
97 if (conf.is_depthwise) {
98 if (!(conf.kw <= 4 && conf.stride_w <= 2 && conf.dilate_w == 0
99 && conf.l_pad < 4)) {
100 conf.mb_block = 32;
101 } else {
102 int off = conf.kw == 4 ? 1 : 0;
103 if (conf.ow < 15 - off) {
104 conf.ow_block = conf.ow;
105 } else {
106 for (int i = 0; i < 7; ++i) {
107 conf.ow_block = utils::max_div(conf.ow + i, 14 - off);
108 if (conf.ow_block > 4) break;
109 }
110 }
111 }
112
113 int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
114
115 conf.sub_group_size = 16;
116
117 conf.lws_d[0] = 16;
118 conf.lws_d[1] = 1;
119 conf.lws_d[2] = 1;
120
121 conf.gws_d[0] = utils::div_up(conf.ngroups, 32) * conf.lws_d[0];
122 conf.gws_d[1] = conf.od * conf.oh * ow_nchunk;
123 conf.gws_d[2] = utils::div_up(conf.mb,
124 utils::div_up(conf.mb_block, conf.mb_block == 32 ? 4 : 1));
125 } else {
126 if (!is_1stconv) {
127 conf.ow_block
128 = (conf.mb * conf.oc * conf.oh * conf.ow < 49 * 1024)
129 ? 4
130 : 8;
131 } else { // 1st conv
132 conf.ic_block = 4;
133 conf.ow_block = (conf.kw * conf.kh <= 49 && conf.ow % 16 < 8)
134 ? 16
135 : 12;
136 if (conf.mb == 8 || conf.mb % 16 == 0) { conf.mb_block = 32; }
137 }
138
139 int max_oc = 4;
140 int oc_group = utils::max_div(
141 utils::div_up(conf.oc, conf.oc_block), max_oc);
142 int max_subgroups = 32;
143 int max_ow_group = max_subgroups / oc_group;
144 int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
145 int ow_group = utils::max_div(ow_nchunk, max_ow_group);
146
147 conf.sub_group_size = 8;
148 conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block);
149 conf.src_slm_size = conf.ic_block / 4
150 * (ow_group * conf.stride_w * conf.ow_block
151 + (conf.kw - 1) * (1 + conf.dilate_w));
152
153 conf.lws_d[0] = 8 * oc_group;
154 conf.lws_d[1] = ow_group;
155 conf.lws_d[2] = 1;
156
157 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
158 conf.gws_d[1]
159 = conf.od * conf.oh * utils::rnd_up(ow_nchunk, ow_group);
160 conf.gws_d[2] = is_1stconv
161 ? utils::rnd_up(conf.mb, conf.mb_block)
162 : utils::div_up(conf.mb,
163 utils::div_up(conf.mb_block,
164 conf.mb_block == 32 ? 2 : 1));
165 }
166
167 } else if (conf.is_depthwise) {
168 if (conf.mb == 8 || conf.mb % 16 == 0
169 || !(conf.kw <= 4 && conf.stride_w <= 2 && conf.dilate_w == 0
170 && conf.l_pad < 4)) {
171 conf.ver = ver_mb_block;
172 conf.mb_block = 32;
173 } else {
174 conf.ver = ver_ow_block;
175 int off = conf.kw == 4 ? 1 : 0;
176 // Try to do not use ow blocks of size > 10 as there is
177 // a lot of GRF memory used what leads to spills
178 if (conf.ow < 10 - off) {
179 conf.ow_block = conf.ow;
180 } else {
181 for (int i = 0; i < 7; ++i) {
182 conf.ow_block = utils::max_div(conf.ow + i, 10 - off);
183 if (conf.ow_block > 4) break;
184 }
185 }
186 }
187
188 conf.sub_group_size = 16;
189 const int spatial_global_size
190 = conf.od * conf.oh * utils::div_up(conf.ow, conf.ow_block);
191
192 conf.lws_d[0] = 16;
193 conf.lws_d[1] = 1;
194 if (conf.ver == ver_mb_block) {
195 // Try to increase WG size in order to improve caching
196 for (const int pixels_per_wg : {2, 3, 5}) {
197 if (spatial_global_size % pixels_per_wg == 0) {
198 conf.lws_d[1] = pixels_per_wg;
199 break;
200 }
201 }
202 }
203 conf.lws_d[2] = 1;
204
205 conf.gws_d[0] = utils::div_up(conf.ngroups, 32) * conf.lws_d[0];
206 conf.gws_d[1] = spatial_global_size;
207 conf.gws_d[2] = (conf.mb_block == 32 ? 4 : 1)
208 * utils::div_up(conf.mb, conf.mb_block);
209
210 } else {
211 if (conf.mb % 16 == 0) {
212 conf.ver = ver_mb_block;
213 conf.mb_block = 32;
214 } else {
215 conf.ver = ver_ow_block;
216 }
217 if (conf.ic <= 4) conf.ver = ver_1stconv;
218
219 int max_oc = 4;
220 int oc_group
221 = utils::max_div(utils::div_up(conf.oc, conf.oc_block), max_oc);
222 int max_subgroups = 32;
223 int max_ow_group = max_subgroups / oc_group;
224 int ow_group = 1;
225 int ow_nchunk = 1;
226
227 conf.sub_group_size = 8;
228 conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block);
229
230 switch (conf.ver) {
231 case ver_mb_block:
232 oc_group = 1;
233 conf.ow_block = 1;
234 ow_group = 1;
235 break;
236 case ver_ow_block:
237 conf.ow_block
238 = (conf.mb * conf.oc * conf.oh * conf.ow < 49 * 1024)
239 ? 4
240 : 8;
241 ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
242 ow_group = utils::max_div(ow_nchunk, max_ow_group);
243 break;
244 case ver_1stconv:
245 conf.ic_block = 4;
246 conf.ow_block = (conf.kw * conf.kh <= 49 && conf.ow % 16 < 8)
247 ? 16
248 : 12;
249 ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
250 ow_group = utils::max_div(ow_nchunk, max_ow_group);
251 if (ow_group == 1)
252 ow_group = utils::max_div(ow_nchunk + 1, max_ow_group);
253 break;
254 }
255
256 conf.src_slm_size = conf.ic_block / 4
257 * (ow_group * conf.stride_w * conf.ow_block
258 + (conf.kw - 1) * (1 + conf.dilate_w));
259
260 conf.lws_d[0] = 8 * oc_group;
261 conf.lws_d[1] = ow_group;
262 conf.lws_d[2] = 1;
263
264 conf.src_slm_size = conf.ic_block / 4
265 * (conf.lws_d[1] * conf.stride_w * conf.ow_block
266 + (conf.kw - 1) * (1 + conf.dilate_w) + conf.l_pad);
267
268 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
269 conf.gws_d[1] = conf.od * conf.oh
270 * utils::rnd_up(
271 utils::div_up(conf.ow, conf.ow_block), ow_group);
272 conf.gws_d[2] = (conf.mb_block == 32 ? 2 : 1)
273 * utils::div_up(conf.mb, conf.mb_block);
274
275 if (conf.ver == ver_1stconv) {
276 conf.gws_d[2] = utils::rnd_up(conf.mb, conf.mb_block);
277 // Save opportunity to use this implementation with nchw formats,
278 // which will result in worse performance, but prevent us using reorder.
279 // That can be efficient in some cases.
280 conf.is_nchw = src_mdw.matches_one_of_tag(ncw, nchw, ncdhw)
281 || src_mdw.format_kind() == format_kind::any;
282 // decrease src ic_block in case of input nchw
283 if (conf.is_nchw) conf.ic_block = 1;
284 }
285 }
286
287 // TODO: add support for nhwc and dw ow_block
288 const bool has_compensation = conf.attr_info.with_src_zpoints
289 || conf.attr_info.with_dst_zpoints;
290 if (has_compensation)
291 if (conf.is_nhwc || (conf.is_depthwise && conf.mb_block != 32))
292 return status::unimplemented;
293
294 conf.with_bias = cd.bias_desc.format_kind != format_kind::undef;
295
296 format_tag_t src_tag, dst_tag, wei_tag;
297
298 if (conf.is_nhwc) {
299 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
300 dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
301
302 if (is_1stconv) {
303 wei_tag = conf.with_groups
304 ? utils::pick(ndims - 3, gOIw8o4i, gOIhw8o4i, gOIdhw8o4i)
305 : utils::pick(ndims - 3, OIw8o4i, OIhw8o4i, OIdhw8o4i);
306 if (!conf.is_dst_nhwc) {
307 if (conf.mb_block == 32) {
308 dst_tag = utils::pick(
309 ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
310 } else {
311 dst_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
312 }
313 }
314 } else if (conf.is_depthwise) {
315 wei_tag = utils::pick(ndims - 3, Goiw32g, Goihw32g, Goidhw32g);
316 } else {
317 wei_tag = conf.with_groups ? utils::pick(ndims - 3, gOIw4o8i8o4i,
318 gOIhw4o8i8o4i, gOIdhw4o8i8o4i)
319 : utils::pick(ndims - 3, OIw4o8i8o4i,
320 OIhw4o8i8o4i, OIdhw4o8i8o4i);
321 }
322
323 } else {
324 if (conf.mb_block == 32) {
325 src_tag = utils::pick(
326 ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
327 dst_tag = utils::pick(
328 ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
329 } else {
330 src_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
331 dst_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
332 }
333
334 if (!conf.is_depthwise && conf.ver == ver_1stconv) {
335 src_tag = (conf.is_nchw)
336 ? utils::pick(ndims - 3, ncw, nchw, ncdhw)
337 : utils::pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
338 }
339
340 if (conf.is_depthwise) {
341 wei_tag = utils::pick(ndims - 3, Goiw32g, Goihw32g, Goidhw32g);
342 } else {
343 if (conf.ver == ver_1stconv) {
344 wei_tag = conf.with_groups
345 ? utils::pick(
346 ndims - 3, gOIw8o4i, gOIhw8o4i, gOIdhw8o4i)
347 : utils::pick(ndims - 3, OIw8o4i, OIhw8o4i, OIdhw8o4i);
348 } else {
349 wei_tag = conf.with_groups ? utils::pick(ndims - 3,
350 gOIw4o8i8o4i, gOIhw4o8i8o4i, gOIdhw4o8i8o4i)
351 : utils::pick(ndims - 3, OIw4o8i8o4i,
352 OIhw4o8i8o4i, OIdhw4o8i8o4i);
353 }
354 }
355 }
356
357 conf.src_tag = src_mdw.format_kind() == format_kind::any
358 ? src_tag
359 : src_mdw.matches_one_of_tag(src_tag);
360 conf.wei_tag = weights_mdw.format_kind() == format_kind::any
361 ? wei_tag
362 : weights_mdw.matches_one_of_tag(wei_tag);
363 conf.dst_tag = dst_mdw.format_kind() == format_kind::any
364 ? dst_tag
365 : dst_mdw.matches_one_of_tag(dst_tag);
366
367 if (conf.src_tag != src_tag || conf.wei_tag != wei_tag
368 || conf.dst_tag != dst_tag)
369 return status::unimplemented;
370
371 return status::success;
372 }
373
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const374 status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_kernel_ctx(
375 compute::kernel_ctx_t &kernel_ctx) const {
376 int owx = nstl::max(
377 1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w));
378 int ow_block_with_stride = conf.stride_w * conf.ow_block;
379 int iw_with_l_pad = conf.iw + conf.l_pad;
380 int iw_len = iw_with_l_pad < ow_block_with_stride + conf.kw - 1
381 ? iw_with_l_pad - ow_block_with_stride
382 : iw_with_l_pad % ow_block_with_stride;
383 int iw_tail
384 = iw_len < (conf.kw - 1) ? ow_block_with_stride + iw_len : iw_len;
385 int ow_tail = conf.ow % conf.ow_block;
386 int iw_nchunk = utils::div_up(conf.iw, ow_block_with_stride);
387 int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
388 int min_w_nchunk = nstl::min(ow_nchunk, iw_nchunk);
389 int slm_tail
390 = conf.iw - (conf.stride_w * conf.ow_block * (min_w_nchunk - 1));
391 int zero_tail = utils::rnd_up(conf.ow, conf.ow_block) * conf.stride_w
392 - conf.iw + (conf.kw - 1) * (1 + conf.dilate_w) - conf.l_pad;
393
394 kernel_ctx.define_int("NCHW", conf.is_nchw);
395 kernel_ctx.define_int("DST_NHWC", conf.is_dst_nhwc);
396 kernel_ctx.define_int("G", conf.ngroups);
397 kernel_ctx.define_int("MB", conf.mb);
398 kernel_ctx.define_int("IC", conf.ic);
399 kernel_ctx.define_int("ID", conf.id);
400 kernel_ctx.define_int("IH", conf.ih);
401 kernel_ctx.define_int("IW", conf.iw);
402 kernel_ctx.define_int("OC", conf.oc);
403 kernel_ctx.define_int("OD", conf.od);
404 kernel_ctx.define_int("OH", conf.oh);
405 kernel_ctx.define_int("OW", conf.ow);
406 kernel_ctx.define_int("KD", conf.kd);
407 kernel_ctx.define_int("KH", conf.kh);
408 kernel_ctx.define_int("KW", conf.kw);
409 kernel_ctx.define_int("SD", conf.stride_d);
410 kernel_ctx.define_int("SH", conf.stride_h);
411 kernel_ctx.define_int("SW", conf.stride_w);
412 kernel_ctx.define_int("PD", conf.f_pad);
413 kernel_ctx.define_int("PH", conf.t_pad);
414 kernel_ctx.define_int("PW", conf.l_pad);
415 kernel_ctx.define_int("DD", conf.dilate_d);
416 kernel_ctx.define_int("DH", conf.dilate_h);
417 kernel_ctx.define_int("DW", conf.dilate_w);
418
419 kernel_ctx.define_int("OW_PADDED",
420 utils::rnd_up(
421 utils::div_up(conf.ow, conf.ow_block), conf.lws_d[1]));
422 int ow = nstl::max(
423 1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w));
424 kernel_ctx.define_int("OWX", ow);
425 kernel_ctx.define_int("OWB", utils::div_up(conf.ow, conf.ow_block));
426
427 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
428 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
429 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
430 kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
431 kernel_ctx.define_int("SRC_SLM_SIZE", conf.src_slm_size);
432 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
433 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
434 kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
435 kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
436 kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
437
438 kernel_ctx.define_int("OW_TAIL", ow_tail);
439 kernel_ctx.define_int("IW_TAIL", iw_tail);
440 kernel_ctx.define_int("SLM_TAIL", slm_tail);
441 kernel_ctx.define_int("ZERO_TAIL", zero_tail);
442
443 kernel_ctx.define_int("OW_PADDED", utils::rnd_up(ow_nchunk, conf.lws_d[1]));
444 kernel_ctx.define_int("G_PADDED",
445 utils::div_up(conf.ngroups, conf.oc_block) * conf.oc_block);
446
447 kernel_ctx.define_int("MB_GROUP", 1);
448 kernel_ctx.define_int("SP_GROUP", conf.lws_d[1]);
449 kernel_ctx.define_int("OC_GROUP", utils::div_up(conf.lws_d[0], 8));
450
451 kernel_ctx.define_int("OC_NCHUNK", utils::div_up(conf.oc, conf.oc_block));
452 kernel_ctx.define_int("IC_NCHUNK", utils::div_up(conf.ic, conf.ic_block));
453 kernel_ctx.define_int("OW_NCHUNK", ow_nchunk);
454 kernel_ctx.define_int("SLM_NCHUNK", min_w_nchunk);
455 kernel_ctx.define_int("OWB", ow_nchunk);
456 kernel_ctx.define_int("OWX", owx);
457
458 kernel_ctx.define_int("DISABLE_DPAS", disable_dpas);
459
460 if (conf.is_depthwise)
461 kernel_ctx.define_int("WEI_32G", 1);
462 else
463 kernel_ctx.define_int("WEI_4O8I8O4I", 1);
464
465 kernel_ctx.set_data_type(conf.dst_data_type);
466
467 def_data_type(kernel_ctx, conf.src_data_type, "SRC");
468 def_data_type(kernel_ctx, conf.dst_data_type, "DST");
469 def_data_type(kernel_ctx,
470 conf.attr_info.sum_data_type == dnnl_data_type_undef
471 ? conf.dst_data_type
472 : conf.attr_info.sum_data_type,
473 "SUM");
474
475 def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_);
476
477 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
478 kernel_ctx.add_option("-Dcl_intel_subgroups_long");
479
480 return status::success;
481 }
482
init_scratchpad()483 void xe_lp_x8s8x_convolution_fwd_t::pd_t::init_scratchpad() {
484 if (conf.attr_info.with_src_zpoints) {
485 size_t size = conf.is_depthwise
486 ? utils::rnd_up(conf.ngroups, 32)
487 : conf.ngroups * utils::rnd_up(conf.oc, 32);
488
489 auto scratchpad = scratchpad_registry().registrar();
490 scratchpad.book(memory_tracking::names::key_conv_wei_reduction, size,
491 types::data_type_size(data_type::s32), OCL_BUFFER_ALIGNMENT);
492 }
493 }
494
execute_forward(const exec_ctx_t & ctx) const495 status_t xe_lp_x8s8x_convolution_fwd_t::execute_forward(
496 const exec_ctx_t &ctx) const {
497 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
498 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
499 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
500 auto &oscales = CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES);
501 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
502 auto &src_zpoints
503 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
504 auto &dst_zpoints
505 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
506
507 const auto &conf = pd()->conf;
508
509 // XXX: first convolution calculates compensation in-place
510 const bool precompute_compensation = conf.is_depthwise || conf.ic > 4;
511
512 std::unique_ptr<memory_storage_t> temp_src_compensation;
513 if (conf.attr_info.with_src_zpoints && precompute_compensation) {
514 temp_src_compensation = ctx.get_scratchpad_grantor().get_memory_storage(
515 memory_tracking::names::key_conv_wei_reduction);
516
517 compute::kernel_arg_list_t arg_list;
518 arg_list.set(0, src_zpoints);
519 arg_list.set(1, weights);
520 arg_list.set(2, *temp_src_compensation);
521
522 auto nd_range = conf.is_depthwise
523 ? compute::nd_range_t(
524 {16, utils::div_up(conf.ngroups, 32), 1}, {16, 1, 1})
525 : compute::nd_range_t(
526 {8, utils::div_up(conf.oc, 32), conf.ngroups},
527 {8, 1, 1});
528 status_t status = parallel_for(
529 ctx, nd_range, src_compensation_kernel_, arg_list);
530 if (status != status::success) return status::runtime_error;
531 }
532
533 compute::kernel_arg_list_t arg_list;
534 arg_list.set(0, src);
535 arg_list.set(1, weights);
536 arg_list.set(2, bias);
537 arg_list.set(3, dst);
538
539 unsigned arg_idx = append_post_ops_to_arg_list(
540 ctx, arg_list, 4, pd()->attr()->post_ops_);
541
542 if (conf.attr_info.common_oscales) {
543 float scales = pd()->attr()->output_scales_.scales_[0];
544 arg_list.set(arg_idx++, scales);
545 } else {
546 arg_list.set(arg_idx++, 1.0f);
547 }
548
549 if (conf.attr_info.with_per_oc_oscales) {
550 if (conf.attr_info.with_runtime_oscales)
551 arg_list.set(arg_idx++, oscales);
552 else
553 arg_list.set(arg_idx++, CTX_GPU_RES_STORAGE(SCALES_));
554 } else {
555 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
556 }
557
558 if (conf.attr_info.with_src_zpoints) {
559 if (precompute_compensation)
560 arg_list.set(arg_idx++, *temp_src_compensation);
561 else
562 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
563 arg_list.set(arg_idx++, src_zpoints);
564 } else {
565 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
566 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
567 }
568
569 if (conf.attr_info.with_dst_zpoints)
570 arg_list.set(arg_idx++, dst_zpoints);
571 else
572 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
573
574 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
575 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
576
577 if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) {
578 ctx.zero_pad_output(DNNL_ARG_DST);
579 }
580 return status;
581 }
582
init_conf()583 status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_conf() {
584 using namespace format_tag;
585
586 const convolution_desc_t &cd = *desc();
587 const memory_desc_wrapper src_mdw(diff_src_md());
588 const memory_desc_wrapper weights_mdw(weights_md());
589 const memory_desc_wrapper dst_mdw(diff_dst_md());
590 const memory_desc_wrapper bias_mdw(weights_md(1));
591
592 set_default_conf(conf, cd, *diff_src_md(), *weights_md(), *diff_dst_md(),
593 *weights_md(1), *attr());
594
595 conf.is_nhwc = is_nhwc(src_mdw, dst_mdw);
596
597 if (conf.with_groups && conf.ngroups > 1
598 && (conf.oc % 32 != 0 || conf.ic % 32 != 0))
599 return status::unimplemented;
600
601 if (!conf.is_nhwc) {
602 if (conf.mb % 16 == 0) {
603 conf.ver = ver_mb_block;
604 } else {
605 conf.ver = ver_ow_block;
606 }
607 }
608
609 conf.oc_block = 32;
610 conf.ic_block = 32;
611 conf.iw_block = 1;
612
613 conf.sub_group_size = 8;
614 conf.nchunk = utils::div_up(conf.ic * conf.ngroups, conf.ic_block);
615 int ic_group = nstl::min(conf.nchunk, 2);
616
617 if (conf.ver == ver_ow_block || conf.is_nhwc) {
618 conf.mb_block = 1;
619 int max_ic = 4;
620 ic_group
621 = utils::max_div(utils::div_up(conf.ic, conf.ic_block), max_ic);
622 int max_subgroups = 32;
623 int max_iw_group = max_subgroups / ic_group;
624 conf.iw_block
625 = (conf.mb * conf.ic * conf.ih * conf.iw < 49 * 1024) ? 4 : 8;
626 int iw_nchunk = utils::div_up(conf.iw, conf.iw_block);
627 int iw_group = utils::max_div(iw_nchunk, max_iw_group);
628
629 //an upper bound on the number of elems per subgroup
630 conf.dst_slm_size = (conf.oc_block / 4)
631 * ((iw_group * conf.iw_block)
632 + (conf.kw - 1) * (1 + conf.dilate_w));
633 conf.iw_tail = conf.iw % conf.iw_block;
634
635 conf.lws_d[0] = 8 * ic_group;
636 conf.lws_d[1] = iw_group;
637 conf.lws_d[2] = 1;
638
639 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
640 conf.gws_d[1] = conf.id * conf.ih * iw_nchunk;
641 conf.gws_d[2] = utils::div_up(conf.mb, utils::div_up(conf.mb_block, 2));
642 } else { //ver_mb_block
643 conf.mb_block = 32;
644 conf.lws_d[0] = 8 * ic_group;
645 conf.lws_d[1] = 8;
646 conf.lws_d[2] = 1;
647
648 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
649 conf.gws_d[1]
650 = conf.id * conf.ih * utils::rnd_up(conf.iw, conf.lws_d[1]);
651 conf.gws_d[2] = 2 * utils::div_up(conf.mb, conf.mb_block);
652 }
653 conf.with_bias = cd.bias_desc.format_kind != format_kind::undef;
654
655 format_tag_t src_tag, dst_tag, wei_tag;
656
657 if (conf.is_nhwc) {
658 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
659 dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
660 } else {
661 src_tag = (conf.ver == ver_ow_block)
662 ? utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c)
663 : utils::pick(
664 conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
665 dst_tag = (conf.ver == ver_ow_block)
666 ? utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c)
667 : utils::pick(
668 conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
669 }
670
671 wei_tag = conf.with_groups ? utils::pick(conf.ndims - 3, gIOw4i8o8i4o,
672 gIOhw4i8o8i4o, gIOdhw4i8o8i4o)
673 : utils::pick(conf.ndims - 3, IOw4i8o8i4o,
674 IOhw4i8o8i4o, IOdhw4i8o8i4o);
675
676 conf.dst_data_type = dst_mdw.data_type();
677 conf.src_data_type = src_mdw.data_type();
678
679 conf.src_tag = src_mdw.format_kind() == format_kind::any
680 ? src_tag
681 : src_mdw.matches_one_of_tag(src_tag);
682 conf.wei_tag = weights_mdw.format_kind() == format_kind::any
683 ? wei_tag
684 : weights_mdw.matches_one_of_tag(wei_tag);
685 conf.dst_tag = dst_mdw.format_kind() == format_kind::any
686 ? dst_tag
687 : dst_mdw.matches_one_of_tag(dst_tag);
688
689 if (conf.src_tag != src_tag || conf.wei_tag != wei_tag
690 || conf.dst_tag != dst_tag)
691 return status::unimplemented;
692
693 return status::success;
694 }
695
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const696 status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_kernel_ctx(
697 compute::kernel_ctx_t &kernel_ctx) const {
698 kernel_ctx.define_int("G", conf.ngroups);
699 kernel_ctx.define_int("MB", conf.mb);
700 kernel_ctx.define_int("IC", conf.ic);
701 kernel_ctx.define_int("ID", conf.id);
702 kernel_ctx.define_int("IH", conf.ih);
703 kernel_ctx.define_int("IW", conf.iw);
704 kernel_ctx.define_int("OC", conf.oc);
705 kernel_ctx.define_int("OD", conf.od);
706 kernel_ctx.define_int("OH", conf.oh);
707 kernel_ctx.define_int("OW", conf.ow);
708 kernel_ctx.define_int("KD", conf.kd);
709 kernel_ctx.define_int("KH", conf.kh);
710 kernel_ctx.define_int("KW", conf.kw);
711 kernel_ctx.define_int("SD", conf.stride_d);
712 kernel_ctx.define_int("SH", conf.stride_h);
713 kernel_ctx.define_int("SW", conf.stride_w);
714 kernel_ctx.define_int("PD", conf.f_pad);
715 kernel_ctx.define_int("PH", conf.t_pad);
716 kernel_ctx.define_int("PW", conf.l_pad);
717 kernel_ctx.define_int("DD", conf.dilate_d);
718 kernel_ctx.define_int("DH", conf.dilate_h);
719 kernel_ctx.define_int("DW", conf.dilate_w);
720
721 kernel_ctx.define_int("IW_PADDED", utils::rnd_up(conf.iw, conf.lws_d[1]));
722 kernel_ctx.define_int("IW_TAIL", conf.iw_tail);
723
724 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
725 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
726 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
727 kernel_ctx.define_int("IW_BLOCK", conf.iw_block);
728
729 kernel_ctx.define_int("MB_GROUP", 1);
730 kernel_ctx.define_int("IC_GROUP", utils::div_up(conf.lws_d[0], 8));
731 kernel_ctx.define_int("SP_GROUP", conf.lws_d[1]);
732
733 kernel_ctx.define_int("IW_NCHUNK", utils::div_up(conf.iw, conf.iw_block));
734 kernel_ctx.define_int("OC_NCHUNK", utils::div_up(conf.oc, conf.oc_block));
735 kernel_ctx.define_int("IC_NCHUNK", utils::div_up(conf.ic, conf.ic_block));
736
737 kernel_ctx.define_int("DST_SLM_SIZE", conf.dst_slm_size);
738 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
739
740 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
741
742 kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
743 kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
744 kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
745
746 kernel_ctx.define_int("IS_NHWC", conf.is_nhwc);
747
748 kernel_ctx.define_int("DISABLE_DPAS", disable_dpas);
749
750 kernel_ctx.set_data_type(conf.dst_data_type);
751 def_data_type(kernel_ctx, conf.src_data_type, "SRC");
752 def_data_type(kernel_ctx, conf.dst_data_type, "DST");
753 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
754
755 return status::success;
756 }
757
execute_backward_data(const exec_ctx_t & ctx) const758 status_t xe_lp_x8s8x_convolution_bwd_data_t::execute_backward_data(
759 const exec_ctx_t &ctx) const {
760
761 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
762 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
763 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
764 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
765
766 const auto &conf = pd()->conf;
767
768 compute::kernel_arg_list_t arg_list;
769 arg_list.set(0, diff_src);
770 arg_list.set(1, weights);
771 arg_list.set(2, bias);
772 arg_list.set(3, diff_dst);
773
774 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
775 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
776
777 return status;
778 }
779
780 } // namespace ocl
781 } // namespace gpu
782 } // namespace impl
783 } // namespace dnnl
784