1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gpu/ocl/gen9_wino_convolution.hpp"
18
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23 #include "gpu/compute/device_info.hpp"
24 #include "gpu/ocl/ocl_memory_storage.hpp"
25
26 using namespace dnnl::impl::memory_tracking::names;
27
28 namespace dnnl {
29 namespace impl {
30 namespace gpu {
31 namespace ocl {
32
33 using namespace dnnl::impl::data_type;
34 using namespace dnnl::impl::format_tag;
35
is_impl_optimal(conv_conf_t & conf,const convolution_desc_t & cd,const compute::gpu_arch_t arch)36 static bool is_impl_optimal(conv_conf_t &conf, const convolution_desc_t &cd,
37 const compute::gpu_arch_t arch) {
38 if (cd.alg_kind == alg_kind::convolution_winograd) return true;
39
40 int ow_blocks = conf.wino_ow / conf.ow_block;
41 float ow_util = (float)conf.ow / conf.wino_ow;
42 int oh_blocks = conf.wino_oh / conf.oh_block;
43 float oh_util = (float)conf.oh / conf.wino_oh;
44 int oc_blocks = conf.ocb;
45 float oc_util = (float)conf.oc_without_padding / conf.wino_oc;
46 float ic_util = (float)conf.ic_without_padding / conf.wino_ic;
47
48 int blocks = ow_blocks * oh_blocks * oc_blocks;
49 float utilization = ow_util * oh_util * oc_util * ic_util;
50 float score;
51
52 switch (arch) {
53 case compute::gpu_arch_t::gen9:
54 score = blocks * utilization;
55 if (score >= 128 && utilization >= 0.50) return true;
56 return false;
57 case compute::gpu_arch_t::xe_lp:
58 // Performance is poor with large oc*ic and small spatial, this is
59 // likely due to overflowing cache and no blocking on ic.
60 score = (float)conf.oc * conf.ic / (oh_blocks * ow_blocks);
61 if (score < 32 * 1024 && utilization >= 0.50) return true;
62 return false;
63 default: return false;
64 }
65 }
66
fwd_compute_block_sizes(conv_conf_t & conf,const compute::gpu_arch_t arch)67 static void fwd_compute_block_sizes(
68 conv_conf_t &conf, const compute::gpu_arch_t arch) {
69
70 if (conf.ver == ver_16mb16c) {
71 conf.mb_block = (conf.src_data_type == data_type::f16)
72 ? (conf.mb % 32 == 0 ? 32 : 16)
73 : 16;
74 } else {
75 conf.mb_block = 1;
76 }
77
78 //Using F(m, r) for r = 3 and tile_size = m + r - 1
79 const int m = utils::div_up(conf.oh, 6) < utils::div_up(conf.oh, 4)
80 ? 6
81 : conf.oh > 2 ? 4 : 2;
82 const int r = 3;
83 conf.is_fused = true;
84
85 conf.wino_m = m;
86 conf.wino_r = r;
87 conf.tile_size = m + r - 1;
88
89 conf.vect_size = (arch == compute::gpu_arch_t::gen9)
90 ? static_cast<int>(16 / types::data_type_size(conf.src_data_type))
91 : 8;
92 conf.oc_block = 16;
93 conf.ic_block = nstl::min(conf.ic, 16);
94 if (conf.src_data_type == data_type::f16)
95 conf.wino_ic_block = 32;
96 else if (arch != compute::gpu_arch_t::gen9 && conf.ow * conf.oh <= 256)
97 conf.wino_ic_block = 32;
98 else
99 conf.wino_ic_block = 16;
100
101 conf.ocb = utils::div_up(conf.oc, conf.oc_block);
102
103 if (conf.is_fused) {
104 conf.wino_oc_block = 16;
105 conf.oh_block = conf.wino_m;
106 conf.ow_block = conf.ow > 14 ? 14 : utils::rnd_up(conf.ow, 2);
107 } else {
108 conf.wino_oc_block = 32;
109 conf.oh_block = 8;
110 conf.ow_block = conf.wino_m;
111 }
112
113 // Used for the internal data transform
114 conf.wino_ow = utils::rnd_up(conf.ow, conf.ow_block);
115 conf.wino_iw = conf.wino_ow;
116 conf.wino_oh = utils::rnd_up(conf.oh, conf.oh_block);
117 conf.wino_ih = conf.wino_oh + conf.t_pad + conf.b_pad;
118 conf.wino_ic = utils::rnd_up(conf.ic, conf.wino_ic_block);
119 conf.wino_oc = utils::rnd_up(conf.oc, conf.wino_oc_block);
120 }
121
init_conf(compute::compute_engine_t * engine)122 status_t gen9_wino_convolution_fwd_t::pd_t::init_conf(
123 compute::compute_engine_t *engine) {
124
125 const convolution_desc_t &cd = *desc();
126 const memory_desc_wrapper src_mdw(src_md());
127 const memory_desc_wrapper weights_mdw(weights_md());
128 const memory_desc_wrapper dst_mdw(dst_md());
129 const memory_desc_wrapper bias_mdw(weights_md(1));
130
131 set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(),
132 *weights_md(1), *attr());
133
134 conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
135 conf.oc = utils::rnd_up(conf.oc_without_padding, 16);
136
137 const bool is_wino_shape = conf.ndims == 4 && conf.kh == 3 && conf.kw == 3
138 && conf.ngroups == 1 && conf.stride_h == 1 && conf.stride_w == 1
139 && conf.dilate_h == 0 && conf.dilate_w == 0 && conf.l_pad < conf.kw
140 && conf.r_pad < conf.kw && conf.t_pad < conf.kh
141 && conf.b_pad < conf.kh;
142 if (!is_wino_shape) return status::unimplemented;
143
144 const bool is_16oc = conf.oc % 16 == 0;
145 const bool is_16ic = conf.ic % 16 == 0;
146
147 if (src_mdw.matches_one_of_tag(nhwc)
148 && (dst_mdw.matches_one_of_tag(nhwc)
149 || dst_mdw.format_kind() == format_kind::any)) {
150 // Technically this implementation currently requires ic is a multiple
151 // of VTRANS_BLOCK = 4. This condition was not implemented yet due to no
152 // known use case, and small IC is expected to have poor performance
153 // because of extra work created by the current blocking.
154 if (conf.ic_without_padding % 16 != 0
155 || conf.oc_without_padding % 16 != 0)
156 return status::unimplemented;
157 conf.ver = ver_nhwc;
158 } else if ((is_16oc && is_16ic)) {
159 conf.ver = (conf.mb % 16 == 0) ? ver_16mb16c : ver_8ow16c;
160 } else {
161 return status::unimplemented;
162 }
163
164 const compute::gpu_arch_t arch = engine->device_info()->gpu_arch();
165 fwd_compute_block_sizes(conf, arch);
166 if (!is_impl_optimal(conf, cd, arch)) return status::unimplemented;
167
168 size_t U_sz = conf.tile_size * conf.kh * conf.wino_ic * conf.wino_oc;
169 size_t M_sz = 0, V_sz = 0;
170 if (!conf.is_fused) {
171 M_sz = conf.tile_size * conf.mb * conf.wino_oc * conf.wino_oh
172 * conf.wino_ow;
173 V_sz = conf.tile_size * conf.mb * conf.wino_ic * conf.wino_ih
174 * conf.wino_iw;
175 }
176
177 // Limit max problem size since this method uses more memory
178 if (U_sz + M_sz + V_sz > 300000000) return status::unimplemented;
179
180 //Using F(m, r) for r = 3 and tile_size = m + r - 1
181 if (!conf.is_fused) {
182 conf.mb_block = 1;
183 conf.lws_d[0] = 8;
184 conf.lws_d[1] = 1;
185 conf.lws_d[2] = 1;
186 conf.gws_d[0] = (conf.wino_oc / conf.wino_oc_block) * conf.lws_d[0];
187 conf.gws_d[1] = conf.wino_ow * (conf.wino_oh / conf.oh_block);
188 conf.gws_d[2] = (conf.mb / conf.mb_block) * conf.tile_size;
189
190 conf.U_lws_d[0] = 1;
191 conf.U_lws_d[1] = 1;
192 conf.U_lws_d[2] = 1;
193 conf.U_gws_d[0] = 1;
194 conf.U_gws_d[1] = 3; // kh or kw depending
195 conf.U_gws_d[2] = conf.wino_ic * conf.wino_oc;
196
197 conf.V_lws_d[0] = 1;
198 conf.V_lws_d[1] = 1;
199 conf.V_lws_d[2] = 1;
200 conf.V_gws_d[0] = conf.wino_ow;
201 conf.V_gws_d[1] = conf.wino_ih;
202 conf.V_gws_d[2] = conf.wino_ic / conf.ic_block * conf.mb;
203
204 conf.M_lws_d[0] = 1;
205 conf.M_lws_d[1] = 1;
206 conf.M_lws_d[2] = 1;
207 conf.M_gws_d[0] = utils::div_up(conf.ow, conf.ow_block);
208 conf.M_gws_d[1] = conf.oh;
209 conf.M_gws_d[2] = conf.oc / conf.oc_block * conf.mb;
210 } else {
211 conf.mb_block = 1;
212 conf.lws_d[0] = conf.wino_ic_block / 2;
213 conf.lws_d[1] = 8;
214 conf.lws_d[2] = 1;
215 conf.gws_d[0]
216 = utils::div_up(conf.wino_ow, conf.ow_block) * conf.lws_d[0];
217 conf.gws_d[1]
218 = utils::div_up(conf.wino_oh, conf.oh_block) * conf.lws_d[1];
219 conf.gws_d[2] = (conf.mb / conf.mb_block)
220 * (conf.wino_oc / conf.wino_oc_block);
221
222 conf.U_lws_d[0] = conf.wino_ic_block / 2;
223 conf.U_lws_d[1] = 1;
224 conf.U_lws_d[2] = 1;
225 conf.U_gws_d[0] = conf.wino_ic * conf.wino_oc / conf.vect_size;
226 conf.U_gws_d[1] = 3;
227 conf.U_gws_d[2] = 1; // kh or kw depending
228 }
229
230 format_tag_t src_tag, dst_tag, wei_tag;
231
232 switch (conf.ver) {
233 case ver_16mb16c:
234 src_tag = NChw16n16c;
235 dst_tag = NChw16n16c;
236 wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
237 break;
238 case ver_8ow16c:
239 src_tag = nChw16c;
240 dst_tag = nChw16c;
241 wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
242 break;
243 case ver_nhwc:
244 src_tag = nhwc;
245 dst_tag = nhwc;
246 wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
247 break;
248 default: return status::unimplemented;
249 }
250
251 if (src_mdw.format_kind() == format_kind::any) {
252 conf.src_tag = src_tag;
253 } else {
254 conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
255 }
256 if (conf.src_tag != src_tag) return status::unimplemented;
257
258 if (weights_mdw.format_kind() == format_kind::any) {
259 conf.wei_tag = wei_tag;
260 } else {
261 conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
262 }
263 if (conf.wei_tag != wei_tag) return status::unimplemented;
264
265 if (dst_mdw.format_kind() == format_kind::any) {
266 conf.dst_tag = dst_tag;
267 } else {
268 conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
269 }
270 if (conf.dst_tag != dst_tag) return status::unimplemented;
271
272 return status::success;
273 }
274
init_scratchpad()275 void gen9_wino_convolution_fwd_t::pd_t::init_scratchpad() {
276 auto scratchpad = scratchpad_registry().registrar();
277
278 auto wei_data_t = this->desc()->weights_desc.data_type;
279 size_t U_sz = conf.tile_size * conf.kh * conf.wino_ic * conf.wino_oc;
280 scratchpad.book(key_wino_U, U_sz, types::data_type_size(wei_data_t),
281 OCL_BUFFER_ALIGNMENT);
282
283 if (!conf.is_fused) {
284 auto dst_data_t = this->desc()->dst_desc.data_type;
285 size_t M_sz = conf.tile_size * conf.mb * conf.wino_oc * conf.wino_oh
286 * conf.wino_ow;
287 scratchpad.book(key_wino_M, M_sz, types::data_type_size(dst_data_t),
288 OCL_BUFFER_ALIGNMENT);
289
290 auto src_data_t = this->desc()->src_desc.data_type;
291 size_t V_sz = conf.tile_size * conf.mb * conf.wino_ic * conf.wino_ih
292 * conf.wino_iw;
293 scratchpad.book(key_wino_V, V_sz, types::data_type_size(src_data_t),
294 OCL_BUFFER_ALIGNMENT);
295 }
296 }
297
init_kernel_ctx(compute::kernel_ctx_t & kernel_ctx) const298 status_t gen9_wino_convolution_fwd_t::pd_t::init_kernel_ctx(
299 compute::kernel_ctx_t &kernel_ctx) const {
300 kernel_ctx.define_int("G", conf.ngroups);
301 kernel_ctx.define_int("MB", conf.mb);
302 kernel_ctx.define_int("IC", conf.ic);
303 kernel_ctx.define_int("ID", conf.id);
304 kernel_ctx.define_int("IH", conf.ih);
305 kernel_ctx.define_int("IW", conf.iw);
306 kernel_ctx.define_int("OC", conf.oc);
307 kernel_ctx.define_int("OD", conf.od);
308 kernel_ctx.define_int("OH", conf.oh);
309 kernel_ctx.define_int("OW", conf.ow);
310 kernel_ctx.define_int("KD", conf.kd);
311 kernel_ctx.define_int("KH", conf.kh);
312 kernel_ctx.define_int("KW", conf.kw);
313 kernel_ctx.define_int("PH", conf.t_pad);
314 kernel_ctx.define_int("PW", conf.l_pad);
315 kernel_ctx.define_int("OCB", conf.ocb);
316 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
317 kernel_ctx.define_int("OH_BLOCK", conf.oh_block);
318 kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
319 kernel_ctx.define_int("OW_LAST", utils::rnd_dn(conf.ow, conf.ow_block));
320 kernel_ctx.define_int("OWB", utils::div_up(conf.ow, conf.ow_block));
321 kernel_ctx.define_int("OHB", utils::div_up(conf.oh, conf.oh_block));
322 kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
323 kernel_ctx.define_int("WINO_M", conf.wino_m);
324 kernel_ctx.define_int("WINO_R", conf.wino_r);
325 kernel_ctx.define_int("WINO_IC_BLOCK", conf.wino_ic_block);
326 kernel_ctx.define_int("WINO_OC_BLOCK", conf.wino_oc_block);
327 kernel_ctx.define_int("WINO_IC", conf.wino_ic);
328 kernel_ctx.define_int("WINO_OC", conf.wino_oc);
329 kernel_ctx.define_int("WINO_IH", conf.wino_ih);
330 kernel_ctx.define_int("WINO_IW", conf.wino_iw);
331 kernel_ctx.define_int("WINO_OH", conf.wino_oh);
332 kernel_ctx.define_int("WINO_OW", conf.wino_ow);
333 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
334 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
335 kernel_ctx.define_int("VECT_DT_N", conf.vect_size);
336
337 kernel_ctx.set_data_type(conf.src_data_type);
338
339 kernel_ctx.define_int("VER_8OW16C", conf.ver == ver_8ow16c);
340 kernel_ctx.define_int("VER_16MB16C", conf.ver == ver_16mb16c);
341
342 kernel_ctx.define_int("SRC_NHWC", utils::one_of(conf.src_tag, nhwc));
343 kernel_ctx.define_int(
344 "SRC_16N16C", utils::one_of(conf.src_tag, NChw16n16c));
345 kernel_ctx.define_int("SRC_W16C", utils::one_of(conf.src_tag, nChw16c));
346
347 kernel_ctx.define_int(
348 "WEI_16I16O", utils::one_of(conf.wei_tag, gOIhw16i16o, OIhw16i16o));
349 kernel_ctx.define_int("WEI_16I16O_FLIPPED",
350 utils::one_of(conf.wei_tag, gIOhw16i16o, IOhw16i16o));
351
352 kernel_ctx.define_int("DST_NHWC", utils::one_of(conf.src_tag, nhwc));
353 kernel_ctx.define_int(
354 "DST_16N16C", utils::one_of(conf.dst_tag, NChw16n16c));
355 kernel_ctx.define_int("DST_W16C", utils::one_of(conf.dst_tag, nChw16c));
356
357 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
358
359 def_attr_info(kernel_ctx, conf.attr_info);
360
361 kernel_ctx.print_options();
362 return status::success;
363 }
364
execute_forward(const exec_ctx_t & ctx) const365 status_t gen9_wino_convolution_fwd_t::execute_forward(
366 const exec_ctx_t &ctx) const {
367
368 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
369 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
370 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
371 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
372
373 const auto &conf = pd()->conf;
374 const auto &attr_info = conf.attr_info;
375
376 std::unique_ptr<memory_storage_t> wei_trans
377 = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_U);
378 compute::kernel_arg_list_t wei_transform_args;
379 wei_transform_args.set(0, *wei_trans);
380 wei_transform_args.set(1, weights);
381 auto wei_trans_nd_range = compute::nd_range_t(conf.U_gws_d, conf.U_lws_d);
382 status_t status = parallel_for(
383 ctx, wei_trans_nd_range, wei_trans_kernel_, wei_transform_args);
384
385 if (conf.is_fused) {
386 compute::kernel_arg_list_t arg_list;
387 arg_list.set(0, dst);
388 arg_list.set(1, src);
389 arg_list.set(2, *wei_trans);
390 arg_list.set(3, bias);
391 append_post_ops_to_arg_list(ctx, arg_list, 4, attr_info.all_post_ops);
392 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
393 status = parallel_for(ctx, nd_range, kernel_, arg_list);
394 } else {
395 std::unique_ptr<memory_storage_t> src_trans
396 = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_V);
397 compute::kernel_arg_list_t src_transform_args;
398 src_transform_args.set(0, *src_trans);
399 src_transform_args.set(1, src);
400 auto src_trans_nd_range
401 = compute::nd_range_t(conf.V_gws_d, conf.V_lws_d);
402 status = parallel_for(
403 ctx, src_trans_nd_range, src_trans_kernel_, src_transform_args);
404
405 std::unique_ptr<memory_storage_t> M_buf
406 = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_M);
407 compute::kernel_arg_list_t arg_list;
408 arg_list.set(0, *M_buf);
409 arg_list.set(1, *src_trans);
410 arg_list.set(2, *wei_trans);
411 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
412 status = parallel_for(ctx, nd_range, kernel_, arg_list);
413
414 compute::kernel_arg_list_t dst_transform_args;
415 dst_transform_args.set(0, dst);
416 dst_transform_args.set(1, *M_buf);
417 dst_transform_args.set(2, bias);
418 append_post_ops_to_arg_list(
419 ctx, dst_transform_args, 3, attr_info.all_post_ops);
420 auto dst_trans_nd_range
421 = compute::nd_range_t(conf.M_gws_d, conf.M_lws_d);
422 status = parallel_for(
423 ctx, dst_trans_nd_range, dst_trans_kernel_, dst_transform_args);
424 }
425
426 if (attr_info.with_eltwise
427 && !gpu_eltwise_fwd_pd_t::eltwise_preserves_zero(
428 attr_info.eltwise_alg, attr_info.eltwise_alpha,
429 attr_info.eltwise_beta)) {
430 ctx.zero_pad_output(DNNL_ARG_DST);
431 }
432 return status;
433 }
434 } // namespace ocl
435 } // namespace gpu
436 } // namespace impl
437 } // namespace dnnl
438
439 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
440