1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, too-many-locals, too-many-arguments
18"""Utility functions for bitserial operators"""
19import numpy as np
20import tvm
21from topi.transform import concatenate
22from ..util import get_const_int
23
24def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
25    """Packs data into format necessary for bitserial computation
26    pack_axis : int
27       index of the axis to pack in data
28    bit_axis : int
29       index of axis to place bit axis in resulting packed data"""
30    ishape = data.shape
31    n = len(ishape)
32    if pack_type == 'uint8':
33        data_width = 8
34    elif pack_type == 'uint16':
35        data_width = 16
36    elif pack_type == 'uint32':
37        data_width = 32
38    elif pack_type == 'uint64':
39        data_width = 64
40
41    # Data must be in multiples of the data_width
42    assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size"
43
44    shape_vec = list(ishape)
45    shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width)
46    shape_vec.insert(bit_axis, 1)
47    bitserial_oshape = tuple(shape_vec)
48    masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80])
49
50    # pack axis shifts if bit axis comes before
51    if bit_axis <= pack_axis:
52        pack_axis += 1
53
54    def _bitpack(*indices):
55        packed_data = [tvm.const(0, pack_type)] * bits
56        for k in range(data_width):
57            # Translate indices for packed data back to original
58            idx = [0] * n
59            j = 0
60            for i in range(n+1):
61                if i == bit_axis:
62                    continue
63                elif i == pack_axis:
64                    idx[j] = indices[i] * data_width + k
65                else:
66                    idx[j] = indices[i]
67                j += 1
68
69            element = data(*idx)
70            for b in range(bits):
71                extracted_bit = ((element & tvm.const(masks[b], "int32")) >> b).astype(pack_type)
72                packed_data[b] = (packed_data[b] | extracted_bit)
73                if k < data_width - 1:
74                    packed_data[b] = packed_data[b] << 1
75
76            if k == data_width - 1:
77                return tuple(packed_data)
78        return tuple(packed_data)
79
80    output_tuple = tvm.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack')
81
82    if bits > 1:
83        return concatenate(output_tuple, axis=bit_axis)
84    return output_tuple
85
86def binary_op_multiplier(pack_dtype):
87    """"Returns number of bits packed into
88    pack_dtype: string
89        pack type for the operator (must be a uint)"""
90    return int(pack_dtype[4:])
91