1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
convolution_pack8to4_int8_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & weight_data_int8,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,const Option & opt)15 static void convolution_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_int8, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt)
16 {
17 int w = bottom_blob.w;
18 int channels = bottom_blob.c;
19
20 int outw = top_blob.w;
21 int outh = top_blob.h;
22 int outch = top_blob.c;
23
24 const int maxk = kernel_w * kernel_h;
25
26 // kernel offsets
27 std::vector<int> _space_ofs(maxk);
28 int* space_ofs = &_space_ofs[0];
29 {
30 int p1 = 0;
31 int p2 = 0;
32 int gap = w * dilation_h - kernel_w * dilation_w;
33 for (int i = 0; i < kernel_h; i++)
34 {
35 for (int j = 0; j < kernel_w; j++)
36 {
37 space_ofs[p1] = p2;
38 p1++;
39 p2 += dilation_w;
40 }
41 p2 += gap;
42 }
43 }
44
45 // num_output
46 #pragma omp parallel for num_threads(opt.num_threads)
47 for (int p = 0; p < outch; p++)
48 {
49 int* outptr = top_blob.channel(p);
50
51 for (int i = 0; i < outh; i++)
52 {
53 for (int j = 0; j < outw; j++)
54 {
55 __m128i _sum0 = _mm_setzero_si128();
56 __m128i _sum1 = _mm_setzero_si128();
57 __m128i _sum2 = _mm_setzero_si128();
58 __m128i _sum3 = _mm_setzero_si128();
59
60 const signed char* kptr = weight_data_int8.channel(p);
61
62 // channels
63 for (int q = 0; q < channels; q++)
64 {
65 const Mat m = bottom_blob.channel(q);
66 const signed char* sptr = m.row<signed char>(i * stride_h) + j * stride_w * 8;
67
68 for (int k = 0; k < maxk; k++)
69 {
70 // TODO use _mm_cvtepi8_epi16 on sse4.1
71 __m128i _val = _mm_loadl_epi64((const __m128i*)(sptr + space_ofs[k] * 8));
72 _val = _mm_unpacklo_epi8(_val, _mm_cmpgt_epi8(_mm_setzero_si128(), _val));
73
74 // TODO use _mm_cvtepi8_epi16 on sse4.1
75 __m128i _w01 = _mm_loadu_si128((const __m128i*)kptr);
76 __m128i _w23 = _mm_loadu_si128((const __m128i*)(kptr + 16));
77 __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01);
78 __m128i _extw23 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w23);
79 __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01);
80 __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01);
81 __m128i _w2 = _mm_unpacklo_epi8(_w23, _extw23);
82 __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23);
83
84 __m128i _sl0 = _mm_mullo_epi16(_val, _w0);
85 __m128i _sh0 = _mm_mulhi_epi16(_val, _w0);
86 __m128i _sl1 = _mm_mullo_epi16(_val, _w1);
87 __m128i _sh1 = _mm_mulhi_epi16(_val, _w1);
88 __m128i _sl2 = _mm_mullo_epi16(_val, _w2);
89 __m128i _sh2 = _mm_mulhi_epi16(_val, _w2);
90 __m128i _sl3 = _mm_mullo_epi16(_val, _w3);
91 __m128i _sh3 = _mm_mulhi_epi16(_val, _w3);
92
93 _sum0 = _mm_add_epi32(_sum0, _mm_unpacklo_epi16(_sl0, _sh0));
94 _sum1 = _mm_add_epi32(_sum1, _mm_unpacklo_epi16(_sl1, _sh1));
95 _sum2 = _mm_add_epi32(_sum2, _mm_unpacklo_epi16(_sl2, _sh2));
96 _sum3 = _mm_add_epi32(_sum3, _mm_unpacklo_epi16(_sl3, _sh3));
97 _sum0 = _mm_add_epi32(_sum0, _mm_unpackhi_epi16(_sl0, _sh0));
98 _sum1 = _mm_add_epi32(_sum1, _mm_unpackhi_epi16(_sl1, _sh1));
99 _sum2 = _mm_add_epi32(_sum2, _mm_unpackhi_epi16(_sl2, _sh2));
100 _sum3 = _mm_add_epi32(_sum3, _mm_unpackhi_epi16(_sl3, _sh3));
101
102 kptr += 32;
103 }
104 }
105
106 // transpose 4x4
107 {
108 __m128i _tmp0, _tmp1, _tmp2, _tmp3;
109 _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1);
110 _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3);
111 _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1);
112 _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3);
113 _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1);
114 _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1);
115 _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3);
116 _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3);
117 }
118
119 _sum0 = _mm_add_epi32(_sum0, _sum1);
120 _sum2 = _mm_add_epi32(_sum2, _sum3);
121
122 _sum0 = _mm_add_epi32(_sum0, _sum2);
123
124 _mm_storeu_si128((__m128i*)(outptr + j * 4), _sum0);
125 }
126
127 outptr += outw * 4;
128 }
129 }
130 }
131