1 // This generator implements a quantized matrix multiplication and schedules for
2 // HVX and CPU. The generator schedule assumes certain size constraints on the
3 // two input matrices:
4 // * The width of the left-hand-side (mat_a_ below) must be divisible by 4.
5 // * The height of mat_a_ must be divisible by 4.
6 // * The width of the right-hand-side (mat_b_ below) must be divisible by
7 //   the natural vector size of the architecture you want to run this code on.
8 // Note that all these constraints are asserted at runtime, so running with
9 // illegal sizes with trigger those assertions. Correct input sizes can be
10 // achieved by padding mat_a_ with the value -mat_a_offset_ and mat_b_ with the
11 // value -mat_b_offset_.
12 
13 #include "common.h"
14 #include <Halide.h>
15 
16 using Halide::Generator;
17 using Halide::RVar;
18 using Halide::ConciseCasts::i32;
19 using Halide::ConciseCasts::u16;
20 using Halide::ConciseCasts::u32;
21 using Halide::ConciseCasts::u8_sat;
22 
23 class MatrixMultiply : public Generator<MatrixMultiply> {
24 public:
25     // Two unsigned 8-bit input matrices, indexed by x, y.
26     Input<Buffer<uint8_t>> mat_a_{"mat_a", 2};
27     Input<Buffer<uint8_t>> mat_b_{"mat_b", 2};
28 
29     // A 1D array of 32-bit biases indexed by output width.
30     Input<Buffer<int32_t>> bias_{"bias", 1};
31 
32     // Offsets and multipliers for the input, filter, and output.
33     Input<int16_t> mat_a_offset_{"mat_a_offset", 0, -255, 0};
34     Input<int16_t> mat_b_offset_{"mat_b_offset", 0, -255, 0};
35     Input<int> output_multiplier_{"output_multiplier"};
36     Input<int> output_shift_{"output_shift"};
37     Input<int> output_offset_{"output_offset", 0, 0, 255};
38     Input<uint8_t> output_min_{"output_min"};
39     Input<uint8_t> output_max_{"output_max"};
40 
41     Output<Buffer<uint8_t>> output_{"output", 2};
42 
generate()43     void generate() {
44         // We take two 8 bit matrices as input.
45         Var x("x"), y("y");
46 
47         // Align the extent of the k dimension to the unroll factor used in the
48         // reduction. Unrolling there is needed to use the vrmpy instruction on
49         // Hexagon.
50         constexpr int kDotProductUnrollFactor = 4;
51         Expr k_extent = mat_a_.dim(0).extent();
52         k_extent =
53             (k_extent / (kDotProductUnrollFactor)) * (kDotProductUnrollFactor);
54         mat_a_.dim(0).set_extent(k_extent);
55         mat_b_.dim(1).set_extent(k_extent);
56 
57         // We split directly in the algorithm by a factor of 4
58         // (== kDotProductUnrollFactor), so we can generate vrmpy instructions on
59         // Hexagon.
60         RDom rk(0, k_extent / kDotProductUnrollFactor, "k");
61 
62         int vector_size_u8 = natural_vector_size<uint8_t>();
63         int vector_size_u32 = natural_vector_size<uint32_t>();
64         bool use_hexagon = false;
65         if (get_target().has_feature(Halide::Target::HVX_64)) {
66             vector_size_u8 = 64;
67             vector_size_u32 = 16;
68             use_hexagon = true;
69         } else if (get_target().has_feature(Halide::Target::HVX_128)) {
70             vector_size_u8 = 128;
71             vector_size_u32 = 32;
72             use_hexagon = true;
73         }
74 
75         // Define the reordering of mat_b_ as a separate stage so we can lift
76         // the interleaving required by vrmpy out of the inner loop.
77         Func mat_b_swizzled("mat_b_swizzled");
78         Var k("k");
79         mat_b_swizzled(x, y, k) = mat_b_(x, 4 * y + k);
80 
81         // We need to compute the matrix product:
82         //   (mat_a_ + mat_a_offset_ * 1_a) * (mat_b_ + mat_b_offset_ * 1_b),
83         // where
84         //   * mat_a_ and mat_b_ are 8-bit input matrices,
85         //   * mat_a_offset_ and mat_b_offset_ are scalar values in [-255, 0],
86         //   * 1_a is the matrix of same size as mat_a_ filled with 1s,
87         //   * 1_b is the matrix of same size as mat_b_ filled with 1s.
88         // If we add the offsets upfront then the matrix multiplication has to be
89         // carried out on 16-bit input matrices. To take full advantage of the
90         // available instructions we need to hit the correct pattern, which is
91         // the expression defining the Func multiplied_no_offsets below.
92         //
93         // Hence we have to factor the computation into the following four products
94         // (1) mat_a_ * mat_b_
95         // (2) mat_a_offset_ * 1_a * mat_b_
96         // (3) mat_b_offset_ * mat_a_ * 1_b
97         // (4) mat_a_offset_ * mat_b_offset_ * 1_a * 1_b
98         // Product (1) is the main matrix multiplication that we carry out with
99         // multiplied_no_offsets.
100         // Product (2) can be computed by computing the column sums of mat_b_ (dot
101         // product with the vector containing only 1s) and replicating the result
102         // into the right shape and multiplying by mat_a_offset_.
103         // Product (3) can be computed by computing the row sums of mat_a_ (dot
104         // product with the vector containing only 1s), replicating the result
105         // into the right shape and multiplying with mat_b_offset_.
106         // Finally, the product (4) is just
107         // mat_a_offset_ * mat_b_offset_ * mat_a_.width() replicated to every
108         // element of the resulting matrix.
109         Func multiplied_no_offsets("multiplied_no_offsets");
110         multiplied_no_offsets(x, y) = u32(0);
111         multiplied_no_offsets(x, y) +=
112             u32(u16(mat_a_(4 * rk + 0, y)) * u16(mat_b_swizzled(x, rk, 0))) +
113             u32(u16(mat_a_(4 * rk + 1, y)) * u16(mat_b_swizzled(x, rk, 1))) +
114             u32(u16(mat_a_(4 * rk + 2, y)) * u16(mat_b_swizzled(x, rk, 2))) +
115             u32(u16(mat_a_(4 * rk + 3, y)) * u16(mat_b_swizzled(x, rk, 3)));
116 
117         RDom fk(0, mat_a_.width(), "fk");
118 
119         // We could convert the row sums into a partial horizontal reduction that is
120         // vectorized. Partial results can be summed up by a scalar sum afterwards.
121         // While there is a performance benefit for large matrices, we did not
122         // observe any performance difference on practical models. So for simplicity
123         // we just use the straightforward row sum computation here.
124         Func row_sums_a("row_sums_a");
125         row_sums_a(y) = sum(u32(mat_a_(fk, y)));
126 
127         Func column_sums_b("column_sums_b");
128         column_sums_b(x) = sum(u32(mat_b_(x, fk)));
129 
130         Expr offset =
131             cast<int32_t>(mat_a_offset_) * cast<int32_t>(mat_b_offset_) * mat_a_.width();
132 
133         Func multiplied("multiplied");
134         multiplied(x, y) =
135             multiplied_no_offsets(x, y) +
136             i32(mat_a_offset_) * i32(column_sums_b(x)) +
137             i32(mat_b_offset_) * i32(row_sums_a(y)) + offset;
138 
139         // Scale the output.
140         Func scaled_plus_offset("scaled_plus_offset");
141         scaled_plus_offset(x, y) =
142             multiply_quantized_multiplier(multiplied(x, y) + bias_(x),
143                                           output_multiplier_, output_shift_) +
144             output_offset_;
145 
146         // Saturate and narrow the output.
147         output_(x, y) =
148             clamp(u8_sat(scaled_plus_offset(x, y)), output_min_, output_max_);
149 
150         // Specifying .hexagon() on a Func will generate an RPC to run this stage
151         // on Hexagon. If Hexagon is the host (that is, the architecture is
152         // Hexagon), we have to omit the .hexagon() directive as we are already
153         // running on Hexagon.
154         if (use_hexagon && get_target().arch != Target::Hexagon) {
155             output_.hexagon();
156         }
157 
158         constexpr int kTileSizeHeight = 4;
159         if (use_hexagon) {
160             Var xo("xo"), yo("yo");
161 
162             // Split the output into tiles, traversed in columns of tiles
163             // that we parallelize over.
164             output_.compute_root()
165                 .tile(x, y, xo, yo, x, y, vector_size_u8, kTileSizeHeight,
166                       TailStrategy::RoundUp)
167                 .reorder(yo, xo)
168                 .prefetch(mat_a_, yo)
169                 .vectorize(x)
170                 .unroll(y)
171                 .parallel(xo);
172 
173             // Compute the product at tiles of the output.
174             multiplied_no_offsets.compute_at(output_, yo).vectorize(x).unroll(y);
175 
176             multiplied_no_offsets.update(0).reorder(x, y, rk).vectorize(x).unroll(y);
177 
178             // Lift the swizzling out of the inner loop.
179             mat_b_swizzled.compute_at(output_, xo)
180                 .reorder_storage(k, x, y)
181                 .reorder(k, x, y)
182                 .vectorize(x)
183                 .unroll(k);
184 
185             // Split the rows into chunks we can parallelize over, but prefetch
186             // within.
187             Var yi("yi");
188             row_sums_a.compute_at(output_, Var::outermost())
189                 .split(y, y, yi, 32)
190                 .parallel(y)
191                 .prefetch(mat_a_, yi);
192 
193             Var xi("xi");
194             column_sums_b.compute_at(output_, Var::outermost())
195                 .split(x, x, xi, vector_size_u8, TailStrategy::GuardWithIf)
196                 .parallel(x)
197                 .vectorize(xi);
198 
199         } else {
200             Var xi("xi"), xii("xii"), yi("yi"), yii("yii");
201             RVar rki("rki");
202 
203             // This schedule taken from test/performance/MatrixMultiply.cpp
204             constexpr int kBlockSize = 32;
205             constexpr int kBlockSizeXi = 8;
206 
207             output_.compute_root()
208                 .tile(x, y, x, y, xi, yi, vector_size_u8, kTileSizeHeight,
209                       TailStrategy::RoundUp)
210                 .reorder(xi, yi, x, y)
211                 .vectorize(xi)
212                 .unroll(yi)
213                 .parallel(y);
214 
215             multiplied_no_offsets.compute_root().vectorize(x, vector_size_u32);
216 
217             multiplied_no_offsets.update(0)
218                 .split(x, x, xi, kBlockSize, TailStrategy::GuardWithIf)
219                 .split(xi, xi, xii, kBlockSizeXi, TailStrategy::GuardWithIf)
220                 .split(y, y, yi, kBlockSize, TailStrategy::GuardWithIf)
221                 .split(yi, yi, yii, 4, TailStrategy::GuardWithIf)
222                 .split(rk, rk, rki, kBlockSize, TailStrategy::GuardWithIf)
223                 .reorder(xii, yii, xi, rki, yi, rk, x, y)
224                 .parallel(y)
225                 .vectorize(xii)
226                 .unroll(xi)
227                 .unroll(yii);
228 
229             row_sums_a.compute_root().vectorize(y, vector_size_u8,
230                                                 TailStrategy::ShiftInwards);
231 
232             column_sums_b.compute_root().vectorize(x, vector_size_u8,
233                                                    TailStrategy::ShiftInwards);
234         }
235 
236         constexpr int kMatAHeightAlign = kTileSizeHeight;
237         int vector_dim_align = vector_size_u8;
238 
239         mat_a_.dim(0)
240             .set_bounds(0, mat_a_.dim(0).extent())
241             .dim(1)
242             .set_bounds(0, (mat_a_.dim(1).extent() / kMatAHeightAlign) * kMatAHeightAlign)
243             .set_stride((mat_a_.dim(1).stride() / kMatAHeightAlign) *
244                         kMatAHeightAlign);
245 
246         mat_b_.dim(0)
247             .set_bounds(0, (mat_b_.dim(0).extent() / vector_dim_align) * vector_dim_align)
248             .dim(1)
249             .set_bounds(0, mat_b_.dim(1).extent())
250             .set_stride(mat_b_.dim(1).stride());
251 
252         output_.dim(0)
253             .set_bounds(0, (output_.dim(0).extent() / vector_dim_align) * vector_dim_align)
254             .dim(1)
255             .set_bounds(0, (output_.dim(1).extent() / kMatAHeightAlign) * kMatAHeightAlign)
256             .set_stride((output_.dim(1).stride() / kMatAHeightAlign) * kMatAHeightAlign);
257 
258         bias_.dim(0).set_bounds(0, bias_.dim(0).extent());
259     }
260 };
261 
262 HALIDE_REGISTER_GENERATOR(MatrixMultiply, MatrixMultiply)
263