1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name 18"""The templates for cuda conv2d operators""" 19import tvm 20from tvm import autotvm 21from ..util import get_const_tuple 22 23def schedule_direct_cuda(cfg, s, conv): 24 """schedule optimized for batch size = 1""" 25 26 ##### space definition begin ##### 27 n, f, y, x = s[conv].op.axis 28 rc, ry, rx = s[conv].op.reduce_axis 29 cfg.define_split("tile_f", f, num_outputs=4) 30 cfg.define_split("tile_y", y, num_outputs=4) 31 cfg.define_split("tile_x", x, num_outputs=4) 32 cfg.define_split("tile_rc", rc, num_outputs=2) 33 cfg.define_split("tile_ry", ry, num_outputs=2) 34 cfg.define_split("tile_rx", rx, num_outputs=2) 35 cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) 36 37 target = tvm.target.current_target() 38 if target.target_name in ['nvptx', 'rocm']: 39 cfg.define_knob("unroll_explicit", [1]) 40 else: 41 cfg.define_knob("unroll_explicit", [0, 1]) 42 43 # fallback support 44 if cfg.is_fallback: 45 ref_log = autotvm.tophub.load_reference_log( 46 target.target_name, target.model, 'conv2d', 'direct') 47 cfg.fallback_with_reference_log(ref_log) 48 ##### space definition end ##### 49 50 pad_data, kernel = s[conv].op.input_tensors 51 52 s[pad_data].compute_inline() 53 if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: 54 s[kernel].compute_inline() 55 56 if conv.op in s.outputs: 57 output = conv 58 OL = s.cache_write(conv, 'local') 59 else: 60 output = s.outputs[0].output(0) 61 s[conv].set_scope('local') 62 OL = conv 63 64 # create cache stage 65 AA = s.cache_read(pad_data, 'shared', [OL]) 66 WW = s.cache_read(kernel, 'shared', [OL]) 67 68 # tile and bind spatial axes 69 n, f, y, x = s[output].op.axis 70 kernel_scope, n = s[output].split(n, nparts=1) 71 72 bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) 73 by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) 74 bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) 75 76 bf = s[output].fuse(n, bf) 77 s[output].bind(bf, tvm.thread_axis("blockIdx.z")) 78 s[output].bind(by, tvm.thread_axis("blockIdx.y")) 79 s[output].bind(bx, tvm.thread_axis("blockIdx.x")) 80 s[output].bind(vf, tvm.thread_axis("vthread")) 81 s[output].bind(vy, tvm.thread_axis("vthread")) 82 s[output].bind(vx, tvm.thread_axis("vthread")) 83 s[output].bind(tf, tvm.thread_axis("threadIdx.z")) 84 s[output].bind(ty, tvm.thread_axis("threadIdx.y")) 85 s[output].bind(tx, tvm.thread_axis("threadIdx.x")) 86 s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) 87 s[OL].compute_at(s[output], tx) 88 89 # tile reduction axes 90 n, f, y, x = s[OL].op.axis 91 rc, ry, rx = s[OL].op.reduce_axis 92 rco, rci = cfg['tile_rc'].apply(s, OL, rc) 93 ryo, ryi = cfg['tile_ry'].apply(s, OL, ry) 94 rxo, rxi = cfg['tile_rx'].apply(s, OL, rx) 95 s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) 96 97 s[AA].compute_at(s[OL], rxo) 98 s[WW].compute_at(s[OL], rxo) 99 100 # cooperative fetching 101 for load in [AA, WW]: 102 n, f, y, x = s[load].op.axis 103 fused = s[load].fuse(n, f, y, x) 104 tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) 105 ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) 106 tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) 107 s[load].bind(tz, tvm.thread_axis("threadIdx.z")) 108 s[load].bind(ty, tvm.thread_axis("threadIdx.y")) 109 s[load].bind(tx, tvm.thread_axis("threadIdx.x")) 110 111 # unroll 112 s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) 113 s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) 114 115 N, CO, OH, OW = get_const_tuple(output.shape) 116 _, KH, KW, CI = get_const_tuple(kernel.shape) 117 cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW) 118