1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "libjit_defs.h"
17 
18 namespace {
19 
20 /// Macros for accessing submatrices of a matmul using the leading dimension.
21 #define A(i, j) a[(j)*lda + (i)]
22 #define B(i, j) b[(j)*ldb + (i)]
23 #define C(i, j) c[(j)*ldc + (i)]
24 
25 /// Naive gemm helper to handle oddly-sized matrices.
libjit_matmul_odd(int m,int n,int k,const float * a,int lda,const float * b,int ldb,float * c,int ldc)26 void libjit_matmul_odd(int m, int n, int k, const float *a, int lda,
27                        const float *b, int ldb, float *c, int ldc) {
28   // The order of these loops is tuned for column-major matrices.
29   for (int p = 0; p < k; p++) {
30     for (int j = 0; j < n; j++) {
31       for (int i = 0; i < m; i++) {
32         C(i, j) += A(i, p) * B(p, j);
33       }
34     }
35   }
36 }
37 
38 /// Number of registers to use for rows of A in the dot-product kernel.
39 constexpr int regsA = 4;
40 /// Number of registers to use for columns of B in the dot-product kernel.
41 constexpr int regsB = 3;
42 
43 /// Number of rows of A to process in the kernel.  Vector loads are used for A,
44 /// so we load eight times as many floats as we use registers.
45 constexpr int mr = regsA * 8;
46 /// Number of columns of B to process in the kernel.
47 constexpr int nr = regsB;
48 
49 /// Blocking parameters for the outer kernel.  We multiply mc x kc blocks of A
50 /// with kc x nc panels of B (this approach is referred to as `gebp` in the
51 /// literature).  TODO: Generalize these parameters for other cache sizes.
52 constexpr int mc = 256;
53 constexpr int kc = 128;
54 constexpr int nc = 4096;
55 
56 /// Compute a RAxRB block of C using a vectorized dot product, where RA is the
57 /// number of registers to load from matrix A, and RB is the number of registers
58 /// to load from matrix B.
59 template <size_t regsA, size_t regsB>
libjit_matmul_dot(size_t k,const float * a,size_t lda,const float * b,size_t ldb,float * c,size_t ldc)60 void libjit_matmul_dot(size_t k, const float *a, size_t lda, const float *b,
61                        size_t ldb, float *c, size_t ldc) {
62   float8 csum[regsA][regsB] = {{0.0}};
63   for (size_t p = 0; p < k; p++) {
64     // Perform the DOT product.
65     for (size_t ai = 0; ai < regsA; ai++) {
66       float8 aa = LoaduFloat8(&A(ai * 8, p));
67       for (size_t bi = 0; bi < regsB; bi++) {
68         float8 bb = BroadcastFloat8(B(p, bi));
69         csum[ai][bi] += aa * bb;
70       }
71     }
72   }
73 
74   // Accumulate the results into C.
75   for (size_t bi = 0; bi < regsB; bi++) {
76     for (size_t ai = 0; ai < regsA; ai++) {
77       AdduFloat8(&C(ai * 8, bi), csum[ai][bi]);
78     }
79   }
80 }
81 
82 /// Similar to libjit_matmul_dot, but assumes that \p a and \p b have been
83 /// packed using z-ordering.
84 template <size_t regsA, size_t regsB>
libjit_matmul_zdot(size_t k,const float * a,size_t lda,const float * b,size_t ldb,float * c,size_t ldc)85 void libjit_matmul_zdot(size_t k, const float *a, size_t lda, const float *b,
86                         size_t ldb, float *c, size_t ldc) {
87   float8 csum[regsA][regsB] = {{0.0}};
88 
89   for (size_t p = 0; p < k; p++) {
90     // Perform the DOT product.
91     float8 *aptr = (float8 *)&A(0, p);
92     for (size_t ai = 0; ai < regsA; ai++) {
93       float8 aa = *aptr++;
94       for (size_t bi = 0; bi < regsB; bi++) {
95         float8 bb = BroadcastFloat8(*(b + bi));
96         csum[ai][bi] += aa * bb;
97       }
98     }
99     b += regsB;
100   }
101 
102   // Accumulate the results into C.
103   for (size_t bi = 0; bi < regsB; bi++) {
104     for (size_t ai = 0; ai < regsA; ai++) {
105       AdduFloat8(&C(ai * 8, bi), csum[ai][bi]);
106     }
107   }
108 }
109 
110 /// Pack matrix \p a into matrix \p a_to using a z-ordering, so that the
111 /// dot-product kernel can stride sequentially through memory.
112 template <size_t regsA>
pack_matrix_a(size_t m,size_t k,const float * a,size_t lda,float * a_to)113 void pack_matrix_a(size_t m, size_t k, const float *a, size_t lda,
114                    float *a_to) {
115   for (int i = 0; i < int(m) - mr + 1; i += mr) {
116     for (size_t j = 0; j < k; j++) {
117       const float *a_ij_pntr = &A(i, j);
118       for (size_t ai = 0; ai < regsA; ai++) {
119         StoreuFloat8(a_to + 8 * ai, LoaduFloat8(a_ij_pntr + 8 * ai));
120       }
121       a_to += 8 * regsA;
122     }
123   }
124 }
125 
126 /// Pack matrix \p b into matrix \p b_to using a z-ordering, so that the
127 /// dot-product kernel can stride sequentially through memory, rather than
128 /// reading from `regsB` separate columns.
129 template <size_t regsB>
pack_matrix_b(size_t n,size_t k,const float * b,size_t ldb,float * b_to)130 void pack_matrix_b(size_t n, size_t k, const float *b, size_t ldb,
131                    float *b_to) {
132   for (int j = 0; j < int(n) - nr + 1; j += nr) {
133     for (size_t i = 0; i < k; i++) {
134       for (size_t bi = 0; bi < regsB; bi++) {
135         *b_to++ = B(i, j + bi);
136       }
137     }
138   }
139 }
140 
141 /// Inner kernel for packed matrices.  The order of the M and N loops matters,
142 /// because packed matrices need to be more more sensitive to cache locality,
143 /// and N strides over the B matrix, which is very large and will blow out the
144 /// cache.
libjit_matmul_inner_packed(int m,int n,int k,const float * packedA,const float * packedB,float * c,int ldc)145 void libjit_matmul_inner_packed(int m, int n, int k, const float *packedA,
146                                 const float *packedB, float *c, int ldc) {
147   for (int j = 0; j < n - nr + 1; j += nr) {
148     for (int i = 0; i < m - mr + 1; i += mr) {
149       libjit_matmul_zdot<regsA, regsB>(k, &packedA[i * k], mr, &packedB[j * k],
150                                        k, &C(i, j), ldc);
151     }
152   }
153 }
154 
155 /// Inner kernel for non-packed matrices.  In these cases N is small, so it
156 /// tends to be beneficial to retain locality in the A matrix.
libjit_matmul_inner_unpacked(int m,int n,int k,const float * a,int lda,const float * b,int ldb,float * c,int ldc)157 void libjit_matmul_inner_unpacked(int m, int n, int k, const float *a, int lda,
158                                   const float *b, int ldb, float *c, int ldc) {
159   for (int i = 0; i < m - mr + 1; i += mr) {
160     for (int j = 0; j < n - nr + 1; j += nr) {
161       libjit_matmul_dot<regsA, regsB>(k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j),
162                                       ldc);
163     }
164   }
165 }
166 
167 /// Compute a portion of C one block at a time.  Handle ragged edges with calls
168 /// to a slow but general helper.
169 template <bool pack>
libjit_matmul_inner(int m,int n,int k,const float * a,int lda,const float * b,int ldb,float * c,int ldc,float * packedB)170 void libjit_matmul_inner(int m, int n, int k, const float *a, int lda,
171                          const float *b, int ldb, float *c, int ldc,
172                          float *packedB) {
173   // The tiling scheme naturally divides the input matrices into 2 parts each;
174   // one tiled section, and three "ragged" edges.
175   //
176   // --------------------    -------
177   // | A00*B00 | A00*B01|    | A00 |   -------------
178   // -------------------- += ------- * | B00 | B01 |
179   // | A10*B00 | A10*B01|    | A10 |   -------------
180   // --------------------    -------
181   //
182   // We can process this as 4 separate matrix multiplications.  A00*B00 is the
183   // perfectly-tiled portion, which we handly with a 4x16 dot-product kernel.
184   // The ragged edges are (ideally) less critical, so we handle them with a call
185   // to a general matrix-multiplication for odd sizes.
186   float packedA[m * k] __attribute__((aligned(64)));
187   if (pack) {
188     pack_matrix_a<regsA>(m, k, &A(0, 0), lda, packedA);
189   }
190 
191   if (pack) {
192     libjit_matmul_inner_packed(m, n, k, packedA, packedB, c, ldc);
193   } else {
194     libjit_matmul_inner_unpacked(m, n, k, a, lda, b, ldb, c, ldc);
195   }
196 
197   sdim_t i = (m / mr) * mr;
198   sdim_t j = (n / nr) * nr;
199   if (i < m) {
200     libjit_matmul_odd(m - i, j, k, &A(i, 0), lda, &B(0, 0), ldb, &C(i, 0), ldc);
201   }
202   if (j < n) {
203     libjit_matmul_odd(i, n - j, k, &A(0, 0), lda, &B(0, j), ldb, &C(0, j), ldc);
204   }
205   if (i < m && j < n) {
206     libjit_matmul_odd(m - i, n - j, k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j),
207                       ldc);
208   }
209 }
210 
211 /// Tile A into mc * kc blocks, where mc and kc are chosen to approximately fit
212 /// the L2 cache on recent Intel processors (e.g., 256 KB for Skylake).  Stream
213 /// kc * n panels of B through memory to compute each mc * n block of C.
214 /// \p a is an \p m x \p k column-major matrix;
215 /// \p b is a \p k x \p n column-major matrix;
216 /// \p c is a \p m x \p n column-major matrix.
217 /// \p lda, \p ldb, and \p ldc are the leading dimensions of A, B, and C,
218 /// respectively.
219 template <bool pack>
220 void __attribute__((noinline))
libjit_matmul_outer(dim_t m,dim_t n,dim_t k,const float * a,dim_t lda,const float * b,dim_t ldb,float * c,dim_t ldc)221 libjit_matmul_outer(dim_t m, dim_t n, dim_t k, const float *a, dim_t lda,
222                     const float *b, dim_t ldb, float *c, dim_t ldc) {
223   float *packedB = nullptr;
224   if (pack) {
225     libjit_aligned_malloc((void **)&packedB, 64, kc * nc);
226   }
227 
228   for (dim_t p = 0; p < k; p += kc) {
229     dim_t pb = MIN(k - p, kc);
230     for (dim_t j = 0; j < n; j += nc) {
231       dim_t jb = MIN(n - j, nc);
232       if (pack) {
233         pack_matrix_b<regsB>(jb, pb, &B(p, j), ldb, packedB);
234       }
235       for (dim_t i = 0; i < m; i += mc) {
236         dim_t ib = MIN(m - i, mc);
237         libjit_matmul_inner<pack>(ib, jb, pb, &A(i, p), lda, &B(p, j), ldb,
238                                   &C(i, j), ldc, packedB);
239       }
240     }
241   }
242 
243   if (pack) {
244     libjit_aligned_free(packedB);
245   }
246 }
247 
248 #undef C
249 #undef B
250 #undef A
251 
252 /// Generic template for rowwise quantized FullyConnected. The template allows
253 /// choosing element type and bias type.
254 template <typename ElemTy, typename BiasElemTy>
libjit_rowwise_quantized_fc_generic(ElemTy * outW,const ElemTy * inW,const ElemTy * weightsW,const BiasElemTy * biasW,const int32_t * weightsOffsets,const int32_t * biasPre,const int32_t * biasPost,const int32_t * biasScale,const int32_t * outPre,const int32_t * outPost,const int32_t * outScale,const dim_t * outWdims,const dim_t * inWdims,const dim_t * weightsWdims,const dim_t * biasWdims,dim_t rowNum,int32_t outOffset,int32_t inOffset,int32_t biasOffset)255 void libjit_rowwise_quantized_fc_generic(
256     ElemTy *outW, const ElemTy *inW, const ElemTy *weightsW,
257     const BiasElemTy *biasW, const int32_t *weightsOffsets,
258     const int32_t *biasPre, const int32_t *biasPost, const int32_t *biasScale,
259     const int32_t *outPre, const int32_t *outPost, const int32_t *outScale,
260     const dim_t *outWdims, const dim_t *inWdims, const dim_t *weightsWdims,
261     const dim_t *biasWdims, dim_t rowNum, int32_t outOffset, int32_t inOffset,
262     int32_t biasOffset) {
263   dim_t in_w = inWdims[1];
264   dim_t out_h = outWdims[0];
265   dim_t out_w = outWdims[1];
266 
267   // In rowwise quantized FC, weights is not pretransposed : I * Tranpose(W) +
268   // B. out(i, j) = in(i, 0) * weights(j, 0) + in(i, 1) * weights(j, 1) + ... +
269   //                in(i, k) * weights(j, k) + bias(j);
270   for (size_t i = 0; i < out_h; i++) {
271     for (size_t j = 0; j < out_w; j++) {
272       int32_t sum = 0;
273       for (size_t k = 0; k < in_w; k++) {
274         int32_t W = weightsW[libjit_getXY(weightsWdims, j, k)];
275         int32_t I = inW[libjit_getXY(inWdims, i, k)];
276         sum += (W - weightsOffsets[j]) * (I - inOffset);
277       }
278       int32_t B = libjit_scale_i32i8(biasW[j] - biasOffset, biasPre[j],
279                                      biasPost[j], biasScale[j], 0);
280       sum += B;
281       int32_t scaledSum = libjit_scale_i32i8(sum, outPre[j], outPost[j],
282                                              outScale[j], outOffset);
283       outW[libjit_getXY(outWdims, i, j)] = libjit_clip(scaledSum);
284     }
285   }
286 }
287 } // namespace
288 
289 extern "C" {
290 
291 /// Performs the matrix multiplication c = a * b, where c, a, and b are
292 /// row-major matrices.
293 /// \p c is a m x n matrix, so \p cDims = {m, n}
294 /// \p a is a m x k matrix, so \p aDims = {m, k}
295 /// \p b is a k x n matrix, so \p bDims = {k, n}
libjit_matmul_f(float * c,const float * a,const float * b,const dim_t * cDims,const dim_t * aDims,const dim_t * bDims)296 void libjit_matmul_f(float *c, const float *a, const float *b,
297                      const dim_t *cDims, const dim_t *aDims,
298                      const dim_t *bDims) {
299   memset(c, 0, cDims[0] * cDims[1] * sizeof(float));
300   // Call the matrix multiplication routine with appropriate dimensions and
301   // leading dimensions. The "leading dimension" for a row-major matrix is equal
302   // to the number of columns in the matrix.  For a, this is k; for b and c,
303   // this is n.
304   //
305   // This "outer" helper assumes the matrices are given in column-major format
306   // (the packing algorithm is more effective with column-major matrices), while
307   // the input is row-major. So we compute C += B * A, which is equivalent.
308   //
309   // The matrix multiplication routine is heavily inspired by:
310   // https://github.com/flame/how-to-optimize-gemm
311   int m = cDims[1];
312   int n = cDims[0];
313   int k = aDims[1];
314 
315   // Use the unpacked version which does not use extra HEAP or STACK which
316   // makes the memory usage predictable. This is very useful when building
317   // bundles (AOT) for MCU targets where the HEAP and STACK are relatively
318   // limited in size. By avoiding heap/stack usage the memory consumption
319   // is controlled and perfectly known (e.g. printed in the bundle API).
320   libjit_matmul_outer<false>(m, n, k, b, bDims[1], a, aDims[1], c, cDims[1]);
321 }
322 
libjit_matmul_i8(int8_t * outW,const int8_t * lhsW,const int8_t * rhsW,const dim_t * outWdims,const dim_t * lhsWdims,const dim_t * rhsWdims,int32_t outOffset,int32_t lhsOffset,int32_t rhsOffset,int32_t outPre,int32_t outPost,int32_t outScale)323 void libjit_matmul_i8(int8_t *outW, const int8_t *lhsW, const int8_t *rhsW,
324                       const dim_t *outWdims, const dim_t *lhsWdims,
325                       const dim_t *rhsWdims, int32_t outOffset,
326                       int32_t lhsOffset, int32_t rhsOffset, int32_t outPre,
327                       int32_t outPost, int32_t outScale) {
328   for (dim_t x = 0; x < outWdims[0]; x++) {
329     for (dim_t y = 0; y < outWdims[1]; y++) {
330       int32_t sum = 0;
331       for (dim_t i = 0; i < lhsWdims[1]; i++) {
332         int32_t lhs = lhsW[libjit_getXY(lhsWdims, x, i)] - lhsOffset;
333         int32_t rhs = rhsW[libjit_getXY(rhsWdims, i, y)] - rhsOffset;
334         sum += lhs * rhs;
335       }
336       int32_t s = libjit_scale_i32i8(sum, outPre, outPost, outScale, outOffset);
337       outW[libjit_getXY(outWdims, x, y)] = libjit_clip(s);
338     }
339   }
340 }
341 
342 /// Rowwise quantized FullyConnected with int8 precision and int32 bias.
libjit_rowwise_quantized_fc_i8_i32(int8_t * outW,const int8_t * inW,const int8_t * weightsW,const int32_t * biasW,const int32_t * weightsOffsets,const int32_t * biasPre,const int32_t * biasPost,const int32_t * biasScale,const int32_t * outPre,const int32_t * outPost,const int32_t * outScale,const dim_t * outWdims,const dim_t * inWdims,const dim_t * weightsWdims,const dim_t * biasWdims,dim_t rowNum,int32_t outOffset,int32_t inOffset,int32_t biasOffset)343 void libjit_rowwise_quantized_fc_i8_i32(
344     int8_t *outW, const int8_t *inW, const int8_t *weightsW,
345     const int32_t *biasW, const int32_t *weightsOffsets, const int32_t *biasPre,
346     const int32_t *biasPost, const int32_t *biasScale, const int32_t *outPre,
347     const int32_t *outPost, const int32_t *outScale, const dim_t *outWdims,
348     const dim_t *inWdims, const dim_t *weightsWdims, const dim_t *biasWdims,
349     dim_t rowNum, int32_t outOffset, int32_t inOffset, int32_t biasOffset) {
350   libjit_rowwise_quantized_fc_generic<int8_t, int32_t>(
351       outW, inW, weightsW, biasW, weightsOffsets, biasPre, biasPost, biasScale,
352       outPre, outPost, outScale, outWdims, inWdims, weightsWdims, biasWdims,
353       rowNum, outOffset, inOffset, biasOffset);
354 }
355 
356 /// Rowwise quantized FullyConnected with int8 precision and int8 bias.
libjit_rowwise_quantized_fc_i8_i8(int8_t * outW,const int8_t * inW,const int8_t * weightsW,const int8_t * biasW,const int32_t * weightsOffsets,const int32_t * biasPre,const int32_t * biasPost,const int32_t * biasScale,const int32_t * outPre,const int32_t * outPost,const int32_t * outScale,const dim_t * outWdims,const dim_t * inWdims,const dim_t * weightsWdims,const dim_t * biasWdims,dim_t rowNum,int32_t outOffset,int32_t inOffset,int32_t biasOffset)357 void libjit_rowwise_quantized_fc_i8_i8(
358     int8_t *outW, const int8_t *inW, const int8_t *weightsW,
359     const int8_t *biasW, const int32_t *weightsOffsets, const int32_t *biasPre,
360     const int32_t *biasPost, const int32_t *biasScale, const int32_t *outPre,
361     const int32_t *outPost, const int32_t *outScale, const dim_t *outWdims,
362     const dim_t *inWdims, const dim_t *weightsWdims, const dim_t *biasWdims,
363     dim_t rowNum, int32_t outOffset, int32_t inOffset, int32_t biasOffset) {
364   libjit_rowwise_quantized_fc_generic<int8_t, int8_t>(
365       outW, inW, weightsW, biasW, weightsOffsets, biasPre, biasPost, biasScale,
366       outPre, outPost, outScale, outWdims, inWdims, weightsWdims, biasWdims,
367       rowNum, outOffset, inOffset, biasOffset);
368 }
369 }
370