1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name,unused-variable 18"""Depthwise convolution schedule for ARM CPU""" 19 20import tvm 21from tvm import te 22from tvm import autotvm 23from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity 24 25from .. import nn 26from ..util import traverse_inline, get_const_tuple, get_const_int 27from ..nn.util import get_pad_tuple 28 29 30@autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu") 31def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype): 32 """Compute depthwise_conv2d with NCHW layout""" 33 return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) 34 35 36@autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu") 37def schedule_depthwise_conv2d_nchw(cfg, outs): 38 """Schedule depthwise conv2d 39 40 Parameters 41 ---------- 42 cfg: ConfigEntity 43 The configuration of this template 44 outs: Array of Tensor 45 The computation graph description of depthwise convolution2d 46 in the format of an array of tensors. 47 48 Returns 49 ------- 50 s: Schedule 51 The computation schedule for depthwise_conv2d nchw. 52 """ 53 outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs 54 s = te.create_schedule([x.op for x in outs]) 55 56 def _schedule(cfg, s, data, data_pad, kernel, output): 57 A, B, C = data, kernel, output 58 s[data_pad].compute_inline() 59 60 ##### space definition begin ##### 61 n, c, h, w = s[output].op.axis 62 _, vc = cfg.define_split("tile_c", c, num_outputs=2) 63 _, vh = cfg.define_split("tile_h", h, num_outputs=2) 64 _, vw = cfg.define_split("tile_w", w, num_outputs=2) 65 cfg.define_annotate("ann", [vh, vw, vc], policy="try_unroll_vec") 66 67 # fallback support 68 if cfg.is_fallback: 69 ref_log = autotvm.tophub.load_reference_log( 70 "arm_cpu", "rk3399", "depthwise_conv2d_nchw.arm_cpu" 71 ) 72 cfg.fallback_with_reference_log(ref_log) 73 ##### space definition end ##### 74 75 # park data to vector form [n, c, h, w] -> [n, C, h, w, VC] 76 A0 = s.cache_read(data_pad, "global", C) 77 n, c, h, w = s[A0].op.axis 78 c, vc = cfg["tile_c"].apply(s, A0, c) 79 s[A0].reorder(n, c, h, w, vc) 80 A1 = s.cache_write(A0, "global") 81 s[A0].compute_inline() 82 83 # park kernel to vector form [co, ci, kh, kw] -> [CO, ci, kh, kw, VC] 84 B0 = s.cache_read(B, "global", C) 85 c, m, h, w = s[B0].op.axis 86 c, vc, = cfg[ 87 "tile_c" 88 ].apply(s, B0, c) 89 s[B0].reorder(c, m, h, w, vc) 90 B1 = s.cache_write(B0, "global") 91 s[B0].compute_inline() 92 93 n, c, h, w = s[C].op.axis 94 c, vc, = cfg[ 95 "tile_c" 96 ].apply(s, C, c) 97 s[C].reorder(n, c, h, w, vc) 98 99 # depthwise conv 100 C0 = s.cache_write(C, "global") 101 _, c, h, w, vc = s[C0].op.axis 102 dh, dw = s[C0].op.reduce_axis 103 oh, ih = cfg["tile_h"].apply(s, C0, h) 104 ow, iw = cfg["tile_w"].apply(s, C0, w) 105 s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc) 106 s[A1].compute_at(s[C0], oh) 107 108 # try unroll and vectorization 109 cfg["ann"].apply( 110 s, 111 C0, 112 [ih, iw, vc], 113 axis_lens=[cfg["tile_h"].size[-1], cfg["tile_w"].size[-1], cfg["tile_c"].size[-1]], 114 max_unroll=16, 115 cfg=cfg, 116 ) 117 118 # fusion 119 if C.op not in s.outputs: 120 s[C].compute_inline() 121 122 # mark parallel 123 last = outs[0] 124 n, c, h, w = s[last].op.axis 125 s[last].parallel(c) 126 127 n, c, h, w, vc = s[C0].op.axis 128 s[C0].parallel(c) 129 130 c, m, h, w, vc = s[B1].op.axis 131 s[B1].parallel(c) 132 133 return s 134 135 def _callback(op): 136 if op.tag == "depthwise_conv2d_nchw": 137 output = op.output(0) 138 kernel = op.input_tensors[1] 139 data = op.input_tensors[0] 140 data_pad = None 141 if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag: 142 data_pad = data 143 data = data_pad.op.input_tensors[0] 144 _schedule(cfg, s, data, data_pad, kernel, output) 145 146 traverse_inline(s, outs[0].op, _callback) 147 return s 148 149 150# TODO: 151# This schedule has incorrect result on some hardware platforms (like NV Jetson TX2) 152# Let us comment it out but not remove. 153# see discussion: 154# https://discuss.tvm.ai/t/autotuner-incorrect-result-after-tuning-mobilenetv2-on-arm-cpu/6088 155@autotvm.register_topi_compute("depthwise_conv2d_nchw_spatial_pack.arm_cpu") 156def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype): 157 """TOPI compute callback for depthwise_conv2d nchw 158 159 Parameters 160 ---------- 161 cfg: ConfigEntity 162 The config for this template 163 164 data : tvm.te.Tensor 165 4-D with shape [batch, in_channel, in_height, in_width] 166 167 kernel : tvm.te.Tensor 168 4-D with shape [num_filter, multiplier, filter_height, filter_width] or 169 pre-packed 5-D with shape [num_filter_chunk, multiplier, filter_height, 170 filter_width, num_filter_block] 171 172 strides : list of two ints 173 [stride_height, stride_width] 174 175 padding : list of two ints 176 [pad_height, pad_width] 177 178 dilation : list of two ints 179 [dilation_height, dilation_width] 180 181 out_dtype: str 182 The output type. This is used for mixed precision. 183 184 Returns 185 ------- 186 output : tvm.te.Tensor 187 4-D with shape [batch, out_channel, out_height, out_width] 188 """ 189 190 return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2) 191 192 193@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu") 194def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype): 195 """TOPI compute callback for depthwise_conv2d nhwc 196 197 Parameters 198 ---------- 199 cfg: ConfigEntity 200 The config for this template 201 202 data : tvm.te.Tensor 203 4-D with shape [batch, in_height, in_width, in_channel] 204 205 kernel : tvm.te.Tensor 206 4-D with shape [filter_height, filter_width, in_channel, channel_multiplier] 207 208 strides : list of two ints 209 [stride_height, stride_width] 210 211 padding : list of two ints 212 [pad_height, pad_width] 213 214 dilation : list of two ints 215 [dilation_height, dilation_width] 216 217 out_dtype: str 218 The output type. This is used for mixed precision. 219 220 Returns 221 ------- 222 output : tvm.te.Tensor 223 4-D with shape [batch, out_height, out_width, out_channel] 224 """ 225 226 out_dtype = out_dtype or data.dtype 227 228 N, IH, IW, IC = get_const_tuple(data.shape) 229 230 if isinstance(dilation, int): 231 dilation_h = dilation_w = dilation 232 else: 233 dilation_h, dilation_w = dilation 234 235 KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape) 236 237 dilated_kernel_h = (KH - 1) * dilation_h + 1 238 dilated_kernel_w = (KW - 1) * dilation_w + 1 239 240 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 241 padding, (dilated_kernel_h, dilated_kernel_w) 242 ) 243 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) 244 245 OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 246 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 247 248 if pad_top or pad_left or pad_down or pad_right: 249 data_pad = nn.pad( 250 data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad" 251 ) 252 else: 253 data_pad = data 254 255 output_shape = (N, OH, OW, IC * channel_multiplier) 256 257 idxdiv = tvm.tir.indexdiv 258 idxmod = tvm.tir.indexmod 259 260 reduce_h = te.reduce_axis((0, KH), name="reduce_h") 261 reduce_w = te.reduce_axis((0, KW), name="reduce_w") 262 263 out = te.compute( 264 output_shape, 265 lambda n, h, w, c: te.sum( 266 data_pad[ 267 n, 268 HSTR * h + dilation_h * reduce_h, 269 w * WSTR + reduce_w * dilation_w, 270 idxdiv(c, channel_multiplier), 271 ].astype(out_dtype) 272 * kernel[ 273 reduce_h, reduce_w, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier) 274 ].astype(out_dtype), 275 axis=[reduce_h, reduce_w], 276 ), 277 name="depthwise_conv2d_nhwc_output", 278 ) 279 return out 280 281 282@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu") 283def schedule_depthwise_conv2d_nhwc(cfg, outs): 284 """Create the schedule for depthwise_conv2d_nchw_spatial_pack""" 285 outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs 286 s = te.create_schedule([x.op for x in outs]) 287 out = outs[0] 288 289 ##### space definition begin ##### 290 n, h, w, c = s[out].op.axis 291 cfg.define_split("tile_c", c, num_outputs=2) 292 _, hi = cfg.define_split("tile_h", h, num_outputs=2) 293 _, wi = cfg.define_split("tile_w", w, num_outputs=2) 294 cfg.define_knob("locate_output", [0, 1]) 295 296 # fallback support 297 if cfg.is_fallback: 298 cfg["tile_c"] = SplitEntity([-1, 8]) 299 cfg["tile_h"] = SplitEntity([-1, 2]) 300 cfg["tile_w"] = SplitEntity([-1, 2]) 301 cfg["locate_output"] = OtherOptionEntity(1) 302 ##### space definition end ##### 303 304 def schedule_conv(conv): 305 conv_data = conv.op.input_tensors[0] 306 307 n, w, h, c = conv.op.axis 308 r_h, r_w = conv.op.reduce_axis 309 ho, hi = cfg["tile_h"].apply(s, conv, h) 310 wo, wi = cfg["tile_w"].apply(s, conv, w) 311 co, ci = cfg["tile_c"].apply(s, conv, c) 312 313 if conv_data.name == "data_pad": 314 assert isinstance(conv_data.op, tvm.te.ComputeOp) 315 # Define a policy for padding computation 316 cfg.define_knob("data_pad_inline", [1, 2, 3]) 317 if cfg.is_fallback: 318 cfg["data_pad_inline"] = OtherOptionEntity(3) 319 if cfg["data_pad_inline"].val == 1: 320 s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) 321 s[conv_data].compute_at(s[conv], ho) 322 if cfg["data_pad_inline"].val == 2: 323 s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) 324 s[conv_data].compute_at(s[conv], wo) 325 if cfg["data_pad_inline"].val == 3: 326 s[conv_data].compute_inline() 327 328 s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci) 329 fused_n_ho = s[conv].fuse(n, ho) 330 s[conv].vectorize(ci) 331 return fused_n_ho 332 333 def schedule_conv_out(out): 334 n, h, w, c = out.op.axis 335 co, ci = cfg["tile_c"].apply(s, out, c) 336 wo, wi = cfg["tile_w"].apply(s, out, w) 337 ho, hi = cfg["tile_h"].apply(s, out, h) 338 s[out].reorder(n, ho, wo, co, hi, wi) 339 340 if out.dtype in ["int8", "uint8"]: 341 # In case of quantized convolution further split the channel in batches of 4 elements 342 # so that we can use arm intrinsics to run fixed_point_multiplication 343 ci_outer, ci_inner = s[out].split(ci, 4) 344 s[out].vectorize(ci_inner) 345 346 fused_n_ho = s[out].fuse(n, ho) 347 return hi, wi, fused_n_ho 348 349 def _callback(op): 350 if op.name == "depthwise_conv2d_nhwc_output": 351 conv = op.output(0) 352 if conv != out: 353 hi, wi, p_axis = schedule_conv_out(out) 354 schedule_conv(conv) 355 if cfg["locate_output"].val == 0: 356 s[conv].compute_at(s[out], hi) 357 if cfg["locate_output"].val == 1: 358 s[conv].compute_at(s[out], wi) 359 else: 360 p_axis = schedule_conv(out) 361 362 s[out].parallel(p_axis) 363 364 traverse_inline(s, outs[0].op, _callback) 365 return s 366 367 368@autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu") 369def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs): 370 """Create the schedule for depthwise_conv2d_nchw_spatial_pack""" 371 outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs 372 s = te.create_schedule([x.op for x in outs]) 373 374 def _callback(op): 375 if op.tag == "spatial_depthwise_conv2d_nchw_output": 376 output = op.output(0) 377 conv = op.input_tensors[0] 378 data_vec = conv.op.input_tensors[0] 379 kernel_vec = conv.op.input_tensors[1] 380 if kernel_vec.op.name == "kernel_vec": 381 kernel = kernel_vec.op.input_tensors[0] 382 else: 383 kernel = kernel_vec 384 if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: 385 s[kernel].compute_inline() 386 _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) 387 388 traverse_inline(s, outs[0].op, _callback) 389 return s 390 391 392def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile): 393 out_dtype = out_dtype or data.dtype 394 395 N, C, IH, IW = get_const_tuple(data.shape) 396 397 if isinstance(dilation, int): 398 dilation_h = dilation_w = dilation 399 else: 400 dilation_h, dilation_w = dilation 401 402 if len(kernel.shape) == 4: 403 pre_packed = False 404 C, M, KH, KW = get_const_tuple(kernel.shape) 405 else: # kernel tensor is pre packed 406 pre_packed = True 407 C, M, KH, KW, VC = get_const_tuple(kernel.shape) 408 C = C * VC 409 410 dilated_kernel_h = (KH - 1) * dilation_h + 1 411 dilated_kernel_w = (KW - 1) * dilation_w + 1 412 413 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 414 padding, (dilated_kernel_h, dilated_kernel_w) 415 ) 416 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) 417 OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 418 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 419 # pack data 420 HPAD = pad_top + pad_down 421 WPAD = pad_left + pad_right 422 DOPAD = HPAD != 0 or WPAD != 0 423 if DOPAD: 424 data_pad = nn.pad( 425 data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), name="data_pad" 426 ) 427 else: 428 data_pad = data 429 430 # fallback support 431 # Currently, Mali schedule doesn't use it like conv2d. 432 if cfg.is_fallback: 433 ref_log = autotvm.tophub.load_reference_log( 434 "arm_cpu", "rk3399", "depthwise_conv2d_nchw_spatial_pack.arm_cpu" 435 ) 436 cfg.fallback_with_reference_log(ref_log) 437 438 # ==================== define configuration space ==================== 439 n, c, oh, ow = cfg.axis(N), cfg.axis(C), cfg.axis(OH), cfg.axis(OW) 440 kh, kw = cfg.reduce_axis(KH), cfg.reduce_axis(KW) 441 442 # Currently, Mali schedule doesn't use it like conv2d. 443 # Leave num_tile for possible future use of Mali schedule 444 if num_tile == 2: # for arm cpu 445 co, vc = cfg.define_split("tile_co", c, num_outputs=2) 446 oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2) 447 ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2) 448 else: 449 raise RuntimeError("Invalid num_tile") 450 451 cfg.define_reorder( 452 "reorder_0", 453 [n, co, oh, ow, kh, kw, vh, vw, vc], 454 policy="candidate", 455 candidate=[[n, co, oh, ow, kh, kw, vh, vw, vc], [n, co, oh, ow, kh, kw, vc, vh, vw]], 456 ) 457 458 cfg.define_reorder( 459 "reorder_1", 460 [n, co, oh, ow, vh, vw, vc], 461 policy="candidate", 462 candidate=[ 463 [n, co, oh, ow, vh, vw, vc], 464 [n, co, oh, ow, vc, vh, vw], 465 [n, co, oh, ow, vh, vc, vw], 466 ], 467 ) 468 469 cfg.define_annotate("ann_reduce", [kh, kw], policy="try_unroll") 470 cfg.define_annotate("ann_spatial", [vh, vw, vc], policy="try_unroll_vec") 471 # ==================================================================== 472 473 VC = cfg["tile_co"].size[-1] 474 VH = cfg["tile_oh"].size[-1] 475 VW = cfg["tile_ow"].size[-1] 476 477 kvshape = (C // VC, M, KH, KW, VC) 478 ovshape = (N, C * M // VC, OH // VH, OW // VW, VH, VW, VC) 479 oshape = (N, C * M, OH, OW) 480 481 if dilation_h != 1 or dilation_w != 1: 482 # undilate input data 483 dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW) 484 data_vec = te.compute( 485 dvshape, 486 lambda n, h, w, c, kh, kw, vh, vw: data_pad[n][c][ 487 (h * VH + vh) * HSTR + kh * dilation_h 488 ][(w * VW + vw) * WSTR + kw * dilation_w], 489 name="data_vec_undilated", 490 ) 491 else: 492 dvshape = (N, OH // VH, OW // VW, C, VH * HSTR + KH - 1, VW * WSTR + KW - 1) 493 data_vec = te.compute( 494 dvshape, 495 lambda n, h, w, c, vh, vw: data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw], 496 name="data_vec", 497 ) 498 499 if pre_packed: 500 kernel_vec = kernel 501 else: 502 kernel_vec = te.compute( 503 kvshape, lambda co, m, kh, kw, vc: kernel[co * VC + vc][m][kh][kw], name="kernel_vec" 504 ) 505 506 kh = te.reduce_axis((0, KH), name="kh") 507 kw = te.reduce_axis((0, KW), name="kw") 508 509 idxdiv = tvm.tir.indexdiv 510 idxmod = tvm.tir.indexmod 511 512 if dilation_h != 1 or dilation_w != 1: 513 conv = te.compute( 514 ovshape, 515 lambda n, co, h, w, vh, vw, vc: te.sum( 516 data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw].astype(out_dtype) 517 * kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype), 518 axis=[kh, kw], 519 ), 520 name="depthwise_conv", 521 ) 522 else: 523 conv = te.compute( 524 ovshape, 525 lambda n, co, h, w, vh, vw, vc: te.sum( 526 data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh, vw * WSTR + kw].astype( 527 out_dtype 528 ) 529 * kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype), 530 axis=[kh, kw], 531 ), 532 name="depthwise_conv", 533 ) 534 535 output = te.compute( 536 oshape, 537 lambda n, co, h, w: conv[ 538 n, 539 idxdiv(co, VC), 540 idxdiv(h, VH), 541 idxdiv(w, VW), 542 idxmod(h, VH), 543 idxmod(w, VW), 544 idxmod(co, VC), 545 ], 546 name="output_unpack", 547 tag="spatial_depthwise_conv2d_nchw_output", 548 ) 549 return output 550 551 552def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, last): 553 """schedule implementation""" 554 n, co, oh, ow, vh, vw, vc = s[conv].op.axis 555 kh, kw = s[conv].op.reduce_axis 556 557 if data_vec.op.name == "data_vec_undilated": 558 _, dv_oh, dv_ow, dv_c, _, _, dv_vh, dv_vw = s[data_vec].op.axis 559 else: 560 _, dv_oh, dv_ow, dv_c, dv_vh, dv_vw = s[data_vec].op.axis 561 562 data_pad = data_vec.op.input_tensors[0] 563 if data_pad.op.name == "data_pad": 564 assert isinstance(data_pad.op, tvm.te.ComputeOp) 565 has_padding = True 566 else: 567 assert isinstance(data_pad.op, tvm.te.PlaceholderOp) 568 has_padding = False 569 570 cfg.define_knob("data_pad_inline", [0, 1, 2, 3, 4]) 571 572 if cfg["data_pad_inline"].val == 1 and has_padding: 573 s[data_pad].compute_inline() 574 if cfg["data_pad_inline"].val == 2 and has_padding: 575 s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) 576 if cfg["data_pad_inline"].val == 3 and has_padding: 577 s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) 578 s[data_pad].compute_at(s[data_vec], dv_oh) 579 if cfg["data_pad_inline"].val == 4 and has_padding: 580 s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) 581 s[data_pad].compute_at(s[data_vec], dv_ow) 582 583 cfg.define_knob("data_vec_inline", [0, 1, 2, 3]) 584 if cfg["data_vec_inline"].val == 1: 585 s[data_vec].compute_at(s[conv], oh) 586 if cfg["data_vec_inline"].val == 2: 587 s[data_vec].compute_at(s[conv], ow) 588 if cfg["data_vec_inline"].val == 3: 589 s[data_vec].compute_at(s[conv], co) 590 591 # schedule conv 592 cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, kh, kw, vh, vw, vc]) 593 cfg["ann_reduce"].apply( 594 s, 595 conv, 596 [kh, kw], 597 axis_lens=[get_const_int(kh.dom.extent), get_const_int(kw.dom.extent)], 598 max_unroll=16, 599 cfg=cfg, 600 ) 601 cfg["ann_spatial"].apply( 602 s, 603 conv, 604 [vh, vw, vc], 605 axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]], 606 max_unroll=16, 607 cfg=cfg, 608 ) 609 610 # schedule fusion 611 n, co, h, w = s[last].op.axis 612 co, vc = cfg["tile_co"].apply(s, last, co) 613 oh, vh = cfg["tile_oh"].apply(s, last, h) 614 ow, vw = cfg["tile_ow"].apply(s, last, w) 615 cfg["reorder_1"].apply(s, last, [n, co, oh, ow, vh, vw, vc]) 616 if last != output: 617 s[output].compute_inline() 618 cfg["ann_spatial"].apply( 619 s, 620 last, 621 [vh, vw, vc], 622 axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]], 623 max_unroll=16, 624 cfg=cfg, 625 ) 626 else: 627 s[last].vectorize(vw) 628 cfg.define_knob("conv_inline", [0, 1, 2, 3]) 629 if cfg["conv_inline"].val == 1: 630 s[conv].compute_at(s[last], ow) 631 if cfg["conv_inline"].val == 2: 632 s[conv].compute_at(s[last], oh) 633 if cfg["conv_inline"].val == 3: 634 s[conv].compute_at(s[last], co) 635 636 # mark parallel 637 s[last].parallel(co) 638 639 if data_vec.op.name == "data_vec_undilated": 640 _, h, _, _, _, _, _, _ = s[data_vec].op.axis 641 else: 642 _, h, _, _, _, _ = s[data_vec].op.axis 643 s[data_vec].parallel(h) 644 645 if kernel_vec.op.name == "kernel_vec": 646 co, _, _, _, _ = s[kernel_vec].op.axis 647 if autotvm.GLOBAL_SCOPE.in_tuning: 648 # kernel packing will be pre-computed during compliation, so we skip 649 # this part to make tuning records correct 650 s[kernel_vec].pragma(co, "debug_skip_region") 651 else: 652 s[kernel_vec].parallel(co) 653 654 return s 655