1 #pragma once
2 
3 #include "intgemm/intgemm_config.h"
4 #include "intrinsics.h"
5 #include "types.h"
6 
7 #include <algorithm>
8 #include <cassert>
9 
10 namespace intgemm {
11 
12 /*
13  * Interleave vectors.
14  */
15 #define INTGEMM_INTERLEAVE_N(target, type, N) \
16 target static inline void Interleave##N(type &first, type &second) { \
17   type temp = unpacklo_epi##N(first, second); \
18   second = unpackhi_epi##N(first, second); \
19   first = temp; \
20 }
21 
22 #define INTGEMM_INTERLEAVE(target, type) \
23 INTGEMM_INTERLEAVE_N(target, type, 8) \
24 INTGEMM_INTERLEAVE_N(target, type, 16) \
25 INTGEMM_INTERLEAVE_N(target, type, 32) \
26 INTGEMM_INTERLEAVE_N(target, type, 64)
27 
INTGEMM_INTERLEAVE(INTGEMM_SSE2,__m128i)28 INTGEMM_INTERLEAVE(INTGEMM_SSE2, __m128i)
29 
30 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
31 INTGEMM_INTERLEAVE(INTGEMM_AVX2, __m256i)
32 #endif
33 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
34 INTGEMM_INTERLEAVE(INTGEMM_AVX512BW, __m512i)
35 #endif
36 
37 /*
38  * Swap vectors.
39  */
40 #define INTGEMM_SWAP(target, Register) \
41 target static inline void Swap(Register &a, Register &b) { \
42   Register tmp = a; \
43   a = b; \
44   b = tmp; \
45 } \
46 
47 INTGEMM_SWAP(INTGEMM_SSE2, __m128i)
48 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
49 INTGEMM_SWAP(INTGEMM_AVX2, __m256i)
50 #endif
51 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
52 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
53 INTGEMM_SWAP(INTGEMM_AVX512BW, __m512i)
54 #endif
55 
56 /* Transpose registers containing 8 packed 16-bit integers.
57  * Each 128-bit lane is handled independently.
58  */
59 #define INTGEMM_TRANSPOSE16(target, Register) \
60 target static inline void Transpose16InLane(Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7) { \
61   /* r0: columns 0 1 2 3 4 5 6 7 from row 0
62      r1: columns 0 1 2 3 4 5 6 7 from row 1*/ \
63   Interleave16(r0, r1); \
64   Interleave16(r2, r3); \
65   Interleave16(r4, r5); \
66   Interleave16(r6, r7); \
67   /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
68      r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
69      r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
70      r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
71      r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
72      r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
73      r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
74      r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/ \
75   Interleave32(r0, r2); \
76   Interleave32(r1, r3); \
77   Interleave32(r4, r6); \
78   Interleave32(r5, r7); \
79   /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
80      r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
81      r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
82      r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
83      r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
84      r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
85      r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
86      r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/ \
87   Interleave64(r0, r4); \
88   Interleave64(r1, r5); \
89   Interleave64(r2, r6); \
90   Interleave64(r3, r7); \
91   /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
92      r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
93      r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
94      r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
95      r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
96      r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/ \
97   /* Empirically gcc is able to remove these movs and just rename the outputs of Interleave64. */ \
98   Swap(r1, r4); \
99   Swap(r3, r6); \
100 } \
101 
102 INTGEMM_TRANSPOSE16(INTGEMM_SSE2, __m128i)
103 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
104 INTGEMM_TRANSPOSE16(INTGEMM_AVX2, __m256i)
105 #endif
106 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
107 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
108 INTGEMM_TRANSPOSE16(INTGEMM_AVX512BW, __m512i)
109 #endif
110 
111 /* Tranpose registers containing 16 packed 8-bit integers.
112  * Each 128-bit lane is handled independently.
113  */
114 template <class Register> static inline void Transpose8InLane(
115     Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7,
116     Register &r8, Register &r9, Register &r10, Register &r11, Register &r12, Register &r13, Register &r14, Register &r15) {
117   // Get 8-bit values to 16-bit values so they can travel together.
118   Interleave8(r0, r1);
119   // r0: columns 0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 from rows 0 and 1.
120   // r1: columns 8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15 from rows 0 and 1.
121   Interleave8(r2, r3);
122   // r2: columns 0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 from rows 2 and 3.
123   Interleave8(r4, r5);
124   Interleave8(r6, r7);
125   Interleave8(r8, r9);
126   Interleave8(r10, r11);
127   Interleave8(r12, r13);
128   Interleave8(r14, r15);
129   Transpose16InLane(r0, r2, r4, r6, r8, r10, r12, r14);
130   Transpose16InLane(r1, r3, r5, r7, r9, r11, r13, r15);
131   // Permute into correct order.  This is free because the outputs just get pemuted.
132   Register tmp;
133   tmp = r2;
134   r2 = r4;
135   r4 = r8;
136   r8 = r1;
137   r1 = tmp;
138   tmp = r3;
139   r3 = r6;
140   r6 = r12;
141   r12 = r9;
142   r9 = tmp;
143   tmp = r5;
144   r5 = r10;
145   r10 = tmp;
146   tmp = r7;
147   r7 = r14;
148   r14 = r13;
149   r13 = r11;
150   r11 = tmp;
151 }
152 
153 // PREPARE B: quantize and rearrange.  B is presumed to be constantparameters
154 // so we can take our time rearranging it in order to save during the multiply.
155 //
156 // We presume B starts in row-major order.
157 //
158 // In INTGEMM_AVX2, a register holds 32 8-bit values or 16 16-bit values and we want
159 // that many values from the same column in the register.
160 //
161 // The multiplier reads 8 rows at a time and we want these reads to be
162 // contiguous.
163 //
164 // Each 8x32 (for 8-bit) or 8x16 (for 16-bit) tile of B is transposed.
165 // The tiles are stored in column major order.
166 //
167 // For INTGEMM_AVX2, this matrix shows what index each value of B will be stored at:
168 //   0  16 ... 240
169 //   1  17 ... 241
170 //   2  18 ... 242
171 //   3  19 ... 243
172 //   4  20 ... 244
173 //   5  21 ... 245
174 //   6  22 ... 246
175 //   7  23 ... 247
176 //   8  24 ... 248
177 //   9  25 ... 249
178 //  10  26 ... 250
179 //  11  27 ... 251
180 //  12  28 ... 252
181 //  13  29 ... 253
182 //  14  30 ... 254
183 //  15  31 ... 255
184 // 256 272
185 // 257 273
186 // ... ...
187 #define INTGEMM_PREPARE_B_8(target, QuantClass) \
188 target static inline void PrepareB(const float *input, int8_t *output_shadow, float quant_mult, Index rows, Index cols) { \
189   FRegister q = set1_ps<FRegister>(quant_mult); \
190   /* Currently all multipliers have a stride of 8 columns.*/ \
191   const Index kColStride = 8; \
192   assert(cols % kColStride == 0); \
193   assert(rows % sizeof(Register) == 0); \
194   assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
195   Register *output = reinterpret_cast<Register*>(output_shadow); \
196   assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
197   for (Index c = 0; c < cols; c += kColStride) { \
198     for (Index r = 0; r < rows; r += sizeof(Register), output += 8) { \
199       /* Quantize and perform a transpose with height sizeof(Register) and width 8. \
200          This isn't quite Transpose8InLane because it's half the number of columns, \
201          so each register starts with two rows instead of being one row. \
202          The quantizers know to skip a row.*/ \
203       output[0] = QuantClass::ForReshape(q, input + cols * (r    ) + c, cols); \
204       output[1] = QuantClass::ForReshape(q, input + cols * (r + 1) + c, cols); \
205       output[2] = QuantClass::ForReshape(q, input + cols * (r + 4) + c, cols); \
206       output[3] = QuantClass::ForReshape(q, input + cols * (r + 5) + c, cols); \
207       output[4] = QuantClass::ForReshape(q, input + cols * (r + 8) + c, cols); \
208       output[5] = QuantClass::ForReshape(q, input + cols * (r + 9) + c, cols); \
209       output[6] = QuantClass::ForReshape(q, input + cols * (r + 12) + c, cols); \
210       output[7] = QuantClass::ForReshape(q, input + cols * (r + 13) + c, cols); \
211       Interleave8(output[0], output[1]); \
212       Interleave8(output[2], output[3]); \
213       Interleave8(output[4], output[5]); \
214       Interleave8(output[6], output[7]); \
215       Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \
216     } \
217   } \
218 } \
219 
220 #define INTGEMM_PREPARE_B_16(target, QuantClass) \
221 target static inline void PrepareB(const float *input, int16_t *output_shadow, float quant_mult, Index rows, Index cols) { \
222   FRegister q = set1_ps<FRegister>(quant_mult); \
223   assert(cols % 8 == 0); \
224   assert(rows % (sizeof(Register) / sizeof(int16_t)) == 0); \
225   assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
226   Register *output = reinterpret_cast<Register*>(output_shadow); \
227   assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
228   for (Index c = 0; c < cols; c += 8) { \
229     for (Index r = 0; r < rows; r += (sizeof(Register) / sizeof(int16_t)), output += 8) { \
230       /* gcc unrolls this loop and uses registers for output[k]*/ \
231       for (Index k = 0; k < 8; ++k) { \
232         output[k] = QuantClass::ForReshape(q, input + cols * (r + k) + c, cols); \
233       } \
234       Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \
235     } \
236   } \
237 }
238 
239 /*
240  * Prepare B matrix.
241  * B matrix has to be transposed and quantized.
242  * Cols has to be a multiple of sizeof(Register) / sizeof(Integer).
243  *
244  * cols and rows describe size of transposed B.
245  */
246 #define INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(target, Integer) \
247 target static inline void PrepareBQuantizedTransposed(const Integer* input, Integer* output, Index cols, Index rows) { \
248   const Index RegisterElems = sizeof(Register) / sizeof(Integer); \
249   const Index kColStride = 8; \
250   \
251   assert(cols % RegisterElems == 0); \
252   assert(rows % kColStride == 0); \
253   assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
254   assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
255   \
256   Register* output_it = reinterpret_cast<Register*>(output); \
257   for (Index r = 0; r < rows; r += kColStride) \
258     for (Index c = 0; c < cols; c += RegisterElems) \
259       for (Index ri = 0; ri < 8; ++ri) \
260         *output_it++ = *reinterpret_cast<const Register*>(input + (r + ri) * cols + c); \
261 }
262 
263 /*
264  * Prepare B matrix.
265  * B matrix has to be transposed.
266  * Cols has to be a multiple of sizeof(Register) / sizeof(float).
267  *
268  * cols and rows describe size of transposed B.
269  */
270 #define INTGEMM_PREPARE_B_TRANSPOSED(target, Quantizer, Integer) \
271 target static inline void PrepareBTransposed(const float* input, Integer* output, float quant_mult, Index cols, Index rows) { \
272   const Index RegisterElemsInt = sizeof(Register) / sizeof(Integer); \
273   const Index kColStride = 8; \
274   \
275   assert(cols % (sizeof(Register) / sizeof(float)) == 0); \
276   assert(rows % kColStride == 0); \
277   assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
278   assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
279   \
280   FRegister q = set1_ps<FRegister>(quant_mult); \
281   Register* output_it = reinterpret_cast<Register*>(output); \
282   Index r = 0; \
283   Index c = 0; \
284   while (r < rows) { \
285     for (Index ri = 0; ri < 8; ++ri) \
286       *output_it++ = Quantizer::ConsecutiveWithWrapping(q, input + (r + ri) * cols + c, cols - c, cols, 8); \
287     c += RegisterElemsInt; \
288     while (c >= cols) { \
289       r += kColStride; \
290       c -= cols; \
291     } \
292   } \
293 }
294 
295 /* Select columns of B from PrepareB format to PrepareB format.
296  */
297 #define INTGEMM_SELECT_COL_B(target, Register) \
298 target static inline void SelectColumnsOfB(const Register *input, Register *output, Index rows_bytes /* number of bytes in a row */, const Index *cols_begin, const Index *cols_end) { \
299   assert(rows_bytes % sizeof(Register) == 0); \
300   assert((cols_end - cols_begin) % 8 == 0);  \
301   /* Do columns for multiples of 8.*/ \
302   Index register_rows = rows_bytes / sizeof(Register); \
303   const Register *starts[8]; \
304   for (; cols_begin != cols_end; cols_begin += 8) { \
305     for (Index k = 0; k < 8; ++k) { \
306       starts[k] = input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows; \
307     } \
308     for (Index r = 0; r < register_rows; ++r) { \
309       for (Index k = 0; k < 8; ++k) { \
310         *(output++) = *starts[k]; \
311         starts[k] += 8; \
312       } \
313     } \
314   } \
315 }
316 
317 } // namespace intgemm
318