1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name 18"""1x1 Conv2D schedule on for Intel CPU""" 19from __future__ import absolute_import as _abs 20import tvm 21from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity 22 23from ..nn.pad import pad 24from ..nn.util import infer_pad, get_pad_tuple 25from ..generic import conv2d as conv2d_generic 26from ..util import get_const_tuple, simplify 27from .tensor_intrin import dot_16x1x16_uint8_int8_int32 28from .util import get_fp32_len 29 30def _fallback_schedule(cfg, wkl): 31 simd_width = get_fp32_len() 32 HPAD, WPAD = wkl.hpad, wkl.wpad 33 HSTR, WSTR = wkl.hstride, wkl.wstride 34 out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 35 out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 36 37 oc_bn = 1 38 for bn in range(simd_width, 0, -1): 39 if wkl.out_filter % bn == 0: 40 oc_bn = bn 41 break 42 43 ic_bn = 1 44 for bn in range(oc_bn, 0, -1): 45 if wkl.in_filter % bn == 0: 46 ic_bn = bn 47 break 48 49 for ow_factor in range(out_width, 0, -1): 50 if out_width % ow_factor == 0: 51 for oh_factor in range(out_height, 0, -1): 52 if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: 53 cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) 54 cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) 55 cfg["tile_oh"] = OtherOptionEntity(oh_factor) 56 cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor]) 57 return 58 raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) 59 60 61def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): 62 # fetch schedule 63 ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], 64 cfg["tile_oh"].val, cfg["tile_ow"].size[-1]) 65 66 # no stride and padding info here 67 padding = infer_pad(data, data_pad) 68 HPAD, WPAD = padding 69 DOPAD = (HPAD != 0 or WPAD != 0) 70 71 A, W = data, kernel_vec 72 A0, A1 = data_pad, data_vec 73 # schedule data 74 if DOPAD: 75 s[A0].compute_inline() 76 batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis 77 parallel_axis = s[A1].fuse(batch, ic_chunk, ih) 78 s[A1].parallel(parallel_axis) 79 80 # schedule kernel pack 81 oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis 82 s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) 83 if oc_bn > 1: 84 s[W].vectorize(oc_block) 85 parallel_axis = s[W].fuse(oc_chunk, oh) 86 s[W].parallel(parallel_axis) 87 88 C, O0, O = conv_out, output, last 89 CC = s.cache_write(C, 'global') 90 91 batch, oc_chunk, oh, ow, oc_block = s[C].op.axis 92 oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) 93 s[C].vectorize(oc_block) 94 95 s[CC].compute_at(s[C], oh_outer) 96 _, oc_chunk, oh, ow, oc_block = s[CC].op.axis 97 ic, _, _ = s[CC].op.reduce_axis 98 99 ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn) 100 101 oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) 102 ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) 103 104 s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) 105 s[CC].vectorize(oc_block) 106 107 s[CC].unroll(ow_inner) 108 s[CC].unroll(oh_inner) 109 110 if O0 != O: 111 s[O0].compute_inline() 112 batch, oc, oh, ow = s[O].op.axis 113 114 oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) 115 oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) 116 ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) 117 s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) 118 119 parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) 120 s[C].compute_at(s[O], parallel_axis) 121 s[O].vectorize(oc_block) 122 123 s[O].parallel(parallel_axis) 124 125 return s 126 127 128def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): 129 # fetch schedule 130 oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] 131 _, _, _, _, ic_bn = get_const_tuple(data.shape) 132 133 # schedule data 134 A = data 135 if isinstance(s[A].op, tvm.tensor.ComputeOp): 136 batch, ic_chunk, ih, iw, ic_block = s[A].op.axis 137 parallel_axis = s[A].fuse(batch, ic_chunk, ih) 138 s[A].parallel(parallel_axis) 139 140 C, O = conv_out, last 141 CC = s.cache_write(C, 'global') 142 143 batch, oc_chunk, oh, ow, oc_block = s[C].op.axis 144 oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) 145 ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) 146 s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) 147 s[C].vectorize(oc_block) 148 149 parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) 150 s[CC].compute_at(s[C], parallel_axis) 151 if C == O: 152 s[C].parallel(parallel_axis) 153 154 _, oc_chunk, oh, ow, oc_block = s[CC].op.axis 155 ic, _, _ = s[CC].op.reduce_axis 156 157 ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn) 158 159 oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) 160 ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) 161 162 s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) 163 s[CC].fuse(oc_chunk, oh_outer) 164 s[CC].vectorize(oc_block) 165 166 s[CC].unroll(ow_inner) 167 s[CC].unroll(oh_inner) 168 169 if C != O: 170 batch, oc_chunk, oh, ow, oc_block = s[O].op.axis 171 oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) 172 ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) 173 s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) 174 175 parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) 176 s[C].compute_at(s[O], parallel_axis) 177 s[O].vectorize(oc_block) 178 s[O].parallel(parallel_axis) 179 180 return s 181 182 183def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): 184 return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, 185 int32_lanes=16, 186 intrin=dot_16x1x16_uint8_int8_int32()) 187 188 189def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype): 190 # more assertion for the shapes 191 assert isinstance(stride, int) or len(stride) == 2 192 assert isinstance(dilation, int) or len(dilation) == 2 193 if isinstance(stride, int): 194 stride_h = stride_w = stride 195 else: 196 stride_h, stride_w = stride 197 198 if isinstance(dilation, int): 199 dilation_h = dilation_w = dilation 200 else: 201 dilation_h, dilation_w = dilation 202 203 batch, in_height, in_width, in_channel = Input.shape 204 kernel_h, kernel_w, num_filter, channel = Filter.shape 205 206 # compute the output shape 207 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 208 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 209 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 210 padding, (dilated_kernel_h, dilated_kernel_w)) 211 out_channel = num_filter 212 out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) 213 out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) 214 pad_before = [0, pad_top, pad_left, 0] 215 pad_after = [0, pad_down, pad_right, 0] 216 PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") 217 # todo: padding filter to accomodate the intrinsic 218 219 # packing the Filter to let memory access be consecutive for AVX512 intrinsic 220 # Done in pre-compute stage 221 idxd = tvm.indexdiv 222 idxm = tvm.indexmod 223 224 packw_shape = (kernel_h, kernel_w, idxd(num_filter, 16), 16 * idxd(channel, 4), 4) 225 PackW = tvm.compute(packw_shape, 226 lambda a, b, c, d, e: 227 Filter[a, b, 228 c*16 + idxm(d, 16), 229 idxd(d, 16) * 4 + e], 230 name="packed_filter") 231 232 rc = tvm.reduce_axis((0, in_channel), name='rc') 233 ry = tvm.reduce_axis((0, kernel_h), name='ry') 234 rx = tvm.reduce_axis((0, kernel_w), name='rx') 235 Output = tvm.compute( 236 (batch, out_height, out_width, out_channel), 237 lambda nn, yy, xx, ff: tvm.sum( 238 PaddedInput[nn, yy * stride_h + ry * dilation_h, 239 xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * 240 PackW[ry, rx, idxd(ff, 16), 241 idxd(rc, 4) * 16 + idxm(ff, 16), 242 idxm(rc, 4)].astype(out_dtype), axis=[ry, rx, rc]), 243 name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8") 244 return Output 245 246 247def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): 248 """ 249 Defines the schedule for the int8 nhwc layout. For 1x1 conv, it 250 is a matrix-multiply operation by using nhwc layout. We will do 251 packing of weight to make the address access be friendly to int8 252 intrinsic 253 """ 254 # FIXME - https://github.com/apache/incubator-tvm/issues/3598 255 # pylint: disable=unreachable 256 return s 257 258 int32_lanes = 16 259 260 # assertion to fail the unhandled case 261 _, _, _, ic_num = get_const_tuple(data.shape) 262 _, _, _, oc_num = get_const_tuple(conv_out.shape) 263 assert ic_num % 4 == 0 264 assert oc_num % 16 == 0 265 266 ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] 267 # schedule data 268 A = data 269 if isinstance(s[A].op, tvm.tensor.ComputeOp): 270 batch, ih, iw, ic = s[A].op.axis 271 d_ic_chunk, d_ic_block = s[A].split(ic, factor=4) 272 s[A].vectorize(d_ic_block) 273 274 C, O = conv_out, last 275 276 batch, oh, ow, oc = s[C].op.axis 277 kh, kw, ic = s[C].op.reduce_axis 278 # match the x86 intrinsic 279 ic_outer, ic_inner = s[C].split(ic, factor=4) 280 oc_outer, oc_inner = s[C].split(oc, factor=int32_lanes) 281 282 ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor) 283 s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner) 284 285 pc = dot_16x1x16_uint8_int8_int32() 286 s[C].tensorize(oc_inner, pc) 287 288 if C != O: 289 batch, last_oh, last_ow, last_oc = s[O].op.axis 290 oc_chunk, oc_block = s[O].split(ochannel, 16) 291 # not saw perf improvement to split oh/ow here 292 s[O].vectorize(oc_block) 293 294 return s 295