1 ///////////////////////////////////////////////////////////////////////
2 // File:        intsimdmatrixavx2.cpp
3 // Description: matrix-vector product for 8-bit data on avx2.
4 // Author:      Ray Smith
5 //
6 // (C) Copyright 2017, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17 
18 #include "intsimdmatrix.h"
19 
20 #if !defined(__AVX2__)
21 #  if defined(__i686__) || defined(__x86_64__)
22 #    error Implementation only for AVX2 capable architectures
23 #  endif
24 #else
25 #  include <immintrin.h>
26 #  include <algorithm>
27 #  include <cstdint>
28 #  include <vector>
29 
30 namespace tesseract {
31 
32 // Number of outputs held in each register. 8 x 32 bit ints.
33 constexpr int kNumOutputsPerRegister = 8;
34 // Maximum number of registers that we will use.
35 constexpr int kMaxOutputRegisters = 8;
36 // Number of inputs in the inputs register.
37 constexpr int kNumInputsPerRegister = 32;
38 // Number of inputs in each weight group.
39 constexpr int kNumInputsPerGroup = 4;
40 // Number of groups of inputs to be broadcast.
41 constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
42 
43 // Functions to compute part of a matrix.vector multiplication. The weights
44 // are in a very specific order (see above) in w, which is multiplied by
45 // u of length num_in, to produce output v after scaling the integer results
46 // by the corresponding member of scales.
47 // The amount of w and scales consumed is fixed and not available to the
48 // caller. The number of outputs written to v will be at most num_out.
49 
50 // Computes one set of 4x8 products of inputs and weights, adding to result.
51 // Horizontally adds 4 adjacent results, making 8x32-bit results.
52 // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
53 // Note that wi must previously have been re-organized with blocks of 4x8
54 // weights in contiguous memory.
55 // ones is a register of 16x16-bit values all equal to 1.
56 // Note: wi is incremented by the amount of data read.
57 // weights and reps are scratch registers.
58 // This function must be inlined with references in order for the compiler to
59 // correctly use the registers declared in the caller.
MultiplyGroup(const __m256i & rep_input,const __m256i & ones,const int8_t * & wi,__m256i & weights,__m256i & reps,__m256i & result)60 static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi,
61                                  __m256i &weights, __m256i &reps, __m256i &result) {
62   // Load a 4x8 block of weights.
63   weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi));
64   wi += kNumInputsPerRegister;
65   // Normalize the signs on rep_input, weights, so weights is always +ve.
66   reps = _mm256_sign_epi8(rep_input, weights);
67   weights = _mm256_sign_epi8(weights, weights);
68   // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
69   // with adjacent pairs added.
70   weights = _mm256_maddubs_epi16(weights, reps);
71   // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
72   // with  adjacent pairs added. What we really want is a horizontal add of
73   // 16+16=32 bit result, but there is no such instruction, so multiply by
74   // 16-bit ones instead. It is probably faster than all the sign-extending,
75   // permuting and adding that would otherwise be required.
76   weights = _mm256_madd_epi16(weights, ones);
77   result = _mm256_add_epi32(result, weights);
78 }
79 
80 // Load 64 bits into the bottom of a 128bit register.
81 // We don't actually care what the top 64bits are, but this ends
82 // up with them being zero.
load64_to_128(const int8_t * wi_)83 static inline __m128i load64_to_128(const int8_t *wi_) {
84   const auto *wi = reinterpret_cast<const int64_t *>(wi_);
85   return _mm_set_epi64x(0, wi[0]);
86 }
87 
88 #if defined(FAST_FLOAT)
89 
ExtractResults8(__m256i result,const int8_t * wi,const float * scales,float * v)90 static inline void ExtractResults8(__m256i result, const int8_t *wi,
91                                    const float *scales, float *v) {
92   __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
93   __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
94   __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
95   __m256 scale01234567 = _mm256_loadu_ps(scales);
96   w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
97   result = _mm256_add_epi32(result, w256);     // result += bias * 127
98   __m256 res01234567 = _mm256_cvtepi32_ps(result);
99   result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
100   res01234567 = _mm256_mul_ps(res01234567, scale01234567);
101   _mm256_storeu_ps(v, res01234567);
102 }
103 
ExtractResults16(__m256i result0,__m256i result1,const int8_t * & wi,const float * & scales,float * & v)104 static inline void ExtractResults16(__m256i result0, __m256i result1,
105                                     const int8_t *&wi, const float *&scales,
106                                     float *&v) {
107   __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
108   // 8x8bit vals in bottom of 128bit reg
109   const __m256i bias_scale =
110       _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
111   __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
112   __m256 scale01234567 = _mm256_loadu_ps(scales);
113   w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
114   result0 = _mm256_add_epi32(result0, w256);   // result += bias * 127
115   __m256 res01234567 = _mm256_cvtepi32_ps(result0);
116   result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
117   res01234567 = _mm256_mul_ps(res01234567, scale01234567);
118   _mm256_storeu_ps(v, res01234567);
119   w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
120   w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
121   scale01234567 = _mm256_loadu_ps(scales + 8);
122   w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
123   result1 = _mm256_add_epi32(result1, w256);   // result += bias * 127
124   res01234567 = _mm256_cvtepi32_ps(result1);
125   result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
126   res01234567 = _mm256_mul_ps(res01234567, scale01234567);
127   _mm256_storeu_ps(v + 8, res01234567);
128   wi += 16;
129   scales += 16;
130   v += 16;
131 }
132 
133 // Computes part of matrix.vector v = Wu. Computes N=64 results.
134 // The weights *must* be arranged so that consecutive reads from wi
135 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
136 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
137 // bias weights, before continuing with any more weights.
138 // u must be padded out with zeros to
139 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
PartialMatrixDotVector64(const int8_t * wi,const float * scales,const int8_t * u,int num_in,float * v)140 static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
141                                      int num_in, float *v) {
142   // Register containing 16-bit ones for horizontal add with 16->32 bit
143   // conversion.
144   __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
145   __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
146   // Initialize all the results to 0.
147   __m256i result0 = _mm256_setzero_si256();
148   __m256i result1 = _mm256_setzero_si256();
149   __m256i result2 = _mm256_setzero_si256();
150   __m256i result3 = _mm256_setzero_si256();
151   __m256i result4 = _mm256_setzero_si256();
152   __m256i result5 = _mm256_setzero_si256();
153   __m256i result6 = _mm256_setzero_si256();
154   __m256i result7 = _mm256_setzero_si256();
155   // Iterate over the input (u), one registerful at a time.
156   for (int j = 0; j < num_in;) {
157     __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
158     // Inputs are processed in groups of kNumInputsPerGroup, replicated
159     // kNumInputGroups times.
160     for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
161       // Replicate the low 32 bits (4 inputs) 8 times.
162       __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
163       // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
164       inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
165       __m256i weights, reps;
166       // Mul-add, with horizontal add of the 4 inputs to each of the results.
167       MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
168       MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
169       MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
170       MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
171       MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
172       MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
173       MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
174       MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
175     }
176   }
177   ExtractResults16(result0, result1, wi, scales, v);
178   ExtractResults16(result2, result3, wi, scales, v);
179   ExtractResults16(result4, result5, wi, scales, v);
180   ExtractResults16(result6, result7, wi, scales, v);
181 }
182 
183 // Computes part of matrix.vector v = Wu. Computes N=32 results.
184 // For details see PartialMatrixDotVector64 with N=32.
PartialMatrixDotVector32(const int8_t * wi,const float * scales,const int8_t * u,int num_in,float * v)185 static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
186                                      int num_in, float *v) {
187   // Register containing 16-bit ones for horizontal add with 16->32 bit
188   // conversion.
189   __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
190   __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
191   // Initialize all the results to 0.
192   __m256i result0 = _mm256_setzero_si256();
193   __m256i result1 = _mm256_setzero_si256();
194   __m256i result2 = _mm256_setzero_si256();
195   __m256i result3 = _mm256_setzero_si256();
196   // Iterate over the input (u), one registerful at a time.
197   for (int j = 0; j < num_in;) {
198     __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
199     // Inputs are processed in groups of kNumInputsPerGroup, replicated
200     // kNumInputGroups times.
201     for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
202       // Replicate the low 32 bits (4 inputs) 8 times.
203       __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
204       // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
205       inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
206       __m256i weights, reps;
207       // Mul-add, with horizontal add of the 4 inputs to each of the results.
208       MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
209       MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
210       MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
211       MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
212     }
213   }
214   ExtractResults16(result0, result1, wi, scales, v);
215   ExtractResults16(result2, result3, wi, scales, v);
216 }
217 
218 // Computes part of matrix.vector v = Wu. Computes N=16 results.
219 // For details see PartialMatrixDotVector64 with N=16.
PartialMatrixDotVector16(const int8_t * wi,const float * scales,const int8_t * u,int num_in,float * v)220 static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
221                                      int num_in, float *v) {
222   // Register containing 16-bit ones for horizontal add with 16->32 bit
223   // conversion.
224   __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
225   __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
226   // Initialize all the results to 0.
227   __m256i result0 = _mm256_setzero_si256();
228   __m256i result1 = _mm256_setzero_si256();
229   // Iterate over the input (u), one registerful at a time.
230   for (int j = 0; j < num_in;) {
231     __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
232     // Inputs are processed in groups of kNumInputsPerGroup, replicated
233     // kNumInputGroups times.
234     for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
235       // Replicate the low 32 bits (4 inputs) 8 times.
236       __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
237       // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
238       inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
239       __m256i weights, reps;
240       // Mul-add, with horizontal add of the 4 inputs to each of the results.
241       MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
242       MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
243     }
244   }
245   ExtractResults16(result0, result1, wi, scales, v);
246 }
247 
248 // Computes part of matrix.vector v = Wu. Computes N=8 results.
249 // For details see PartialMatrixDotVector64 with N=8.
PartialMatrixDotVector8(const int8_t * wi,const float * scales,const int8_t * u,int num_in,float * v)250 static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
251                                            int num_in, float *v) {
252   // Register containing 16-bit ones for horizontal add with 16->32 bit
253   // conversion.
254   __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
255   __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
256   // Initialize all the results to 0.
257   __m256i result0 = _mm256_setzero_si256();
258   // Iterate over the input (u), one registerful at a time.
259   for (int j = 0; j < num_in;) {
260     __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
261     // Inputs are processed in groups of kNumInputsPerGroup, replicated
262     // kNumInputGroups times.
263     for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
264       // Replicate the low 32 bits (4 inputs) 8 times.
265       __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
266       // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
267       inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
268       __m256i weights, reps;
269       // Mul-add, with horizontal add of the 4 inputs to each of the results.
270       MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
271     }
272   }
273   ExtractResults8(result0, wi, scales, v);
274 }
275 
matrixDotVector(int dim1,int dim2,const int8_t * wi,const float * scales,const int8_t * u,float * v)276 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
277                             const int8_t *u, float *v) {
278   const int num_out = dim1;
279   const int num_in = dim2 - 1;
280   // Each call to a partial_func_ produces group_size outputs, except the
281   // last one, which can produce less.
282   const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
283   const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
284   int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
285   int output = 0;
286 
287   int w_step = (rounded_num_in + 1) * group_size;
288 
289   // Run with this group size, until it would produce too much output, then
290   // switch to a smaller size.
291   for (; output + group_size <= rounded_num_out; output += group_size) {
292     PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
293     wi += w_step;
294     scales += group_size;
295     v += group_size;
296   }
297   group_size /= 2;
298   w_step /= 2;
299 
300   if (output + group_size <= rounded_num_out) {
301     PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
302     wi += w_step;
303     scales += group_size;
304     v += group_size;
305     output += group_size;
306   }
307   group_size /= 2;
308   w_step /= 2;
309 
310   if (output + group_size <= rounded_num_out) {
311     PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
312     wi += w_step;
313     scales += group_size;
314     v += group_size;
315     output += group_size;
316   }
317   group_size /= 2;
318   w_step /= 2;
319 
320   if (output + group_size <= rounded_num_out) {
321     PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
322   }
323 }
324 #else
ExtractResults8(__m256i result,const int8_t * wi,const double * scales,double * v)325 static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
326                                    double *v) {
327   __m128i w128 = load64_to_128(wi);          // 8x8bit vals in bottom of 128bit reg
328   __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
329   __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
330   __m256d scale0123 = _mm256_loadu_pd(scales);
331   __m256d scale4567 = _mm256_loadu_pd(scales + 4);
332   w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
333   result = _mm256_add_epi32(result, w256);     // result += bias * 127
334   __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
335   result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
336   __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
337   res0123 = _mm256_mul_pd(res0123, scale0123);
338   res4567 = _mm256_mul_pd(res4567, scale4567);
339   _mm256_storeu_pd(v, res0123);
340   _mm256_storeu_pd(v + 4, res4567);
341 }
342 
ExtractResults16(__m256i result0,__m256i result1,const int8_t * & wi,const double * & scales,double * & v)343 static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
344                                     const double *&scales, double *&v) {
345   __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
346   // 8x8bit vals in bottom of 128bit reg
347   const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
348   __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
349   __m256d scale0123 = _mm256_loadu_pd(scales);
350   __m256d scale4567 = _mm256_loadu_pd(scales + 4);
351   w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
352   result0 = _mm256_add_epi32(result0, w256);   // result += bias * 127
353   __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
354   result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
355   __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
356   res0123 = _mm256_mul_pd(res0123, scale0123);
357   res4567 = _mm256_mul_pd(res4567, scale4567);
358   _mm256_storeu_pd(v, res0123);
359   _mm256_storeu_pd(v + 4, res4567);
360   w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
361   w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
362   scale0123 = _mm256_loadu_pd(scales + 8);
363   scale4567 = _mm256_loadu_pd(scales + 12);
364   w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
365   result1 = _mm256_add_epi32(result1, w256);   // result += bias * 127
366   res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
367   result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
368   res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
369   res0123 = _mm256_mul_pd(res0123, scale0123);
370   res4567 = _mm256_mul_pd(res4567, scale4567);
371   _mm256_storeu_pd(v + 8, res0123);
372   _mm256_storeu_pd(v + 12, res4567);
373   wi += 16;
374   scales += 16;
375   v += 16;
376 }
377 
378 // Computes part of matrix.vector v = Wu. Computes N=64 results.
379 // The weights *must* be arranged so that consecutive reads from wi
380 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
381 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
382 // bias weights, before continuing with any more weights.
383 // u must be padded out with zeros to
384 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
PartialMatrixDotVector64(const int8_t * wi,const double * scales,const int8_t * u,int num_in,double * v)385 static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
386                                      int num_in, double *v) {
387   // Register containing 16-bit ones for horizontal add with 16->32 bit
388   // conversion.
389   __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
390   __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
391   // Initialize all the results to 0.
392   __m256i result0 = _mm256_setzero_si256();
393   __m256i result1 = _mm256_setzero_si256();
394   __m256i result2 = _mm256_setzero_si256();
395   __m256i result3 = _mm256_setzero_si256();
396   __m256i result4 = _mm256_setzero_si256();
397   __m256i result5 = _mm256_setzero_si256();
398   __m256i result6 = _mm256_setzero_si256();
399   __m256i result7 = _mm256_setzero_si256();
400   // Iterate over the input (u), one registerful at a time.
401   for (int j = 0; j < num_in;) {
402     __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
403     // Inputs are processed in groups of kNumInputsPerGroup, replicated
404     // kNumInputGroups times.
405     for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
406       // Replicate the low 32 bits (4 inputs) 8 times.
407       __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
408       // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
409       inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
410       __m256i weights, reps;
411       // Mul-add, with horizontal add of the 4 inputs to each of the results.
412       MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
413       MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
414       MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
415       MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
416       MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
417       MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
418       MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
419       MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
420     }
421   }
422   ExtractResults16(result0, result1, wi, scales, v);
423   ExtractResults16(result2, result3, wi, scales, v);
424   ExtractResults16(result4, result5, wi, scales, v);
425   ExtractResults16(result6, result7, wi, scales, v);
426 }
427 
428 // Computes part of matrix.vector v = Wu. Computes N=32 results.
429 // For details see PartialMatrixDotVector64 with N=32.
PartialMatrixDotVector32(const int8_t * wi,const double * scales,const int8_t * u,int num_in,double * v)430 static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
431                                      int num_in, double *v) {
432   // Register containing 16-bit ones for horizontal add with 16->32 bit
433   // conversion.
434   __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
435   __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
436   // Initialize all the results to 0.
437   __m256i result0 = _mm256_setzero_si256();
438   __m256i result1 = _mm256_setzero_si256();
439   __m256i result2 = _mm256_setzero_si256();
440   __m256i result3 = _mm256_setzero_si256();
441   // Iterate over the input (u), one registerful at a time.
442   for (int j = 0; j < num_in;) {
443     __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
444     // Inputs are processed in groups of kNumInputsPerGroup, replicated
445     // kNumInputGroups times.
446     for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
447       // Replicate the low 32 bits (4 inputs) 8 times.
448       __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
449       // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
450       inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
451       __m256i weights, reps;
452       // Mul-add, with horizontal add of the 4 inputs to each of the results.
453       MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
454       MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
455       MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
456       MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
457     }
458   }
459   ExtractResults16(result0, result1, wi, scales, v);
460   ExtractResults16(result2, result3, wi, scales, v);
461 }
462 
463 // Computes part of matrix.vector v = Wu. Computes N=16 results.
464 // For details see PartialMatrixDotVector64 with N=16.
PartialMatrixDotVector16(const int8_t * wi,const double * scales,const int8_t * u,int num_in,double * v)465 static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
466                                      int num_in, double *v) {
467   // Register containing 16-bit ones for horizontal add with 16->32 bit
468   // conversion.
469   __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
470   __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
471   // Initialize all the results to 0.
472   __m256i result0 = _mm256_setzero_si256();
473   __m256i result1 = _mm256_setzero_si256();
474   // Iterate over the input (u), one registerful at a time.
475   for (int j = 0; j < num_in;) {
476     __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
477     // Inputs are processed in groups of kNumInputsPerGroup, replicated
478     // kNumInputGroups times.
479     for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
480       // Replicate the low 32 bits (4 inputs) 8 times.
481       __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
482       // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
483       inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
484       __m256i weights, reps;
485       // Mul-add, with horizontal add of the 4 inputs to each of the results.
486       MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
487       MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
488     }
489   }
490   ExtractResults16(result0, result1, wi, scales, v);
491 }
492 
493 // Computes part of matrix.vector v = Wu. Computes N=8 results.
494 // For details see PartialMatrixDotVector64 with N=8.
PartialMatrixDotVector8(const int8_t * wi,const double * scales,const int8_t * u,int num_in,double * v)495 static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
496                                            int num_in, double *v) {
497   // Register containing 16-bit ones for horizontal add with 16->32 bit
498   // conversion.
499   __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
500   __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
501   // Initialize all the results to 0.
502   __m256i result0 = _mm256_setzero_si256();
503   // Iterate over the input (u), one registerful at a time.
504   for (int j = 0; j < num_in;) {
505     __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
506     // Inputs are processed in groups of kNumInputsPerGroup, replicated
507     // kNumInputGroups times.
508     for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
509       // Replicate the low 32 bits (4 inputs) 8 times.
510       __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
511       // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
512       inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
513       __m256i weights, reps;
514       // Mul-add, with horizontal add of the 4 inputs to each of the results.
515       MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
516     }
517   }
518   ExtractResults8(result0, wi, scales, v);
519 }
520 
matrixDotVector(int dim1,int dim2,const int8_t * wi,const double * scales,const int8_t * u,double * v)521 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
522                             const int8_t *u, double *v) {
523   const int num_out = dim1;
524   const int num_in = dim2 - 1;
525   // Each call to a partial_func_ produces group_size outputs, except the
526   // last one, which can produce less.
527   const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
528   const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
529   int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
530   int output = 0;
531 
532   int w_step = (rounded_num_in + 1) * group_size;
533 
534   // Run with this group size, until it would produce too much output, then
535   // switch to a smaller size.
536   for (; output + group_size <= rounded_num_out; output += group_size) {
537     PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
538     wi += w_step;
539     scales += group_size;
540     v += group_size;
541   }
542   group_size /= 2;
543   w_step /= 2;
544 
545   if (output + group_size <= rounded_num_out) {
546     PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
547     wi += w_step;
548     scales += group_size;
549     v += group_size;
550     output += group_size;
551   }
552   group_size /= 2;
553   w_step /= 2;
554 
555   if (output + group_size <= rounded_num_out) {
556     PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
557     wi += w_step;
558     scales += group_size;
559     v += group_size;
560     output += group_size;
561   }
562   group_size /= 2;
563   w_step /= 2;
564 
565   if (output + group_size <= rounded_num_out) {
566     PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
567   }
568 }
569 #endif
570 
571 const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
572     // Function.
573     matrixDotVector,
574     // Number of 32 bit outputs held in each register.
575     kNumOutputsPerRegister,
576     // Maximum number of registers that we will use to hold outputs.
577     kMaxOutputRegisters,
578     // Number of 8 bit inputs in the inputs register.
579     kNumInputsPerRegister,
580     // Number of inputs in each weight group.
581     kNumInputsPerGroup
582 };
583 
584 } // namespace tesseract.
585 
586 #endif
587