1 #pragma once
2
3 #include "intgemm/intgemm_config.h"
4 #include "intrinsics.h"
5 #include "types.h"
6
7 #include <algorithm>
8 #include <cassert>
9
10 namespace intgemm {
11
12 /*
13 * Interleave vectors.
14 */
15 #define INTGEMM_INTERLEAVE_N(target, type, N) \
16 target static inline void Interleave##N(type &first, type &second) { \
17 type temp = unpacklo_epi##N(first, second); \
18 second = unpackhi_epi##N(first, second); \
19 first = temp; \
20 }
21
22 #define INTGEMM_INTERLEAVE(target, type) \
23 INTGEMM_INTERLEAVE_N(target, type, 8) \
24 INTGEMM_INTERLEAVE_N(target, type, 16) \
25 INTGEMM_INTERLEAVE_N(target, type, 32) \
26 INTGEMM_INTERLEAVE_N(target, type, 64)
27
INTGEMM_INTERLEAVE(INTGEMM_SSE2,__m128i)28 INTGEMM_INTERLEAVE(INTGEMM_SSE2, __m128i)
29
30 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
31 INTGEMM_INTERLEAVE(INTGEMM_AVX2, __m256i)
32 #endif
33 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
34 INTGEMM_INTERLEAVE(INTGEMM_AVX512BW, __m512i)
35 #endif
36
37 /*
38 * Swap vectors.
39 */
40 #define INTGEMM_SWAP(target, Register) \
41 target static inline void Swap(Register &a, Register &b) { \
42 Register tmp = a; \
43 a = b; \
44 b = tmp; \
45 } \
46
47 INTGEMM_SWAP(INTGEMM_SSE2, __m128i)
48 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
49 INTGEMM_SWAP(INTGEMM_AVX2, __m256i)
50 #endif
51 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
52 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
53 INTGEMM_SWAP(INTGEMM_AVX512BW, __m512i)
54 #endif
55
56 /* Transpose registers containing 8 packed 16-bit integers.
57 * Each 128-bit lane is handled independently.
58 */
59 #define INTGEMM_TRANSPOSE16(target, Register) \
60 target static inline void Transpose16InLane(Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7) { \
61 /* r0: columns 0 1 2 3 4 5 6 7 from row 0
62 r1: columns 0 1 2 3 4 5 6 7 from row 1*/ \
63 Interleave16(r0, r1); \
64 Interleave16(r2, r3); \
65 Interleave16(r4, r5); \
66 Interleave16(r6, r7); \
67 /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
68 r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
69 r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
70 r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
71 r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
72 r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
73 r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
74 r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/ \
75 Interleave32(r0, r2); \
76 Interleave32(r1, r3); \
77 Interleave32(r4, r6); \
78 Interleave32(r5, r7); \
79 /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
80 r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
81 r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
82 r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
83 r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
84 r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
85 r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
86 r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/ \
87 Interleave64(r0, r4); \
88 Interleave64(r1, r5); \
89 Interleave64(r2, r6); \
90 Interleave64(r3, r7); \
91 /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
92 r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
93 r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
94 r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
95 r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
96 r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/ \
97 /* Empirically gcc is able to remove these movs and just rename the outputs of Interleave64. */ \
98 Swap(r1, r4); \
99 Swap(r3, r6); \
100 } \
101
102 INTGEMM_TRANSPOSE16(INTGEMM_SSE2, __m128i)
103 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
104 INTGEMM_TRANSPOSE16(INTGEMM_AVX2, __m256i)
105 #endif
106 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
107 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
108 INTGEMM_TRANSPOSE16(INTGEMM_AVX512BW, __m512i)
109 #endif
110
111 /* Tranpose registers containing 16 packed 8-bit integers.
112 * Each 128-bit lane is handled independently.
113 */
114 template <class Register> static inline void Transpose8InLane(
115 Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7,
116 Register &r8, Register &r9, Register &r10, Register &r11, Register &r12, Register &r13, Register &r14, Register &r15) {
117 // Get 8-bit values to 16-bit values so they can travel together.
118 Interleave8(r0, r1);
119 // r0: columns 0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 from rows 0 and 1.
120 // r1: columns 8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15 from rows 0 and 1.
121 Interleave8(r2, r3);
122 // r2: columns 0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 from rows 2 and 3.
123 Interleave8(r4, r5);
124 Interleave8(r6, r7);
125 Interleave8(r8, r9);
126 Interleave8(r10, r11);
127 Interleave8(r12, r13);
128 Interleave8(r14, r15);
129 Transpose16InLane(r0, r2, r4, r6, r8, r10, r12, r14);
130 Transpose16InLane(r1, r3, r5, r7, r9, r11, r13, r15);
131 // Permute into correct order. This is free because the outputs just get pemuted.
132 Register tmp;
133 tmp = r2;
134 r2 = r4;
135 r4 = r8;
136 r8 = r1;
137 r1 = tmp;
138 tmp = r3;
139 r3 = r6;
140 r6 = r12;
141 r12 = r9;
142 r9 = tmp;
143 tmp = r5;
144 r5 = r10;
145 r10 = tmp;
146 tmp = r7;
147 r7 = r14;
148 r14 = r13;
149 r13 = r11;
150 r11 = tmp;
151 }
152
153 // PREPARE B: quantize and rearrange. B is presumed to be constantparameters
154 // so we can take our time rearranging it in order to save during the multiply.
155 //
156 // We presume B starts in row-major order.
157 //
158 // In INTGEMM_AVX2, a register holds 32 8-bit values or 16 16-bit values and we want
159 // that many values from the same column in the register.
160 //
161 // The multiplier reads 8 rows at a time and we want these reads to be
162 // contiguous.
163 //
164 // Each 8x32 (for 8-bit) or 8x16 (for 16-bit) tile of B is transposed.
165 // The tiles are stored in column major order.
166 //
167 // For INTGEMM_AVX2, this matrix shows what index each value of B will be stored at:
168 // 0 16 ... 240
169 // 1 17 ... 241
170 // 2 18 ... 242
171 // 3 19 ... 243
172 // 4 20 ... 244
173 // 5 21 ... 245
174 // 6 22 ... 246
175 // 7 23 ... 247
176 // 8 24 ... 248
177 // 9 25 ... 249
178 // 10 26 ... 250
179 // 11 27 ... 251
180 // 12 28 ... 252
181 // 13 29 ... 253
182 // 14 30 ... 254
183 // 15 31 ... 255
184 // 256 272
185 // 257 273
186 // ... ...
187 #define INTGEMM_PREPARE_B_8(target, QuantClass) \
188 target static inline void PrepareB(const float *input, int8_t *output_shadow, float quant_mult, Index rows, Index cols) { \
189 FRegister q = set1_ps<FRegister>(quant_mult); \
190 /* Currently all multipliers have a stride of 8 columns.*/ \
191 const Index kColStride = 8; \
192 assert(cols % kColStride == 0); \
193 assert(rows % sizeof(Register) == 0); \
194 assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
195 Register *output = reinterpret_cast<Register*>(output_shadow); \
196 assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
197 for (Index c = 0; c < cols; c += kColStride) { \
198 for (Index r = 0; r < rows; r += sizeof(Register), output += 8) { \
199 /* Quantize and perform a transpose with height sizeof(Register) and width 8. \
200 This isn't quite Transpose8InLane because it's half the number of columns, \
201 so each register starts with two rows instead of being one row. \
202 The quantizers know to skip a row.*/ \
203 output[0] = QuantClass::ForReshape(q, input + cols * (r ) + c, cols); \
204 output[1] = QuantClass::ForReshape(q, input + cols * (r + 1) + c, cols); \
205 output[2] = QuantClass::ForReshape(q, input + cols * (r + 4) + c, cols); \
206 output[3] = QuantClass::ForReshape(q, input + cols * (r + 5) + c, cols); \
207 output[4] = QuantClass::ForReshape(q, input + cols * (r + 8) + c, cols); \
208 output[5] = QuantClass::ForReshape(q, input + cols * (r + 9) + c, cols); \
209 output[6] = QuantClass::ForReshape(q, input + cols * (r + 12) + c, cols); \
210 output[7] = QuantClass::ForReshape(q, input + cols * (r + 13) + c, cols); \
211 Interleave8(output[0], output[1]); \
212 Interleave8(output[2], output[3]); \
213 Interleave8(output[4], output[5]); \
214 Interleave8(output[6], output[7]); \
215 Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \
216 } \
217 } \
218 } \
219
220 #define INTGEMM_PREPARE_B_16(target, QuantClass) \
221 target static inline void PrepareB(const float *input, int16_t *output_shadow, float quant_mult, Index rows, Index cols) { \
222 FRegister q = set1_ps<FRegister>(quant_mult); \
223 assert(cols % 8 == 0); \
224 assert(rows % (sizeof(Register) / sizeof(int16_t)) == 0); \
225 assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
226 Register *output = reinterpret_cast<Register*>(output_shadow); \
227 assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
228 for (Index c = 0; c < cols; c += 8) { \
229 for (Index r = 0; r < rows; r += (sizeof(Register) / sizeof(int16_t)), output += 8) { \
230 /* gcc unrolls this loop and uses registers for output[k]*/ \
231 for (Index k = 0; k < 8; ++k) { \
232 output[k] = QuantClass::ForReshape(q, input + cols * (r + k) + c, cols); \
233 } \
234 Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \
235 } \
236 } \
237 }
238
239 /*
240 * Prepare B matrix.
241 * B matrix has to be transposed and quantized.
242 * Cols has to be a multiple of sizeof(Register) / sizeof(Integer).
243 *
244 * cols and rows describe size of transposed B.
245 */
246 #define INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(target, Integer) \
247 target static inline void PrepareBQuantizedTransposed(const Integer* input, Integer* output, Index cols, Index rows) { \
248 const Index RegisterElems = sizeof(Register) / sizeof(Integer); \
249 const Index kColStride = 8; \
250 \
251 assert(cols % RegisterElems == 0); \
252 assert(rows % kColStride == 0); \
253 assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
254 assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
255 \
256 Register* output_it = reinterpret_cast<Register*>(output); \
257 for (Index r = 0; r < rows; r += kColStride) \
258 for (Index c = 0; c < cols; c += RegisterElems) \
259 for (Index ri = 0; ri < 8; ++ri) \
260 *output_it++ = *reinterpret_cast<const Register*>(input + (r + ri) * cols + c); \
261 }
262
263 /*
264 * Prepare B matrix.
265 * B matrix has to be transposed.
266 * Cols has to be a multiple of sizeof(Register) / sizeof(float).
267 *
268 * cols and rows describe size of transposed B.
269 */
270 #define INTGEMM_PREPARE_B_TRANSPOSED(target, Quantizer, Integer) \
271 target static inline void PrepareBTransposed(const float* input, Integer* output, float quant_mult, Index cols, Index rows) { \
272 const Index RegisterElemsInt = sizeof(Register) / sizeof(Integer); \
273 const Index kColStride = 8; \
274 \
275 assert(cols % (sizeof(Register) / sizeof(float)) == 0); \
276 assert(rows % kColStride == 0); \
277 assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
278 assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
279 \
280 FRegister q = set1_ps<FRegister>(quant_mult); \
281 Register* output_it = reinterpret_cast<Register*>(output); \
282 Index r = 0; \
283 Index c = 0; \
284 while (r < rows) { \
285 for (Index ri = 0; ri < 8; ++ri) \
286 *output_it++ = Quantizer::ConsecutiveWithWrapping(q, input + (r + ri) * cols + c, cols - c, cols, 8); \
287 c += RegisterElemsInt; \
288 while (c >= cols) { \
289 r += kColStride; \
290 c -= cols; \
291 } \
292 } \
293 }
294
295 /* Select columns of B from PrepareB format to PrepareB format.
296 */
297 #define INTGEMM_SELECT_COL_B(target, Register) \
298 target static inline void SelectColumnsOfB(const Register *input, Register *output, Index rows_bytes /* number of bytes in a row */, const Index *cols_begin, const Index *cols_end) { \
299 assert(rows_bytes % sizeof(Register) == 0); \
300 assert((cols_end - cols_begin) % 8 == 0); \
301 /* Do columns for multiples of 8.*/ \
302 Index register_rows = rows_bytes / sizeof(Register); \
303 const Register *starts[8]; \
304 for (; cols_begin != cols_end; cols_begin += 8) { \
305 for (Index k = 0; k < 8; ++k) { \
306 starts[k] = input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows; \
307 } \
308 for (Index r = 0; r < register_rows; ++r) { \
309 for (Index k = 0; k < 8; ++k) { \
310 *(output++) = *starts[k]; \
311 starts[k] += 8; \
312 } \
313 } \
314 } \
315 }
316
317 } // namespace intgemm
318