1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return, too-many-arguments, too-many-locals, too-many-statements, no-member, too-many-branches, too-many-boolean-expressions 18"""conv2d schedule on Intel Graphics""" 19 20from __future__ import absolute_import as _abs 21 22import tvm 23 24from tvm import autotvm 25from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity 26from tvm.autotvm.task.topi_integration import deserialize_args 27from tvm.autotvm.task import get_config 28from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout 29from ..nn.util import get_pad_tuple 30from ..nn.depthwise_conv2d import depthwise_conv2d_nchw 31from ..nn import pad 32from .. import tag 33from .. import generic 34from .. import util 35from ..util import simplify, get_const_tuple 36 37 38def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): 39 if is_depthwise: 40 raise RuntimeError("Depthwise not supported for intel graphics.") 41 else: 42 batch_size, in_channel, height, width = get_const_tuple(data.shape) 43 out_channel, _, hkernel, _ = get_const_tuple(kernel.shape) 44 HSTR, _ = strides 45 46 ic_bn = 1 47 oc_bn, oc_bn_upper = 16, 16 48 for i in range(oc_bn_upper, 0, -1): 49 if out_channel % i == 0: 50 oc_bn = i 51 break 52 53 if HSTR == 2: 54 if out_channel + hkernel == 515: 55 block_oh = 4 56 block_ow = 4 57 else: 58 block_oh = 4 59 block_ow = 5 60 elif hkernel == 3: 61 if out_channel == 512: 62 block_oh = 2 63 block_ow = 7 64 else: 65 block_oh = 2 66 block_ow = 14 67 else: 68 block_oh = 1 69 block_ow = 16 70 cfg["tile_ic"] = SplitEntity([in_channel // ic_bn, ic_bn]) 71 cfg["tile_oc"] = SplitEntity([out_channel // oc_bn, oc_bn]) 72 cfg["block_oh"] = OtherOptionEntity(block_oh) 73 cfg["block_ow"] = OtherOptionEntity(block_ow) 74 75 76def _create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout): 77 """Create schedule configuration from input arguments""" 78 dshape = get_const_tuple(data.shape) 79 kshape = get_const_tuple(kernel.shape) 80 if layout == 'NCHW': 81 n, ic, h, w = dshape 82 oc, _, kh, kw = kshape 83 else: 84 raise ValueError("Not support this layout {} with " 85 "schedule template.".format(layout)) 86 ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) 87 sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) 88 oh = (h - kh + 2 * ph) // sh + 1 89 ow = (w - kw + 2 * pw) // sw + 1 90 ic_bn_upper = 32 91 oc_bn_upper = 64 92 oc_bn_lower = min(oc, 8) 93 ic_bn_candidates, oc_bn_candidates = [], [] 94 for i in range(1, ic + 1): 95 if ic % i == 0 and i <= ic_bn_upper: 96 ic_bn_candidates.append(i) 97 if not ic_bn_candidates: 98 ic_bn_candidates.append(1) 99 ic_bn_candidates.append(ic) 100 101 for i in range(1, oc + 1): 102 if oc % i == 0 and oc_bn_lower <= i <= oc_bn_upper: 103 oc_bn_candidates.append(i) 104 if not oc_bn_candidates: 105 oc_bn_candidates.append(1) 106 oc_bn_candidates.append(oc) 107 108 blk_candidates_low_limits = 5 109 blk_oh_list, blk_ow_list = [], [] 110 for i, j in zip(range(oh, 0, -1), range(ow, 0, -1)): 111 if i <= 16 and oh % i == 0: 112 blk_oh_list.append(i) 113 if j <= 16 and ow % j == 0: 114 blk_ow_list.append(j) 115 116 if len(blk_oh_list) < blk_candidates_low_limits: 117 for i in range(2, oh): 118 if i not in blk_oh_list: 119 blk_oh_list.append(i) 120 if len(blk_oh_list) >= 5: 121 break 122 123 if len(blk_ow_list) < blk_candidates_low_limits: 124 for i in range(min(ow - 1, 16), 1, -1): 125 if i not in blk_ow_list: 126 blk_ow_list.append(i) 127 if len(blk_ow_list) >= 5: 128 break 129 130 # Create schedule config 131 cfg.define_knob("tile_ic", ic_bn_candidates) 132 cfg.define_knob("tile_oc", oc_bn_candidates) 133 cfg.define_knob("block_oh", blk_oh_list) 134 cfg.define_knob("block_ow", blk_ow_list) 135 136 137##### SCHEDULE UTILITIES ##### 138def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): 139 """ tile and bind 3d """ 140 y_factor = y_factor or z_factor 141 x_factor = x_factor or y_factor 142 zo, zi = s[tensor].split(z, z_factor) 143 yo, yi = s[tensor].split(y, y_factor) 144 xo, xi = s[tensor].split(x, x_factor) 145 s[tensor].reorder(zo, yo, xo, zi, yi, xi) 146 147 thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z") 148 thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y") 149 thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x") 150 s[tensor].bind(zo, tvm.thread_axis("blockIdx.z")) 151 s[tensor].bind(zi, thread_z) 152 s[tensor].bind(yo, tvm.thread_axis("blockIdx.y")) 153 s[tensor].bind(yi, thread_y) 154 s[tensor].bind(xo, tvm.thread_axis("blockIdx.x")) 155 s[tensor].bind(xi, thread_x) 156 return xi, thread_z, thread_y, thread_x 157 158# Define template function for autotvm task 159# We define schedule template in this function instead of 160# declaration function since actual input arguments need 161# to be altered by the schedule selected. 162@autotvm.task.register("topi_intel_graphics_conv2d_NCHWc") 163def __topi_nn_conv2d_NCHWc(*args, **kwargs): 164 assert not kwargs, "Do not support kwargs in template function call" 165 data, kernel, strides, padding, dilation, layout, dtype = deserialize_args(args) 166 raw_data_shape = get_const_tuple(data.shape) 167 raw_kernel_shape = get_const_tuple(kernel.shape) 168 169 # get config here 170 cfg = get_config() 171 _create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout) 172 cfg.add_flop(1) 173 174 # change shape with the value in config 175 ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1] 176 oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1] 177 178 new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn, 179 raw_data_shape[2], raw_data_shape[3], ic_bn) 180 new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn, 181 raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) 182 new_data = tvm.placeholder(new_data_shape, data.dtype) 183 new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) 184 185 C = _decl_cl_spatialpack_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation, dtype) 186 s = _schedule_conv2d_NCHWc(cfg, [C]) 187 188 return s, [new_data, new_kernel, C] 189 190@conv2d_alter_layout.register(["intel_graphics"]) 191def _alter_conv2d_layout(attrs, inputs, tinfo, F): 192 import nnvm.symbol as sym 193 194 copy_inputs = [s for s in inputs] 195 new_attrs = {k : attrs[k] for k in attrs.keys()} 196 197 if F.__name__ == 'tvm.relay.op': 198 # Derive channels for frontends (e.g ONNX) that miss "channel" field. 199 new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] 200 201 data, kernel = tinfo[0], tinfo[1] 202 batch_size, in_channel, height, width = get_const_tuple(data.shape) 203 204 groups = attrs.get_int("groups") 205 out_channel = attrs.get_int("channels") 206 padding = attrs.get_int_tuple("padding") 207 strides = attrs.get_int_tuple("strides") 208 dilation = attrs.get_int_tuple("dilation") 209 out_dtype = attrs["out_dtype"] 210 211 layout_name = 'layout' if F == sym else 'data_layout' 212 layout = attrs[layout_name] 213 kh, kw = attrs.get_int_tuple("kernel_size") 214 215 dtype = data.dtype 216 out_dtype = dtype if out_dtype in ("same", "") else out_dtype 217 is_depthwise = groups == in_channel and groups == out_channel 218 219 # only optimize for NCHW 220 if layout != 'NCHW': 221 return None 222 if groups != 1 and not is_depthwise: 223 return None 224 225 dispatch_ctx = autotvm.task.DispatchContext.current 226 target = tvm.target.current_target() 227 228 # query schedule and fallback if necessary 229 workload = autotvm.task.args_to_workload( 230 [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \ 231 if is_depthwise else \ 232 autotvm.task.args_to_workload( 233 [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) 234 if is_depthwise: 235 return None 236 cfg = dispatch_ctx.query(target, workload) 237 if cfg.is_fallback: 238 _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise) 239 240 ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1] 241 oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1] 242 243 new_attrs[layout_name] = 'NCHW%dc' % ic_bn 244 new_attrs['out_layout'] = 'NCHW%dc' % oc_bn 245 246 new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), 247 dtype=data.dtype) 248 249 out_channel, _, kh, kw = get_const_tuple(kernel.shape) 250 # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) 251 new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) 252 253 # Store altered operator's config 254 new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), 255 dtype=kernel.dtype) 256 new_workload = autotvm.task.args_to_workload( 257 [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], 258 new_attrs['out_layout'], out_dtype], conv2d_NCHWc) 259 260 dispatch_ctx.update(target, new_workload, cfg) 261 if F == sym: 262 return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) 263 return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) 264 265@autotvm.register_topi_compute(conv2d_NCHWc, 'intel_graphics', 'direct') 266def _decl_conv2d(cfg, data, kernel, strides, padding, dilation, 267 layout, out_layout, out_dtype='float32'): 268 """Conv2D operator for Intel Graphics backend. 269 270 Parameters 271 ---------- 272 data : tvm.Tensor 273 4-D with shape [batch, in_channel, in_height, in_width] 274 275 kernel : tvm.Tensor 276 5-D with shape [num_filter, in_channel, filter_height, filter_width, nnum_filter_vec] 277 278 stride : int or a list/tuple of two ints 279 stride size, or [stride_height, stride_width] 280 281 padding : int or a list/tuple of two ints 282 padding size, or [pad_height, pad_width] 283 284 layout : str 285 layout of data 286 287 Returns 288 ------- 289 output : tvm.Tensor 290 4-D with shape [batch, out_channel, out_height, out_width] 291 """ 292 dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) 293 assert (dh, dw) == (1, 1), "Does not support dilation" 294 295 n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) 296 oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) 297 in_channel = ic_chunk * ic_bn 298 num_filter = oc_chunk * oc_bn 299 if cfg.is_fallback: 300 _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), 301 tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), 302 dtype=kernel.dtype), 303 strides, padding, out_dtype) 304 305 return _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, out_dtype) 306 307 308@conv2d_infer_layout.register("intel_graphics") 309def _conv2d_infer_layout(workload, cfg): 310 _, data, kernel, strides, padding, dilation, layout, dtype = workload 311 batch_size, in_channel, in_height, in_width = data[:-1] 312 out_channel, _, k_height, k_width = kernel[:-1] 313 out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 314 out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1 315 tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] 316 in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic) 317 in_layout = "NCHW%dc" % tile_ic 318 out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc) 319 out_layout = "NCHW%dc" % tile_oc 320 return ((in_shape, in_layout),), ((out_shape, out_layout),) 321 322 323@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'intel_graphics', ['direct']) 324def _schedule_conv2d_NCHWc(cfg, outs): 325 """Schedule for conv2d_nchw for Intel Graphics 326 327 Parameters 328 ---------- 329 outs: Array of Tensor 330 The computation graph description of conv2d_nchw 331 in the format of an array of tensors. 332 333 Returns 334 ------- 335 s: Schedule 336 The computation schedule for conv2d_nchw. 337 """ 338 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 339 s = tvm.create_schedule([x.op for x in outs]) 340 scheduled_ops = [] 341 342 def traverse(op): 343 """inline all one-to-one-mapping operators except the last stage (output)""" 344 if tag.is_injective(op.tag): 345 if op not in s.outputs: 346 s[op].compute_inline() 347 for tensor in op.input_tensors: 348 if tensor.op.input_tensors and tensor.op not in scheduled_ops: 349 traverse(tensor.op) 350 if "conv" in op.tag: 351 _schedule_cl_spatialpack_NCHWc(cfg, s, op) 352 353 scheduled_ops.append(op) 354 355 traverse(outs[0].op) 356 357 return s 358 359def _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, out_dtype='float16'): 360 batch, in_channel, in_height, in_width, vc = [util.get_const_int(x) for x in data.shape] 361 in_channel *= vc 362 num_filter, channel, kernel_h, kernel_w, ci, co = [util.get_const_int(x) for x in kernel.shape] 363 num_filter *= co 364 pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel) 365 366 ic_bn = vc 367 assert vc == ci 368 369 if isinstance(strides, (tuple, list)): 370 stride_h, stride_w = strides 371 else: 372 stride_h, stride_w = strides, strides 373 374 out_channel = num_filter 375 out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) 376 out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) 377 oshape = (batch, out_channel // co, out_height, out_width, co) 378 379 rc = tvm.reduce_axis((0, in_channel), name='rc') 380 ry = tvm.reduce_axis((0, kernel_h), name='ry') 381 rx = tvm.reduce_axis((0, kernel_w), name='rx') 382 383 block_h = cfg["block_oh"].val 384 block_w = cfg["block_ow"].val 385 386 c_h = out_height 387 c_w = out_width 388 389 if out_height % block_h != 0: 390 c_h = (out_height // block_h + 1) * block_h 391 392 if out_width % block_w != 0: 393 c_w = (out_width // block_w + 1) * block_w 394 395 cshape = (batch, out_channel // co, c_h, c_w, co) 396 397 pad_before = [0, 0, pad_top, pad_left, 0] 398 pad_after = [0, 0, pad_down + c_h - out_height, pad_right + \ 399 c_w - out_width, 0] 400 DOPAD = (pad_top != 0 or pad_left != 0 or pad_down + c_h - out_height != 0 \ 401 or pad_right + c_w - out_width != 0) 402 DOUNPACK = (c_h - out_height != 0 or c_w - out_width != 0) 403 if DOPAD: 404 temp = pad(data, pad_before, pad_after, name="pad_temp") 405 else: 406 temp = data 407 408 conv = tvm.compute( 409 cshape, 410 lambda nn, ff, yy, xx, ff_v: \ 411 tvm.sum( 412 temp[nn, rc//ic_bn, yy * stride_h + ry, xx * stride_w + rx, rc%ic_bn]. \ 413 astype(out_dtype) * 414 kernel[ff, rc//ic_bn, ry, rx, rc%ic_bn, ff_v].astype(out_dtype), 415 axis=[rc, ry, rx]), tag="conv", name='conv') 416 417 if DOUNPACK: 418 output = tvm.compute( 419 oshape, 420 lambda nn, ff, yy, xx, ff_v: 421 conv[nn][ff][yy][xx][ff_v], 422 name='output_unpack', tag="conv_unpack") 423 else: 424 output = conv 425 426 427 return output 428 429 430def _schedule_cl_spatialpack_NCHWc(cfg, s, op): 431 output = op.output(0) 432 conv = op.input_tensors[0] 433 if conv.op.name == "conv": 434 temp = s[conv].op.input_tensors[0] 435 kernel = s[conv].op.input_tensors[1] 436 temp_W = s.cache_read(temp, "warp", [conv]) 437 conv_L = s.cache_write(conv, "local") 438 SCHEDULE_OUTPUT = True 439 else: 440 temp = op.input_tensors[0] 441 kernel = op.input_tensors[1] 442 temp_W = s.cache_read(temp, "warp", [output]) 443 conv_L = s.cache_write(output, "local") 444 if output.op in s.outputs: 445 conv = output 446 else: 447 s[output].compute_inline() 448 conv = s.outputs[0] 449 SCHEDULE_OUTPUT = False 450 kernel_L = s.cache_read(kernel, "local", [conv_L]) 451 452 OUTPUT_BLOCK_HEIGHT = cfg["block_oh"].val 453 OUTPUT_BLOCK_WIDTH = cfg["block_ow"].val 454 455 # schedule conv 456 z_factor = 1 457 y_factor = 1 458 x_factor = 16 459 thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z") 460 thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y") 461 thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x") 462 _, co, oh, ow, vc = s[conv].op.axis 463 ooh, ioh = s[conv].split(oh, factor=OUTPUT_BLOCK_HEIGHT) 464 oow, iow = s[conv].split(ow, factor=OUTPUT_BLOCK_WIDTH) 465 s[conv].reorder(_, co, ooh, oow, vc, ioh, iow) 466 coo, coi = s[conv].split(co, nparts=1) 467 ooho, oohi = s[conv].split(ooh, factor=z_factor) 468 oowo, oowi = s[conv].split(oow, factor=y_factor) 469 vco, vci = s[conv].split(vc, factor=x_factor) 470 s[conv].reorder(_, coo, vco, ooho, oowo, coi, oohi, oowi, vci, ioh, iow) 471 s[conv].bind(oohi, thread_z) 472 s[conv].bind(oowi, thread_y) 473 s[conv].bind(vci, thread_x) 474 s[conv].bind(ooho, tvm.thread_axis("blockIdx.z")) 475 s[conv].bind(oowo, tvm.thread_axis("blockIdx.y")) 476 s[conv].bind(coi, tvm.thread_axis("blockIdx.x")) 477 478 # schedule conv_L 479 s[conv_L].compute_at(s[conv], vci) 480 i, oc, h, w, vc = s[conv_L].op.axis 481 rc, ry, rx = s[conv_L].op.reduce_axis 482 s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w) 483 s[temp_W].compute_at(s[conv_L], rc) 484 if kernel.shape[3].value != 7: 485 s[conv_L].unroll(ry) 486 s[conv_L].unroll(rx) 487 488 # schedule temp 489 if temp.op.name == "pad_temp": 490 _, ci, h, w, vci = s[temp].op.axis 491 tile_and_bind3d(s, temp, ci, h, w, 1, 16, 16) 492 493 # schedule temp_W 494 _, ci, h, w, vci = s[temp_W].op.axis 495 zo, zi = s[temp_W].split(vci, 1) 496 yo, yi = s[temp_W].split(h, 1) 497 xo, xi = s[temp_W].split(w, 16) 498 s[temp_W].reorder(zo, yo, xo, zi, yi, xi) 499 s[temp_W].bind(zi, thread_z) 500 s[temp_W].bind(yi, thread_y) 501 s[temp_W].bind(xi, thread_x) 502 s[temp_W].storage_align(s[temp_W].op.axis[2], 16, 0) 503 504 # schedule kernel_L 505 if OUTPUT_BLOCK_HEIGHT == 2 and OUTPUT_BLOCK_WIDTH == 14: 506 s[kernel_L].compute_at(s[conv_L], ry) 507 else: 508 s[kernel_L].compute_at(s[conv_L], rx) 509 510 # schedule output 511 if SCHEDULE_OUTPUT: 512 if output.op in s.outputs: 513 out = output 514 else: 515 s[output].compute_inline() 516 out = s.outputs[0] 517 518 _, co, h, w, vc = s[out].op.axis 519 tile_and_bind3d(s, out, w, h, vc, 4, 8, 8) 520 521 522def conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype): 523 """convert argument to workload""" 524 if len(kernel.shape) == 4: 525 raw_kernel = kernel 526 else: # the input kernel is transformed by alter_op_layout 527 shape = get_const_tuple(kernel.shape) 528 raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]), 529 dtype=kernel.dtype) 530 return ('conv2d', ) + autotvm.task.args_to_workload( 531 [data, raw_kernel, strides, padding, layout, out_dtype]) 532 533@autotvm.register_topi_compute(conv2d, 'intel_graphics', 'direct') 534def decl_conv2d(cfg, data, kernel, stride, padding, dilation, layout='NCHW', out_dtype='float32'): 535 """Conv2D operator for Intel Graphics backend. 536 537 Parameters 538 ---------- 539 data : tvm.Tensor 540 4-D with shape [batch, in_channel, in_height, in_width] 541 kernel : tvm.Tensor 542 4-D with shape [num_filter, in_channel, filter_height, filter_width] 543 stride : int or a list/tuple of two ints 544 stride size, or [stride_height, stride_width] 545 padding : int or a list/tuple of two ints 546 padding size, or [pad_height, pad_width] 547 layout : str 548 layout of data 549 Returns 550 ------- 551 output : tvm.Tensor 552 4-D with shape [batch, out_channel, out_height, out_width] 553 """ 554 assert layout == 'NCHW', "only support NCHW convolution on intel gpu" 555 assert data.shape[0].value == 1, "only support batch size=1 convolution on intel gpu" 556 assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." 557 558 return _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype) 559 560@autotvm.task.register_topi_schedule(generic.schedule_conv2d_nchw, 'intel_graphics', ['direct']) 561def schedule_conv2d_nchw(cfg, outs): 562 """Schedule for conv2d_nchw for Intel Graphics 563 564 Parameters 565 ---------- 566 outs: Array of Tensor 567 The computation graph description of conv2d_nchw 568 in the format of an array of tensors. 569 Returns 570 ------- 571 s: Schedule 572 The computation schedule for conv2d_nchw. 573 """ 574 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 575 s = tvm.create_schedule([x.op for x in outs]) 576 scheduled_ops = [] 577 578 def traverse(op): 579 """inline all one-to-one-mapping operators except the last stage (output)""" 580 if tag.is_broadcast(op.tag): 581 if op not in s.outputs: 582 s[op].compute_inline() 583 for tensor in op.input_tensors: 584 if tensor.op.input_tensors and tensor.op not in scheduled_ops: 585 traverse(tensor.op) 586 if 'conv2d' in op.tag: 587 _schedule_cl_spatialpack(cfg, s, op) 588 589 scheduled_ops.append(op) 590 591 traverse(outs[0].op) 592 return s 593 594def _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype='float16'): 595 batch, in_channel, in_height, in_width = [util.get_const_int(x) for x in data.shape] 596 num_filter, channel, kernel_h, kernel_w = [util.get_const_int(x) for x in kernel.shape] 597 pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel) 598 599 if isinstance(stride, (tuple, list)): 600 stride_h, stride_w = stride 601 else: 602 stride_h, stride_w = stride, stride 603 604 out_channel = num_filter 605 out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) 606 out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) 607 oshape = (batch, out_channel, out_height, out_width) 608 609 rc = tvm.reduce_axis((0, in_channel), name='rc') 610 ry = tvm.reduce_axis((0, kernel_h), name='ry') 611 rx = tvm.reduce_axis((0, kernel_w), name='rx') 612 613 block_w = 1 614 block_h = 1 615 if stride_h == 2: 616 if num_filter + kernel_h == 515: 617 block_h = 4 618 block_w = 4 619 else: 620 block_h = 4 621 block_w = 5 622 elif kernel_h == 3: 623 if num_filter == 512: 624 block_h = 2 625 block_w = 7 626 else: 627 block_h = 2 628 block_w = 14 629 elif kernel_h == 7 and padding == 3 and stride == 1: 630 block_h = 3 631 block_w = 4 632 else: 633 block_h = 1 634 block_w = 16 635 attrs = {'block_h': block_h, 'block_w' : block_w} 636 c_h = out_height 637 c_w = out_width 638 639 if out_height % block_h != 0: 640 c_h = (out_height // block_h + 1) * block_h 641 642 if out_width % block_w != 0: 643 c_w = (out_width // block_w + 1) * block_w 644 645 pad_before = [0, 0, pad_top, pad_left] 646 pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w] 647 temp = pad(data, pad_before, pad_after, name="pad_temp") 648 649 nv = 16 650 if num_filter % nv != 0: 651 num_filter = (num_filter // nv + 1) * nv 652 out_channel = num_filter 653 654 cshape = (batch, out_channel // nv, c_h, c_w, nv) 655 kvshape = (num_filter // nv, channel, kernel_h, kernel_w, nv) 656 657 kernel_vec = tvm.compute( 658 kvshape, 659 lambda co, ci, kh, kw, vc: 660 kernel[co*nv + vc][ci][kh][kw], name='kernel_vec') 661 662 conv = tvm.compute( 663 cshape, 664 lambda nn, ff, yy, xx, vc: \ 665 tvm.sum( 666 temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) * 667 kernel_vec[ff, rc, ry, rx, vc].astype(out_dtype), 668 axis=[rc, ry, rx]), name='conv', attrs=attrs) 669 670 output = tvm.compute( 671 oshape, 672 lambda nn, ff, yy, xx: 673 conv[nn][ff//nv][yy][xx][ff%nv], 674 name='output_unpack', tag='conv2d', 675 attrs={'workload': conv_arg_to_workload(data, kernel, stride, padding, 676 layout, out_dtype)}) 677 678 return output 679 680def _schedule_cl_spatialpack(cfg, s, op): 681 output = op.output(0) 682 _, _, out_height, out_width = [util.get_const_int(x) for x in output.shape] 683 684 conv = op.input_tensors[0] 685 temp = s[conv].op.input_tensors[0] 686 kernel_vec = s[conv].op.input_tensors[1] 687 kernel = s[kernel_vec].op.input_tensors[0] 688 temp_W = s.cache_read(temp, "warp", [conv]) 689 conv_L = s.cache_write(conv, "local") 690 691 kernel_L = s.cache_read(kernel_vec, "local", [conv_L]) 692 _, in_channel, temp_h, temp_w = [util.get_const_int(x) for x in temp.shape] 693 694 attrs = s[conv].op.attrs 695 OUTPUT_BLOCK_HEIGHT = attrs['block_h'] 696 OUTPUT_BLOCK_WIDTH = attrs['block_w'] 697 698 # schedule conv 699 z_factor = 1 700 y_factor = 1 701 x_factor = 16 702 thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z") 703 thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y") 704 thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x") 705 _, co, oh, ow, vc = s[conv].op.axis 706 ooh, ioh = s[conv].split(oh, factor=OUTPUT_BLOCK_HEIGHT) 707 oow, iow = s[conv].split(ow, factor=OUTPUT_BLOCK_WIDTH) 708 s[conv].reorder(_, co, ooh, oow, vc, ioh, iow) 709 coo, coi = s[conv].split(co, nparts=1) 710 ooho, oohi = s[conv].split(ooh, factor=z_factor) 711 oowo, oowi = s[conv].split(oow, factor=y_factor) 712 vco, vci = s[conv].split(vc, factor=x_factor) 713 s[conv].reorder(_, coo, vco, ooho, oowo, coi, oohi, oowi, vci, ioh, iow) 714 s[conv].bind(oohi, thread_z) 715 s[conv].bind(oowi, thread_y) 716 s[conv].bind(vci, thread_x) 717 s[conv].bind(ooho, tvm.thread_axis("blockIdx.z")) 718 s[conv].bind(oowo, tvm.thread_axis("blockIdx.y")) 719 s[conv].bind(coi, tvm.thread_axis("blockIdx.x")) 720 721 # schedule conv_L 722 s[conv_L].compute_at(s[conv], vci) 723 i, oc, h, w, vc = s[conv_L].op.axis 724 rc, ry, rx = s[conv_L].op.reduce_axis 725 s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w) 726 s[temp_W].compute_at(s[conv_L], rc) 727 if kernel.shape[3].value != 7: 728 s[conv_L].unroll(ry) 729 s[conv_L].unroll(rx) 730 731 # schedule temp 732 _, ci, h, w = s[temp].op.axis 733 tile_and_bind3d(s, temp, ci, h, w, 1, 16, 16) 734 735 # schedule temp_W 736 _, ci, h, w = s[temp_W].op.axis 737 zo, zi = s[temp_W].split(ci, 1) 738 yo, yi = s[temp_W].split(h, 1) 739 xo, xi = s[temp_W].split(w, 16) 740 s[temp_W].reorder(zo, yo, xo, zi, yi, xi) 741 s[temp_W].bind(zi, thread_z) 742 s[temp_W].bind(yi, thread_y) 743 s[temp_W].bind(xi, thread_x) 744 s[temp_W].storage_align(s[temp_W].op.axis[2], 16, 0) 745 746 s[kernel_vec].compute_inline() 747 748 # schedule kernel_L 749 if "2_14" in s[conv].op.tag: 750 s[kernel_L].compute_at(s[conv_L], ry) 751 else: 752 s[kernel_L].compute_at(s[conv_L], rx) 753 754 # schedule output 755 if output.op in s.outputs: 756 out = output 757 else: 758 s[output].compute_inline() 759 out = s.outputs[0] 760 761 _, co, h, w = s[out].op.axis 762 tile_and_bind3d(s, out, w, h, co, 4, 8, 8) 763