1 ///////////////////////////////////////////////////////////////////////
2 // File: intsimdmatrixavx2.cpp
3 // Description: matrix-vector product for 8-bit data on avx2.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2017, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17
18 #include "intsimdmatrix.h"
19
20 #if !defined(__AVX2__)
21 # if defined(__i686__) || defined(__x86_64__)
22 # error Implementation only for AVX2 capable architectures
23 # endif
24 #else
25 # include <immintrin.h>
26 # include <algorithm>
27 # include <cstdint>
28 # include <vector>
29
30 namespace tesseract {
31
32 // Number of outputs held in each register. 8 x 32 bit ints.
33 constexpr int kNumOutputsPerRegister = 8;
34 // Maximum number of registers that we will use.
35 constexpr int kMaxOutputRegisters = 8;
36 // Number of inputs in the inputs register.
37 constexpr int kNumInputsPerRegister = 32;
38 // Number of inputs in each weight group.
39 constexpr int kNumInputsPerGroup = 4;
40 // Number of groups of inputs to be broadcast.
41 constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
42
43 // Functions to compute part of a matrix.vector multiplication. The weights
44 // are in a very specific order (see above) in w, which is multiplied by
45 // u of length num_in, to produce output v after scaling the integer results
46 // by the corresponding member of scales.
47 // The amount of w and scales consumed is fixed and not available to the
48 // caller. The number of outputs written to v will be at most num_out.
49
50 // Computes one set of 4x8 products of inputs and weights, adding to result.
51 // Horizontally adds 4 adjacent results, making 8x32-bit results.
52 // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
53 // Note that wi must previously have been re-organized with blocks of 4x8
54 // weights in contiguous memory.
55 // ones is a register of 16x16-bit values all equal to 1.
56 // Note: wi is incremented by the amount of data read.
57 // weights and reps are scratch registers.
58 // This function must be inlined with references in order for the compiler to
59 // correctly use the registers declared in the caller.
MultiplyGroup(const __m256i & rep_input,const __m256i & ones,const int8_t * & wi,__m256i & weights,__m256i & reps,__m256i & result)60 static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi,
61 __m256i &weights, __m256i &reps, __m256i &result) {
62 // Load a 4x8 block of weights.
63 weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi));
64 wi += kNumInputsPerRegister;
65 // Normalize the signs on rep_input, weights, so weights is always +ve.
66 reps = _mm256_sign_epi8(rep_input, weights);
67 weights = _mm256_sign_epi8(weights, weights);
68 // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
69 // with adjacent pairs added.
70 weights = _mm256_maddubs_epi16(weights, reps);
71 // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
72 // with adjacent pairs added. What we really want is a horizontal add of
73 // 16+16=32 bit result, but there is no such instruction, so multiply by
74 // 16-bit ones instead. It is probably faster than all the sign-extending,
75 // permuting and adding that would otherwise be required.
76 weights = _mm256_madd_epi16(weights, ones);
77 result = _mm256_add_epi32(result, weights);
78 }
79
80 // Load 64 bits into the bottom of a 128bit register.
81 // We don't actually care what the top 64bits are, but this ends
82 // up with them being zero.
load64_to_128(const int8_t * wi_)83 static inline __m128i load64_to_128(const int8_t *wi_) {
84 const auto *wi = reinterpret_cast<const int64_t *>(wi_);
85 return _mm_set_epi64x(0, wi[0]);
86 }
87
88 #if defined(FAST_FLOAT)
89
ExtractResults8(__m256i result,const int8_t * wi,const float * scales,float * v)90 static inline void ExtractResults8(__m256i result, const int8_t *wi,
91 const float *scales, float *v) {
92 __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
93 __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
94 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
95 __m256 scale01234567 = _mm256_loadu_ps(scales);
96 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
97 result = _mm256_add_epi32(result, w256); // result += bias * 127
98 __m256 res01234567 = _mm256_cvtepi32_ps(result);
99 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
100 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
101 _mm256_storeu_ps(v, res01234567);
102 }
103
ExtractResults16(__m256i result0,__m256i result1,const int8_t * & wi,const float * & scales,float * & v)104 static inline void ExtractResults16(__m256i result0, __m256i result1,
105 const int8_t *&wi, const float *&scales,
106 float *&v) {
107 __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
108 // 8x8bit vals in bottom of 128bit reg
109 const __m256i bias_scale =
110 _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
111 __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
112 __m256 scale01234567 = _mm256_loadu_ps(scales);
113 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
114 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
115 __m256 res01234567 = _mm256_cvtepi32_ps(result0);
116 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
117 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
118 _mm256_storeu_ps(v, res01234567);
119 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
120 w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
121 scale01234567 = _mm256_loadu_ps(scales + 8);
122 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
123 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
124 res01234567 = _mm256_cvtepi32_ps(result1);
125 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
126 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
127 _mm256_storeu_ps(v + 8, res01234567);
128 wi += 16;
129 scales += 16;
130 v += 16;
131 }
132
133 // Computes part of matrix.vector v = Wu. Computes N=64 results.
134 // The weights *must* be arranged so that consecutive reads from wi
135 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
136 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
137 // bias weights, before continuing with any more weights.
138 // u must be padded out with zeros to
139 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
PartialMatrixDotVector64(const int8_t * wi,const float * scales,const int8_t * u,int num_in,float * v)140 static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
141 int num_in, float *v) {
142 // Register containing 16-bit ones for horizontal add with 16->32 bit
143 // conversion.
144 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
145 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
146 // Initialize all the results to 0.
147 __m256i result0 = _mm256_setzero_si256();
148 __m256i result1 = _mm256_setzero_si256();
149 __m256i result2 = _mm256_setzero_si256();
150 __m256i result3 = _mm256_setzero_si256();
151 __m256i result4 = _mm256_setzero_si256();
152 __m256i result5 = _mm256_setzero_si256();
153 __m256i result6 = _mm256_setzero_si256();
154 __m256i result7 = _mm256_setzero_si256();
155 // Iterate over the input (u), one registerful at a time.
156 for (int j = 0; j < num_in;) {
157 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
158 // Inputs are processed in groups of kNumInputsPerGroup, replicated
159 // kNumInputGroups times.
160 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
161 // Replicate the low 32 bits (4 inputs) 8 times.
162 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
163 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
164 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
165 __m256i weights, reps;
166 // Mul-add, with horizontal add of the 4 inputs to each of the results.
167 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
168 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
169 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
170 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
171 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
172 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
173 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
174 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
175 }
176 }
177 ExtractResults16(result0, result1, wi, scales, v);
178 ExtractResults16(result2, result3, wi, scales, v);
179 ExtractResults16(result4, result5, wi, scales, v);
180 ExtractResults16(result6, result7, wi, scales, v);
181 }
182
183 // Computes part of matrix.vector v = Wu. Computes N=32 results.
184 // For details see PartialMatrixDotVector64 with N=32.
PartialMatrixDotVector32(const int8_t * wi,const float * scales,const int8_t * u,int num_in,float * v)185 static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
186 int num_in, float *v) {
187 // Register containing 16-bit ones for horizontal add with 16->32 bit
188 // conversion.
189 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
190 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
191 // Initialize all the results to 0.
192 __m256i result0 = _mm256_setzero_si256();
193 __m256i result1 = _mm256_setzero_si256();
194 __m256i result2 = _mm256_setzero_si256();
195 __m256i result3 = _mm256_setzero_si256();
196 // Iterate over the input (u), one registerful at a time.
197 for (int j = 0; j < num_in;) {
198 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
199 // Inputs are processed in groups of kNumInputsPerGroup, replicated
200 // kNumInputGroups times.
201 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
202 // Replicate the low 32 bits (4 inputs) 8 times.
203 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
204 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
205 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
206 __m256i weights, reps;
207 // Mul-add, with horizontal add of the 4 inputs to each of the results.
208 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
209 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
210 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
211 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
212 }
213 }
214 ExtractResults16(result0, result1, wi, scales, v);
215 ExtractResults16(result2, result3, wi, scales, v);
216 }
217
218 // Computes part of matrix.vector v = Wu. Computes N=16 results.
219 // For details see PartialMatrixDotVector64 with N=16.
PartialMatrixDotVector16(const int8_t * wi,const float * scales,const int8_t * u,int num_in,float * v)220 static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
221 int num_in, float *v) {
222 // Register containing 16-bit ones for horizontal add with 16->32 bit
223 // conversion.
224 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
225 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
226 // Initialize all the results to 0.
227 __m256i result0 = _mm256_setzero_si256();
228 __m256i result1 = _mm256_setzero_si256();
229 // Iterate over the input (u), one registerful at a time.
230 for (int j = 0; j < num_in;) {
231 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
232 // Inputs are processed in groups of kNumInputsPerGroup, replicated
233 // kNumInputGroups times.
234 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
235 // Replicate the low 32 bits (4 inputs) 8 times.
236 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
237 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
238 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
239 __m256i weights, reps;
240 // Mul-add, with horizontal add of the 4 inputs to each of the results.
241 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
242 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
243 }
244 }
245 ExtractResults16(result0, result1, wi, scales, v);
246 }
247
248 // Computes part of matrix.vector v = Wu. Computes N=8 results.
249 // For details see PartialMatrixDotVector64 with N=8.
PartialMatrixDotVector8(const int8_t * wi,const float * scales,const int8_t * u,int num_in,float * v)250 static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
251 int num_in, float *v) {
252 // Register containing 16-bit ones for horizontal add with 16->32 bit
253 // conversion.
254 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
255 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
256 // Initialize all the results to 0.
257 __m256i result0 = _mm256_setzero_si256();
258 // Iterate over the input (u), one registerful at a time.
259 for (int j = 0; j < num_in;) {
260 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
261 // Inputs are processed in groups of kNumInputsPerGroup, replicated
262 // kNumInputGroups times.
263 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
264 // Replicate the low 32 bits (4 inputs) 8 times.
265 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
266 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
267 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
268 __m256i weights, reps;
269 // Mul-add, with horizontal add of the 4 inputs to each of the results.
270 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
271 }
272 }
273 ExtractResults8(result0, wi, scales, v);
274 }
275
matrixDotVector(int dim1,int dim2,const int8_t * wi,const float * scales,const int8_t * u,float * v)276 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
277 const int8_t *u, float *v) {
278 const int num_out = dim1;
279 const int num_in = dim2 - 1;
280 // Each call to a partial_func_ produces group_size outputs, except the
281 // last one, which can produce less.
282 const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
283 const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
284 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
285 int output = 0;
286
287 int w_step = (rounded_num_in + 1) * group_size;
288
289 // Run with this group size, until it would produce too much output, then
290 // switch to a smaller size.
291 for (; output + group_size <= rounded_num_out; output += group_size) {
292 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
293 wi += w_step;
294 scales += group_size;
295 v += group_size;
296 }
297 group_size /= 2;
298 w_step /= 2;
299
300 if (output + group_size <= rounded_num_out) {
301 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
302 wi += w_step;
303 scales += group_size;
304 v += group_size;
305 output += group_size;
306 }
307 group_size /= 2;
308 w_step /= 2;
309
310 if (output + group_size <= rounded_num_out) {
311 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
312 wi += w_step;
313 scales += group_size;
314 v += group_size;
315 output += group_size;
316 }
317 group_size /= 2;
318 w_step /= 2;
319
320 if (output + group_size <= rounded_num_out) {
321 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
322 }
323 }
324 #else
ExtractResults8(__m256i result,const int8_t * wi,const double * scales,double * v)325 static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
326 double *v) {
327 __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
328 __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
329 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
330 __m256d scale0123 = _mm256_loadu_pd(scales);
331 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
332 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
333 result = _mm256_add_epi32(result, w256); // result += bias * 127
334 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
335 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
336 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
337 res0123 = _mm256_mul_pd(res0123, scale0123);
338 res4567 = _mm256_mul_pd(res4567, scale4567);
339 _mm256_storeu_pd(v, res0123);
340 _mm256_storeu_pd(v + 4, res4567);
341 }
342
ExtractResults16(__m256i result0,__m256i result1,const int8_t * & wi,const double * & scales,double * & v)343 static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
344 const double *&scales, double *&v) {
345 __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
346 // 8x8bit vals in bottom of 128bit reg
347 const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
348 __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
349 __m256d scale0123 = _mm256_loadu_pd(scales);
350 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
351 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
352 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
353 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
354 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
355 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
356 res0123 = _mm256_mul_pd(res0123, scale0123);
357 res4567 = _mm256_mul_pd(res4567, scale4567);
358 _mm256_storeu_pd(v, res0123);
359 _mm256_storeu_pd(v + 4, res4567);
360 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
361 w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
362 scale0123 = _mm256_loadu_pd(scales + 8);
363 scale4567 = _mm256_loadu_pd(scales + 12);
364 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
365 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
366 res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
367 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
368 res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
369 res0123 = _mm256_mul_pd(res0123, scale0123);
370 res4567 = _mm256_mul_pd(res4567, scale4567);
371 _mm256_storeu_pd(v + 8, res0123);
372 _mm256_storeu_pd(v + 12, res4567);
373 wi += 16;
374 scales += 16;
375 v += 16;
376 }
377
378 // Computes part of matrix.vector v = Wu. Computes N=64 results.
379 // The weights *must* be arranged so that consecutive reads from wi
380 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
381 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
382 // bias weights, before continuing with any more weights.
383 // u must be padded out with zeros to
384 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
PartialMatrixDotVector64(const int8_t * wi,const double * scales,const int8_t * u,int num_in,double * v)385 static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
386 int num_in, double *v) {
387 // Register containing 16-bit ones for horizontal add with 16->32 bit
388 // conversion.
389 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
390 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
391 // Initialize all the results to 0.
392 __m256i result0 = _mm256_setzero_si256();
393 __m256i result1 = _mm256_setzero_si256();
394 __m256i result2 = _mm256_setzero_si256();
395 __m256i result3 = _mm256_setzero_si256();
396 __m256i result4 = _mm256_setzero_si256();
397 __m256i result5 = _mm256_setzero_si256();
398 __m256i result6 = _mm256_setzero_si256();
399 __m256i result7 = _mm256_setzero_si256();
400 // Iterate over the input (u), one registerful at a time.
401 for (int j = 0; j < num_in;) {
402 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
403 // Inputs are processed in groups of kNumInputsPerGroup, replicated
404 // kNumInputGroups times.
405 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
406 // Replicate the low 32 bits (4 inputs) 8 times.
407 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
408 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
409 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
410 __m256i weights, reps;
411 // Mul-add, with horizontal add of the 4 inputs to each of the results.
412 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
413 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
414 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
415 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
416 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
417 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
418 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
419 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
420 }
421 }
422 ExtractResults16(result0, result1, wi, scales, v);
423 ExtractResults16(result2, result3, wi, scales, v);
424 ExtractResults16(result4, result5, wi, scales, v);
425 ExtractResults16(result6, result7, wi, scales, v);
426 }
427
428 // Computes part of matrix.vector v = Wu. Computes N=32 results.
429 // For details see PartialMatrixDotVector64 with N=32.
PartialMatrixDotVector32(const int8_t * wi,const double * scales,const int8_t * u,int num_in,double * v)430 static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
431 int num_in, double *v) {
432 // Register containing 16-bit ones for horizontal add with 16->32 bit
433 // conversion.
434 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
435 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
436 // Initialize all the results to 0.
437 __m256i result0 = _mm256_setzero_si256();
438 __m256i result1 = _mm256_setzero_si256();
439 __m256i result2 = _mm256_setzero_si256();
440 __m256i result3 = _mm256_setzero_si256();
441 // Iterate over the input (u), one registerful at a time.
442 for (int j = 0; j < num_in;) {
443 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
444 // Inputs are processed in groups of kNumInputsPerGroup, replicated
445 // kNumInputGroups times.
446 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
447 // Replicate the low 32 bits (4 inputs) 8 times.
448 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
449 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
450 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
451 __m256i weights, reps;
452 // Mul-add, with horizontal add of the 4 inputs to each of the results.
453 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
454 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
455 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
456 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
457 }
458 }
459 ExtractResults16(result0, result1, wi, scales, v);
460 ExtractResults16(result2, result3, wi, scales, v);
461 }
462
463 // Computes part of matrix.vector v = Wu. Computes N=16 results.
464 // For details see PartialMatrixDotVector64 with N=16.
PartialMatrixDotVector16(const int8_t * wi,const double * scales,const int8_t * u,int num_in,double * v)465 static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
466 int num_in, double *v) {
467 // Register containing 16-bit ones for horizontal add with 16->32 bit
468 // conversion.
469 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
470 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
471 // Initialize all the results to 0.
472 __m256i result0 = _mm256_setzero_si256();
473 __m256i result1 = _mm256_setzero_si256();
474 // Iterate over the input (u), one registerful at a time.
475 for (int j = 0; j < num_in;) {
476 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
477 // Inputs are processed in groups of kNumInputsPerGroup, replicated
478 // kNumInputGroups times.
479 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
480 // Replicate the low 32 bits (4 inputs) 8 times.
481 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
482 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
483 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
484 __m256i weights, reps;
485 // Mul-add, with horizontal add of the 4 inputs to each of the results.
486 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
487 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
488 }
489 }
490 ExtractResults16(result0, result1, wi, scales, v);
491 }
492
493 // Computes part of matrix.vector v = Wu. Computes N=8 results.
494 // For details see PartialMatrixDotVector64 with N=8.
PartialMatrixDotVector8(const int8_t * wi,const double * scales,const int8_t * u,int num_in,double * v)495 static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
496 int num_in, double *v) {
497 // Register containing 16-bit ones for horizontal add with 16->32 bit
498 // conversion.
499 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
500 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
501 // Initialize all the results to 0.
502 __m256i result0 = _mm256_setzero_si256();
503 // Iterate over the input (u), one registerful at a time.
504 for (int j = 0; j < num_in;) {
505 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
506 // Inputs are processed in groups of kNumInputsPerGroup, replicated
507 // kNumInputGroups times.
508 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
509 // Replicate the low 32 bits (4 inputs) 8 times.
510 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
511 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
512 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
513 __m256i weights, reps;
514 // Mul-add, with horizontal add of the 4 inputs to each of the results.
515 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
516 }
517 }
518 ExtractResults8(result0, wi, scales, v);
519 }
520
matrixDotVector(int dim1,int dim2,const int8_t * wi,const double * scales,const int8_t * u,double * v)521 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
522 const int8_t *u, double *v) {
523 const int num_out = dim1;
524 const int num_in = dim2 - 1;
525 // Each call to a partial_func_ produces group_size outputs, except the
526 // last one, which can produce less.
527 const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
528 const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
529 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
530 int output = 0;
531
532 int w_step = (rounded_num_in + 1) * group_size;
533
534 // Run with this group size, until it would produce too much output, then
535 // switch to a smaller size.
536 for (; output + group_size <= rounded_num_out; output += group_size) {
537 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
538 wi += w_step;
539 scales += group_size;
540 v += group_size;
541 }
542 group_size /= 2;
543 w_step /= 2;
544
545 if (output + group_size <= rounded_num_out) {
546 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
547 wi += w_step;
548 scales += group_size;
549 v += group_size;
550 output += group_size;
551 }
552 group_size /= 2;
553 w_step /= 2;
554
555 if (output + group_size <= rounded_num_out) {
556 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
557 wi += w_step;
558 scales += group_size;
559 v += group_size;
560 output += group_size;
561 }
562 group_size /= 2;
563 w_step /= 2;
564
565 if (output + group_size <= rounded_num_out) {
566 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
567 }
568 }
569 #endif
570
571 const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
572 // Function.
573 matrixDotVector,
574 // Number of 32 bit outputs held in each register.
575 kNumOutputsPerRegister,
576 // Maximum number of registers that we will use to hold outputs.
577 kMaxOutputRegisters,
578 // Number of 8 bit inputs in the inputs register.
579 kNumInputsPerRegister,
580 // Number of inputs in each weight group.
581 kNumInputsPerGroup
582 };
583
584 } // namespace tesseract.
585
586 #endif
587