1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
im2col_sgemm_pack8to4_int8_sse(const Mat & bottom_im2col,Mat & top_blob,const Mat & kernel,const Option & opt)15 static void im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt)
16 {
17 // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator);
18
19 const int size = bottom_im2col.w;
20 const int maxk = bottom_im2col.h;
21 const int inch = bottom_im2col.c;
22
23 const int outch = top_blob.c;
24
25 // permute
26 Mat tmp;
27 if (size >= 2)
28 tmp.create(2 * maxk, inch, size / 2 + size % 2, 8u, 8, opt.workspace_allocator);
29 else
30 tmp.create(maxk, inch, size, 8u, 8, opt.workspace_allocator);
31 {
32 int remain_size_start = 0;
33 int nn_size = size >> 1;
34
35 #pragma omp parallel for num_threads(opt.num_threads)
36 for (int ii = 0; ii < nn_size; ii++)
37 {
38 int i = remain_size_start + ii * 2;
39
40 int64_t* tmpptr = tmp.channel(i / 2);
41
42 for (int q = 0; q < inch; q++)
43 {
44 const int64_t* img0 = (const int64_t*)bottom_im2col.channel(q) + i;
45
46 for (int k = 0; k < maxk; k++)
47 {
48 __m128i _v = _mm_loadu_si128((const __m128i*)img0);
49 _mm_storeu_si128((__m128i*)tmpptr, _v);
50 tmpptr += 2;
51 img0 += size;
52 }
53 }
54 }
55
56 remain_size_start += nn_size << 1;
57
58 #pragma omp parallel for num_threads(opt.num_threads)
59 for (int i = remain_size_start; i < size; i++)
60 {
61 int64_t* tmpptr = tmp.channel(i / 2 + i % 2);
62
63 for (int q = 0; q < inch; q++)
64 {
65 const int64_t* img0 = (const int64_t*)bottom_im2col.channel(q) + i;
66
67 for (int k = 0; k < maxk; k++)
68 {
69 tmpptr[0] = img0[0];
70 tmpptr += 1;
71 img0 += size;
72 }
73 }
74 }
75 }
76
77 #pragma omp parallel for num_threads(opt.num_threads)
78 for (int p = 0; p < outch; p++)
79 {
80 int* outptr0 = top_blob.channel(p);
81
82 int i = 0;
83 for (; i + 1 < size; i += 2)
84 {
85 const signed char* tmpptr = tmp.channel(i / 2);
86 const signed char* kptr0 = kernel.channel(p);
87
88 int nn = inch * maxk; // inch always > 0
89
90 __m128i _sum00 = _mm_setzero_si128();
91 __m128i _sum01 = _mm_setzero_si128();
92 __m128i _sum02 = _mm_setzero_si128();
93 __m128i _sum03 = _mm_setzero_si128();
94 __m128i _sum10 = _mm_setzero_si128();
95 __m128i _sum11 = _mm_setzero_si128();
96 __m128i _sum12 = _mm_setzero_si128();
97 __m128i _sum13 = _mm_setzero_si128();
98
99 int j = 0;
100 for (; j < nn; j++)
101 {
102 // TODO use _mm_cvtepi8_epi16 on sse4.1
103 __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr);
104 __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01);
105 __m128i _val0 = _mm_unpacklo_epi8(_val01, _extval01);
106 __m128i _val1 = _mm_unpackhi_epi8(_val01, _extval01);
107
108 // TODO use _mm_cvtepi8_epi16 on sse4.1
109 __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0);
110 __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16));
111 __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01);
112 __m128i _extw23 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w23);
113 __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01);
114 __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01);
115 __m128i _w2 = _mm_unpacklo_epi8(_w23, _extw23);
116 __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23);
117
118 __m128i _sl00 = _mm_mullo_epi16(_val0, _w0);
119 __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0);
120 __m128i _sl01 = _mm_mullo_epi16(_val0, _w1);
121 __m128i _sh01 = _mm_mulhi_epi16(_val0, _w1);
122 __m128i _sl02 = _mm_mullo_epi16(_val0, _w2);
123 __m128i _sh02 = _mm_mulhi_epi16(_val0, _w2);
124 __m128i _sl03 = _mm_mullo_epi16(_val0, _w3);
125 __m128i _sh03 = _mm_mulhi_epi16(_val0, _w3);
126 __m128i _sl10 = _mm_mullo_epi16(_val1, _w0);
127 __m128i _sh10 = _mm_mulhi_epi16(_val1, _w0);
128 __m128i _sl11 = _mm_mullo_epi16(_val1, _w1);
129 __m128i _sh11 = _mm_mulhi_epi16(_val1, _w1);
130 __m128i _sl12 = _mm_mullo_epi16(_val1, _w2);
131 __m128i _sh12 = _mm_mulhi_epi16(_val1, _w2);
132 __m128i _sl13 = _mm_mullo_epi16(_val1, _w3);
133 __m128i _sh13 = _mm_mulhi_epi16(_val1, _w3);
134
135 _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00));
136 _sum01 = _mm_add_epi32(_sum01, _mm_unpacklo_epi16(_sl01, _sh01));
137 _sum02 = _mm_add_epi32(_sum02, _mm_unpacklo_epi16(_sl02, _sh02));
138 _sum03 = _mm_add_epi32(_sum03, _mm_unpacklo_epi16(_sl03, _sh03));
139 _sum00 = _mm_add_epi32(_sum00, _mm_unpackhi_epi16(_sl00, _sh00));
140 _sum01 = _mm_add_epi32(_sum01, _mm_unpackhi_epi16(_sl01, _sh01));
141 _sum02 = _mm_add_epi32(_sum02, _mm_unpackhi_epi16(_sl02, _sh02));
142 _sum03 = _mm_add_epi32(_sum03, _mm_unpackhi_epi16(_sl03, _sh03));
143 _sum10 = _mm_add_epi32(_sum10, _mm_unpacklo_epi16(_sl10, _sh10));
144 _sum11 = _mm_add_epi32(_sum11, _mm_unpacklo_epi16(_sl11, _sh11));
145 _sum12 = _mm_add_epi32(_sum12, _mm_unpacklo_epi16(_sl12, _sh12));
146 _sum13 = _mm_add_epi32(_sum13, _mm_unpacklo_epi16(_sl13, _sh13));
147 _sum10 = _mm_add_epi32(_sum10, _mm_unpackhi_epi16(_sl10, _sh10));
148 _sum11 = _mm_add_epi32(_sum11, _mm_unpackhi_epi16(_sl11, _sh11));
149 _sum12 = _mm_add_epi32(_sum12, _mm_unpackhi_epi16(_sl12, _sh12));
150 _sum13 = _mm_add_epi32(_sum13, _mm_unpackhi_epi16(_sl13, _sh13));
151
152 tmpptr += 16;
153 kptr0 += 32;
154 }
155
156 // transpose 4x4
157 {
158 __m128i _tmp0, _tmp1, _tmp2, _tmp3;
159 _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01);
160 _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03);
161 _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01);
162 _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03);
163 _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1);
164 _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1);
165 _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3);
166 _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3);
167 }
168 {
169 __m128i _tmp0, _tmp1, _tmp2, _tmp3;
170 _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11);
171 _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13);
172 _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11);
173 _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13);
174 _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1);
175 _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1);
176 _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3);
177 _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3);
178 }
179
180 _sum00 = _mm_add_epi32(_sum00, _sum01);
181 _sum02 = _mm_add_epi32(_sum02, _sum03);
182 _sum10 = _mm_add_epi32(_sum10, _sum11);
183 _sum12 = _mm_add_epi32(_sum12, _sum13);
184
185 _sum00 = _mm_add_epi32(_sum00, _sum02);
186 _sum10 = _mm_add_epi32(_sum10, _sum12);
187
188 _mm_storeu_si128((__m128i*)outptr0, _sum00);
189 _mm_storeu_si128((__m128i*)(outptr0 + 4), _sum10);
190 outptr0 += 8;
191 }
192 for (; i < size; i++)
193 {
194 const signed char* tmpptr = tmp.channel(i / 2 + i % 2);
195 const signed char* kptr0 = kernel.channel(p);
196
197 int nn = inch * maxk; // inch always > 0
198
199 __m128i _sum0 = _mm_setzero_si128();
200 __m128i _sum1 = _mm_setzero_si128();
201 __m128i _sum2 = _mm_setzero_si128();
202 __m128i _sum3 = _mm_setzero_si128();
203
204 int j = 0;
205 for (; j < nn; j++)
206 {
207 // TODO use _mm_cvtepi8_epi16 on sse4.1
208 __m128i _val = _mm_loadl_epi64((const __m128i*)tmpptr);
209 _val = _mm_unpacklo_epi8(_val, _mm_cmpgt_epi8(_mm_setzero_si128(), _val));
210
211 // TODO use _mm_cvtepi8_epi16 on sse4.1
212 __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0);
213 __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16));
214 __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01);
215 __m128i _extw23 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w23);
216 __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01);
217 __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01);
218 __m128i _w2 = _mm_unpacklo_epi8(_w23, _extw23);
219 __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23);
220
221 __m128i _sl0 = _mm_mullo_epi16(_val, _w0);
222 __m128i _sh0 = _mm_mulhi_epi16(_val, _w0);
223 __m128i _sl1 = _mm_mullo_epi16(_val, _w1);
224 __m128i _sh1 = _mm_mulhi_epi16(_val, _w1);
225 __m128i _sl2 = _mm_mullo_epi16(_val, _w2);
226 __m128i _sh2 = _mm_mulhi_epi16(_val, _w2);
227 __m128i _sl3 = _mm_mullo_epi16(_val, _w3);
228 __m128i _sh3 = _mm_mulhi_epi16(_val, _w3);
229
230 _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl0, _sh0));
231 _sum1 = _mm_add_epi32(_sum1, _mm_unpacklo_epi16(_sl1, _sh1));
232 _sum2 = _mm_add_epi32(_sum2, _mm_unpacklo_epi16(_sl2, _sh2));
233 _sum3 = _mm_add_epi32(_sum3, _mm_unpacklo_epi16(_sl3, _sh3));
234 _sum0 = _mm_add_epi32(_sum0, _mm_unpackhi_epi16(_sl0, _sh0));
235 _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl1, _sh1));
236 _sum2 = _mm_add_epi32(_sum2, _mm_unpackhi_epi16(_sl2, _sh2));
237 _sum3 = _mm_add_epi32(_sum3, _mm_unpackhi_epi16(_sl3, _sh3));
238
239 tmpptr += 8;
240 kptr0 += 32;
241 }
242
243 // transpose 4x4
244 {
245 __m128i _tmp0, _tmp1, _tmp2, _tmp3;
246 _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1);
247 _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3);
248 _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1);
249 _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3);
250 _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1);
251 _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1);
252 _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3);
253 _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3);
254 }
255
256 _sum0 = _mm_add_epi32(_sum0, _sum1);
257 _sum2 = _mm_add_epi32(_sum2, _sum3);
258
259 _sum0 = _mm_add_epi32(_sum0, _sum2);
260
261 _mm_storeu_si128((__m128i*)outptr0, _sum0);
262 outptr0 += 4;
263 }
264 }
265 }
266
convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_w,int kernel_h)267 static void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h)
268 {
269 const int maxk = kernel_w * kernel_h;
270
271 // interleave
272 // src = maxk-inch-outch
273 // dst = 8a-4b-maxk-inch/8a-outch/4b
274 Mat kernel = _kernel.reshape(maxk, inch, outch);
275 kernel_tm.create(32 * maxk, inch / 8, outch / 4, (size_t)1u);
276
277 for (int q = 0; q + 3 < outch; q += 4)
278 {
279 signed char* g00 = kernel_tm.channel(q / 4);
280
281 for (int p = 0; p + 7 < inch; p += 8)
282 {
283 for (int k = 0; k < maxk; k++)
284 {
285 for (int i = 0; i < 4; i++)
286 {
287 for (int j = 0; j < 8; j++)
288 {
289 const signed char* k00 = kernel.channel(q + i).row<const signed char>(p + j);
290
291 g00[0] = k00[k];
292
293 g00++;
294 }
295 }
296 }
297 }
298 }
299 }
300
convolution_im2col_sgemm_pack8to4_int8_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,const Option & opt)301 static void convolution_im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt)
302 {
303 int w = bottom_blob.w;
304 int inch = bottom_blob.c;
305
306 int outw = top_blob.w;
307 int outh = top_blob.h;
308 const int size = outw * outh;
309
310 const int maxk = kernel_w * kernel_h;
311
312 // im2col
313 Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator);
314 {
315 const int gap = w * stride_h - outw * stride_w;
316
317 #pragma omp parallel for num_threads(opt.num_threads)
318 for (int p = 0; p < inch; p++)
319 {
320 const Mat img = bottom_blob.channel(p);
321 int64_t* ptr = bottom_im2col.channel(p);
322
323 for (int u = 0; u < kernel_h; u++)
324 {
325 for (int v = 0; v < kernel_w; v++)
326 {
327 const int64_t* sptr = img.row<const int64_t>(dilation_h * u) + dilation_w * v;
328
329 for (int i = 0; i < outh; i++)
330 {
331 int j = 0;
332 for (; j < outw; j++)
333 {
334 ptr[0] = sptr[0];
335
336 sptr += stride_w;
337 ptr += 1;
338 }
339
340 sptr += gap;
341 }
342 }
343 }
344 }
345 }
346
347 im2col_sgemm_pack8to4_int8_sse(bottom_im2col, top_blob, kernel, opt);
348 }
349