1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,unused-variable
18"""Depthwise convolution schedule for ARM CPU"""
19
20import tvm
21from tvm import te
22from tvm import autotvm
23from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
24
25from .. import nn
26from ..util import traverse_inline, get_const_tuple, get_const_int
27from ..nn.util import get_pad_tuple
28
29
30@autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu")
31def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype):
32    """Compute depthwise_conv2d with NCHW layout"""
33    return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
34
35
36@autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu")
37def schedule_depthwise_conv2d_nchw(cfg, outs):
38    """Schedule depthwise conv2d
39
40    Parameters
41    ----------
42    cfg: ConfigEntity
43        The configuration of this template
44    outs: Array of Tensor
45        The computation graph description of depthwise convolution2d
46        in the format of an array of tensors.
47
48    Returns
49    -------
50    s: Schedule
51        The computation schedule for depthwise_conv2d nchw.
52    """
53    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
54    s = te.create_schedule([x.op for x in outs])
55
56    def _schedule(cfg, s, data, data_pad, kernel, output):
57        A, B, C = data, kernel, output
58        s[data_pad].compute_inline()
59
60        ##### space definition begin #####
61        n, c, h, w = s[output].op.axis
62        _, vc = cfg.define_split("tile_c", c, num_outputs=2)
63        _, vh = cfg.define_split("tile_h", h, num_outputs=2)
64        _, vw = cfg.define_split("tile_w", w, num_outputs=2)
65        cfg.define_annotate("ann", [vh, vw, vc], policy="try_unroll_vec")
66
67        # fallback support
68        if cfg.is_fallback:
69            ref_log = autotvm.tophub.load_reference_log(
70                "arm_cpu", "rk3399", "depthwise_conv2d_nchw.arm_cpu"
71            )
72            cfg.fallback_with_reference_log(ref_log)
73        ##### space definition end #####
74
75        # park data to vector form  [n, c, h, w] -> [n, C, h, w, VC]
76        A0 = s.cache_read(data_pad, "global", C)
77        n, c, h, w = s[A0].op.axis
78        c, vc = cfg["tile_c"].apply(s, A0, c)
79        s[A0].reorder(n, c, h, w, vc)
80        A1 = s.cache_write(A0, "global")
81        s[A0].compute_inline()
82
83        # park kernel to vector form  [co, ci, kh, kw] -> [CO, ci, kh, kw, VC]
84        B0 = s.cache_read(B, "global", C)
85        c, m, h, w = s[B0].op.axis
86        c, vc, = cfg[
87            "tile_c"
88        ].apply(s, B0, c)
89        s[B0].reorder(c, m, h, w, vc)
90        B1 = s.cache_write(B0, "global")
91        s[B0].compute_inline()
92
93        n, c, h, w = s[C].op.axis
94        c, vc, = cfg[
95            "tile_c"
96        ].apply(s, C, c)
97        s[C].reorder(n, c, h, w, vc)
98
99        # depthwise conv
100        C0 = s.cache_write(C, "global")
101        _, c, h, w, vc = s[C0].op.axis
102        dh, dw = s[C0].op.reduce_axis
103        oh, ih = cfg["tile_h"].apply(s, C0, h)
104        ow, iw = cfg["tile_w"].apply(s, C0, w)
105        s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
106        s[A1].compute_at(s[C0], oh)
107
108        # try unroll and vectorization
109        cfg["ann"].apply(
110            s,
111            C0,
112            [ih, iw, vc],
113            axis_lens=[cfg["tile_h"].size[-1], cfg["tile_w"].size[-1], cfg["tile_c"].size[-1]],
114            max_unroll=16,
115            cfg=cfg,
116        )
117
118        # fusion
119        if C.op not in s.outputs:
120            s[C].compute_inline()
121
122        # mark parallel
123        last = outs[0]
124        n, c, h, w = s[last].op.axis
125        s[last].parallel(c)
126
127        n, c, h, w, vc = s[C0].op.axis
128        s[C0].parallel(c)
129
130        c, m, h, w, vc = s[B1].op.axis
131        s[B1].parallel(c)
132
133        return s
134
135    def _callback(op):
136        if op.tag == "depthwise_conv2d_nchw":
137            output = op.output(0)
138            kernel = op.input_tensors[1]
139            data = op.input_tensors[0]
140            data_pad = None
141            if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag:
142                data_pad = data
143                data = data_pad.op.input_tensors[0]
144            _schedule(cfg, s, data, data_pad, kernel, output)
145
146    traverse_inline(s, outs[0].op, _callback)
147    return s
148
149
150# TODO:
151# This schedule has incorrect result on some hardware platforms (like NV Jetson TX2)
152# Let us comment it out but not remove.
153# see discussion:
154# https://discuss.tvm.ai/t/autotuner-incorrect-result-after-tuning-mobilenetv2-on-arm-cpu/6088
155@autotvm.register_topi_compute("depthwise_conv2d_nchw_spatial_pack.arm_cpu")
156def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
157    """TOPI compute callback for depthwise_conv2d nchw
158
159    Parameters
160    ----------
161    cfg: ConfigEntity
162        The config for this template
163
164    data : tvm.te.Tensor
165        4-D with shape [batch, in_channel, in_height, in_width]
166
167    kernel : tvm.te.Tensor
168        4-D with shape [num_filter, multiplier, filter_height, filter_width] or
169        pre-packed 5-D with shape [num_filter_chunk, multiplier, filter_height,
170        filter_width, num_filter_block]
171
172    strides : list of two ints
173        [stride_height, stride_width]
174
175    padding : list of two ints
176        [pad_height, pad_width]
177
178    dilation : list of two ints
179        [dilation_height, dilation_width]
180
181    out_dtype: str
182        The output type. This is used for mixed precision.
183
184    Returns
185    -------
186    output : tvm.te.Tensor
187        4-D with shape [batch, out_channel, out_height, out_width]
188    """
189
190    return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
191
192
193@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
194def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
195    """TOPI compute callback for depthwise_conv2d nhwc
196
197    Parameters
198    ----------
199    cfg: ConfigEntity
200        The config for this template
201
202    data : tvm.te.Tensor
203        4-D with shape [batch, in_height, in_width, in_channel]
204
205    kernel : tvm.te.Tensor
206        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
207
208    strides : list of two ints
209        [stride_height, stride_width]
210
211    padding : list of two ints
212        [pad_height, pad_width]
213
214    dilation : list of two ints
215        [dilation_height, dilation_width]
216
217    out_dtype: str
218        The output type. This is used for mixed precision.
219
220    Returns
221    -------
222    output : tvm.te.Tensor
223        4-D with shape [batch, out_height, out_width, out_channel]
224    """
225
226    out_dtype = out_dtype or data.dtype
227
228    N, IH, IW, IC = get_const_tuple(data.shape)
229
230    if isinstance(dilation, int):
231        dilation_h = dilation_w = dilation
232    else:
233        dilation_h, dilation_w = dilation
234
235    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
236
237    dilated_kernel_h = (KH - 1) * dilation_h + 1
238    dilated_kernel_w = (KW - 1) * dilation_w + 1
239
240    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
241        padding, (dilated_kernel_h, dilated_kernel_w)
242    )
243    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
244
245    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
246    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
247
248    if pad_top or pad_left or pad_down or pad_right:
249        data_pad = nn.pad(
250            data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad"
251        )
252    else:
253        data_pad = data
254
255    output_shape = (N, OH, OW, IC * channel_multiplier)
256
257    idxdiv = tvm.tir.indexdiv
258    idxmod = tvm.tir.indexmod
259
260    reduce_h = te.reduce_axis((0, KH), name="reduce_h")
261    reduce_w = te.reduce_axis((0, KW), name="reduce_w")
262
263    out = te.compute(
264        output_shape,
265        lambda n, h, w, c: te.sum(
266            data_pad[
267                n,
268                HSTR * h + dilation_h * reduce_h,
269                w * WSTR + reduce_w * dilation_w,
270                idxdiv(c, channel_multiplier),
271            ].astype(out_dtype)
272            * kernel[
273                reduce_h, reduce_w, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)
274            ].astype(out_dtype),
275            axis=[reduce_h, reduce_w],
276        ),
277        name="depthwise_conv2d_nhwc_output",
278    )
279    return out
280
281
282@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
283def schedule_depthwise_conv2d_nhwc(cfg, outs):
284    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
285    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
286    s = te.create_schedule([x.op for x in outs])
287    out = outs[0]
288
289    ##### space definition begin #####
290    n, h, w, c = s[out].op.axis
291    cfg.define_split("tile_c", c, num_outputs=2)
292    _, hi = cfg.define_split("tile_h", h, num_outputs=2)
293    _, wi = cfg.define_split("tile_w", w, num_outputs=2)
294    cfg.define_knob("locate_output", [0, 1])
295
296    # fallback support
297    if cfg.is_fallback:
298        cfg["tile_c"] = SplitEntity([-1, 8])
299        cfg["tile_h"] = SplitEntity([-1, 2])
300        cfg["tile_w"] = SplitEntity([-1, 2])
301        cfg["locate_output"] = OtherOptionEntity(1)
302    ##### space definition end #####
303
304    def schedule_conv(conv):
305        conv_data = conv.op.input_tensors[0]
306
307        n, w, h, c = conv.op.axis
308        r_h, r_w = conv.op.reduce_axis
309        ho, hi = cfg["tile_h"].apply(s, conv, h)
310        wo, wi = cfg["tile_w"].apply(s, conv, w)
311        co, ci = cfg["tile_c"].apply(s, conv, c)
312
313        if conv_data.name == "data_pad":
314            assert isinstance(conv_data.op, tvm.te.ComputeOp)
315            # Define a policy for padding computation
316            cfg.define_knob("data_pad_inline", [1, 2, 3])
317            if cfg.is_fallback:
318                cfg["data_pad_inline"] = OtherOptionEntity(3)
319            if cfg["data_pad_inline"].val == 1:
320                s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
321                s[conv_data].compute_at(s[conv], ho)
322            if cfg["data_pad_inline"].val == 2:
323                s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
324                s[conv_data].compute_at(s[conv], wo)
325            if cfg["data_pad_inline"].val == 3:
326                s[conv_data].compute_inline()
327
328        s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci)
329        fused_n_ho = s[conv].fuse(n, ho)
330        s[conv].vectorize(ci)
331        return fused_n_ho
332
333    def schedule_conv_out(out):
334        n, h, w, c = out.op.axis
335        co, ci = cfg["tile_c"].apply(s, out, c)
336        wo, wi = cfg["tile_w"].apply(s, out, w)
337        ho, hi = cfg["tile_h"].apply(s, out, h)
338        s[out].reorder(n, ho, wo, co, hi, wi)
339
340        if out.dtype in ["int8", "uint8"]:
341            # In case of quantized convolution further split the channel in batches of 4 elements
342            # so that we can use arm intrinsics to run fixed_point_multiplication
343            ci_outer, ci_inner = s[out].split(ci, 4)
344            s[out].vectorize(ci_inner)
345
346        fused_n_ho = s[out].fuse(n, ho)
347        return hi, wi, fused_n_ho
348
349    def _callback(op):
350        if op.name == "depthwise_conv2d_nhwc_output":
351            conv = op.output(0)
352            if conv != out:
353                hi, wi, p_axis = schedule_conv_out(out)
354                schedule_conv(conv)
355                if cfg["locate_output"].val == 0:
356                    s[conv].compute_at(s[out], hi)
357                if cfg["locate_output"].val == 1:
358                    s[conv].compute_at(s[out], wi)
359            else:
360                p_axis = schedule_conv(out)
361
362            s[out].parallel(p_axis)
363
364    traverse_inline(s, outs[0].op, _callback)
365    return s
366
367
368@autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu")
369def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs):
370    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
371    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
372    s = te.create_schedule([x.op for x in outs])
373
374    def _callback(op):
375        if op.tag == "spatial_depthwise_conv2d_nchw_output":
376            output = op.output(0)
377            conv = op.input_tensors[0]
378            data_vec = conv.op.input_tensors[0]
379            kernel_vec = conv.op.input_tensors[1]
380            if kernel_vec.op.name == "kernel_vec":
381                kernel = kernel_vec.op.input_tensors[0]
382            else:
383                kernel = kernel_vec
384            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
385                s[kernel].compute_inline()
386            _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
387
388    traverse_inline(s, outs[0].op, _callback)
389    return s
390
391
392def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile):
393    out_dtype = out_dtype or data.dtype
394
395    N, C, IH, IW = get_const_tuple(data.shape)
396
397    if isinstance(dilation, int):
398        dilation_h = dilation_w = dilation
399    else:
400        dilation_h, dilation_w = dilation
401
402    if len(kernel.shape) == 4:
403        pre_packed = False
404        C, M, KH, KW = get_const_tuple(kernel.shape)
405    else:  # kernel tensor is pre packed
406        pre_packed = True
407        C, M, KH, KW, VC = get_const_tuple(kernel.shape)
408        C = C * VC
409
410    dilated_kernel_h = (KH - 1) * dilation_h + 1
411    dilated_kernel_w = (KW - 1) * dilation_w + 1
412
413    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
414        padding, (dilated_kernel_h, dilated_kernel_w)
415    )
416    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
417    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
418    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
419    # pack data
420    HPAD = pad_top + pad_down
421    WPAD = pad_left + pad_right
422    DOPAD = HPAD != 0 or WPAD != 0
423    if DOPAD:
424        data_pad = nn.pad(
425            data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), name="data_pad"
426        )
427    else:
428        data_pad = data
429
430    # fallback support
431    # Currently, Mali schedule doesn't use it like conv2d.
432    if cfg.is_fallback:
433        ref_log = autotvm.tophub.load_reference_log(
434            "arm_cpu", "rk3399", "depthwise_conv2d_nchw_spatial_pack.arm_cpu"
435        )
436        cfg.fallback_with_reference_log(ref_log)
437
438    # ==================== define configuration space ====================
439    n, c, oh, ow = cfg.axis(N), cfg.axis(C), cfg.axis(OH), cfg.axis(OW)
440    kh, kw = cfg.reduce_axis(KH), cfg.reduce_axis(KW)
441
442    # Currently, Mali schedule doesn't use it like conv2d.
443    # Leave num_tile for possible future use of Mali schedule
444    if num_tile == 2:  # for arm cpu
445        co, vc = cfg.define_split("tile_co", c, num_outputs=2)
446        oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2)
447        ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2)
448    else:
449        raise RuntimeError("Invalid num_tile")
450
451    cfg.define_reorder(
452        "reorder_0",
453        [n, co, oh, ow, kh, kw, vh, vw, vc],
454        policy="candidate",
455        candidate=[[n, co, oh, ow, kh, kw, vh, vw, vc], [n, co, oh, ow, kh, kw, vc, vh, vw]],
456    )
457
458    cfg.define_reorder(
459        "reorder_1",
460        [n, co, oh, ow, vh, vw, vc],
461        policy="candidate",
462        candidate=[
463            [n, co, oh, ow, vh, vw, vc],
464            [n, co, oh, ow, vc, vh, vw],
465            [n, co, oh, ow, vh, vc, vw],
466        ],
467    )
468
469    cfg.define_annotate("ann_reduce", [kh, kw], policy="try_unroll")
470    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy="try_unroll_vec")
471    # ====================================================================
472
473    VC = cfg["tile_co"].size[-1]
474    VH = cfg["tile_oh"].size[-1]
475    VW = cfg["tile_ow"].size[-1]
476
477    kvshape = (C // VC, M, KH, KW, VC)
478    ovshape = (N, C * M // VC, OH // VH, OW // VW, VH, VW, VC)
479    oshape = (N, C * M, OH, OW)
480
481    if dilation_h != 1 or dilation_w != 1:
482        # undilate input data
483        dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW)
484        data_vec = te.compute(
485            dvshape,
486            lambda n, h, w, c, kh, kw, vh, vw: data_pad[n][c][
487                (h * VH + vh) * HSTR + kh * dilation_h
488            ][(w * VW + vw) * WSTR + kw * dilation_w],
489            name="data_vec_undilated",
490        )
491    else:
492        dvshape = (N, OH // VH, OW // VW, C, VH * HSTR + KH - 1, VW * WSTR + KW - 1)
493        data_vec = te.compute(
494            dvshape,
495            lambda n, h, w, c, vh, vw: data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw],
496            name="data_vec",
497        )
498
499    if pre_packed:
500        kernel_vec = kernel
501    else:
502        kernel_vec = te.compute(
503            kvshape, lambda co, m, kh, kw, vc: kernel[co * VC + vc][m][kh][kw], name="kernel_vec"
504        )
505
506    kh = te.reduce_axis((0, KH), name="kh")
507    kw = te.reduce_axis((0, KW), name="kw")
508
509    idxdiv = tvm.tir.indexdiv
510    idxmod = tvm.tir.indexmod
511
512    if dilation_h != 1 or dilation_w != 1:
513        conv = te.compute(
514            ovshape,
515            lambda n, co, h, w, vh, vw, vc: te.sum(
516                data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw].astype(out_dtype)
517                * kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
518                axis=[kh, kw],
519            ),
520            name="depthwise_conv",
521        )
522    else:
523        conv = te.compute(
524            ovshape,
525            lambda n, co, h, w, vh, vw, vc: te.sum(
526                data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh, vw * WSTR + kw].astype(
527                    out_dtype
528                )
529                * kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
530                axis=[kh, kw],
531            ),
532            name="depthwise_conv",
533        )
534
535    output = te.compute(
536        oshape,
537        lambda n, co, h, w: conv[
538            n,
539            idxdiv(co, VC),
540            idxdiv(h, VH),
541            idxdiv(w, VW),
542            idxmod(h, VH),
543            idxmod(w, VW),
544            idxmod(co, VC),
545        ],
546        name="output_unpack",
547        tag="spatial_depthwise_conv2d_nchw_output",
548    )
549    return output
550
551
552def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, last):
553    """schedule implementation"""
554    n, co, oh, ow, vh, vw, vc = s[conv].op.axis
555    kh, kw = s[conv].op.reduce_axis
556
557    if data_vec.op.name == "data_vec_undilated":
558        _, dv_oh, dv_ow, dv_c, _, _, dv_vh, dv_vw = s[data_vec].op.axis
559    else:
560        _, dv_oh, dv_ow, dv_c, dv_vh, dv_vw = s[data_vec].op.axis
561
562    data_pad = data_vec.op.input_tensors[0]
563    if data_pad.op.name == "data_pad":
564        assert isinstance(data_pad.op, tvm.te.ComputeOp)
565        has_padding = True
566    else:
567        assert isinstance(data_pad.op, tvm.te.PlaceholderOp)
568        has_padding = False
569
570    cfg.define_knob("data_pad_inline", [0, 1, 2, 3, 4])
571
572    if cfg["data_pad_inline"].val == 1 and has_padding:
573        s[data_pad].compute_inline()
574    if cfg["data_pad_inline"].val == 2 and has_padding:
575        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
576    if cfg["data_pad_inline"].val == 3 and has_padding:
577        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
578        s[data_pad].compute_at(s[data_vec], dv_oh)
579    if cfg["data_pad_inline"].val == 4 and has_padding:
580        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
581        s[data_pad].compute_at(s[data_vec], dv_ow)
582
583    cfg.define_knob("data_vec_inline", [0, 1, 2, 3])
584    if cfg["data_vec_inline"].val == 1:
585        s[data_vec].compute_at(s[conv], oh)
586    if cfg["data_vec_inline"].val == 2:
587        s[data_vec].compute_at(s[conv], ow)
588    if cfg["data_vec_inline"].val == 3:
589        s[data_vec].compute_at(s[conv], co)
590
591    # schedule conv
592    cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, kh, kw, vh, vw, vc])
593    cfg["ann_reduce"].apply(
594        s,
595        conv,
596        [kh, kw],
597        axis_lens=[get_const_int(kh.dom.extent), get_const_int(kw.dom.extent)],
598        max_unroll=16,
599        cfg=cfg,
600    )
601    cfg["ann_spatial"].apply(
602        s,
603        conv,
604        [vh, vw, vc],
605        axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]],
606        max_unroll=16,
607        cfg=cfg,
608    )
609
610    # schedule fusion
611    n, co, h, w = s[last].op.axis
612    co, vc = cfg["tile_co"].apply(s, last, co)
613    oh, vh = cfg["tile_oh"].apply(s, last, h)
614    ow, vw = cfg["tile_ow"].apply(s, last, w)
615    cfg["reorder_1"].apply(s, last, [n, co, oh, ow, vh, vw, vc])
616    if last != output:
617        s[output].compute_inline()
618        cfg["ann_spatial"].apply(
619            s,
620            last,
621            [vh, vw, vc],
622            axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]],
623            max_unroll=16,
624            cfg=cfg,
625        )
626    else:
627        s[last].vectorize(vw)
628    cfg.define_knob("conv_inline", [0, 1, 2, 3])
629    if cfg["conv_inline"].val == 1:
630        s[conv].compute_at(s[last], ow)
631    if cfg["conv_inline"].val == 2:
632        s[conv].compute_at(s[last], oh)
633    if cfg["conv_inline"].val == 3:
634        s[conv].compute_at(s[last], co)
635
636    # mark parallel
637    s[last].parallel(co)
638
639    if data_vec.op.name == "data_vec_undilated":
640        _, h, _, _, _, _, _, _ = s[data_vec].op.axis
641    else:
642        _, h, _, _, _, _ = s[data_vec].op.axis
643    s[data_vec].parallel(h)
644
645    if kernel_vec.op.name == "kernel_vec":
646        co, _, _, _, _ = s[kernel_vec].op.axis
647        if autotvm.GLOBAL_SCOPE.in_tuning:
648            # kernel packing will be pre-computed during compliation, so we skip
649            # this part to make tuning records correct
650            s[kernel_vec].pragma(co, "debug_skip_region")
651        else:
652            s[kernel_vec].parallel(co)
653
654    return s
655