1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "libjit_defs.h"
17
18 namespace {
19
20 /// Macros for accessing submatrices of a matmul using the leading dimension.
21 #define A(i, j) a[(j)*lda + (i)]
22 #define B(i, j) b[(j)*ldb + (i)]
23 #define C(i, j) c[(j)*ldc + (i)]
24
25 /// Naive gemm helper to handle oddly-sized matrices.
libjit_matmul_odd(int m,int n,int k,const float * a,int lda,const float * b,int ldb,float * c,int ldc)26 void libjit_matmul_odd(int m, int n, int k, const float *a, int lda,
27 const float *b, int ldb, float *c, int ldc) {
28 // The order of these loops is tuned for column-major matrices.
29 for (int p = 0; p < k; p++) {
30 for (int j = 0; j < n; j++) {
31 for (int i = 0; i < m; i++) {
32 C(i, j) += A(i, p) * B(p, j);
33 }
34 }
35 }
36 }
37
38 /// Number of registers to use for rows of A in the dot-product kernel.
39 constexpr int regsA = 4;
40 /// Number of registers to use for columns of B in the dot-product kernel.
41 constexpr int regsB = 3;
42
43 /// Number of rows of A to process in the kernel. Vector loads are used for A,
44 /// so we load eight times as many floats as we use registers.
45 constexpr int mr = regsA * 8;
46 /// Number of columns of B to process in the kernel.
47 constexpr int nr = regsB;
48
49 /// Blocking parameters for the outer kernel. We multiply mc x kc blocks of A
50 /// with kc x nc panels of B (this approach is referred to as `gebp` in the
51 /// literature). TODO: Generalize these parameters for other cache sizes.
52 constexpr int mc = 256;
53 constexpr int kc = 128;
54 constexpr int nc = 4096;
55
56 /// Compute a RAxRB block of C using a vectorized dot product, where RA is the
57 /// number of registers to load from matrix A, and RB is the number of registers
58 /// to load from matrix B.
59 template <size_t regsA, size_t regsB>
libjit_matmul_dot(size_t k,const float * a,size_t lda,const float * b,size_t ldb,float * c,size_t ldc)60 void libjit_matmul_dot(size_t k, const float *a, size_t lda, const float *b,
61 size_t ldb, float *c, size_t ldc) {
62 float8 csum[regsA][regsB] = {{0.0}};
63 for (size_t p = 0; p < k; p++) {
64 // Perform the DOT product.
65 for (size_t ai = 0; ai < regsA; ai++) {
66 float8 aa = LoaduFloat8(&A(ai * 8, p));
67 for (size_t bi = 0; bi < regsB; bi++) {
68 float8 bb = BroadcastFloat8(B(p, bi));
69 csum[ai][bi] += aa * bb;
70 }
71 }
72 }
73
74 // Accumulate the results into C.
75 for (size_t bi = 0; bi < regsB; bi++) {
76 for (size_t ai = 0; ai < regsA; ai++) {
77 AdduFloat8(&C(ai * 8, bi), csum[ai][bi]);
78 }
79 }
80 }
81
82 /// Similar to libjit_matmul_dot, but assumes that \p a and \p b have been
83 /// packed using z-ordering.
84 template <size_t regsA, size_t regsB>
libjit_matmul_zdot(size_t k,const float * a,size_t lda,const float * b,size_t ldb,float * c,size_t ldc)85 void libjit_matmul_zdot(size_t k, const float *a, size_t lda, const float *b,
86 size_t ldb, float *c, size_t ldc) {
87 float8 csum[regsA][regsB] = {{0.0}};
88
89 for (size_t p = 0; p < k; p++) {
90 // Perform the DOT product.
91 float8 *aptr = (float8 *)&A(0, p);
92 for (size_t ai = 0; ai < regsA; ai++) {
93 float8 aa = *aptr++;
94 for (size_t bi = 0; bi < regsB; bi++) {
95 float8 bb = BroadcastFloat8(*(b + bi));
96 csum[ai][bi] += aa * bb;
97 }
98 }
99 b += regsB;
100 }
101
102 // Accumulate the results into C.
103 for (size_t bi = 0; bi < regsB; bi++) {
104 for (size_t ai = 0; ai < regsA; ai++) {
105 AdduFloat8(&C(ai * 8, bi), csum[ai][bi]);
106 }
107 }
108 }
109
110 /// Pack matrix \p a into matrix \p a_to using a z-ordering, so that the
111 /// dot-product kernel can stride sequentially through memory.
112 template <size_t regsA>
pack_matrix_a(size_t m,size_t k,const float * a,size_t lda,float * a_to)113 void pack_matrix_a(size_t m, size_t k, const float *a, size_t lda,
114 float *a_to) {
115 for (int i = 0; i < int(m) - mr + 1; i += mr) {
116 for (size_t j = 0; j < k; j++) {
117 const float *a_ij_pntr = &A(i, j);
118 for (size_t ai = 0; ai < regsA; ai++) {
119 StoreuFloat8(a_to + 8 * ai, LoaduFloat8(a_ij_pntr + 8 * ai));
120 }
121 a_to += 8 * regsA;
122 }
123 }
124 }
125
126 /// Pack matrix \p b into matrix \p b_to using a z-ordering, so that the
127 /// dot-product kernel can stride sequentially through memory, rather than
128 /// reading from `regsB` separate columns.
129 template <size_t regsB>
pack_matrix_b(size_t n,size_t k,const float * b,size_t ldb,float * b_to)130 void pack_matrix_b(size_t n, size_t k, const float *b, size_t ldb,
131 float *b_to) {
132 for (int j = 0; j < int(n) - nr + 1; j += nr) {
133 for (size_t i = 0; i < k; i++) {
134 for (size_t bi = 0; bi < regsB; bi++) {
135 *b_to++ = B(i, j + bi);
136 }
137 }
138 }
139 }
140
141 /// Inner kernel for packed matrices. The order of the M and N loops matters,
142 /// because packed matrices need to be more more sensitive to cache locality,
143 /// and N strides over the B matrix, which is very large and will blow out the
144 /// cache.
libjit_matmul_inner_packed(int m,int n,int k,const float * packedA,const float * packedB,float * c,int ldc)145 void libjit_matmul_inner_packed(int m, int n, int k, const float *packedA,
146 const float *packedB, float *c, int ldc) {
147 for (int j = 0; j < n - nr + 1; j += nr) {
148 for (int i = 0; i < m - mr + 1; i += mr) {
149 libjit_matmul_zdot<regsA, regsB>(k, &packedA[i * k], mr, &packedB[j * k],
150 k, &C(i, j), ldc);
151 }
152 }
153 }
154
155 /// Inner kernel for non-packed matrices. In these cases N is small, so it
156 /// tends to be beneficial to retain locality in the A matrix.
libjit_matmul_inner_unpacked(int m,int n,int k,const float * a,int lda,const float * b,int ldb,float * c,int ldc)157 void libjit_matmul_inner_unpacked(int m, int n, int k, const float *a, int lda,
158 const float *b, int ldb, float *c, int ldc) {
159 for (int i = 0; i < m - mr + 1; i += mr) {
160 for (int j = 0; j < n - nr + 1; j += nr) {
161 libjit_matmul_dot<regsA, regsB>(k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j),
162 ldc);
163 }
164 }
165 }
166
167 /// Compute a portion of C one block at a time. Handle ragged edges with calls
168 /// to a slow but general helper.
169 template <bool pack>
libjit_matmul_inner(int m,int n,int k,const float * a,int lda,const float * b,int ldb,float * c,int ldc,float * packedB)170 void libjit_matmul_inner(int m, int n, int k, const float *a, int lda,
171 const float *b, int ldb, float *c, int ldc,
172 float *packedB) {
173 // The tiling scheme naturally divides the input matrices into 2 parts each;
174 // one tiled section, and three "ragged" edges.
175 //
176 // -------------------- -------
177 // | A00*B00 | A00*B01| | A00 | -------------
178 // -------------------- += ------- * | B00 | B01 |
179 // | A10*B00 | A10*B01| | A10 | -------------
180 // -------------------- -------
181 //
182 // We can process this as 4 separate matrix multiplications. A00*B00 is the
183 // perfectly-tiled portion, which we handly with a 4x16 dot-product kernel.
184 // The ragged edges are (ideally) less critical, so we handle them with a call
185 // to a general matrix-multiplication for odd sizes.
186 float packedA[m * k] __attribute__((aligned(64)));
187 if (pack) {
188 pack_matrix_a<regsA>(m, k, &A(0, 0), lda, packedA);
189 }
190
191 if (pack) {
192 libjit_matmul_inner_packed(m, n, k, packedA, packedB, c, ldc);
193 } else {
194 libjit_matmul_inner_unpacked(m, n, k, a, lda, b, ldb, c, ldc);
195 }
196
197 sdim_t i = (m / mr) * mr;
198 sdim_t j = (n / nr) * nr;
199 if (i < m) {
200 libjit_matmul_odd(m - i, j, k, &A(i, 0), lda, &B(0, 0), ldb, &C(i, 0), ldc);
201 }
202 if (j < n) {
203 libjit_matmul_odd(i, n - j, k, &A(0, 0), lda, &B(0, j), ldb, &C(0, j), ldc);
204 }
205 if (i < m && j < n) {
206 libjit_matmul_odd(m - i, n - j, k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j),
207 ldc);
208 }
209 }
210
211 /// Tile A into mc * kc blocks, where mc and kc are chosen to approximately fit
212 /// the L2 cache on recent Intel processors (e.g., 256 KB for Skylake). Stream
213 /// kc * n panels of B through memory to compute each mc * n block of C.
214 /// \p a is an \p m x \p k column-major matrix;
215 /// \p b is a \p k x \p n column-major matrix;
216 /// \p c is a \p m x \p n column-major matrix.
217 /// \p lda, \p ldb, and \p ldc are the leading dimensions of A, B, and C,
218 /// respectively.
219 template <bool pack>
220 void __attribute__((noinline))
libjit_matmul_outer(dim_t m,dim_t n,dim_t k,const float * a,dim_t lda,const float * b,dim_t ldb,float * c,dim_t ldc)221 libjit_matmul_outer(dim_t m, dim_t n, dim_t k, const float *a, dim_t lda,
222 const float *b, dim_t ldb, float *c, dim_t ldc) {
223 float *packedB = nullptr;
224 if (pack) {
225 libjit_aligned_malloc((void **)&packedB, 64, kc * nc);
226 }
227
228 for (dim_t p = 0; p < k; p += kc) {
229 dim_t pb = MIN(k - p, kc);
230 for (dim_t j = 0; j < n; j += nc) {
231 dim_t jb = MIN(n - j, nc);
232 if (pack) {
233 pack_matrix_b<regsB>(jb, pb, &B(p, j), ldb, packedB);
234 }
235 for (dim_t i = 0; i < m; i += mc) {
236 dim_t ib = MIN(m - i, mc);
237 libjit_matmul_inner<pack>(ib, jb, pb, &A(i, p), lda, &B(p, j), ldb,
238 &C(i, j), ldc, packedB);
239 }
240 }
241 }
242
243 if (pack) {
244 libjit_aligned_free(packedB);
245 }
246 }
247
248 #undef C
249 #undef B
250 #undef A
251
252 /// Generic template for rowwise quantized FullyConnected. The template allows
253 /// choosing element type and bias type.
254 template <typename ElemTy, typename BiasElemTy>
libjit_rowwise_quantized_fc_generic(ElemTy * outW,const ElemTy * inW,const ElemTy * weightsW,const BiasElemTy * biasW,const int32_t * weightsOffsets,const int32_t * biasPre,const int32_t * biasPost,const int32_t * biasScale,const int32_t * outPre,const int32_t * outPost,const int32_t * outScale,const dim_t * outWdims,const dim_t * inWdims,const dim_t * weightsWdims,const dim_t * biasWdims,dim_t rowNum,int32_t outOffset,int32_t inOffset,int32_t biasOffset)255 void libjit_rowwise_quantized_fc_generic(
256 ElemTy *outW, const ElemTy *inW, const ElemTy *weightsW,
257 const BiasElemTy *biasW, const int32_t *weightsOffsets,
258 const int32_t *biasPre, const int32_t *biasPost, const int32_t *biasScale,
259 const int32_t *outPre, const int32_t *outPost, const int32_t *outScale,
260 const dim_t *outWdims, const dim_t *inWdims, const dim_t *weightsWdims,
261 const dim_t *biasWdims, dim_t rowNum, int32_t outOffset, int32_t inOffset,
262 int32_t biasOffset) {
263 dim_t in_w = inWdims[1];
264 dim_t out_h = outWdims[0];
265 dim_t out_w = outWdims[1];
266
267 // In rowwise quantized FC, weights is not pretransposed : I * Tranpose(W) +
268 // B. out(i, j) = in(i, 0) * weights(j, 0) + in(i, 1) * weights(j, 1) + ... +
269 // in(i, k) * weights(j, k) + bias(j);
270 for (size_t i = 0; i < out_h; i++) {
271 for (size_t j = 0; j < out_w; j++) {
272 int32_t sum = 0;
273 for (size_t k = 0; k < in_w; k++) {
274 int32_t W = weightsW[libjit_getXY(weightsWdims, j, k)];
275 int32_t I = inW[libjit_getXY(inWdims, i, k)];
276 sum += (W - weightsOffsets[j]) * (I - inOffset);
277 }
278 int32_t B = libjit_scale_i32i8(biasW[j] - biasOffset, biasPre[j],
279 biasPost[j], biasScale[j], 0);
280 sum += B;
281 int32_t scaledSum = libjit_scale_i32i8(sum, outPre[j], outPost[j],
282 outScale[j], outOffset);
283 outW[libjit_getXY(outWdims, i, j)] = libjit_clip(scaledSum);
284 }
285 }
286 }
287 } // namespace
288
289 extern "C" {
290
291 /// Performs the matrix multiplication c = a * b, where c, a, and b are
292 /// row-major matrices.
293 /// \p c is a m x n matrix, so \p cDims = {m, n}
294 /// \p a is a m x k matrix, so \p aDims = {m, k}
295 /// \p b is a k x n matrix, so \p bDims = {k, n}
libjit_matmul_f(float * c,const float * a,const float * b,const dim_t * cDims,const dim_t * aDims,const dim_t * bDims)296 void libjit_matmul_f(float *c, const float *a, const float *b,
297 const dim_t *cDims, const dim_t *aDims,
298 const dim_t *bDims) {
299 memset(c, 0, cDims[0] * cDims[1] * sizeof(float));
300 // Call the matrix multiplication routine with appropriate dimensions and
301 // leading dimensions. The "leading dimension" for a row-major matrix is equal
302 // to the number of columns in the matrix. For a, this is k; for b and c,
303 // this is n.
304 //
305 // This "outer" helper assumes the matrices are given in column-major format
306 // (the packing algorithm is more effective with column-major matrices), while
307 // the input is row-major. So we compute C += B * A, which is equivalent.
308 //
309 // The matrix multiplication routine is heavily inspired by:
310 // https://github.com/flame/how-to-optimize-gemm
311 int m = cDims[1];
312 int n = cDims[0];
313 int k = aDims[1];
314
315 // Use the unpacked version which does not use extra HEAP or STACK which
316 // makes the memory usage predictable. This is very useful when building
317 // bundles (AOT) for MCU targets where the HEAP and STACK are relatively
318 // limited in size. By avoiding heap/stack usage the memory consumption
319 // is controlled and perfectly known (e.g. printed in the bundle API).
320 libjit_matmul_outer<false>(m, n, k, b, bDims[1], a, aDims[1], c, cDims[1]);
321 }
322
libjit_matmul_i8(int8_t * outW,const int8_t * lhsW,const int8_t * rhsW,const dim_t * outWdims,const dim_t * lhsWdims,const dim_t * rhsWdims,int32_t outOffset,int32_t lhsOffset,int32_t rhsOffset,int32_t outPre,int32_t outPost,int32_t outScale)323 void libjit_matmul_i8(int8_t *outW, const int8_t *lhsW, const int8_t *rhsW,
324 const dim_t *outWdims, const dim_t *lhsWdims,
325 const dim_t *rhsWdims, int32_t outOffset,
326 int32_t lhsOffset, int32_t rhsOffset, int32_t outPre,
327 int32_t outPost, int32_t outScale) {
328 for (dim_t x = 0; x < outWdims[0]; x++) {
329 for (dim_t y = 0; y < outWdims[1]; y++) {
330 int32_t sum = 0;
331 for (dim_t i = 0; i < lhsWdims[1]; i++) {
332 int32_t lhs = lhsW[libjit_getXY(lhsWdims, x, i)] - lhsOffset;
333 int32_t rhs = rhsW[libjit_getXY(rhsWdims, i, y)] - rhsOffset;
334 sum += lhs * rhs;
335 }
336 int32_t s = libjit_scale_i32i8(sum, outPre, outPost, outScale, outOffset);
337 outW[libjit_getXY(outWdims, x, y)] = libjit_clip(s);
338 }
339 }
340 }
341
342 /// Rowwise quantized FullyConnected with int8 precision and int32 bias.
libjit_rowwise_quantized_fc_i8_i32(int8_t * outW,const int8_t * inW,const int8_t * weightsW,const int32_t * biasW,const int32_t * weightsOffsets,const int32_t * biasPre,const int32_t * biasPost,const int32_t * biasScale,const int32_t * outPre,const int32_t * outPost,const int32_t * outScale,const dim_t * outWdims,const dim_t * inWdims,const dim_t * weightsWdims,const dim_t * biasWdims,dim_t rowNum,int32_t outOffset,int32_t inOffset,int32_t biasOffset)343 void libjit_rowwise_quantized_fc_i8_i32(
344 int8_t *outW, const int8_t *inW, const int8_t *weightsW,
345 const int32_t *biasW, const int32_t *weightsOffsets, const int32_t *biasPre,
346 const int32_t *biasPost, const int32_t *biasScale, const int32_t *outPre,
347 const int32_t *outPost, const int32_t *outScale, const dim_t *outWdims,
348 const dim_t *inWdims, const dim_t *weightsWdims, const dim_t *biasWdims,
349 dim_t rowNum, int32_t outOffset, int32_t inOffset, int32_t biasOffset) {
350 libjit_rowwise_quantized_fc_generic<int8_t, int32_t>(
351 outW, inW, weightsW, biasW, weightsOffsets, biasPre, biasPost, biasScale,
352 outPre, outPost, outScale, outWdims, inWdims, weightsWdims, biasWdims,
353 rowNum, outOffset, inOffset, biasOffset);
354 }
355
356 /// Rowwise quantized FullyConnected with int8 precision and int8 bias.
libjit_rowwise_quantized_fc_i8_i8(int8_t * outW,const int8_t * inW,const int8_t * weightsW,const int8_t * biasW,const int32_t * weightsOffsets,const int32_t * biasPre,const int32_t * biasPost,const int32_t * biasScale,const int32_t * outPre,const int32_t * outPost,const int32_t * outScale,const dim_t * outWdims,const dim_t * inWdims,const dim_t * weightsWdims,const dim_t * biasWdims,dim_t rowNum,int32_t outOffset,int32_t inOffset,int32_t biasOffset)357 void libjit_rowwise_quantized_fc_i8_i8(
358 int8_t *outW, const int8_t *inW, const int8_t *weightsW,
359 const int8_t *biasW, const int32_t *weightsOffsets, const int32_t *biasPre,
360 const int32_t *biasPost, const int32_t *biasScale, const int32_t *outPre,
361 const int32_t *outPost, const int32_t *outScale, const dim_t *outWdims,
362 const dim_t *inWdims, const dim_t *weightsWdims, const dim_t *biasWdims,
363 dim_t rowNum, int32_t outOffset, int32_t inOffset, int32_t biasOffset) {
364 libjit_rowwise_quantized_fc_generic<int8_t, int8_t>(
365 outW, inW, weightsW, biasW, weightsOffsets, biasPre, biasPost, biasScale,
366 outPre, outPost, outScale, outWdims, inWdims, weightsWdims, biasWdims,
367 rowNum, outOffset, inOffset, biasOffset);
368 }
369 }
370