1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "innerproduct_x86.h"
16
17 #if __SSE2__
18 #include <emmintrin.h>
19 #if __AVX__
20 #include <immintrin.h>
21 #endif
22 #endif // __SSE2__
23
24 #include "x86_activation.h"
25 #include "x86_usability.h"
26
27 #include "layer_type.h"
28
29 namespace ncnn {
30
InnerProduct_x86()31 InnerProduct_x86::InnerProduct_x86()
32 {
33 #if __SSE2__
34 support_packing = true;
35 #if __AVX__
36 support_weight_fp16_storage = true;
37 #endif
38 #endif // __SSE2__
39
40 flatten = 0;
41 activation = 0;
42 }
43
create_pipeline(const Option & opt)44 int InnerProduct_x86::create_pipeline(const Option& opt)
45 {
46 // if (opt.use_packing_layout)
47 {
48 flatten = ncnn::create_layer(ncnn::LayerType::Flatten);
49
50 ncnn::ParamDict pd;
51
52 flatten->load_param(pd);
53
54 flatten->create_pipeline(opt);
55 }
56
57 #if NCNN_INT8
58 if (opt.use_int8_inference && weight_data.elemsize == (size_t)1u)
59 {
60 return create_pipeline_int8_x86(opt);
61 }
62 #endif
63
64 const int num_input = weight_data_size / num_output;
65
66 int out_elempack = 1;
67
68 #if __SSE2__
69 if (opt.use_packing_layout)
70 {
71 #if __AVX__
72 out_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;
73 #else
74 out_elempack = num_output % 4 == 0 ? 4 : 1;
75 #endif
76 }
77 #endif // __SSE2__
78
79 if (out_elempack != 1)
80 {
81 // src = inch-outch
82 // dst = pb-inch-outch/pb
83 {
84 Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
85
86 weight_data_packed.create(num_input, num_output / out_elempack, (size_t)4u * out_elempack, out_elempack);
87
88 for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack)
89 {
90 float* g0 = weight_data_packed.row(q / out_elempack);
91
92 for (int p = 0; p < num_input; p++)
93 {
94 for (int j = 0; j < out_elempack; j++)
95 {
96 *g0++ = weight_data_r2.row(q + j)[p];
97 }
98 }
99 }
100 }
101 }
102
103 #if __AVX__
104 if (opt.use_weight_fp16_storage && weight_data.elemsize == 4u)
105 {
106 ncnn::cast_float32_to_float16(weight_data, weight_data_fp16, opt);
107
108 return 0;
109 }
110 #endif
111
112 return 0;
113 }
114
destroy_pipeline(const Option & opt)115 int InnerProduct_x86::destroy_pipeline(const Option& opt)
116 {
117 if (flatten)
118 {
119 flatten->destroy_pipeline(opt);
120 delete flatten;
121 flatten = 0;
122 }
123
124 return 0;
125 }
126
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const127 int InnerProduct_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
128 {
129 #if NCNN_INT8
130 if (opt.use_int8_inference && weight_data.elemsize == (size_t)1u)
131 {
132 return forward_int8_x86(bottom_blob, top_blob, opt);
133 }
134 #endif
135
136 const int num_input = weight_data_size / num_output;
137
138 if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
139 {
140 // gemm
141 int h = bottom_blob.h;
142 size_t elemsize = bottom_blob.elemsize;
143 int elempack = bottom_blob.elempack;
144
145 top_blob.create(num_output, h, elemsize, elempack, opt.blob_allocator);
146 if (top_blob.empty())
147 return -100;
148
149 int num_output_elempack = 1;
150 #if __SSE2__
151 if (opt.use_packing_layout)
152 {
153 #if __AVX__
154 num_output_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;
155 #else
156 num_output_elempack = num_output % 4 == 0 ? 4 : 1;
157 #endif
158 }
159 #endif // __SSE2__
160
161 #pragma omp parallel for num_threads(opt.num_threads)
162 for (int j = 0; j < h; j++)
163 {
164 #if __SSE2__
165 #if __AVX__
166 if (elempack == 8 && num_output_elempack == 8)
167 {
168 float* outptr = top_blob.row(j);
169
170 for (int p = 0; p < num_output / num_output_elempack; p++)
171 {
172 const float* kptr = (const float*)weight_data_packed + num_input * p * 8;
173 const float* m = bottom_blob.row(j);
174
175 __m256 _sum0 = _mm256_set1_ps(0.f);
176 __m256 _sum1 = _mm256_set1_ps(0.f);
177 __m256 _sum2 = _mm256_set1_ps(0.f);
178 __m256 _sum3 = _mm256_set1_ps(0.f);
179 __m256 _sum4 = _mm256_set1_ps(0.f);
180 __m256 _sum5 = _mm256_set1_ps(0.f);
181 __m256 _sum6 = _mm256_set1_ps(0.f);
182 __m256 _sum7 = _mm256_set1_ps(0.f);
183
184 if (bias_term)
185 {
186 _sum0 = _mm256_set1_ps(bias_data[p * 8 + 0]);
187 _sum1 = _mm256_set1_ps(bias_data[p * 8 + 1]);
188 _sum2 = _mm256_set1_ps(bias_data[p * 8 + 2]);
189 _sum3 = _mm256_set1_ps(bias_data[p * 8 + 3]);
190 _sum4 = _mm256_set1_ps(bias_data[p * 8 + 4]);
191 _sum5 = _mm256_set1_ps(bias_data[p * 8 + 5]);
192 _sum6 = _mm256_set1_ps(bias_data[p * 8 + 6]);
193 _sum7 = _mm256_set1_ps(bias_data[p * 8 + 7]);
194 }
195
196 for (int i = 0; i < num_input; i++)
197 {
198 __m256 _val = _mm256_loadu_ps(m);
199 __m256 _k0 = _mm256_set1_ps(kptr[0]);
200 __m256 _k1 = _mm256_set1_ps(kptr[1]);
201 __m256 _k2 = _mm256_set1_ps(kptr[2]);
202 __m256 _k3 = _mm256_set1_ps(kptr[3]);
203 __m256 _k4 = _mm256_set1_ps(kptr[4]);
204 __m256 _k5 = _mm256_set1_ps(kptr[5]);
205 __m256 _k6 = _mm256_set1_ps(kptr[6]);
206 __m256 _k7 = _mm256_set1_ps(kptr[7]);
207 _sum0 = _mm256_fmadd_ps(_val, _k0, _sum0);
208 _sum1 = _mm256_fmadd_ps(_val, _k1, _sum1);
209 _sum2 = _mm256_fmadd_ps(_val, _k2, _sum2);
210 _sum3 = _mm256_fmadd_ps(_val, _k3, _sum3);
211 _sum4 = _mm256_fmadd_ps(_val, _k4, _sum4);
212 _sum5 = _mm256_fmadd_ps(_val, _k5, _sum5);
213 _sum6 = _mm256_fmadd_ps(_val, _k6, _sum6);
214 _sum7 = _mm256_fmadd_ps(_val, _k7, _sum7);
215
216 m += 8;
217 kptr += 8;
218 }
219
220 _sum0 = activation_avx(_sum0, activation_type, activation_params);
221 _sum1 = activation_avx(_sum1, activation_type, activation_params);
222 _sum2 = activation_avx(_sum2, activation_type, activation_params);
223 _sum3 = activation_avx(_sum3, activation_type, activation_params);
224 _sum4 = activation_avx(_sum4, activation_type, activation_params);
225 _sum5 = activation_avx(_sum5, activation_type, activation_params);
226 _sum6 = activation_avx(_sum6, activation_type, activation_params);
227 _sum7 = activation_avx(_sum7, activation_type, activation_params);
228
229 _mm256_storeu_ps(outptr, _sum0);
230 _mm256_storeu_ps(outptr + 8, _sum1);
231 _mm256_storeu_ps(outptr + 16, _sum2);
232 _mm256_storeu_ps(outptr + 24, _sum3);
233 _mm256_storeu_ps(outptr + 32, _sum4);
234 _mm256_storeu_ps(outptr + 40, _sum5);
235 _mm256_storeu_ps(outptr + 48, _sum6);
236 _mm256_storeu_ps(outptr + 56, _sum7);
237 outptr += 64;
238 }
239 }
240
241 if (elempack == 1 && num_output_elempack == 8)
242 {
243 float* outptr = top_blob.row(j);
244
245 for (int p = 0; p < num_output / num_output_elempack; p++)
246 {
247 const float* kptr = (const float*)weight_data_packed + num_input * p * 8;
248 const float* m = bottom_blob.row(j);
249
250 __m256 _sum = _mm256_set1_ps(0.f);
251
252 if (bias_term)
253 {
254 _sum = _mm256_loadu_ps((const float*)bias_data + p * 8);
255 }
256
257 int i = 0;
258 for (; i + 7 < num_input; i += 8)
259 {
260 __m256 _val0 = _mm256_broadcast_ss(m);
261 __m256 _val1 = _mm256_broadcast_ss(m + 1);
262 __m256 _val2 = _mm256_broadcast_ss(m + 2);
263 __m256 _val3 = _mm256_broadcast_ss(m + 3);
264 __m256 _val4 = _mm256_broadcast_ss(m + 4);
265 __m256 _val5 = _mm256_broadcast_ss(m + 5);
266 __m256 _val6 = _mm256_broadcast_ss(m + 6);
267 __m256 _val7 = _mm256_broadcast_ss(m + 7);
268
269 __m256 _w0 = _mm256_loadu_ps(kptr);
270 _sum = _mm256_fmadd_ps(_val0, _w0, _sum);
271 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
272 _sum = _mm256_fmadd_ps(_val1, _w1, _sum);
273 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
274 _sum = _mm256_fmadd_ps(_val2, _w2, _sum);
275 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
276 _sum = _mm256_fmadd_ps(_val3, _w3, _sum);
277 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
278 _sum = _mm256_fmadd_ps(_val4, _w4, _sum);
279 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
280 _sum = _mm256_fmadd_ps(_val5, _w5, _sum);
281 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
282 _sum = _mm256_fmadd_ps(_val6, _w6, _sum);
283 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
284 _sum = _mm256_fmadd_ps(_val7, _w7, _sum);
285
286 m += 8;
287 kptr += 64;
288 }
289 for (; i + 3 < num_input; i += 4)
290 {
291 __m256 _val0 = _mm256_broadcast_ss(m);
292 __m256 _val1 = _mm256_broadcast_ss(m + 1);
293 __m256 _val2 = _mm256_broadcast_ss(m + 2);
294 __m256 _val3 = _mm256_broadcast_ss(m + 3);
295
296 __m256 _w0 = _mm256_loadu_ps(kptr);
297 _sum = _mm256_fmadd_ps(_val0, _w0, _sum);
298 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
299 _sum = _mm256_fmadd_ps(_val1, _w1, _sum);
300 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
301 _sum = _mm256_fmadd_ps(_val2, _w2, _sum);
302 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
303 _sum = _mm256_fmadd_ps(_val3, _w3, _sum);
304
305 m += 4;
306 kptr += 32;
307 }
308 for (; i < num_input; i++)
309 {
310 __m256 _val = _mm256_set1_ps(m[0]);
311 __m256 _w = _mm256_loadu_ps(kptr);
312 _sum = _mm256_fmadd_ps(_val, _w, _sum);
313
314 m += 1;
315 kptr += 8;
316 }
317
318 _sum = activation_avx(_sum, activation_type, activation_params);
319
320 _mm256_storeu_ps(outptr, _sum);
321 outptr += 8;
322 }
323 }
324
325 if (elempack == 4 && num_output_elempack == 8)
326 {
327 float* outptr = top_blob.row(j);
328
329 for (int p = 0; p < num_output / num_output_elempack; p++)
330 {
331 const float* kptr = (const float*)weight_data_packed + num_input * p * 8;
332 const float* m = bottom_blob.row(j);
333
334 __m128 _sum0 = _mm_set1_ps(0.f);
335 __m128 _sum1 = _mm_set1_ps(0.f);
336 __m128 _sum2 = _mm_set1_ps(0.f);
337 __m128 _sum3 = _mm_set1_ps(0.f);
338 __m128 _sum4 = _mm_set1_ps(0.f);
339 __m128 _sum5 = _mm_set1_ps(0.f);
340 __m128 _sum6 = _mm_set1_ps(0.f);
341 __m128 _sum7 = _mm_set1_ps(0.f);
342
343 if (bias_term)
344 {
345 _sum0 = _mm_set1_ps(bias_data[p * 8 + 0]);
346 _sum1 = _mm_set1_ps(bias_data[p * 8 + 1]);
347 _sum2 = _mm_set1_ps(bias_data[p * 8 + 2]);
348 _sum3 = _mm_set1_ps(bias_data[p * 8 + 3]);
349 _sum4 = _mm_set1_ps(bias_data[p * 8 + 4]);
350 _sum5 = _mm_set1_ps(bias_data[p * 8 + 5]);
351 _sum6 = _mm_set1_ps(bias_data[p * 8 + 6]);
352 _sum7 = _mm_set1_ps(bias_data[p * 8 + 7]);
353 }
354
355 int i = 0;
356 for (; i < num_input; i++)
357 {
358 __m128 _val = _mm_loadu_ps(m);
359 _sum0 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[0]), _sum0);
360 _sum1 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[1]), _sum1);
361 _sum2 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[2]), _sum2);
362 _sum3 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[3]), _sum3);
363 _sum4 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[4]), _sum4);
364 _sum5 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[5]), _sum5);
365 _sum6 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[6]), _sum6);
366 _sum7 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[7]), _sum7);
367
368 m += 4;
369 kptr += 8;
370 }
371
372 _sum0 = activation_sse(_sum0, activation_type, activation_params);
373 _sum1 = activation_sse(_sum1, activation_type, activation_params);
374 _sum2 = activation_sse(_sum2, activation_type, activation_params);
375 _sum3 = activation_sse(_sum3, activation_type, activation_params);
376 _sum4 = activation_sse(_sum4, activation_type, activation_params);
377 _sum5 = activation_sse(_sum5, activation_type, activation_params);
378 _sum6 = activation_sse(_sum6, activation_type, activation_params);
379 _sum7 = activation_sse(_sum7, activation_type, activation_params);
380
381 _mm_storeu_ps(outptr, _sum0);
382 _mm_storeu_ps(outptr + 4, _sum1);
383 _mm_storeu_ps(outptr + 8, _sum2);
384 _mm_storeu_ps(outptr + 12, _sum3);
385 _mm_storeu_ps(outptr + 16, _sum4);
386 _mm_storeu_ps(outptr + 20, _sum5);
387 _mm_storeu_ps(outptr + 24, _sum6);
388 _mm_storeu_ps(outptr + 28, _sum7);
389 outptr += 32;
390 }
391 }
392
393 if (elempack == 8 && num_output_elempack == 1)
394 {
395 float* outptr = top_blob.row(j);
396
397 for (int p = 0; p < num_output; p++)
398 {
399 const float* kptr = (const float*)weight_data + num_input * p;
400 const float* m = bottom_blob.row(j);
401
402 __m256 _sum0 = _mm256_set1_ps(0.f);
403 __m256 _sum1 = _mm256_set1_ps(0.f);
404 __m256 _sum2 = _mm256_set1_ps(0.f);
405 __m256 _sum3 = _mm256_set1_ps(0.f);
406
407 if (bias_term)
408 {
409 _sum0 = _mm256_set1_ps(bias_data[p]);
410 }
411
412 int i = 0;
413 for (; i + 7 < num_input; i += 8)
414 {
415 __m256 _val0 = _mm256_loadu_ps(m);
416 __m256 _val1 = _mm256_loadu_ps(m + 8);
417 __m256 _val2 = _mm256_loadu_ps(m + 16);
418 __m256 _val3 = _mm256_loadu_ps(m + 24);
419 __m256 _val4 = _mm256_loadu_ps(m + 32);
420 __m256 _val5 = _mm256_loadu_ps(m + 40);
421 __m256 _val6 = _mm256_loadu_ps(m + 48);
422 __m256 _val7 = _mm256_loadu_ps(m + 56);
423 _sum0 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[0]), _sum0);
424 _sum1 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[1]), _sum1);
425 _sum2 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[2]), _sum2);
426 _sum3 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[3]), _sum3);
427 _sum0 = _mm256_fmadd_ps(_val4, _mm256_set1_ps(kptr[4]), _sum0);
428 _sum1 = _mm256_fmadd_ps(_val5, _mm256_set1_ps(kptr[5]), _sum1);
429 _sum2 = _mm256_fmadd_ps(_val6, _mm256_set1_ps(kptr[6]), _sum2);
430 _sum3 = _mm256_fmadd_ps(_val7, _mm256_set1_ps(kptr[7]), _sum3);
431
432 m += 64;
433 kptr += 8;
434 }
435 for (; i + 3 < num_input; i += 4)
436 {
437 __m256 _val0 = _mm256_loadu_ps(m);
438 __m256 _val1 = _mm256_loadu_ps(m + 8);
439 __m256 _val2 = _mm256_loadu_ps(m + 16);
440 __m256 _val3 = _mm256_loadu_ps(m + 24);
441 _sum0 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[0]), _sum0);
442 _sum1 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[1]), _sum1);
443 _sum2 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[2]), _sum2);
444 _sum3 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[3]), _sum3);
445
446 m += 32;
447 kptr += 4;
448 }
449 for (; i < num_input; i++)
450 {
451 __m256 _val = _mm256_loadu_ps(m);
452 __m256 _k = _mm256_set1_ps(kptr[0]);
453 _sum0 = _mm256_fmadd_ps(_val, _k, _sum0);
454
455 m += 8;
456 kptr += 1;
457 }
458
459 _sum0 = _mm256_add_ps(_sum0, _sum1);
460 _sum2 = _mm256_add_ps(_sum2, _sum3);
461 _sum0 = _mm256_add_ps(_sum0, _sum2);
462
463 _sum0 = activation_avx(_sum0, activation_type, activation_params);
464
465 _mm256_storeu_ps(outptr, _sum0);
466 outptr += 8;
467 }
468 }
469
470 if (elempack == 8 && num_output_elempack == 4)
471 {
472 float* outptr = top_blob.row(j);
473
474 for (int p = 0; p < num_output / num_output_elempack; p++)
475 {
476 const float* kptr = (const float*)weight_data_packed + num_input * p * 4;
477 const float* m = bottom_blob.row(j);
478
479 __m256 _sum0 = _mm256_set1_ps(0.f);
480 __m256 _sum1 = _mm256_set1_ps(0.f);
481 __m256 _sum2 = _mm256_set1_ps(0.f);
482 __m256 _sum3 = _mm256_set1_ps(0.f);
483
484 if (bias_term)
485 {
486 _sum0 = _mm256_set1_ps(bias_data[p * 4 + 0]);
487 _sum1 = _mm256_set1_ps(bias_data[p * 4 + 1]);
488 _sum2 = _mm256_set1_ps(bias_data[p * 4 + 2]);
489 _sum3 = _mm256_set1_ps(bias_data[p * 4 + 3]);
490 }
491
492 int i = 0;
493 for (; i + 3 < num_input; i += 4)
494 {
495 __m256 _val0 = _mm256_loadu_ps(m);
496 __m256 _val1 = _mm256_loadu_ps(m + 8);
497 __m256 _val2 = _mm256_loadu_ps(m + 16);
498 __m256 _val3 = _mm256_loadu_ps(m + 24);
499 _sum0 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[0]), _sum0);
500 _sum1 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[1]), _sum1);
501 _sum2 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[2]), _sum2);
502 _sum3 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[3]), _sum3);
503 _sum0 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[4]), _sum0);
504 _sum1 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[5]), _sum1);
505 _sum2 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[6]), _sum2);
506 _sum3 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[7]), _sum3);
507 kptr += 8;
508
509 _sum0 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[0]), _sum0);
510 _sum1 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[1]), _sum1);
511 _sum2 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[2]), _sum2);
512 _sum3 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[3]), _sum3);
513 _sum0 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[4]), _sum0);
514 _sum1 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[5]), _sum1);
515 _sum2 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[6]), _sum2);
516 _sum3 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[7]), _sum3);
517
518 m += 32;
519 kptr += 8;
520 }
521 for (; i < num_input; i++)
522 {
523 __m256 _val = _mm256_loadu_ps(m);
524 _sum0 = _mm256_fmadd_ps(_val, _mm256_set1_ps(kptr[0]), _sum0);
525 _sum1 = _mm256_fmadd_ps(_val, _mm256_set1_ps(kptr[1]), _sum1);
526 _sum2 = _mm256_fmadd_ps(_val, _mm256_set1_ps(kptr[2]), _sum2);
527 _sum3 = _mm256_fmadd_ps(_val, _mm256_set1_ps(kptr[3]), _sum3);
528
529 m += 8;
530 kptr += 4;
531 }
532
533 _sum0 = activation_avx(_sum0, activation_type, activation_params);
534 _sum1 = activation_avx(_sum1, activation_type, activation_params);
535 _sum2 = activation_avx(_sum2, activation_type, activation_params);
536 _sum3 = activation_avx(_sum3, activation_type, activation_params);
537
538 _mm256_storeu_ps(outptr, _sum0);
539 _mm256_storeu_ps(outptr + 8, _sum1);
540 _mm256_storeu_ps(outptr + 16, _sum2);
541 _mm256_storeu_ps(outptr + 24, _sum3);
542 outptr += 32;
543 }
544 }
545 #endif // __AVX__
546
547 if (elempack == 4 && num_output_elempack == 4)
548 {
549 float* outptr = top_blob.row(j);
550
551 for (int p = 0; p < num_output / num_output_elempack; p++)
552 {
553 const float* kptr = (const float*)weight_data_packed + num_input * p * 4;
554 const float* m = bottom_blob.row(j);
555
556 __m128 _sum0 = _mm_set1_ps(0.f);
557 __m128 _sum1 = _mm_set1_ps(0.f);
558 __m128 _sum2 = _mm_set1_ps(0.f);
559 __m128 _sum3 = _mm_set1_ps(0.f);
560
561 if (bias_term)
562 {
563 _sum0 = _mm_set1_ps(bias_data[p * 4 + 0]);
564 _sum1 = _mm_set1_ps(bias_data[p * 4 + 1]);
565 _sum2 = _mm_set1_ps(bias_data[p * 4 + 2]);
566 _sum3 = _mm_set1_ps(bias_data[p * 4 + 3]);
567 }
568
569 int i = 0;
570 for (; i + 3 < num_input; i += 4)
571 {
572 __m128 _val0 = _mm_loadu_ps(m);
573 __m128 _val1 = _mm_loadu_ps(m + 4);
574 __m128 _val2 = _mm_loadu_ps(m + 8);
575 __m128 _val3 = _mm_loadu_ps(m + 12);
576 _sum0 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[0])), _sum0);
577 _sum1 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[1])), _sum1);
578 _sum2 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[2])), _sum2);
579 _sum3 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[3])), _sum3);
580 _sum0 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[4])), _sum0);
581 _sum1 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[5])), _sum1);
582 _sum2 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[6])), _sum2);
583 _sum3 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[7])), _sum3);
584 _sum0 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[8])), _sum0);
585 _sum1 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[9])), _sum1);
586 _sum2 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[10])), _sum2);
587 _sum3 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[11])), _sum3);
588 _sum0 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[12])), _sum0);
589 _sum1 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[13])), _sum1);
590 _sum2 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[14])), _sum2);
591 _sum3 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[15])), _sum3);
592
593 m += 16;
594 kptr += 16;
595 }
596 for (; i < num_input; i++)
597 {
598 __m128 _val = _mm_loadu_ps(m);
599 _sum0 = _mm_add_ps(_mm_mul_ps(_val, _mm_set1_ps(kptr[0])), _sum0);
600 _sum1 = _mm_add_ps(_mm_mul_ps(_val, _mm_set1_ps(kptr[1])), _sum1);
601 _sum2 = _mm_add_ps(_mm_mul_ps(_val, _mm_set1_ps(kptr[2])), _sum2);
602 _sum3 = _mm_add_ps(_mm_mul_ps(_val, _mm_set1_ps(kptr[3])), _sum3);
603
604 m += 4;
605 kptr += 4;
606 }
607
608 _sum0 = activation_sse(_sum0, activation_type, activation_params);
609 _sum1 = activation_sse(_sum1, activation_type, activation_params);
610 _sum2 = activation_sse(_sum2, activation_type, activation_params);
611 _sum3 = activation_sse(_sum3, activation_type, activation_params);
612
613 _mm_storeu_ps(outptr, _sum0);
614 _mm_storeu_ps(outptr + 4, _sum1);
615 _mm_storeu_ps(outptr + 8, _sum2);
616 _mm_storeu_ps(outptr + 12, _sum3);
617 outptr += 16;
618 }
619 }
620
621 if (elempack == 1 && num_output_elempack == 4)
622 {
623 float* outptr = top_blob.row(j);
624
625 for (int p = 0; p < num_output / num_output_elempack; p++)
626 {
627 const float* kptr = (const float*)weight_data_packed + num_input * p * 4;
628 const float* m = bottom_blob.row(j);
629
630 __m128 _sum = _mm_set1_ps(0.f);
631
632 if (bias_term)
633 {
634 _sum = _mm_loadu_ps((const float*)bias_data + p * 4);
635 }
636
637 int i = 0;
638 #if __AVX__
639 for (; i + 7 < num_input; i += 8)
640 {
641 __m128 _val0 = _mm_broadcast_ss(m);
642 __m128 _val1 = _mm_broadcast_ss(m + 1);
643 __m128 _val2 = _mm_broadcast_ss(m + 2);
644 __m128 _val3 = _mm_broadcast_ss(m + 3);
645 __m128 _val4 = _mm_broadcast_ss(m + 4);
646 __m128 _val5 = _mm_broadcast_ss(m + 5);
647 __m128 _val6 = _mm_broadcast_ss(m + 6);
648 __m128 _val7 = _mm_broadcast_ss(m + 7);
649
650 __m128 _w0 = _mm_loadu_ps(kptr);
651 _sum = _mm_fmadd_ps(_val0, _w0, _sum);
652 __m128 _w1 = _mm_loadu_ps(kptr + 4);
653 _sum = _mm_fmadd_ps(_val1, _w1, _sum);
654 __m128 _w2 = _mm_loadu_ps(kptr + 8);
655 _sum = _mm_fmadd_ps(_val2, _w2, _sum);
656 __m128 _w3 = _mm_loadu_ps(kptr + 12);
657 _sum = _mm_fmadd_ps(_val3, _w3, _sum);
658 __m128 _w4 = _mm_loadu_ps(kptr + 16);
659 _sum = _mm_fmadd_ps(_val4, _w4, _sum);
660 __m128 _w5 = _mm_loadu_ps(kptr + 20);
661 _sum = _mm_fmadd_ps(_val5, _w5, _sum);
662 __m128 _w6 = _mm_loadu_ps(kptr + 24);
663 _sum = _mm_fmadd_ps(_val6, _w6, _sum);
664 __m128 _w7 = _mm_loadu_ps(kptr + 28);
665 _sum = _mm_fmadd_ps(_val7, _w7, _sum);
666
667 m += 8;
668 kptr += 32;
669 }
670 #endif // __AVX__
671 for (; i + 3 < num_input; i += 4)
672 {
673 __m128 _val0 = _mm_set1_ps(m[0]);
674 __m128 _val1 = _mm_set1_ps(m[1]);
675 __m128 _val2 = _mm_set1_ps(m[2]);
676 __m128 _val3 = _mm_set1_ps(m[3]);
677
678 __m128 _w0 = _mm_loadu_ps(kptr);
679 _sum = _mm_add_ps(_mm_mul_ps(_val0, _w0), _sum);
680 __m128 _w1 = _mm_loadu_ps(kptr + 4);
681 _sum = _mm_add_ps(_mm_mul_ps(_val1, _w1), _sum);
682 __m128 _w2 = _mm_loadu_ps(kptr + 8);
683 _sum = _mm_add_ps(_mm_mul_ps(_val2, _w2), _sum);
684 __m128 _w3 = _mm_loadu_ps(kptr + 12);
685 _sum = _mm_add_ps(_mm_mul_ps(_val3, _w3), _sum);
686
687 m += 4;
688 kptr += 16;
689 }
690 for (; i < num_input; i++)
691 {
692 __m128 _val = _mm_set1_ps(m[0]);
693 __m128 _k = _mm_loadu_ps(kptr);
694 _sum = _mm_add_ps(_mm_mul_ps(_val, _k), _sum);
695
696 m += 1;
697 kptr += 4;
698 }
699
700 _sum = activation_sse(_sum, activation_type, activation_params);
701
702 _mm_storeu_ps(outptr, _sum);
703 outptr += 4;
704 }
705 }
706
707 if (elempack == 4 && num_output_elempack == 1)
708 {
709 float* outptr = top_blob.row(j);
710
711 for (int p = 0; p < num_output; p++)
712 {
713 const float* kptr = (const float*)weight_data + num_input * p;
714 const float* m = bottom_blob.row(j);
715
716 __m128 _sum0 = _mm_set1_ps(0.f);
717 __m128 _sum1 = _mm_set1_ps(0.f);
718 __m128 _sum2 = _mm_set1_ps(0.f);
719 __m128 _sum3 = _mm_set1_ps(0.f);
720
721 if (bias_term)
722 {
723 _sum0 = _mm_set1_ps(bias_data[p]);
724 }
725
726 int i = 0;
727 for (; i + 7 < num_input; i += 8)
728 {
729 __m128 _val0 = _mm_loadu_ps(m);
730 __m128 _val1 = _mm_loadu_ps(m + 4);
731 __m128 _val2 = _mm_loadu_ps(m + 8);
732 __m128 _val3 = _mm_loadu_ps(m + 12);
733 __m128 _val4 = _mm_loadu_ps(m + 16);
734 __m128 _val5 = _mm_loadu_ps(m + 20);
735 __m128 _val6 = _mm_loadu_ps(m + 24);
736 __m128 _val7 = _mm_loadu_ps(m + 28);
737 _sum0 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[0])), _sum0);
738 _sum1 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[1])), _sum1);
739 _sum2 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[2])), _sum2);
740 _sum3 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[3])), _sum3);
741 _sum0 = _mm_add_ps(_mm_mul_ps(_val4, _mm_set1_ps(kptr[4])), _sum0);
742 _sum1 = _mm_add_ps(_mm_mul_ps(_val5, _mm_set1_ps(kptr[5])), _sum1);
743 _sum2 = _mm_add_ps(_mm_mul_ps(_val6, _mm_set1_ps(kptr[6])), _sum2);
744 _sum3 = _mm_add_ps(_mm_mul_ps(_val7, _mm_set1_ps(kptr[7])), _sum3);
745
746 m += 32;
747 kptr += 8;
748 }
749 for (; i + 3 < num_input; i += 4)
750 {
751 __m128 _val0 = _mm_loadu_ps(m);
752 __m128 _val1 = _mm_loadu_ps(m + 4);
753 __m128 _val2 = _mm_loadu_ps(m + 8);
754 __m128 _val3 = _mm_loadu_ps(m + 12);
755 _sum0 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[0])), _sum0);
756 _sum1 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[1])), _sum1);
757 _sum2 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[2])), _sum2);
758 _sum3 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[3])), _sum3);
759
760 m += 16;
761 kptr += 4;
762 }
763 for (; i < num_input; i++)
764 {
765 __m128 _val = _mm_loadu_ps(m);
766 __m128 _k = _mm_set1_ps(kptr[0]);
767 _sum0 = _mm_add_ps(_mm_mul_ps(_val, _k), _sum0);
768
769 m += 4;
770 kptr += 1;
771 }
772
773 _sum0 = _mm_add_ps(_sum0, _sum1);
774 _sum2 = _mm_add_ps(_sum2, _sum3);
775 _sum0 = _mm_add_ps(_sum0, _sum2);
776
777 _sum0 = activation_sse(_sum0, activation_type, activation_params);
778
779 _mm_storeu_ps(outptr, _sum0);
780 outptr += 4;
781 }
782 }
783 #endif // __SSE2__
784
785 if (elempack == 1 && num_output_elempack == 1)
786 {
787 float* outptr = top_blob.row(j);
788
789 for (int p = 0; p < num_output; p++)
790 {
791 const float* kptr = (const float*)weight_data + num_input * p;
792 const float* m = bottom_blob.row(j);
793
794 float sum = 0.f;
795
796 if (bias_term)
797 {
798 sum = bias_data[p];
799 }
800
801 int i = 0;
802 #if __SSE2__
803 #if __AVX__
804 __m256 _sum = _mm256_set1_ps(0.f);
805 for (; i + 7 < num_input; i += 8)
806 {
807 __m256 _m = _mm256_loadu_ps(m);
808 __m256 _w = _mm256_loadu_ps(kptr);
809 _sum = _mm256_fmadd_ps(_m, _w, _sum);
810
811 m += 8;
812 kptr += 8;
813 }
814 #endif // __AVX__
815 __m128 _suml = _mm_set1_ps(0.f);
816 for (; i + 3 < num_input; i += 4)
817 {
818 __m128 _val = _mm_loadu_ps(m);
819 __m128 _k = _mm_loadu_ps(kptr);
820 _suml = _mm_add_ps(_mm_mul_ps(_val, _k), _suml);
821
822 m += 4;
823 kptr += 4;
824 }
825 #endif // __SSE2__
826 for (; i < num_input; i++)
827 {
828 sum += *m++ * *kptr++;
829 }
830
831 #if __SSE2__
832 #if __AVX__
833 sum += _mm256_reduce_add_ps(_sum);
834 #endif // __AVX__
835 sum += _mm_reduce_add_ps(_suml);
836 #endif // __SSE2__
837
838 sum = activation_ss(sum, activation_type, activation_params);
839
840 outptr[0] = sum;
841 outptr += 1;
842 }
843 }
844 }
845
846 return 0;
847 }
848
849 #if __AVX__
850 if (opt.use_weight_fp16_storage)
851 {
852 return forward_fp16(bottom_blob, top_blob, opt);
853 }
854 #endif // __AVX__
855
856 // flatten
857 Mat bottom_blob_flattened = bottom_blob;
858 if (bottom_blob.dims != 1)
859 {
860 Option opt_flatten = opt;
861 opt_flatten.blob_allocator = opt.workspace_allocator;
862
863 flatten->forward(bottom_blob, bottom_blob_flattened, opt_flatten);
864 }
865
866 size_t elemsize = bottom_blob_flattened.elemsize;
867 int elempack = bottom_blob_flattened.elempack;
868
869 int out_elempack = 1;
870 #if __SSE2__
871 if (opt.use_packing_layout)
872 {
873 #if __AVX__
874 out_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;
875 #else
876 out_elempack = num_output % 4 == 0 ? 4 : 1;
877 #endif
878 }
879 #endif // __SSE2__
880 size_t out_elemsize = elemsize / elempack * out_elempack;
881
882 top_blob.create(num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
883 if (top_blob.empty())
884 return -100;
885
886 #if __SSE2__
887 #if __AVX__
888 if (out_elempack == 8)
889 {
890 // num_output
891 #pragma omp parallel for num_threads(opt.num_threads)
892 for (int p = 0; p < num_output / out_elempack; p++)
893 {
894 __m256 _sum0 = _mm256_set1_ps(0.f);
895 __m256 _sum1 = _mm256_set1_ps(0.f);
896 __m256 _sum2 = _mm256_set1_ps(0.f);
897 __m256 _sum3 = _mm256_set1_ps(0.f);
898 __m256 _sum4 = _mm256_set1_ps(0.f);
899 __m256 _sum5 = _mm256_set1_ps(0.f);
900 __m256 _sum6 = _mm256_set1_ps(0.f);
901 __m256 _sum7 = _mm256_set1_ps(0.f);
902
903 if (bias_term)
904 {
905 _sum0 = _mm256_loadu_ps((const float*)bias_data + p * 8);
906 }
907
908 const float* kptr = weight_data_packed.row(p);
909
910 const float* sptr = bottom_blob_flattened;
911
912 int i = 0;
913 for (; i + 7 < num_input; i += 8)
914 {
915 __m256 _val0 = _mm256_broadcast_ss(sptr);
916 __m256 _val1 = _mm256_broadcast_ss(sptr + 1);
917 __m256 _val2 = _mm256_broadcast_ss(sptr + 2);
918 __m256 _val3 = _mm256_broadcast_ss(sptr + 3);
919 __m256 _val4 = _mm256_broadcast_ss(sptr + 4);
920 __m256 _val5 = _mm256_broadcast_ss(sptr + 5);
921 __m256 _val6 = _mm256_broadcast_ss(sptr + 6);
922 __m256 _val7 = _mm256_broadcast_ss(sptr + 7);
923
924 __m256 _w0 = _mm256_loadu_ps(kptr);
925 _sum0 = _mm256_fmadd_ps(_val0, _w0, _sum0);
926 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
927 _sum1 = _mm256_fmadd_ps(_val1, _w1, _sum1);
928 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
929 _sum2 = _mm256_fmadd_ps(_val2, _w2, _sum2);
930 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
931 _sum3 = _mm256_fmadd_ps(_val3, _w3, _sum3);
932 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
933 _sum4 = _mm256_fmadd_ps(_val4, _w4, _sum4);
934 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
935 _sum5 = _mm256_fmadd_ps(_val5, _w5, _sum5);
936 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
937 _sum6 = _mm256_fmadd_ps(_val6, _w6, _sum6);
938 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
939 _sum7 = _mm256_fmadd_ps(_val7, _w7, _sum7);
940
941 sptr += 8;
942 kptr += 64;
943 }
944 for (; i + 3 < num_input; i += 4)
945 {
946 __m256 _val0 = _mm256_broadcast_ss(sptr);
947 __m256 _val1 = _mm256_broadcast_ss(sptr + 1);
948 __m256 _val2 = _mm256_broadcast_ss(sptr + 2);
949 __m256 _val3 = _mm256_broadcast_ss(sptr + 3);
950
951 __m256 _w0 = _mm256_loadu_ps(kptr);
952 _sum0 = _mm256_fmadd_ps(_val0, _w0, _sum0);
953 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
954 _sum1 = _mm256_fmadd_ps(_val1, _w1, _sum1);
955 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
956 _sum2 = _mm256_fmadd_ps(_val2, _w2, _sum2);
957 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
958 _sum3 = _mm256_fmadd_ps(_val3, _w3, _sum3);
959
960 sptr += 4;
961 kptr += 32;
962 }
963 for (; i < num_input; i++)
964 {
965 __m256 _val = _mm256_set1_ps(sptr[0]);
966 __m256 _w = _mm256_loadu_ps(kptr);
967 _sum0 = _mm256_fmadd_ps(_val, _w, _sum0);
968
969 sptr += 1;
970 kptr += 8;
971 }
972
973 _sum0 = _mm256_add_ps(_sum0, _sum1);
974 _sum2 = _mm256_add_ps(_sum2, _sum3);
975 _sum4 = _mm256_add_ps(_sum4, _sum5);
976 _sum6 = _mm256_add_ps(_sum6, _sum7);
977 _sum0 = _mm256_add_ps(_sum0, _sum2);
978 _sum4 = _mm256_add_ps(_sum4, _sum6);
979 _sum0 = _mm256_add_ps(_sum0, _sum4);
980
981 _sum0 = activation_avx(_sum0, activation_type, activation_params);
982
983 float* outptr = top_blob;
984 _mm256_storeu_ps(outptr + p * 8, _sum0);
985 }
986 }
987 #endif // __AVX__
988
989 if (out_elempack == 4)
990 {
991 // num_output
992 #pragma omp parallel for num_threads(opt.num_threads)
993 for (int p = 0; p < num_output / out_elempack; p++)
994 {
995 __m128 _sum0 = _mm_set1_ps(0.f);
996 __m128 _sum1 = _mm_set1_ps(0.f);
997 __m128 _sum2 = _mm_set1_ps(0.f);
998 __m128 _sum3 = _mm_set1_ps(0.f);
999 #if __AVX__
1000 __m128 _sum4 = _mm_set1_ps(0.f);
1001 __m128 _sum5 = _mm_set1_ps(0.f);
1002 __m128 _sum6 = _mm_set1_ps(0.f);
1003 __m128 _sum7 = _mm_set1_ps(0.f);
1004 #endif
1005
1006 if (bias_term)
1007 {
1008 _sum0 = _mm_loadu_ps((const float*)bias_data + p * 4);
1009 }
1010
1011 const float* kptr = weight_data_packed.row(p);
1012
1013 const float* sptr = bottom_blob_flattened;
1014
1015 int i = 0;
1016 #if __AVX__
1017 for (; i + 7 < num_input; i += 8)
1018 {
1019 __m128 _val0 = _mm_broadcast_ss(sptr);
1020 __m128 _val1 = _mm_broadcast_ss(sptr + 1);
1021 __m128 _val2 = _mm_broadcast_ss(sptr + 2);
1022 __m128 _val3 = _mm_broadcast_ss(sptr + 3);
1023 __m128 _val4 = _mm_broadcast_ss(sptr + 4);
1024 __m128 _val5 = _mm_broadcast_ss(sptr + 5);
1025 __m128 _val6 = _mm_broadcast_ss(sptr + 6);
1026 __m128 _val7 = _mm_broadcast_ss(sptr + 7);
1027
1028 __m128 _w0 = _mm_loadu_ps(kptr);
1029 _sum0 = _mm_fmadd_ps(_val0, _w0, _sum0);
1030 __m128 _w1 = _mm_loadu_ps(kptr + 4);
1031 _sum1 = _mm_fmadd_ps(_val1, _w1, _sum1);
1032 __m128 _w2 = _mm_loadu_ps(kptr + 8);
1033 _sum2 = _mm_fmadd_ps(_val2, _w2, _sum2);
1034 __m128 _w3 = _mm_loadu_ps(kptr + 12);
1035 _sum3 = _mm_fmadd_ps(_val3, _w3, _sum3);
1036 __m128 _w4 = _mm_loadu_ps(kptr + 16);
1037 _sum4 = _mm_fmadd_ps(_val4, _w4, _sum4);
1038 __m128 _w5 = _mm_loadu_ps(kptr + 20);
1039 _sum5 = _mm_fmadd_ps(_val5, _w5, _sum5);
1040 __m128 _w6 = _mm_loadu_ps(kptr + 24);
1041 _sum6 = _mm_fmadd_ps(_val6, _w6, _sum6);
1042 __m128 _w7 = _mm_loadu_ps(kptr + 28);
1043 _sum7 = _mm_fmadd_ps(_val7, _w7, _sum7);
1044
1045 sptr += 8;
1046 kptr += 32;
1047 }
1048 #endif
1049 for (; i + 3 < num_input; i += 4)
1050 {
1051 __m128 _val0 = _mm_set1_ps(sptr[0]);
1052 __m128 _val1 = _mm_set1_ps(sptr[1]);
1053 __m128 _val2 = _mm_set1_ps(sptr[2]);
1054 __m128 _val3 = _mm_set1_ps(sptr[3]);
1055
1056 __m128 _w0 = _mm_loadu_ps(kptr);
1057 _sum0 = _mm_add_ps(_mm_mul_ps(_val0, _w0), _sum0);
1058 __m128 _w1 = _mm_loadu_ps(kptr + 4);
1059 _sum1 = _mm_add_ps(_mm_mul_ps(_val1, _w1), _sum1);
1060 __m128 _w2 = _mm_loadu_ps(kptr + 8);
1061 _sum2 = _mm_add_ps(_mm_mul_ps(_val2, _w2), _sum2);
1062 __m128 _w3 = _mm_loadu_ps(kptr + 12);
1063 _sum3 = _mm_add_ps(_mm_mul_ps(_val3, _w3), _sum3);
1064
1065 sptr += 4;
1066 kptr += 16;
1067 }
1068 for (; i < num_input; i++)
1069 {
1070 __m128 _val = _mm_set1_ps(sptr[0]);
1071 __m128 _w = _mm_loadu_ps(kptr);
1072 _sum0 = _mm_add_ps(_mm_mul_ps(_val, _w), _sum0);
1073
1074 sptr += 1;
1075 kptr += 4;
1076 }
1077
1078 _sum0 = _mm_add_ps(_sum0, _sum1);
1079 _sum2 = _mm_add_ps(_sum2, _sum3);
1080 #if __AVX__
1081 _sum4 = _mm_add_ps(_sum4, _sum5);
1082 _sum6 = _mm_add_ps(_sum6, _sum7);
1083 #endif
1084 _sum0 = _mm_add_ps(_sum0, _sum2);
1085 #if __AVX__
1086 _sum4 = _mm_add_ps(_sum4, _sum6);
1087 _sum0 = _mm_add_ps(_sum0, _sum4);
1088 #endif
1089
1090 _sum0 = activation_sse(_sum0, activation_type, activation_params);
1091
1092 float* outptr = top_blob;
1093 _mm_storeu_ps(outptr + p * 4, _sum0);
1094 }
1095 }
1096 #endif // __SSE2__
1097
1098 if (out_elempack == 1)
1099 {
1100 #if __SSE2__
1101 #if __AVX__
1102 int remain_num_output_start = 0;
1103 int nn_num_output = num_output >> 3;
1104
1105 #pragma omp parallel for num_threads(opt.num_threads)
1106 for (int pp = 0; pp < nn_num_output; pp++)
1107 {
1108 int p = pp * 8;
1109
1110 float sums[8] = {0.0f};
1111 if (bias_term)
1112 {
1113 sums[0] = bias_data[p];
1114 sums[1] = bias_data[p + 1];
1115 sums[2] = bias_data[p + 2];
1116 sums[3] = bias_data[p + 3];
1117 sums[4] = bias_data[p + 4];
1118 sums[5] = bias_data[p + 5];
1119 sums[6] = bias_data[p + 6];
1120 sums[7] = bias_data[p + 7];
1121 }
1122
1123 const float* w0 = (const float*)weight_data + num_input * p;
1124 const float* w1 = (const float*)weight_data + num_input * (p + 1);
1125 const float* w2 = (const float*)weight_data + num_input * (p + 2);
1126 const float* w3 = (const float*)weight_data + num_input * (p + 3);
1127 const float* w4 = (const float*)weight_data + num_input * (p + 4);
1128 const float* w5 = (const float*)weight_data + num_input * (p + 5);
1129 const float* w6 = (const float*)weight_data + num_input * (p + 6);
1130 const float* w7 = (const float*)weight_data + num_input * (p + 7);
1131
1132 const float* m = bottom_blob_flattened;
1133
1134 __m256 _sum0 = _mm256_set1_ps(0.f);
1135 __m256 _sum1 = _mm256_set1_ps(0.f);
1136 __m256 _sum2 = _mm256_set1_ps(0.f);
1137 __m256 _sum3 = _mm256_set1_ps(0.f);
1138 __m256 _sum4 = _mm256_set1_ps(0.f);
1139 __m256 _sum5 = _mm256_set1_ps(0.f);
1140 __m256 _sum6 = _mm256_set1_ps(0.f);
1141 __m256 _sum7 = _mm256_set1_ps(0.f);
1142
1143 int i = 0;
1144 for (; i + 7 < num_input; i += 8)
1145 {
1146 __m256 _m = _mm256_loadu_ps(m);
1147
1148 __m256 _w0 = _mm256_loadu_ps(w0);
1149 _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1150 __m256 _w1 = _mm256_loadu_ps(w1);
1151 _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1152 __m256 _w2 = _mm256_loadu_ps(w2);
1153 _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1154 __m256 _w3 = _mm256_loadu_ps(w3);
1155 _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1156 __m256 _w4 = _mm256_loadu_ps(w4);
1157 _sum4 = _mm256_fmadd_ps(_m, _w4, _sum4);
1158 __m256 _w5 = _mm256_loadu_ps(w5);
1159 _sum5 = _mm256_fmadd_ps(_m, _w5, _sum5);
1160 __m256 _w6 = _mm256_loadu_ps(w6);
1161 _sum6 = _mm256_fmadd_ps(_m, _w6, _sum6);
1162 __m256 _w7 = _mm256_loadu_ps(w7);
1163 _sum7 = _mm256_fmadd_ps(_m, _w7, _sum7);
1164
1165 m += 8;
1166 w0 += 8;
1167 w1 += 8;
1168 w2 += 8;
1169 w3 += 8;
1170 w4 += 8;
1171 w5 += 8;
1172 w6 += 8;
1173 w7 += 8;
1174 }
1175 for (; i < num_input; i++)
1176 {
1177 sums[0] += *m * *w0;
1178 sums[1] += *m * *w1;
1179 sums[2] += *m * *w2;
1180 sums[3] += *m * *w3;
1181 sums[4] += *m * *w4;
1182 sums[5] += *m * *w5;
1183 sums[6] += *m * *w6;
1184 sums[7] += *m * *w7;
1185
1186 m++;
1187 w0++;
1188 w1++;
1189 w2++;
1190 w3++;
1191 w4++;
1192 w5++;
1193 w6++;
1194 w7++;
1195 }
1196
1197 __m256 _sums = HorizontalSums(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7);
1198 __m256 _sums_f = _mm256_loadu_ps(sums);
1199 _sums = _mm256_add_ps(_sums_f, _sums);
1200 _sums = activation_avx(_sums, activation_type, activation_params);
1201
1202 float* outptr = top_blob;
1203 _mm256_storeu_ps(outptr + p, _sums);
1204 }
1205
1206 remain_num_output_start += (nn_num_output << 3);
1207 nn_num_output = (num_output - remain_num_output_start) >> 2;
1208 #else
1209 int remain_num_output_start = 0;
1210 int nn_num_output = num_output >> 2;
1211 #endif // __AVX__
1212
1213 #pragma omp parallel for num_threads(opt.num_threads)
1214 for (int pp = 0; pp < nn_num_output; pp++)
1215 {
1216 int p = remain_num_output_start + (pp * 4);
1217
1218 float sums[4] = {0.0f};
1219 if (bias_term)
1220 {
1221 sums[0] = bias_data[p];
1222 sums[1] = bias_data[p + 1];
1223 sums[2] = bias_data[p + 2];
1224 sums[3] = bias_data[p + 3];
1225 }
1226
1227 const float* w0 = (const float*)weight_data + num_input * p;
1228 const float* w1 = (const float*)weight_data + num_input * (p + 1);
1229 const float* w2 = (const float*)weight_data + num_input * (p + 2);
1230 const float* w3 = (const float*)weight_data + num_input * (p + 3);
1231
1232 const float* m = bottom_blob_flattened;
1233
1234 int i = 0;
1235 #if __AVX__
1236 __m256 _sum0 = _mm256_set1_ps(0.f);
1237 __m256 _sum1 = _mm256_set1_ps(0.f);
1238 __m256 _sum2 = _mm256_set1_ps(0.f);
1239 __m256 _sum3 = _mm256_set1_ps(0.f);
1240 for (; i + 7 < num_input; i += 8)
1241 {
1242 __m256 _m = _mm256_loadu_ps(m);
1243
1244 __m256 _w0 = _mm256_loadu_ps(w0);
1245 _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1246 __m256 _w1 = _mm256_loadu_ps(w1);
1247 _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1248 __m256 _w2 = _mm256_loadu_ps(w2);
1249 _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1250 __m256 _w3 = _mm256_loadu_ps(w3);
1251 _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1252
1253 m += 8;
1254 w0 += 8;
1255 w1 += 8;
1256 w2 += 8;
1257 w3 += 8;
1258 }
1259 #endif // __AVX__
1260 __m128 _sum0l = _mm_set1_ps(0.f);
1261 __m128 _sum1l = _mm_set1_ps(0.f);
1262 __m128 _sum2l = _mm_set1_ps(0.f);
1263 __m128 _sum3l = _mm_set1_ps(0.f);
1264 for (; i + 3 < num_input; i += 4)
1265 {
1266 __m128 _m = _mm_loadu_ps(m);
1267
1268 __m128 _w0 = _mm_loadu_ps(w0);
1269 _sum0l = _mm_add_ps(_mm_mul_ps(_m, _w0), _sum0l);
1270 __m128 _w1 = _mm_loadu_ps(w1);
1271 _sum1l = _mm_add_ps(_mm_mul_ps(_m, _w1), _sum1l);
1272 __m128 _w2 = _mm_loadu_ps(w2);
1273 _sum2l = _mm_add_ps(_mm_mul_ps(_m, _w2), _sum2l);
1274 __m128 _w3 = _mm_loadu_ps(w3);
1275 _sum3l = _mm_add_ps(_mm_mul_ps(_m, _w3), _sum3l);
1276
1277 m += 4;
1278 w0 += 4;
1279 w1 += 4;
1280 w2 += 4;
1281 w3 += 4;
1282 }
1283 for (; i < num_input; i++)
1284 {
1285 sums[0] += *m * *w0;
1286 sums[1] += *m * *w1;
1287 sums[2] += *m * *w2;
1288 sums[3] += *m * *w3;
1289
1290 m++;
1291 w0++;
1292 w1++;
1293 w2++;
1294 w3++;
1295 }
1296
1297 __m128 _sums = _mm_loadu_ps(sums);
1298 #if __AVX__
1299 _sums = _mm_add_ps(HorizontalSums(_sum0, _sum1, _sum2, _sum3), _sums);
1300 #endif
1301 _MM_TRANSPOSE4_PS(_sum0l, _sum1l, _sum2l, _sum3l);
1302 _sums = _mm_add_ps(_sum0l, _sums);
1303 _sums = _mm_add_ps(_sum1l, _sums);
1304 _sums = _mm_add_ps(_sum2l, _sums);
1305 _sums = _mm_add_ps(_sum3l, _sums);
1306 _sums = activation_sse(_sums, activation_type, activation_params);
1307
1308 float* outptr = top_blob;
1309 _mm_storeu_ps(outptr + p, _sums);
1310 }
1311
1312 remain_num_output_start += (nn_num_output << 2);
1313 #else
1314 int remain_num_output_start = 0;
1315 #endif // __SSE2__
1316
1317 // num_output
1318 #pragma omp parallel for num_threads(opt.num_threads)
1319 for (int p = remain_num_output_start; p < num_output; p++)
1320 {
1321 float sum = 0.f;
1322
1323 if (bias_term)
1324 sum = bias_data[p];
1325
1326 const float* w = (const float*)weight_data + num_input * p;
1327
1328 const float* m = bottom_blob_flattened;
1329
1330 int i = 0;
1331 #if __SSE2__
1332 #if __AVX__
1333 __m256 _sum = _mm256_set1_ps(0.f);
1334 for (; i + 7 < num_input; i += 8)
1335 {
1336 __m256 _m = _mm256_loadu_ps(m);
1337
1338 __m256 _w = _mm256_loadu_ps(w);
1339 _sum = _mm256_fmadd_ps(_m, _w, _sum);
1340
1341 m += 8;
1342 w += 8;
1343 }
1344 #endif // __AVX__
1345 __m128 _suml = _mm_set1_ps(0.f);
1346 for (; i + 3 < num_input; i += 4)
1347 {
1348 __m128 _m = _mm_loadu_ps(m);
1349
1350 __m128 _w = _mm_loadu_ps(w);
1351 _suml = _mm_add_ps(_mm_mul_ps(_m, _w), _suml);
1352
1353 m += 4;
1354 w += 4;
1355 }
1356 #endif // __SSE2__
1357 for (; i < num_input; i++)
1358 {
1359 sum += *m * *w;
1360 m++;
1361 w++;
1362 }
1363
1364 #if __SSE2__
1365 #if __AVX__
1366 sum += _mm256_reduce_add_ps(_sum);
1367 #endif
1368 sum += _mm_reduce_add_ps(_suml);
1369 #endif // __SSE2__
1370
1371 sum = activation_ss(sum, activation_type, activation_params);
1372
1373 float* outptr = top_blob;
1374 outptr[p] = sum;
1375 }
1376 }
1377
1378 return 0;
1379 }
1380 #if __AVX__
1381
forward_fp16(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1382 int InnerProduct_x86::forward_fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1383 {
1384 // flatten
1385 Mat bottom_blob_flattened = bottom_blob;
1386 if (bottom_blob.dims != 1)
1387 {
1388 Option opt_flatten = opt;
1389 opt_flatten.blob_allocator = opt.workspace_allocator;
1390
1391 flatten->forward(bottom_blob, bottom_blob_flattened, opt_flatten);
1392 }
1393
1394 // pack1
1395 {
1396 bottom_blob_flattened.w *= bottom_blob_flattened.elempack;
1397 bottom_blob_flattened.cstep = bottom_blob_flattened.w;
1398 bottom_blob_flattened.elemsize = 4u;
1399 bottom_blob_flattened.elempack = 1;
1400 }
1401
1402 int w = bottom_blob_flattened.w;
1403 int h = bottom_blob_flattened.h;
1404 size_t elemsize = bottom_blob_flattened.elemsize;
1405 int size = w * h;
1406 top_blob.create(num_output, elemsize, opt.blob_allocator);
1407 if (top_blob.empty())
1408 return -100;
1409
1410 const unsigned short* weight_data_ptr = (const unsigned short*)weight_data_fp16;
1411 float* output_ptr = top_blob;
1412 int nn_num_output = num_output >> 3;
1413 int remain_num_output_start = nn_num_output << 3;
1414
1415 #pragma omp parallel for num_threads(opt.num_threads)
1416 for (int pp = 0; pp < nn_num_output; pp++)
1417 {
1418 int p = pp * 8;
1419
1420 float sums[8] = {0.0f};
1421 if (bias_term)
1422 {
1423 sums[0] = bias_data[p];
1424 sums[1] = bias_data[p + 1];
1425 sums[2] = bias_data[p + 2];
1426 sums[3] = bias_data[p + 3];
1427 sums[4] = bias_data[p + 4];
1428 sums[5] = bias_data[p + 5];
1429 sums[6] = bias_data[p + 6];
1430 sums[7] = bias_data[p + 7];
1431 }
1432 __m256 _sum0 = _mm256_set1_ps(0.f);
1433 __m256 _sum1 = _mm256_set1_ps(0.f);
1434 __m256 _sum2 = _mm256_set1_ps(0.f);
1435 __m256 _sum3 = _mm256_set1_ps(0.f);
1436 __m256 _sum4 = _mm256_set1_ps(0.f);
1437 __m256 _sum5 = _mm256_set1_ps(0.f);
1438 __m256 _sum6 = _mm256_set1_ps(0.f);
1439 __m256 _sum7 = _mm256_set1_ps(0.f);
1440
1441 const unsigned short* w0 = (const unsigned short*)weight_data_ptr + size * p;
1442 const unsigned short* w1 = (const unsigned short*)weight_data_ptr + size * (p + 1);
1443 const unsigned short* w2 = (const unsigned short*)weight_data_ptr + size * (p + 2);
1444 const unsigned short* w3 = (const unsigned short*)weight_data_ptr + size * (p + 3);
1445 const unsigned short* w4 = (const unsigned short*)weight_data_ptr + size * (p + 4);
1446 const unsigned short* w5 = (const unsigned short*)weight_data_ptr + size * (p + 5);
1447 const unsigned short* w6 = (const unsigned short*)weight_data_ptr + size * (p + 6);
1448 const unsigned short* w7 = (const unsigned short*)weight_data_ptr + size * (p + 7);
1449
1450 const float* m = bottom_blob_flattened;
1451 int nn = size >> 3;
1452 int remain = size & 7;
1453
1454 for (; nn > 0; nn--)
1455 {
1456 __m256 _m = _mm256_loadu_ps(m);
1457
1458 __m256 _w0 = loadfp16(w0);
1459 _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1460
1461 __m256 _w1 = loadfp16(w1);
1462 _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1463
1464 __m256 _w2 = loadfp16(w2);
1465 _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1466
1467 __m256 _w3 = loadfp16(w3);
1468 _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1469
1470 __m256 _w4 = loadfp16(w4);
1471 _sum4 = _mm256_fmadd_ps(_m, _w4, _sum4);
1472
1473 __m256 _w5 = loadfp16(w5);
1474 _sum5 = _mm256_fmadd_ps(_m, _w5, _sum5);
1475
1476 __m256 _w6 = loadfp16(w6);
1477 _sum6 = _mm256_fmadd_ps(_m, _w6, _sum6);
1478
1479 __m256 _w7 = loadfp16(w7);
1480 _sum7 = _mm256_fmadd_ps(_m, _w7, _sum7);
1481
1482 m += 8;
1483 w0 += 8;
1484 w1 += 8;
1485 w2 += 8;
1486 w3 += 8;
1487 w4 += 8;
1488 w5 += 8;
1489 w6 += 8;
1490 w7 += 8;
1491 }
1492 if (remain != 0)
1493 {
1494 unsigned short fp16_weights[8][8] = {{0}};
1495 float _m_f[8] = {0};
1496 int i = 0;
1497 // No fast way to convert to fp32 one element at the time
1498 // so batch an 8 lane vector.
1499 for (; remain > 0; remain--)
1500 {
1501 _m_f[i] = *m;
1502 fp16_weights[0][i] = *w0;
1503 fp16_weights[1][i] = *w1;
1504 fp16_weights[2][i] = *w2;
1505 fp16_weights[3][i] = *w3;
1506 fp16_weights[4][i] = *w4;
1507 fp16_weights[5][i] = *w5;
1508 fp16_weights[6][i] = *w6;
1509 fp16_weights[7][i] = *w7;
1510 i++;
1511 m++;
1512 w0++;
1513 w1++;
1514 w2++;
1515 w3++;
1516 w4++;
1517 w5++;
1518 w6++;
1519 w7++;
1520 }
1521 __m256 _m = _mm256_loadu_ps(_m_f);
1522
1523 __m256 _w0 = loadfp16(fp16_weights[0]);
1524 _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1525
1526 __m256 _w1 = loadfp16(fp16_weights[1]);
1527 _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1528
1529 __m256 _w2 = loadfp16(fp16_weights[2]);
1530 _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1531
1532 __m256 _w3 = loadfp16(fp16_weights[3]);
1533 _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1534
1535 __m256 _w4 = loadfp16(fp16_weights[4]);
1536 _sum4 = _mm256_fmadd_ps(_m, _w4, _sum4);
1537
1538 __m256 _w5 = loadfp16(fp16_weights[5]);
1539 _sum5 = _mm256_fmadd_ps(_m, _w5, _sum5);
1540
1541 __m256 _w6 = loadfp16(fp16_weights[6]);
1542 _sum6 = _mm256_fmadd_ps(_m, _w6, _sum6);
1543
1544 __m256 _w7 = loadfp16(fp16_weights[7]);
1545 _sum7 = _mm256_fmadd_ps(_m, _w7, _sum7);
1546 }
1547
1548 __m256 _sums = HorizontalSums(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7);
1549 __m256 _sums_f = _mm256_loadu_ps(sums);
1550 _sums = activation_avx(_mm256_add_ps(_sums_f, _sums), activation_type, activation_params);
1551 _mm256_storeu_ps(output_ptr + p, _sums);
1552 }
1553
1554 nn_num_output = (num_output - remain_num_output_start) >> 2;
1555 int nn_offset = remain_num_output_start;
1556 remain_num_output_start += (nn_num_output << 2);
1557
1558 #pragma omp parallel for num_threads(opt.num_threads)
1559 for (int pp = 0; pp < nn_num_output; pp++)
1560 {
1561 int p = nn_offset + (pp * 4);
1562
1563 float sums[4] = {0.0f};
1564 if (bias_term)
1565 {
1566 sums[0] = bias_data[p];
1567 sums[1] = bias_data[p + 1];
1568 sums[2] = bias_data[p + 2];
1569 sums[3] = bias_data[p + 3];
1570 }
1571 __m256 _sum0 = _mm256_set1_ps(0.f);
1572 __m256 _sum1 = _mm256_set1_ps(0.f);
1573 __m256 _sum2 = _mm256_set1_ps(0.f);
1574 __m256 _sum3 = _mm256_set1_ps(0.f);
1575
1576 const unsigned short* w0 = (const unsigned short*)weight_data_ptr + size * p;
1577 const unsigned short* w1 = (const unsigned short*)weight_data_ptr + size * (p + 1);
1578 const unsigned short* w2 = (const unsigned short*)weight_data_ptr + size * (p + 2);
1579 const unsigned short* w3 = (const unsigned short*)weight_data_ptr + size * (p + 3);
1580
1581 const float* m = bottom_blob_flattened;
1582 int nn = size >> 3;
1583 int remain = size & 7;
1584
1585 for (; nn > 0; nn--)
1586 {
1587 __m256 _m = _mm256_loadu_ps(m);
1588
1589 __m256 _w0 = loadfp16(w0);
1590 _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1591
1592 __m256 _w1 = loadfp16(w1);
1593 _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1594
1595 __m256 _w2 = loadfp16(w2);
1596 _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1597
1598 __m256 _w3 = loadfp16(w3);
1599 _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1600
1601 m += 8;
1602 w0 += 8;
1603 w1 += 8;
1604 w2 += 8;
1605 w3 += 8;
1606 }
1607 if (remain != 0)
1608 {
1609 unsigned short fp16_weights[4][8] = {{0}};
1610 float _m_f[8] = {0};
1611 int i = 0;
1612 for (; remain > 0; remain--)
1613 {
1614 _m_f[i] = *m;
1615 fp16_weights[0][i] = *w0;
1616 fp16_weights[1][i] = *w1;
1617 fp16_weights[2][i] = *w2;
1618 fp16_weights[3][i] = *w3;
1619 i++;
1620 m++;
1621 w0++;
1622 w1++;
1623 w2++;
1624 w3++;
1625 }
1626 __m256 _m = _mm256_loadu_ps(_m_f);
1627
1628 __m256 _w0 = loadfp16(fp16_weights[0]);
1629 _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1630
1631 __m256 _w1 = loadfp16(fp16_weights[1]);
1632 _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1633
1634 __m256 _w2 = loadfp16(fp16_weights[2]);
1635 _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1636
1637 __m256 _w3 = loadfp16(fp16_weights[3]);
1638 _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1639 }
1640
1641 __m128 _sums = HorizontalSums(_sum0, _sum1, _sum2, _sum3);
1642 __m256 _sums_a = activation_avx(_mm256_castps128_ps256(_mm_add_ps(_mm_loadu_ps(sums), _sums)), activation_type, activation_params);
1643 _mm_storeu_ps(output_ptr + p, _mm256_castps256_ps128(_sums_a));
1644 }
1645
1646 // num_output
1647 #pragma omp parallel for num_threads(opt.num_threads)
1648 for (int p = remain_num_output_start; p < num_output; p++)
1649 {
1650 float sum = 0.f;
1651
1652 if (bias_term)
1653 sum = bias_data[p];
1654
1655 const unsigned short* w = (const unsigned short*)weight_data_ptr + size * p;
1656
1657 __m256 _sum = _mm256_set1_ps(0.f);
1658
1659 const float* m = bottom_blob_flattened;
1660
1661 int nn = size >> 3;
1662 int remain = size & 7;
1663 for (; nn > 0; nn--)
1664 {
1665 __m256 _m = _mm256_loadu_ps(m);
1666
1667 __m256 _w = loadfp16(w);
1668 _sum = _mm256_fmadd_ps(_m, _w, _sum);
1669
1670 m += 8;
1671 w += 8;
1672 }
1673 if (remain != 0)
1674 {
1675 unsigned short fp16_weights[8] = {0};
1676 float _m_f[8] = {0};
1677 int i = 0;
1678 for (; remain > 0; remain--)
1679 {
1680 _m_f[i] = *m;
1681 fp16_weights[i] = *w;
1682 i++;
1683 m++;
1684 w++;
1685 }
1686 __m256 _m = _mm256_loadu_ps(_m_f);
1687
1688 __m256 _w = loadfp16(fp16_weights);
1689 _sum = _mm256_fmadd_ps(_m, _w, _sum);
1690 }
1691
1692 sum += _mm256_reduce_add_ps(_sum);
1693 sum = activation_ss(sum, activation_type, activation_params);
1694
1695 output_ptr[p] = sum;
1696 }
1697 return 0;
1698 }
1699 #endif // __AVX__
1700
1701 #if NCNN_INT8
create_pipeline_int8_x86(const Option & opt)1702 int InnerProduct_x86::create_pipeline_int8_x86(const Option& opt)
1703 {
1704 if (activation_type == 1)
1705 {
1706 activation = ncnn::create_layer(ncnn::LayerType::ReLU);
1707
1708 ncnn::ParamDict pd;
1709 activation->load_param(pd);
1710 }
1711 else if (activation_type == 2)
1712 {
1713 activation = ncnn::create_layer(ncnn::LayerType::ReLU);
1714
1715 ncnn::ParamDict pd;
1716 pd.set(0, activation_params[0]); // slope
1717 activation->load_param(pd);
1718 }
1719 else if (activation_type == 3)
1720 {
1721 activation = ncnn::create_layer(ncnn::LayerType::Clip);
1722
1723 ncnn::ParamDict pd;
1724 pd.set(0, activation_params[0]); // min
1725 pd.set(1, activation_params[1]); // max
1726 activation->load_param(pd);
1727 }
1728 else if (activation_type == 4)
1729 {
1730 activation = ncnn::create_layer(ncnn::LayerType::Sigmoid);
1731
1732 ncnn::ParamDict pd;
1733 activation->load_param(pd);
1734 }
1735 else if (activation_type == 5)
1736 {
1737 activation = ncnn::create_layer(ncnn::LayerType::Mish);
1738
1739 ncnn::ParamDict pd;
1740 activation->load_param(pd);
1741 }
1742
1743 if (activation)
1744 {
1745 activation->create_pipeline(opt);
1746 }
1747
1748 const int num_input = weight_data_size / num_output;
1749
1750 int out_elempack = 1;
1751 #if __SSE2__
1752 if (opt.use_packing_layout)
1753 {
1754 out_elempack = num_output % 8 == 0 ? 8 : 1;
1755 }
1756 #endif // __SSE2__
1757
1758 // src = inch-outch
1759 // dst = pb-inch-outch/pb
1760 {
1761 Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
1762
1763 weight_data_int8.create(num_input, num_output / out_elempack, (size_t)out_elempack, out_elempack);
1764
1765 for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack)
1766 {
1767 signed char* g0 = weight_data_int8.row<signed char>(q / out_elempack);
1768
1769 for (int p = 0; p < num_input; p++)
1770 {
1771 for (int j = 0; j < out_elempack; j++)
1772 {
1773 *g0++ = weight_data_r2.row<signed char>(q + j)[p];
1774 }
1775 }
1776 }
1777 }
1778
1779 return 0;
1780 }
1781
forward_int8_x86(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1782 int InnerProduct_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1783 {
1784 const int num_input = weight_data_size / num_output;
1785
1786 if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
1787 {
1788 // gemm
1789 Mat bottom_blob_unpacked;
1790 Option opt_unpack = opt;
1791 opt_unpack.blob_allocator = opt.workspace_allocator;
1792 convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_unpack);
1793
1794 return forward_int8(bottom_blob_unpacked, top_blob, opt);
1795 }
1796
1797 int elembits = bottom_blob.elembits();
1798
1799 Mat bottom_blob_int8 = bottom_blob;
1800 if (elembits != 8)
1801 {
1802 Option opt_q = opt;
1803 opt_q.blob_allocator = opt.workspace_allocator;
1804 quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q);
1805 }
1806
1807 Mat bottom_blob_int8_flattened = bottom_blob_int8;
1808 if (bottom_blob_int8.dims != 1)
1809 {
1810 Option opt_flatten = opt;
1811 opt_flatten.blob_allocator = opt.workspace_allocator;
1812 flatten->forward(bottom_blob_int8, bottom_blob_int8_flattened, opt_flatten);
1813 }
1814
1815 // int elempack = bottom_blob_int8_flattened.elempack;
1816
1817 int out_elempack = 1;
1818 #if __SSE2__
1819 if (opt.use_packing_layout)
1820 {
1821 out_elempack = num_output % 8 == 0 ? 8 : 1;
1822 }
1823 #endif // __SSE2__
1824 // size_t out_elemsize = elemsize / elempack * out_elempack;
1825
1826 top_blob.create(num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.blob_allocator);
1827 if (top_blob.empty())
1828 return -100;
1829
1830 Mat top_blob_int32;
1831 top_blob_int32.create(num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.workspace_allocator);
1832 if (top_blob_int32.empty())
1833 return -100;
1834
1835 #if __SSE2__
1836 if (out_elempack == 8)
1837 {
1838 // num_output
1839 #pragma omp parallel for num_threads(opt.num_threads)
1840 for (int p = 0; p < num_output / out_elempack; p++)
1841 {
1842 __m128i _sum0 = _mm_setzero_si128();
1843 __m128i _sum1 = _mm_setzero_si128();
1844
1845 const signed char* kptr = weight_data_int8.row<const signed char>(p);
1846 const signed char* sptr = bottom_blob_int8_flattened;
1847
1848 int i = 0;
1849 for (; i < num_input; i++)
1850 {
1851 __m128i _val = _mm_set1_epi16((short)sptr[0]);
1852
1853 // TODO use _mm_cvtepi8_epi16 on sse4.1
1854 __m128i _w = _mm_loadl_epi64((const __m128i*)kptr);
1855 _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w));
1856
1857 __m128i _sl = _mm_mullo_epi16(_val, _w);
1858 __m128i _sh = _mm_mulhi_epi16(_val, _w);
1859 __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh);
1860 __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh);
1861
1862 _sum0 = _mm_add_epi32(_sum0, _s0);
1863 _sum1 = _mm_add_epi32(_sum1, _s1);
1864
1865 sptr += 1;
1866 kptr += 8;
1867 }
1868
1869 int* outptr = (int*)top_blob_int32;
1870 _mm_storeu_si128((__m128i*)(outptr + p * 8), _sum0);
1871 _mm_storeu_si128((__m128i*)(outptr + p * 8 + 4), _sum1);
1872 }
1873 }
1874 #endif // __SSE2__
1875
1876 if (out_elempack == 1)
1877 {
1878 // num_output
1879 #pragma omp parallel for num_threads(opt.num_threads)
1880 for (int p = 0; p < num_output / out_elempack; p++)
1881 {
1882 int sum = 0;
1883
1884 const signed char* kptr = weight_data_int8.row<const signed char>(p);
1885 const signed char* sptr = bottom_blob_int8_flattened;
1886
1887 int i = 0;
1888 for (; i < num_input; i++)
1889 {
1890 signed char val = sptr[0];
1891
1892 signed char w = kptr[0];
1893
1894 sum += val * w;
1895
1896 sptr += 1;
1897 kptr += 1;
1898 }
1899
1900 int* outptr = (int*)top_blob_int32;
1901 outptr[p] = sum;
1902 }
1903 }
1904
1905 Mat scale_data(num_output);
1906 for (int p = 0; p < num_output; p++)
1907 {
1908 // dequantize
1909 float scale_in;
1910 if (weight_data_int8_scales[p] == 0)
1911 scale_in = 0;
1912 else
1913 scale_in = 1.f / (bottom_blob_int8_scales[0] * weight_data_int8_scales[p]);
1914
1915 scale_data[p] = scale_in;
1916 }
1917
1918 dequantize_from_int32(top_blob_int32, top_blob, scale_data, bias_data, opt);
1919
1920 if (activation)
1921 {
1922 activation->forward_inplace(top_blob, opt);
1923 }
1924
1925 return 0;
1926 }
1927 #endif // NCNN_INT8
1928
1929 } // namespace ncnn
1930