1 // This generator implements a quantized matrix multiplication and schedules for 2 // HVX and CPU. The generator schedule assumes certain size constraints on the 3 // two input matrices: 4 // * The width of the left-hand-side (mat_a_ below) must be divisible by 4. 5 // * The height of mat_a_ must be divisible by 4. 6 // * The width of the right-hand-side (mat_b_ below) must be divisible by 7 // the natural vector size of the architecture you want to run this code on. 8 // Note that all these constraints are asserted at runtime, so running with 9 // illegal sizes with trigger those assertions. Correct input sizes can be 10 // achieved by padding mat_a_ with the value -mat_a_offset_ and mat_b_ with the 11 // value -mat_b_offset_. 12 13 #include "common.h" 14 #include <Halide.h> 15 16 using Halide::Generator; 17 using Halide::RVar; 18 using Halide::ConciseCasts::i32; 19 using Halide::ConciseCasts::u16; 20 using Halide::ConciseCasts::u32; 21 using Halide::ConciseCasts::u8_sat; 22 23 class MatrixMultiply : public Generator<MatrixMultiply> { 24 public: 25 // Two unsigned 8-bit input matrices, indexed by x, y. 26 Input<Buffer<uint8_t>> mat_a_{"mat_a", 2}; 27 Input<Buffer<uint8_t>> mat_b_{"mat_b", 2}; 28 29 // A 1D array of 32-bit biases indexed by output width. 30 Input<Buffer<int32_t>> bias_{"bias", 1}; 31 32 // Offsets and multipliers for the input, filter, and output. 33 Input<int16_t> mat_a_offset_{"mat_a_offset", 0, -255, 0}; 34 Input<int16_t> mat_b_offset_{"mat_b_offset", 0, -255, 0}; 35 Input<int> output_multiplier_{"output_multiplier"}; 36 Input<int> output_shift_{"output_shift"}; 37 Input<int> output_offset_{"output_offset", 0, 0, 255}; 38 Input<uint8_t> output_min_{"output_min"}; 39 Input<uint8_t> output_max_{"output_max"}; 40 41 Output<Buffer<uint8_t>> output_{"output", 2}; 42 generate()43 void generate() { 44 // We take two 8 bit matrices as input. 45 Var x("x"), y("y"); 46 47 // Align the extent of the k dimension to the unroll factor used in the 48 // reduction. Unrolling there is needed to use the vrmpy instruction on 49 // Hexagon. 50 constexpr int kDotProductUnrollFactor = 4; 51 Expr k_extent = mat_a_.dim(0).extent(); 52 k_extent = 53 (k_extent / (kDotProductUnrollFactor)) * (kDotProductUnrollFactor); 54 mat_a_.dim(0).set_extent(k_extent); 55 mat_b_.dim(1).set_extent(k_extent); 56 57 // We split directly in the algorithm by a factor of 4 58 // (== kDotProductUnrollFactor), so we can generate vrmpy instructions on 59 // Hexagon. 60 RDom rk(0, k_extent / kDotProductUnrollFactor, "k"); 61 62 int vector_size_u8 = natural_vector_size<uint8_t>(); 63 int vector_size_u32 = natural_vector_size<uint32_t>(); 64 bool use_hexagon = false; 65 if (get_target().has_feature(Halide::Target::HVX_64)) { 66 vector_size_u8 = 64; 67 vector_size_u32 = 16; 68 use_hexagon = true; 69 } else if (get_target().has_feature(Halide::Target::HVX_128)) { 70 vector_size_u8 = 128; 71 vector_size_u32 = 32; 72 use_hexagon = true; 73 } 74 75 // Define the reordering of mat_b_ as a separate stage so we can lift 76 // the interleaving required by vrmpy out of the inner loop. 77 Func mat_b_swizzled("mat_b_swizzled"); 78 Var k("k"); 79 mat_b_swizzled(x, y, k) = mat_b_(x, 4 * y + k); 80 81 // We need to compute the matrix product: 82 // (mat_a_ + mat_a_offset_ * 1_a) * (mat_b_ + mat_b_offset_ * 1_b), 83 // where 84 // * mat_a_ and mat_b_ are 8-bit input matrices, 85 // * mat_a_offset_ and mat_b_offset_ are scalar values in [-255, 0], 86 // * 1_a is the matrix of same size as mat_a_ filled with 1s, 87 // * 1_b is the matrix of same size as mat_b_ filled with 1s. 88 // If we add the offsets upfront then the matrix multiplication has to be 89 // carried out on 16-bit input matrices. To take full advantage of the 90 // available instructions we need to hit the correct pattern, which is 91 // the expression defining the Func multiplied_no_offsets below. 92 // 93 // Hence we have to factor the computation into the following four products 94 // (1) mat_a_ * mat_b_ 95 // (2) mat_a_offset_ * 1_a * mat_b_ 96 // (3) mat_b_offset_ * mat_a_ * 1_b 97 // (4) mat_a_offset_ * mat_b_offset_ * 1_a * 1_b 98 // Product (1) is the main matrix multiplication that we carry out with 99 // multiplied_no_offsets. 100 // Product (2) can be computed by computing the column sums of mat_b_ (dot 101 // product with the vector containing only 1s) and replicating the result 102 // into the right shape and multiplying by mat_a_offset_. 103 // Product (3) can be computed by computing the row sums of mat_a_ (dot 104 // product with the vector containing only 1s), replicating the result 105 // into the right shape and multiplying with mat_b_offset_. 106 // Finally, the product (4) is just 107 // mat_a_offset_ * mat_b_offset_ * mat_a_.width() replicated to every 108 // element of the resulting matrix. 109 Func multiplied_no_offsets("multiplied_no_offsets"); 110 multiplied_no_offsets(x, y) = u32(0); 111 multiplied_no_offsets(x, y) += 112 u32(u16(mat_a_(4 * rk + 0, y)) * u16(mat_b_swizzled(x, rk, 0))) + 113 u32(u16(mat_a_(4 * rk + 1, y)) * u16(mat_b_swizzled(x, rk, 1))) + 114 u32(u16(mat_a_(4 * rk + 2, y)) * u16(mat_b_swizzled(x, rk, 2))) + 115 u32(u16(mat_a_(4 * rk + 3, y)) * u16(mat_b_swizzled(x, rk, 3))); 116 117 RDom fk(0, mat_a_.width(), "fk"); 118 119 // We could convert the row sums into a partial horizontal reduction that is 120 // vectorized. Partial results can be summed up by a scalar sum afterwards. 121 // While there is a performance benefit for large matrices, we did not 122 // observe any performance difference on practical models. So for simplicity 123 // we just use the straightforward row sum computation here. 124 Func row_sums_a("row_sums_a"); 125 row_sums_a(y) = sum(u32(mat_a_(fk, y))); 126 127 Func column_sums_b("column_sums_b"); 128 column_sums_b(x) = sum(u32(mat_b_(x, fk))); 129 130 Expr offset = 131 cast<int32_t>(mat_a_offset_) * cast<int32_t>(mat_b_offset_) * mat_a_.width(); 132 133 Func multiplied("multiplied"); 134 multiplied(x, y) = 135 multiplied_no_offsets(x, y) + 136 i32(mat_a_offset_) * i32(column_sums_b(x)) + 137 i32(mat_b_offset_) * i32(row_sums_a(y)) + offset; 138 139 // Scale the output. 140 Func scaled_plus_offset("scaled_plus_offset"); 141 scaled_plus_offset(x, y) = 142 multiply_quantized_multiplier(multiplied(x, y) + bias_(x), 143 output_multiplier_, output_shift_) + 144 output_offset_; 145 146 // Saturate and narrow the output. 147 output_(x, y) = 148 clamp(u8_sat(scaled_plus_offset(x, y)), output_min_, output_max_); 149 150 // Specifying .hexagon() on a Func will generate an RPC to run this stage 151 // on Hexagon. If Hexagon is the host (that is, the architecture is 152 // Hexagon), we have to omit the .hexagon() directive as we are already 153 // running on Hexagon. 154 if (use_hexagon && get_target().arch != Target::Hexagon) { 155 output_.hexagon(); 156 } 157 158 constexpr int kTileSizeHeight = 4; 159 if (use_hexagon) { 160 Var xo("xo"), yo("yo"); 161 162 // Split the output into tiles, traversed in columns of tiles 163 // that we parallelize over. 164 output_.compute_root() 165 .tile(x, y, xo, yo, x, y, vector_size_u8, kTileSizeHeight, 166 TailStrategy::RoundUp) 167 .reorder(yo, xo) 168 .prefetch(mat_a_, yo) 169 .vectorize(x) 170 .unroll(y) 171 .parallel(xo); 172 173 // Compute the product at tiles of the output. 174 multiplied_no_offsets.compute_at(output_, yo).vectorize(x).unroll(y); 175 176 multiplied_no_offsets.update(0).reorder(x, y, rk).vectorize(x).unroll(y); 177 178 // Lift the swizzling out of the inner loop. 179 mat_b_swizzled.compute_at(output_, xo) 180 .reorder_storage(k, x, y) 181 .reorder(k, x, y) 182 .vectorize(x) 183 .unroll(k); 184 185 // Split the rows into chunks we can parallelize over, but prefetch 186 // within. 187 Var yi("yi"); 188 row_sums_a.compute_at(output_, Var::outermost()) 189 .split(y, y, yi, 32) 190 .parallel(y) 191 .prefetch(mat_a_, yi); 192 193 Var xi("xi"); 194 column_sums_b.compute_at(output_, Var::outermost()) 195 .split(x, x, xi, vector_size_u8, TailStrategy::GuardWithIf) 196 .parallel(x) 197 .vectorize(xi); 198 199 } else { 200 Var xi("xi"), xii("xii"), yi("yi"), yii("yii"); 201 RVar rki("rki"); 202 203 // This schedule taken from test/performance/MatrixMultiply.cpp 204 constexpr int kBlockSize = 32; 205 constexpr int kBlockSizeXi = 8; 206 207 output_.compute_root() 208 .tile(x, y, x, y, xi, yi, vector_size_u8, kTileSizeHeight, 209 TailStrategy::RoundUp) 210 .reorder(xi, yi, x, y) 211 .vectorize(xi) 212 .unroll(yi) 213 .parallel(y); 214 215 multiplied_no_offsets.compute_root().vectorize(x, vector_size_u32); 216 217 multiplied_no_offsets.update(0) 218 .split(x, x, xi, kBlockSize, TailStrategy::GuardWithIf) 219 .split(xi, xi, xii, kBlockSizeXi, TailStrategy::GuardWithIf) 220 .split(y, y, yi, kBlockSize, TailStrategy::GuardWithIf) 221 .split(yi, yi, yii, 4, TailStrategy::GuardWithIf) 222 .split(rk, rk, rki, kBlockSize, TailStrategy::GuardWithIf) 223 .reorder(xii, yii, xi, rki, yi, rk, x, y) 224 .parallel(y) 225 .vectorize(xii) 226 .unroll(xi) 227 .unroll(yii); 228 229 row_sums_a.compute_root().vectorize(y, vector_size_u8, 230 TailStrategy::ShiftInwards); 231 232 column_sums_b.compute_root().vectorize(x, vector_size_u8, 233 TailStrategy::ShiftInwards); 234 } 235 236 constexpr int kMatAHeightAlign = kTileSizeHeight; 237 int vector_dim_align = vector_size_u8; 238 239 mat_a_.dim(0) 240 .set_bounds(0, mat_a_.dim(0).extent()) 241 .dim(1) 242 .set_bounds(0, (mat_a_.dim(1).extent() / kMatAHeightAlign) * kMatAHeightAlign) 243 .set_stride((mat_a_.dim(1).stride() / kMatAHeightAlign) * 244 kMatAHeightAlign); 245 246 mat_b_.dim(0) 247 .set_bounds(0, (mat_b_.dim(0).extent() / vector_dim_align) * vector_dim_align) 248 .dim(1) 249 .set_bounds(0, mat_b_.dim(1).extent()) 250 .set_stride(mat_b_.dim(1).stride()); 251 252 output_.dim(0) 253 .set_bounds(0, (output_.dim(0).extent() / vector_dim_align) * vector_dim_align) 254 .dim(1) 255 .set_bounds(0, (output_.dim(1).extent() / kMatAHeightAlign) * kMatAHeightAlign) 256 .set_stride((output_.dim(1).stride() / kMatAHeightAlign) * kMatAHeightAlign); 257 258 bias_.dim(0).set_bounds(0, bias_.dim(0).extent()); 259 } 260 }; 261 262 HALIDE_REGISTER_GENERATOR(MatrixMultiply, MatrixMultiply) 263