1 #pragma once
2 
3 #include "intgemm/intgemm_config.h"
4 
5 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
6 
7 #include "interleave.h"
8 #include "kernels.h"
9 #include "multiply.h"
10 #include "types.h"
11 
12 #include <cassert>
13 #include <cstddef>
14 #include <cstdint>
15 #include <cstdlib>
16 
17 /* AVX512 implementation.
18  * This uses INTGEMM_AVX512BW, INTGEMM_AVX512DQ, and might use AVX512VL
19  * That means it supports mainstream CPUs with AVX512, starting with Skylake
20  * Xeons.
21  * It does not support any Knights / Xeon Phi processors.
22  *
23  * All memory must be 64-byte aligned.
24  */
25 
26 namespace intgemm {
27 
28 // AVX512 has combined collapse and store instructions:
29 // _mm512_mask_cvtsepi32_storeu_epi16
30 // _mm512_mask_cvtsepi32_storeu_epi8
31 // So conversion in memory uses these, but I also implement a wider version for
32 // rearranging B.
33 
34 namespace AVX512BW {
35 
36 // Load from memory, multiply, and convert to int32_t.
37 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
QuantizerGrab(const float * input,const __m512 quant_mult_reg)38 INTGEMM_AVX512BW inline __m512i QuantizerGrab(const float *input, const __m512 quant_mult_reg) {
39   return kernels::quantize(loadu_ps<__m512>(input), quant_mult_reg);
40 }
41 
42 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_SELECT_COL_B(INTGEMM_AVX512BW,__m512i)43 INTGEMM_SELECT_COL_B(INTGEMM_AVX512BW, __m512i)
44 
45 // For PrepareB we want to read 8 columns at a time.  When converting 32-bit
46 // floats to 8-bit values, that's 32 bytes of floats.  But AVX512 is 64 bytes
47 // wide so it reads off the edge of the tile.  We could expand the tile size
48 // but then the memory written to won't be contiguous anyway so we'd be doing a
49 // scatter anyway.  Easier to just read the 8 columns we wanted as 256 bits
50 // concatenate.
51 INTGEMM_AVX512DQ inline __m512 Concat(const __m256 first, const __m256 second) {
52   // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway.
53   return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1);
54 }
55 
56 // Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be controlled independently.
57 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
QuantizerGrabHalves(const float * input0,const float * input1,const __m512 quant_mult_reg)58 INTGEMM_AVX512BW inline __m512i QuantizerGrabHalves(const float *input0, const float *input1, const __m512 quant_mult_reg) {
59   __m512 appended = Concat(loadu_ps<__m256>(input0), loadu_ps<__m256>(input1));
60   appended = _mm512_mul_ps(appended, quant_mult_reg);
61   return _mm512_cvtps_epi32(appended);
62 }
63 
64 // These are only used for reshaping due to the AVX512 instructions
65 // _mm512_mask_cvtsepi32_storeu_epi16 and _mm512_mask_cvtsepi32_storeu_epi8
66 // being used for the quantizer.
67 class QuantizeTile16 {
68   public:
ConsecutiveWithWrapping(FRegister quant_mult,const float * input,Index cols_left,Index cols,Index row_step)69     INTGEMM_AVX512BW static inline Register ConsecutiveWithWrapping(FRegister quant_mult, const float *input, Index cols_left, Index cols, Index row_step) {
70       auto input0 = input;
71       auto input1 = input + 16 + (cols_left <= 16 ? cols * (row_step - 1) : 0);
72       auto g0 = QuantizerGrabHalves(input0, input1, quant_mult);
73       auto g1 = QuantizerGrabHalves(input0 + 8, input1 + 8, quant_mult);
74       auto packed = packs_epi32(g0, g1);
75       return _mm512_permutex_epi64(packed, 0xd8 /* 0, 2, 1, 3 */);
76     }
77 
ForReshape(FRegister quant_mult,const float * input,Index cols)78     INTGEMM_AVX512BW static inline Register ForReshape(FRegister quant_mult, const float *input, Index cols) {
79       __m512i g0 = QuantizerGrabHalves(input, input + 16 * cols, quant_mult);
80       __m512i g1 = QuantizerGrabHalves(input + 8 * cols, input + 24 * cols, quant_mult);
81       __m512i packed = packs_epi32(g0, g1);
82       // Permute within 256-bit lanes, so same as INTGEMM_AVX2
83       return _mm512_permutex_epi64(packed, 0xd8 /* 0, 2, 1, 3 */);
84     }
85 };
86 
87 class QuantizeTile8 {
88   public:
ConsecutiveWithWrapping(FRegister quant_mult,const float * input,Index cols_left,Index cols,Index row_step)89     INTGEMM_AVX512BW static inline Register ConsecutiveWithWrapping(FRegister quant_mult, const float *input, Index cols_left, Index cols, Index row_step) {
90       static const __m512i neg127 = _mm512_set1_epi8(-127);
91       static const __m512i shuffle_param = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
92 
93       const float* inputs[4];
94       for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
95         while (cols_left < sizeof(Register) / sizeof(float)) {
96           input += cols * (row_step - 1);
97           cols_left += cols;
98         }
99         inputs[i] = input;
100         input += sizeof(Register) / sizeof(float);
101         cols_left -= sizeof(Register) / sizeof(float);
102       }
103 
104       auto g0 = QuantizerGrab(inputs[0], quant_mult);
105       auto g1 = QuantizerGrab(inputs[1], quant_mult);
106       auto g2 = QuantizerGrab(inputs[2], quant_mult);
107       auto g3 = QuantizerGrab(inputs[3], quant_mult);
108 
109       auto packed0 = packs_epi32(g0, g1);
110       auto packed1 = packs_epi32(g2, g3);
111       auto packed = _mm512_packs_epi16(packed0, packed1);
112       packed = _mm512_max_epi8(packed, neg127);
113       return _mm512_permutexvar_epi32(shuffle_param, packed);
114     }
115 
ForReshape(FRegister quant_mult,const float * input,Index cols)116     INTGEMM_AVX512BW static inline __m512i ForReshape(FRegister quant_mult, const float *input, Index cols) {
117       // TODO: try alternative: _mm512_cvtsepi32_epi8 ?
118       const __m512i neg127 = _mm512_set1_epi8(-127);
119       // In reverse order: grabbing the first 32-bit values from each 128-bit register, then the second 32-bit values, etc.
120       const __m512i shuffle_param = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
121 
122       // 32-bit format.
123       __m512i g0 = QuantizerGrabHalves(input, input + 2 * cols, quant_mult);
124       __m512i g1 = QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult);
125       __m512i g2 = QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult);
126       __m512i g3 = QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult);
127       // Pack 32-bit to 16-bit.
128       __m512i packed0 = packs_epi32(g0, g1);
129       __m512i packed1 = packs_epi32(g2, g3);
130       // Pack 16-bit to 8-bit.
131       __m512i packed = _mm512_packs_epi16(packed0, packed1);
132       // Ban -128.
133       packed = _mm512_max_epi8(packed, neg127);
134       // 0 1 2 3 16 17 18 19 32 33 34 35 48 49 50 51 4 5 6 7 20 21 22 23 36 37 38 39 52 53 54 55 8 9 10 11 24 25 26 27 40 41 42 43 56 57 58 59 12 13 14 15 28 29 30 31 44 45 46 47 60 61 62 63
135       return _mm512_permutexvar_epi32(shuffle_param, packed);
136     }
137 };
138 
139 struct Kernels16 {
140   typedef int16_t Integer;
141 
142   // Currently A is prepared by quantization but this could theoretically change.
143   // rows * cols must be a multiple of 16.
144   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
PrepareAKernels16145   INTGEMM_AVX512BW static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
146     Quantize(input, output, quant_mult, rows * cols);
147   }
148 
149   // Technically output can be unaligned in Quantize.
150   // But then it will need to be aligned for Multiply.
151   // size must be a multiple of 16.
152   // Convert to 16-bit signed integers.
153   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
QuantizeKernels16154   INTGEMM_AVX512BW static void Quantize(const float *input, int16_t *output, float quant_mult, Index size) {
155     assert(size % 16 == 0);
156     assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
157     // Fill with the quantization multiplier.
158     const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
159     const float *end = input + size;
160     for (; input != end; input += 16, output += 16) {
161       // There doesn't seem to be an unmasked version.
162       _mm512_mask_cvtsepi32_storeu_epi16(output, 0xffff, QuantizerGrab(input, quant_mult_reg));
163     }
164   }
165 
166 
167   // Tile size for B; B must be a multiple of this block size.
168   static const Index kBTileRow = 32;
169   static const Index kBTileCol = 8;
170 
171   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_PREPARE_B_16Kernels16172   INTGEMM_PREPARE_B_16(INTGEMM_AVX512BW, QuantizeTile16)
173   INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_AVX512BW, int16_t)
174   INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX512BW, QuantizeTile16, int16_t)
175 
176   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
177   INTGEMM_AVX512BW static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
178     SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end);
179   }
180 
181   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
182   INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
183 
184   constexpr static const char *const kName = "16-bit AVX512";
185 
186   static const CPUType kUses = CPUType::AVX512BW;
187 };
188 
189 struct Kernels8 {
190   typedef int8_t Integer;
191 
192   // Currently A is prepared by quantization but this could theoretically change.
193   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
PrepareAKernels8194   INTGEMM_AVX512BW static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
195     Quantize(input, output, quant_mult, rows * cols);
196   }
197 
198  private:
199   /* g++ (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0 does not carry target attributes
200    * to the hidden function it creates in implementing #pragma omp parallel for.
201    * So intrinstics were not working inside the for loop when compiled with
202    * OMP. Also, passing register types across #pragma omp parallel for
203    * generated an internal compiler error.
204    * The problem does not occur in g++-8 (Ubuntu 8.3.0-6ubuntu1~18.04.1) 8.3.0.
205    * As a workaround, I split into #pragma omp parallel with boring types
206    * passed across the boundary then call this function with target attributes.
207    */
QuantizeThreadKernels8208   INTGEMM_AVX512BW static void QuantizeThread(const float *input, int8_t *output, float quant_mult, std::size_t count) {
209     const __m512i neg127 = _mm512_set1_epi32(-127);
210     const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
211     const std::size_t kBatch = sizeof(__m512i) / sizeof(float);
212 #pragma omp for
213     for (std::size_t i = 0; i < count; i += kBatch) {
214       __m512i asint = QuantizerGrab(input + i, quant_mult_reg);
215       asint = _mm512_max_epi32(asint, neg127);
216       // There doesn't seem to be an unmasked version.
217       _mm512_mask_cvtsepi32_storeu_epi8(output + i, 0xffff, asint);
218     }
219   }
220 
221  public:
222   // Technically output can be unaligned in Quantize.
223   // But then it will need to be aligned for Multiply.
224   // Convert to 8-bit signed integers.
225   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
QuantizeKernels8226   INTGEMM_AVX512BW static void Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
227     assert(reinterpret_cast<uintptr_t>(input) % sizeof(__m512i) == 0);
228     const std::size_t kBatch = sizeof(__m512i) / sizeof(float);
229     std::size_t fast_size = (size & ~(kBatch - 1));
230     const float *fast_input_end = input + fast_size;
231     int8_t *fast_output_end = output + fast_size;
232 #pragma omp parallel
233     {
234       QuantizeThread(input, output, quant_mult, fast_size);
235     }
236     std::size_t overhang = size & (kBatch - 1);
237     if (!overhang) return; // We needed a branch anyway for the empty case.
238     const __m512i neg127 = _mm512_set1_epi32(-127);
239     const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
240     __m512i asint = QuantizerGrab(fast_input_end, quant_mult_reg);
241     asint = _mm512_max_epi32(asint, neg127);
242     _mm512_mask_cvtsepi32_storeu_epi8(fast_output_end, (1 << overhang) - 1, asint);
243   }
244 
245   // Preparing A for the signed/unsigned multiplication. Using add 127
246   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
PrepareAKernels8247   INTGEMM_AVX512BW static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) {
248     QuantizeU(input, output, quant_mult, rows * cols);
249   }
250 
251   // Technically output can be unaligned in Quantize.
252   // But then it will need to be aligned for Multiply.
253   // Convert to 8-bit signed integers.
254   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
255 
QuantizeUKernels8256   INTGEMM_AVX512BW static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) {
257     assert(size % 16 == 0);
258     assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
259     const __m512i pos127 = _mm512_set1_epi32(127);
260     const __m512i zero = _mm512_setzero_si512();
261     const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
262     const float *end = input + size;
263     for (; input < end; input += 16, output += 16) {
264       __m512i asint = QuantizerGrab(input, quant_mult_reg);
265       asint = _mm512_min_epi32(asint, pos127);
266       asint = _mm512_add_epi32(asint, pos127);
267       asint = _mm512_max_epi32(asint, zero);
268       _mm512_mask_cvtusepi32_storeu_epi8(output, 0xffff, asint);
269     }
270   }
271 
272   // Tile size for B; B must be a multiple of this block size.
273   static const Index kBTileRow = 64;
274   static const Index kBTileCol = 8;
275 
276   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_PREPARE_B_8Kernels8277   INTGEMM_PREPARE_B_8(INTGEMM_AVX512BW, QuantizeTile8)
278   INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_AVX512BW, int8_t)
279   INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX512BW, QuantizeTile8, int8_t)
280 
281   /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
282   INTGEMM_AVX512BW static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
283     SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows, cols_begin, cols_end);
284   }
285 
286   // Special AVX512 implementation due to having 32 registers (so I don't have to
287   // allocate registers manually) and no sign instruction.
288   template <typename Callback>
MultiplyKernels8289   INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
290     // This is copy-paste from Multiply8_SSE2OrAVX2.
291     assert(width % sizeof(Register) == 0);
292     assert(B_cols % 8 == 0);
293     assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0);
294     assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
295     // There's 8 results for INTGEMM_AVX2 to handle.
296     auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
297     const Index simd_width = width / sizeof(Register);
298     // Added for AVX512.
299     Register zeros = setzero_si<Register>();
300     // Go over 8 columns of B at a time.
301 #pragma omp for
302     for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
303       const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
304       // Process one row of A at a time.  Doesn't seem to be faster to do multiple rows of A at once.
305       for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
306         // Iterate over shared (inner) dimension.
307         const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width);
308         const Register *A_end = A_live + simd_width;
309         const Register *B_live = B0_col;
310 
311         // Do the first iteration to initialize the sums.
312         __m512i a = *A_live;
313         __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128));
314         __m512i a_positive = _mm512_abs_epi8(a);
315         // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.
316         Register sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0]));
317         Register sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1]));
318         Register sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2]));
319         Register sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3]));
320         Register sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4]));
321         Register sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5]));
322         Register sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6]));
323         Register sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7]));
324 
325         ++A_live;
326         B_live += 8;
327 
328         // Use A as the loop variable so the add can be done where gcc likes it
329         // for branch prediction.
330         for (; A_live != A_end; ++A_live, B_live += 8) {
331           // Unique code here: can we do an inline function?
332           // Retrieve a.  We will use this as the unsigned part.
333           a = *A_live;
334           // Retrieve the conveniently consecutive values of B.
335           __m512i b0 = *B_live;
336           __m512i b1 = *(B_live + 1);
337           __m512i b2 = *(B_live + 2);
338           __m512i b3 = *(B_live + 3);
339           __m512i b4 = *(B_live + 4);
340           __m512i b5 = *(B_live + 5);
341           __m512i b6 = *(B_live + 6);
342           __m512i b7 = *(B_live + 7);
343 
344           // Get a mask where a is negative.
345           // Didn't seem to make a difference definining sign bits here vs at top
346           neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128));
347           a_positive = _mm512_abs_epi8(a);
348 
349           // Negate by subtracting from zero with a mask.
350           b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0);
351           b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1);
352           b2 = _mm512_mask_sub_epi8(b2, neg_mask, zeros, b2);
353           b3 = _mm512_mask_sub_epi8(b3, neg_mask, zeros, b3);
354           b4 = _mm512_mask_sub_epi8(b4, neg_mask, zeros, b4);
355           b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5);
356           b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6);
357           b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7);
358           // The magic 8-bit multiply then horizontal sum into 16-bit.
359           b0 = _mm512_maddubs_epi16(a_positive, b0);
360           b1 = _mm512_maddubs_epi16(a_positive, b1);
361           b2 = _mm512_maddubs_epi16(a_positive, b2);
362           b3 = _mm512_maddubs_epi16(a_positive, b3);
363           b4 = _mm512_maddubs_epi16(a_positive, b4);
364           b5 = _mm512_maddubs_epi16(a_positive, b5);
365           b6 = _mm512_maddubs_epi16(a_positive, b6);
366           b7 = _mm512_maddubs_epi16(a_positive, b7);
367           // Now we have 16-bit results that are the sum of two multiplies.
368           // Choosing to approximate and do adds.
369           // Perhaps every so often we could accumulate by upcasting.
370           sum0 = _mm512_adds_epi16(sum0, b0);
371           sum1 = _mm512_adds_epi16(sum1, b1);
372           sum2 = _mm512_adds_epi16(sum2, b2);
373           sum3 = _mm512_adds_epi16(sum3, b3);
374           sum4 = _mm512_adds_epi16(sum4, b4);
375           sum5 = _mm512_adds_epi16(sum5, b5);
376           sum6 = _mm512_adds_epi16(sum6, b6);
377           sum7 = _mm512_adds_epi16(sum7, b7);
378           // Unique code ends: can we do an inline function?
379         }
380         // Upcast to 32-bit and horizontally add.
381         Register ones = set1_epi16<Register>(1);
382         sum0 = madd_epi16(sum0, ones);
383         sum1 = madd_epi16(sum1, ones);
384         sum2 = madd_epi16(sum2, ones);
385         sum3 = madd_epi16(sum3, ones);
386         sum4 = madd_epi16(sum4, ones);
387         sum5 = madd_epi16(sum5, ones);
388         sum6 = madd_epi16(sum6, ones);
389         sum7 = madd_epi16(sum7, ones);
390         Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
391         Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
392 
393         auto total = PermuteSummer(pack0123, pack4567);
394         callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
395       }
396     }
397   }
398 
399   INTGEMM_MULTIPLY8SHIFT(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
400 
401   INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
402 
403   constexpr static const char *const kName = "8-bit AVX512BW";
404 
405   static const CPUType kUses = CPUType::AVX512BW;
406 };
407 
408 } // namespace AVX512BW
409 } // namespace intgemm
410 
411 #endif
412