1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/ocl/gen9_wino_convolution.hpp"
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23 #include "gpu/compute/device_info.hpp"
24 #include "gpu/ocl/ocl_memory_storage.hpp"
25 
26 using namespace dnnl::impl::memory_tracking::names;
27 
28 namespace dnnl {
29 namespace impl {
30 namespace gpu {
31 namespace ocl {
32 
33 using namespace dnnl::impl::data_type;
34 using namespace dnnl::impl::format_tag;
35 
is_impl_optimal(conv_conf_t & conf,const convolution_desc_t & cd,const compute::gpu_arch_t arch)36 static bool is_impl_optimal(conv_conf_t &conf, const convolution_desc_t &cd,
37         const compute::gpu_arch_t arch) {
38     if (cd.alg_kind == alg_kind::convolution_winograd) return true;
39 
40     int ow_blocks = conf.wino_ow / conf.ow_block;
41     float ow_util = (float)conf.ow / conf.wino_ow;
42     int oh_blocks = conf.wino_oh / conf.oh_block;
43     float oh_util = (float)conf.oh / conf.wino_oh;
44     int oc_blocks = conf.ocb;
45     float oc_util = (float)conf.oc_without_padding / conf.wino_oc;
46     float ic_util = (float)conf.ic_without_padding / conf.wino_ic;
47 
48     int blocks = ow_blocks * oh_blocks * oc_blocks;
49     float utilization = ow_util * oh_util * oc_util * ic_util;
50     float score;
51 
52     switch (arch) {
53         case compute::gpu_arch_t::gen9:
54             score = blocks * utilization;
55             if (score >= 128 && utilization >= 0.50) return true;
56             return false;
57         case compute::gpu_arch_t::xe_lp:
58             // Performance is poor with large oc*ic and small spatial, this is
59             // likely due to overflowing cache and no blocking on ic.
60             score = (float)conf.oc * conf.ic / (oh_blocks * ow_blocks);
61             if (score < 32 * 1024 && utilization >= 0.50) return true;
62             return false;
63         default: return false;
64     }
65 }
66 
fwd_compute_block_sizes(conv_conf_t & conf,const compute::gpu_arch_t arch)67 static void fwd_compute_block_sizes(
68         conv_conf_t &conf, const compute::gpu_arch_t arch) {
69 
70     if (conf.ver == ver_16mb16c) {
71         conf.mb_block = (conf.src_data_type == data_type::f16)
72                 ? (conf.mb % 32 == 0 ? 32 : 16)
73                 : 16;
74     } else {
75         conf.mb_block = 1;
76     }
77 
78     //Using F(m, r) for r = 3 and tile_size = m + r - 1
79     const int m = utils::div_up(conf.oh, 6) < utils::div_up(conf.oh, 4)
80             ? 6
81             : conf.oh > 2 ? 4 : 2;
82     const int r = 3;
83     conf.is_fused = true;
84 
85     conf.wino_m = m;
86     conf.wino_r = r;
87     conf.tile_size = m + r - 1;
88 
89     conf.vect_size = (arch == compute::gpu_arch_t::gen9)
90             ? static_cast<int>(16 / types::data_type_size(conf.src_data_type))
91             : 8;
92     conf.oc_block = 16;
93     conf.ic_block = nstl::min(conf.ic, 16);
94     if (conf.src_data_type == data_type::f16)
95         conf.wino_ic_block = 32;
96     else if (arch != compute::gpu_arch_t::gen9 && conf.ow * conf.oh <= 256)
97         conf.wino_ic_block = 32;
98     else
99         conf.wino_ic_block = 16;
100 
101     conf.ocb = utils::div_up(conf.oc, conf.oc_block);
102 
103     if (conf.is_fused) {
104         conf.wino_oc_block = 16;
105         conf.oh_block = conf.wino_m;
106         conf.ow_block = conf.ow > 14 ? 14 : utils::rnd_up(conf.ow, 2);
107     } else {
108         conf.wino_oc_block = 32;
109         conf.oh_block = 8;
110         conf.ow_block = conf.wino_m;
111     }
112 
113     // Used for the internal data transform
114     conf.wino_ow = utils::rnd_up(conf.ow, conf.ow_block);
115     conf.wino_iw = conf.wino_ow;
116     conf.wino_oh = utils::rnd_up(conf.oh, conf.oh_block);
117     conf.wino_ih = conf.wino_oh + conf.t_pad + conf.b_pad;
118     conf.wino_ic = utils::rnd_up(conf.ic, conf.wino_ic_block);
119     conf.wino_oc = utils::rnd_up(conf.oc, conf.wino_oc_block);
120 }
121 
init_conf(compute::compute_engine_t * engine)122 status_t gen9_wino_convolution_fwd_t::pd_t::init_conf(
123         compute::compute_engine_t *engine) {
124 
125     const convolution_desc_t &cd = *desc();
126     const memory_desc_wrapper src_mdw(src_md());
127     const memory_desc_wrapper weights_mdw(weights_md());
128     const memory_desc_wrapper dst_mdw(dst_md());
129     const memory_desc_wrapper bias_mdw(weights_md(1));
130 
131     set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(),
132             *weights_md(1), *attr());
133 
134     conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
135     conf.oc = utils::rnd_up(conf.oc_without_padding, 16);
136 
137     const bool is_wino_shape = conf.ndims == 4 && conf.kh == 3 && conf.kw == 3
138             && conf.ngroups == 1 && conf.stride_h == 1 && conf.stride_w == 1
139             && conf.dilate_h == 0 && conf.dilate_w == 0 && conf.l_pad < conf.kw
140             && conf.r_pad < conf.kw && conf.t_pad < conf.kh
141             && conf.b_pad < conf.kh;
142     if (!is_wino_shape) return status::unimplemented;
143 
144     const bool is_16oc = conf.oc % 16 == 0;
145     const bool is_16ic = conf.ic % 16 == 0;
146 
147     if (src_mdw.matches_one_of_tag(nhwc)
148             && (dst_mdw.matches_one_of_tag(nhwc)
149                     || dst_mdw.format_kind() == format_kind::any)) {
150         // Technically this implementation currently requires ic is a multiple
151         // of VTRANS_BLOCK = 4. This condition was not implemented yet due to no
152         // known use case, and small IC is expected to have poor performance
153         // because of extra work created by the current blocking.
154         if (conf.ic_without_padding % 16 != 0
155                 || conf.oc_without_padding % 16 != 0)
156             return status::unimplemented;
157         conf.ver = ver_nhwc;
158     } else if ((is_16oc && is_16ic)) {
159         conf.ver = (conf.mb % 16 == 0) ? ver_16mb16c : ver_8ow16c;
160     } else {
161         return status::unimplemented;
162     }
163 
164     const compute::gpu_arch_t arch = engine->device_info()->gpu_arch();
165     fwd_compute_block_sizes(conf, arch);
166     if (!is_impl_optimal(conf, cd, arch)) return status::unimplemented;
167 
168     size_t U_sz = conf.tile_size * conf.kh * conf.wino_ic * conf.wino_oc;
169     size_t M_sz = 0, V_sz = 0;
170     if (!conf.is_fused) {
171         M_sz = conf.tile_size * conf.mb * conf.wino_oc * conf.wino_oh
172                 * conf.wino_ow;
173         V_sz = conf.tile_size * conf.mb * conf.wino_ic * conf.wino_ih
174                 * conf.wino_iw;
175     }
176 
177     // Limit max problem size since this method uses more memory
178     if (U_sz + M_sz + V_sz > 300000000) return status::unimplemented;
179 
180     //Using F(m, r) for r = 3 and tile_size = m + r - 1
181     if (!conf.is_fused) {
182         conf.mb_block = 1;
183         conf.lws_d[0] = 8;
184         conf.lws_d[1] = 1;
185         conf.lws_d[2] = 1;
186         conf.gws_d[0] = (conf.wino_oc / conf.wino_oc_block) * conf.lws_d[0];
187         conf.gws_d[1] = conf.wino_ow * (conf.wino_oh / conf.oh_block);
188         conf.gws_d[2] = (conf.mb / conf.mb_block) * conf.tile_size;
189 
190         conf.U_lws_d[0] = 1;
191         conf.U_lws_d[1] = 1;
192         conf.U_lws_d[2] = 1;
193         conf.U_gws_d[0] = 1;
194         conf.U_gws_d[1] = 3; // kh or kw depending
195         conf.U_gws_d[2] = conf.wino_ic * conf.wino_oc;
196 
197         conf.V_lws_d[0] = 1;
198         conf.V_lws_d[1] = 1;
199         conf.V_lws_d[2] = 1;
200         conf.V_gws_d[0] = conf.wino_ow;
201         conf.V_gws_d[1] = conf.wino_ih;
202         conf.V_gws_d[2] = conf.wino_ic / conf.ic_block * conf.mb;
203 
204         conf.M_lws_d[0] = 1;
205         conf.M_lws_d[1] = 1;
206         conf.M_lws_d[2] = 1;
207         conf.M_gws_d[0] = utils::div_up(conf.ow, conf.ow_block);
208         conf.M_gws_d[1] = conf.oh;
209         conf.M_gws_d[2] = conf.oc / conf.oc_block * conf.mb;
210     } else {
211         conf.mb_block = 1;
212         conf.lws_d[0] = conf.wino_ic_block / 2;
213         conf.lws_d[1] = 8;
214         conf.lws_d[2] = 1;
215         conf.gws_d[0]
216                 = utils::div_up(conf.wino_ow, conf.ow_block) * conf.lws_d[0];
217         conf.gws_d[1]
218                 = utils::div_up(conf.wino_oh, conf.oh_block) * conf.lws_d[1];
219         conf.gws_d[2] = (conf.mb / conf.mb_block)
220                 * (conf.wino_oc / conf.wino_oc_block);
221 
222         conf.U_lws_d[0] = conf.wino_ic_block / 2;
223         conf.U_lws_d[1] = 1;
224         conf.U_lws_d[2] = 1;
225         conf.U_gws_d[0] = conf.wino_ic * conf.wino_oc / conf.vect_size;
226         conf.U_gws_d[1] = 3;
227         conf.U_gws_d[2] = 1; // kh or kw depending
228     }
229 
230     format_tag_t src_tag, dst_tag, wei_tag;
231 
232     switch (conf.ver) {
233         case ver_16mb16c:
234             src_tag = NChw16n16c;
235             dst_tag = NChw16n16c;
236             wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
237             break;
238         case ver_8ow16c:
239             src_tag = nChw16c;
240             dst_tag = nChw16c;
241             wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
242             break;
243         case ver_nhwc:
244             src_tag = nhwc;
245             dst_tag = nhwc;
246             wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
247             break;
248         default: return status::unimplemented;
249     }
250 
251     if (src_mdw.format_kind() == format_kind::any) {
252         conf.src_tag = src_tag;
253     } else {
254         conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
255     }
256     if (conf.src_tag != src_tag) return status::unimplemented;
257 
258     if (weights_mdw.format_kind() == format_kind::any) {
259         conf.wei_tag = wei_tag;
260     } else {
261         conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
262     }
263     if (conf.wei_tag != wei_tag) return status::unimplemented;
264 
265     if (dst_mdw.format_kind() == format_kind::any) {
266         conf.dst_tag = dst_tag;
267     } else {
268         conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
269     }
270     if (conf.dst_tag != dst_tag) return status::unimplemented;
271 
272     return status::success;
273 }
274 
init_scratchpad()275 void gen9_wino_convolution_fwd_t::pd_t::init_scratchpad() {
276     auto scratchpad = scratchpad_registry().registrar();
277 
278     auto wei_data_t = this->desc()->weights_desc.data_type;
279     size_t U_sz = conf.tile_size * conf.kh * conf.wino_ic * conf.wino_oc;
280     scratchpad.book(key_wino_U, U_sz, types::data_type_size(wei_data_t),
281             OCL_BUFFER_ALIGNMENT);
282 
283     if (!conf.is_fused) {
284         auto dst_data_t = this->desc()->dst_desc.data_type;
285         size_t M_sz = conf.tile_size * conf.mb * conf.wino_oc * conf.wino_oh
286                 * conf.wino_ow;
287         scratchpad.book(key_wino_M, M_sz, types::data_type_size(dst_data_t),
288                 OCL_BUFFER_ALIGNMENT);
289 
290         auto src_data_t = this->desc()->src_desc.data_type;
291         size_t V_sz = conf.tile_size * conf.mb * conf.wino_ic * conf.wino_ih
292                 * conf.wino_iw;
293         scratchpad.book(key_wino_V, V_sz, types::data_type_size(src_data_t),
294                 OCL_BUFFER_ALIGNMENT);
295     }
296 }
297 
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const298 status_t gen9_wino_convolution_fwd_t::pd_t::init_kernel_ctx(
299         compute::kernel_ctx_t &kernel_ctx) const {
300     kernel_ctx.define_int("G", conf.ngroups);
301     kernel_ctx.define_int("MB", conf.mb);
302     kernel_ctx.define_int("IC", conf.ic);
303     kernel_ctx.define_int("ID", conf.id);
304     kernel_ctx.define_int("IH", conf.ih);
305     kernel_ctx.define_int("IW", conf.iw);
306     kernel_ctx.define_int("OC", conf.oc);
307     kernel_ctx.define_int("OD", conf.od);
308     kernel_ctx.define_int("OH", conf.oh);
309     kernel_ctx.define_int("OW", conf.ow);
310     kernel_ctx.define_int("KD", conf.kd);
311     kernel_ctx.define_int("KH", conf.kh);
312     kernel_ctx.define_int("KW", conf.kw);
313     kernel_ctx.define_int("PH", conf.t_pad);
314     kernel_ctx.define_int("PW", conf.l_pad);
315     kernel_ctx.define_int("OCB", conf.ocb);
316     kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
317     kernel_ctx.define_int("OH_BLOCK", conf.oh_block);
318     kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
319     kernel_ctx.define_int("OW_LAST", utils::rnd_dn(conf.ow, conf.ow_block));
320     kernel_ctx.define_int("OWB", utils::div_up(conf.ow, conf.ow_block));
321     kernel_ctx.define_int("OHB", utils::div_up(conf.oh, conf.oh_block));
322     kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
323     kernel_ctx.define_int("WINO_M", conf.wino_m);
324     kernel_ctx.define_int("WINO_R", conf.wino_r);
325     kernel_ctx.define_int("WINO_IC_BLOCK", conf.wino_ic_block);
326     kernel_ctx.define_int("WINO_OC_BLOCK", conf.wino_oc_block);
327     kernel_ctx.define_int("WINO_IC", conf.wino_ic);
328     kernel_ctx.define_int("WINO_OC", conf.wino_oc);
329     kernel_ctx.define_int("WINO_IH", conf.wino_ih);
330     kernel_ctx.define_int("WINO_IW", conf.wino_iw);
331     kernel_ctx.define_int("WINO_OH", conf.wino_oh);
332     kernel_ctx.define_int("WINO_OW", conf.wino_ow);
333     kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
334     kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
335     kernel_ctx.define_int("VECT_DT_N", conf.vect_size);
336 
337     kernel_ctx.set_data_type(conf.src_data_type);
338 
339     kernel_ctx.define_int("VER_8OW16C", conf.ver == ver_8ow16c);
340     kernel_ctx.define_int("VER_16MB16C", conf.ver == ver_16mb16c);
341 
342     kernel_ctx.define_int("SRC_NHWC", utils::one_of(conf.src_tag, nhwc));
343     kernel_ctx.define_int(
344             "SRC_16N16C", utils::one_of(conf.src_tag, NChw16n16c));
345     kernel_ctx.define_int("SRC_W16C", utils::one_of(conf.src_tag, nChw16c));
346 
347     kernel_ctx.define_int(
348             "WEI_16I16O", utils::one_of(conf.wei_tag, gOIhw16i16o, OIhw16i16o));
349     kernel_ctx.define_int("WEI_16I16O_FLIPPED",
350             utils::one_of(conf.wei_tag, gIOhw16i16o, IOhw16i16o));
351 
352     kernel_ctx.define_int("DST_NHWC", utils::one_of(conf.src_tag, nhwc));
353     kernel_ctx.define_int(
354             "DST_16N16C", utils::one_of(conf.dst_tag, NChw16n16c));
355     kernel_ctx.define_int("DST_W16C", utils::one_of(conf.dst_tag, nChw16c));
356 
357     kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
358 
359     def_attr_info(kernel_ctx, conf.attr_info);
360 
361     kernel_ctx.print_options();
362     return status::success;
363 }
364 
execute_forward(const exec_ctx_t & ctx) const365 status_t gen9_wino_convolution_fwd_t::execute_forward(
366         const exec_ctx_t &ctx) const {
367 
368     auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
369     auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
370     auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
371     auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
372 
373     const auto &conf = pd()->conf;
374     const auto &attr_info = conf.attr_info;
375 
376     std::unique_ptr<memory_storage_t> wei_trans
377             = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_U);
378     compute::kernel_arg_list_t wei_transform_args;
379     wei_transform_args.set(0, *wei_trans);
380     wei_transform_args.set(1, weights);
381     auto wei_trans_nd_range = compute::nd_range_t(conf.U_gws_d, conf.U_lws_d);
382     status_t status = parallel_for(
383             ctx, wei_trans_nd_range, wei_trans_kernel_, wei_transform_args);
384 
385     if (conf.is_fused) {
386         compute::kernel_arg_list_t arg_list;
387         arg_list.set(0, dst);
388         arg_list.set(1, src);
389         arg_list.set(2, *wei_trans);
390         arg_list.set(3, bias);
391         append_post_ops_to_arg_list(ctx, arg_list, 4, attr_info.all_post_ops);
392         auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
393         status = parallel_for(ctx, nd_range, kernel_, arg_list);
394     } else {
395         std::unique_ptr<memory_storage_t> src_trans
396                 = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_V);
397         compute::kernel_arg_list_t src_transform_args;
398         src_transform_args.set(0, *src_trans);
399         src_transform_args.set(1, src);
400         auto src_trans_nd_range
401                 = compute::nd_range_t(conf.V_gws_d, conf.V_lws_d);
402         status = parallel_for(
403                 ctx, src_trans_nd_range, src_trans_kernel_, src_transform_args);
404 
405         std::unique_ptr<memory_storage_t> M_buf
406                 = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_M);
407         compute::kernel_arg_list_t arg_list;
408         arg_list.set(0, *M_buf);
409         arg_list.set(1, *src_trans);
410         arg_list.set(2, *wei_trans);
411         auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
412         status = parallel_for(ctx, nd_range, kernel_, arg_list);
413 
414         compute::kernel_arg_list_t dst_transform_args;
415         dst_transform_args.set(0, dst);
416         dst_transform_args.set(1, *M_buf);
417         dst_transform_args.set(2, bias);
418         append_post_ops_to_arg_list(
419                 ctx, dst_transform_args, 3, attr_info.all_post_ops);
420         auto dst_trans_nd_range
421                 = compute::nd_range_t(conf.M_gws_d, conf.M_lws_d);
422         status = parallel_for(
423                 ctx, dst_trans_nd_range, dst_trans_kernel_, dst_transform_args);
424     }
425 
426     if (attr_info.with_eltwise
427             && !gpu_eltwise_fwd_pd_t::eltwise_preserves_zero(
428                     attr_info.eltwise_alg, attr_info.eltwise_alpha,
429                     attr_info.eltwise_beta)) {
430         ctx.zero_pad_output(DNNL_ARG_DST);
431     }
432     return status;
433 }
434 } // namespace ocl
435 } // namespace gpu
436 } // namespace impl
437 } // namespace dnnl
438 
439 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
440