1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/ocl/gen9_convolution.hpp"
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/reorder.hpp"
23 #include "common/type_helpers.hpp"
24 #include "gpu/ocl/ocl_memory_storage.hpp"
25 #include "gpu/primitive_conf.hpp"
26 
27 using namespace dnnl::impl::memory_tracking::names;
28 
29 namespace dnnl {
30 namespace impl {
31 namespace gpu {
32 namespace ocl {
33 
34 using namespace dnnl::impl::data_type;
35 using namespace dnnl::impl::format_tag;
36 
fwd_compute_block_sizes(conv_conf_t & conf,const convolution_pd_t * pd)37 static void fwd_compute_block_sizes(
38         conv_conf_t &conf, const convolution_pd_t *pd) {
39 
40     int max_ow_block = (conf.src_data_type == data_type::f16 ? 20 : 16);
41     if (conf.ver == ver_16mb16c || conf.ver == ver_32mb16c) {
42         max_ow_block = 1;
43     } else if (conf.is_depthwise || conf.ver == ver_1stconv) {
44         max_ow_block = 8;
45     }
46     max_ow_block = nstl::min(conf.ow, max_ow_block);
47 
48     if (conf.ver == ver_16mb16c) {
49         conf.mb_block
50                 = (conf.src_data_type == data_type::f16 && !conf.is_depthwise)
51                 ? (conf.mb % 32 == 0 ? 32 : 16)
52                 : 16;
53     } else if (conf.ver == ver_32mb16c) {
54         conf.mb_block = 32;
55     } else {
56         conf.mb_block = 1;
57     }
58 
59     conf.ow_block = utils::max_div(conf.ow, max_ow_block);
60 
61     if (conf.ow_block < max_ow_block / 2) {
62         float min_tail_ratio = 1;
63         int best_ow_block = -1;
64         for (int ow_block = 8; ow_block <= max_ow_block; ow_block++) {
65             float tail_ratio
66                     = (ow_block - (conf.ow % ow_block)) / (float)conf.ow;
67             if (tail_ratio <= min_tail_ratio) {
68                 min_tail_ratio = tail_ratio;
69                 best_ow_block = ow_block;
70             }
71         }
72         assert(best_ow_block > 0);
73         conf.ow_block = best_ow_block;
74     }
75 
76     if (conf.is_depthwise) {
77         conf.oc_block = 16;
78         conf.ic_block = 16;
79         conf.omb = conf.mb_block;
80         return;
81     }
82 
83     if (conf.ver == ver_1stconv && conf.mb_block == 1 && conf.oc % 32 == 0) {
84         conf.oc_block = 32;
85     } else {
86         conf.oc_block = 16;
87     }
88     conf.ic_block = nstl::min(conf.ic, 16);
89 
90     conf.omb = (conf.mb_block == 1 && conf.mb % 16 == 0) ? 16 : conf.mb_block;
91     conf.ocb = utils::max_div(conf.oc / 16, 8) * 16;
92 }
93 
init_conf(engine_t * engine)94 status_t gen9_convolution_fwd_t::pd_t::init_conf(engine_t *engine) {
95 
96     const convolution_desc_t &cd = *desc();
97     const memory_desc_wrapper src_mdw(src_md());
98     const memory_desc_wrapper weights_mdw(weights_md());
99     const memory_desc_wrapper dst_mdw(dst_md());
100     const memory_desc_wrapper bias_mdw(weights_md(1));
101 
102     set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(),
103             *weights_md(1), *attr());
104 
105     const bool int8_dst = conf.dst_data_type == data_type::s8;
106     const bool is_src_nhwc
107             = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
108     const bool is_dst_nhwc
109             = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
110     const bool is_nhwc = is_src_nhwc || is_dst_nhwc;
111 
112     const bool is_1stconv = conf.ic_without_padding == 3;
113     const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1)
114             && (conf.oc_without_padding == 1);
115 
116     conf.is_nhwc = is_1stconv ? is_dst_nhwc : is_nhwc;
117     conf.is_depthwise = is_depthwise;
118 
119     const int out_block = int8_dst && !is_1stconv ? 32 : 16;
120     if (is_1stconv || (conf.with_groups && conf.ngroups > 1)) {
121         conf.ic = conf.ic_without_padding;
122         conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, out_block)
123                              : conf.oc_without_padding;
124     } else {
125         conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
126         conf.oc = utils::rnd_up(conf.oc_without_padding, out_block);
127     }
128 
129     conf.ngroups_without_padding = conf.ngroups;
130     if (is_depthwise)
131         conf.ngroups = utils::rnd_up(conf.ngroups, int8_dst ? 32 : 16);
132 
133     const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0);
134     const bool is_16oc = conf.oc % out_block == 0;
135     const bool is_16ic = conf.ic % 16 == 0;
136 
137     conf.mb_block = 1;
138     conf.oc_block = 1;
139     conf.ic_block = 1;
140     conf.od_block = 1;
141     conf.oh_block = 1;
142     conf.ow_block = 1;
143     conf.omb = 1;
144     conf.ocb = 1;
145     auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
146     const bool is_xe_hp_plus
147             = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg();
148     const bool has_non_uniform_wg
149             = compute_engine->mayiuse_non_uniform_work_groups();
150 
151     if (conf.is_nhwc) {
152         if (!utils::one_of(src_mdw.data_type(), f32, f16))
153             return status::unimplemented;
154         if (conf.is_depthwise && conf.ngroups_without_padding % 16)
155             return status::unimplemented;
156         // TODO: Add group convolution support in NHWC kernel.
157         if (!conf.is_depthwise && conf.ngroups > 1 && !(is_16oc && is_16ic)) {
158             return status::unimplemented;
159         }
160         if (int8_dst) { return status::unimplemented; }
161         conf.ver = ver_nhwc;
162     } else if (is_1stconv) {
163         if (!is_16oc) return status::unimplemented;
164         conf.ver = ver_1stconv;
165     } else if ((is_16oc && is_16ic) || is_dw_16g) {
166         if (conf.mb % 32 == 0 && conf.is_depthwise
167                 && utils::one_of(src_mdw.data_type(), bf16, f16)
168                 && is_xe_hp_plus) {
169             conf.ver = ver_32mb16c;
170         } else {
171             conf.ver = (conf.mb % out_block == 0) ? ver_16mb16c : ver_8ow16c;
172         }
173     } else {
174         return status::unimplemented;
175     }
176 
177     const bool is_fp16 = src_mdw.data_type() == data_type::f16;
178 
179     switch (conf.ver) {
180         case ver_nhwc: {
181             conf.mb_block = 1;
182             conf.oc_block = 16;
183             conf.ic_block = is_1stconv ? 1 : 16;
184 
185             int max_ow_block = (conf.kw > 1) ? 8 : 16;
186             if (conf.oc <= 64 && conf.ic <= 64) max_ow_block = 8;
187 
188             conf.ow_block = utils::max_div(conf.ow, max_ow_block);
189 
190             if (conf.ow_block <= 8) {
191                 int max_tail = 0;
192                 for (int j = 8; j < max_ow_block; j++) {
193                     if (conf.ow % j > max_tail) {
194                         max_tail = conf.ow % j;
195                         conf.ow_block = j;
196                     }
197                 }
198             }
199             if (conf.ow_block <= 8) conf.ow_block = 8;
200             if (conf.ow <= 8 || conf.oc <= 32) conf.ow_block = 8;
201 
202             conf.oh_block = 1;
203             conf.sub_group_size = 16;
204             conf.lws_d[0] = 16;
205             conf.lws_d[1] = 1;
206             conf.lws_d[2] = 1;
207 
208             int max_oc_block = 8;
209             if (conf.is_depthwise) {
210                 conf.ocb = conf.ngroups;
211             } else {
212                 conf.ocb = conf.oc_block
213                         * utils::max_div(utils::div_up(conf.oc, conf.oc_block),
214                                 max_oc_block);
215             }
216 
217             conf.gws_d[0] = conf.ocb;
218             conf.gws_d[1] = utils::div_up(conf.oh, conf.oh_block)
219                     * utils::div_up(conf.ow, conf.ow_block) * conf.od;
220             if (conf.is_depthwise) {
221                 conf.gws_d[2] = conf.mb;
222             } else {
223                 conf.gws_d[2] = conf.mb * utils::div_up(conf.oc, conf.ocb)
224                         * conf.ngroups;
225             }
226         } break;
227         case ver_1stconv:
228         case ver_8ow16c:
229         case ver_16mb16c:
230         case ver_32mb16c: {
231             fwd_compute_block_sizes(conf, this);
232             conf.sub_group_size = 16;
233             conf.gws_d[0] = conf.ngroups * conf.ocb / (conf.oc_block / 16);
234             conf.gws_d[1]
235                     = (conf.od * conf.oh * utils::div_up(conf.ow, conf.ow_block)
236                             * (conf.omb / conf.mb_block));
237             conf.gws_d[2] = (conf.oc / conf.ocb) * (conf.mb / conf.omb);
238             conf.lws_d[0] = is_xe_hp_plus ? 32 : 16;
239             conf.lws_d[1] = 1;
240             conf.lws_d[2] = 1;
241             break;
242         }
243         default: return status::unimplemented;
244     }
245 
246     maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf);
247 
248     format_tag_t src_tag, dst_tag, wei_tag;
249 
250     switch (conf.ver) {
251         case ver_nhwc:
252             src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
253             dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
254             if (is_1stconv) {
255                 wei_tag = conf.with_groups ? utils::pick(
256                                   conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
257                                            : utils::pick(conf.ndims - 3, Owi16o,
258                                                    Ohwi16o, Odhwi16o);
259             } else if (conf.is_depthwise) {
260                 wei_tag = utils::pick(
261                         conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g);
262             } else {
263                 wei_tag = conf.with_groups
264                         ? utils::pick(conf.ndims - 3, gOIw16i16o, gOIhw16i16o,
265                                 gOIdhw16i16o)
266                         : utils::pick(conf.ndims - 3, OIw16i16o, OIhw16i16o,
267                                 OIdhw16i16o);
268             }
269             break;
270         case ver_1stconv:
271             if (is_src_nhwc)
272                 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
273             else
274                 src_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw);
275 
276             if (is_xe_hp_plus && is_fp16) {
277                 dst_tag = conf.mb % 32 == 0 ? utils::pick(conf.ndims - 3,
278                                   NCw32n16c, NChw32n16c, NCdhw32n16c)
279                                             : utils::pick(conf.ndims - 3,
280                                                     nCw16c, nChw16c, nCdhw16c);
281             } else {
282                 dst_tag = conf.mb % 16 == 0 ? utils::pick(conf.ndims - 3,
283                                   NCw16n16c, NChw16n16c, NCdhw16n16c)
284                                             : utils::pick(conf.ndims - 3,
285                                                     nCw16c, nChw16c, nCdhw16c);
286             }
287             wei_tag = conf.with_groups
288                     ? utils::pick(conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
289                     : utils::pick(conf.ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
290             break;
291         case ver_16mb16c:
292             src_tag = utils::pick(
293                     conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
294             dst_tag = utils::pick(
295                     conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
296             wei_tag = conf.is_depthwise
297                     ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
298                     : (conf.with_groups ? utils::pick(conf.ndims - 3,
299                                gIOw16i16o, gIOhw16i16o, gIOdhw16i16o)
300                                         : utils::pick(conf.ndims - 3, IOw16i16o,
301                                                 IOhw16i16o, IOdhw16i16o));
302             break;
303         case ver_32mb16c:
304             src_tag = utils::pick(
305                     conf.ndims - 3, NCw32n16c, NChw32n16c, NCdhw32n16c);
306             dst_tag = utils::pick(
307                     conf.ndims - 3, NCw32n16c, NChw32n16c, NCdhw32n16c);
308             wei_tag = conf.is_depthwise
309                     ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
310                     : (conf.with_groups ? utils::pick(conf.ndims - 3,
311                                gIOw16i16o, gIOhw16i16o, gIOdhw16i16o)
312                                         : utils::pick(conf.ndims - 3, IOw16i16o,
313                                                 IOhw16i16o, IOdhw16i16o));
314             break;
315         case ver_8ow16c:
316             src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
317             dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
318             wei_tag = conf.is_depthwise
319                     ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
320                     : (conf.with_groups ? utils::pick(conf.ndims - 3,
321                                gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
322                                         : utils::pick(conf.ndims - 3, OIw16i16o,
323                                                 OIhw16i16o, OIdhw16i16o));
324             break;
325         default: return status::unimplemented;
326     }
327     if (int8_dst) {
328         if (is_1stconv && conf.ic_without_padding < 4) {
329             dst_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw);
330         } else if (conf.ver == ver_16mb16c || conf.ver == ver_32mb16c) {
331             dst_tag = utils::pick(
332                     conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
333         } else {
334             dst_tag = utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c);
335         }
336     }
337 
338     if (src_mdw.format_kind() == format_kind::any) {
339         conf.src_tag = src_tag;
340     } else {
341         conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
342     }
343     if (conf.src_tag != src_tag) return status::unimplemented;
344 
345     if (weights_mdw.format_kind() == format_kind::any) {
346         conf.wei_tag = wei_tag;
347     } else {
348         conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
349     }
350     if (conf.wei_tag != wei_tag) return status::unimplemented;
351 
352     if (dst_mdw.format_kind() == format_kind::any) {
353         conf.dst_tag = dst_tag;
354     } else {
355         conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
356     }
357     if (conf.dst_tag != dst_tag) return status::unimplemented;
358 
359     conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw);
360     conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc);
361 
362     return status::success;
363 }
364 
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const365 status_t gen9_convolution_fwd_t::pd_t::init_kernel_ctx(
366         compute::kernel_ctx_t &kernel_ctx) const {
367     kernel_ctx.define_int("IS_DW", conf.is_depthwise);
368     kernel_ctx.define_int("G", conf.ngroups);
369     kernel_ctx.define_int("MB", conf.mb);
370     kernel_ctx.define_int("IC", conf.ic);
371     kernel_ctx.define_int("ID", conf.id);
372     kernel_ctx.define_int("IH", conf.ih);
373     kernel_ctx.define_int("IW", conf.iw);
374     kernel_ctx.define_int("OC", conf.oc);
375     kernel_ctx.define_int("OD", conf.od);
376     kernel_ctx.define_int("OH", conf.oh);
377     kernel_ctx.define_int("OW", conf.ow);
378     kernel_ctx.define_int("KD", conf.kd);
379     kernel_ctx.define_int("KH", conf.kh);
380     kernel_ctx.define_int("KW", conf.kw);
381     kernel_ctx.define_int("SD", conf.stride_d);
382     kernel_ctx.define_int("SH", conf.stride_h);
383     kernel_ctx.define_int("SW", conf.stride_w);
384     kernel_ctx.define_int("PD", conf.f_pad);
385     kernel_ctx.define_int("PH", conf.t_pad);
386     kernel_ctx.define_int("PW", conf.l_pad);
387     kernel_ctx.define_int("PD_R", conf.back_pad);
388     kernel_ctx.define_int("PH_R", conf.b_pad);
389     kernel_ctx.define_int("PW_R", conf.r_pad);
390     kernel_ctx.define_int("DD", conf.dilate_d);
391     kernel_ctx.define_int("DH", conf.dilate_h);
392     kernel_ctx.define_int("DW", conf.dilate_w);
393     kernel_ctx.define_int("OW_PADDED", utils::rnd_up(conf.ow, 4));
394     kernel_ctx.define_int("OC_PADDED", conf.oc);
395     kernel_ctx.define_int("OMB", conf.omb);
396     kernel_ctx.define_int("OCB", conf.ocb);
397     kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
398     kernel_ctx.define_int("OH_BLOCK", conf.oh_block);
399     kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
400     kernel_ctx.define_int("OW_LAST", utils::rnd_dn(conf.ow, conf.ow_block));
401     kernel_ctx.define_int("OWB", utils::div_up(conf.ow, conf.ow_block));
402     kernel_ctx.define_int("OHB", utils::div_up(conf.oh, conf.oh_block));
403     kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
404     kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
405     kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
406     kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
407     kernel_ctx.define_int("G_WO_PADDING", conf.ngroups_without_padding);
408     kernel_ctx.define_int("IC_WO_PADDING", conf.ic_without_padding);
409     kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
410     kernel_ctx.define_int("OC_GROUP", conf.lws_d[0] / 8);
411     kernel_ctx.define_int("MB_GROUP", 1);
412     kernel_ctx.define_int("SP_GROUP", conf.lws_d[1]);
413     if (conf.kw == 1)
414         kernel_ctx.define_int("SRC_SP_GROUP", conf.lws_d[1] + conf.kw - 1);
415     else
416         kernel_ctx.define_int(
417                 "SRC_SP_GROUP", conf.stride_w * (conf.lws_d[1] - 1) + conf.kw);
418 
419     kernel_ctx.set_data_type(conf.src_data_type);
420     def_data_type(kernel_ctx, conf.dst_data_type, "DST");
421 
422     kernel_ctx.define_int("VER_1STCONV", conf.ver == ver_1stconv);
423     kernel_ctx.define_int("VER_8OW16C", conf.ver == ver_8ow16c);
424     kernel_ctx.define_int("VER_16MB16C", conf.ver == ver_16mb16c);
425     kernel_ctx.define_int("VER_32MB16C", conf.ver == ver_32mb16c);
426 
427     kernel_ctx.define_int("SRC_NCHW", conf.is_src_nchw);
428     kernel_ctx.define_int("SRC_NHWC", conf.is_src_nhwc);
429     kernel_ctx.define_int("SRC_16N16C",
430             utils::one_of(conf.src_tag, NCw16n16c, NChw16n16c, NCdhw16n16c));
431     kernel_ctx.define_int(
432             "SRC_W16C", utils::one_of(conf.src_tag, nCw16c, nChw16c, nCdhw16c));
433 
434     kernel_ctx.define_int("WEI_I16O",
435             utils::one_of(conf.wei_tag, gOwi16o, gOhwi16o, gOdhwi16o, Owi16o,
436                     Ohwi16o, Odhwi16o));
437     kernel_ctx.define_int("WEI_16I16O",
438             utils::one_of(conf.wei_tag, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o,
439                     OIw16i16o, OIhw16i16o, OIdhw16i16o));
440     kernel_ctx.define_int("WEI_16I16O_FLIPPED",
441             utils::one_of(conf.wei_tag, gIOw16i16o, gIOhw16i16o, gIOdhw16i16o,
442                     IOw16i16o, IOhw16i16o, IOdhw16i16o));
443 
444     kernel_ctx.define_int(
445             "DST_W16C", utils::one_of(conf.dst_tag, nCw16c, nChw16c, nCdhw16c));
446     kernel_ctx.define_int("DST_16N16C",
447             utils::one_of(conf.dst_tag, NCw16n16c, NChw16n16c, NCdhw16n16c));
448     kernel_ctx.define_int("DST_32N16C",
449             utils::one_of(conf.dst_tag, NCw32n16c, NChw32n16c, NCdhw32n16c));
450     kernel_ctx.define_int("DST_32N32C",
451             utils::one_of(conf.dst_tag, NCw32n32c, NChw32n32c, NCdhw32n32c));
452     kernel_ctx.define_int(
453             "DST_W32C", utils::one_of(conf.dst_tag, nCw32c, nChw32c, nCdhw32c));
454     kernel_ctx.define_int(
455             "DST_NCHW", utils::one_of(conf.dst_tag, ncw, nchw, ncdhw));
456 
457     kernel_ctx.define_int("GWS_0", conf.gws_d[0]);
458     kernel_ctx.define_int("GWS_1", conf.gws_d[1]);
459     kernel_ctx.define_int("GWS_2", conf.gws_d[2]);
460 
461     kernel_ctx.define_int("GWS_ORIG_0", conf.gws_orig_d[0]);
462     kernel_ctx.define_int("GWS_ORIG_1", conf.gws_orig_d[1]);
463     kernel_ctx.define_int("GWS_ORIG_2", conf.gws_orig_d[2]);
464 
465     kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
466     kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
467     kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
468 
469     dnnl_dims_t dst_dims;
470     dst_dims[0] = conf.mb;
471     dst_dims[1] = conf.ngroups_without_padding * conf.oc_without_padding;
472     dst_dims[2] = conf.ndims > 4 ? conf.od : conf.oh;
473     dst_dims[3] = conf.ndims > 4 ? conf.oh : conf.ow;
474     dst_dims[4] = conf.ow;
475     kernel_ctx.add_option("-cl-std=CL2.0");
476     def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_, &dst_dims);
477 
478     kernel_ctx.print_options();
479     return status::success;
480 }
481 
init_conf(engine_t * engine)482 status_t gen9_convolution_bwd_data_t::pd_t::init_conf(engine_t *engine) {
483     using namespace dnnl::impl::format_tag;
484     using namespace data_type;
485 
486     const convolution_desc_t &cd = *desc();
487     const memory_desc_wrapper src_mdw(diff_src_md());
488     const memory_desc_wrapper weights_mdw(weights_md());
489     const memory_desc_wrapper dst_mdw(diff_dst_md());
490     const memory_desc_wrapper bias_mdw(weights_md(1));
491 
492     set_default_conf(conf, cd, *diff_src_md(), *weights_md(), *diff_dst_md(),
493             *weights_md(1), *attr());
494     const bool is_nhwc
495             = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef
496             || dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc)
497                     != format_tag::undef;
498     const bool is_1stconv = conf.ic_without_padding == 3;
499     const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1)
500             && (conf.oc_without_padding == 1);
501     conf.is_nhwc = is_nhwc;
502     conf.is_depthwise = is_depthwise;
503 
504     if (is_nhwc && (is_depthwise || is_1stconv)) return status::unimplemented;
505 
506     if (is_1stconv || (conf.with_groups && conf.ngroups > 1) || is_depthwise) {
507         conf.ic = conf.ic_without_padding;
508         conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, 16)
509                              : conf.oc_without_padding;
510     } else {
511         conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
512         conf.oc = utils::rnd_up(conf.oc_without_padding, 16);
513     }
514     conf.ngroups_without_padding = conf.ngroups;
515     if (is_depthwise) conf.ngroups = utils::rnd_up(conf.ngroups, 16);
516     const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0);
517 
518     const bool is_16ic = conf.ic % 16 == 0;
519     const bool is_16oc = conf.oc % 16 == 0;
520     const bool use_16mb_unroll = !is_nhwc
521             && !(conf.mb == 1 || conf.mb % 16 != 0) && !is_1stconv
522             && ((is_16ic && is_16oc) || is_dw_16g);
523     conf.mb_block = 1;
524     conf.oc_block = 1;
525     conf.ic_block = 1;
526     conf.od_block = 1;
527     conf.oh_block = 1;
528     conf.ow_block = 1;
529     conf.icb = 1;
530     if (is_nhwc)
531         conf.ver = ver_nhwc;
532     else if (use_16mb_unroll)
533         conf.ver = ver_16mb16c;
534     else if (conf.mb % 16 != 0 && ((is_16oc && is_16ic) || is_dw_16g))
535         conf.ver = ver_8ow16c;
536     else
537         return status::unimplemented;
538 
539     auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
540     //TODO: Fix Gtests and reenable
541     const bool is_xe_hp_plus
542             = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg();
543     const bool has_non_uniform_wg
544             = compute_engine->mayiuse_non_uniform_work_groups();
545 
546     status_t status = status::success;
547     switch (conf.ver) {
548         case ver_16mb16c:
549             conf.mb_block = 16;
550             conf.oc_block = 16;
551             conf.ic_block = 16;
552             conf.od_block = 1;
553             conf.ih_block = 1;
554             conf.iw_block = 1;
555             conf.sub_group_size = 16;
556             if (conf.is_depthwise) {
557                 conf.icb = conf.ngroups;
558                 conf.lws_d[0] = 1;
559                 conf.lws_d[1] = is_xe_hp_plus ? 32 : 16;
560                 conf.lws_d[2] = 1;
561                 conf.gws_d[0] = conf.ih * conf.iw * conf.id;
562                 conf.gws_d[1] = conf.ic * conf.ngroups;
563                 conf.gws_d[2] = conf.mb / 16;
564             } else {
565                 conf.icb = 64;
566                 while (conf.icb > 16) {
567                     if (conf.ic % conf.icb == 0) break;
568                     conf.icb /= 2;
569                 }
570                 conf.lws_d[0] = is_xe_hp_plus ? 32 : 16;
571                 conf.lws_d[1] = 1;
572                 conf.lws_d[2] = 1;
573                 conf.gws_d[0] = conf.icb;
574                 conf.gws_d[1] = conf.ih * conf.iw * conf.id;
575                 conf.gws_d[2]
576                         = conf.mb / 16 * (conf.ic / conf.icb) * conf.ngroups;
577             }
578             break;
579         case ver_8ow16c:
580         case ver_nhwc: {
581             conf.mb_block = 1;
582             conf.oc_block = 16;
583             conf.ic_block = 16;
584             conf.od_block = 1;
585             conf.ih_block = 1;
586             int max_iw_block = 16;
587             if (conf.ver == ver_nhwc) { max_iw_block = (conf.kw > 1) ? 8 : 16; }
588             conf.iw_block = nstl::max(8, utils::max_div(conf.iw, max_iw_block));
589             conf.sub_group_size = 16;
590             if (conf.is_depthwise) {
591                 conf.icb = conf.ngroups;
592                 conf.lws_d[0] = 1;
593                 conf.lws_d[1] = conf.ic_block;
594                 conf.lws_d[2] = 1;
595                 conf.gws_d[0] = conf.ih * utils::div_up(conf.iw, conf.iw_block)
596                         * conf.id;
597                 conf.gws_d[1] = conf.ic * conf.ngroups;
598                 conf.gws_d[2] = conf.mb;
599             } else {
600                 conf.icb = 64;
601                 while (conf.icb > conf.ic_block) {
602                     if (utils::rnd_up(conf.ic, conf.ic_block) % conf.icb == 0)
603                         break;
604                     conf.icb /= 2;
605                 }
606                 conf.lws_d[0] = conf.ic_block;
607                 conf.lws_d[1] = 1;
608                 conf.lws_d[2] = 1;
609                 conf.gws_d[0] = conf.icb;
610                 conf.gws_d[1] = conf.ih * utils::div_up(conf.iw, conf.iw_block)
611                         * conf.id;
612                 conf.gws_d[2] = conf.mb
613                         * (utils::rnd_up(conf.ic, conf.ic_block) / conf.icb)
614                         * conf.ngroups;
615             }
616             break;
617         }
618         default: status = status::unimplemented;
619     }
620 
621     maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf);
622 
623     format_tag_t src_tag, dst_tag, wei_tag;
624 
625     switch (conf.ver) {
626         case ver_nhwc:
627             src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
628             dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
629             wei_tag = conf.with_groups ? utils::pick(conf.ndims - 3, gOIw16o16i,
630                               gOIhw16o16i, gOIdhw16o16i)
631                                        : utils::pick(conf.ndims - 3, OIw16o16i,
632                                                OIhw16o16i, OIdhw16o16i);
633             break;
634         case ver_16mb16c:
635             src_tag = utils::pick(
636                     conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
637             dst_tag = utils::pick(
638                     conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
639             wei_tag = conf.is_depthwise
640                     ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
641                     : (conf.with_groups ? utils::pick(conf.ndims - 3,
642                                gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
643                                         : utils::pick(conf.ndims - 3, OIw16o16i,
644                                                 OIhw16o16i, OIdhw16o16i));
645             break;
646         case ver_8ow16c:
647             src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
648             dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
649             wei_tag = conf.is_depthwise
650                     ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
651                     : (conf.with_groups ? utils::pick(conf.ndims - 3,
652                                gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
653                                         : utils::pick(conf.ndims - 3, OIw16o16i,
654                                                 OIhw16o16i, OIdhw16o16i));
655             break;
656         default: status = status::unimplemented;
657     }
658     if (status != status::success) return status;
659 
660     if (src_mdw.format_kind() == format_kind::any) {
661         conf.src_tag = src_tag;
662     } else {
663         conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
664     }
665     if (conf.src_tag != src_tag) return status::unimplemented;
666 
667     if (weights_mdw.format_kind() == format_kind::any) {
668         conf.wei_tag = wei_tag;
669     } else {
670         conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
671     }
672     if (conf.wei_tag != wei_tag) return status::unimplemented;
673 
674     if (dst_mdw.format_kind() == format_kind::any) {
675         conf.dst_tag = dst_tag;
676     } else {
677         conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
678     }
679     if (conf.dst_tag != dst_tag) return status::unimplemented;
680 
681     conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw);
682     conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc);
683 
684     return status::success;
685 }
686 
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const687 status_t gen9_convolution_bwd_data_t::pd_t::init_kernel_ctx(
688         compute::kernel_ctx_t &kernel_ctx) const {
689     kernel_ctx.define_int("IS_DW", conf.is_depthwise);
690     kernel_ctx.define_int("BWD_DATA", 1);
691     kernel_ctx.define_int("G", conf.ngroups);
692     kernel_ctx.define_int("MB", conf.mb);
693     kernel_ctx.define_int("IC", conf.ic);
694     kernel_ctx.define_int("ICB", conf.icb);
695     kernel_ctx.define_int("ID", conf.id);
696     kernel_ctx.define_int("IH", conf.ih);
697     kernel_ctx.define_int("IW", conf.iw);
698     kernel_ctx.define_int("OC", conf.oc);
699     kernel_ctx.define_int("OD", conf.od);
700     kernel_ctx.define_int("OH", conf.oh);
701     kernel_ctx.define_int("OW", conf.ow);
702     kernel_ctx.define_int("KD", conf.kd);
703     kernel_ctx.define_int("KH", conf.kh);
704     kernel_ctx.define_int("KW", conf.kw);
705     kernel_ctx.define_int("SD", conf.stride_d);
706     kernel_ctx.define_int("SH", conf.stride_h);
707     kernel_ctx.define_int("SW", conf.stride_w);
708     kernel_ctx.define_int("PD", conf.f_pad);
709     kernel_ctx.define_int("PH", conf.t_pad);
710     kernel_ctx.define_int("PW", conf.l_pad);
711     kernel_ctx.define_int("PD_R", conf.back_pad);
712     kernel_ctx.define_int("PH_R", conf.b_pad);
713     kernel_ctx.define_int("PW_R", conf.r_pad);
714     kernel_ctx.define_int("DD", conf.dilate_d);
715     kernel_ctx.define_int("DH", conf.dilate_h);
716     kernel_ctx.define_int("DW", conf.dilate_w);
717     kernel_ctx.define_int("OC_PADDED", utils::rnd_up(conf.oc, conf.oc_block));
718     kernel_ctx.define_int("IC_PADDED", utils::rnd_up(conf.ic, conf.ic_block));
719     kernel_ctx.define_int("G_WO_PADDING", conf.ngroups_without_padding);
720     kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
721     kernel_ctx.define_int("IC_WO_PADDING", conf.ic_without_padding);
722     kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
723     kernel_ctx.define_int("IH_BLOCK", conf.ih_block);
724     kernel_ctx.define_int("IW_BLOCK", conf.iw_block);
725     kernel_ctx.define_int("IWB", utils::div_up(conf.iw, conf.iw_block));
726     kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
727     kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
728     kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
729     kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
730 
731     kernel_ctx.define_int("GWS_0", conf.gws_d[0]);
732     kernel_ctx.define_int("GWS_1", conf.gws_d[1]);
733     kernel_ctx.define_int("GWS_2", conf.gws_d[2]);
734 
735     kernel_ctx.define_int("GWS_ORIG_0", conf.gws_orig_d[0]);
736     kernel_ctx.define_int("GWS_ORIG_1", conf.gws_orig_d[1]);
737     kernel_ctx.define_int("GWS_ORIG_2", conf.gws_orig_d[2]);
738 
739     kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
740     kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
741     kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
742 
743     kernel_ctx.set_data_type(conf.src_data_type);
744 
745     switch (conf.ver) {
746         case ver_16mb16c: kernel_ctx.define_int("VER_16MB16C", 1); break;
747         case ver_8ow16c: kernel_ctx.define_int("VER_8OW16C", 1); break;
748         default: break;
749     }
750 
751     kernel_ctx.add_option("-cl-std=CL2.0");
752 
753     return status::success;
754 }
755 
execute_backward_data(const exec_ctx_t & ctx) const756 status_t gen9_convolution_bwd_data_t::execute_backward_data(
757         const exec_ctx_t &ctx) const {
758 
759     auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
760     auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
761     auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
762     auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
763 
764     const auto &conf = pd()->conf;
765 
766     compute::kernel_arg_list_t arg_list;
767     arg_list.set(0, diff_src);
768     arg_list.set(1, weights);
769     arg_list.set(2, diff_dst);
770     arg_list.set(3, bias);
771 
772     auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
773 
774     status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
775 
776     return status;
777 }
778 
init_conf(engine_t * engine)779 status_t gen9_convolution_bwd_weights_t::pd_t::init_conf(engine_t *engine) {
780     using namespace dnnl::impl::format_tag;
781     using namespace data_type;
782 
783     const convolution_desc_t &cd = *desc();
784     const memory_desc_wrapper src_mdw(src_md());
785     const memory_desc_wrapper weights_mdw(diff_weights_md());
786     const memory_desc_wrapper dst_mdw(diff_dst_md());
787     const memory_desc_wrapper bias_mdw(diff_weights_md(1));
788 
789     set_default_conf(conf, cd, *src_md(), *diff_weights_md(), *diff_dst_md(),
790             *diff_weights_md(1), *attr());
791 
792     const bool is_nhwc
793             = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef
794             || dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc)
795                     != format_tag::undef;
796 
797     const bool is_1stconv = conf.ic_without_padding == 3;
798     const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1)
799             && (conf.oc_without_padding == 1);
800 
801     conf.is_nhwc = is_nhwc;
802     conf.is_depthwise = is_depthwise;
803 
804     if (is_1stconv || (conf.with_groups && conf.ngroups > 1) || is_depthwise
805             || is_nhwc) {
806         conf.ic = conf.ic_without_padding;
807         conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, 16)
808                              : conf.oc_without_padding;
809     } else {
810         conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
811         conf.oc = utils::rnd_up(conf.oc_without_padding, 16);
812     }
813 
814     conf.ngroups_without_padding = conf.ngroups;
815     if (is_depthwise && !is_nhwc)
816         conf.ngroups = utils::rnd_up(conf.ngroups, 16);
817     const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0);
818 
819     const bool is_16ic = conf.ic % 16 == 0;
820     const bool is_16oc = conf.oc % 16 == 0;
821     const bool use_16mb_unroll = !is_nhwc
822             && !(conf.mb == 1 || conf.mb % 16 != 0) && !is_1stconv
823             && ((is_16ic && is_16oc) || is_dw_16g);
824 
825     conf.mb_block = 1;
826     conf.oc_block = 1;
827     conf.ic_block = 1;
828     conf.od_block = 1;
829     conf.oh_block = 1;
830     conf.ow_block = 1;
831     conf.osp_chunk = 1;
832     conf.mb_chunk = 1;
833     if (is_nhwc)
834         conf.ver = ver_nhwc;
835     else if (use_16mb_unroll)
836         conf.ver = ver_16mb16c;
837     else if (conf.mb % 16 != 0 && ((is_16oc && is_16ic) || is_dw_16g))
838         conf.ver = ver_8ow16c;
839     else if (is_1stconv && is_16oc)
840         conf.ver = ver_1stconv;
841     else
842         return status::unimplemented;
843 
844     switch (conf.ver) {
845         case ver_1stconv:
846         case ver_8ow16c:
847         case ver_nhwc:
848             conf.mb_block = 1;
849             conf.oc_block = 16;
850             conf.ic_block = is_1stconv ? 1 : 16;
851             conf.ow_block = 8;
852             break;
853         case ver_16mb16c:
854             conf.mb_block = 16;
855             conf.oc_block = 16;
856             conf.ic_block = 16;
857             conf.ow_block = 1;
858             break;
859     }
860 
861     bwd_w_compute_block_sizes(conf, engine);
862 
863     auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
864 
865     //TODO: Fix Gtests and reenable
866     const bool is_xe_hp_plus
867             = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg();
868     const bool has_non_uniform_wg
869             = compute_engine->mayiuse_non_uniform_work_groups();
870 
871     conf.sub_group_size = 16;
872     conf.lws_d[0] = is_xe_hp_plus ? 32 : 16;
873     conf.lws_d[1] = 1;
874     conf.lws_d[2] = 1;
875 
876     if (conf.is_depthwise) {
877         conf.gws_d[0] = utils::rnd_up(conf.ngroups, 16);
878     } else {
879         conf.gws_d[0] = is_1stconv ? conf.ocb * conf.ngroups
880                                    : conf.ocb * (conf.icb / 16) * conf.ngroups;
881     }
882     conf.gws_d[1] = is_1stconv && !is_nhwc
883             ? utils::div_up(conf.kh * conf.kw * conf.kd * conf.ic, 16)
884             : conf.kh * conf.kw * conf.kd;
885     conf.gws_d[2] = conf.nchunk * utils::div_up(conf.ic, conf.icb)
886             * utils::div_up(conf.oc, conf.ocb);
887 
888     maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf);
889 
890     format_tag_t src_tag, dst_tag, wei_tag;
891 
892     switch (conf.ver) {
893         case ver_nhwc:
894             src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
895             dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
896             if (is_1stconv) {
897                 wei_tag = conf.with_groups ? utils::pick(
898                                   conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
899                                            : utils::pick(conf.ndims - 3, Owi16o,
900                                                    Ohwi16o, Odhwi16o);
901             } else if (conf.is_depthwise) {
902                 wei_tag = utils::pick(
903                         conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g);
904             } else {
905                 wei_tag = conf.with_groups
906                         ? utils::pick(conf.ndims - 3, gIOw16i16o, gIOhw16i16o,
907                                 gIOdhw16i16o)
908                         : utils::pick(conf.ndims - 3, IOw16i16o, IOhw16i16o,
909                                 IOdhw16i16o);
910             }
911             break;
912         case ver_1stconv:
913             assert(!conf.is_depthwise);
914             src_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw);
915             dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
916             wei_tag = conf.with_groups
917                     ? utils::pick(conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
918                     : utils::pick(conf.ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
919             break;
920         case ver_16mb16c:
921             src_tag = utils::pick(
922                     conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
923             dst_tag = utils::pick(
924                     conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
925             wei_tag = conf.is_depthwise
926                     ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
927                     : (conf.with_groups ? utils::pick(conf.ndims - 3,
928                                gIOw16i16o, gIOhw16i16o, gIOdhw16i16o)
929                                         : utils::pick(conf.ndims - 3, IOw16i16o,
930                                                 IOhw16i16o, IOdhw16i16o));
931             break;
932         case ver_8ow16c:
933             src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
934             dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
935             wei_tag = conf.is_depthwise
936                     ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
937                     : (conf.with_groups ? utils::pick(conf.ndims - 3,
938                                gIOw16i16o, gIOhw16i16o, gIOdhw16i16o)
939                                         : utils::pick(conf.ndims - 3, IOw16i16o,
940                                                 IOhw16i16o, IOdhw16i16o));
941             break;
942         default: return status::unimplemented;
943     }
944 
945     if (src_mdw.format_kind() == format_kind::any) {
946         conf.src_tag = src_tag;
947     } else {
948         conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
949     }
950     if (conf.src_tag != src_tag) return status::unimplemented;
951 
952     if (weights_mdw.format_kind() == format_kind::any) {
953         conf.wei_tag = wei_tag;
954     } else {
955         conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
956     }
957     if (conf.wei_tag != wei_tag) return status::unimplemented;
958 
959     if (dst_mdw.format_kind() == format_kind::any) {
960         conf.dst_tag = dst_tag;
961     } else {
962         conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
963     }
964     if (conf.dst_tag != dst_tag) return status::unimplemented;
965 
966     conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw);
967     conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc);
968 
969     bool ok = set_default_formats_common(
970             conf.src_tag, conf.wei_tag, conf.dst_tag);
971     if (!ok) return status::unimplemented;
972     if (is_1stconv && !is_nhwc) {
973         if (data_type::bf16 == conf.weights_data_type) {
974             conf.reorder_wei = true;
975             auto temp_wei_md = *diff_weights_md();
976             temp_wei_md.data_type = data_type::f32;
977 
978             primitive_attr_t r_attr(default_attr());
979             if (!r_attr.is_initialized()) return status::out_of_memory;
980 
981             CHECK(reorder_primitive_desc_create(rpd_wei_, engine, &temp_wei_md,
982                     diff_weights_md(), &r_attr));
983         }
984 
985         if (conf.with_bias && data_type::bf16 == conf.bias_data_type) {
986             conf.reorder_bias = true;
987             auto temp_bias_md = *diff_weights_md(1);
988             temp_bias_md.data_type = data_type::f32;
989             primitive_attr_t r_attr(default_attr());
990             if (!r_attr.is_initialized()) return status::out_of_memory;
991 
992             CHECK(reorder_primitive_desc_create(rpd_bia_, engine, &temp_bias_md,
993                     diff_weights_md(1), &r_attr));
994         }
995     }
996 
997     return status::success;
998 }
999 
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const1000 status_t gen9_convolution_bwd_weights_t::pd_t::init_kernel_ctx(
1001         compute::kernel_ctx_t &kernel_ctx) const {
1002     kernel_ctx.define_int("IS_DW", conf.is_depthwise);
1003     kernel_ctx.define_int("BWD_WEIGHTS", 1);
1004     kernel_ctx.define_int("G", conf.ngroups);
1005     kernel_ctx.define_int("MB", conf.mb);
1006     kernel_ctx.define_int("IC", conf.ic);
1007     kernel_ctx.define_int("ICB", conf.icb);
1008     kernel_ctx.define_int("ID", conf.id);
1009     kernel_ctx.define_int("IH", conf.ih);
1010     kernel_ctx.define_int("IW", conf.iw);
1011     kernel_ctx.define_int("OC", conf.oc);
1012     kernel_ctx.define_int("OCB", conf.ocb);
1013     kernel_ctx.define_int("OD", conf.od);
1014     kernel_ctx.define_int("OH", conf.oh);
1015     kernel_ctx.define_int("OW", conf.ow);
1016     kernel_ctx.define_int("KD", conf.kd);
1017     kernel_ctx.define_int("KH", conf.kh);
1018     kernel_ctx.define_int("KW", conf.kw);
1019     kernel_ctx.define_int("SD", conf.stride_d);
1020     kernel_ctx.define_int("SH", conf.stride_h);
1021     kernel_ctx.define_int("SW", conf.stride_w);
1022     kernel_ctx.define_int("PD", conf.f_pad);
1023     kernel_ctx.define_int("PH", conf.t_pad);
1024     kernel_ctx.define_int("PW", conf.l_pad);
1025     kernel_ctx.define_int("PD_R", conf.back_pad);
1026     kernel_ctx.define_int("PH_R", conf.b_pad);
1027     kernel_ctx.define_int("PW_R", conf.r_pad);
1028     kernel_ctx.define_int("DD", conf.dilate_d);
1029     kernel_ctx.define_int("DH", conf.dilate_h);
1030     kernel_ctx.define_int("DW", conf.dilate_w);
1031     kernel_ctx.define_int("OC_PADDED", conf.oc);
1032     kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
1033     kernel_ctx.define_int("G_WO_PADDING", conf.ngroups_without_padding);
1034 
1035     kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
1036     kernel_ctx.define_int("ODB", conf.odb);
1037     kernel_ctx.define_int("OHB", conf.ohb);
1038     kernel_ctx.define_int("OWB", conf.owb);
1039 
1040     kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
1041     kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
1042     kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
1043     kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
1044     kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
1045     kernel_ctx.define_int("NCHUNK", conf.nchunk);
1046     kernel_ctx.define_int("OSP_CHUNK", conf.osp_chunk);
1047     kernel_ctx.define_int("MB_CHUNK", conf.mb_chunk);
1048     kernel_ctx.define_int(
1049             "MB_CHUNK_SIZE", utils::div_up(conf.mb, conf.mb_chunk));
1050     kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
1051 
1052     kernel_ctx.define_int("GWS_0", conf.gws_d[0]);
1053     kernel_ctx.define_int("GWS_1", conf.gws_d[1]);
1054     kernel_ctx.define_int("GWS_2", conf.gws_d[2]);
1055 
1056     kernel_ctx.define_int("GWS_ORIG_0", conf.gws_orig_d[0]);
1057     kernel_ctx.define_int("GWS_ORIG_1", conf.gws_orig_d[1]);
1058     kernel_ctx.define_int("GWS_ORIG_2", conf.gws_orig_d[2]);
1059 
1060     kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
1061     kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
1062     kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
1063 
1064     kernel_ctx.add_option("-cl-std=CL2.0");
1065 
1066     kernel_ctx.set_data_type(data_type::f32);
1067     def_data_type(kernel_ctx, src_md()->data_type, "SRC");
1068 
1069     def_data_type(kernel_ctx, diff_dst_md()->data_type, "DST");
1070 
1071     def_data_type(kernel_ctx,
1072             diff_weights_md(conf.with_bias ? 1 : 0)->data_type, "BIA");
1073 
1074     def_data_type(kernel_ctx, data_type::f32, "WEI");
1075 
1076     switch (conf.ver) {
1077         case ver_16mb16c: kernel_ctx.define_int("VER_16MB16C", 1); break;
1078         case ver_1stconv:
1079         case ver_8ow16c: kernel_ctx.define_int("VER_8OW16C", 1); break;
1080         default: break;
1081     }
1082 
1083     return status::success;
1084 }
1085 
init_scratchpad()1086 status_t gen9_convolution_bwd_weights_t::pd_t::init_scratchpad() {
1087     auto scratchpad = scratchpad_registry().registrar();
1088     if (!conf.reorder_wei && !conf.reorder_bias) return status::success;
1089     if (conf.reorder_wei) {
1090         auto temp_wei_md = *diff_weights_md();
1091         temp_wei_md.data_type = data_type::f32;
1092         memory_desc_wrapper wei_md_d(temp_wei_md);
1093         scratchpad.book(memory_tracking::names::key_conv_bwd_w_1st_wei_reorder,
1094                 wei_md_d.size(), 1, OCL_BUFFER_ALIGNMENT);
1095         scratchpad.book(memory_tracking::names::key_nested_multiple,
1096                 rpd_wei_->scratchpad_registry());
1097     }
1098     if (!conf.reorder_bias) return status::success;
1099     auto temp_bias_md = *diff_weights_md(1);
1100     temp_bias_md.data_type = data_type::f32;
1101     memory_desc_wrapper bia_md_d(temp_bias_md);
1102     scratchpad.book(memory_tracking::names::key_conv_bwd_w_1st_bia_reorder,
1103             bia_md_d.size(), 1, OCL_BUFFER_ALIGNMENT);
1104     scratchpad.book(memory_tracking::names::key_nested_multiple + 1,
1105             rpd_bia_->scratchpad_registry());
1106 
1107     return status::success;
1108 }
1109 
execute_forward(const exec_ctx_t & ctx) const1110 status_t gen9_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
1111 
1112     auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
1113     auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
1114     auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
1115     auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
1116 
1117     const auto &conf = pd()->conf;
1118 
1119     compute::kernel_arg_list_t arg_list;
1120     arg_list.set(0, src);
1121     arg_list.set(1, weights);
1122     arg_list.set(2, bias);
1123     arg_list.set(3, dst);
1124     append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_);
1125 
1126     auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
1127 
1128     status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
1129 
1130     if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) {
1131         ctx.zero_pad_output(DNNL_ARG_DST);
1132     }
1133     return status;
1134 }
1135 
execute_backward_weights(const exec_ctx_t & ctx) const1136 status_t gen9_convolution_bwd_weights_t::execute_backward_weights(
1137         const exec_ctx_t &ctx) const {
1138     auto *compute_stream
1139             = utils::downcast<compute::compute_stream_t *>(ctx.stream());
1140 
1141     auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
1142     auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
1143     auto &diff_weights = CTX_OUT_STORAGE(DNNL_ARG_DIFF_WEIGHTS);
1144     auto &diff_bias = CTX_OUT_STORAGE(DNNL_ARG_DIFF_BIAS);
1145 
1146     const auto &conf = pd()->conf;
1147 
1148     const uint8_t zero = 0;
1149     std::unique_ptr<memory_t> wspace_wei;
1150     std::unique_ptr<memory_t> wspace_bia;
1151     auto temp_wei_md = *pd()->diff_weights_md();
1152     auto temp_bia_md = *pd()->diff_weights_md(1);
1153     std::unique_ptr<memory_storage_t> wspace_ptr_wei;
1154     std::unique_ptr<memory_storage_t> wspace_ptr_bia;
1155     if (conf.reorder_wei) {
1156         wspace_ptr_wei = ctx.get_scratchpad_grantor().get_memory_storage(
1157                 memory_tracking::names::key_conv_bwd_w_1st_wei_reorder);
1158 
1159         temp_wei_md.data_type = data_type::f32;
1160     }
1161     if (conf.reorder_bias) {
1162         wspace_ptr_bia = ctx.get_scratchpad_grantor().get_memory_storage(
1163                 memory_tracking::names::key_conv_bwd_w_1st_bia_reorder);
1164 
1165         temp_bia_md.data_type = data_type::f32;
1166     }
1167 
1168     memory_desc_wrapper wei_mdw(temp_wei_md);
1169     CHECK(compute_stream->fill(
1170             conf.reorder_wei ? *wspace_ptr_wei : diff_weights, zero,
1171             wei_mdw.size()));
1172     if (conf.with_bias) {
1173         memory_desc_wrapper bia_mdw(temp_bia_md);
1174         CHECK(compute_stream->fill(
1175                 conf.reorder_bias ? *wspace_ptr_bia : diff_bias, zero,
1176                 bia_mdw.size()));
1177     }
1178 
1179     compute::kernel_arg_list_t arg_list;
1180     arg_list.set(0, src);
1181     arg_list.set(1, conf.reorder_wei ? *wspace_ptr_wei : diff_weights);
1182     arg_list.set(2, conf.reorder_bias ? *wspace_ptr_bia : diff_bias);
1183     arg_list.set(3, diff_dst);
1184 
1185     status_t status = parallel_for(ctx,
1186             compute::nd_range_t(conf.gws_d, conf.lws_d), kernel_, arg_list);
1187     if (status != status::success) return status;
1188     auto exec_reorder = [&](memory_t *in, memory_t *out,
1189                                 const std::shared_ptr<primitive_t> &prim,
1190                                 int r_num) -> status_t {
1191         exec_args_t r_args;
1192         r_args[DNNL_ARG_FROM] = memory_arg_t {in, true};
1193         r_args[DNNL_ARG_TO] = memory_arg_t {out, false};
1194         exec_ctx_t r_ctx(ctx, std::move(r_args));
1195         nested_scratchpad_t ns(
1196                 ctx, memory_tracking::names::key_nested_multiple + r_num, prim);
1197         r_ctx.set_scratchpad_grantor(ns.grantor());
1198         return prim->execute(r_ctx);
1199     };
1200 
1201     if (conf.reorder_wei) {
1202         CHECK(safe_ptr_assign(wspace_wei,
1203                 new memory_t(ctx.stream()->engine(), &temp_wei_md,
1204                         std::move(wspace_ptr_wei))));
1205         CHECK(exec_reorder(wspace_wei.get(), ctx.output(DNNL_ARG_DIFF_WEIGHTS),
1206                 wei_reorder_, 0));
1207     }
1208     if (conf.reorder_bias) {
1209         CHECK(safe_ptr_assign(wspace_bia,
1210                 new memory_t(ctx.stream()->engine(), &temp_bia_md,
1211                         std::move(wspace_ptr_bia))));
1212         CHECK(exec_reorder(wspace_bia.get(), ctx.output(DNNL_ARG_DIFF_BIAS),
1213                 bia_reorder_, 1));
1214     }
1215 
1216     return status::success;
1217 }
1218 
1219 } // namespace ocl
1220 } // namespace gpu
1221 } // namespace impl
1222 } // namespace dnnl
1223 
1224 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1225