1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, too-many-locals, too-many-arguments 18# pylint: disable=unused-argument, redefined-builtin 19"""Bitserial Conv2D operators""" 20from __future__ import absolute_import as _abs 21import tvm 22from tvm import autotvm 23from .pad import pad 24from .util import get_pad_tuple 25from .bitserial_util import bitpack, binary_op_multiplier 26from ..util import get_const_tuple 27 28@tvm.target.generic_func 29def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits, 30 pack_dtype='uint32', out_dtype='int16', unipolar=True): 31 """Bitserial Conv2D operator. 32 33 Parameters 34 ---------- 35 input : tvm.Tensor 36 4-D with shape [batch, in_channel, in_height, in_width] 37 38 filter : tvm.Tensor 39 4-D with shape [num_filter, in_channel, filter_height, filter_width] 40 41 stride : int or a list/tuple of two ints 42 stride size, or [stride_height, stride_width] 43 44 padding : int or a list/tuple of two or four ints 45 padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right] 46 47 activation_bits: int 48 number of bits used for activations/input elements 49 50 weight_bits: int 51 number of bits used for weight elements 52 53 out_dtype: str 54 return type of convolution 55 56 pack_dtype: str 57 bit packing type 58 59 unipolar: bool 60 if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format 61 62 Returns 63 ------- 64 output : tvm.Tensor 65 4-D with shape [batch, out_channel, out_height, out_width] 66 """ 67 assert isinstance(stride, int) or len(stride) == 2 68 Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype) 69 if len(filter.shape) == 4: 70 Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) 71 else: 72 Filter_q = filter 73 batch, in_channel, activation_bits, in_height, in_width = Input_q.shape 74 num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape 75 76 if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): 77 TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) 78 else: 79 TPAD, LPAD, DPAD, RPAD = padding 80 pad_before = [0, 0, 0, TPAD, LPAD] 81 pad_after = [0, 0, 0, DPAD, RPAD] 82 83 PadInput_q = pad(Input_q, pad_before, pad_after, name="pad_temp") 84 # compute the output shape 85 if isinstance(stride, int): 86 stride_h = stride_w = stride 87 else: 88 stride_h, stride_w = stride 89 out_channel = num_filter 90 out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1 91 out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1 92 93 rc = tvm.reduce_axis((0, in_channel), name='rc') 94 ry = tvm.reduce_axis((0, kernel_h), name='ry') 95 rx = tvm.reduce_axis((0, kernel_w), name='rx') 96 b1 = tvm.reduce_axis((0, activation_bits), name='b1') 97 b2 = tvm.reduce_axis((0, weight_bits), name='b2') 98 99 if unipolar: 100 def _conv(nn, ff, yy, xx): 101 b1b2 = (b1+b2).astype(out_dtype) 102 return tvm.sum( 103 ((tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & 104 Filter_q[ff, rc, ry, rx, b2]) - 105 tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & 106 ~Filter_q[ff, rc, ry, rx, b2])) 107 << (b1b2)).astype(out_dtype), 108 axis=[rc, ry, rx, b2, b1]).astype(out_dtype) 109 else: 110 def _conv(nn, ff, yy, xx): 111 b1b2 = (b1+b2).astype(out_dtype) 112 return tvm.sum((tvm.popcount( 113 PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & 114 Filter_q[ff, rc, ry, rx, b2])<< (b1b2)).astype(out_dtype), 115 axis=[rc, ry, rx, b2, b1]).astype(out_dtype) 116 117 return tvm.compute((batch, out_channel, out_height, out_width), _conv, 118 name="Conv2dOutput", tag="bitserial_conv2d_nchw") 119 120@tvm.target.generic_func 121def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, 122 pack_dtype='uint32', out_dtype='int16', unipolar=True): 123 """Bitserial Conv2D operator. 124 125 Parameters 126 ---------- 127 input : tvm.Tensor 128 4-D with shape [batch, in_height, in_width, in_channel] 129 130 filter : tvm.Tensor 131 4-D with shape [filter_height, filter_width, in_channel, num_filter] 132 133 stride : int or a list/tuple of two ints 134 stride size, or [stride_height, stride_width] 135 136 padding : int or a list/tuple of two or four ints 137 padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right] 138 139 activation_bits: int 140 number of bits used for activations/input elements 141 142 weight_bits: int 143 number of bits used for weight elements 144 145 out_dtype: str 146 return type of convolution 147 148 pack_dtype: str 149 bit packing type 150 151 unipolar: bool 152 if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format 153 154 Returns 155 ------- 156 output : tvm.Tensor 157 4-D with shape [batch, out_height, out_width, out_channel] 158 """ 159 assert isinstance(stride, int) or len(stride) == 2 160 Input_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype) 161 if len(kernel.shape) == 4: 162 Filter_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype) 163 kernel_h, kernel_w, _, num_filter, _ = get_const_tuple(Filter_q.shape) 164 else: 165 Filter_q = kernel 166 kernel_h, kernel_w, _, _, num_filter = get_const_tuple(Filter_q.shape) 167 batch, in_height, in_width, in_channel_q, _ = get_const_tuple(Input_q.shape) 168 169 if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): 170 TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) 171 else: 172 TPAD, LPAD, DPAD, RPAD = padding 173 pad_before = [0, TPAD, LPAD, 0, 0] 174 pad_after = [0, DPAD, RPAD, 0, 0] 175 176 # compute the output shape 177 if isinstance(stride, int): 178 stride_h = stride_w = stride 179 else: 180 stride_h, stride_w = stride 181 out_channel = num_filter 182 out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1 183 out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1 184 PadInput_q = pad(Input_q, pad_before, pad_after, name="PaddedInput") 185 186 rc = tvm.reduce_axis((0, in_channel_q), name='rc') 187 ry = tvm.reduce_axis((0, kernel_h), name='ry') 188 rx = tvm.reduce_axis((0, kernel_w), name='rx') 189 b1 = tvm.reduce_axis((0, activation_bits), name='b1') 190 b2 = tvm.reduce_axis((0, weight_bits), name='b2') 191 192 if unipolar: 193 def _conv(nn, yy, xx, ff): 194 b1b2 = (b1+b2).astype(out_dtype) 195 return tvm.sum( 196 ((tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & 197 Filter_q[ry, rx, rc, ff, b2]) - 198 tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & 199 ~Filter_q[ry, rx, rc, ff, b2])) 200 << b1b2).astype(out_dtype), 201 axis=[rc, ry, rx, b2, b1]) 202 203 else: 204 def _conv(nn, yy, xx, ff): 205 b1b2 = (b1+b2).astype(out_dtype) 206 return tvm.sum((tvm.popcount( 207 PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & 208 Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype), 209 axis=[rc, ry, rx, b2, b1]) 210 211 conv = tvm.compute((batch, out_height, out_width, out_channel), _conv, 212 name="Conv2dOutput", tag="bitserial_conv2d_nhwc") 213 214 return conv 215 216@autotvm.register_topi_compute(bitserial_conv2d_nchw, ['cpu', 'arm_cpu'], 'direct') 217def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, 218 pack_dtype='uint32', out_dtype='int16', unipolar=True): 219 """ Compute convolution with pack on spatial axes. """ 220 assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" 221 data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) 222 # Check if kernel is already bitpacked 223 if len(kernel.shape) == 4: 224 kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) 225 KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape) 226 else: 227 kernel_vec = kernel 228 OCO, _, KH, KW, KB, VC = get_const_tuple(kernel_vec.shape) 229 CO = OCO * VC 230 231 IB, N, CI, H, W = get_const_tuple(data_q.shape) 232 KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape) 233 234 if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): 235 TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) 236 else: 237 TPAD, LPAD, DPAD, RPAD = padding 238 pad_before = [0, 0, 0, TPAD, LPAD] 239 pad_after = [0, 0, 0, DPAD, RPAD] 240 241 if isinstance(stride, (tuple, list)): 242 HSTR, WSTR = stride 243 else: 244 HSTR, WSTR = stride, stride 245 HCAT, WCAT = KH-1, KW-1 246 247 TH = H + TPAD + DPAD 248 TW = W + LPAD + RPAD 249 OH = (H + TPAD + DPAD - KH) // HSTR + 1 250 OW = (W + LPAD + RPAD - KW) // WSTR + 1 251 252 # ==================== define configuration space ==================== 253 n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW) 254 ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) 255 ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits) 256 257 co, vc = cfg.define_split('tile_co', co, num_outputs=2, 258 filter=lambda x: max(x.size[1:]) <= 16) 259 oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2, 260 filter=lambda x: max(x.size[1:]) <= 16) 261 ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2, 262 filter=lambda x: max(x.size[1:]) <= 16) 263 cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') 264 265 cfg.define_reorder("reorder_0", 266 [n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci], 267 policy='interval_all', interval=(6, 11)) 268 # binary ops 269 cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype)) 270 # ==================== 271 272 VC = cfg["tile_co"].size[-1] 273 VH = cfg["tile_oh"].size[-1] 274 VW = cfg["tile_ow"].size[-1] 275 276 dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB) 277 kvshape = (CO//VC, CI, KH, KW, KB, VC) 278 ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC) 279 oshape = (1, CO, OH, OW) 280 281 if (TPAD != 0 and RPAD != 0): 282 data_pad = pad(data_q, pad_before, pad_after, name="data_pad") 283 else: 284 data_pad = data_q 285 286 data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \ 287 data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec') 288 289 if len(kernel.shape) == 4: 290 kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \ 291 kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec') 292 293 ci = tvm.reduce_axis((0, CI), name='ci') 294 dh = tvm.reduce_axis((0, KH), name='dh') 295 dw = tvm.reduce_axis((0, KW), name='dw') 296 b1 = tvm.reduce_axis((0, IB), name='ib') 297 b2 = tvm.reduce_axis((0, KB), name='kb') 298 299 def _conv(n, co, h, w, vh, vw, vc): 300 b1b2 = (b1+b2).astype(out_dtype) 301 if unipolar: 302 return tvm.sum((tvm.popcount( 303 data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) & 304 kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) - 305 tvm.popcount( 306 data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) 307 & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2, 308 axis=[ci, dh, dw, b1, b2]) 309 310 return tvm.sum((tvm.popcount( 311 data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] & 312 kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2, 313 axis=[ci, dh, dw, b1, b2]) 314 315 conv = tvm.compute(ovshape, _conv, name='conv_out') 316 idxd = tvm.indexdiv 317 idxm = tvm.indexmod 318 319 return tvm.compute( 320 oshape, lambda n, co, h, w: 321 conv[n, 322 idxd(co, VC), idxd(h, VH), idxd(w, VW), 323 idxm(h, VH), idxm(w, VW), idxm(co, VC)], 324 name='conv_vec', tag='spatial_bitserial_conv_nchw') 325 326@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct') 327def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, 328 pack_dtype='uint32', out_dtype='int16', unipolar=True): 329 """ Compute convolution with pack on spatial axes. """ 330 assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" 331 data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype) 332 pack_kernel = len(kernel.shape) == 4 333 334 if pack_kernel: 335 kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype) 336 else: 337 kernel_q = kernel 338 339 KH, KW, _, CO, KB = get_const_tuple(kernel_q.shape) 340 N, H, W, CI, IB = get_const_tuple(data_q.shape) 341 342 if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): 343 TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) 344 else: 345 TPAD, LPAD, DPAD, RPAD = padding 346 pad_before = [0, TPAD, LPAD, 0, 0] 347 pad_after = [0, DPAD, RPAD, 0, 0] 348 349 if isinstance(stride, (tuple, list)): 350 HSTR, WSTR = stride 351 else: 352 HSTR, WSTR = stride, stride 353 HCAT, WCAT = KH-1, KW-1 354 355 PAD_H = H + (TPAD + DPAD) 356 PAD_W = W + (LPAD + RPAD) 357 OH = (PAD_H - KH) // HSTR + 1 358 OW = (PAD_W - KW) // WSTR + 1 359 oshape = (1, OH, OW, CO) 360 361 # ==================== define configuration space ==================== 362 n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO) 363 ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) 364 ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits) 365 366 co, vc = cfg.define_split('tile_co', co, num_outputs=2, 367 filter=lambda x: max(x.size[1:]) <= 16) 368 oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2, 369 filter=lambda x: max(x.size[1:]) <= 16) 370 ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2, 371 filter=lambda x: max(x.size[1:]) <= 16) 372 cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') 373 cfg.define_reorder("reorder_0", 374 [n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci], 375 policy='interval_all', interval=(3, 7)) 376 # binary ops 377 cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype)) 378 # ==================== 379 380 VC = cfg["tile_co"].size[-1] 381 VH = cfg["tile_oh"].size[-1] 382 VW = cfg["tile_ow"].size[-1] 383 384 dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB) 385 kvshape = (CO, KH, KW, CI, VC, KB) 386 ovshape = (1, OH, OW, CO, VH, VW, VC) 387 oshape = (1, OH, OW, CO) 388 389 if (DPAD != 0 and RPAD != 0): 390 data_pad = pad(data_q, pad_before, pad_after, name="data_pad") 391 else: 392 data_pad = data_q 393 394 data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, ci, b: \ 395 data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][ci][b], name='data_vec') 396 397 kernel_vec = tvm.compute(kvshape, lambda co, dh, dw, ci, vc, b: \ 398 kernel_q[dh][dw][ci][co*VC+vc][b], name='kernel_vec') 399 400 ci = tvm.reduce_axis((0, CI), name='ci') 401 dh = tvm.reduce_axis((0, KH), name='dh') 402 dw = tvm.reduce_axis((0, KW), name='dw') 403 b1 = tvm.reduce_axis((0, IB), name='ib') 404 b2 = tvm.reduce_axis((0, KB), name='kb') 405 406 def _conv(n, h, w, co, vh, vw, vc): 407 b1b2 = (b1+b2).astype(out_dtype) 408 if unipolar: 409 return tvm.sum( 410 ((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] & 411 kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) - 412 tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]& 413 ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2), 414 axis=[dh, dw, ci, b1, b2]) 415 416 return tvm.sum(tvm.popcount( 417 data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] & 418 kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2, 419 axis=[dh, dw, ci, b1, b2]) 420 421 conv = tvm.compute(ovshape, _conv, name='conv') 422 423 idxd = tvm.indexdiv 424 idxm = tvm.indexmod 425 return tvm.compute( 426 oshape, lambda n, h, w, co: 427 conv[n, 428 idxd(h, VH), idxd(w, VW), idxd(co, VC), 429 idxm(h, VH), idxm(w, VW), idxm(co, VC)], 430 name='output_unpack', tag='spatial_bitserial_conv_nhwc') 431 432@tvm.target.generic_func 433def bitserial_conv2d_legalize(attrs, inputs, types): 434 """Legalizes Bitserial Conv2D op. 435 436 Parameters 437 ---------- 438 attrs : tvm.attrs.Attrs 439 Attributes of current convolution 440 inputs : list of tvm.relay.Expr 441 The args of the Relay expr to be legalized 442 types : list of types 443 List of input and output types 444 445 Returns 446 ------- 447 result : tvm.relay.Expr 448 The legalized expr 449 """ 450 # not to change by default 451 return None 452