1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, too-many-locals, too-many-arguments 18"""Utility functions for bitserial operators""" 19import numpy as np 20import tvm 21from topi.transform import concatenate 22from ..util import get_const_int 23 24def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"): 25 """Packs data into format necessary for bitserial computation 26 pack_axis : int 27 index of the axis to pack in data 28 bit_axis : int 29 index of axis to place bit axis in resulting packed data""" 30 ishape = data.shape 31 n = len(ishape) 32 if pack_type == 'uint8': 33 data_width = 8 34 elif pack_type == 'uint16': 35 data_width = 16 36 elif pack_type == 'uint32': 37 data_width = 32 38 elif pack_type == 'uint64': 39 data_width = 64 40 41 # Data must be in multiples of the data_width 42 assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size" 43 44 shape_vec = list(ishape) 45 shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width) 46 shape_vec.insert(bit_axis, 1) 47 bitserial_oshape = tuple(shape_vec) 48 masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80]) 49 50 # pack axis shifts if bit axis comes before 51 if bit_axis <= pack_axis: 52 pack_axis += 1 53 54 def _bitpack(*indices): 55 packed_data = [tvm.const(0, pack_type)] * bits 56 for k in range(data_width): 57 # Translate indices for packed data back to original 58 idx = [0] * n 59 j = 0 60 for i in range(n+1): 61 if i == bit_axis: 62 continue 63 elif i == pack_axis: 64 idx[j] = indices[i] * data_width + k 65 else: 66 idx[j] = indices[i] 67 j += 1 68 69 element = data(*idx) 70 for b in range(bits): 71 extracted_bit = ((element & tvm.const(masks[b], "int32")) >> b).astype(pack_type) 72 packed_data[b] = (packed_data[b] | extracted_bit) 73 if k < data_width - 1: 74 packed_data[b] = packed_data[b] << 1 75 76 if k == data_width - 1: 77 return tuple(packed_data) 78 return tuple(packed_data) 79 80 output_tuple = tvm.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack') 81 82 if bits > 1: 83 return concatenate(output_tuple, axis=bit_axis) 84 return output_tuple 85 86def binary_op_multiplier(pack_dtype): 87 """"Returns number of bits packed into 88 pack_dtype: string 89 pack type for the operator (must be a uint)""" 90 return int(pack_dtype[4:]) 91