1 ///////////////////////////////////////////////////////////////////////
2 // File:        intsimdmatrixneon.cpp
3 // Description: matrix-vector product for 8-bit data on neon.
4 // Author:      Robin Watts (from the AVX2 original by Ray Smith)
5 //
6 // (C) Copyright 2017, Google Inc.
7 // (C) Copyright 2020, Artifex Software Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////
18 
19 #if defined(__ARM_NEON)
20 
21 #  include "intsimdmatrix.h"
22 #  include "tesstypes.h"
23 
24 #  include <algorithm>
25 #  include <cstdint>
26 #  include <vector>
27 #  include "arm_neon.h"
28 
29 namespace tesseract {
30 
31 // Number of outputs held in each register. (Actually, we use a
32 // pair of 4x32 registers, so 8 x 32 bit ints).
33 constexpr int kNumOutputsPerRegister = 8;
34 // Maximum number of registers that we will use.
35 constexpr int kMaxOutputRegisters = 1;
36 // Number of inputs in the inputs register.
37 constexpr int kNumInputsPerRegister = 8;
38 // Number of inputs in each weight group.
39 constexpr int kNumInputsPerGroup = 8;
40 
41 // Function to compute part of a matrix.vector multiplication. The weights
42 // are in a very specific order (see above) in w, which is multiplied by
43 // u of length num_in, to produce output v after scaling the integer results
44 // by the corresponding member of scales.
45 // The amount of w and scales consumed is fixed and not available to the
46 // caller.
47 
48 // Computes part of matrix.vector v = Wu. Computes N=8 results.
49 // The weights *must* be arranged so that consecutive reads from wi
50 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
51 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
52 // bias weights, before continuing with any more weights.
53 // u must be padded out with zeros to
54 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
PartialMatrixDotVector8(const int8_t * __restrict wi,const TFloat * __restrict scales,const int8_t * __restrict u,int num_in,TFloat * __restrict v,int num_out)55 static inline void PartialMatrixDotVector8(const int8_t *__restrict wi,
56                                            const TFloat *__restrict scales,
57                                            const int8_t *__restrict u, int num_in,
58                                            TFloat *__restrict v, int num_out) {
59   // Initialize all the results to 0.
60   int32x4_t result0123 = {0, 0, 0, 0};
61   int32x4_t result4567 = {0, 0, 0, 0};
62   int8x8_t bias_scale = {127, 127, 127, 127, 127, 127, 127, 127};
63   // Iterate over the input (u), one registerful at a time.
64   for (int j = 0; j < num_in; j += 8) {
65     int8x8_t vu = vld1_s8(u);              // vu     = u0  u1  u2  u3  u4  u5  u6  u7
66     int8x16_t vw01 = vld1q_s8(wi);         // vw0    = w00 w01 w02 w03 w04 w05 w06 w07
67                                            // w10 w11 w12 w13 w14 w15 w16 w17
68     int8x16_t vw23 = vld1q_s8(wi + 8 * 2); // vw2    = w20 w21 w22 w23 w24 w25 w26 w27 w30
69                                            // w31 w32 w33 w34 w35 w36 w37
70     int8x16_t vw45 = vld1q_s8(wi + 8 * 4); // vw4    = w40 w41 w42 w43 w44 w45 w46 w47 w50
71                                            // w51 w52 w53 w54 w55 w56 w57
72     int8x16_t vw67 = vld1q_s8(wi + 8 * 6); // vw6    = w60 w61 w62 w63 w64 w65 w66 w67 w70
73                                            // w71 w72 w73 w74 w75 w76 w77
74 
75     int16x8_t vrow0q = vmull_s8(vget_low_s8(vw01), vu); // vrow0q = vw00.u0 w01.u1 w02.u2
76                                                         // w03.u3 vw04.u4 w05.u5 w06.u6 w07.u7
77     int16x8_t vrow1q = vmull_s8(vget_high_s8(vw01),
78                                 vu);                    // vrow1q = vw10.u0 w11.u1 w12.u2 w13.u3
79                                                         // vw14.u4 w15.u5 w16.u6 w17.u7
80     int16x8_t vrow2q = vmull_s8(vget_low_s8(vw23), vu); // vrow2q = vw20.u0 w21.u1 w22.u2
81                                                         // w23.u3 vw24.u4 w25.u5 w26.u6 w27.u7
82     int16x8_t vrow3q = vmull_s8(vget_high_s8(vw23),
83                                 vu);                    // vrow3q = vw30.u0 w31.u1 w32.u2 w33.u3
84                                                         // vw34.u4 w35.u5 w36.u6 w37.u7
85     int16x8_t vrow4q = vmull_s8(vget_low_s8(vw45), vu); // vrow4q = vw40.u0 w41.u1 w42.u2
86                                                         // w43.u3 vw44.u4 w45.u5 w46.u6 w47.u7
87     int16x8_t vrow5q = vmull_s8(vget_high_s8(vw45),
88                                 vu);                    // vrow5q = vw50.u0 w51.u1 w52.u2 w53.u3
89                                                         // vw54.u4 w55.u5 w56.u6 w57.u7
90     int16x8_t vrow6q = vmull_s8(vget_low_s8(vw67), vu); // vrow6q = vw60.u0 w61.u1 w62.u2
91                                                         // w63.u3 vw64.u4 w65.u5 w66.u6 w67.u7
92     int16x8_t vrow7q = vmull_s8(vget_high_s8(vw67),
93                                 vu); // vrow7q = vw70.u0 w71.u1 w72.u2 w73.u3
94                                      // vw74.u4 w75.u5 w76.u6 w77.u7
95 
96     int32x4_t vrow0q2 = vpaddlq_s16(vrow0q); // vrow0q2 = vw00.u0+w01.u1 w02.u2+w03.u3
97                                              // vw04.u4+w05.u5 w06.u6+w07.u7
98     int32x4_t vrow1q2 = vpaddlq_s16(vrow1q); // vrow1q2 = vw10.u0+w11.u1 w12.u2+w13.u3
99                                              // vw14.u4+w15.u5 w16.u6+w17.u7
100     int32x4_t vrow2q2 = vpaddlq_s16(vrow2q); // vrow2q2 = vw20.u0+w21.u1 w22.u2+w23.u3
101                                              // vw24.u4+w25.u5 w26.u6+w27.u7
102     int32x4_t vrow3q2 = vpaddlq_s16(vrow3q); // vrow3q2 = vw30.u0+w31.u1 w32.u2+w33.u3
103                                              // vw34.u4+w35.u5 w36.u6+w37.u7
104     int32x4_t vrow4q2 = vpaddlq_s16(vrow4q); // vrow4q2 = vw40.u0+w41.u1 w42.u2+w43.u3
105                                              // vw44.u4+w45.u5 w46.u6+w47.u7
106     int32x4_t vrow5q2 = vpaddlq_s16(vrow5q); // vrow5q2 = vw50.u0+w51.u1 w52.u2+w53.u3
107                                              // vw54.u4+w55.u5 w56.u6+w57.u7
108     int32x4_t vrow6q2 = vpaddlq_s16(vrow6q); // vrow6q2 = vw60.u0+w61.u1 w62.u2+w63.u3
109                                              // vw64.u4+w65.u5 w66.u6+w67.u7
110     int32x4_t vrow7q2 = vpaddlq_s16(vrow7q); // vrow7q2 = vw70.u0+w71.u1 w72.u2+w73.u3
111                                              // vw74.u4+w75.u5 w76.u6+w77.u7
112 
113     vrow0q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow0q2), vget_high_s32(vrow0q2)),
114                            vpadd_s32(vget_low_s32(vrow1q2), vget_high_s32(vrow1q2)));
115     // vrow0q2 = vw00.u0+...+w03.u3 vw04.u4+...+w07.u7 vw10.u0+...+w13.u3
116     // vw14.u4+...+w17.u7
117     vrow2q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow2q2), vget_high_s32(vrow2q2)),
118                            vpadd_s32(vget_low_s32(vrow3q2), vget_high_s32(vrow3q2)));
119     // vrow0q2 = vw20.u0+...+w23.u3 vw24.u4+...+w27.u7 vw30.u0+...+w33.u3
120     // vw34.u4+...+w37.u7
121     vrow4q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow4q2), vget_high_s32(vrow4q2)),
122                            vpadd_s32(vget_low_s32(vrow5q2), vget_high_s32(vrow5q2)));
123     // vrow0q2 = vw40.u0+...+w43.u3 vw44.u4+...+w47.u7 vw50.u0+...+w53.u3
124     // vw54.u4+...+w57.u7
125     vrow6q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow6q2), vget_high_s32(vrow6q2)),
126                            vpadd_s32(vget_low_s32(vrow7q2), vget_high_s32(vrow7q2)));
127     // vrow0q2 = vw60.u0+...+w63.u3 vw64.u4+...+w67.u7 vw70.u0+...+w73.u3
128     // vw74.u4+...+w77.u7
129 
130     vrow0q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow0q2), vget_high_s32(vrow0q2)),
131                            vpadd_s32(vget_low_s32(vrow2q2), vget_high_s32(vrow2q2)));
132     // vrow0q2 = vw00.u0+...+w07.u7 vw10.u0+...+w17.u7 vw20.u0+...+w27.u7
133     // vw30.u0+...+w37.u7
134     vrow4q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow4q2), vget_high_s32(vrow4q2)),
135                            vpadd_s32(vget_low_s32(vrow6q2), vget_high_s32(vrow6q2)));
136     // vrow0q2 = vw40.u0+...+w47.u7 vw50.u0+...+w57.u7 vw60.u0+...+w67.u7
137     // vw70.u0+...+w77.u7
138 
139     result0123 = vaddq_s32(result0123, vrow0q2);
140     result4567 = vaddq_s32(result4567, vrow4q2);
141     u += 8;
142     wi += 64;
143   }
144   {
145     int8x8_t bias = vld1_s8(wi); // vw0    = b0  b1  b2  b3  b4  b5  b6  b7
146     int16x8_t scaled_bias = vmull_s8(bias, bias_scale);
147     result0123 = vaddw_s16(result0123, vget_low_s16(scaled_bias));
148     result4567 = vaddw_s16(result4567, vget_high_s16(scaled_bias));
149     *v++ = vget_lane_s32(vget_low_s32(result0123), 0) * *scales++;
150     if (num_out > 1)
151       *v++ = vget_lane_s32(vget_low_s32(result0123), 1) * *scales++;
152     if (num_out > 2)
153       *v++ = vget_lane_s32(vget_high_s32(result0123), 0) * *scales++;
154     if (num_out > 3)
155       *v++ = vget_lane_s32(vget_high_s32(result0123), 1) * *scales++;
156     if (num_out > 4)
157       *v++ = vget_lane_s32(vget_low_s32(result4567), 0) * *scales++;
158     if (num_out > 5)
159       *v++ = vget_lane_s32(vget_low_s32(result4567), 1) * *scales++;
160     if (num_out > 6)
161       *v++ = vget_lane_s32(vget_high_s32(result4567), 0) * *scales++;
162     if (num_out > 7)
163       *v = vget_lane_s32(vget_high_s32(result4567), 1) * *scales;
164   }
165 }
166 
matrixDotVector(int dim1,int dim2,const int8_t * wi,const TFloat * scales,const int8_t * u,TFloat * v)167 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales,
168                             const int8_t *u, TFloat *v) {
169   const int num_out = dim1;
170   const int num_in = dim2 - 1;
171   // Each call to a partial_func_ produces group_size outputs, except the
172   // last one, which can produce less.
173   const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
174   int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
175   int output = 0;
176 
177   int w_step = (rounded_num_in + 1) * group_size;
178 
179   for (; output + group_size <= num_out; output += group_size) {
180     PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v, kNumOutputsPerRegister);
181     wi += w_step;
182     scales += group_size;
183     v += group_size;
184   }
185   if (output < num_out)
186     PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v,
187                             num_out & (kNumOutputsPerRegister - 1));
188 }
189 
190 const IntSimdMatrix IntSimdMatrix::intSimdMatrixNEON = {
191     // Function.
192     matrixDotVector,
193     // Number of 32 bit outputs held in each register.
194     kNumOutputsPerRegister,
195     // Maximum number of registers that we will use to hold outputs.
196     kMaxOutputRegisters,
197     // Number of 8 bit inputs in the inputs register.
198     kNumInputsPerRegister,
199     // Number of inputs in each weight group.
200     kNumInputsPerGroup
201 };
202 
203 } // namespace tesseract.
204 
205 #endif /* __ARM_NEON */
206