1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""The templates for cuda conv2d operators"""
19import tvm
20from tvm import autotvm
21from ..util import get_const_tuple
22
23def schedule_direct_cuda(cfg, s, conv):
24    """schedule optimized for batch size = 1"""
25
26    ##### space definition begin #####
27    n, f, y, x = s[conv].op.axis
28    rc, ry, rx = s[conv].op.reduce_axis
29    cfg.define_split("tile_f", f, num_outputs=4)
30    cfg.define_split("tile_y", y, num_outputs=4)
31    cfg.define_split("tile_x", x, num_outputs=4)
32    cfg.define_split("tile_rc", rc, num_outputs=2)
33    cfg.define_split("tile_ry", ry, num_outputs=2)
34    cfg.define_split("tile_rx", rx, num_outputs=2)
35    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
36
37    target = tvm.target.current_target()
38    if target.target_name in ['nvptx', 'rocm']:
39        cfg.define_knob("unroll_explicit", [1])
40    else:
41        cfg.define_knob("unroll_explicit", [0, 1])
42
43    # fallback support
44    if cfg.is_fallback:
45        ref_log = autotvm.tophub.load_reference_log(
46            target.target_name, target.model, 'conv2d', 'direct')
47        cfg.fallback_with_reference_log(ref_log)
48    ##### space definition end #####
49
50    pad_data, kernel = s[conv].op.input_tensors
51
52    s[pad_data].compute_inline()
53    if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
54        s[kernel].compute_inline()
55
56    if conv.op in s.outputs:
57        output = conv
58        OL = s.cache_write(conv, 'local')
59    else:
60        output = s.outputs[0].output(0)
61        s[conv].set_scope('local')
62        OL = conv
63
64    # create cache stage
65    AA = s.cache_read(pad_data, 'shared', [OL])
66    WW = s.cache_read(kernel, 'shared', [OL])
67
68    # tile and bind spatial axes
69    n, f, y, x = s[output].op.axis
70    kernel_scope, n = s[output].split(n, nparts=1)
71
72    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
73    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
74    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
75
76    bf = s[output].fuse(n, bf)
77    s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
78    s[output].bind(by, tvm.thread_axis("blockIdx.y"))
79    s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
80    s[output].bind(vf, tvm.thread_axis("vthread"))
81    s[output].bind(vy, tvm.thread_axis("vthread"))
82    s[output].bind(vx, tvm.thread_axis("vthread"))
83    s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
84    s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
85    s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
86    s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
87    s[OL].compute_at(s[output], tx)
88
89    # tile reduction axes
90    n, f, y, x = s[OL].op.axis
91    rc, ry, rx = s[OL].op.reduce_axis
92    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
93    ryo, ryi = cfg['tile_ry'].apply(s, OL, ry)
94    rxo, rxi = cfg['tile_rx'].apply(s, OL, rx)
95    s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
96
97    s[AA].compute_at(s[OL], rxo)
98    s[WW].compute_at(s[OL], rxo)
99
100    # cooperative fetching
101    for load in [AA, WW]:
102        n, f, y, x = s[load].op.axis
103        fused = s[load].fuse(n, f, y, x)
104        tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
105        ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
106        tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
107        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
108        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
109        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
110
111    # unroll
112    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
113    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
114
115    N, CO, OH, OW = get_const_tuple(output.shape)
116    _, KH, KW, CI = get_const_tuple(kernel.shape)
117    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
118