1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Binary op constructions
22  * \file nn/bnn.h
23  */
24 #ifndef TVM_TOPI_NN_BNN_H_
25 #define TVM_TOPI_NN_BNN_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/topi/detail/constant_utils.h>
30 #include <tvm/topi/tags.h>
31 
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 namespace nn {
37 
38 using namespace tvm::te;
39 
40 /*!
41  * \brief Binarization and bit-packing along a certain axis.
42  *
43  * \param data N-D tensor, can be any layout
44  * \param axis The axis along which to do binarization and bit-packing. This axis
45  * must have a size equal to an integer multiple of 32.
46  * \param name The name of the operation
47  * \param tag The tag to mark the operation
48  *
49  * \return Output tensor with dtype uint32
50  */
51 inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis,
52                                      std::string name = "PackedInput",
53                                      std::string tag = "binarize_pack") {
54   auto ishape = data->shape;
55   CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0)
56       << "binarize_pack: axis size must be a multiple of 32";
57 
58   arith::Analyzer analyzer;
59   auto n = ishape.size();
60   Array<PrimExpr> oshape;
61   for (size_t i = 0; i < n; ++i) {
62     oshape.push_back(i == static_cast<size_t>(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32))
63                                                     : ishape[i]);
64   }
65 
66   return tvm::te::compute(
67       oshape,
68       [&](const Array<Var>& indices) {
69         Array<PrimExpr> start_idx;
70         for (size_t i = 0; i < n; ++i) {
71           start_idx.push_back(i == static_cast<size_t>(axis) ? indices[i] * 32
72                                                              : static_cast<PrimExpr>(indices[i]));
73         }
74         auto packed = make_const(DataType::UInt(32), 0);
75         for (size_t j = 0; j < 32; ++j) {
76           Array<PrimExpr> idx;
77           for (size_t i = 0; i < n; ++i) {
78             idx.push_back(i == static_cast<size_t>(axis) ? start_idx[i] + static_cast<int>(j)
79                                                          : start_idx[i]);
80           }
81           auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0);
82           packed = (packed | sign);
83           if (j == 31) {
84             return packed;
85           }
86           packed = packed << 1;
87         }
88         return packed;  // never reached, but suppress compiler warning
89       },
90       name, tag);
91 }
92 
93 /*!
94  * \brief Binary matrix multiplication using xor and bit-count
95  *
96  * \param data Tensor with shape [batch, in_dim], dtype is uint32
97  * \param weight Tensor with shape [out_dim, in_dim], dtype is uint32
98  *
99  * \return Tensor with shape [batch, out_dim], dtype is float32
100  */
binary_dense(const tvm::te::Tensor & data,const tvm::te::Tensor & weight)101 inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) {
102   CHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data";
103   CHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight";
104   CHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data";
105   CHECK_EQ(weight->dtype, DataType::UInt(32)) << "binary_dense requires uint32 weight";
106 
107   auto batch = data->shape[0];
108   auto in_dim = data->shape[1];
109   auto out_dim = weight->shape[0];
110 
111   auto k = tvm::te::reduce_axis(Range(0, in_dim), "k");
112   auto matmul = tvm::te::compute(
113       {batch, out_dim},
114       [&](Var i, Var j) { return tvm::sum(popcount(data(i, k) ^ weight(j, k)), {k}); }, "tensor",
115       "binary_dense");
116 
117   return tvm::te::compute(
118       {batch, out_dim}, [&](Var i, Var j) { return 32 * in_dim - 2.0f * matmul(i, j); }, "tensor",
119       kElementWise);
120 }
121 
122 }  // namespace nn
123 }  // namespace topi
124 }  // namespace tvm
125 #endif  // TVM_TOPI_NN_BNN_H_
126