1 #include "test.h"
2 #include "../intgemm/aligned.h"
3 #include "../intgemm/callbacks.h"
4 #include "../intgemm/interleave.h"
5 #include "../intgemm/intgemm.h"
6 #include "../intgemm/multiply.h"
7 #include "../intgemm/stats.h"
8 
9 #include <algorithm>
10 #include <cassert>
11 #include <cmath>
12 #include <cstdio>
13 #include <cstdlib>
14 #include <cstring>
15 #include <iomanip>
16 #include <iostream>
17 #include <memory>
18 #include <numeric>
19 #include <random>
20 
21 namespace intgemm {
22 
23 #ifndef __INTEL_COMPILER
24 INTGEMM_SSE2
25 #endif
26 TEST_CASE("Transpose 16", "[transpose]") {
27   if (kCPU < CPUType::SSE2) return;
28   const unsigned N = 8;
29   AlignedVector<int16_t> input(N * N);
30   std::iota(input.begin(), input.end(), static_cast<int16_t>(0));
31 
32   AlignedVector<int16_t> ref(N * N);
33   references::Transpose(input.begin(), ref.begin(), N, N);
34 
35   // Overwrite input.
36   __m128i *t = input.as<__m128i>();
37   Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]);
38 
39   for (std::size_t i = 0; i < input.size(); ++i) {
40   	CHECK_MESSAGE(ref[i] == input[i], "16-bit transpose failure at: " << i << ": " << ref[i] << " != " << input[i]);
41   }
42 }
43 
44 #ifndef __INTEL_COMPILER
45 INTGEMM_SSSE3
46 #endif
47 TEST_CASE("Transpose 8", "[transpose]") {
48   if (kCPU < CPUType::SSSE3) return;
49   const unsigned N = 16;
50   AlignedVector<int8_t> input(N * N);
51   std::iota(input.begin(), input.end(), static_cast<int8_t>(0));
52 
53   AlignedVector<int8_t> ref(input.size());
54   references::Transpose(input.begin(), ref.begin(), N, N);
55 
56   // Overwrite input.
57   __m128i *t = input.as<__m128i>();
58   Transpose8InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]);
59 
60   for (std::size_t i = 0; i < input.size(); ++i) {
61     CHECK_MESSAGE(ref[i] == input[i], "8-bit transpose failure at " << i << ": " << (int16_t)ref[i] << " != " << (int16_t)input[i]);
62   }
63 }
64 
TestPrepare(Index rows=32,Index cols=16)65 template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) {
66   std::mt19937 gen;
67   // Go somewhat out of range too.
68   std::uniform_real_distribution<float> dist(-129.0, 129.0);
69   // Create array.
70   AlignedVector<float> input(rows * cols);
71   for (auto& it : input) {
72     it = dist(gen);
73   }
74 
75   using Integer = typename Routine::Integer;
76   // Call Prepare
77   AlignedVector<Integer> test(input.size());
78   Routine::PrepareB(input.begin(), test.begin(), 1, rows, cols);
79 
80   // Compute reference output.
81   AlignedVector<Integer> quantized(input.size());
82   Routine::Quantize(input.begin(), quantized.begin(), 1, static_cast<Index>(input.size()));
83   AlignedVector<Integer> reference(input.size());
84   // Note this won't work for Int8/Int16 generic routines because tile sizes vary.
85   references::Rearragement(quantized.begin(), reference.begin(), Routine::kBTileRow, Routine::kBTileCol, rows, cols);
86   CHECK_MESSAGE(memcmp(reference.begin(), test.begin(), test.size() * sizeof(Integer)) == 0, Routine::kName << " Mismatch:\n" <<
87   	"Quantized Input" << '\n' << PrintMatrix(quantized.begin(), rows, cols) << "Reference" << '\n' <<
88   	 PrintMatrix(reference.begin(), rows, cols) << "Routine" << '\n' << PrintMatrix(test.begin(), rows, cols));
89 }
90 
91 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
92 TEST_CASE("Prepare AVX512", "[prepare]") {
93   if (kCPU < CPUType::AVX512BW) return;
94 	TestPrepare<AVX512BW::Kernels8>(64, 8);
95 	TestPrepare<AVX512BW::Kernels8>(256, 32);
96     TestPrepare<AVX512BW::Kernels16>(64, 8);
97     TestPrepare<AVX512BW::Kernels16>(256, 32);
98 }
99 #endif
100 
101 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
102 TEST_CASE("Prepare AVX2", "[prepare]") {
103   if (kCPU < CPUType::AVX2) return;
104   TestPrepare<AVX2::Kernels8>(64, 32);
105   TestPrepare<AVX2::Kernels16>(64, 32);
106 }
107 #endif
108 
109 TEST_CASE("Prepare SSSE3", "[prepare]") {
110   if (kCPU < CPUType::SSSE3) return;
111   TestPrepare<SSSE3::Kernels8>(16, 8);
112   TestPrepare<SSSE3::Kernels8>(32, 16);
113   TestPrepare<SSSE3::Kernels8>(32, 32);
114 }
115 
116 TEST_CASE("Prepare SSE2", "[prepare]") {
117   if (kCPU < CPUType::SSE2) return;
118   TestPrepare<SSE2::Kernels16>(8, 8);
119   TestPrepare<SSE2::Kernels16>(32, 32);
120 }
121 
TestSelectColumnsB(Index rows=64,Index cols=16)122 template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 16) {
123   std::mt19937 gen;
124   // Go somewhat out of range too.
125   std::uniform_real_distribution<float> dist(-129.0, 129.0);
126   AlignedVector<float> input(rows * cols);
127   for (auto& it : input) {
128     it = dist(gen);
129   }
130   using Integer = typename Routine::Integer;
131   AlignedVector<Integer> prepared(input.size());
132   Routine::PrepareB(input.begin(), prepared.begin(), 1, rows, cols);
133 
134   const int kSelectCols = 24;
135   Index select_cols[kSelectCols];
136   std::uniform_int_distribution<Index> col_dist(0, cols - 1);
137   for (auto& it : select_cols) {
138     it = col_dist(gen);
139   }
140 
141   AlignedVector<Integer> test(rows * kSelectCols);
142   Routine::SelectColumnsB(prepared.begin(), test.begin(), rows, select_cols, select_cols + kSelectCols);
143 
144   // Select columns manually in float space.
145   AlignedVector<float> selected(rows * kSelectCols);
146   for (Index r = 0; r < rows; ++r) {
147     for (int c = 0; c < kSelectCols; ++c) {
148       assert(c + r * kSelectCols < rows * kSelectCols);
149       selected[c + r * kSelectCols] = input[select_cols[c] + r * cols];
150     }
151   }
152   AlignedVector<Integer> ref(rows * kSelectCols);
153   Routine::PrepareB(selected.begin(), ref.begin(), 1, rows, kSelectCols);
154   CHECK_MESSAGE(memcmp(ref.begin(), test.begin(), sizeof(Integer) * rows * kSelectCols) == 0, "Reference:\n" <<
155   	PrintMatrix(ref.begin(), rows, kSelectCols) << PrintMatrix(test.begin(), rows, kSelectCols));
156 }
157 
158 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
159 TEST_CASE("SelectColumnsB AVX512", "[select]") {
160   if (kCPU < CPUType::AVX512BW) return;
161     TestSelectColumnsB<AVX512BW::Kernels8>();
162     TestSelectColumnsB<AVX512BW::Kernels16>(256, 256);
163 }
164 #endif
165 
166 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
167 TEST_CASE("SelectColumnsB AVX2", "[select]") {
168   if (kCPU < CPUType::AVX2) return;
169   TestSelectColumnsB<AVX2::Kernels8>(256, 256);
170   TestSelectColumnsB<AVX2::Kernels16>(256, 256);
171 }
172 #endif
173 
174 TEST_CASE("SelectColumnsB SSSE3", "[select]") {
175   if (kCPU < CPUType::SSSE3) return;
176   TestSelectColumnsB<SSSE3::Kernels8>();
177   TestSelectColumnsB<SSSE3::Kernels8>(256, 256);
178 }
179 
180 TEST_CASE("SelectColumnsB SSE2", "[select]") {
181   if (kCPU < CPUType::SSE2) return;
182   TestSelectColumnsB<SSE2::Kernels16>();
183   TestSelectColumnsB<SSE2::Kernels16>(256, 256);
184 }
185 
TestMax()186 template <class Register> void TestMax() {
187   Register r = set1_ps<Register>(-2.0);
188   for (std::size_t i = 0; i < sizeof(Register) / sizeof(float); ++i) {
189     Register c = r;
190     reinterpret_cast<float*>(&c)[i] = -1.0;
191     CHECK_MESSAGE((MaxFloat32(c) == -1.0), "MaxFloat32 produced " << MaxFloat32(c));
192   }
193 }
194 
195 TEST_CASE("Max", "[max]") {
196   TestMax<__m128>();
197 }
198 
CompareMaxAbs(const float * begin,const float * end,float test,std::size_t offset)199 void CompareMaxAbs(const float *begin, const float *end, float test, std::size_t offset) {
200   float largest = std::fabs(*std::max_element(begin, end));
201   float smallest = std::fabs(*std::min_element(begin, end));
202   largest = std::max(largest, smallest);
203   CHECK_MESSAGE(largest == test, "Error: " << largest << " versus " << test << " in length " << (end - begin) << " offset " << offset);
204 }
205 
TestMaxAbsolute()206 template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute() {
207   std::mt19937 gen;
208   std::uniform_real_distribution<float> dist(-8.0, 8.0);
209   const std::size_t kLengthMax = 65;
210   AlignedVector<float> test(kLengthMax);
211   for (std::size_t len = 1; len < kLengthMax; ++len) {
212     for (std::size_t t = 0; t < len; ++t) {
213       // Fill with [-8, 8).
214       for (auto& it : test) {
215         it = dist(gen);
216       }
217       CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
218       test[t] = -32.0;
219       CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
220       test[t] = 32.0;
221       CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
222     }
223   }
224 }
225 
226 TEST_CASE("MaxAbsolute SSE2", "[max]") {
227   if (kCPU < CPUType::SSE2) return;
228   TestMaxAbsolute<SSE2::MaxAbsolute>();
229 }
230 
231 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
232 TEST_CASE("MaxAbsolute AVX2", "[max]") {
233   if (kCPU < CPUType::AVX2) return;
234   TestMaxAbsolute<AVX2::MaxAbsolute>();
235 }
236 #endif
237 
238 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
239 TEST_CASE("MaxAbsolute AVX512BW", "[max]") {
240   if (kCPU < CPUType::AVX512BW) return;
241   TestMaxAbsolute<AVX512BW::MaxAbsolute>();
242 }
243 #endif
244 
245 // Based on https://arxiv.org/abs/1705.01991
246 
247 // Copyright (c) 2017 Microsoft Corporation
248 
249 // Permission is hereby granted, free of charge, to any person obtaining a copy
250 // of this software and associated documentation files (the "Software"), to deal
251 // in the Software without restriction, including without limitation the rights
252 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
253 // copies of the Software, and to permit persons to whom the Software is
254 // furnished to do so, subject to the following conditions:
255 
256 // The above copyright notice and this permission notice shall be included in all
257 // copies or substantial portions of the Software.
258 
259 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
260 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
261 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
262 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
263 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
264 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
265 // SOFTWARE.
266 // Compute A*B slowly in floats.
267 
TestMultiply(Index A_rows,Index width,Index B_cols,float int_tolerance=.1,float float_tolerance=1,float MSE_float_tolerance=0,float MSE_int_tolerance=0)268 template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols,
269  float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
270   using Integer = typename Routine::Integer;
271   std::ostringstream info;
272   info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
273 
274   // Initialize A and B.
275   AlignedVector<float> A(A_rows * width);
276   AlignedVector<float> B(width * B_cols);
277   std::mt19937 gen;
278   std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
279   for (auto& it : A) {
280     it = dist(gen);
281   }
282   for (auto& it : B) {
283     it = dist(gen);
284   }
285 
286   float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
287   float unquant_mult = 1.0f / (quant_mult*quant_mult);
288 
289   AlignedVector<Integer> A_prep(A.size());
290   AlignedVector<Integer> B_prep(B.size());
291   Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
292   Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
293 
294   AlignedVector<float> test_C(A_rows * B_cols);
295   OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
296   // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
297   //   callbacks::Unquantize(unquant_mult),
298   //   callbacks::Write<float>(test_C.begin())
299   // ));
300 
301   AlignedVector<Integer> B_quant(B.size());
302   Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
303   AlignedVector<float> slowint_C(test_C.size());
304   // Assuming A is just quantization here.
305   references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) {
306     return sum * unquant_mult;
307   });
308 
309   AlignedVector<float> float_C(test_C.size());
310   references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) {
311     return static_cast<float>(sum);
312   });
313 
314   CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
315    int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
316 }
317 
TestMultiplyRelu(Index A_rows,Index width,Index B_cols,float int_tolerance=.1,float float_tolerance=1,float MSE_float_tolerance=0,float MSE_int_tolerance=0)318 template <class Routine> void TestMultiplyRelu(Index A_rows, Index width, Index B_cols,
319  float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
320   using Integer = typename Routine::Integer;
321   std::ostringstream info;
322   info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
323 
324   // Initialize A and B.
325   AlignedVector<float> A(A_rows * width);
326   AlignedVector<float> B(width * B_cols);
327   std::mt19937 gen;
328   std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
329   for (auto& it : A) {
330     it = dist(gen);
331   }
332   for (auto& it : B) {
333     it = dist(gen);
334   }
335 
336   float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
337   float unquant_mult = 1.0f / (quant_mult*quant_mult);
338 
339   AlignedVector<Integer> A_prep(A.size());
340   AlignedVector<Integer> B_prep(B.size());
341   Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
342   Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
343 
344   AlignedVector<float> test_C(A_rows * B_cols);
345   OMPParallelWrap<callbacks::UnquantizeAndWriteRelu, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin()));
346   // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
347   //   callbacks::Unquantize(unquant_mult),
348   //   callbacks::Write<float>(test_C.begin())
349   // ));
350 
351   AlignedVector<Integer> B_quant(B.size());
352   Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
353   AlignedVector<float> slowint_C(test_C.size());
354   // Assuming A is just quantization here.
355   references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) {
356     float ret = std::max(0.0f, sum * unquant_mult);
357     return ret;
358   });
359 
360   AlignedVector<float> float_C(test_C.size());
361   references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) {
362     return static_cast<float>(std::max(0.0,sum));
363   });
364 
365   CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
366    int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
367 }
368 
369 //Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols
370 //Require different number of arguments. I don't think the refactoring is worth it.
TestMultiplyBias(Index A_rows,Index width,Index B_cols,float int_tolerance=0.1f,float float_tolerance=1.0f,float MSE_float_tolerance=0.0f,float MSE_int_tolerance=0.0f)371 template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index B_cols,
372  float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) {
373   using Integer = typename Routine::Integer;
374   std::ostringstream info;
375   info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
376 
377   // Initialize A and B.
378   AlignedVector<float> A(A_rows * width);
379   AlignedVector<float> B(width * B_cols);
380   AlignedVector<float> bias(B_cols);
381   std::mt19937 gen;
382   std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
383   for (auto& it : A) {
384     it = dist(gen);
385   }
386   for (auto& it : B) {
387     it = dist(gen);
388   }
389   for (auto& it : bias) {
390     it = dist(gen);
391   }
392 
393   float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
394   float unquant_mult = 1.0f / (quant_mult*quant_mult);
395 
396   AlignedVector<Integer> A_prep(A.size());
397   AlignedVector<Integer> B_prep(B.size());
398   Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
399   Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
400 
401   AlignedVector<float> test_C(A_rows * B_cols);
402 
403   Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
404 
405   AlignedVector<Integer> B_quant(B.size());
406   Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
407   AlignedVector<float> slowint_C(test_C.size());
408   // Assuming A is just quantization here.
409   references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
410     return sum * unquant_mult + bias[info.col_idx];
411   });
412 
413   AlignedVector<float> float_C(test_C.size());
414   references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
415     return static_cast<float>(sum) + bias[info.col_idx];
416   });
417 
418   CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
419    int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
420 }
421 
TestMultiplyBiasRelu(Index A_rows,Index width,Index B_cols,float int_tolerance=0.1f,float float_tolerance=1.0f,float MSE_float_tolerance=0.0f,float MSE_int_tolerance=0.0f)422 template <class Routine> void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols,
423  float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) {
424   using Integer = typename Routine::Integer;
425   std::ostringstream info;
426   info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
427 
428   // Initialize A and B.
429   AlignedVector<float> A(A_rows * width);
430   AlignedVector<float> B(width * B_cols);
431   AlignedVector<float> bias(B_cols);
432   std::mt19937 gen;
433   std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
434   for (auto& it : A) {
435     it = dist(gen);
436   }
437   for (auto& it : B) {
438     it = dist(gen);
439   }
440   for (auto& it : bias) {
441     it = dist(gen);
442   }
443 
444   float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
445   float unquant_mult = 1.0f / (quant_mult*quant_mult);
446 
447   AlignedVector<Integer> A_prep(A.size());
448   AlignedVector<Integer> B_prep(B.size());
449   Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
450   Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
451 
452   AlignedVector<float> test_C(A_rows * B_cols);
453 
454   Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin()));
455 
456   AlignedVector<Integer> B_quant(B.size());
457   Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
458   AlignedVector<float> slowint_C(test_C.size());
459   // Assuming A is just quantization here.
460   references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
461     return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]);
462   });
463 
464   AlignedVector<float> float_C(test_C.size());
465   references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
466     return std::max(0.0f, static_cast<float>(sum) + bias[info.col_idx]);
467   });
468 
469   CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
470    int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
471 }
472 
473 TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
474   if (kCPU < CPUType::SSE2) return;
475   TestMultiply<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
476   TestMultiply<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
477   TestMultiply<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
478   TestMultiply<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
479   TestMultiply<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
480   TestMultiply<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
481 }
482 
483 TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") {
484   if (kCPU < CPUType::SSE2) return;
485   TestMultiplyRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
486   TestMultiplyRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
487   TestMultiplyRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
488   TestMultiplyRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
489   TestMultiplyRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
490   TestMultiplyRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
491 }
492 
493 TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
494   if (kCPU < CPUType::SSE2) return;
495   TestMultiplyBias<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
496   TestMultiplyBias<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
497   TestMultiplyBias<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
498   TestMultiplyBias<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
499   TestMultiplyBias<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
500   TestMultiplyBias<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
501 }
502 
503 TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") {
504   if (kCPU < CPUType::SSE2) return;
505   TestMultiplyBiasRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
506   TestMultiplyBiasRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
507   TestMultiplyBiasRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
508   TestMultiplyBiasRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
509   TestMultiplyBiasRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
510   TestMultiplyBiasRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
511 }
512 
513 TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
514   if (kCPU < CPUType::SSSE3) return;
515   TestMultiply<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
516   TestMultiply<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
517   TestMultiply<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
518   TestMultiply<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
519   TestMultiply<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
520   TestMultiply<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
521 }
522 
523 TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") {
524   if (kCPU < CPUType::SSSE3) return;
525   TestMultiplyRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
526   TestMultiplyRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
527   TestMultiplyRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
528   TestMultiplyRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
529   TestMultiplyRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
530   TestMultiplyRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
531 }
532 
533 TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
534   if (kCPU < CPUType::SSSE3) return;
535   TestMultiplyBias<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
536   TestMultiplyBias<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
537   TestMultiplyBias<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
538   TestMultiplyBias<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
539   TestMultiplyBias<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
540   TestMultiplyBias<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
541 }
542 
543 TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") {
544   if (kCPU < CPUType::SSSE3) return;
545   TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
546   TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
547   TestMultiplyBiasRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
548   TestMultiplyBiasRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
549   TestMultiplyBiasRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
550   TestMultiplyBiasRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
551 }
552 
553 
554 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
555 TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
556   if (kCPU < CPUType::AVX2) return;
557   TestMultiply<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
558   TestMultiply<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
559   TestMultiply<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
560   TestMultiply<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
561   TestMultiply<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
562   TestMultiply<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
563 }
564 
565 TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") {
566   if (kCPU < CPUType::AVX2) return;
567   TestMultiplyRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
568   TestMultiplyRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
569   TestMultiplyRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
570   TestMultiplyRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
571   TestMultiplyRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
572   TestMultiplyRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
573 }
574 
575 TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
576   if (kCPU < CPUType::AVX2) return;
577   TestMultiplyBias<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
578   TestMultiplyBias<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
579   TestMultiplyBias<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
580   TestMultiplyBias<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
581   TestMultiplyBias<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
582   TestMultiplyBias<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
583 }
584 
585 TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") {
586   if (kCPU < CPUType::AVX2) return;
587   TestMultiplyBiasRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
588   TestMultiplyBiasRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
589   TestMultiplyBiasRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
590   TestMultiplyBiasRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
591   TestMultiplyBiasRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
592   TestMultiplyBiasRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
593 }
594 
595 TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
596   if (kCPU < CPUType::AVX2) return;
597   TestMultiply<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
598   TestMultiply<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
599   TestMultiply<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
600   TestMultiply<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
601   TestMultiply<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
602   TestMultiply<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
603 }
604 
605 TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") {
606   if (kCPU < CPUType::AVX2) return;
607   TestMultiplyRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
608   TestMultiplyRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
609   TestMultiplyRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
610   TestMultiplyRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
611   TestMultiplyRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
612   TestMultiplyRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
613 }
614 
615 TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
616   if (kCPU < CPUType::AVX2) return;
617   TestMultiplyBias<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
618   TestMultiplyBias<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
619   TestMultiplyBias<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
620   TestMultiplyBias<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
621   TestMultiplyBias<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
622   TestMultiplyBias<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
623 }
624 
625 TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") {
626   if (kCPU < CPUType::AVX2) return;
627   TestMultiplyBiasRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
628   TestMultiplyBiasRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
629   TestMultiplyBiasRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
630   TestMultiplyBiasRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
631   TestMultiplyBiasRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
632   TestMultiplyBiasRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
633 }
634 #endif
635 
636 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
637   TEST_CASE ("Multiply AVX512 8bit", "[multiply]") {
638     if (kCPU < CPUType::AVX512BW) return;
639     TestMultiply<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
640     TestMultiply<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
641     TestMultiply<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
642     TestMultiply<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
643     TestMultiply<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
644     TestMultiply<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
645   }
646 
647   TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") {
648     if (kCPU < CPUType::AVX512BW) return;
649     TestMultiplyRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
650     TestMultiplyRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
651     TestMultiplyRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
652     TestMultiplyRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
653     TestMultiplyRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
654     TestMultiplyRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
655   }
656 
657   TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") {
658     if (kCPU < CPUType::AVX512BW) return;
659     TestMultiplyBias<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
660     TestMultiplyBias<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
661     TestMultiplyBias<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
662     TestMultiplyBias<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
663     TestMultiplyBias<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
664     TestMultiplyBias<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
665   }
666 
667   TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") {
668     if (kCPU < CPUType::AVX512BW) return;
669     TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
670     TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
671     TestMultiplyBiasRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
672     TestMultiplyBiasRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
673     TestMultiplyBiasRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
674     TestMultiplyBiasRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
675   }
676 
677   #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
678     TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") {
679       if (kCPU < CPUType::AVX512VNNI) return;
680       TestMultiply<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
681       TestMultiply<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
682       TestMultiply<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
683       TestMultiply<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
684       TestMultiply<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
685       TestMultiply<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
686     }
687 
688     TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") {
689       if (kCPU < CPUType::AVX512VNNI) return;
690       TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
691       TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
692       TestMultiplyRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
693       TestMultiplyRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
694       TestMultiplyRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
695       TestMultiplyRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
696     }
697 
698     TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") {
699       if (kCPU < CPUType::AVX512VNNI) return;
700       TestMultiplyBias<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
701       TestMultiplyBias<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
702       TestMultiplyBias<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
703       TestMultiplyBias<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
704       TestMultiplyBias<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
705       TestMultiplyBias<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
706     }
707 
708     TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") {
709       if (kCPU < CPUType::AVX512VNNI) return;
710       TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
711       TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
712       TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
713       TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
714       TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
715       TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
716     }
717   #endif
718 
719   TEST_CASE ("Multiply AVX512 16bit", "[multiply]") {
720     if (kCPU < CPUType::AVX512BW) return;
721     TestMultiply<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
722     TestMultiply<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
723     TestMultiply<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
724     TestMultiply<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
725     TestMultiply<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
726     TestMultiply<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
727   }
728 
729   TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") {
730     if (kCPU < CPUType::AVX512BW) return;
731     TestMultiplyRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
732     TestMultiplyRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
733     TestMultiplyRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
734     TestMultiplyRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
735     TestMultiplyRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
736     TestMultiplyRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
737   }
738 
739 
740   TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") {
741     if (kCPU < CPUType::AVX512BW) return;
742     TestMultiplyBias<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
743     TestMultiplyBias<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
744     TestMultiplyBias<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
745     TestMultiplyBias<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
746     TestMultiplyBias<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
747     TestMultiplyBias<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
748   }
749 
750   TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") {
751     if (kCPU < CPUType::AVX512BW) return;
752     TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
753     TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
754     TestMultiplyBiasRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
755     TestMultiplyBiasRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
756     TestMultiplyBiasRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
757     TestMultiplyBiasRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
758   }
759 #endif
760 
761 } // namespace intgemm
762