1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name,unused-variable 18"""Depthwise convolution schedule for ARM CPU""" 19 20import tvm 21from tvm import autotvm 22 23from ..generic import schedule_depthwise_conv2d_nchw 24from ..nn import depthwise_conv2d_nchw, pad 25from ..util import traverse_inline, get_const_tuple, get_const_int 26from ..nn.util import get_pad_tuple 27 28# register original implementation of depthwise_conv2d_nchw since we don't need to change this part 29autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct', 30 depthwise_conv2d_nchw.fdefault) 31 32# register customized schedule for arm cpu. 33@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 34 ['direct', 'contrib_spatial_pack']) 35def schedule_depthwise_conv2d_nchw_arm(cfg, outs): 36 """Schedule depthwise conv2d 37 38 Parameters 39 ---------- 40 cfg: ConfigEntity 41 The configuration of this template 42 outs: Array of Tensor 43 The computation graph description of depthwise convolution2d 44 in the format of an array of tensors. 45 46 Returns 47 ------- 48 s: Schedule 49 The computation schedule for depthwise_conv2d nchw. 50 """ 51 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 52 s = tvm.create_schedule([x.op for x in outs]) 53 54 def _schedule(cfg, s, data, data_pad, kernel, output): 55 A, B, C = data, kernel, output 56 s[data_pad].compute_inline() 57 58 ##### space definition begin ##### 59 n, c, h, w = s[output].op.axis 60 _, vc = cfg.define_split('tile_c', c, num_outputs=2) 61 _, vh = cfg.define_split('tile_h', h, num_outputs=2) 62 _, vw = cfg.define_split('tile_w', w, num_outputs=2) 63 cfg.define_annotate('ann', [vh, vw, vc], policy='try_unroll_vec') 64 65 # fallback support 66 if cfg.is_fallback: 67 ref_log = autotvm.tophub.load_reference_log( 68 'arm_cpu', 'rk3399', 'depthwise_conv2d_nchw', 'direct') 69 cfg.fallback_with_reference_log(ref_log) 70 ##### space definition end ##### 71 72 # park data to vector form [n, c, h, w] -> [n, C, h, w, VC] 73 A0 = s.cache_read(data_pad, "global", C) 74 n, c, h, w = s[A0].op.axis 75 c, vc = cfg['tile_c'].apply(s, A0, c) 76 s[A0].reorder(n, c, h, w, vc) 77 A1 = s.cache_write(A0, 'global') 78 s[A0].compute_inline() 79 80 # park kernel to vector form [co, ci, kh, kw] -> [CO, ci, kh, kw, VC] 81 B0 = s.cache_read(B, "global", C) 82 c, m, h, w = s[B0].op.axis 83 c, vc, = cfg['tile_c'].apply(s, B0, c) 84 s[B0].reorder(c, m, h, w, vc) 85 B1 = s.cache_write(B0, 'global') 86 s[B0].compute_inline() 87 88 n, c, h, w = s[C].op.axis 89 c, vc, = cfg['tile_c'].apply(s, C, c) 90 s[C].reorder(n, c, h, w, vc) 91 92 # depthwise conv 93 C0 = s.cache_write(C, 'global') 94 _, c, h, w, vc = s[C0].op.axis 95 dh, dw = s[C0].op.reduce_axis 96 oh, ih = cfg['tile_h'].apply(s, C0, h) 97 ow, iw = cfg['tile_w'].apply(s, C0, w) 98 s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc) 99 s[A1].compute_at(s[C0], oh) 100 101 # try unroll and vectorization 102 cfg['ann'].apply(s, C0, [ih, iw, vc], 103 axis_lens=[cfg['tile_h'].size[-1], 104 cfg['tile_w'].size[-1], 105 cfg['tile_c'].size[-1]], 106 max_unroll=16, 107 cfg=cfg) 108 109 # fusion 110 if C.op not in s.outputs: 111 s[C].compute_inline() 112 113 # mark parallel 114 last = outs[0] 115 n, c, h, w = s[last].op.axis 116 s[last].parallel(c) 117 118 n, c, h, w, vc = s[C0].op.axis 119 s[C0].parallel(c) 120 121 c, m, h, w, vc = s[B1].op.axis 122 s[B1].parallel(c) 123 124 return s 125 126 def _callback(op): 127 if op.tag == 'depthwise_conv2d_nchw': 128 output = op.output(0) 129 kernel = op.input_tensors[1] 130 data = op.input_tensors[0] 131 data_pad = None 132 if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: 133 data_pad = data 134 data = data_pad.op.input_tensors[0] 135 _schedule(cfg, s, data, data_pad, kernel, output) 136 137 if op.tag == 'spatial_depthwise_conv2d_nchw_output': 138 output = op.output(0) 139 conv = op.input_tensors[0] 140 data_vec = conv.op.input_tensors[0] 141 kernel_vec = conv.op.input_tensors[1] 142 if kernel_vec.op.name == 'kernel_vec': 143 kernel = kernel_vec.op.input_tensors[0] 144 else: 145 kernel = kernel_vec 146 if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: 147 s[kernel].compute_inline() 148 149 _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) 150 151 traverse_inline(s, outs[0].op, _callback) 152 return s 153 154@autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', ['contrib_spatial_pack']) 155def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype): 156 """TOPI compute callback for depthwise_conv2d nchw 157 158 Parameters 159 ---------- 160 cfg: ConfigEntity 161 The config for this template 162 163 data : tvm.Tensor 164 4-D with shape [batch, in_channel, in_height, in_width] 165 166 kernel : tvm.Tensor 167 4-D with shape [num_filter, multiplier, filter_height, filter_width] or 168 pre-packed 5-D with shape [num_filter_chunk, multiplier, filter_height, 169 filter_width, num_filter_block] 170 171 strides : list of two ints 172 [stride_height, stride_width] 173 174 padding : list of two ints 175 [pad_height, pad_width] 176 177 dilation : list of two ints 178 [dilation_height, dilation_width] 179 180 out_dtype: str 181 The output type. This is used for mixed precision. 182 183 Returns 184 ------- 185 output : tvm.Tensor 186 4-D with shape [batch, out_channel, out_height, out_width] 187 """ 188 189 return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2) 190 191 192def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile): 193 out_dtype = out_dtype or data.dtype 194 195 N, C, IH, IW = get_const_tuple(data.shape) 196 197 if isinstance(dilation, int): 198 dilation_h = dilation_w = dilation 199 else: 200 dilation_h, dilation_w = dilation 201 202 if len(kernel.shape) == 4: 203 pre_packed = False 204 C, M, KH, KW = get_const_tuple(kernel.shape) 205 else: # kernel tensor is pre packed 206 pre_packed = True 207 C, M, KH, KW, VC = get_const_tuple(kernel.shape) 208 C = C * VC 209 210 dilated_kernel_h = (KH - 1) * dilation_h + 1 211 dilated_kernel_w = (KW - 1) * dilation_w + 1 212 213 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 214 padding, (dilated_kernel_h, dilated_kernel_w)) 215 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) 216 OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 217 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 218 # pack data 219 HPAD = pad_top + pad_down 220 WPAD = pad_left + pad_right 221 DOPAD = (HPAD != 0 or WPAD != 0) 222 if DOPAD: 223 data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), 224 name="data_pad") 225 else: 226 data_pad = data 227 228 # fallback support 229 # Currently, Mali schedule doesn't use it like conv2d. 230 if cfg.is_fallback: 231 ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'depthwise_conv2d_nchw', 232 'contrib_spatial_pack') 233 cfg.fallback_with_reference_log(ref_log) 234 235 # ==================== define configuration space ==================== 236 n, c, oh, ow = cfg.axis(N), cfg.axis(C), cfg.axis(OH), cfg.axis(OW) 237 kh, kw = cfg.reduce_axis(KH), cfg.reduce_axis(KW) 238 239 # Currently, Mali schedule doesn't use it like conv2d. 240 # Leave num_tile for possible future use of Mali schedule 241 if num_tile == 2: # for arm cpu 242 co, vc = cfg.define_split('tile_co', c, num_outputs=2) 243 oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) 244 ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) 245 else: 246 raise RuntimeError("Invalid num_tile") 247 248 cfg.define_reorder("reorder_0", 249 [n, co, oh, ow, kh, kw, vh, vw, vc], 250 policy='candidate', candidate=[ 251 [n, co, oh, ow, kh, kw, vh, vw, vc], 252 [n, co, oh, ow, kh, kw, vc, vh, vw]]) 253 254 cfg.define_reorder("reorder_1", 255 [n, co, oh, ow, vh, vw, vc], 256 policy='candidate', candidate=[ 257 [n, co, oh, ow, vh, vw, vc], 258 [n, co, oh, ow, vc, vh, vw], 259 [n, co, oh, ow, vh, vc, vw]]) 260 261 cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') 262 cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') 263 # ==================================================================== 264 265 VC = cfg["tile_co"].size[-1] 266 VH = cfg["tile_oh"].size[-1] 267 VW = cfg["tile_ow"].size[-1] 268 269 kvshape = (C // VC, M, KH, KW, VC) 270 ovshape = (N, C * M // VC, OH // VH, OW // VW, VH, VW, VC) 271 oshape = (N, C * M, OH, OW) 272 273 if dilation_h != 1 or dilation_w != 1: 274 # undilate input data 275 dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW) 276 data_vec = tvm.compute(dvshape, lambda n, h, w, c, kh, kw, vh, vw: 277 data_pad[n][c][(h * VH + vh) * HSTR + kh * dilation_h] 278 [(w*VW+vw)*WSTR+kw*dilation_w], 279 name='data_vec_undilated') 280 else: 281 dvshape = (N, OH // VH, OW // VW, C, VH*HSTR + KH-1, VW*WSTR + KW-1) 282 data_vec = tvm.compute(dvshape, lambda n, h, w, c, vh, vw: 283 data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw], 284 name='data_vec') 285 286 if pre_packed: 287 kernel_vec = kernel 288 else: 289 kernel_vec = tvm.compute(kvshape, lambda co, m, kh, kw, vc: 290 kernel[co*VC+vc][m][kh][kw], 291 name='kernel_vec') 292 293 kh = tvm.reduce_axis((0, KH), name='kh') 294 kw = tvm.reduce_axis((0, KW), name='kw') 295 296 idxdiv = tvm.indexdiv 297 idxmod = tvm.indexmod 298 299 if dilation_h != 1 or dilation_w != 1: 300 conv = tvm.compute( 301 ovshape, lambda n, co, h, w, vh, vw, vc: \ 302 tvm.sum(data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw] 303 .astype(out_dtype) * 304 kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype), 305 axis=[kh, kw]), name='depthwise_conv') 306 else: 307 conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ 308 tvm.sum(data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh, 309 vw * WSTR + kw].astype(out_dtype) * 310 kernel_vec[idxdiv(co, M), 311 idxmod(co, M), 312 kh, kw, vc].astype(out_dtype), 313 axis=[kh, kw]), name='depthwise_conv') 314 315 output = tvm.compute(oshape, lambda n, co, h, w: 316 conv[n, 317 idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW), 318 idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)], 319 name='output_unpack', tag='spatial_depthwise_conv2d_nchw_output') 320 return output 321 322def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, 323 conv, output, last): 324 """schedule implementation""" 325 n, co, oh, ow, vh, vw, vc = s[conv].op.axis 326 kh, kw = s[conv].op.reduce_axis 327 328 if data_vec.op.name == 'data_vec_undilated': 329 _, dv_oh, dv_ow, dv_c, _, _, dv_vh, dv_vw = s[data_vec].op.axis 330 else: 331 _, dv_oh, dv_ow, dv_c, dv_vh, dv_vw = s[data_vec].op.axis 332 333 data_pad = data_vec.op.input_tensors[0] 334 if data_pad.op.name == "data_pad": 335 assert isinstance(data_pad.op, tvm.tensor.ComputeOp) 336 has_padding = True 337 else: 338 assert isinstance(data_pad.op, tvm.tensor.PlaceholderOp) 339 has_padding = False 340 341 cfg.define_knob('data_pad_inline', [0, 1, 2, 3, 4]) 342 343 if cfg['data_pad_inline'].val == 1 and has_padding: 344 s[data_pad].compute_inline() 345 if cfg['data_pad_inline'].val == 2 and has_padding: 346 s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) 347 if cfg['data_pad_inline'].val == 3 and has_padding: 348 s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) 349 s[data_pad].compute_at(s[data_vec], dv_oh) 350 if cfg['data_pad_inline'].val == 4 and has_padding: 351 s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) 352 s[data_pad].compute_at(s[data_vec], dv_ow) 353 354 cfg.define_knob('data_vec_inline', [0, 1, 2, 3]) 355 if cfg['data_vec_inline'].val == 1: 356 s[data_vec].compute_at(s[conv], oh) 357 if cfg['data_vec_inline'].val == 2: 358 s[data_vec].compute_at(s[conv], ow) 359 if cfg['data_vec_inline'].val == 3: 360 s[data_vec].compute_at(s[conv], co) 361 362 # schedule conv 363 cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, kh, kw, vh, vw, vc]) 364 cfg["ann_reduce"].apply(s, conv, [kh, kw], 365 axis_lens=[get_const_int(kh.dom.extent), 366 get_const_int(kw.dom.extent)], 367 max_unroll=16, 368 cfg=cfg) 369 cfg["ann_spatial"].apply(s, conv, [vh, vw, vc], 370 axis_lens=[cfg['tile_oh'].size[-1], 371 cfg['tile_ow'].size[-1], 372 cfg['tile_co'].size[-1]], 373 max_unroll=16, 374 cfg=cfg) 375 376 # schedule fusion 377 n, co, h, w = s[last].op.axis 378 co, vc = cfg['tile_co'].apply(s, last, co) 379 oh, vh = cfg['tile_oh'].apply(s, last, h) 380 ow, vw = cfg['tile_ow'].apply(s, last, w) 381 cfg["reorder_1"].apply(s, last, [n, co, oh, ow, vh, vw, vc]) 382 if last != output: 383 s[output].compute_inline() 384 cfg["ann_spatial"].apply(s, last, [vh, vw, vc], 385 axis_lens=[cfg['tile_oh'].size[-1], 386 cfg['tile_ow'].size[-1], 387 cfg['tile_co'].size[-1]], 388 max_unroll=16, 389 cfg=cfg) 390 else: 391 s[last].vectorize(vw) 392 cfg.define_knob('conv_inline', [0, 1, 2, 3]) 393 if cfg['conv_inline'].val == 1: 394 s[conv].compute_at(s[last], ow) 395 if cfg['conv_inline'].val == 2: 396 s[conv].compute_at(s[last], oh) 397 if cfg['conv_inline'].val == 3: 398 s[conv].compute_at(s[last], co) 399 400 # mark parallel 401 s[last].parallel(co) 402 403 if data_vec.op.name == 'data_vec_undilated': 404 _, h, _, _, _, _, _, _ = s[data_vec].op.axis 405 else: 406 _, h, _, _, _, _ = s[data_vec].op.axis 407 s[data_vec].parallel(h) 408 409 if kernel_vec.op.name == 'kernel_vec': 410 co, _, _, _, _ = s[kernel_vec].op.axis 411 if autotvm.GLOBAL_SCOPE.in_tuning: 412 # kernel packing will be pre-computed during compliation, so we skip 413 # this part to make tuning records correct 414 s[kernel_vec].pragma(co, 'debug_skip_region') 415 else: 416 s[kernel_vec].parallel(co) 417 418 return s 419