1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name,unused-variable,unused-argument,no-member 18"""Conv2D schedule on x86""" 19 20import logging 21import re 22 23import tvm 24from tvm import autotvm 25from tvm.autotvm.task.topi_integration import deserialize_args 26from tvm.autotvm.task import get_config 27from .. import generic, tag 28from .. import nn 29from ..nn.conv2d import conv2d, conv2d_NCHWc, \ 30 conv2d_infer_layout, _get_workload as _get_conv2d_workload 31from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload 32from ..nn.pad import pad 33from ..util import get_const_tuple 34 35from . import conv2d_avx_1x1, conv2d_avx_common 36 37logger = logging.getLogger('topi') 38 39def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, 40 layout='NCHW'): 41 """ 42 Get default schedule config for the workload 43 """ 44 static_data_shape = [] 45 for dim in get_const_tuple(data.shape): 46 if isinstance(dim, tvm.expr.Var): 47 static_data_shape.append(1) 48 else: 49 static_data_shape.append(dim) 50 data = tvm.placeholder(static_data_shape, dtype=data.dtype) 51 if is_depthwise: 52 wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype) 53 from .depthwise_conv2d import _fallback_schedule 54 _fallback_schedule(cfg, wkl) 55 else: 56 wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) 57 is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 58 if is_kernel_1x1: 59 conv2d_avx_1x1._fallback_schedule(cfg, wkl) 60 else: 61 conv2d_avx_common._fallback_schedule(cfg, wkl) 62 63def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): 64 """Create schedule configuration from input arguments""" 65 dshape = get_const_tuple(data.shape) 66 kshape = get_const_tuple(kernel.shape) 67 pat = re.compile(r'NCHW.+(\d+)c') 68 if layout == 'NCHW': 69 n, ic, h, w = dshape 70 oc, _, kh, kw = kshape 71 elif layout == 'NHWC': 72 n, h, w, ic = dshape 73 kh, kw, oc, _ = kshape 74 elif pat.match(layout) is not None: 75 n, ic_chunk, h, w, ic_bn = dshape 76 target = tvm.target.current_target(allow_none=False) 77 oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape 78 assert ic_chunk == k_ic_chunk 79 assert ic_bn == k_ic_bn 80 ic = ic_chunk*ic_bn 81 oc = oc_chunk*oc_bn 82 else: 83 raise ValueError("Not support this layout {} with " 84 "schedule template.".format(layout)) 85 86 is_kernel_1x1 = kh == 1 and kw == 1 87 ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) 88 sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) 89 oh = (h - kh + 2 * ph) // sh + 1 90 ow = (w - kw + 2 * pw) // sw + 1 91 92 # Create schedule config 93 cfg.define_split("tile_ic", ic, num_outputs=2) 94 cfg.define_split("tile_oc", oc, num_outputs=2) 95 cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64) 96 if is_kernel_1x1: 97 cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1]) 98 else: 99 cfg.define_knob("unroll_kw", [True, False]) 100 101 102@autotvm.register_topi_compute(conv2d, 'cpu', ['direct']) 103def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): 104 out_dtype = data.dtype if out_dtype is None else out_dtype 105 padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) 106 strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) 107 dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) 108 109 if layout == 'NCHW': 110 _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout) 111 if cfg.is_fallback: 112 _get_default_config(cfg, data, kernel, strides, padding, out_dtype) 113 return _declaration_conv_impl(cfg, data, kernel, strides, 114 padding, dilation, layout, out_dtype) 115 116 # HWOI kernel layout is for NHWC and HWCN 117 kh, kw, _, _ = get_const_tuple(kernel.shape) 118 if layout == 'HWCN': 119 return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) 120 # FIXME - https://github.com/apache/incubator-tvm/issues/4122 121 # _declaration_conv_nhwc_pack expects kernel layout to be HWOI. However, the tests use HWIO 122 # layout. Commenting until we have clarity about the nhwc_pack implementation from the author. 123 # elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8": 124 # if cfg.is_fallback: 125 # _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout) 126 # # specialize for INT8 1X1 conv on X86 127 # return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides, 128 # padding, dilation, out_dtype) 129 elif layout == 'NHWC': 130 return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) 131 raise ValueError("not support this layout {} yet".format(layout)) 132 133 134def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): 135 out_dtype = data.dtype if out_dtype is None else out_dtype 136 assert layout == 'NCHW', "only support NCHW convolution for AVX" 137 138 assert isinstance(dilation, int) or len(dilation) == 2 139 if isinstance(dilation, int): 140 dilation_h, dilation_w = dilation 141 else: 142 dilation_h, dilation_w = dilation 143 144 HPAD, WPAD = padding 145 HSTR, WSTR = strides 146 147 batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) 148 num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape) 149 150 pad_height = in_height + 2 * HPAD 151 pad_width = in_width + 2 * WPAD 152 153 dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 154 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 155 out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1 156 out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1 157 158 # pack data 159 DOPAD = (HPAD != 0 or WPAD != 0) 160 if DOPAD: 161 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") 162 else: 163 data_pad = data 164 165 # fetch schedule 166 ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] 167 168 shape = (batch_size, in_channel // ic_bn, pad_height, ic_bn, pad_width) 169 data_vec = tvm.compute(shape, 170 lambda n, C, h, c, w: data_pad[n, C * ic_bn + c, h, w], 171 name='data_vec') 172 173 # pack kernel 174 shape = (num_filter//oc_bn, in_channel//ic_bn, 175 kernel_height, kernel_width, ic_bn, oc_bn) 176 kernel_vec = tvm.compute(shape, 177 lambda CO, CI, h, w, ci, co: 178 kernel[CO * oc_bn + co, CI * ic_bn + ci, h, w], 179 name='kernel_vec') 180 181 # convolution 182 oshape = (batch_size, num_filter//oc_bn, out_height, out_width, oc_bn) 183 unpack_shape = (batch_size, num_filter, out_height, out_width) 184 185 ic = tvm.reduce_axis((0, in_channel), name='ic') 186 kh = tvm.reduce_axis((0, kernel_height), name='kh') 187 kw = tvm.reduce_axis((0, kernel_width), name='kw') 188 idxmod = tvm.indexmod 189 idxdiv = tvm.indexdiv 190 191 conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: 192 tvm.sum(data_vec[n, idxdiv(ic, ic_bn), oh*HSTR+kh*dilation_h, 193 idxmod(ic, ic_bn), 194 ow*WSTR+kw*dilation_w].astype(out_dtype) * 195 kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kh, kw, 196 idxmod(ic, ic_bn), 197 oc_block].astype(out_dtype), 198 axis=[ic, kh, kw]), name='conv') 199 200 unpack = tvm.compute(unpack_shape, 201 lambda n, c, h, w: conv[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)] 202 .astype(out_dtype), 203 name='output_unpack', 204 tag='conv2d_nchw') 205 return unpack 206 207 208@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct']) 209def schedule_conv2d(cfg, outs): 210 """Create schedule for tensors""" 211 s = tvm.create_schedule([x.op for x in outs]) 212 scheduled_ops = [] 213 214 def traverse(op): 215 """Traverse operators from computation graph""" 216 # inline all one-to-one-mapping operators except the last stage (output) 217 if tag.is_broadcast(op.tag): 218 if op not in s.outputs: 219 s[op].compute_inline() 220 for tensor in op.input_tensors: 221 if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: 222 traverse(tensor.op) 223 224 if 'conv2d_nchw' in op.tag: 225 output = op.output(0) 226 conv_out = op.input_tensors[0] 227 kernel_vec = conv_out.op.input_tensors[1] 228 kernel = kernel_vec.op.input_tensors[0] 229 if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: 230 s[kernel].compute_inline() 231 data_vec = conv_out.op.input_tensors[0] 232 data = data_vec.op.input_tensors[0] 233 data_pad = None 234 if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: 235 data_pad = data 236 data = data_pad.op.input_tensors[0] 237 238 _, _, kh, kw = get_const_tuple(kernel.shape) 239 is_kernel_1x1 = kh == 1 and kw == 1 240 args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]] 241 if is_kernel_1x1: 242 conv2d_avx_1x1._schedule_conv(*args) 243 else: 244 conv2d_avx_common._schedule_conv(*args) 245 246 scheduled_ops.append(op) 247 248 traverse(outs[0].op) 249 return s 250 251@generic.schedule_conv2d_nhwc.register("cpu") 252def schedule_conv2d_nhwc(outs): 253 """Create schedule for tensors""" 254 s = tvm.create_schedule([x.op for x in outs]) 255 output_op = outs[0].op 256 scheduled_ops = [] 257 258 def traverse(op): 259 """Traverse operators from computation graph""" 260 # inline all one-to-one-mapping operators except the last stage (output) 261 if tag.is_broadcast(op.tag): 262 if op not in s.outputs: 263 s[op].compute_inline() 264 else: # inject custom schedule 265 if len(op.axis) == 4: # schedule bias + bn + relu 266 n, h, w, c = op.axis 267 fused = s[op].fuse(n, h, w) 268 s[op].parallel(fused) 269 s[op].vectorize(c) 270 for tensor in op.input_tensors: 271 if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: 272 traverse(tensor.op) 273 274 if 'conv2d_nhwc' in op.tag: 275 conv = op.output(0) 276 kernel = op.input_tensors[1] 277 if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: 278 s[kernel].compute_inline() 279 280 data = op.input_tensors[0] 281 data_pad = None 282 if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: 283 data_pad = data 284 data = data_pad.op.input_tensors[0] 285 286 n_pad, h_pad, w_pad, c_pad = data_pad.op.axis 287 pad_fused = s[data_pad].fuse(n_pad, h_pad) 288 s[data_pad].parallel(pad_fused) 289 C = conv 290 n, h, w, c = C.op.axis 291 ry, rx, rc = C.op.reduce_axis 292 n_out, h_out, w_out, c_out = output_op.axis 293 s[C].vectorize(c) 294 if op != output_op: # fuse bias + bn + relu into conv 295 s[C].compute_at(s[output_op], c_out) 296 else: 297 fused = s[C].fuse(n, h, w) 298 s[C].parallel(fused) 299 300 scheduled_ops.append(op) 301 302 traverse(output_op) 303 return s 304 305 306# Define template function for autotvm task 307# We define schedule template in this function instead of 308# declaration function since actual input arguments need 309# to be altered by the schedule selected. 310@autotvm.task.register("topi_x86_conv2d_NCHWc") 311def _topi_nn_conv2d_NCHWc(*args, **kwargs): 312 assert not kwargs, "Do not support kwargs in template function call" 313 args = deserialize_args(args) 314 315 if len(args) == 7: 316 data, kernel, strides, padding, dilation, origin_layout, dtype = args 317 else: 318 assert len(args) == 8 319 data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args 320 321 raw_data_shape = get_const_tuple(data.shape) 322 raw_kernel_shape = get_const_tuple(kernel.shape) 323 324 # get config here 325 cfg = get_config() 326 _create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout) 327 328 idxdiv = tvm.indexdiv 329 idxmod = tvm.indexmod 330 # change shape with the value in config 331 ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], 332 cfg["tile_ow"].size[-1]) 333 new_data_shape = (raw_data_shape[0], idxdiv(raw_data_shape[1], ic_bn), 334 raw_data_shape[2], raw_data_shape[3], ic_bn) 335 data_layout = "NCHW%dc" % ic_bn 336 out_layout = "NCHW%dc" % oc_bn 337 new_kernel_shape = (idxdiv(raw_kernel_shape[0], oc_bn), 338 idxdiv(raw_kernel_shape[1], ic_bn), 339 raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) 340 new_data = tvm.placeholder(new_data_shape, data.dtype) 341 new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) 342 343 C = _declaration_conv_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation, 344 data_layout, out_layout, dtype) 345 s = _schedule_conv2d_NCHWc(cfg, [C]) 346 return s, [new_data, new_kernel, C] 347 348 349@conv2d_infer_layout.register("cpu") 350def _conv2d_infer_layout(workload, cfg): 351 _, data, kernel, strides, padding, dilation, layout, dtype = workload 352 batch_size, in_channel, in_height, in_width = data[:-1] 353 out_channel, _, k_height, k_width = kernel[:-1] 354 idxdiv = tvm.indexdiv 355 356 out_height = idxdiv(in_height + 2 * padding[0] - k_height, strides[0]) + 1 357 out_width = idxdiv(in_width + 2 * padding[1] - k_width, strides[1]) + 1 358 tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] 359 in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic) 360 in_layout = "NCHW%dc" % tile_ic 361 out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc) 362 out_layout = "NCHW%dc" % tile_oc 363 return ((in_shape, in_layout),), ((out_shape, out_layout),) 364 365 366@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct') 367def _declaration_conv_NCHWc(cfg, data, kernel, strides, 368 padding, dilation, layout, out_layout, out_dtype): 369 # layout and out_layout are not used here, 370 # we keep them for debug convenience when dumping autotvm workload 371 n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) 372 in_channel = ic_chunk * ic_bn 373 oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ 374 get_const_tuple(kernel.shape) 375 num_filter = oc_chunk * oc_bn 376 377 # If no config was set, we can fallback to NCHW config. 378 if cfg.is_fallback: 379 _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), 380 tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), 381 dtype=kernel.dtype), 382 strides, padding, out_dtype) 383 384 return nn.conv2d_NCHWc_compute(data, 385 kernel, 386 strides, 387 padding, 388 dilation, 389 layout, 390 out_layout, 391 out_dtype) 392 393 394@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct']) 395def _schedule_conv2d_NCHWc(cfg, outs): 396 """Create schedule for tensors""" 397 s = tvm.create_schedule([x.op for x in outs]) 398 scheduled_ops = [] 399 400 def traverse(op): 401 """Traverse operators from computation graph""" 402 # inline all one-to-one-mapping operators except the last stage (output) 403 if tag.is_broadcast(op.tag): 404 if op not in s.outputs: 405 s[op].compute_inline() 406 for tensor in op.input_tensors: 407 if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: 408 traverse(tensor.op) 409 410 if 'conv2d_NCHWc' in op.tag: 411 conv_out = op.output(0) 412 kernel = conv_out.op.input_tensors[1] 413 data_vec = conv_out.op.input_tensors[0] 414 data = data_vec.op.input_tensors[0] \ 415 if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ 416 else data_vec 417 if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: 418 data_pad = data 419 data = data_pad.op.input_tensors[0] 420 421 args = [s, cfg, data_vec, conv_out, outs[0]] 422 target = tvm.target.current_target(allow_none=False) 423 _, _, kh, kw, _, _, = get_const_tuple(kernel.shape) 424 if kh == 1 and kw == 1: 425 conv2d_avx_1x1._schedule_conv_NCHWc(*args) 426 else: 427 conv2d_avx_common._schedule_conv_NCHWc(*args) 428 429 scheduled_ops.append(op) 430 431 traverse(outs[0].op) 432 return s 433