1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
im2col_sgemm_pack8to4_int8_sse(const Mat & bottom_im2col,Mat & top_blob,const Mat & kernel,const Option & opt)15 static void im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt)
16 {
17     // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator);
18 
19     const int size = bottom_im2col.w;
20     const int maxk = bottom_im2col.h;
21     const int inch = bottom_im2col.c;
22 
23     const int outch = top_blob.c;
24 
25     // permute
26     Mat tmp;
27     if (size >= 2)
28         tmp.create(2 * maxk, inch, size / 2 + size % 2, 8u, 8, opt.workspace_allocator);
29     else
30         tmp.create(maxk, inch, size, 8u, 8, opt.workspace_allocator);
31     {
32         int remain_size_start = 0;
33         int nn_size = size >> 1;
34 
35         #pragma omp parallel for num_threads(opt.num_threads)
36         for (int ii = 0; ii < nn_size; ii++)
37         {
38             int i = remain_size_start + ii * 2;
39 
40             int64_t* tmpptr = tmp.channel(i / 2);
41 
42             for (int q = 0; q < inch; q++)
43             {
44                 const int64_t* img0 = (const int64_t*)bottom_im2col.channel(q) + i;
45 
46                 for (int k = 0; k < maxk; k++)
47                 {
48                     __m128i _v = _mm_loadu_si128((const __m128i*)img0);
49                     _mm_storeu_si128((__m128i*)tmpptr, _v);
50                     tmpptr += 2;
51                     img0 += size;
52                 }
53             }
54         }
55 
56         remain_size_start += nn_size << 1;
57 
58         #pragma omp parallel for num_threads(opt.num_threads)
59         for (int i = remain_size_start; i < size; i++)
60         {
61             int64_t* tmpptr = tmp.channel(i / 2 + i % 2);
62 
63             for (int q = 0; q < inch; q++)
64             {
65                 const int64_t* img0 = (const int64_t*)bottom_im2col.channel(q) + i;
66 
67                 for (int k = 0; k < maxk; k++)
68                 {
69                     tmpptr[0] = img0[0];
70                     tmpptr += 1;
71                     img0 += size;
72                 }
73             }
74         }
75     }
76 
77     #pragma omp parallel for num_threads(opt.num_threads)
78     for (int p = 0; p < outch; p++)
79     {
80         int* outptr0 = top_blob.channel(p);
81 
82         int i = 0;
83         for (; i + 1 < size; i += 2)
84         {
85             const signed char* tmpptr = tmp.channel(i / 2);
86             const signed char* kptr0 = kernel.channel(p);
87 
88             int nn = inch * maxk; // inch always > 0
89 
90             __m128i _sum00 = _mm_setzero_si128();
91             __m128i _sum01 = _mm_setzero_si128();
92             __m128i _sum02 = _mm_setzero_si128();
93             __m128i _sum03 = _mm_setzero_si128();
94             __m128i _sum10 = _mm_setzero_si128();
95             __m128i _sum11 = _mm_setzero_si128();
96             __m128i _sum12 = _mm_setzero_si128();
97             __m128i _sum13 = _mm_setzero_si128();
98 
99             int j = 0;
100             for (; j < nn; j++)
101             {
102                 // TODO use _mm_cvtepi8_epi16 on sse4.1
103                 __m128i _val01 = _mm_loadu_si128((const __m128i*)tmpptr);
104                 __m128i _extval01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _val01);
105                 __m128i _val0 = _mm_unpacklo_epi8(_val01, _extval01);
106                 __m128i _val1 = _mm_unpackhi_epi8(_val01, _extval01);
107 
108                 // TODO use _mm_cvtepi8_epi16 on sse4.1
109                 __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0);
110                 __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16));
111                 __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01);
112                 __m128i _extw23 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w23);
113                 __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01);
114                 __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01);
115                 __m128i _w2 = _mm_unpacklo_epi8(_w23, _extw23);
116                 __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23);
117 
118                 __m128i _sl00 = _mm_mullo_epi16(_val0, _w0);
119                 __m128i _sh00 = _mm_mulhi_epi16(_val0, _w0);
120                 __m128i _sl01 = _mm_mullo_epi16(_val0, _w1);
121                 __m128i _sh01 = _mm_mulhi_epi16(_val0, _w1);
122                 __m128i _sl02 = _mm_mullo_epi16(_val0, _w2);
123                 __m128i _sh02 = _mm_mulhi_epi16(_val0, _w2);
124                 __m128i _sl03 = _mm_mullo_epi16(_val0, _w3);
125                 __m128i _sh03 = _mm_mulhi_epi16(_val0, _w3);
126                 __m128i _sl10 = _mm_mullo_epi16(_val1, _w0);
127                 __m128i _sh10 = _mm_mulhi_epi16(_val1, _w0);
128                 __m128i _sl11 = _mm_mullo_epi16(_val1, _w1);
129                 __m128i _sh11 = _mm_mulhi_epi16(_val1, _w1);
130                 __m128i _sl12 = _mm_mullo_epi16(_val1, _w2);
131                 __m128i _sh12 = _mm_mulhi_epi16(_val1, _w2);
132                 __m128i _sl13 = _mm_mullo_epi16(_val1, _w3);
133                 __m128i _sh13 = _mm_mulhi_epi16(_val1, _w3);
134 
135                 _sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00));
136                 _sum01 = _mm_add_epi32(_sum01, _mm_unpacklo_epi16(_sl01, _sh01));
137                 _sum02 = _mm_add_epi32(_sum02, _mm_unpacklo_epi16(_sl02, _sh02));
138                 _sum03 = _mm_add_epi32(_sum03, _mm_unpacklo_epi16(_sl03, _sh03));
139                 _sum00 = _mm_add_epi32(_sum00, _mm_unpackhi_epi16(_sl00, _sh00));
140                 _sum01 = _mm_add_epi32(_sum01, _mm_unpackhi_epi16(_sl01, _sh01));
141                 _sum02 = _mm_add_epi32(_sum02, _mm_unpackhi_epi16(_sl02, _sh02));
142                 _sum03 = _mm_add_epi32(_sum03, _mm_unpackhi_epi16(_sl03, _sh03));
143                 _sum10 = _mm_add_epi32(_sum10, _mm_unpacklo_epi16(_sl10, _sh10));
144                 _sum11 = _mm_add_epi32(_sum11, _mm_unpacklo_epi16(_sl11, _sh11));
145                 _sum12 = _mm_add_epi32(_sum12, _mm_unpacklo_epi16(_sl12, _sh12));
146                 _sum13 = _mm_add_epi32(_sum13, _mm_unpacklo_epi16(_sl13, _sh13));
147                 _sum10 = _mm_add_epi32(_sum10, _mm_unpackhi_epi16(_sl10, _sh10));
148                 _sum11 = _mm_add_epi32(_sum11, _mm_unpackhi_epi16(_sl11, _sh11));
149                 _sum12 = _mm_add_epi32(_sum12, _mm_unpackhi_epi16(_sl12, _sh12));
150                 _sum13 = _mm_add_epi32(_sum13, _mm_unpackhi_epi16(_sl13, _sh13));
151 
152                 tmpptr += 16;
153                 kptr0 += 32;
154             }
155 
156             // transpose 4x4
157             {
158                 __m128i _tmp0, _tmp1, _tmp2, _tmp3;
159                 _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01);
160                 _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03);
161                 _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01);
162                 _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03);
163                 _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1);
164                 _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1);
165                 _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3);
166                 _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3);
167             }
168             {
169                 __m128i _tmp0, _tmp1, _tmp2, _tmp3;
170                 _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11);
171                 _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13);
172                 _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11);
173                 _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13);
174                 _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1);
175                 _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1);
176                 _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3);
177                 _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3);
178             }
179 
180             _sum00 = _mm_add_epi32(_sum00, _sum01);
181             _sum02 = _mm_add_epi32(_sum02, _sum03);
182             _sum10 = _mm_add_epi32(_sum10, _sum11);
183             _sum12 = _mm_add_epi32(_sum12, _sum13);
184 
185             _sum00 = _mm_add_epi32(_sum00, _sum02);
186             _sum10 = _mm_add_epi32(_sum10, _sum12);
187 
188             _mm_storeu_si128((__m128i*)outptr0, _sum00);
189             _mm_storeu_si128((__m128i*)(outptr0 + 4), _sum10);
190             outptr0 += 8;
191         }
192         for (; i < size; i++)
193         {
194             const signed char* tmpptr = tmp.channel(i / 2 + i % 2);
195             const signed char* kptr0 = kernel.channel(p);
196 
197             int nn = inch * maxk; // inch always > 0
198 
199             __m128i _sum0 = _mm_setzero_si128();
200             __m128i _sum1 = _mm_setzero_si128();
201             __m128i _sum2 = _mm_setzero_si128();
202             __m128i _sum3 = _mm_setzero_si128();
203 
204             int j = 0;
205             for (; j < nn; j++)
206             {
207                 // TODO use _mm_cvtepi8_epi16 on sse4.1
208                 __m128i _val = _mm_loadl_epi64((const __m128i*)tmpptr);
209                 _val = _mm_unpacklo_epi8(_val, _mm_cmpgt_epi8(_mm_setzero_si128(), _val));
210 
211                 // TODO use _mm_cvtepi8_epi16 on sse4.1
212                 __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr0);
213                 __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr0 + 16));
214                 __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01);
215                 __m128i _extw23 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w23);
216                 __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01);
217                 __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01);
218                 __m128i _w2 = _mm_unpacklo_epi8(_w23, _extw23);
219                 __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23);
220 
221                 __m128i _sl0 = _mm_mullo_epi16(_val, _w0);
222                 __m128i _sh0 = _mm_mulhi_epi16(_val, _w0);
223                 __m128i _sl1 = _mm_mullo_epi16(_val, _w1);
224                 __m128i _sh1 = _mm_mulhi_epi16(_val, _w1);
225                 __m128i _sl2 = _mm_mullo_epi16(_val, _w2);
226                 __m128i _sh2 = _mm_mulhi_epi16(_val, _w2);
227                 __m128i _sl3 = _mm_mullo_epi16(_val, _w3);
228                 __m128i _sh3 = _mm_mulhi_epi16(_val, _w3);
229 
230                 _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl0, _sh0));
231                 _sum1 = _mm_add_epi32(_sum1, _mm_unpacklo_epi16(_sl1, _sh1));
232                 _sum2 = _mm_add_epi32(_sum2, _mm_unpacklo_epi16(_sl2, _sh2));
233                 _sum3 = _mm_add_epi32(_sum3, _mm_unpacklo_epi16(_sl3, _sh3));
234                 _sum0 = _mm_add_epi32(_sum0, _mm_unpackhi_epi16(_sl0, _sh0));
235                 _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl1, _sh1));
236                 _sum2 = _mm_add_epi32(_sum2, _mm_unpackhi_epi16(_sl2, _sh2));
237                 _sum3 = _mm_add_epi32(_sum3, _mm_unpackhi_epi16(_sl3, _sh3));
238 
239                 tmpptr += 8;
240                 kptr0 += 32;
241             }
242 
243             // transpose 4x4
244             {
245                 __m128i _tmp0, _tmp1, _tmp2, _tmp3;
246                 _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1);
247                 _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3);
248                 _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1);
249                 _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3);
250                 _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1);
251                 _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1);
252                 _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3);
253                 _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3);
254             }
255 
256             _sum0 = _mm_add_epi32(_sum0, _sum1);
257             _sum2 = _mm_add_epi32(_sum2, _sum3);
258 
259             _sum0 = _mm_add_epi32(_sum0, _sum2);
260 
261             _mm_storeu_si128((__m128i*)outptr0, _sum0);
262             outptr0 += 4;
263         }
264     }
265 }
266 
convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_w,int kernel_h)267 static void convolution_im2col_sgemm_transform_kernel_pack8to4_int8_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h)
268 {
269     const int maxk = kernel_w * kernel_h;
270 
271     // interleave
272     // src = maxk-inch-outch
273     // dst = 8a-4b-maxk-inch/8a-outch/4b
274     Mat kernel = _kernel.reshape(maxk, inch, outch);
275     kernel_tm.create(32 * maxk, inch / 8, outch / 4, (size_t)1u);
276 
277     for (int q = 0; q + 3 < outch; q += 4)
278     {
279         signed char* g00 = kernel_tm.channel(q / 4);
280 
281         for (int p = 0; p + 7 < inch; p += 8)
282         {
283             for (int k = 0; k < maxk; k++)
284             {
285                 for (int i = 0; i < 4; i++)
286                 {
287                     for (int j = 0; j < 8; j++)
288                     {
289                         const signed char* k00 = kernel.channel(q + i).row<const signed char>(p + j);
290 
291                         g00[0] = k00[k];
292 
293                         g00++;
294                     }
295                 }
296             }
297         }
298     }
299 }
300 
convolution_im2col_sgemm_pack8to4_int8_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,const Option & opt)301 static void convolution_im2col_sgemm_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt)
302 {
303     int w = bottom_blob.w;
304     int inch = bottom_blob.c;
305 
306     int outw = top_blob.w;
307     int outh = top_blob.h;
308     const int size = outw * outh;
309 
310     const int maxk = kernel_w * kernel_h;
311 
312     // im2col
313     Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator);
314     {
315         const int gap = w * stride_h - outw * stride_w;
316 
317         #pragma omp parallel for num_threads(opt.num_threads)
318         for (int p = 0; p < inch; p++)
319         {
320             const Mat img = bottom_blob.channel(p);
321             int64_t* ptr = bottom_im2col.channel(p);
322 
323             for (int u = 0; u < kernel_h; u++)
324             {
325                 for (int v = 0; v < kernel_w; v++)
326                 {
327                     const int64_t* sptr = img.row<const int64_t>(dilation_h * u) + dilation_w * v;
328 
329                     for (int i = 0; i < outh; i++)
330                     {
331                         int j = 0;
332                         for (; j < outw; j++)
333                         {
334                             ptr[0] = sptr[0];
335 
336                             sptr += stride_w;
337                             ptr += 1;
338                         }
339 
340                         sptr += gap;
341                     }
342                 }
343             }
344         }
345     }
346 
347     im2col_sgemm_pack8to4_int8_sse(bottom_im2col, top_blob, kernel, opt);
348 }
349