1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv5x5s1_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)15 static void conv5x5s1_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
16 {
17     int w = bottom_blob.w;
18     int inch = bottom_blob.c;
19 
20     int outw = top_blob.w;
21     int outh = top_blob.h;
22     int outch = top_blob.c;
23 
24     const float* kernel = _kernel;
25     const float* bias = _bias;
26 
27     #pragma omp parallel for num_threads(opt.num_threads)
28     for (int p = 0; p < outch; p++)
29     {
30         Mat out = top_blob.channel(p);
31 
32         const float bias0 = bias ? bias[p] : 0.f;
33 
34         out.fill(bias0);
35 
36         for (int q = 0; q < inch; q++)
37         {
38             float* outptr = out;
39             float* outptr2 = outptr + outw;
40 
41             const float* img0 = bottom_blob.channel(q);
42 
43             const float* kernel0 = kernel + p * inch * 25 + q * 25;
44 
45             const float* r0 = img0;
46             const float* r1 = img0 + w;
47             const float* r2 = img0 + w * 2;
48             const float* r3 = img0 + w * 3;
49             const float* r4 = img0 + w * 4;
50             const float* r5 = img0 + w * 5;
51 
52             const float* k0 = kernel0;
53             const float* k1 = kernel0 + 5;
54             const float* k2 = kernel0 + 10;
55             const float* k3 = kernel0 + 15;
56             const float* k4 = kernel0 + 20;
57 
58             int i = 0;
59 
60             for (; i + 1 < outh; i += 2)
61             {
62                 int remain = outw;
63 
64                 for (; remain > 0; remain--)
65                 {
66                     float sum = 0;
67                     float sum2 = 0;
68 
69                     sum += r0[0] * k0[0];
70                     sum += r0[1] * k0[1];
71                     sum += r0[2] * k0[2];
72                     sum += r0[3] * k0[3];
73                     sum += r0[4] * k0[4];
74 
75                     sum += r1[0] * k1[0];
76                     sum += r1[1] * k1[1];
77                     sum += r1[2] * k1[2];
78                     sum += r1[3] * k1[3];
79                     sum += r1[4] * k1[4];
80 
81                     sum += r2[0] * k2[0];
82                     sum += r2[1] * k2[1];
83                     sum += r2[2] * k2[2];
84                     sum += r2[3] * k2[3];
85                     sum += r2[4] * k2[4];
86 
87                     sum += r3[0] * k3[0];
88                     sum += r3[1] * k3[1];
89                     sum += r3[2] * k3[2];
90                     sum += r3[3] * k3[3];
91                     sum += r3[4] * k3[4];
92 
93                     sum += r4[0] * k4[0];
94                     sum += r4[1] * k4[1];
95                     sum += r4[2] * k4[2];
96                     sum += r4[3] * k4[3];
97                     sum += r4[4] * k4[4];
98 
99                     sum2 += r1[0] * k0[0];
100                     sum2 += r1[1] * k0[1];
101                     sum2 += r1[2] * k0[2];
102                     sum2 += r1[3] * k0[3];
103                     sum2 += r1[4] * k0[4];
104 
105                     sum2 += r2[0] * k1[0];
106                     sum2 += r2[1] * k1[1];
107                     sum2 += r2[2] * k1[2];
108                     sum2 += r2[3] * k1[3];
109                     sum2 += r2[4] * k1[4];
110 
111                     sum2 += r3[0] * k2[0];
112                     sum2 += r3[1] * k2[1];
113                     sum2 += r3[2] * k2[2];
114                     sum2 += r3[3] * k2[3];
115                     sum2 += r3[4] * k2[4];
116 
117                     sum2 += r4[0] * k3[0];
118                     sum2 += r4[1] * k3[1];
119                     sum2 += r4[2] * k3[2];
120                     sum2 += r4[3] * k3[3];
121                     sum2 += r4[4] * k3[4];
122 
123                     sum2 += r5[0] * k4[0];
124                     sum2 += r5[1] * k4[1];
125                     sum2 += r5[2] * k4[2];
126                     sum2 += r5[3] * k4[3];
127                     sum2 += r5[4] * k4[4];
128 
129                     *outptr += sum;
130                     *outptr2 += sum2;
131 
132                     r0++;
133                     r1++;
134                     r2++;
135                     r3++;
136                     r4++;
137                     r5++;
138                     outptr++;
139                     outptr2++;
140                 }
141 
142                 r0 += 4 + w;
143                 r1 += 4 + w;
144                 r2 += 4 + w;
145                 r3 += 4 + w;
146                 r4 += 4 + w;
147                 r5 += 4 + w;
148 
149                 outptr += outw;
150                 outptr2 += outw;
151             }
152 
153             for (; i < outh; i++)
154             {
155                 int remain = outw;
156 
157                 for (; remain > 0; remain--)
158                 {
159                     float sum = 0;
160 
161                     sum += r0[0] * k0[0];
162                     sum += r0[1] * k0[1];
163                     sum += r0[2] * k0[2];
164                     sum += r0[3] * k0[3];
165                     sum += r0[4] * k0[4];
166 
167                     sum += r1[0] * k1[0];
168                     sum += r1[1] * k1[1];
169                     sum += r1[2] * k1[2];
170                     sum += r1[3] * k1[3];
171                     sum += r1[4] * k1[4];
172 
173                     sum += r2[0] * k2[0];
174                     sum += r2[1] * k2[1];
175                     sum += r2[2] * k2[2];
176                     sum += r2[3] * k2[3];
177                     sum += r2[4] * k2[4];
178 
179                     sum += r3[0] * k3[0];
180                     sum += r3[1] * k3[1];
181                     sum += r3[2] * k3[2];
182                     sum += r3[3] * k3[3];
183                     sum += r3[4] * k3[4];
184 
185                     sum += r4[0] * k4[0];
186                     sum += r4[1] * k4[1];
187                     sum += r4[2] * k4[2];
188                     sum += r4[3] * k4[3];
189                     sum += r4[4] * k4[4];
190 
191                     *outptr += sum;
192 
193                     r0++;
194                     r1++;
195                     r2++;
196                     r3++;
197                     r4++;
198                     outptr++;
199                 }
200 
201                 r0 += 4;
202                 r1 += 4;
203                 r2 += 4;
204                 r3 += 4;
205                 r4 += 4;
206             }
207         }
208     }
209 }
210 
conv5x5s2_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)211 static void conv5x5s2_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
212 {
213     int kernel_w = 5;
214     int kernel_h = 5;
215 
216     int stride_w = 2;
217     int stride_h = 2;
218 
219     conv_im2col_sgemm_sse(bottom_blob, top_blob, _kernel, _bias, kernel_w, kernel_h, stride_w, stride_h, opt);
220 }