1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
convolution_pack8to4_int8_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & weight_data_int8,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,const Option & opt)15 static void convolution_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_int8, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt)
16 {
17     int w = bottom_blob.w;
18     int channels = bottom_blob.c;
19 
20     int outw = top_blob.w;
21     int outh = top_blob.h;
22     int outch = top_blob.c;
23 
24     const int maxk = kernel_w * kernel_h;
25 
26     // kernel offsets
27     std::vector<int> _space_ofs(maxk);
28     int* space_ofs = &_space_ofs[0];
29     {
30         int p1 = 0;
31         int p2 = 0;
32         int gap = w * dilation_h - kernel_w * dilation_w;
33         for (int i = 0; i < kernel_h; i++)
34         {
35             for (int j = 0; j < kernel_w; j++)
36             {
37                 space_ofs[p1] = p2;
38                 p1++;
39                 p2 += dilation_w;
40             }
41             p2 += gap;
42         }
43     }
44 
45     // num_output
46     #pragma omp parallel for num_threads(opt.num_threads)
47     for (int p = 0; p < outch; p++)
48     {
49         int* outptr = top_blob.channel(p);
50 
51         for (int i = 0; i < outh; i++)
52         {
53             for (int j = 0; j < outw; j++)
54             {
55                 __m128i _sum0 = _mm_setzero_si128();
56                 __m128i _sum1 = _mm_setzero_si128();
57                 __m128i _sum2 = _mm_setzero_si128();
58                 __m128i _sum3 = _mm_setzero_si128();
59 
60                 const signed char* kptr = weight_data_int8.channel(p);
61 
62                 // channels
63                 for (int q = 0; q < channels; q++)
64                 {
65                     const Mat m = bottom_blob.channel(q);
66                     const signed char* sptr = m.row<signed char>(i * stride_h) + j * stride_w * 8;
67 
68                     for (int k = 0; k < maxk; k++)
69                     {
70                         // TODO use _mm_cvtepi8_epi16 on sse4.1
71                         __m128i _val = _mm_loadl_epi64((const __m128i*)(sptr + space_ofs[k] * 8));
72                         _val = _mm_unpacklo_epi8(_val, _mm_cmpgt_epi8(_mm_setzero_si128(), _val));
73 
74                         // TODO use _mm_cvtepi8_epi16 on sse4.1
75                         __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr);
76                         __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr + 16));
77                         __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01);
78                         __m128i _extw23 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w23);
79                         __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01);
80                         __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01);
81                         __m128i _w2 = _mm_unpacklo_epi8(_w23, _extw23);
82                         __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23);
83 
84                         __m128i _sl0 = _mm_mullo_epi16(_val, _w0);
85                         __m128i _sh0 = _mm_mulhi_epi16(_val, _w0);
86                         __m128i _sl1 = _mm_mullo_epi16(_val, _w1);
87                         __m128i _sh1 = _mm_mulhi_epi16(_val, _w1);
88                         __m128i _sl2 = _mm_mullo_epi16(_val, _w2);
89                         __m128i _sh2 = _mm_mulhi_epi16(_val, _w2);
90                         __m128i _sl3 = _mm_mullo_epi16(_val, _w3);
91                         __m128i _sh3 = _mm_mulhi_epi16(_val, _w3);
92 
93                         _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl0, _sh0));
94                         _sum1 = _mm_add_epi32(_sum1, _mm_unpacklo_epi16(_sl1, _sh1));
95                         _sum2 = _mm_add_epi32(_sum2, _mm_unpacklo_epi16(_sl2, _sh2));
96                         _sum3 = _mm_add_epi32(_sum3, _mm_unpacklo_epi16(_sl3, _sh3));
97                         _sum0 = _mm_add_epi32(_sum0, _mm_unpackhi_epi16(_sl0, _sh0));
98                         _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl1, _sh1));
99                         _sum2 = _mm_add_epi32(_sum2, _mm_unpackhi_epi16(_sl2, _sh2));
100                         _sum3 = _mm_add_epi32(_sum3, _mm_unpackhi_epi16(_sl3, _sh3));
101 
102                         kptr += 32;
103                     }
104                 }
105 
106                 // transpose 4x4
107                 {
108                     __m128i _tmp0, _tmp1, _tmp2, _tmp3;
109                     _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1);
110                     _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3);
111                     _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1);
112                     _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3);
113                     _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1);
114                     _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1);
115                     _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3);
116                     _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3);
117                 }
118 
119                 _sum0 = _mm_add_epi32(_sum0, _sum1);
120                 _sum2 = _mm_add_epi32(_sum2, _sum3);
121 
122                 _sum0 = _mm_add_epi32(_sum0, _sum2);
123 
124                 _mm_storeu_si128((__m128i*)(outptr + j * 4), _sum0);
125             }
126 
127             outptr += outw * 4;
128         }
129     }
130 }
131