1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
convolution_pack8to4_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & weight_data_int8,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,const Option & opt)15 static void convolution_pack8to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_int8, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt)
16 {
17 int w = bottom_blob.w;
18 int channels = bottom_blob.c;
19
20 int outw = top_blob.w;
21 int outh = top_blob.h;
22 int outch = top_blob.c;
23
24 const int maxk = kernel_w * kernel_h;
25
26 // kernel offsets
27 std::vector<int> _space_ofs(maxk);
28 int* space_ofs = &_space_ofs[0];
29 {
30 int p1 = 0;
31 int p2 = 0;
32 int gap = w * dilation_h - kernel_w * dilation_w;
33 for (int i = 0; i < kernel_h; i++)
34 {
35 for (int j = 0; j < kernel_w; j++)
36 {
37 space_ofs[p1] = p2;
38 p1++;
39 p2 += dilation_w;
40 }
41 p2 += gap;
42 }
43 }
44
45 // num_output
46 #pragma omp parallel for num_threads(opt.num_threads)
47 for (int p = 0; p < outch; p++)
48 {
49 int* outptr = top_blob.channel(p);
50
51 for (int i = 0; i < outh; i++)
52 {
53 for (int j = 0; j < outw; j++)
54 {
55 int32x4_t _sum01 = vdupq_n_s32(0);
56 int32x4_t _sum23 = vdupq_n_s32(0);
57
58 const signed char* kptr = weight_data_int8.channel(p);
59
60 // channels
61 for (int q = 0; q < channels; q++)
62 {
63 const Mat m = bottom_blob.channel(q);
64 const signed char* sptr = m.row<signed char>(i * stride_h) + j * stride_w * 8;
65
66 for (int k = 0; k < maxk; k++)
67 {
68 int8x8_t _val = vld1_s8(sptr + space_ofs[k] * 8);
69
70 int8x8_t _w0 = vld1_s8(kptr);
71 int8x8_t _w1 = vld1_s8(kptr + 8);
72 int8x8_t _w2 = vld1_s8(kptr + 16);
73 int8x8_t _w3 = vld1_s8(kptr + 24);
74
75 int16x8_t _wv0 = vmull_s8(_val, _w0);
76 int16x8_t _wv1 = vmull_s8(_val, _w1);
77 int16x8_t _wv2 = vmull_s8(_val, _w2);
78 int16x8_t _wv3 = vmull_s8(_val, _w3);
79
80 int16x4_t _wv00 = vpadd_s16(vget_low_s16(_wv0), vget_high_s16(_wv0));
81 int16x4_t _wv11 = vpadd_s16(vget_low_s16(_wv1), vget_high_s16(_wv1));
82 int16x4_t _wv22 = vpadd_s16(vget_low_s16(_wv2), vget_high_s16(_wv2));
83 int16x4_t _wv33 = vpadd_s16(vget_low_s16(_wv3), vget_high_s16(_wv3));
84
85 _sum01 = vpadalq_s16(_sum01, vcombine_s16(_wv00, _wv11));
86 _sum23 = vpadalq_s16(_sum23, vcombine_s16(_wv22, _wv33));
87
88 kptr += 32;
89 }
90 }
91
92 int32x4_t _sum0 = vcombine_s32(vpadd_s32(vget_low_s32(_sum01), vget_high_s32(_sum01)), vpadd_s32(vget_low_s32(_sum23), vget_high_s32(_sum23)));
93
94 vst1q_s32(outptr + j * 4, _sum0);
95 }
96
97 outptr += outw * 4;
98 }
99 }
100 }
101