1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv5x5s1_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)15 static void conv5x5s1_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
16 {
17 int w = bottom_blob.w;
18 int inch = bottom_blob.c;
19
20 int outw = top_blob.w;
21 int outh = top_blob.h;
22 int outch = top_blob.c;
23
24 const float* kernel = _kernel;
25 const float* bias = _bias;
26
27 #pragma omp parallel for num_threads(opt.num_threads)
28 for (int p = 0; p < outch; p++)
29 {
30 Mat out = top_blob.channel(p);
31
32 const float bias0 = bias ? bias[p] : 0.f;
33
34 out.fill(bias0);
35
36 for (int q = 0; q < inch; q++)
37 {
38 float* outptr = out;
39 float* outptr2 = outptr + outw;
40
41 const float* img0 = bottom_blob.channel(q);
42
43 const float* kernel0 = kernel + p * inch * 25 + q * 25;
44
45 const float* r0 = img0;
46 const float* r1 = img0 + w;
47 const float* r2 = img0 + w * 2;
48 const float* r3 = img0 + w * 3;
49 const float* r4 = img0 + w * 4;
50 const float* r5 = img0 + w * 5;
51
52 const float* k0 = kernel0;
53 const float* k1 = kernel0 + 5;
54 const float* k2 = kernel0 + 10;
55 const float* k3 = kernel0 + 15;
56 const float* k4 = kernel0 + 20;
57
58 int i = 0;
59
60 for (; i + 1 < outh; i += 2)
61 {
62 int remain = outw;
63
64 for (; remain > 0; remain--)
65 {
66 float sum = 0;
67 float sum2 = 0;
68
69 sum += r0[0] * k0[0];
70 sum += r0[1] * k0[1];
71 sum += r0[2] * k0[2];
72 sum += r0[3] * k0[3];
73 sum += r0[4] * k0[4];
74
75 sum += r1[0] * k1[0];
76 sum += r1[1] * k1[1];
77 sum += r1[2] * k1[2];
78 sum += r1[3] * k1[3];
79 sum += r1[4] * k1[4];
80
81 sum += r2[0] * k2[0];
82 sum += r2[1] * k2[1];
83 sum += r2[2] * k2[2];
84 sum += r2[3] * k2[3];
85 sum += r2[4] * k2[4];
86
87 sum += r3[0] * k3[0];
88 sum += r3[1] * k3[1];
89 sum += r3[2] * k3[2];
90 sum += r3[3] * k3[3];
91 sum += r3[4] * k3[4];
92
93 sum += r4[0] * k4[0];
94 sum += r4[1] * k4[1];
95 sum += r4[2] * k4[2];
96 sum += r4[3] * k4[3];
97 sum += r4[4] * k4[4];
98
99 sum2 += r1[0] * k0[0];
100 sum2 += r1[1] * k0[1];
101 sum2 += r1[2] * k0[2];
102 sum2 += r1[3] * k0[3];
103 sum2 += r1[4] * k0[4];
104
105 sum2 += r2[0] * k1[0];
106 sum2 += r2[1] * k1[1];
107 sum2 += r2[2] * k1[2];
108 sum2 += r2[3] * k1[3];
109 sum2 += r2[4] * k1[4];
110
111 sum2 += r3[0] * k2[0];
112 sum2 += r3[1] * k2[1];
113 sum2 += r3[2] * k2[2];
114 sum2 += r3[3] * k2[3];
115 sum2 += r3[4] * k2[4];
116
117 sum2 += r4[0] * k3[0];
118 sum2 += r4[1] * k3[1];
119 sum2 += r4[2] * k3[2];
120 sum2 += r4[3] * k3[3];
121 sum2 += r4[4] * k3[4];
122
123 sum2 += r5[0] * k4[0];
124 sum2 += r5[1] * k4[1];
125 sum2 += r5[2] * k4[2];
126 sum2 += r5[3] * k4[3];
127 sum2 += r5[4] * k4[4];
128
129 *outptr += sum;
130 *outptr2 += sum2;
131
132 r0++;
133 r1++;
134 r2++;
135 r3++;
136 r4++;
137 r5++;
138 outptr++;
139 outptr2++;
140 }
141
142 r0 += 4 + w;
143 r1 += 4 + w;
144 r2 += 4 + w;
145 r3 += 4 + w;
146 r4 += 4 + w;
147 r5 += 4 + w;
148
149 outptr += outw;
150 outptr2 += outw;
151 }
152
153 for (; i < outh; i++)
154 {
155 int remain = outw;
156
157 for (; remain > 0; remain--)
158 {
159 float sum = 0;
160
161 sum += r0[0] * k0[0];
162 sum += r0[1] * k0[1];
163 sum += r0[2] * k0[2];
164 sum += r0[3] * k0[3];
165 sum += r0[4] * k0[4];
166
167 sum += r1[0] * k1[0];
168 sum += r1[1] * k1[1];
169 sum += r1[2] * k1[2];
170 sum += r1[3] * k1[3];
171 sum += r1[4] * k1[4];
172
173 sum += r2[0] * k2[0];
174 sum += r2[1] * k2[1];
175 sum += r2[2] * k2[2];
176 sum += r2[3] * k2[3];
177 sum += r2[4] * k2[4];
178
179 sum += r3[0] * k3[0];
180 sum += r3[1] * k3[1];
181 sum += r3[2] * k3[2];
182 sum += r3[3] * k3[3];
183 sum += r3[4] * k3[4];
184
185 sum += r4[0] * k4[0];
186 sum += r4[1] * k4[1];
187 sum += r4[2] * k4[2];
188 sum += r4[3] * k4[3];
189 sum += r4[4] * k4[4];
190
191 *outptr += sum;
192
193 r0++;
194 r1++;
195 r2++;
196 r3++;
197 r4++;
198 outptr++;
199 }
200
201 r0 += 4;
202 r1 += 4;
203 r2 += 4;
204 r3 += 4;
205 r4 += 4;
206 }
207 }
208 }
209 }
210
conv5x5s2_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)211 static void conv5x5s2_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
212 {
213 int kernel_w = 5;
214 int kernel_h = 5;
215
216 int stride_w = 2;
217 int stride_h = 2;
218
219 conv_im2col_sgemm_sse(bottom_blob, top_blob, _kernel, _bias, kernel_w, kernel_h, stride_w, stride_h, opt);
220 }