1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // Copyright (C) 2019 BUG1989. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv_im2col_sgemm_int8_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)15 static void conv_im2col_sgemm_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel,
16                                        const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
17 {
18     int w = bottom_blob.w;
19     int inch = bottom_blob.c;
20 
21     int outw = top_blob.w;
22     int outh = top_blob.h;
23     int outch = top_blob.c;
24 
25     const signed char* kernel = _kernel;
26 
27     // im2row
28     Mat bottom_im2row(kernel_h * kernel_w * inch, outw * outh, 1UL, opt.workspace_allocator);
29     {
30         signed char* ret = (signed char*)bottom_im2row;
31         int retID = 0;
32 
33         for (int i = 0; i < outh; i++)
34         {
35             for (int j = 0; j < outw; j++)
36             {
37                 for (int p = 0; p < inch; p++)
38                 {
39                     const signed char* input = bottom_blob.channel(p);
40                     for (int u = 0; u < kernel_h; u++)
41                     {
42                         for (int v = 0; v < kernel_w; v++)
43                         {
44                             int row = u + i * stride_h;
45                             int col = v + j * stride_w;
46                             int index = row * w + col;
47                             ret[retID] = input[index];
48                             retID++;
49                         }
50                     }
51                 }
52             }
53         }
54     }
55 
56     int kernel_size = kernel_w * kernel_h;
57     int out_size = outw * outh;
58 
59     // int M = outch;  // outch
60     int N = outw * outh;                // outsize or out stride
61     int K = kernel_w * kernel_h * inch; // ksize * inch
62 
63     // bottom_im2row memory packed 4 x 4
64     Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, (size_t)1u, opt.workspace_allocator);
65     {
66         int nn_size = out_size >> 2;
67         int remain_size_start = nn_size << 2;
68 
69         #pragma omp parallel for num_threads(opt.num_threads)
70         for (int ii = 0; ii < nn_size; ii++)
71         {
72             int i = ii * 4;
73 
74             const signed char* img0 = bottom_im2row.row<signed char>(i);
75             const signed char* img1 = bottom_im2row.row<signed char>(i + 1);
76             const signed char* img2 = bottom_im2row.row<signed char>(i + 2);
77             const signed char* img3 = bottom_im2row.row<signed char>(i + 3);
78 
79             signed char* tmpptr = bottom_tm.channel(i / 4);
80 
81             int q = 0;
82             for (; q + 1 < inch * kernel_size; q = q + 2)
83             {
84                 tmpptr[0] = img0[0];
85                 tmpptr[1] = img0[1];
86                 tmpptr[2] = img1[0];
87                 tmpptr[3] = img1[1];
88                 tmpptr[4] = img2[0];
89                 tmpptr[5] = img2[1];
90                 tmpptr[6] = img3[0];
91                 tmpptr[7] = img3[1];
92 
93                 tmpptr += 8;
94                 img0 += 2;
95                 img1 += 2;
96                 img2 += 2;
97                 img3 += 2;
98             }
99 
100             for (; q < inch * kernel_size; q++)
101             {
102                 tmpptr[0] = img0[0];
103                 tmpptr[1] = img1[0];
104                 tmpptr[2] = img2[0];
105                 tmpptr[3] = img3[0];
106 
107                 tmpptr += 4;
108                 img0 += 1;
109                 img1 += 1;
110                 img2 += 1;
111                 img3 += 1;
112             }
113         }
114 
115         #pragma omp parallel for num_threads(opt.num_threads)
116         for (int i = remain_size_start; i < out_size; i++)
117         {
118             const signed char* img0 = bottom_im2row.row<signed char>(i);
119 
120             signed char* tmpptr = bottom_tm.channel(i / 4 + i % 4);
121 
122             int q = 0;
123             for (; q + 1 < inch * kernel_size; q = q + 2)
124             {
125                 tmpptr[0] = img0[0];
126                 tmpptr[1] = img0[1];
127 
128                 tmpptr += 2;
129                 img0 += 2;
130             }
131 
132             for (; q < inch * kernel_size; q++)
133             {
134                 tmpptr[0] = img0[0];
135 
136                 tmpptr += 1;
137                 img0 += 1;
138             }
139         }
140     }
141 
142     // kernel memory packed 4 x 4
143     Mat kernel_tm(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u, opt.workspace_allocator);
144     {
145         int nn_outch = 0;
146         int remain_outch_start = 0;
147 
148         nn_outch = outch >> 2;
149         remain_outch_start = nn_outch << 2;
150 
151         #pragma omp parallel for num_threads(opt.num_threads)
152         for (int pp = 0; pp < nn_outch; pp++)
153         {
154             int p = pp * 4;
155 
156             const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
157             const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
158             const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
159             const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
160 
161             signed char* ktmp = kernel_tm.channel(p / 4);
162 
163             int q = 0;
164             for (; q + 1 < inch * kernel_size; q += 2)
165             {
166                 ktmp[0] = k0[0];
167                 ktmp[1] = k0[1];
168                 ktmp[2] = k1[0];
169                 ktmp[3] = k1[1];
170                 ktmp[4] = k2[0];
171                 ktmp[5] = k2[1];
172                 ktmp[6] = k3[0];
173                 ktmp[7] = k3[1];
174 
175                 ktmp += 8;
176 
177                 k0 += 2;
178                 k1 += 2;
179                 k2 += 2;
180                 k3 += 2;
181             }
182 
183             for (; q < inch * kernel_size; q++)
184             {
185                 ktmp[0] = k0[0];
186                 ktmp[1] = k1[0];
187                 ktmp[2] = k2[0];
188                 ktmp[3] = k3[0];
189                 ktmp += 4;
190 
191                 k0 += 1;
192                 k1 += 1;
193                 k2 += 1;
194                 k3 += 1;
195             }
196         }
197 
198         #pragma omp parallel for num_threads(opt.num_threads)
199         for (int p = remain_outch_start; p < outch; p++)
200         {
201             const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
202 
203             signed char* ktmp = kernel_tm.channel(p / 4 + p % 4);
204 
205             int q = 0;
206             for (; q + 1 < inch * kernel_size; q = q + 2)
207             {
208                 ktmp[0] = k0[0];
209                 ktmp[1] = k0[1];
210                 ktmp += 2;
211                 k0 += 2;
212             }
213 
214             for (; q < inch * kernel_size; q++)
215             {
216                 ktmp[0] = k0[0];
217                 ktmp++;
218                 k0++;
219             }
220         }
221     }
222 
223     // 4x4
224     // sgemm(int M, int N, int K, float* A, float* B, float* C)
225     {
226         // int M = outch;  // outch
227         // int N = outw * outh; // outsize or out stride
228         // int L = kernel_w * kernel_h * inch; // ksize * inch
229 
230         int nn_outch = 0;
231         int remain_outch_start = 0;
232 
233         nn_outch = outch >> 2;
234         remain_outch_start = nn_outch << 2;
235 
236         #pragma omp parallel for num_threads(opt.num_threads)
237         for (int pp = 0; pp < nn_outch; pp++)
238         {
239             int i = pp * 4;
240 
241             int* output0 = top_blob.channel(i);
242             int* output1 = top_blob.channel(i + 1);
243             int* output2 = top_blob.channel(i + 2);
244             int* output3 = top_blob.channel(i + 3);
245 
246             int j = 0;
247             for (; j + 3 < N; j = j + 4)
248             {
249                 signed char* vb = bottom_tm.channel(j / 4);
250                 signed char* va = kernel_tm.channel(i / 4);
251 
252                 int sum0[4] = {0};
253                 int sum1[4] = {0};
254                 int sum2[4] = {0};
255                 int sum3[4] = {0};
256 
257                 int k = 0;
258 
259                 for (; k + 1 < K; k = k + 2)
260                 {
261                     for (int n = 0; n < 4; n++)
262                     {
263                         sum0[n] += (int)va[0] * vb[2 * n]; // k0
264                         sum0[n] += (int)va[1] * vb[2 * n + 1];
265 
266                         sum1[n] += (int)va[2] * vb[2 * n]; // k1
267                         sum1[n] += (int)va[3] * vb[2 * n + 1];
268 
269                         sum2[n] += (int)va[4] * vb[2 * n]; // k2
270                         sum2[n] += (int)va[5] * vb[2 * n + 1];
271 
272                         sum3[n] += (int)va[6] * vb[2 * n]; // k3
273                         sum3[n] += (int)va[7] * vb[2 * n + 1];
274                     }
275 
276                     va += 8;
277                     vb += 8;
278                 }
279 
280                 for (; k < K; k++)
281                 {
282                     for (int n = 0; n < 4; n++)
283                     {
284                         sum0[n] += (int)va[0] * vb[n];
285                         sum1[n] += (int)va[1] * vb[n];
286                         sum2[n] += (int)va[2] * vb[n];
287                         sum3[n] += (int)va[3] * vb[n];
288                     }
289 
290                     va += 4;
291                     vb += 4;
292                 }
293 
294                 for (int n = 0; n < 4; n++)
295                 {
296                     output0[n] = sum0[n];
297                     output1[n] = sum1[n];
298                     output2[n] = sum2[n];
299                     output3[n] = sum3[n];
300                 }
301                 output0 += 4;
302                 output1 += 4;
303                 output2 += 4;
304                 output3 += 4;
305             }
306 
307             for (; j < N; j++)
308             {
309                 int sum0 = 0;
310                 int sum1 = 0;
311                 int sum2 = 0;
312                 int sum3 = 0;
313 
314                 signed char* vb = bottom_tm.channel(j / 4 + j % 4);
315                 signed char* va = kernel_tm.channel(i / 4);
316 
317                 int k = 0;
318 
319                 for (; k + 1 < K; k = k + 2)
320                 {
321                     sum0 += (int)va[0] * vb[0];
322                     sum0 += (int)va[1] * vb[1];
323 
324                     sum1 += (int)va[2] * vb[0];
325                     sum1 += (int)va[3] * vb[1];
326 
327                     sum2 += (int)va[4] * vb[0];
328                     sum2 += (int)va[5] * vb[1];
329 
330                     sum3 += (int)va[6] * vb[0];
331                     sum3 += (int)va[7] * vb[1];
332 
333                     va += 8;
334                     vb += 2;
335                 }
336 
337                 for (; k < K; k++)
338                 {
339                     sum0 += (int)va[0] * vb[0];
340                     sum1 += (int)va[1] * vb[0];
341                     sum2 += (int)va[2] * vb[0];
342                     sum3 += (int)va[3] * vb[0];
343 
344                     va += 4;
345                     vb += 1;
346                 }
347 
348                 output0[0] = sum0;
349                 output1[0] = sum1;
350                 output2[0] = sum2;
351                 output3[0] = sum3;
352 
353                 output0++;
354                 output1++;
355                 output2++;
356                 output3++;
357             }
358         }
359 
360         #pragma omp parallel for num_threads(opt.num_threads)
361         for (int i = remain_outch_start; i < outch; i++)
362         {
363             int* output = top_blob.channel(i);
364 
365             int j = 0;
366             for (; j + 3 < N; j = j + 4)
367             {
368                 signed char* vb = bottom_tm.channel(j / 4);
369                 signed char* va = kernel_tm.channel(i / 4 + i % 4);
370                 int sum[4] = {0};
371 
372                 int k = 0;
373                 for (; k + 1 < K; k = k + 2)
374                 {
375                     for (int n = 0; n < 4; n++)
376                     {
377                         sum[n] += (int)va[0] * vb[2 * n];
378                         sum[n] += (int)va[1] * vb[2 * n + 1];
379                     }
380                     va += 2;
381                     vb += 8;
382                 }
383 
384                 for (; k < K; k++)
385                 {
386                     for (int n = 0; n < 4; n++)
387                     {
388                         sum[n] += (int)va[0] * vb[n];
389                     }
390                     va += 1;
391                     vb += 4;
392                 }
393 
394                 for (int n = 0; n < 4; n++)
395                 {
396                     output[n] = sum[n];
397                 }
398                 output += 4;
399             }
400 
401             for (; j < N; j++)
402             {
403                 int sum = 0;
404 
405                 signed char* vb = bottom_tm.channel(j / 4 + j % 4);
406                 signed char* va = kernel_tm.channel(i / 4 + i % 4);
407 
408                 for (int k = 0; k < K; k++)
409                 {
410                     sum += (int)va[0] * vb[0];
411 
412                     va += 1;
413                     vb += 1;
414                 }
415                 output[0] = sum;
416 
417                 output++;
418             }
419         }
420     }
421 
422     // // sgemm(int M, int N, int K, float* A, float* B, float* C)
423     // {
424     //     for (int i=0; i<M; i++)
425     //     {
426     //         int* output = top_blob.channel(i);
427 
428     //         for (int j=0; j<N; j++)
429     //         {
430     //             int sum = 0;
431 
432     //             signed char* vb = (signed char*)bottom_im2row + K * j;
433     //             const signed char* va = kernel + K * i;
434 
435     //             for (int k=0; k<K; k++)
436     //             {
437     //                 sum += (int)va[0] * vb[0];
438 
439     //                 va += 1;
440     //                 vb += 1;
441     //             }
442     //             output[0] = sum;
443 
444     //             output++;
445     //         }
446     //     }
447     // }
448 }
449 
conv_im2col_sgemm_int8_dequant_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Mat & _bias,std::vector<float> scale_dequant,const Option & opt)450 static void conv_im2col_sgemm_int8_dequant_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel,
451         const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Mat& _bias, std::vector<float> scale_dequant, const Option& opt)
452 {
453     int w = bottom_blob.w;
454     int inch = bottom_blob.c;
455 
456     int outw = top_blob.w;
457     int outh = top_blob.h;
458     int outch = top_blob.c;
459 
460     const signed char* kernel = _kernel;
461     const float* bias = _bias;
462 
463     // im2row
464     Mat bottom_im2row(kernel_h * kernel_w * inch, outw * outh, 1UL, opt.workspace_allocator);
465     {
466         signed char* ret = (signed char*)bottom_im2row;
467         int retID = 0;
468 
469         for (int i = 0; i < outh; i++)
470         {
471             for (int j = 0; j < outw; j++)
472             {
473                 for (int p = 0; p < inch; p++)
474                 {
475                     const signed char* input = bottom_blob.channel(p);
476                     for (int u = 0; u < kernel_h; u++)
477                     {
478                         for (int v = 0; v < kernel_w; v++)
479                         {
480                             int row = u + i * stride_h;
481                             int col = v + j * stride_w;
482                             int index = row * w + col;
483                             ret[retID] = input[index];
484                             retID++;
485                         }
486                     }
487                 }
488             }
489         }
490     }
491 
492     int kernel_size = kernel_w * kernel_h;
493     int out_size = outw * outh;
494 
495     // int M = outch;  // outch
496     int N = outw * outh;                // outsize or out stride
497     int K = kernel_w * kernel_h * inch; // ksize * inch
498 
499     // bottom_im2row memory packed 4 x 4
500     Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, (size_t)1u, opt.workspace_allocator);
501     {
502         int nn_size = out_size >> 2;
503         int remain_size_start = nn_size << 2;
504 
505         #pragma omp parallel for num_threads(opt.num_threads)
506         for (int ii = 0; ii < nn_size; ii++)
507         {
508             int i = ii * 4;
509 
510             const signed char* img0 = bottom_im2row.row<signed char>(i);
511             const signed char* img1 = bottom_im2row.row<signed char>(i + 1);
512             const signed char* img2 = bottom_im2row.row<signed char>(i + 2);
513             const signed char* img3 = bottom_im2row.row<signed char>(i + 3);
514 
515             signed char* tmpptr = bottom_tm.channel(i / 4);
516 
517             int q = 0;
518             for (; q + 1 < inch * kernel_size; q = q + 2)
519             {
520                 tmpptr[0] = img0[0];
521                 tmpptr[1] = img0[1];
522                 tmpptr[2] = img1[0];
523                 tmpptr[3] = img1[1];
524                 tmpptr[4] = img2[0];
525                 tmpptr[5] = img2[1];
526                 tmpptr[6] = img3[0];
527                 tmpptr[7] = img3[1];
528 
529                 tmpptr += 8;
530                 img0 += 2;
531                 img1 += 2;
532                 img2 += 2;
533                 img3 += 2;
534             }
535 
536             for (; q < inch * kernel_size; q++)
537             {
538                 tmpptr[0] = img0[0];
539                 tmpptr[1] = img1[0];
540                 tmpptr[2] = img2[0];
541                 tmpptr[3] = img3[0];
542 
543                 tmpptr += 4;
544                 img0 += 1;
545                 img1 += 1;
546                 img2 += 1;
547                 img3 += 1;
548             }
549         }
550 
551         #pragma omp parallel for num_threads(opt.num_threads)
552         for (int i = remain_size_start; i < out_size; i++)
553         {
554             const signed char* img0 = bottom_im2row.row<signed char>(i);
555 
556             signed char* tmpptr = bottom_tm.channel(i / 4 + i % 4);
557 
558             int q = 0;
559             for (; q + 1 < inch * kernel_size; q = q + 2)
560             {
561                 tmpptr[0] = img0[0];
562                 tmpptr[1] = img0[1];
563 
564                 tmpptr += 2;
565                 img0 += 2;
566             }
567 
568             for (; q < inch * kernel_size; q++)
569             {
570                 tmpptr[0] = img0[0];
571 
572                 tmpptr += 1;
573                 img0 += 1;
574             }
575         }
576     }
577 
578     // kernel memory packed 4 x 4
579     Mat kernel_tm(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u, opt.workspace_allocator);
580     {
581         int nn_outch = 0;
582         int remain_outch_start = 0;
583 
584         nn_outch = outch >> 2;
585         remain_outch_start = nn_outch << 2;
586 
587         #pragma omp parallel for num_threads(opt.num_threads)
588         for (int pp = 0; pp < nn_outch; pp++)
589         {
590             int p = pp * 4;
591 
592             const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
593             const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
594             const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
595             const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
596 
597             signed char* ktmp = kernel_tm.channel(p / 4);
598 
599             int q = 0;
600             for (; q + 1 < inch * kernel_size; q += 2)
601             {
602                 ktmp[0] = k0[0];
603                 ktmp[1] = k0[1];
604                 ktmp[2] = k1[0];
605                 ktmp[3] = k1[1];
606                 ktmp[4] = k2[0];
607                 ktmp[5] = k2[1];
608                 ktmp[6] = k3[0];
609                 ktmp[7] = k3[1];
610 
611                 ktmp += 8;
612 
613                 k0 += 2;
614                 k1 += 2;
615                 k2 += 2;
616                 k3 += 2;
617             }
618 
619             for (; q < inch * kernel_size; q++)
620             {
621                 ktmp[0] = k0[0];
622                 ktmp[1] = k1[0];
623                 ktmp[2] = k2[0];
624                 ktmp[3] = k3[0];
625                 ktmp += 4;
626 
627                 k0 += 1;
628                 k1 += 1;
629                 k2 += 1;
630                 k3 += 1;
631             }
632         }
633 
634         #pragma omp parallel for num_threads(opt.num_threads)
635         for (int p = remain_outch_start; p < outch; p++)
636         {
637             const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
638 
639             signed char* ktmp = kernel_tm.channel(p / 4 + p % 4);
640 
641             int q = 0;
642             for (; q + 1 < inch * kernel_size; q = q + 2)
643             {
644                 ktmp[0] = k0[0];
645                 ktmp[1] = k0[1];
646                 ktmp += 2;
647                 k0 += 2;
648             }
649 
650             for (; q < inch * kernel_size; q++)
651             {
652                 ktmp[0] = k0[0];
653                 ktmp++;
654                 k0++;
655             }
656         }
657     }
658 
659     // 4x4
660     // sgemm(int M, int N, int K, float* A, float* B, float* C)
661     {
662         // int M = outch;  // outch
663         // int N = outw * outh; // outsize or out stride
664         // int L = kernel_w * kernel_h * inch; // ksize * inch
665 
666         int nn_outch = 0;
667         int remain_outch_start = 0;
668 
669         nn_outch = outch >> 2;
670         remain_outch_start = nn_outch << 2;
671 
672         #pragma omp parallel for num_threads(opt.num_threads)
673         for (int pp = 0; pp < nn_outch; pp++)
674         {
675             int i = pp * 4;
676 
677             const float bias0 = bias ? bias[i] : 0.f;
678             const float bias1 = bias ? bias[i + 1] : 0.f;
679             const float bias2 = bias ? bias[i + 2] : 0.f;
680             const float bias3 = bias ? bias[i + 3] : 0.f;
681 
682             const float scale_dequant0 = scale_dequant[i];
683             const float scale_dequant1 = scale_dequant[i + 1];
684             const float scale_dequant2 = scale_dequant[i + 2];
685             const float scale_dequant3 = scale_dequant[i + 3];
686 
687             float* output0 = top_blob.channel(i);
688             float* output1 = top_blob.channel(i + 1);
689             float* output2 = top_blob.channel(i + 2);
690             float* output3 = top_blob.channel(i + 3);
691 
692             int j = 0;
693             for (; j + 3 < N; j = j + 4)
694             {
695                 signed char* vb = bottom_tm.channel(j / 4);
696                 signed char* va = kernel_tm.channel(i / 4);
697 
698                 int sum0[4] = {0};
699                 int sum1[4] = {0};
700                 int sum2[4] = {0};
701                 int sum3[4] = {0};
702 
703                 int k = 0;
704 
705                 for (; k + 1 < K; k = k + 2)
706                 {
707                     for (int n = 0; n < 4; n++)
708                     {
709                         sum0[n] += (int)va[0] * vb[2 * n]; // k0
710                         sum0[n] += (int)va[1] * vb[2 * n + 1];
711 
712                         sum1[n] += (int)va[2] * vb[2 * n]; // k1
713                         sum1[n] += (int)va[3] * vb[2 * n + 1];
714 
715                         sum2[n] += (int)va[4] * vb[2 * n]; // k2
716                         sum2[n] += (int)va[5] * vb[2 * n + 1];
717 
718                         sum3[n] += (int)va[6] * vb[2 * n]; // k3
719                         sum3[n] += (int)va[7] * vb[2 * n + 1];
720                     }
721 
722                     va += 8;
723                     vb += 8;
724                 }
725 
726                 for (; k < K; k++)
727                 {
728                     for (int n = 0; n < 4; n++)
729                     {
730                         sum0[n] += (int)va[0] * vb[n];
731                         sum1[n] += (int)va[1] * vb[n];
732                         sum2[n] += (int)va[2] * vb[n];
733                         sum3[n] += (int)va[3] * vb[n];
734                     }
735 
736                     va += 4;
737                     vb += 4;
738                 }
739 
740                 for (int n = 0; n < 4; n++)
741                 {
742                     output0[n] = (float)sum0[n] * scale_dequant0 + bias0;
743                     output1[n] = (float)sum1[n] * scale_dequant1 + bias1;
744                     output2[n] = (float)sum2[n] * scale_dequant2 + bias2;
745                     output3[n] = (float)sum3[n] * scale_dequant3 + bias3;
746                 }
747                 output0 += 4;
748                 output1 += 4;
749                 output2 += 4;
750                 output3 += 4;
751             }
752 
753             for (; j < N; j++)
754             {
755                 int sum0 = 0;
756                 int sum1 = 0;
757                 int sum2 = 0;
758                 int sum3 = 0;
759 
760                 signed char* vb = bottom_tm.channel(j / 4 + j % 4);
761                 signed char* va = kernel_tm.channel(i / 4);
762 
763                 int k = 0;
764 
765                 for (; k + 1 < K; k = k + 2)
766                 {
767                     sum0 += (int)va[0] * vb[0];
768                     sum0 += (int)va[1] * vb[1];
769 
770                     sum1 += (int)va[2] * vb[0];
771                     sum1 += (int)va[3] * vb[1];
772 
773                     sum2 += (int)va[4] * vb[0];
774                     sum2 += (int)va[5] * vb[1];
775 
776                     sum3 += (int)va[6] * vb[0];
777                     sum3 += (int)va[7] * vb[1];
778 
779                     va += 8;
780                     vb += 2;
781                 }
782 
783                 for (; k < K; k++)
784                 {
785                     sum0 += (int)va[0] * vb[0];
786                     sum1 += (int)va[1] * vb[0];
787                     sum2 += (int)va[2] * vb[0];
788                     sum3 += (int)va[3] * vb[0];
789 
790                     va += 4;
791                     vb += 1;
792                 }
793 
794                 output0[0] = (float)sum0 * scale_dequant0 + bias0;
795                 output1[0] = (float)sum1 * scale_dequant1 + bias1;
796                 output2[0] = (float)sum2 * scale_dequant2 + bias2;
797                 output3[0] = (float)sum3 * scale_dequant3 + bias3;
798 
799                 output0++;
800                 output1++;
801                 output2++;
802                 output3++;
803             }
804         }
805 
806         #pragma omp parallel for num_threads(opt.num_threads)
807         for (int i = remain_outch_start; i < outch; i++)
808         {
809             float* output = top_blob.channel(i);
810 
811             const float bias0 = bias ? bias[i] : 0.f;
812             const float scale_dequant0 = scale_dequant[i];
813 
814             int j = 0;
815             for (; j + 3 < N; j = j + 4)
816             {
817                 signed char* vb = bottom_tm.channel(j / 4);
818                 signed char* va = kernel_tm.channel(i / 4 + i % 4);
819                 int sum[4] = {0};
820 
821                 int k = 0;
822                 for (; k + 1 < K; k = k + 2)
823                 {
824                     for (int n = 0; n < 4; n++)
825                     {
826                         sum[n] += (int)va[0] * vb[2 * n];
827                         sum[n] += (int)va[1] * vb[2 * n + 1];
828                     }
829                     va += 2;
830                     vb += 8;
831                 }
832 
833                 for (; k < K; k++)
834                 {
835                     for (int n = 0; n < 4; n++)
836                     {
837                         sum[n] += (int)va[0] * vb[n];
838                     }
839                     va += 1;
840                     vb += 4;
841                 }
842 
843                 for (int n = 0; n < 4; n++)
844                 {
845                     output[n] = (float)sum[n] * scale_dequant0 + bias0;
846                 }
847                 output += 4;
848             }
849 
850             for (; j < N; j++)
851             {
852                 int sum = 0;
853 
854                 signed char* vb = bottom_tm.channel(j / 4 + j % 4);
855                 signed char* va = kernel_tm.channel(i / 4 + i % 4);
856 
857                 for (int k = 0; k < K; k++)
858                 {
859                     sum += (int)va[0] * vb[0];
860 
861                     va += 1;
862                     vb += 1;
863                 }
864                 output[0] = (float)sum * scale_dequant0 + bias0;
865 
866                 output++;
867             }
868         }
869     }
870 
871     // // sgemm(int M, int N, int K, float* A, float* B, float* C)
872     // {
873     //     for (int i=0; i<M; i++)
874     //     {
875     //         int* output = top_blob.channel(i);
876 
877     //         for (int j=0; j<N; j++)
878     //         {
879     //             int sum = 0;
880 
881     //             signed char* vb = (signed char*)bottom_im2row + K * j;
882     //             const signed char* va = kernel + K * i;
883 
884     //             for (int k=0; k<K; k++)
885     //             {
886     //                 sum += (int)va[0] * vb[0];
887 
888     //                 va += 1;
889     //                 vb += 1;
890     //             }
891     //             output[0] = sum;
892 
893     //             output++;
894     //         }
895     //     }
896     // }
897 }
898 
conv_im2col_sgemm_int8_requant_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Mat & _bias,std::vector<float> scale_requant,const Option & opt)899 static void conv_im2col_sgemm_int8_requant_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel,
900         const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Mat& _bias, std::vector<float> scale_requant, const Option& opt)
901 {
902     int w = bottom_blob.w;
903     int inch = bottom_blob.c;
904 
905     int outw = top_blob.w;
906     int outh = top_blob.h;
907     int outch = top_blob.c;
908 
909     const signed char* kernel = _kernel;
910     const float* bias = _bias;
911 
912     // im2row
913     Mat bottom_im2row(kernel_h * kernel_w * inch, outw * outh, 1UL, opt.workspace_allocator);
914     {
915         signed char* ret = (signed char*)bottom_im2row;
916         int retID = 0;
917 
918         for (int i = 0; i < outh; i++)
919         {
920             for (int j = 0; j < outw; j++)
921             {
922                 for (int p = 0; p < inch; p++)
923                 {
924                     const signed char* input = bottom_blob.channel(p);
925                     for (int u = 0; u < kernel_h; u++)
926                     {
927                         for (int v = 0; v < kernel_w; v++)
928                         {
929                             int row = u + i * stride_h;
930                             int col = v + j * stride_w;
931                             int index = row * w + col;
932                             ret[retID] = input[index];
933                             retID++;
934                         }
935                     }
936                 }
937             }
938         }
939     }
940 
941     int kernel_size = kernel_w * kernel_h;
942     int out_size = outw * outh;
943 
944     // int M = outch;  // outch
945     int N = outw * outh;                // outsize or out stride
946     int K = kernel_w * kernel_h * inch; // ksize * inch
947 
948     // bottom_im2row memory packed 4 x 4
949     Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, (size_t)1u, opt.workspace_allocator);
950     {
951         int nn_size = out_size >> 2;
952         int remain_size_start = nn_size << 2;
953 
954         #pragma omp parallel for num_threads(opt.num_threads)
955         for (int ii = 0; ii < nn_size; ii++)
956         {
957             int i = ii * 4;
958 
959             const signed char* img0 = bottom_im2row.row<signed char>(i);
960             const signed char* img1 = bottom_im2row.row<signed char>(i + 1);
961             const signed char* img2 = bottom_im2row.row<signed char>(i + 2);
962             const signed char* img3 = bottom_im2row.row<signed char>(i + 3);
963 
964             signed char* tmpptr = bottom_tm.channel(i / 4);
965 
966             int q = 0;
967             for (; q + 1 < inch * kernel_size; q = q + 2)
968             {
969                 tmpptr[0] = img0[0];
970                 tmpptr[1] = img0[1];
971                 tmpptr[2] = img1[0];
972                 tmpptr[3] = img1[1];
973                 tmpptr[4] = img2[0];
974                 tmpptr[5] = img2[1];
975                 tmpptr[6] = img3[0];
976                 tmpptr[7] = img3[1];
977 
978                 tmpptr += 8;
979                 img0 += 2;
980                 img1 += 2;
981                 img2 += 2;
982                 img3 += 2;
983             }
984 
985             for (; q < inch * kernel_size; q++)
986             {
987                 tmpptr[0] = img0[0];
988                 tmpptr[1] = img1[0];
989                 tmpptr[2] = img2[0];
990                 tmpptr[3] = img3[0];
991 
992                 tmpptr += 4;
993                 img0 += 1;
994                 img1 += 1;
995                 img2 += 1;
996                 img3 += 1;
997             }
998         }
999 
1000         #pragma omp parallel for num_threads(opt.num_threads)
1001         for (int i = remain_size_start; i < out_size; i++)
1002         {
1003             const signed char* img0 = bottom_im2row.row<signed char>(i);
1004 
1005             signed char* tmpptr = bottom_tm.channel(i / 4 + i % 4);
1006 
1007             int q = 0;
1008             for (; q + 1 < inch * kernel_size; q = q + 2)
1009             {
1010                 tmpptr[0] = img0[0];
1011                 tmpptr[1] = img0[1];
1012 
1013                 tmpptr += 2;
1014                 img0 += 2;
1015             }
1016 
1017             for (; q < inch * kernel_size; q++)
1018             {
1019                 tmpptr[0] = img0[0];
1020 
1021                 tmpptr += 1;
1022                 img0 += 1;
1023             }
1024         }
1025     }
1026 
1027     // kernel memory packed 4 x 4
1028     Mat kernel_tm(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u, opt.workspace_allocator);
1029     {
1030         int nn_outch = 0;
1031         int remain_outch_start = 0;
1032 
1033         nn_outch = outch >> 2;
1034         remain_outch_start = nn_outch << 2;
1035 
1036         #pragma omp parallel for num_threads(opt.num_threads)
1037         for (int pp = 0; pp < nn_outch; pp++)
1038         {
1039             int p = pp * 4;
1040 
1041             const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
1042             const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
1043             const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
1044             const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
1045 
1046             signed char* ktmp = kernel_tm.channel(p / 4);
1047 
1048             int q = 0;
1049             for (; q + 1 < inch * kernel_size; q += 2)
1050             {
1051                 ktmp[0] = k0[0];
1052                 ktmp[1] = k0[1];
1053                 ktmp[2] = k1[0];
1054                 ktmp[3] = k1[1];
1055                 ktmp[4] = k2[0];
1056                 ktmp[5] = k2[1];
1057                 ktmp[6] = k3[0];
1058                 ktmp[7] = k3[1];
1059 
1060                 ktmp += 8;
1061 
1062                 k0 += 2;
1063                 k1 += 2;
1064                 k2 += 2;
1065                 k3 += 2;
1066             }
1067 
1068             for (; q < inch * kernel_size; q++)
1069             {
1070                 ktmp[0] = k0[0];
1071                 ktmp[1] = k1[0];
1072                 ktmp[2] = k2[0];
1073                 ktmp[3] = k3[0];
1074                 ktmp += 4;
1075 
1076                 k0 += 1;
1077                 k1 += 1;
1078                 k2 += 1;
1079                 k3 += 1;
1080             }
1081         }
1082 
1083         #pragma omp parallel for num_threads(opt.num_threads)
1084         for (int p = remain_outch_start; p < outch; p++)
1085         {
1086             const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
1087 
1088             signed char* ktmp = kernel_tm.channel(p / 4 + p % 4);
1089 
1090             int q = 0;
1091             for (; q + 1 < inch * kernel_size; q = q + 2)
1092             {
1093                 ktmp[0] = k0[0];
1094                 ktmp[1] = k0[1];
1095                 ktmp += 2;
1096                 k0 += 2;
1097             }
1098 
1099             for (; q < inch * kernel_size; q++)
1100             {
1101                 ktmp[0] = k0[0];
1102                 ktmp++;
1103                 k0++;
1104             }
1105         }
1106     }
1107 
1108     // 4x4
1109     // sgemm(int M, int N, int K, float* A, float* B, float* C)
1110     {
1111         // int M = outch;  // outch
1112         // int N = outw * outh; // outsize or out stride
1113         // int L = kernel_w * kernel_h * inch; // ksize * inch
1114 
1115         int nn_outch = 0;
1116         int remain_outch_start = 0;
1117 
1118         nn_outch = outch >> 2;
1119         remain_outch_start = nn_outch << 2;
1120 
1121         #pragma omp parallel for num_threads(opt.num_threads)
1122         for (int pp = 0; pp < nn_outch; pp++)
1123         {
1124             int i = pp * 4;
1125 
1126             signed char* output0 = top_blob.channel(i);
1127             signed char* output1 = top_blob.channel(i + 1);
1128             signed char* output2 = top_blob.channel(i + 2);
1129             signed char* output3 = top_blob.channel(i + 3);
1130 
1131             const float bias0 = bias ? bias[i] : 0.f;
1132             const float bias1 = bias ? bias[i + 1] : 0.f;
1133             const float bias2 = bias ? bias[i + 2] : 0.f;
1134             const float bias3 = bias ? bias[i + 3] : 0.f;
1135 
1136             const float scale_requant_in0 = scale_requant[2 * i];
1137             const float scale_requant_out0 = scale_requant[2 * i + 1];
1138             const float scale_requant_in1 = scale_requant[2 * (i + 1)];
1139             const float scale_requant_out1 = scale_requant[2 * (i + 1) + 1];
1140             const float scale_requant_in2 = scale_requant[2 * (i + 2)];
1141             const float scale_requant_out2 = scale_requant[2 * (i + 2) + 1];
1142             const float scale_requant_in3 = scale_requant[2 * (i + 3)];
1143             const float scale_requant_out3 = scale_requant[2 * (i + 3) + 1];
1144 
1145             int j = 0;
1146             for (; j + 3 < N; j = j + 4)
1147             {
1148                 signed char* vb = bottom_tm.channel(j / 4);
1149                 signed char* va = kernel_tm.channel(i / 4);
1150 
1151                 int sum0[4] = {0};
1152                 int sum1[4] = {0};
1153                 int sum2[4] = {0};
1154                 int sum3[4] = {0};
1155 
1156                 int k = 0;
1157 
1158                 for (; k + 1 < K; k = k + 2)
1159                 {
1160                     for (int n = 0; n < 4; n++)
1161                     {
1162                         sum0[n] += (int)va[0] * vb[2 * n]; // k0
1163                         sum0[n] += (int)va[1] * vb[2 * n + 1];
1164 
1165                         sum1[n] += (int)va[2] * vb[2 * n]; // k1
1166                         sum1[n] += (int)va[3] * vb[2 * n + 1];
1167 
1168                         sum2[n] += (int)va[4] * vb[2 * n]; // k2
1169                         sum2[n] += (int)va[5] * vb[2 * n + 1];
1170 
1171                         sum3[n] += (int)va[6] * vb[2 * n]; // k3
1172                         sum3[n] += (int)va[7] * vb[2 * n + 1];
1173                     }
1174 
1175                     va += 8;
1176                     vb += 8;
1177                 }
1178 
1179                 for (; k < K; k++)
1180                 {
1181                     for (int n = 0; n < 4; n++)
1182                     {
1183                         sum0[n] += (int)va[0] * vb[n];
1184                         sum1[n] += (int)va[1] * vb[n];
1185                         sum2[n] += (int)va[2] * vb[n];
1186                         sum3[n] += (int)va[3] * vb[n];
1187                     }
1188 
1189                     va += 4;
1190                     vb += 4;
1191                 }
1192 
1193                 for (int n = 0; n < 4; n++)
1194                 {
1195                     output0[n] = float2int8(((float)sum0[n] * scale_requant_in0 + bias0) * scale_requant_out0);
1196                     output1[n] = float2int8(((float)sum1[n] * scale_requant_in1 + bias1) * scale_requant_out1);
1197                     output2[n] = float2int8(((float)sum2[n] * scale_requant_in2 + bias2) * scale_requant_out2);
1198                     output3[n] = float2int8(((float)sum3[n] * scale_requant_in3 + bias3) * scale_requant_out3);
1199                 }
1200                 output0 += 4;
1201                 output1 += 4;
1202                 output2 += 4;
1203                 output3 += 4;
1204             }
1205 
1206             for (; j < N; j++)
1207             {
1208                 int sum0 = 0;
1209                 int sum1 = 0;
1210                 int sum2 = 0;
1211                 int sum3 = 0;
1212 
1213                 signed char* vb = bottom_tm.channel(j / 4 + j % 4);
1214                 signed char* va = kernel_tm.channel(i / 4);
1215 
1216                 int k = 0;
1217 
1218                 for (; k + 1 < K; k = k + 2)
1219                 {
1220                     sum0 += (int)va[0] * vb[0];
1221                     sum0 += (int)va[1] * vb[1];
1222 
1223                     sum1 += (int)va[2] * vb[0];
1224                     sum1 += (int)va[3] * vb[1];
1225 
1226                     sum2 += (int)va[4] * vb[0];
1227                     sum2 += (int)va[5] * vb[1];
1228 
1229                     sum3 += (int)va[6] * vb[0];
1230                     sum3 += (int)va[7] * vb[1];
1231 
1232                     va += 8;
1233                     vb += 2;
1234                 }
1235 
1236                 for (; k < K; k++)
1237                 {
1238                     sum0 += (int)va[0] * vb[0];
1239                     sum1 += (int)va[1] * vb[0];
1240                     sum2 += (int)va[2] * vb[0];
1241                     sum3 += (int)va[3] * vb[0];
1242 
1243                     va += 4;
1244                     vb += 1;
1245                 }
1246 
1247                 output0[0] = float2int8(((float)sum0 * scale_requant_in0 + bias0) * scale_requant_out0);
1248                 output1[0] = float2int8(((float)sum1 * scale_requant_in1 + bias1) * scale_requant_out1);
1249                 output2[0] = float2int8(((float)sum2 * scale_requant_in2 + bias2) * scale_requant_out2);
1250                 output3[0] = float2int8(((float)sum3 * scale_requant_in3 + bias3) * scale_requant_out3);
1251 
1252                 output0++;
1253                 output1++;
1254                 output2++;
1255                 output3++;
1256             }
1257         }
1258 
1259         #pragma omp parallel for num_threads(opt.num_threads)
1260         for (int i = remain_outch_start; i < outch; i++)
1261         {
1262             signed char* output = top_blob.channel(i);
1263 
1264             const float bias0 = bias ? bias[i] : 0.f;
1265 
1266             const float scale_requant_in0 = scale_requant[2 * i];
1267             const float scale_requant_out0 = scale_requant[2 * i + 1];
1268 
1269             int j = 0;
1270             for (; j + 3 < N; j = j + 4)
1271             {
1272                 signed char* vb = bottom_tm.channel(j / 4);
1273                 signed char* va = kernel_tm.channel(i / 4 + i % 4);
1274                 int sum[4] = {0};
1275 
1276                 int k = 0;
1277                 for (; k + 1 < K; k = k + 2)
1278                 {
1279                     for (int n = 0; n < 4; n++)
1280                     {
1281                         sum[n] += (int)va[0] * vb[2 * n];
1282                         sum[n] += (int)va[1] * vb[2 * n + 1];
1283                     }
1284                     va += 2;
1285                     vb += 8;
1286                 }
1287 
1288                 for (; k < K; k++)
1289                 {
1290                     for (int n = 0; n < 4; n++)
1291                     {
1292                         sum[n] += (int)va[0] * vb[n];
1293                     }
1294                     va += 1;
1295                     vb += 4;
1296                 }
1297 
1298                 for (int n = 0; n < 4; n++)
1299                 {
1300                     output[n] = float2int8(((float)sum[n] * scale_requant_in0 + bias0) * scale_requant_out0);
1301                 }
1302                 output += 4;
1303             }
1304 
1305             for (; j < N; j++)
1306             {
1307                 int sum = 0;
1308 
1309                 signed char* vb = bottom_tm.channel(j / 4 + j % 4);
1310                 signed char* va = kernel_tm.channel(i / 4 + i % 4);
1311 
1312                 for (int k = 0; k < K; k++)
1313                 {
1314                     sum += (int)va[0] * vb[0];
1315 
1316                     va += 1;
1317                     vb += 1;
1318                 }
1319                 output[0] = float2int8(((float)sum * scale_requant_in0 + bias0) * scale_requant_out0);
1320 
1321                 output++;
1322             }
1323         }
1324     }
1325 
1326     // // sgemm(int M, int N, int K, float* A, float* B, float* C)
1327     // {
1328     //     for (int i=0; i<M; i++)
1329     //     {
1330     //         int* output = top_blob.channel(i);
1331 
1332     //         for (int j=0; j<N; j++)
1333     //         {
1334     //             int sum = 0;
1335 
1336     //             signed char* vb = (signed char*)bottom_im2row + K * j;
1337     //             const signed char* va = kernel + K * i;
1338 
1339     //             for (int k=0; k<K; k++)
1340     //             {
1341     //                 sum += (int)va[0] * vb[0];
1342 
1343     //                 va += 1;
1344     //                 vb += 1;
1345     //             }
1346     //             output[0] = sum;
1347 
1348     //             output++;
1349     //         }
1350     //     }
1351     // }
1352 }
1353