1 #include "test.h"
2 #include "../intgemm/aligned.h"
3 #include "../intgemm/callbacks.h"
4 #include "../intgemm/interleave.h"
5 #include "../intgemm/intgemm.h"
6 #include "../intgemm/multiply.h"
7 #include "../intgemm/stats.h"
8
9 #include <algorithm>
10 #include <cassert>
11 #include <cmath>
12 #include <cstdio>
13 #include <cstdlib>
14 #include <cstring>
15 #include <iomanip>
16 #include <iostream>
17 #include <memory>
18 #include <numeric>
19 #include <random>
20
21 namespace intgemm {
22
23 #ifndef __INTEL_COMPILER
24 INTGEMM_SSE2
25 #endif
26 TEST_CASE("Transpose 16", "[transpose]") {
27 if (kCPU < CPUType::SSE2) return;
28 const unsigned N = 8;
29 AlignedVector<int16_t> input(N * N);
30 std::iota(input.begin(), input.end(), static_cast<int16_t>(0));
31
32 AlignedVector<int16_t> ref(N * N);
33 references::Transpose(input.begin(), ref.begin(), N, N);
34
35 // Overwrite input.
36 __m128i *t = input.as<__m128i>();
37 Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]);
38
39 for (std::size_t i = 0; i < input.size(); ++i) {
40 CHECK_MESSAGE(ref[i] == input[i], "16-bit transpose failure at: " << i << ": " << ref[i] << " != " << input[i]);
41 }
42 }
43
44 #ifndef __INTEL_COMPILER
45 INTGEMM_SSSE3
46 #endif
47 TEST_CASE("Transpose 8", "[transpose]") {
48 if (kCPU < CPUType::SSSE3) return;
49 const unsigned N = 16;
50 AlignedVector<int8_t> input(N * N);
51 std::iota(input.begin(), input.end(), static_cast<int8_t>(0));
52
53 AlignedVector<int8_t> ref(input.size());
54 references::Transpose(input.begin(), ref.begin(), N, N);
55
56 // Overwrite input.
57 __m128i *t = input.as<__m128i>();
58 Transpose8InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]);
59
60 for (std::size_t i = 0; i < input.size(); ++i) {
61 CHECK_MESSAGE(ref[i] == input[i], "8-bit transpose failure at " << i << ": " << (int16_t)ref[i] << " != " << (int16_t)input[i]);
62 }
63 }
64
TestPrepare(Index rows=32,Index cols=16)65 template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) {
66 std::mt19937 gen;
67 // Go somewhat out of range too.
68 std::uniform_real_distribution<float> dist(-129.0, 129.0);
69 // Create array.
70 AlignedVector<float> input(rows * cols);
71 for (auto& it : input) {
72 it = dist(gen);
73 }
74
75 using Integer = typename Routine::Integer;
76 // Call Prepare
77 AlignedVector<Integer> test(input.size());
78 Routine::PrepareB(input.begin(), test.begin(), 1, rows, cols);
79
80 // Compute reference output.
81 AlignedVector<Integer> quantized(input.size());
82 Routine::Quantize(input.begin(), quantized.begin(), 1, static_cast<Index>(input.size()));
83 AlignedVector<Integer> reference(input.size());
84 // Note this won't work for Int8/Int16 generic routines because tile sizes vary.
85 references::Rearragement(quantized.begin(), reference.begin(), Routine::kBTileRow, Routine::kBTileCol, rows, cols);
86 CHECK_MESSAGE(memcmp(reference.begin(), test.begin(), test.size() * sizeof(Integer)) == 0, Routine::kName << " Mismatch:\n" <<
87 "Quantized Input" << '\n' << PrintMatrix(quantized.begin(), rows, cols) << "Reference" << '\n' <<
88 PrintMatrix(reference.begin(), rows, cols) << "Routine" << '\n' << PrintMatrix(test.begin(), rows, cols));
89 }
90
91 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
92 TEST_CASE("Prepare AVX512", "[prepare]") {
93 if (kCPU < CPUType::AVX512BW) return;
94 TestPrepare<AVX512BW::Kernels8>(64, 8);
95 TestPrepare<AVX512BW::Kernels8>(256, 32);
96 TestPrepare<AVX512BW::Kernels16>(64, 8);
97 TestPrepare<AVX512BW::Kernels16>(256, 32);
98 }
99 #endif
100
101 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
102 TEST_CASE("Prepare AVX2", "[prepare]") {
103 if (kCPU < CPUType::AVX2) return;
104 TestPrepare<AVX2::Kernels8>(64, 32);
105 TestPrepare<AVX2::Kernels16>(64, 32);
106 }
107 #endif
108
109 TEST_CASE("Prepare SSSE3", "[prepare]") {
110 if (kCPU < CPUType::SSSE3) return;
111 TestPrepare<SSSE3::Kernels8>(16, 8);
112 TestPrepare<SSSE3::Kernels8>(32, 16);
113 TestPrepare<SSSE3::Kernels8>(32, 32);
114 }
115
116 TEST_CASE("Prepare SSE2", "[prepare]") {
117 if (kCPU < CPUType::SSE2) return;
118 TestPrepare<SSE2::Kernels16>(8, 8);
119 TestPrepare<SSE2::Kernels16>(32, 32);
120 }
121
TestSelectColumnsB(Index rows=64,Index cols=16)122 template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 16) {
123 std::mt19937 gen;
124 // Go somewhat out of range too.
125 std::uniform_real_distribution<float> dist(-129.0, 129.0);
126 AlignedVector<float> input(rows * cols);
127 for (auto& it : input) {
128 it = dist(gen);
129 }
130 using Integer = typename Routine::Integer;
131 AlignedVector<Integer> prepared(input.size());
132 Routine::PrepareB(input.begin(), prepared.begin(), 1, rows, cols);
133
134 const int kSelectCols = 24;
135 Index select_cols[kSelectCols];
136 std::uniform_int_distribution<Index> col_dist(0, cols - 1);
137 for (auto& it : select_cols) {
138 it = col_dist(gen);
139 }
140
141 AlignedVector<Integer> test(rows * kSelectCols);
142 Routine::SelectColumnsB(prepared.begin(), test.begin(), rows, select_cols, select_cols + kSelectCols);
143
144 // Select columns manually in float space.
145 AlignedVector<float> selected(rows * kSelectCols);
146 for (Index r = 0; r < rows; ++r) {
147 for (int c = 0; c < kSelectCols; ++c) {
148 assert(c + r * kSelectCols < rows * kSelectCols);
149 selected[c + r * kSelectCols] = input[select_cols[c] + r * cols];
150 }
151 }
152 AlignedVector<Integer> ref(rows * kSelectCols);
153 Routine::PrepareB(selected.begin(), ref.begin(), 1, rows, kSelectCols);
154 CHECK_MESSAGE(memcmp(ref.begin(), test.begin(), sizeof(Integer) * rows * kSelectCols) == 0, "Reference:\n" <<
155 PrintMatrix(ref.begin(), rows, kSelectCols) << PrintMatrix(test.begin(), rows, kSelectCols));
156 }
157
158 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
159 TEST_CASE("SelectColumnsB AVX512", "[select]") {
160 if (kCPU < CPUType::AVX512BW) return;
161 TestSelectColumnsB<AVX512BW::Kernels8>();
162 TestSelectColumnsB<AVX512BW::Kernels16>(256, 256);
163 }
164 #endif
165
166 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
167 TEST_CASE("SelectColumnsB AVX2", "[select]") {
168 if (kCPU < CPUType::AVX2) return;
169 TestSelectColumnsB<AVX2::Kernels8>(256, 256);
170 TestSelectColumnsB<AVX2::Kernels16>(256, 256);
171 }
172 #endif
173
174 TEST_CASE("SelectColumnsB SSSE3", "[select]") {
175 if (kCPU < CPUType::SSSE3) return;
176 TestSelectColumnsB<SSSE3::Kernels8>();
177 TestSelectColumnsB<SSSE3::Kernels8>(256, 256);
178 }
179
180 TEST_CASE("SelectColumnsB SSE2", "[select]") {
181 if (kCPU < CPUType::SSE2) return;
182 TestSelectColumnsB<SSE2::Kernels16>();
183 TestSelectColumnsB<SSE2::Kernels16>(256, 256);
184 }
185
TestMax()186 template <class Register> void TestMax() {
187 Register r = set1_ps<Register>(-2.0);
188 for (std::size_t i = 0; i < sizeof(Register) / sizeof(float); ++i) {
189 Register c = r;
190 reinterpret_cast<float*>(&c)[i] = -1.0;
191 CHECK_MESSAGE((MaxFloat32(c) == -1.0), "MaxFloat32 produced " << MaxFloat32(c));
192 }
193 }
194
195 TEST_CASE("Max", "[max]") {
196 TestMax<__m128>();
197 }
198
CompareMaxAbs(const float * begin,const float * end,float test,std::size_t offset)199 void CompareMaxAbs(const float *begin, const float *end, float test, std::size_t offset) {
200 float largest = std::fabs(*std::max_element(begin, end));
201 float smallest = std::fabs(*std::min_element(begin, end));
202 largest = std::max(largest, smallest);
203 CHECK_MESSAGE(largest == test, "Error: " << largest << " versus " << test << " in length " << (end - begin) << " offset " << offset);
204 }
205
TestMaxAbsolute()206 template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute() {
207 std::mt19937 gen;
208 std::uniform_real_distribution<float> dist(-8.0, 8.0);
209 const std::size_t kLengthMax = 65;
210 AlignedVector<float> test(kLengthMax);
211 for (std::size_t len = 1; len < kLengthMax; ++len) {
212 for (std::size_t t = 0; t < len; ++t) {
213 // Fill with [-8, 8).
214 for (auto& it : test) {
215 it = dist(gen);
216 }
217 CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
218 test[t] = -32.0;
219 CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
220 test[t] = 32.0;
221 CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
222 }
223 }
224 }
225
226 TEST_CASE("MaxAbsolute SSE2", "[max]") {
227 if (kCPU < CPUType::SSE2) return;
228 TestMaxAbsolute<SSE2::MaxAbsolute>();
229 }
230
231 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
232 TEST_CASE("MaxAbsolute AVX2", "[max]") {
233 if (kCPU < CPUType::AVX2) return;
234 TestMaxAbsolute<AVX2::MaxAbsolute>();
235 }
236 #endif
237
238 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
239 TEST_CASE("MaxAbsolute AVX512BW", "[max]") {
240 if (kCPU < CPUType::AVX512BW) return;
241 TestMaxAbsolute<AVX512BW::MaxAbsolute>();
242 }
243 #endif
244
245 // Based on https://arxiv.org/abs/1705.01991
246
247 // Copyright (c) 2017 Microsoft Corporation
248
249 // Permission is hereby granted, free of charge, to any person obtaining a copy
250 // of this software and associated documentation files (the "Software"), to deal
251 // in the Software without restriction, including without limitation the rights
252 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
253 // copies of the Software, and to permit persons to whom the Software is
254 // furnished to do so, subject to the following conditions:
255
256 // The above copyright notice and this permission notice shall be included in all
257 // copies or substantial portions of the Software.
258
259 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
260 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
261 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
262 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
263 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
264 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
265 // SOFTWARE.
266 // Compute A*B slowly in floats.
267
TestMultiply(Index A_rows,Index width,Index B_cols,float int_tolerance=.1,float float_tolerance=1,float MSE_float_tolerance=0,float MSE_int_tolerance=0)268 template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols,
269 float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
270 using Integer = typename Routine::Integer;
271 std::ostringstream info;
272 info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
273
274 // Initialize A and B.
275 AlignedVector<float> A(A_rows * width);
276 AlignedVector<float> B(width * B_cols);
277 std::mt19937 gen;
278 std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
279 for (auto& it : A) {
280 it = dist(gen);
281 }
282 for (auto& it : B) {
283 it = dist(gen);
284 }
285
286 float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
287 float unquant_mult = 1.0f / (quant_mult*quant_mult);
288
289 AlignedVector<Integer> A_prep(A.size());
290 AlignedVector<Integer> B_prep(B.size());
291 Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
292 Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
293
294 AlignedVector<float> test_C(A_rows * B_cols);
295 OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
296 // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
297 // callbacks::Unquantize(unquant_mult),
298 // callbacks::Write<float>(test_C.begin())
299 // ));
300
301 AlignedVector<Integer> B_quant(B.size());
302 Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
303 AlignedVector<float> slowint_C(test_C.size());
304 // Assuming A is just quantization here.
305 references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) {
306 return sum * unquant_mult;
307 });
308
309 AlignedVector<float> float_C(test_C.size());
310 references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) {
311 return static_cast<float>(sum);
312 });
313
314 CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
315 int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
316 }
317
TestMultiplyRelu(Index A_rows,Index width,Index B_cols,float int_tolerance=.1,float float_tolerance=1,float MSE_float_tolerance=0,float MSE_int_tolerance=0)318 template <class Routine> void TestMultiplyRelu(Index A_rows, Index width, Index B_cols,
319 float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
320 using Integer = typename Routine::Integer;
321 std::ostringstream info;
322 info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
323
324 // Initialize A and B.
325 AlignedVector<float> A(A_rows * width);
326 AlignedVector<float> B(width * B_cols);
327 std::mt19937 gen;
328 std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
329 for (auto& it : A) {
330 it = dist(gen);
331 }
332 for (auto& it : B) {
333 it = dist(gen);
334 }
335
336 float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
337 float unquant_mult = 1.0f / (quant_mult*quant_mult);
338
339 AlignedVector<Integer> A_prep(A.size());
340 AlignedVector<Integer> B_prep(B.size());
341 Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
342 Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
343
344 AlignedVector<float> test_C(A_rows * B_cols);
345 OMPParallelWrap<callbacks::UnquantizeAndWriteRelu, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin()));
346 // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
347 // callbacks::Unquantize(unquant_mult),
348 // callbacks::Write<float>(test_C.begin())
349 // ));
350
351 AlignedVector<Integer> B_quant(B.size());
352 Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
353 AlignedVector<float> slowint_C(test_C.size());
354 // Assuming A is just quantization here.
355 references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) {
356 float ret = std::max(0.0f, sum * unquant_mult);
357 return ret;
358 });
359
360 AlignedVector<float> float_C(test_C.size());
361 references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) {
362 return static_cast<float>(std::max(0.0,sum));
363 });
364
365 CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
366 int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
367 }
368
369 //Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols
370 //Require different number of arguments. I don't think the refactoring is worth it.
TestMultiplyBias(Index A_rows,Index width,Index B_cols,float int_tolerance=0.1f,float float_tolerance=1.0f,float MSE_float_tolerance=0.0f,float MSE_int_tolerance=0.0f)371 template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index B_cols,
372 float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) {
373 using Integer = typename Routine::Integer;
374 std::ostringstream info;
375 info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
376
377 // Initialize A and B.
378 AlignedVector<float> A(A_rows * width);
379 AlignedVector<float> B(width * B_cols);
380 AlignedVector<float> bias(B_cols);
381 std::mt19937 gen;
382 std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
383 for (auto& it : A) {
384 it = dist(gen);
385 }
386 for (auto& it : B) {
387 it = dist(gen);
388 }
389 for (auto& it : bias) {
390 it = dist(gen);
391 }
392
393 float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
394 float unquant_mult = 1.0f / (quant_mult*quant_mult);
395
396 AlignedVector<Integer> A_prep(A.size());
397 AlignedVector<Integer> B_prep(B.size());
398 Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
399 Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
400
401 AlignedVector<float> test_C(A_rows * B_cols);
402
403 Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
404
405 AlignedVector<Integer> B_quant(B.size());
406 Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
407 AlignedVector<float> slowint_C(test_C.size());
408 // Assuming A is just quantization here.
409 references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
410 return sum * unquant_mult + bias[info.col_idx];
411 });
412
413 AlignedVector<float> float_C(test_C.size());
414 references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
415 return static_cast<float>(sum) + bias[info.col_idx];
416 });
417
418 CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
419 int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
420 }
421
TestMultiplyBiasRelu(Index A_rows,Index width,Index B_cols,float int_tolerance=0.1f,float float_tolerance=1.0f,float MSE_float_tolerance=0.0f,float MSE_int_tolerance=0.0f)422 template <class Routine> void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols,
423 float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) {
424 using Integer = typename Routine::Integer;
425 std::ostringstream info;
426 info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
427
428 // Initialize A and B.
429 AlignedVector<float> A(A_rows * width);
430 AlignedVector<float> B(width * B_cols);
431 AlignedVector<float> bias(B_cols);
432 std::mt19937 gen;
433 std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
434 for (auto& it : A) {
435 it = dist(gen);
436 }
437 for (auto& it : B) {
438 it = dist(gen);
439 }
440 for (auto& it : bias) {
441 it = dist(gen);
442 }
443
444 float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
445 float unquant_mult = 1.0f / (quant_mult*quant_mult);
446
447 AlignedVector<Integer> A_prep(A.size());
448 AlignedVector<Integer> B_prep(B.size());
449 Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
450 Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
451
452 AlignedVector<float> test_C(A_rows * B_cols);
453
454 Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin()));
455
456 AlignedVector<Integer> B_quant(B.size());
457 Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
458 AlignedVector<float> slowint_C(test_C.size());
459 // Assuming A is just quantization here.
460 references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
461 return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]);
462 });
463
464 AlignedVector<float> float_C(test_C.size());
465 references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
466 return std::max(0.0f, static_cast<float>(sum) + bias[info.col_idx]);
467 });
468
469 CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
470 int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
471 }
472
473 TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
474 if (kCPU < CPUType::SSE2) return;
475 TestMultiply<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
476 TestMultiply<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
477 TestMultiply<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
478 TestMultiply<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
479 TestMultiply<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
480 TestMultiply<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
481 }
482
483 TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") {
484 if (kCPU < CPUType::SSE2) return;
485 TestMultiplyRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
486 TestMultiplyRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
487 TestMultiplyRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
488 TestMultiplyRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
489 TestMultiplyRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
490 TestMultiplyRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
491 }
492
493 TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
494 if (kCPU < CPUType::SSE2) return;
495 TestMultiplyBias<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
496 TestMultiplyBias<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
497 TestMultiplyBias<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
498 TestMultiplyBias<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
499 TestMultiplyBias<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
500 TestMultiplyBias<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
501 }
502
503 TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") {
504 if (kCPU < CPUType::SSE2) return;
505 TestMultiplyBiasRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
506 TestMultiplyBiasRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
507 TestMultiplyBiasRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
508 TestMultiplyBiasRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
509 TestMultiplyBiasRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
510 TestMultiplyBiasRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
511 }
512
513 TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
514 if (kCPU < CPUType::SSSE3) return;
515 TestMultiply<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
516 TestMultiply<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
517 TestMultiply<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
518 TestMultiply<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
519 TestMultiply<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
520 TestMultiply<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
521 }
522
523 TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") {
524 if (kCPU < CPUType::SSSE3) return;
525 TestMultiplyRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
526 TestMultiplyRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
527 TestMultiplyRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
528 TestMultiplyRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
529 TestMultiplyRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
530 TestMultiplyRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
531 }
532
533 TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
534 if (kCPU < CPUType::SSSE3) return;
535 TestMultiplyBias<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
536 TestMultiplyBias<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
537 TestMultiplyBias<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
538 TestMultiplyBias<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
539 TestMultiplyBias<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
540 TestMultiplyBias<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
541 }
542
543 TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") {
544 if (kCPU < CPUType::SSSE3) return;
545 TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
546 TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
547 TestMultiplyBiasRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
548 TestMultiplyBiasRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
549 TestMultiplyBiasRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
550 TestMultiplyBiasRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
551 }
552
553
554 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
555 TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
556 if (kCPU < CPUType::AVX2) return;
557 TestMultiply<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
558 TestMultiply<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
559 TestMultiply<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
560 TestMultiply<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
561 TestMultiply<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
562 TestMultiply<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
563 }
564
565 TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") {
566 if (kCPU < CPUType::AVX2) return;
567 TestMultiplyRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
568 TestMultiplyRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
569 TestMultiplyRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
570 TestMultiplyRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
571 TestMultiplyRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
572 TestMultiplyRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
573 }
574
575 TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
576 if (kCPU < CPUType::AVX2) return;
577 TestMultiplyBias<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
578 TestMultiplyBias<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
579 TestMultiplyBias<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
580 TestMultiplyBias<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
581 TestMultiplyBias<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
582 TestMultiplyBias<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
583 }
584
585 TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") {
586 if (kCPU < CPUType::AVX2) return;
587 TestMultiplyBiasRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
588 TestMultiplyBiasRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
589 TestMultiplyBiasRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
590 TestMultiplyBiasRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
591 TestMultiplyBiasRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
592 TestMultiplyBiasRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
593 }
594
595 TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
596 if (kCPU < CPUType::AVX2) return;
597 TestMultiply<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
598 TestMultiply<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
599 TestMultiply<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
600 TestMultiply<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
601 TestMultiply<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
602 TestMultiply<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
603 }
604
605 TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") {
606 if (kCPU < CPUType::AVX2) return;
607 TestMultiplyRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
608 TestMultiplyRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
609 TestMultiplyRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
610 TestMultiplyRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
611 TestMultiplyRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
612 TestMultiplyRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
613 }
614
615 TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
616 if (kCPU < CPUType::AVX2) return;
617 TestMultiplyBias<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
618 TestMultiplyBias<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
619 TestMultiplyBias<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
620 TestMultiplyBias<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
621 TestMultiplyBias<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
622 TestMultiplyBias<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
623 }
624
625 TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") {
626 if (kCPU < CPUType::AVX2) return;
627 TestMultiplyBiasRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
628 TestMultiplyBiasRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
629 TestMultiplyBiasRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
630 TestMultiplyBiasRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
631 TestMultiplyBiasRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
632 TestMultiplyBiasRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
633 }
634 #endif
635
636 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
637 TEST_CASE ("Multiply AVX512 8bit", "[multiply]") {
638 if (kCPU < CPUType::AVX512BW) return;
639 TestMultiply<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
640 TestMultiply<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
641 TestMultiply<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
642 TestMultiply<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
643 TestMultiply<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
644 TestMultiply<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
645 }
646
647 TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") {
648 if (kCPU < CPUType::AVX512BW) return;
649 TestMultiplyRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
650 TestMultiplyRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
651 TestMultiplyRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
652 TestMultiplyRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
653 TestMultiplyRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
654 TestMultiplyRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
655 }
656
657 TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") {
658 if (kCPU < CPUType::AVX512BW) return;
659 TestMultiplyBias<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
660 TestMultiplyBias<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
661 TestMultiplyBias<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
662 TestMultiplyBias<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
663 TestMultiplyBias<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
664 TestMultiplyBias<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
665 }
666
667 TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") {
668 if (kCPU < CPUType::AVX512BW) return;
669 TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
670 TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
671 TestMultiplyBiasRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
672 TestMultiplyBiasRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
673 TestMultiplyBiasRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
674 TestMultiplyBiasRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
675 }
676
677 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
678 TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") {
679 if (kCPU < CPUType::AVX512VNNI) return;
680 TestMultiply<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
681 TestMultiply<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
682 TestMultiply<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
683 TestMultiply<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
684 TestMultiply<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
685 TestMultiply<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
686 }
687
688 TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") {
689 if (kCPU < CPUType::AVX512VNNI) return;
690 TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
691 TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
692 TestMultiplyRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
693 TestMultiplyRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
694 TestMultiplyRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
695 TestMultiplyRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
696 }
697
698 TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") {
699 if (kCPU < CPUType::AVX512VNNI) return;
700 TestMultiplyBias<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
701 TestMultiplyBias<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
702 TestMultiplyBias<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
703 TestMultiplyBias<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
704 TestMultiplyBias<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
705 TestMultiplyBias<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
706 }
707
708 TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") {
709 if (kCPU < CPUType::AVX512VNNI) return;
710 TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
711 TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
712 TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
713 TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
714 TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
715 TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
716 }
717 #endif
718
719 TEST_CASE ("Multiply AVX512 16bit", "[multiply]") {
720 if (kCPU < CPUType::AVX512BW) return;
721 TestMultiply<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
722 TestMultiply<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
723 TestMultiply<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
724 TestMultiply<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
725 TestMultiply<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
726 TestMultiply<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
727 }
728
729 TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") {
730 if (kCPU < CPUType::AVX512BW) return;
731 TestMultiplyRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
732 TestMultiplyRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
733 TestMultiplyRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
734 TestMultiplyRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
735 TestMultiplyRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
736 TestMultiplyRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
737 }
738
739
740 TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") {
741 if (kCPU < CPUType::AVX512BW) return;
742 TestMultiplyBias<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
743 TestMultiplyBias<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
744 TestMultiplyBias<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
745 TestMultiplyBias<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
746 TestMultiplyBias<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
747 TestMultiplyBias<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
748 }
749
750 TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") {
751 if (kCPU < CPUType::AVX512BW) return;
752 TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
753 TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
754 TestMultiplyBiasRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
755 TestMultiplyBiasRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
756 TestMultiplyBiasRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
757 TestMultiplyBiasRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
758 }
759 #endif
760
761 } // namespace intgemm
762