1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, too-many-locals, too-many-arguments
18# pylint: disable=unused-argument, redefined-builtin
19"""Bitserial Conv2D operators"""
20from __future__ import absolute_import as _abs
21import tvm
22from tvm import autotvm
23from .pad import pad
24from .util import get_pad_tuple
25from .bitserial_util import bitpack, binary_op_multiplier
26from ..util import get_const_tuple
27
28@tvm.target.generic_func
29def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
30                          pack_dtype='uint32', out_dtype='int16', unipolar=True):
31    """Bitserial Conv2D operator.
32
33    Parameters
34    ----------
35    input : tvm.Tensor
36        4-D with shape [batch, in_channel, in_height, in_width]
37
38    filter : tvm.Tensor
39        4-D with shape [num_filter, in_channel, filter_height, filter_width]
40
41    stride : int or a list/tuple of two ints
42        stride size, or [stride_height, stride_width]
43
44    padding : int or a list/tuple of two or four ints
45        padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
46
47    activation_bits: int
48        number of bits used for activations/input elements
49
50    weight_bits: int
51        number of bits used for weight elements
52
53    out_dtype: str
54        return type of convolution
55
56    pack_dtype: str
57        bit packing type
58
59    unipolar: bool
60        if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
61
62    Returns
63    -------
64    output : tvm.Tensor
65        4-D with shape [batch, out_channel, out_height, out_width]
66    """
67    assert isinstance(stride, int) or len(stride) == 2
68    Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype)
69    if len(filter.shape) == 4:
70        Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
71    else:
72        Filter_q = filter
73    batch, in_channel, activation_bits, in_height, in_width = Input_q.shape
74    num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape
75
76    if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
77        TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
78    else:
79        TPAD, LPAD, DPAD, RPAD = padding
80    pad_before = [0, 0, 0, TPAD, LPAD]
81    pad_after = [0, 0, 0, DPAD, RPAD]
82
83    PadInput_q = pad(Input_q, pad_before, pad_after, name="pad_temp")
84    # compute the output shape
85    if isinstance(stride, int):
86        stride_h = stride_w = stride
87    else:
88        stride_h, stride_w = stride
89    out_channel = num_filter
90    out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1
91    out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1
92
93    rc = tvm.reduce_axis((0, in_channel), name='rc')
94    ry = tvm.reduce_axis((0, kernel_h), name='ry')
95    rx = tvm.reduce_axis((0, kernel_w), name='rx')
96    b1 = tvm.reduce_axis((0, activation_bits), name='b1')
97    b2 = tvm.reduce_axis((0, weight_bits), name='b2')
98
99    if unipolar:
100        def _conv(nn, ff, yy, xx):
101            b1b2 = (b1+b2).astype(out_dtype)
102            return tvm.sum(
103                ((tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
104                               Filter_q[ff, rc, ry, rx, b2]) -
105                  tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
106                               ~Filter_q[ff, rc, ry, rx, b2]))
107                 << (b1b2)).astype(out_dtype),
108                axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
109    else:
110        def _conv(nn, ff, yy, xx):
111            b1b2 = (b1+b2).astype(out_dtype)
112            return tvm.sum((tvm.popcount(
113                PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
114                Filter_q[ff, rc, ry, rx, b2])<< (b1b2)).astype(out_dtype),
115                           axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
116
117    return tvm.compute((batch, out_channel, out_height, out_width), _conv,
118                       name="Conv2dOutput", tag="bitserial_conv2d_nchw")
119
120@tvm.target.generic_func
121def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
122                          pack_dtype='uint32', out_dtype='int16', unipolar=True):
123    """Bitserial Conv2D operator.
124
125    Parameters
126    ----------
127    input : tvm.Tensor
128        4-D with shape [batch, in_height, in_width, in_channel]
129
130    filter : tvm.Tensor
131        4-D with shape [filter_height, filter_width, in_channel, num_filter]
132
133    stride : int or a list/tuple of two ints
134        stride size, or [stride_height, stride_width]
135
136    padding : int or a list/tuple of two or four ints
137        padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
138
139    activation_bits: int
140        number of bits used for activations/input elements
141
142    weight_bits: int
143        number of bits used for weight elements
144
145    out_dtype: str
146        return type of convolution
147
148    pack_dtype: str
149        bit packing type
150
151    unipolar: bool
152        if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
153
154    Returns
155    -------
156    output : tvm.Tensor
157        4-D with shape [batch, out_height, out_width, out_channel]
158    """
159    assert isinstance(stride, int) or len(stride) == 2
160    Input_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
161    if len(kernel.shape) == 4:
162        Filter_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
163        kernel_h, kernel_w, _, num_filter, _ = get_const_tuple(Filter_q.shape)
164    else:
165        Filter_q = kernel
166        kernel_h, kernel_w, _, _, num_filter = get_const_tuple(Filter_q.shape)
167    batch, in_height, in_width, in_channel_q, _ = get_const_tuple(Input_q.shape)
168
169    if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
170        TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
171    else:
172        TPAD, LPAD, DPAD, RPAD = padding
173    pad_before = [0, TPAD, LPAD, 0, 0]
174    pad_after = [0, DPAD, RPAD, 0, 0]
175
176    # compute the output shape
177    if isinstance(stride, int):
178        stride_h = stride_w = stride
179    else:
180        stride_h, stride_w = stride
181    out_channel = num_filter
182    out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1
183    out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1
184    PadInput_q = pad(Input_q, pad_before, pad_after, name="PaddedInput")
185
186    rc = tvm.reduce_axis((0, in_channel_q), name='rc')
187    ry = tvm.reduce_axis((0, kernel_h), name='ry')
188    rx = tvm.reduce_axis((0, kernel_w), name='rx')
189    b1 = tvm.reduce_axis((0, activation_bits), name='b1')
190    b2 = tvm.reduce_axis((0, weight_bits), name='b2')
191
192    if unipolar:
193        def _conv(nn, yy, xx, ff):
194            b1b2 = (b1+b2).astype(out_dtype)
195            return tvm.sum(
196                ((tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
197                               Filter_q[ry, rx, rc, ff, b2]) -
198                  tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
199                               ~Filter_q[ry, rx, rc, ff, b2]))
200                 << b1b2).astype(out_dtype),
201                axis=[rc, ry, rx, b2, b1])
202
203    else:
204        def _conv(nn, yy, xx, ff):
205            b1b2 = (b1+b2).astype(out_dtype)
206            return tvm.sum((tvm.popcount(
207                PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
208                Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype),
209                           axis=[rc, ry, rx, b2, b1])
210
211    conv = tvm.compute((batch, out_height, out_width, out_channel), _conv,
212                       name="Conv2dOutput", tag="bitserial_conv2d_nhwc")
213
214    return conv
215
216@autotvm.register_topi_compute(bitserial_conv2d_nchw, ['cpu', 'arm_cpu'], 'direct')
217def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
218                      pack_dtype='uint32', out_dtype='int16', unipolar=True):
219    """ Compute convolution with pack on spatial axes. """
220    assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
221    data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
222    # Check if kernel is already bitpacked
223    if len(kernel.shape) == 4:
224        kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
225        KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
226    else:
227        kernel_vec = kernel
228        OCO, _, KH, KW, KB, VC = get_const_tuple(kernel_vec.shape)
229        CO = OCO * VC
230
231    IB, N, CI, H, W = get_const_tuple(data_q.shape)
232    KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
233
234    if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
235        TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
236    else:
237        TPAD, LPAD, DPAD, RPAD = padding
238    pad_before = [0, 0, 0, TPAD, LPAD]
239    pad_after = [0, 0, 0, DPAD, RPAD]
240
241    if isinstance(stride, (tuple, list)):
242        HSTR, WSTR = stride
243    else:
244        HSTR, WSTR = stride, stride
245    HCAT, WCAT = KH-1, KW-1
246
247    TH = H + TPAD + DPAD
248    TW = W + LPAD + RPAD
249    OH = (H + TPAD + DPAD - KH) // HSTR + 1
250    OW = (W + LPAD + RPAD - KW) // WSTR + 1
251
252     # ==================== define configuration space ====================
253    n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
254    ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
255    ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
256
257    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
258                              filter=lambda x: max(x.size[1:]) <= 16)
259    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
260                              filter=lambda x: max(x.size[1:]) <= 16)
261    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
262                              filter=lambda x: max(x.size[1:]) <= 16)
263    cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
264
265    cfg.define_reorder("reorder_0",
266                       [n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
267                       policy='interval_all', interval=(6, 11))
268    # binary ops
269    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
270    # ====================
271
272    VC = cfg["tile_co"].size[-1]
273    VH = cfg["tile_oh"].size[-1]
274    VW = cfg["tile_ow"].size[-1]
275
276    dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB)
277    kvshape = (CO//VC, CI, KH, KW, KB, VC)
278    ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC)
279    oshape = (1, CO, OH, OW)
280
281    if (TPAD != 0 and RPAD != 0):
282        data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
283    else:
284        data_pad = data_q
285
286    data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \
287        data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')
288
289    if len(kernel.shape) == 4:
290        kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \
291            kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec')
292
293    ci = tvm.reduce_axis((0, CI), name='ci')
294    dh = tvm.reduce_axis((0, KH), name='dh')
295    dw = tvm.reduce_axis((0, KW), name='dw')
296    b1 = tvm.reduce_axis((0, IB), name='ib')
297    b2 = tvm.reduce_axis((0, KB), name='kb')
298
299    def _conv(n, co, h, w, vh, vw, vc):
300        b1b2 = (b1+b2).astype(out_dtype)
301        if unipolar:
302            return tvm.sum((tvm.popcount(
303                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) &
304                kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype))  -
305                            tvm.popcount(
306                                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype)
307                                & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2,
308                           axis=[ci, dh, dw, b1, b2])
309
310        return tvm.sum((tvm.popcount(
311            data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] &
312            kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2,
313                       axis=[ci, dh, dw, b1, b2])
314
315    conv = tvm.compute(ovshape, _conv, name='conv_out')
316    idxd = tvm.indexdiv
317    idxm = tvm.indexmod
318
319    return tvm.compute(
320        oshape, lambda n, co, h, w:
321        conv[n,
322             idxd(co, VC), idxd(h, VH), idxd(w, VW),
323             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
324        name='conv_vec', tag='spatial_bitserial_conv_nchw')
325
326@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
327def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
328                      pack_dtype='uint32', out_dtype='int16', unipolar=True):
329    """ Compute convolution with pack on spatial axes. """
330    assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
331    data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
332    pack_kernel = len(kernel.shape) == 4
333
334    if pack_kernel:
335        kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
336    else:
337        kernel_q = kernel
338
339    KH, KW, _, CO, KB = get_const_tuple(kernel_q.shape)
340    N, H, W, CI, IB = get_const_tuple(data_q.shape)
341
342    if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
343        TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
344    else:
345        TPAD, LPAD, DPAD, RPAD = padding
346    pad_before = [0, TPAD, LPAD, 0, 0]
347    pad_after = [0, DPAD, RPAD, 0, 0]
348
349    if isinstance(stride, (tuple, list)):
350        HSTR, WSTR = stride
351    else:
352        HSTR, WSTR = stride, stride
353    HCAT, WCAT = KH-1, KW-1
354
355    PAD_H = H + (TPAD + DPAD)
356    PAD_W = W + (LPAD + RPAD)
357    OH = (PAD_H - KH) // HSTR + 1
358    OW = (PAD_W - KW) // WSTR + 1
359    oshape = (1, OH, OW, CO)
360
361    # ==================== define configuration space ====================
362    n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO)
363    ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
364    ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
365
366    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
367                              filter=lambda x: max(x.size[1:]) <= 16)
368    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
369                              filter=lambda x: max(x.size[1:]) <= 16)
370    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
371                              filter=lambda x: max(x.size[1:]) <= 16)
372    cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
373    cfg.define_reorder("reorder_0",
374                       [n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
375                       policy='interval_all', interval=(3, 7))
376    # binary ops
377    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
378    # ====================
379
380    VC = cfg["tile_co"].size[-1]
381    VH = cfg["tile_oh"].size[-1]
382    VW = cfg["tile_ow"].size[-1]
383
384    dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB)
385    kvshape = (CO, KH, KW, CI, VC, KB)
386    ovshape = (1, OH, OW, CO, VH, VW, VC)
387    oshape = (1, OH, OW, CO)
388
389    if (DPAD != 0 and RPAD != 0):
390        data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
391    else:
392        data_pad = data_q
393
394    data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, ci, b: \
395        data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][ci][b], name='data_vec')
396
397    kernel_vec = tvm.compute(kvshape, lambda co, dh, dw, ci, vc, b: \
398        kernel_q[dh][dw][ci][co*VC+vc][b], name='kernel_vec')
399
400    ci = tvm.reduce_axis((0, CI), name='ci')
401    dh = tvm.reduce_axis((0, KH), name='dh')
402    dw = tvm.reduce_axis((0, KW), name='dw')
403    b1 = tvm.reduce_axis((0, IB), name='ib')
404    b2 = tvm.reduce_axis((0, KB), name='kb')
405
406    def _conv(n, h, w, co, vh, vw, vc):
407        b1b2 = (b1+b2).astype(out_dtype)
408        if unipolar:
409            return tvm.sum(
410                ((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
411                               kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) -
412                  tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]&
413                               ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2),
414                axis=[dh, dw, ci, b1, b2])
415
416        return tvm.sum(tvm.popcount(
417            data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
418            kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2,
419                       axis=[dh, dw, ci, b1, b2])
420
421    conv = tvm.compute(ovshape, _conv, name='conv')
422
423    idxd = tvm.indexdiv
424    idxm = tvm.indexmod
425    return tvm.compute(
426        oshape, lambda n, h, w, co:
427        conv[n,
428             idxd(h, VH), idxd(w, VW), idxd(co, VC),
429             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
430        name='output_unpack', tag='spatial_bitserial_conv_nhwc')
431
432@tvm.target.generic_func
433def bitserial_conv2d_legalize(attrs, inputs, types):
434    """Legalizes Bitserial Conv2D op.
435
436    Parameters
437    ----------
438    attrs : tvm.attrs.Attrs
439        Attributes of current convolution
440    inputs : list of tvm.relay.Expr
441        The args of the Relay expr to be legalized
442    types : list of types
443        List of input and output types
444
445    Returns
446    -------
447    result : tvm.relay.Expr
448        The legalized expr
449    """
450    # not to change by default
451    return None
452