1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "innerproduct_x86.h"
16 
17 #if __SSE2__
18 #include <emmintrin.h>
19 #if __AVX__
20 #include <immintrin.h>
21 #endif
22 #endif // __SSE2__
23 
24 #include "x86_activation.h"
25 #include "x86_usability.h"
26 
27 #include "layer_type.h"
28 
29 namespace ncnn {
30 
InnerProduct_x86()31 InnerProduct_x86::InnerProduct_x86()
32 {
33 #if __SSE2__
34     support_packing = true;
35 #if __AVX__
36     support_weight_fp16_storage = true;
37 #endif
38 #endif // __SSE2__
39 
40     flatten = 0;
41     activation = 0;
42 }
43 
create_pipeline(const Option & opt)44 int InnerProduct_x86::create_pipeline(const Option& opt)
45 {
46     //     if (opt.use_packing_layout)
47     {
48         flatten = ncnn::create_layer(ncnn::LayerType::Flatten);
49 
50         ncnn::ParamDict pd;
51 
52         flatten->load_param(pd);
53 
54         flatten->create_pipeline(opt);
55     }
56 
57 #if NCNN_INT8
58     if (opt.use_int8_inference && weight_data.elemsize == (size_t)1u)
59     {
60         return create_pipeline_int8_x86(opt);
61     }
62 #endif
63 
64     const int num_input = weight_data_size / num_output;
65 
66     int out_elempack = 1;
67 
68 #if __SSE2__
69     if (opt.use_packing_layout)
70     {
71 #if __AVX__
72         out_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;
73 #else
74         out_elempack = num_output % 4 == 0 ? 4 : 1;
75 #endif
76     }
77 #endif // __SSE2__
78 
79     if (out_elempack != 1)
80     {
81         // src = inch-outch
82         // dst = pb-inch-outch/pb
83         {
84             Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
85 
86             weight_data_packed.create(num_input, num_output / out_elempack, (size_t)4u * out_elempack, out_elempack);
87 
88             for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack)
89             {
90                 float* g0 = weight_data_packed.row(q / out_elempack);
91 
92                 for (int p = 0; p < num_input; p++)
93                 {
94                     for (int j = 0; j < out_elempack; j++)
95                     {
96                         *g0++ = weight_data_r2.row(q + j)[p];
97                     }
98                 }
99             }
100         }
101     }
102 
103 #if __AVX__
104     if (opt.use_weight_fp16_storage && weight_data.elemsize == 4u)
105     {
106         ncnn::cast_float32_to_float16(weight_data, weight_data_fp16, opt);
107 
108         return 0;
109     }
110 #endif
111 
112     return 0;
113 }
114 
destroy_pipeline(const Option & opt)115 int InnerProduct_x86::destroy_pipeline(const Option& opt)
116 {
117     if (flatten)
118     {
119         flatten->destroy_pipeline(opt);
120         delete flatten;
121         flatten = 0;
122     }
123 
124     return 0;
125 }
126 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const127 int InnerProduct_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
128 {
129 #if NCNN_INT8
130     if (opt.use_int8_inference && weight_data.elemsize == (size_t)1u)
131     {
132         return forward_int8_x86(bottom_blob, top_blob, opt);
133     }
134 #endif
135 
136     const int num_input = weight_data_size / num_output;
137 
138     if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
139     {
140         // gemm
141         int h = bottom_blob.h;
142         size_t elemsize = bottom_blob.elemsize;
143         int elempack = bottom_blob.elempack;
144 
145         top_blob.create(num_output, h, elemsize, elempack, opt.blob_allocator);
146         if (top_blob.empty())
147             return -100;
148 
149         int num_output_elempack = 1;
150 #if __SSE2__
151         if (opt.use_packing_layout)
152         {
153 #if __AVX__
154             num_output_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;
155 #else
156             num_output_elempack = num_output % 4 == 0 ? 4 : 1;
157 #endif
158         }
159 #endif // __SSE2__
160 
161         #pragma omp parallel for num_threads(opt.num_threads)
162         for (int j = 0; j < h; j++)
163         {
164 #if __SSE2__
165 #if __AVX__
166             if (elempack == 8 && num_output_elempack == 8)
167             {
168                 float* outptr = top_blob.row(j);
169 
170                 for (int p = 0; p < num_output / num_output_elempack; p++)
171                 {
172                     const float* kptr = (const float*)weight_data_packed + num_input * p * 8;
173                     const float* m = bottom_blob.row(j);
174 
175                     __m256 _sum0 = _mm256_set1_ps(0.f);
176                     __m256 _sum1 = _mm256_set1_ps(0.f);
177                     __m256 _sum2 = _mm256_set1_ps(0.f);
178                     __m256 _sum3 = _mm256_set1_ps(0.f);
179                     __m256 _sum4 = _mm256_set1_ps(0.f);
180                     __m256 _sum5 = _mm256_set1_ps(0.f);
181                     __m256 _sum6 = _mm256_set1_ps(0.f);
182                     __m256 _sum7 = _mm256_set1_ps(0.f);
183 
184                     if (bias_term)
185                     {
186                         _sum0 = _mm256_set1_ps(bias_data[p * 8 + 0]);
187                         _sum1 = _mm256_set1_ps(bias_data[p * 8 + 1]);
188                         _sum2 = _mm256_set1_ps(bias_data[p * 8 + 2]);
189                         _sum3 = _mm256_set1_ps(bias_data[p * 8 + 3]);
190                         _sum4 = _mm256_set1_ps(bias_data[p * 8 + 4]);
191                         _sum5 = _mm256_set1_ps(bias_data[p * 8 + 5]);
192                         _sum6 = _mm256_set1_ps(bias_data[p * 8 + 6]);
193                         _sum7 = _mm256_set1_ps(bias_data[p * 8 + 7]);
194                     }
195 
196                     for (int i = 0; i < num_input; i++)
197                     {
198                         __m256 _val = _mm256_loadu_ps(m);
199                         __m256 _k0 = _mm256_set1_ps(kptr[0]);
200                         __m256 _k1 = _mm256_set1_ps(kptr[1]);
201                         __m256 _k2 = _mm256_set1_ps(kptr[2]);
202                         __m256 _k3 = _mm256_set1_ps(kptr[3]);
203                         __m256 _k4 = _mm256_set1_ps(kptr[4]);
204                         __m256 _k5 = _mm256_set1_ps(kptr[5]);
205                         __m256 _k6 = _mm256_set1_ps(kptr[6]);
206                         __m256 _k7 = _mm256_set1_ps(kptr[7]);
207                         _sum0 = _mm256_fmadd_ps(_val, _k0, _sum0);
208                         _sum1 = _mm256_fmadd_ps(_val, _k1, _sum1);
209                         _sum2 = _mm256_fmadd_ps(_val, _k2, _sum2);
210                         _sum3 = _mm256_fmadd_ps(_val, _k3, _sum3);
211                         _sum4 = _mm256_fmadd_ps(_val, _k4, _sum4);
212                         _sum5 = _mm256_fmadd_ps(_val, _k5, _sum5);
213                         _sum6 = _mm256_fmadd_ps(_val, _k6, _sum6);
214                         _sum7 = _mm256_fmadd_ps(_val, _k7, _sum7);
215 
216                         m += 8;
217                         kptr += 8;
218                     }
219 
220                     _sum0 = activation_avx(_sum0, activation_type, activation_params);
221                     _sum1 = activation_avx(_sum1, activation_type, activation_params);
222                     _sum2 = activation_avx(_sum2, activation_type, activation_params);
223                     _sum3 = activation_avx(_sum3, activation_type, activation_params);
224                     _sum4 = activation_avx(_sum4, activation_type, activation_params);
225                     _sum5 = activation_avx(_sum5, activation_type, activation_params);
226                     _sum6 = activation_avx(_sum6, activation_type, activation_params);
227                     _sum7 = activation_avx(_sum7, activation_type, activation_params);
228 
229                     _mm256_storeu_ps(outptr, _sum0);
230                     _mm256_storeu_ps(outptr + 8, _sum1);
231                     _mm256_storeu_ps(outptr + 16, _sum2);
232                     _mm256_storeu_ps(outptr + 24, _sum3);
233                     _mm256_storeu_ps(outptr + 32, _sum4);
234                     _mm256_storeu_ps(outptr + 40, _sum5);
235                     _mm256_storeu_ps(outptr + 48, _sum6);
236                     _mm256_storeu_ps(outptr + 56, _sum7);
237                     outptr += 64;
238                 }
239             }
240 
241             if (elempack == 1 && num_output_elempack == 8)
242             {
243                 float* outptr = top_blob.row(j);
244 
245                 for (int p = 0; p < num_output / num_output_elempack; p++)
246                 {
247                     const float* kptr = (const float*)weight_data_packed + num_input * p * 8;
248                     const float* m = bottom_blob.row(j);
249 
250                     __m256 _sum = _mm256_set1_ps(0.f);
251 
252                     if (bias_term)
253                     {
254                         _sum = _mm256_loadu_ps((const float*)bias_data + p * 8);
255                     }
256 
257                     int i = 0;
258                     for (; i + 7 < num_input; i += 8)
259                     {
260                         __m256 _val0 = _mm256_broadcast_ss(m);
261                         __m256 _val1 = _mm256_broadcast_ss(m + 1);
262                         __m256 _val2 = _mm256_broadcast_ss(m + 2);
263                         __m256 _val3 = _mm256_broadcast_ss(m + 3);
264                         __m256 _val4 = _mm256_broadcast_ss(m + 4);
265                         __m256 _val5 = _mm256_broadcast_ss(m + 5);
266                         __m256 _val6 = _mm256_broadcast_ss(m + 6);
267                         __m256 _val7 = _mm256_broadcast_ss(m + 7);
268 
269                         __m256 _w0 = _mm256_loadu_ps(kptr);
270                         _sum = _mm256_fmadd_ps(_val0, _w0, _sum);
271                         __m256 _w1 = _mm256_loadu_ps(kptr + 8);
272                         _sum = _mm256_fmadd_ps(_val1, _w1, _sum);
273                         __m256 _w2 = _mm256_loadu_ps(kptr + 16);
274                         _sum = _mm256_fmadd_ps(_val2, _w2, _sum);
275                         __m256 _w3 = _mm256_loadu_ps(kptr + 24);
276                         _sum = _mm256_fmadd_ps(_val3, _w3, _sum);
277                         __m256 _w4 = _mm256_loadu_ps(kptr + 32);
278                         _sum = _mm256_fmadd_ps(_val4, _w4, _sum);
279                         __m256 _w5 = _mm256_loadu_ps(kptr + 40);
280                         _sum = _mm256_fmadd_ps(_val5, _w5, _sum);
281                         __m256 _w6 = _mm256_loadu_ps(kptr + 48);
282                         _sum = _mm256_fmadd_ps(_val6, _w6, _sum);
283                         __m256 _w7 = _mm256_loadu_ps(kptr + 56);
284                         _sum = _mm256_fmadd_ps(_val7, _w7, _sum);
285 
286                         m += 8;
287                         kptr += 64;
288                     }
289                     for (; i + 3 < num_input; i += 4)
290                     {
291                         __m256 _val0 = _mm256_broadcast_ss(m);
292                         __m256 _val1 = _mm256_broadcast_ss(m + 1);
293                         __m256 _val2 = _mm256_broadcast_ss(m + 2);
294                         __m256 _val3 = _mm256_broadcast_ss(m + 3);
295 
296                         __m256 _w0 = _mm256_loadu_ps(kptr);
297                         _sum = _mm256_fmadd_ps(_val0, _w0, _sum);
298                         __m256 _w1 = _mm256_loadu_ps(kptr + 8);
299                         _sum = _mm256_fmadd_ps(_val1, _w1, _sum);
300                         __m256 _w2 = _mm256_loadu_ps(kptr + 16);
301                         _sum = _mm256_fmadd_ps(_val2, _w2, _sum);
302                         __m256 _w3 = _mm256_loadu_ps(kptr + 24);
303                         _sum = _mm256_fmadd_ps(_val3, _w3, _sum);
304 
305                         m += 4;
306                         kptr += 32;
307                     }
308                     for (; i < num_input; i++)
309                     {
310                         __m256 _val = _mm256_set1_ps(m[0]);
311                         __m256 _w = _mm256_loadu_ps(kptr);
312                         _sum = _mm256_fmadd_ps(_val, _w, _sum);
313 
314                         m += 1;
315                         kptr += 8;
316                     }
317 
318                     _sum = activation_avx(_sum, activation_type, activation_params);
319 
320                     _mm256_storeu_ps(outptr, _sum);
321                     outptr += 8;
322                 }
323             }
324 
325             if (elempack == 4 && num_output_elempack == 8)
326             {
327                 float* outptr = top_blob.row(j);
328 
329                 for (int p = 0; p < num_output / num_output_elempack; p++)
330                 {
331                     const float* kptr = (const float*)weight_data_packed + num_input * p * 8;
332                     const float* m = bottom_blob.row(j);
333 
334                     __m128 _sum0 = _mm_set1_ps(0.f);
335                     __m128 _sum1 = _mm_set1_ps(0.f);
336                     __m128 _sum2 = _mm_set1_ps(0.f);
337                     __m128 _sum3 = _mm_set1_ps(0.f);
338                     __m128 _sum4 = _mm_set1_ps(0.f);
339                     __m128 _sum5 = _mm_set1_ps(0.f);
340                     __m128 _sum6 = _mm_set1_ps(0.f);
341                     __m128 _sum7 = _mm_set1_ps(0.f);
342 
343                     if (bias_term)
344                     {
345                         _sum0 = _mm_set1_ps(bias_data[p * 8 + 0]);
346                         _sum1 = _mm_set1_ps(bias_data[p * 8 + 1]);
347                         _sum2 = _mm_set1_ps(bias_data[p * 8 + 2]);
348                         _sum3 = _mm_set1_ps(bias_data[p * 8 + 3]);
349                         _sum4 = _mm_set1_ps(bias_data[p * 8 + 4]);
350                         _sum5 = _mm_set1_ps(bias_data[p * 8 + 5]);
351                         _sum6 = _mm_set1_ps(bias_data[p * 8 + 6]);
352                         _sum7 = _mm_set1_ps(bias_data[p * 8 + 7]);
353                     }
354 
355                     int i = 0;
356                     for (; i < num_input; i++)
357                     {
358                         __m128 _val = _mm_loadu_ps(m);
359                         _sum0 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[0]), _sum0);
360                         _sum1 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[1]), _sum1);
361                         _sum2 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[2]), _sum2);
362                         _sum3 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[3]), _sum3);
363                         _sum4 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[4]), _sum4);
364                         _sum5 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[5]), _sum5);
365                         _sum6 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[6]), _sum6);
366                         _sum7 = _mm_fmadd_ps(_val, _mm_set1_ps(kptr[7]), _sum7);
367 
368                         m += 4;
369                         kptr += 8;
370                     }
371 
372                     _sum0 = activation_sse(_sum0, activation_type, activation_params);
373                     _sum1 = activation_sse(_sum1, activation_type, activation_params);
374                     _sum2 = activation_sse(_sum2, activation_type, activation_params);
375                     _sum3 = activation_sse(_sum3, activation_type, activation_params);
376                     _sum4 = activation_sse(_sum4, activation_type, activation_params);
377                     _sum5 = activation_sse(_sum5, activation_type, activation_params);
378                     _sum6 = activation_sse(_sum6, activation_type, activation_params);
379                     _sum7 = activation_sse(_sum7, activation_type, activation_params);
380 
381                     _mm_storeu_ps(outptr, _sum0);
382                     _mm_storeu_ps(outptr + 4, _sum1);
383                     _mm_storeu_ps(outptr + 8, _sum2);
384                     _mm_storeu_ps(outptr + 12, _sum3);
385                     _mm_storeu_ps(outptr + 16, _sum4);
386                     _mm_storeu_ps(outptr + 20, _sum5);
387                     _mm_storeu_ps(outptr + 24, _sum6);
388                     _mm_storeu_ps(outptr + 28, _sum7);
389                     outptr += 32;
390                 }
391             }
392 
393             if (elempack == 8 && num_output_elempack == 1)
394             {
395                 float* outptr = top_blob.row(j);
396 
397                 for (int p = 0; p < num_output; p++)
398                 {
399                     const float* kptr = (const float*)weight_data + num_input * p;
400                     const float* m = bottom_blob.row(j);
401 
402                     __m256 _sum0 = _mm256_set1_ps(0.f);
403                     __m256 _sum1 = _mm256_set1_ps(0.f);
404                     __m256 _sum2 = _mm256_set1_ps(0.f);
405                     __m256 _sum3 = _mm256_set1_ps(0.f);
406 
407                     if (bias_term)
408                     {
409                         _sum0 = _mm256_set1_ps(bias_data[p]);
410                     }
411 
412                     int i = 0;
413                     for (; i + 7 < num_input; i += 8)
414                     {
415                         __m256 _val0 = _mm256_loadu_ps(m);
416                         __m256 _val1 = _mm256_loadu_ps(m + 8);
417                         __m256 _val2 = _mm256_loadu_ps(m + 16);
418                         __m256 _val3 = _mm256_loadu_ps(m + 24);
419                         __m256 _val4 = _mm256_loadu_ps(m + 32);
420                         __m256 _val5 = _mm256_loadu_ps(m + 40);
421                         __m256 _val6 = _mm256_loadu_ps(m + 48);
422                         __m256 _val7 = _mm256_loadu_ps(m + 56);
423                         _sum0 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[0]), _sum0);
424                         _sum1 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[1]), _sum1);
425                         _sum2 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[2]), _sum2);
426                         _sum3 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[3]), _sum3);
427                         _sum0 = _mm256_fmadd_ps(_val4, _mm256_set1_ps(kptr[4]), _sum0);
428                         _sum1 = _mm256_fmadd_ps(_val5, _mm256_set1_ps(kptr[5]), _sum1);
429                         _sum2 = _mm256_fmadd_ps(_val6, _mm256_set1_ps(kptr[6]), _sum2);
430                         _sum3 = _mm256_fmadd_ps(_val7, _mm256_set1_ps(kptr[7]), _sum3);
431 
432                         m += 64;
433                         kptr += 8;
434                     }
435                     for (; i + 3 < num_input; i += 4)
436                     {
437                         __m256 _val0 = _mm256_loadu_ps(m);
438                         __m256 _val1 = _mm256_loadu_ps(m + 8);
439                         __m256 _val2 = _mm256_loadu_ps(m + 16);
440                         __m256 _val3 = _mm256_loadu_ps(m + 24);
441                         _sum0 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[0]), _sum0);
442                         _sum1 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[1]), _sum1);
443                         _sum2 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[2]), _sum2);
444                         _sum3 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[3]), _sum3);
445 
446                         m += 32;
447                         kptr += 4;
448                     }
449                     for (; i < num_input; i++)
450                     {
451                         __m256 _val = _mm256_loadu_ps(m);
452                         __m256 _k = _mm256_set1_ps(kptr[0]);
453                         _sum0 = _mm256_fmadd_ps(_val, _k, _sum0);
454 
455                         m += 8;
456                         kptr += 1;
457                     }
458 
459                     _sum0 = _mm256_add_ps(_sum0, _sum1);
460                     _sum2 = _mm256_add_ps(_sum2, _sum3);
461                     _sum0 = _mm256_add_ps(_sum0, _sum2);
462 
463                     _sum0 = activation_avx(_sum0, activation_type, activation_params);
464 
465                     _mm256_storeu_ps(outptr, _sum0);
466                     outptr += 8;
467                 }
468             }
469 
470             if (elempack == 8 && num_output_elempack == 4)
471             {
472                 float* outptr = top_blob.row(j);
473 
474                 for (int p = 0; p < num_output / num_output_elempack; p++)
475                 {
476                     const float* kptr = (const float*)weight_data_packed + num_input * p * 4;
477                     const float* m = bottom_blob.row(j);
478 
479                     __m256 _sum0 = _mm256_set1_ps(0.f);
480                     __m256 _sum1 = _mm256_set1_ps(0.f);
481                     __m256 _sum2 = _mm256_set1_ps(0.f);
482                     __m256 _sum3 = _mm256_set1_ps(0.f);
483 
484                     if (bias_term)
485                     {
486                         _sum0 = _mm256_set1_ps(bias_data[p * 4 + 0]);
487                         _sum1 = _mm256_set1_ps(bias_data[p * 4 + 1]);
488                         _sum2 = _mm256_set1_ps(bias_data[p * 4 + 2]);
489                         _sum3 = _mm256_set1_ps(bias_data[p * 4 + 3]);
490                     }
491 
492                     int i = 0;
493                     for (; i + 3 < num_input; i += 4)
494                     {
495                         __m256 _val0 = _mm256_loadu_ps(m);
496                         __m256 _val1 = _mm256_loadu_ps(m + 8);
497                         __m256 _val2 = _mm256_loadu_ps(m + 16);
498                         __m256 _val3 = _mm256_loadu_ps(m + 24);
499                         _sum0 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[0]), _sum0);
500                         _sum1 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[1]), _sum1);
501                         _sum2 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[2]), _sum2);
502                         _sum3 = _mm256_fmadd_ps(_val0, _mm256_set1_ps(kptr[3]), _sum3);
503                         _sum0 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[4]), _sum0);
504                         _sum1 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[5]), _sum1);
505                         _sum2 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[6]), _sum2);
506                         _sum3 = _mm256_fmadd_ps(_val1, _mm256_set1_ps(kptr[7]), _sum3);
507                         kptr += 8;
508 
509                         _sum0 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[0]), _sum0);
510                         _sum1 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[1]), _sum1);
511                         _sum2 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[2]), _sum2);
512                         _sum3 = _mm256_fmadd_ps(_val2, _mm256_set1_ps(kptr[3]), _sum3);
513                         _sum0 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[4]), _sum0);
514                         _sum1 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[5]), _sum1);
515                         _sum2 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[6]), _sum2);
516                         _sum3 = _mm256_fmadd_ps(_val3, _mm256_set1_ps(kptr[7]), _sum3);
517 
518                         m += 32;
519                         kptr += 8;
520                     }
521                     for (; i < num_input; i++)
522                     {
523                         __m256 _val = _mm256_loadu_ps(m);
524                         _sum0 = _mm256_fmadd_ps(_val, _mm256_set1_ps(kptr[0]), _sum0);
525                         _sum1 = _mm256_fmadd_ps(_val, _mm256_set1_ps(kptr[1]), _sum1);
526                         _sum2 = _mm256_fmadd_ps(_val, _mm256_set1_ps(kptr[2]), _sum2);
527                         _sum3 = _mm256_fmadd_ps(_val, _mm256_set1_ps(kptr[3]), _sum3);
528 
529                         m += 8;
530                         kptr += 4;
531                     }
532 
533                     _sum0 = activation_avx(_sum0, activation_type, activation_params);
534                     _sum1 = activation_avx(_sum1, activation_type, activation_params);
535                     _sum2 = activation_avx(_sum2, activation_type, activation_params);
536                     _sum3 = activation_avx(_sum3, activation_type, activation_params);
537 
538                     _mm256_storeu_ps(outptr, _sum0);
539                     _mm256_storeu_ps(outptr + 8, _sum1);
540                     _mm256_storeu_ps(outptr + 16, _sum2);
541                     _mm256_storeu_ps(outptr + 24, _sum3);
542                     outptr += 32;
543                 }
544             }
545 #endif // __AVX__
546 
547             if (elempack == 4 && num_output_elempack == 4)
548             {
549                 float* outptr = top_blob.row(j);
550 
551                 for (int p = 0; p < num_output / num_output_elempack; p++)
552                 {
553                     const float* kptr = (const float*)weight_data_packed + num_input * p * 4;
554                     const float* m = bottom_blob.row(j);
555 
556                     __m128 _sum0 = _mm_set1_ps(0.f);
557                     __m128 _sum1 = _mm_set1_ps(0.f);
558                     __m128 _sum2 = _mm_set1_ps(0.f);
559                     __m128 _sum3 = _mm_set1_ps(0.f);
560 
561                     if (bias_term)
562                     {
563                         _sum0 = _mm_set1_ps(bias_data[p * 4 + 0]);
564                         _sum1 = _mm_set1_ps(bias_data[p * 4 + 1]);
565                         _sum2 = _mm_set1_ps(bias_data[p * 4 + 2]);
566                         _sum3 = _mm_set1_ps(bias_data[p * 4 + 3]);
567                     }
568 
569                     int i = 0;
570                     for (; i + 3 < num_input; i += 4)
571                     {
572                         __m128 _val0 = _mm_loadu_ps(m);
573                         __m128 _val1 = _mm_loadu_ps(m + 4);
574                         __m128 _val2 = _mm_loadu_ps(m + 8);
575                         __m128 _val3 = _mm_loadu_ps(m + 12);
576                         _sum0 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[0])), _sum0);
577                         _sum1 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[1])), _sum1);
578                         _sum2 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[2])), _sum2);
579                         _sum3 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[3])), _sum3);
580                         _sum0 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[4])), _sum0);
581                         _sum1 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[5])), _sum1);
582                         _sum2 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[6])), _sum2);
583                         _sum3 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[7])), _sum3);
584                         _sum0 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[8])), _sum0);
585                         _sum1 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[9])), _sum1);
586                         _sum2 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[10])), _sum2);
587                         _sum3 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[11])), _sum3);
588                         _sum0 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[12])), _sum0);
589                         _sum1 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[13])), _sum1);
590                         _sum2 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[14])), _sum2);
591                         _sum3 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[15])), _sum3);
592 
593                         m += 16;
594                         kptr += 16;
595                     }
596                     for (; i < num_input; i++)
597                     {
598                         __m128 _val = _mm_loadu_ps(m);
599                         _sum0 = _mm_add_ps(_mm_mul_ps(_val, _mm_set1_ps(kptr[0])), _sum0);
600                         _sum1 = _mm_add_ps(_mm_mul_ps(_val, _mm_set1_ps(kptr[1])), _sum1);
601                         _sum2 = _mm_add_ps(_mm_mul_ps(_val, _mm_set1_ps(kptr[2])), _sum2);
602                         _sum3 = _mm_add_ps(_mm_mul_ps(_val, _mm_set1_ps(kptr[3])), _sum3);
603 
604                         m += 4;
605                         kptr += 4;
606                     }
607 
608                     _sum0 = activation_sse(_sum0, activation_type, activation_params);
609                     _sum1 = activation_sse(_sum1, activation_type, activation_params);
610                     _sum2 = activation_sse(_sum2, activation_type, activation_params);
611                     _sum3 = activation_sse(_sum3, activation_type, activation_params);
612 
613                     _mm_storeu_ps(outptr, _sum0);
614                     _mm_storeu_ps(outptr + 4, _sum1);
615                     _mm_storeu_ps(outptr + 8, _sum2);
616                     _mm_storeu_ps(outptr + 12, _sum3);
617                     outptr += 16;
618                 }
619             }
620 
621             if (elempack == 1 && num_output_elempack == 4)
622             {
623                 float* outptr = top_blob.row(j);
624 
625                 for (int p = 0; p < num_output / num_output_elempack; p++)
626                 {
627                     const float* kptr = (const float*)weight_data_packed + num_input * p * 4;
628                     const float* m = bottom_blob.row(j);
629 
630                     __m128 _sum = _mm_set1_ps(0.f);
631 
632                     if (bias_term)
633                     {
634                         _sum = _mm_loadu_ps((const float*)bias_data + p * 4);
635                     }
636 
637                     int i = 0;
638 #if __AVX__
639                     for (; i + 7 < num_input; i += 8)
640                     {
641                         __m128 _val0 = _mm_broadcast_ss(m);
642                         __m128 _val1 = _mm_broadcast_ss(m + 1);
643                         __m128 _val2 = _mm_broadcast_ss(m + 2);
644                         __m128 _val3 = _mm_broadcast_ss(m + 3);
645                         __m128 _val4 = _mm_broadcast_ss(m + 4);
646                         __m128 _val5 = _mm_broadcast_ss(m + 5);
647                         __m128 _val6 = _mm_broadcast_ss(m + 6);
648                         __m128 _val7 = _mm_broadcast_ss(m + 7);
649 
650                         __m128 _w0 = _mm_loadu_ps(kptr);
651                         _sum = _mm_fmadd_ps(_val0, _w0, _sum);
652                         __m128 _w1 = _mm_loadu_ps(kptr + 4);
653                         _sum = _mm_fmadd_ps(_val1, _w1, _sum);
654                         __m128 _w2 = _mm_loadu_ps(kptr + 8);
655                         _sum = _mm_fmadd_ps(_val2, _w2, _sum);
656                         __m128 _w3 = _mm_loadu_ps(kptr + 12);
657                         _sum = _mm_fmadd_ps(_val3, _w3, _sum);
658                         __m128 _w4 = _mm_loadu_ps(kptr + 16);
659                         _sum = _mm_fmadd_ps(_val4, _w4, _sum);
660                         __m128 _w5 = _mm_loadu_ps(kptr + 20);
661                         _sum = _mm_fmadd_ps(_val5, _w5, _sum);
662                         __m128 _w6 = _mm_loadu_ps(kptr + 24);
663                         _sum = _mm_fmadd_ps(_val6, _w6, _sum);
664                         __m128 _w7 = _mm_loadu_ps(kptr + 28);
665                         _sum = _mm_fmadd_ps(_val7, _w7, _sum);
666 
667                         m += 8;
668                         kptr += 32;
669                     }
670 #endif // __AVX__
671                     for (; i + 3 < num_input; i += 4)
672                     {
673                         __m128 _val0 = _mm_set1_ps(m[0]);
674                         __m128 _val1 = _mm_set1_ps(m[1]);
675                         __m128 _val2 = _mm_set1_ps(m[2]);
676                         __m128 _val3 = _mm_set1_ps(m[3]);
677 
678                         __m128 _w0 = _mm_loadu_ps(kptr);
679                         _sum = _mm_add_ps(_mm_mul_ps(_val0, _w0), _sum);
680                         __m128 _w1 = _mm_loadu_ps(kptr + 4);
681                         _sum = _mm_add_ps(_mm_mul_ps(_val1, _w1), _sum);
682                         __m128 _w2 = _mm_loadu_ps(kptr + 8);
683                         _sum = _mm_add_ps(_mm_mul_ps(_val2, _w2), _sum);
684                         __m128 _w3 = _mm_loadu_ps(kptr + 12);
685                         _sum = _mm_add_ps(_mm_mul_ps(_val3, _w3), _sum);
686 
687                         m += 4;
688                         kptr += 16;
689                     }
690                     for (; i < num_input; i++)
691                     {
692                         __m128 _val = _mm_set1_ps(m[0]);
693                         __m128 _k = _mm_loadu_ps(kptr);
694                         _sum = _mm_add_ps(_mm_mul_ps(_val, _k), _sum);
695 
696                         m += 1;
697                         kptr += 4;
698                     }
699 
700                     _sum = activation_sse(_sum, activation_type, activation_params);
701 
702                     _mm_storeu_ps(outptr, _sum);
703                     outptr += 4;
704                 }
705             }
706 
707             if (elempack == 4 && num_output_elempack == 1)
708             {
709                 float* outptr = top_blob.row(j);
710 
711                 for (int p = 0; p < num_output; p++)
712                 {
713                     const float* kptr = (const float*)weight_data + num_input * p;
714                     const float* m = bottom_blob.row(j);
715 
716                     __m128 _sum0 = _mm_set1_ps(0.f);
717                     __m128 _sum1 = _mm_set1_ps(0.f);
718                     __m128 _sum2 = _mm_set1_ps(0.f);
719                     __m128 _sum3 = _mm_set1_ps(0.f);
720 
721                     if (bias_term)
722                     {
723                         _sum0 = _mm_set1_ps(bias_data[p]);
724                     }
725 
726                     int i = 0;
727                     for (; i + 7 < num_input; i += 8)
728                     {
729                         __m128 _val0 = _mm_loadu_ps(m);
730                         __m128 _val1 = _mm_loadu_ps(m + 4);
731                         __m128 _val2 = _mm_loadu_ps(m + 8);
732                         __m128 _val3 = _mm_loadu_ps(m + 12);
733                         __m128 _val4 = _mm_loadu_ps(m + 16);
734                         __m128 _val5 = _mm_loadu_ps(m + 20);
735                         __m128 _val6 = _mm_loadu_ps(m + 24);
736                         __m128 _val7 = _mm_loadu_ps(m + 28);
737                         _sum0 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[0])), _sum0);
738                         _sum1 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[1])), _sum1);
739                         _sum2 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[2])), _sum2);
740                         _sum3 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[3])), _sum3);
741                         _sum0 = _mm_add_ps(_mm_mul_ps(_val4, _mm_set1_ps(kptr[4])), _sum0);
742                         _sum1 = _mm_add_ps(_mm_mul_ps(_val5, _mm_set1_ps(kptr[5])), _sum1);
743                         _sum2 = _mm_add_ps(_mm_mul_ps(_val6, _mm_set1_ps(kptr[6])), _sum2);
744                         _sum3 = _mm_add_ps(_mm_mul_ps(_val7, _mm_set1_ps(kptr[7])), _sum3);
745 
746                         m += 32;
747                         kptr += 8;
748                     }
749                     for (; i + 3 < num_input; i += 4)
750                     {
751                         __m128 _val0 = _mm_loadu_ps(m);
752                         __m128 _val1 = _mm_loadu_ps(m + 4);
753                         __m128 _val2 = _mm_loadu_ps(m + 8);
754                         __m128 _val3 = _mm_loadu_ps(m + 12);
755                         _sum0 = _mm_add_ps(_mm_mul_ps(_val0, _mm_set1_ps(kptr[0])), _sum0);
756                         _sum1 = _mm_add_ps(_mm_mul_ps(_val1, _mm_set1_ps(kptr[1])), _sum1);
757                         _sum2 = _mm_add_ps(_mm_mul_ps(_val2, _mm_set1_ps(kptr[2])), _sum2);
758                         _sum3 = _mm_add_ps(_mm_mul_ps(_val3, _mm_set1_ps(kptr[3])), _sum3);
759 
760                         m += 16;
761                         kptr += 4;
762                     }
763                     for (; i < num_input; i++)
764                     {
765                         __m128 _val = _mm_loadu_ps(m);
766                         __m128 _k = _mm_set1_ps(kptr[0]);
767                         _sum0 = _mm_add_ps(_mm_mul_ps(_val, _k), _sum0);
768 
769                         m += 4;
770                         kptr += 1;
771                     }
772 
773                     _sum0 = _mm_add_ps(_sum0, _sum1);
774                     _sum2 = _mm_add_ps(_sum2, _sum3);
775                     _sum0 = _mm_add_ps(_sum0, _sum2);
776 
777                     _sum0 = activation_sse(_sum0, activation_type, activation_params);
778 
779                     _mm_storeu_ps(outptr, _sum0);
780                     outptr += 4;
781                 }
782             }
783 #endif // __SSE2__
784 
785             if (elempack == 1 && num_output_elempack == 1)
786             {
787                 float* outptr = top_blob.row(j);
788 
789                 for (int p = 0; p < num_output; p++)
790                 {
791                     const float* kptr = (const float*)weight_data + num_input * p;
792                     const float* m = bottom_blob.row(j);
793 
794                     float sum = 0.f;
795 
796                     if (bias_term)
797                     {
798                         sum = bias_data[p];
799                     }
800 
801                     int i = 0;
802 #if __SSE2__
803 #if __AVX__
804                     __m256 _sum = _mm256_set1_ps(0.f);
805                     for (; i + 7 < num_input; i += 8)
806                     {
807                         __m256 _m = _mm256_loadu_ps(m);
808                         __m256 _w = _mm256_loadu_ps(kptr);
809                         _sum = _mm256_fmadd_ps(_m, _w, _sum);
810 
811                         m += 8;
812                         kptr += 8;
813                     }
814 #endif // __AVX__
815                     __m128 _suml = _mm_set1_ps(0.f);
816                     for (; i + 3 < num_input; i += 4)
817                     {
818                         __m128 _val = _mm_loadu_ps(m);
819                         __m128 _k = _mm_loadu_ps(kptr);
820                         _suml = _mm_add_ps(_mm_mul_ps(_val, _k), _suml);
821 
822                         m += 4;
823                         kptr += 4;
824                     }
825 #endif // __SSE2__
826                     for (; i < num_input; i++)
827                     {
828                         sum += *m++ * *kptr++;
829                     }
830 
831 #if __SSE2__
832 #if __AVX__
833                     sum += _mm256_reduce_add_ps(_sum);
834 #endif // __AVX__
835                     sum += _mm_reduce_add_ps(_suml);
836 #endif // __SSE2__
837 
838                     sum = activation_ss(sum, activation_type, activation_params);
839 
840                     outptr[0] = sum;
841                     outptr += 1;
842                 }
843             }
844         }
845 
846         return 0;
847     }
848 
849 #if __AVX__
850     if (opt.use_weight_fp16_storage)
851     {
852         return forward_fp16(bottom_blob, top_blob, opt);
853     }
854 #endif // __AVX__
855 
856     // flatten
857     Mat bottom_blob_flattened = bottom_blob;
858     if (bottom_blob.dims != 1)
859     {
860         Option opt_flatten = opt;
861         opt_flatten.blob_allocator = opt.workspace_allocator;
862 
863         flatten->forward(bottom_blob, bottom_blob_flattened, opt_flatten);
864     }
865 
866     size_t elemsize = bottom_blob_flattened.elemsize;
867     int elempack = bottom_blob_flattened.elempack;
868 
869     int out_elempack = 1;
870 #if __SSE2__
871     if (opt.use_packing_layout)
872     {
873 #if __AVX__
874         out_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;
875 #else
876         out_elempack = num_output % 4 == 0 ? 4 : 1;
877 #endif
878     }
879 #endif // __SSE2__
880     size_t out_elemsize = elemsize / elempack * out_elempack;
881 
882     top_blob.create(num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
883     if (top_blob.empty())
884         return -100;
885 
886 #if __SSE2__
887 #if __AVX__
888     if (out_elempack == 8)
889     {
890         // num_output
891         #pragma omp parallel for num_threads(opt.num_threads)
892         for (int p = 0; p < num_output / out_elempack; p++)
893         {
894             __m256 _sum0 = _mm256_set1_ps(0.f);
895             __m256 _sum1 = _mm256_set1_ps(0.f);
896             __m256 _sum2 = _mm256_set1_ps(0.f);
897             __m256 _sum3 = _mm256_set1_ps(0.f);
898             __m256 _sum4 = _mm256_set1_ps(0.f);
899             __m256 _sum5 = _mm256_set1_ps(0.f);
900             __m256 _sum6 = _mm256_set1_ps(0.f);
901             __m256 _sum7 = _mm256_set1_ps(0.f);
902 
903             if (bias_term)
904             {
905                 _sum0 = _mm256_loadu_ps((const float*)bias_data + p * 8);
906             }
907 
908             const float* kptr = weight_data_packed.row(p);
909 
910             const float* sptr = bottom_blob_flattened;
911 
912             int i = 0;
913             for (; i + 7 < num_input; i += 8)
914             {
915                 __m256 _val0 = _mm256_broadcast_ss(sptr);
916                 __m256 _val1 = _mm256_broadcast_ss(sptr + 1);
917                 __m256 _val2 = _mm256_broadcast_ss(sptr + 2);
918                 __m256 _val3 = _mm256_broadcast_ss(sptr + 3);
919                 __m256 _val4 = _mm256_broadcast_ss(sptr + 4);
920                 __m256 _val5 = _mm256_broadcast_ss(sptr + 5);
921                 __m256 _val6 = _mm256_broadcast_ss(sptr + 6);
922                 __m256 _val7 = _mm256_broadcast_ss(sptr + 7);
923 
924                 __m256 _w0 = _mm256_loadu_ps(kptr);
925                 _sum0 = _mm256_fmadd_ps(_val0, _w0, _sum0);
926                 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
927                 _sum1 = _mm256_fmadd_ps(_val1, _w1, _sum1);
928                 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
929                 _sum2 = _mm256_fmadd_ps(_val2, _w2, _sum2);
930                 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
931                 _sum3 = _mm256_fmadd_ps(_val3, _w3, _sum3);
932                 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
933                 _sum4 = _mm256_fmadd_ps(_val4, _w4, _sum4);
934                 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
935                 _sum5 = _mm256_fmadd_ps(_val5, _w5, _sum5);
936                 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
937                 _sum6 = _mm256_fmadd_ps(_val6, _w6, _sum6);
938                 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
939                 _sum7 = _mm256_fmadd_ps(_val7, _w7, _sum7);
940 
941                 sptr += 8;
942                 kptr += 64;
943             }
944             for (; i + 3 < num_input; i += 4)
945             {
946                 __m256 _val0 = _mm256_broadcast_ss(sptr);
947                 __m256 _val1 = _mm256_broadcast_ss(sptr + 1);
948                 __m256 _val2 = _mm256_broadcast_ss(sptr + 2);
949                 __m256 _val3 = _mm256_broadcast_ss(sptr + 3);
950 
951                 __m256 _w0 = _mm256_loadu_ps(kptr);
952                 _sum0 = _mm256_fmadd_ps(_val0, _w0, _sum0);
953                 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
954                 _sum1 = _mm256_fmadd_ps(_val1, _w1, _sum1);
955                 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
956                 _sum2 = _mm256_fmadd_ps(_val2, _w2, _sum2);
957                 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
958                 _sum3 = _mm256_fmadd_ps(_val3, _w3, _sum3);
959 
960                 sptr += 4;
961                 kptr += 32;
962             }
963             for (; i < num_input; i++)
964             {
965                 __m256 _val = _mm256_set1_ps(sptr[0]);
966                 __m256 _w = _mm256_loadu_ps(kptr);
967                 _sum0 = _mm256_fmadd_ps(_val, _w, _sum0);
968 
969                 sptr += 1;
970                 kptr += 8;
971             }
972 
973             _sum0 = _mm256_add_ps(_sum0, _sum1);
974             _sum2 = _mm256_add_ps(_sum2, _sum3);
975             _sum4 = _mm256_add_ps(_sum4, _sum5);
976             _sum6 = _mm256_add_ps(_sum6, _sum7);
977             _sum0 = _mm256_add_ps(_sum0, _sum2);
978             _sum4 = _mm256_add_ps(_sum4, _sum6);
979             _sum0 = _mm256_add_ps(_sum0, _sum4);
980 
981             _sum0 = activation_avx(_sum0, activation_type, activation_params);
982 
983             float* outptr = top_blob;
984             _mm256_storeu_ps(outptr + p * 8, _sum0);
985         }
986     }
987 #endif // __AVX__
988 
989     if (out_elempack == 4)
990     {
991         // num_output
992         #pragma omp parallel for num_threads(opt.num_threads)
993         for (int p = 0; p < num_output / out_elempack; p++)
994         {
995             __m128 _sum0 = _mm_set1_ps(0.f);
996             __m128 _sum1 = _mm_set1_ps(0.f);
997             __m128 _sum2 = _mm_set1_ps(0.f);
998             __m128 _sum3 = _mm_set1_ps(0.f);
999 #if __AVX__
1000             __m128 _sum4 = _mm_set1_ps(0.f);
1001             __m128 _sum5 = _mm_set1_ps(0.f);
1002             __m128 _sum6 = _mm_set1_ps(0.f);
1003             __m128 _sum7 = _mm_set1_ps(0.f);
1004 #endif
1005 
1006             if (bias_term)
1007             {
1008                 _sum0 = _mm_loadu_ps((const float*)bias_data + p * 4);
1009             }
1010 
1011             const float* kptr = weight_data_packed.row(p);
1012 
1013             const float* sptr = bottom_blob_flattened;
1014 
1015             int i = 0;
1016 #if __AVX__
1017             for (; i + 7 < num_input; i += 8)
1018             {
1019                 __m128 _val0 = _mm_broadcast_ss(sptr);
1020                 __m128 _val1 = _mm_broadcast_ss(sptr + 1);
1021                 __m128 _val2 = _mm_broadcast_ss(sptr + 2);
1022                 __m128 _val3 = _mm_broadcast_ss(sptr + 3);
1023                 __m128 _val4 = _mm_broadcast_ss(sptr + 4);
1024                 __m128 _val5 = _mm_broadcast_ss(sptr + 5);
1025                 __m128 _val6 = _mm_broadcast_ss(sptr + 6);
1026                 __m128 _val7 = _mm_broadcast_ss(sptr + 7);
1027 
1028                 __m128 _w0 = _mm_loadu_ps(kptr);
1029                 _sum0 = _mm_fmadd_ps(_val0, _w0, _sum0);
1030                 __m128 _w1 = _mm_loadu_ps(kptr + 4);
1031                 _sum1 = _mm_fmadd_ps(_val1, _w1, _sum1);
1032                 __m128 _w2 = _mm_loadu_ps(kptr + 8);
1033                 _sum2 = _mm_fmadd_ps(_val2, _w2, _sum2);
1034                 __m128 _w3 = _mm_loadu_ps(kptr + 12);
1035                 _sum3 = _mm_fmadd_ps(_val3, _w3, _sum3);
1036                 __m128 _w4 = _mm_loadu_ps(kptr + 16);
1037                 _sum4 = _mm_fmadd_ps(_val4, _w4, _sum4);
1038                 __m128 _w5 = _mm_loadu_ps(kptr + 20);
1039                 _sum5 = _mm_fmadd_ps(_val5, _w5, _sum5);
1040                 __m128 _w6 = _mm_loadu_ps(kptr + 24);
1041                 _sum6 = _mm_fmadd_ps(_val6, _w6, _sum6);
1042                 __m128 _w7 = _mm_loadu_ps(kptr + 28);
1043                 _sum7 = _mm_fmadd_ps(_val7, _w7, _sum7);
1044 
1045                 sptr += 8;
1046                 kptr += 32;
1047             }
1048 #endif
1049             for (; i + 3 < num_input; i += 4)
1050             {
1051                 __m128 _val0 = _mm_set1_ps(sptr[0]);
1052                 __m128 _val1 = _mm_set1_ps(sptr[1]);
1053                 __m128 _val2 = _mm_set1_ps(sptr[2]);
1054                 __m128 _val3 = _mm_set1_ps(sptr[3]);
1055 
1056                 __m128 _w0 = _mm_loadu_ps(kptr);
1057                 _sum0 = _mm_add_ps(_mm_mul_ps(_val0, _w0), _sum0);
1058                 __m128 _w1 = _mm_loadu_ps(kptr + 4);
1059                 _sum1 = _mm_add_ps(_mm_mul_ps(_val1, _w1), _sum1);
1060                 __m128 _w2 = _mm_loadu_ps(kptr + 8);
1061                 _sum2 = _mm_add_ps(_mm_mul_ps(_val2, _w2), _sum2);
1062                 __m128 _w3 = _mm_loadu_ps(kptr + 12);
1063                 _sum3 = _mm_add_ps(_mm_mul_ps(_val3, _w3), _sum3);
1064 
1065                 sptr += 4;
1066                 kptr += 16;
1067             }
1068             for (; i < num_input; i++)
1069             {
1070                 __m128 _val = _mm_set1_ps(sptr[0]);
1071                 __m128 _w = _mm_loadu_ps(kptr);
1072                 _sum0 = _mm_add_ps(_mm_mul_ps(_val, _w), _sum0);
1073 
1074                 sptr += 1;
1075                 kptr += 4;
1076             }
1077 
1078             _sum0 = _mm_add_ps(_sum0, _sum1);
1079             _sum2 = _mm_add_ps(_sum2, _sum3);
1080 #if __AVX__
1081             _sum4 = _mm_add_ps(_sum4, _sum5);
1082             _sum6 = _mm_add_ps(_sum6, _sum7);
1083 #endif
1084             _sum0 = _mm_add_ps(_sum0, _sum2);
1085 #if __AVX__
1086             _sum4 = _mm_add_ps(_sum4, _sum6);
1087             _sum0 = _mm_add_ps(_sum0, _sum4);
1088 #endif
1089 
1090             _sum0 = activation_sse(_sum0, activation_type, activation_params);
1091 
1092             float* outptr = top_blob;
1093             _mm_storeu_ps(outptr + p * 4, _sum0);
1094         }
1095     }
1096 #endif // __SSE2__
1097 
1098     if (out_elempack == 1)
1099     {
1100 #if __SSE2__
1101 #if __AVX__
1102         int remain_num_output_start = 0;
1103         int nn_num_output = num_output >> 3;
1104 
1105         #pragma omp parallel for num_threads(opt.num_threads)
1106         for (int pp = 0; pp < nn_num_output; pp++)
1107         {
1108             int p = pp * 8;
1109 
1110             float sums[8] = {0.0f};
1111             if (bias_term)
1112             {
1113                 sums[0] = bias_data[p];
1114                 sums[1] = bias_data[p + 1];
1115                 sums[2] = bias_data[p + 2];
1116                 sums[3] = bias_data[p + 3];
1117                 sums[4] = bias_data[p + 4];
1118                 sums[5] = bias_data[p + 5];
1119                 sums[6] = bias_data[p + 6];
1120                 sums[7] = bias_data[p + 7];
1121             }
1122 
1123             const float* w0 = (const float*)weight_data + num_input * p;
1124             const float* w1 = (const float*)weight_data + num_input * (p + 1);
1125             const float* w2 = (const float*)weight_data + num_input * (p + 2);
1126             const float* w3 = (const float*)weight_data + num_input * (p + 3);
1127             const float* w4 = (const float*)weight_data + num_input * (p + 4);
1128             const float* w5 = (const float*)weight_data + num_input * (p + 5);
1129             const float* w6 = (const float*)weight_data + num_input * (p + 6);
1130             const float* w7 = (const float*)weight_data + num_input * (p + 7);
1131 
1132             const float* m = bottom_blob_flattened;
1133 
1134             __m256 _sum0 = _mm256_set1_ps(0.f);
1135             __m256 _sum1 = _mm256_set1_ps(0.f);
1136             __m256 _sum2 = _mm256_set1_ps(0.f);
1137             __m256 _sum3 = _mm256_set1_ps(0.f);
1138             __m256 _sum4 = _mm256_set1_ps(0.f);
1139             __m256 _sum5 = _mm256_set1_ps(0.f);
1140             __m256 _sum6 = _mm256_set1_ps(0.f);
1141             __m256 _sum7 = _mm256_set1_ps(0.f);
1142 
1143             int i = 0;
1144             for (; i + 7 < num_input; i += 8)
1145             {
1146                 __m256 _m = _mm256_loadu_ps(m);
1147 
1148                 __m256 _w0 = _mm256_loadu_ps(w0);
1149                 _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1150                 __m256 _w1 = _mm256_loadu_ps(w1);
1151                 _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1152                 __m256 _w2 = _mm256_loadu_ps(w2);
1153                 _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1154                 __m256 _w3 = _mm256_loadu_ps(w3);
1155                 _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1156                 __m256 _w4 = _mm256_loadu_ps(w4);
1157                 _sum4 = _mm256_fmadd_ps(_m, _w4, _sum4);
1158                 __m256 _w5 = _mm256_loadu_ps(w5);
1159                 _sum5 = _mm256_fmadd_ps(_m, _w5, _sum5);
1160                 __m256 _w6 = _mm256_loadu_ps(w6);
1161                 _sum6 = _mm256_fmadd_ps(_m, _w6, _sum6);
1162                 __m256 _w7 = _mm256_loadu_ps(w7);
1163                 _sum7 = _mm256_fmadd_ps(_m, _w7, _sum7);
1164 
1165                 m += 8;
1166                 w0 += 8;
1167                 w1 += 8;
1168                 w2 += 8;
1169                 w3 += 8;
1170                 w4 += 8;
1171                 w5 += 8;
1172                 w6 += 8;
1173                 w7 += 8;
1174             }
1175             for (; i < num_input; i++)
1176             {
1177                 sums[0] += *m * *w0;
1178                 sums[1] += *m * *w1;
1179                 sums[2] += *m * *w2;
1180                 sums[3] += *m * *w3;
1181                 sums[4] += *m * *w4;
1182                 sums[5] += *m * *w5;
1183                 sums[6] += *m * *w6;
1184                 sums[7] += *m * *w7;
1185 
1186                 m++;
1187                 w0++;
1188                 w1++;
1189                 w2++;
1190                 w3++;
1191                 w4++;
1192                 w5++;
1193                 w6++;
1194                 w7++;
1195             }
1196 
1197             __m256 _sums = HorizontalSums(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7);
1198             __m256 _sums_f = _mm256_loadu_ps(sums);
1199             _sums = _mm256_add_ps(_sums_f, _sums);
1200             _sums = activation_avx(_sums, activation_type, activation_params);
1201 
1202             float* outptr = top_blob;
1203             _mm256_storeu_ps(outptr + p, _sums);
1204         }
1205 
1206         remain_num_output_start += (nn_num_output << 3);
1207         nn_num_output = (num_output - remain_num_output_start) >> 2;
1208 #else
1209         int remain_num_output_start = 0;
1210         int nn_num_output = num_output >> 2;
1211 #endif // __AVX__
1212 
1213         #pragma omp parallel for num_threads(opt.num_threads)
1214         for (int pp = 0; pp < nn_num_output; pp++)
1215         {
1216             int p = remain_num_output_start + (pp * 4);
1217 
1218             float sums[4] = {0.0f};
1219             if (bias_term)
1220             {
1221                 sums[0] = bias_data[p];
1222                 sums[1] = bias_data[p + 1];
1223                 sums[2] = bias_data[p + 2];
1224                 sums[3] = bias_data[p + 3];
1225             }
1226 
1227             const float* w0 = (const float*)weight_data + num_input * p;
1228             const float* w1 = (const float*)weight_data + num_input * (p + 1);
1229             const float* w2 = (const float*)weight_data + num_input * (p + 2);
1230             const float* w3 = (const float*)weight_data + num_input * (p + 3);
1231 
1232             const float* m = bottom_blob_flattened;
1233 
1234             int i = 0;
1235 #if __AVX__
1236             __m256 _sum0 = _mm256_set1_ps(0.f);
1237             __m256 _sum1 = _mm256_set1_ps(0.f);
1238             __m256 _sum2 = _mm256_set1_ps(0.f);
1239             __m256 _sum3 = _mm256_set1_ps(0.f);
1240             for (; i + 7 < num_input; i += 8)
1241             {
1242                 __m256 _m = _mm256_loadu_ps(m);
1243 
1244                 __m256 _w0 = _mm256_loadu_ps(w0);
1245                 _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1246                 __m256 _w1 = _mm256_loadu_ps(w1);
1247                 _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1248                 __m256 _w2 = _mm256_loadu_ps(w2);
1249                 _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1250                 __m256 _w3 = _mm256_loadu_ps(w3);
1251                 _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1252 
1253                 m += 8;
1254                 w0 += 8;
1255                 w1 += 8;
1256                 w2 += 8;
1257                 w3 += 8;
1258             }
1259 #endif // __AVX__
1260             __m128 _sum0l = _mm_set1_ps(0.f);
1261             __m128 _sum1l = _mm_set1_ps(0.f);
1262             __m128 _sum2l = _mm_set1_ps(0.f);
1263             __m128 _sum3l = _mm_set1_ps(0.f);
1264             for (; i + 3 < num_input; i += 4)
1265             {
1266                 __m128 _m = _mm_loadu_ps(m);
1267 
1268                 __m128 _w0 = _mm_loadu_ps(w0);
1269                 _sum0l = _mm_add_ps(_mm_mul_ps(_m, _w0), _sum0l);
1270                 __m128 _w1 = _mm_loadu_ps(w1);
1271                 _sum1l = _mm_add_ps(_mm_mul_ps(_m, _w1), _sum1l);
1272                 __m128 _w2 = _mm_loadu_ps(w2);
1273                 _sum2l = _mm_add_ps(_mm_mul_ps(_m, _w2), _sum2l);
1274                 __m128 _w3 = _mm_loadu_ps(w3);
1275                 _sum3l = _mm_add_ps(_mm_mul_ps(_m, _w3), _sum3l);
1276 
1277                 m += 4;
1278                 w0 += 4;
1279                 w1 += 4;
1280                 w2 += 4;
1281                 w3 += 4;
1282             }
1283             for (; i < num_input; i++)
1284             {
1285                 sums[0] += *m * *w0;
1286                 sums[1] += *m * *w1;
1287                 sums[2] += *m * *w2;
1288                 sums[3] += *m * *w3;
1289 
1290                 m++;
1291                 w0++;
1292                 w1++;
1293                 w2++;
1294                 w3++;
1295             }
1296 
1297             __m128 _sums = _mm_loadu_ps(sums);
1298 #if __AVX__
1299             _sums = _mm_add_ps(HorizontalSums(_sum0, _sum1, _sum2, _sum3), _sums);
1300 #endif
1301             _MM_TRANSPOSE4_PS(_sum0l, _sum1l, _sum2l, _sum3l);
1302             _sums = _mm_add_ps(_sum0l, _sums);
1303             _sums = _mm_add_ps(_sum1l, _sums);
1304             _sums = _mm_add_ps(_sum2l, _sums);
1305             _sums = _mm_add_ps(_sum3l, _sums);
1306             _sums = activation_sse(_sums, activation_type, activation_params);
1307 
1308             float* outptr = top_blob;
1309             _mm_storeu_ps(outptr + p, _sums);
1310         }
1311 
1312         remain_num_output_start += (nn_num_output << 2);
1313 #else
1314         int remain_num_output_start = 0;
1315 #endif // __SSE2__
1316 
1317         // num_output
1318         #pragma omp parallel for num_threads(opt.num_threads)
1319         for (int p = remain_num_output_start; p < num_output; p++)
1320         {
1321             float sum = 0.f;
1322 
1323             if (bias_term)
1324                 sum = bias_data[p];
1325 
1326             const float* w = (const float*)weight_data + num_input * p;
1327 
1328             const float* m = bottom_blob_flattened;
1329 
1330             int i = 0;
1331 #if __SSE2__
1332 #if __AVX__
1333             __m256 _sum = _mm256_set1_ps(0.f);
1334             for (; i + 7 < num_input; i += 8)
1335             {
1336                 __m256 _m = _mm256_loadu_ps(m);
1337 
1338                 __m256 _w = _mm256_loadu_ps(w);
1339                 _sum = _mm256_fmadd_ps(_m, _w, _sum);
1340 
1341                 m += 8;
1342                 w += 8;
1343             }
1344 #endif // __AVX__
1345             __m128 _suml = _mm_set1_ps(0.f);
1346             for (; i + 3 < num_input; i += 4)
1347             {
1348                 __m128 _m = _mm_loadu_ps(m);
1349 
1350                 __m128 _w = _mm_loadu_ps(w);
1351                 _suml = _mm_add_ps(_mm_mul_ps(_m, _w), _suml);
1352 
1353                 m += 4;
1354                 w += 4;
1355             }
1356 #endif // __SSE2__
1357             for (; i < num_input; i++)
1358             {
1359                 sum += *m * *w;
1360                 m++;
1361                 w++;
1362             }
1363 
1364 #if __SSE2__
1365 #if __AVX__
1366             sum += _mm256_reduce_add_ps(_sum);
1367 #endif
1368             sum += _mm_reduce_add_ps(_suml);
1369 #endif // __SSE2__
1370 
1371             sum = activation_ss(sum, activation_type, activation_params);
1372 
1373             float* outptr = top_blob;
1374             outptr[p] = sum;
1375         }
1376     }
1377 
1378     return 0;
1379 }
1380 #if __AVX__
1381 
forward_fp16(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1382 int InnerProduct_x86::forward_fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1383 {
1384     // flatten
1385     Mat bottom_blob_flattened = bottom_blob;
1386     if (bottom_blob.dims != 1)
1387     {
1388         Option opt_flatten = opt;
1389         opt_flatten.blob_allocator = opt.workspace_allocator;
1390 
1391         flatten->forward(bottom_blob, bottom_blob_flattened, opt_flatten);
1392     }
1393 
1394     // pack1
1395     {
1396         bottom_blob_flattened.w *= bottom_blob_flattened.elempack;
1397         bottom_blob_flattened.cstep = bottom_blob_flattened.w;
1398         bottom_blob_flattened.elemsize = 4u;
1399         bottom_blob_flattened.elempack = 1;
1400     }
1401 
1402     int w = bottom_blob_flattened.w;
1403     int h = bottom_blob_flattened.h;
1404     size_t elemsize = bottom_blob_flattened.elemsize;
1405     int size = w * h;
1406     top_blob.create(num_output, elemsize, opt.blob_allocator);
1407     if (top_blob.empty())
1408         return -100;
1409 
1410     const unsigned short* weight_data_ptr = (const unsigned short*)weight_data_fp16;
1411     float* output_ptr = top_blob;
1412     int nn_num_output = num_output >> 3;
1413     int remain_num_output_start = nn_num_output << 3;
1414 
1415     #pragma omp parallel for num_threads(opt.num_threads)
1416     for (int pp = 0; pp < nn_num_output; pp++)
1417     {
1418         int p = pp * 8;
1419 
1420         float sums[8] = {0.0f};
1421         if (bias_term)
1422         {
1423             sums[0] = bias_data[p];
1424             sums[1] = bias_data[p + 1];
1425             sums[2] = bias_data[p + 2];
1426             sums[3] = bias_data[p + 3];
1427             sums[4] = bias_data[p + 4];
1428             sums[5] = bias_data[p + 5];
1429             sums[6] = bias_data[p + 6];
1430             sums[7] = bias_data[p + 7];
1431         }
1432         __m256 _sum0 = _mm256_set1_ps(0.f);
1433         __m256 _sum1 = _mm256_set1_ps(0.f);
1434         __m256 _sum2 = _mm256_set1_ps(0.f);
1435         __m256 _sum3 = _mm256_set1_ps(0.f);
1436         __m256 _sum4 = _mm256_set1_ps(0.f);
1437         __m256 _sum5 = _mm256_set1_ps(0.f);
1438         __m256 _sum6 = _mm256_set1_ps(0.f);
1439         __m256 _sum7 = _mm256_set1_ps(0.f);
1440 
1441         const unsigned short* w0 = (const unsigned short*)weight_data_ptr + size * p;
1442         const unsigned short* w1 = (const unsigned short*)weight_data_ptr + size * (p + 1);
1443         const unsigned short* w2 = (const unsigned short*)weight_data_ptr + size * (p + 2);
1444         const unsigned short* w3 = (const unsigned short*)weight_data_ptr + size * (p + 3);
1445         const unsigned short* w4 = (const unsigned short*)weight_data_ptr + size * (p + 4);
1446         const unsigned short* w5 = (const unsigned short*)weight_data_ptr + size * (p + 5);
1447         const unsigned short* w6 = (const unsigned short*)weight_data_ptr + size * (p + 6);
1448         const unsigned short* w7 = (const unsigned short*)weight_data_ptr + size * (p + 7);
1449 
1450         const float* m = bottom_blob_flattened;
1451         int nn = size >> 3;
1452         int remain = size & 7;
1453 
1454         for (; nn > 0; nn--)
1455         {
1456             __m256 _m = _mm256_loadu_ps(m);
1457 
1458             __m256 _w0 = loadfp16(w0);
1459             _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1460 
1461             __m256 _w1 = loadfp16(w1);
1462             _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1463 
1464             __m256 _w2 = loadfp16(w2);
1465             _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1466 
1467             __m256 _w3 = loadfp16(w3);
1468             _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1469 
1470             __m256 _w4 = loadfp16(w4);
1471             _sum4 = _mm256_fmadd_ps(_m, _w4, _sum4);
1472 
1473             __m256 _w5 = loadfp16(w5);
1474             _sum5 = _mm256_fmadd_ps(_m, _w5, _sum5);
1475 
1476             __m256 _w6 = loadfp16(w6);
1477             _sum6 = _mm256_fmadd_ps(_m, _w6, _sum6);
1478 
1479             __m256 _w7 = loadfp16(w7);
1480             _sum7 = _mm256_fmadd_ps(_m, _w7, _sum7);
1481 
1482             m += 8;
1483             w0 += 8;
1484             w1 += 8;
1485             w2 += 8;
1486             w3 += 8;
1487             w4 += 8;
1488             w5 += 8;
1489             w6 += 8;
1490             w7 += 8;
1491         }
1492         if (remain != 0)
1493         {
1494             unsigned short fp16_weights[8][8] = {{0}};
1495             float _m_f[8] = {0};
1496             int i = 0;
1497             // No fast way to convert to fp32 one element at the time
1498             // so batch an 8 lane vector.
1499             for (; remain > 0; remain--)
1500             {
1501                 _m_f[i] = *m;
1502                 fp16_weights[0][i] = *w0;
1503                 fp16_weights[1][i] = *w1;
1504                 fp16_weights[2][i] = *w2;
1505                 fp16_weights[3][i] = *w3;
1506                 fp16_weights[4][i] = *w4;
1507                 fp16_weights[5][i] = *w5;
1508                 fp16_weights[6][i] = *w6;
1509                 fp16_weights[7][i] = *w7;
1510                 i++;
1511                 m++;
1512                 w0++;
1513                 w1++;
1514                 w2++;
1515                 w3++;
1516                 w4++;
1517                 w5++;
1518                 w6++;
1519                 w7++;
1520             }
1521             __m256 _m = _mm256_loadu_ps(_m_f);
1522 
1523             __m256 _w0 = loadfp16(fp16_weights[0]);
1524             _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1525 
1526             __m256 _w1 = loadfp16(fp16_weights[1]);
1527             _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1528 
1529             __m256 _w2 = loadfp16(fp16_weights[2]);
1530             _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1531 
1532             __m256 _w3 = loadfp16(fp16_weights[3]);
1533             _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1534 
1535             __m256 _w4 = loadfp16(fp16_weights[4]);
1536             _sum4 = _mm256_fmadd_ps(_m, _w4, _sum4);
1537 
1538             __m256 _w5 = loadfp16(fp16_weights[5]);
1539             _sum5 = _mm256_fmadd_ps(_m, _w5, _sum5);
1540 
1541             __m256 _w6 = loadfp16(fp16_weights[6]);
1542             _sum6 = _mm256_fmadd_ps(_m, _w6, _sum6);
1543 
1544             __m256 _w7 = loadfp16(fp16_weights[7]);
1545             _sum7 = _mm256_fmadd_ps(_m, _w7, _sum7);
1546         }
1547 
1548         __m256 _sums = HorizontalSums(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7);
1549         __m256 _sums_f = _mm256_loadu_ps(sums);
1550         _sums = activation_avx(_mm256_add_ps(_sums_f, _sums), activation_type, activation_params);
1551         _mm256_storeu_ps(output_ptr + p, _sums);
1552     }
1553 
1554     nn_num_output = (num_output - remain_num_output_start) >> 2;
1555     int nn_offset = remain_num_output_start;
1556     remain_num_output_start += (nn_num_output << 2);
1557 
1558     #pragma omp parallel for num_threads(opt.num_threads)
1559     for (int pp = 0; pp < nn_num_output; pp++)
1560     {
1561         int p = nn_offset + (pp * 4);
1562 
1563         float sums[4] = {0.0f};
1564         if (bias_term)
1565         {
1566             sums[0] = bias_data[p];
1567             sums[1] = bias_data[p + 1];
1568             sums[2] = bias_data[p + 2];
1569             sums[3] = bias_data[p + 3];
1570         }
1571         __m256 _sum0 = _mm256_set1_ps(0.f);
1572         __m256 _sum1 = _mm256_set1_ps(0.f);
1573         __m256 _sum2 = _mm256_set1_ps(0.f);
1574         __m256 _sum3 = _mm256_set1_ps(0.f);
1575 
1576         const unsigned short* w0 = (const unsigned short*)weight_data_ptr + size * p;
1577         const unsigned short* w1 = (const unsigned short*)weight_data_ptr + size * (p + 1);
1578         const unsigned short* w2 = (const unsigned short*)weight_data_ptr + size * (p + 2);
1579         const unsigned short* w3 = (const unsigned short*)weight_data_ptr + size * (p + 3);
1580 
1581         const float* m = bottom_blob_flattened;
1582         int nn = size >> 3;
1583         int remain = size & 7;
1584 
1585         for (; nn > 0; nn--)
1586         {
1587             __m256 _m = _mm256_loadu_ps(m);
1588 
1589             __m256 _w0 = loadfp16(w0);
1590             _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1591 
1592             __m256 _w1 = loadfp16(w1);
1593             _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1594 
1595             __m256 _w2 = loadfp16(w2);
1596             _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1597 
1598             __m256 _w3 = loadfp16(w3);
1599             _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1600 
1601             m += 8;
1602             w0 += 8;
1603             w1 += 8;
1604             w2 += 8;
1605             w3 += 8;
1606         }
1607         if (remain != 0)
1608         {
1609             unsigned short fp16_weights[4][8] = {{0}};
1610             float _m_f[8] = {0};
1611             int i = 0;
1612             for (; remain > 0; remain--)
1613             {
1614                 _m_f[i] = *m;
1615                 fp16_weights[0][i] = *w0;
1616                 fp16_weights[1][i] = *w1;
1617                 fp16_weights[2][i] = *w2;
1618                 fp16_weights[3][i] = *w3;
1619                 i++;
1620                 m++;
1621                 w0++;
1622                 w1++;
1623                 w2++;
1624                 w3++;
1625             }
1626             __m256 _m = _mm256_loadu_ps(_m_f);
1627 
1628             __m256 _w0 = loadfp16(fp16_weights[0]);
1629             _sum0 = _mm256_fmadd_ps(_m, _w0, _sum0);
1630 
1631             __m256 _w1 = loadfp16(fp16_weights[1]);
1632             _sum1 = _mm256_fmadd_ps(_m, _w1, _sum1);
1633 
1634             __m256 _w2 = loadfp16(fp16_weights[2]);
1635             _sum2 = _mm256_fmadd_ps(_m, _w2, _sum2);
1636 
1637             __m256 _w3 = loadfp16(fp16_weights[3]);
1638             _sum3 = _mm256_fmadd_ps(_m, _w3, _sum3);
1639         }
1640 
1641         __m128 _sums = HorizontalSums(_sum0, _sum1, _sum2, _sum3);
1642         __m256 _sums_a = activation_avx(_mm256_castps128_ps256(_mm_add_ps(_mm_loadu_ps(sums), _sums)), activation_type, activation_params);
1643         _mm_storeu_ps(output_ptr + p, _mm256_castps256_ps128(_sums_a));
1644     }
1645 
1646 // num_output
1647     #pragma omp parallel for num_threads(opt.num_threads)
1648     for (int p = remain_num_output_start; p < num_output; p++)
1649     {
1650         float sum = 0.f;
1651 
1652         if (bias_term)
1653             sum = bias_data[p];
1654 
1655         const unsigned short* w = (const unsigned short*)weight_data_ptr + size * p;
1656 
1657         __m256 _sum = _mm256_set1_ps(0.f);
1658 
1659         const float* m = bottom_blob_flattened;
1660 
1661         int nn = size >> 3;
1662         int remain = size & 7;
1663         for (; nn > 0; nn--)
1664         {
1665             __m256 _m = _mm256_loadu_ps(m);
1666 
1667             __m256 _w = loadfp16(w);
1668             _sum = _mm256_fmadd_ps(_m, _w, _sum);
1669 
1670             m += 8;
1671             w += 8;
1672         }
1673         if (remain != 0)
1674         {
1675             unsigned short fp16_weights[8] = {0};
1676             float _m_f[8] = {0};
1677             int i = 0;
1678             for (; remain > 0; remain--)
1679             {
1680                 _m_f[i] = *m;
1681                 fp16_weights[i] = *w;
1682                 i++;
1683                 m++;
1684                 w++;
1685             }
1686             __m256 _m = _mm256_loadu_ps(_m_f);
1687 
1688             __m256 _w = loadfp16(fp16_weights);
1689             _sum = _mm256_fmadd_ps(_m, _w, _sum);
1690         }
1691 
1692         sum += _mm256_reduce_add_ps(_sum);
1693         sum = activation_ss(sum, activation_type, activation_params);
1694 
1695         output_ptr[p] = sum;
1696     }
1697     return 0;
1698 }
1699 #endif // __AVX__
1700 
1701 #if NCNN_INT8
create_pipeline_int8_x86(const Option & opt)1702 int InnerProduct_x86::create_pipeline_int8_x86(const Option& opt)
1703 {
1704     if (activation_type == 1)
1705     {
1706         activation = ncnn::create_layer(ncnn::LayerType::ReLU);
1707 
1708         ncnn::ParamDict pd;
1709         activation->load_param(pd);
1710     }
1711     else if (activation_type == 2)
1712     {
1713         activation = ncnn::create_layer(ncnn::LayerType::ReLU);
1714 
1715         ncnn::ParamDict pd;
1716         pd.set(0, activation_params[0]); // slope
1717         activation->load_param(pd);
1718     }
1719     else if (activation_type == 3)
1720     {
1721         activation = ncnn::create_layer(ncnn::LayerType::Clip);
1722 
1723         ncnn::ParamDict pd;
1724         pd.set(0, activation_params[0]); // min
1725         pd.set(1, activation_params[1]); // max
1726         activation->load_param(pd);
1727     }
1728     else if (activation_type == 4)
1729     {
1730         activation = ncnn::create_layer(ncnn::LayerType::Sigmoid);
1731 
1732         ncnn::ParamDict pd;
1733         activation->load_param(pd);
1734     }
1735     else if (activation_type == 5)
1736     {
1737         activation = ncnn::create_layer(ncnn::LayerType::Mish);
1738 
1739         ncnn::ParamDict pd;
1740         activation->load_param(pd);
1741     }
1742 
1743     if (activation)
1744     {
1745         activation->create_pipeline(opt);
1746     }
1747 
1748     const int num_input = weight_data_size / num_output;
1749 
1750     int out_elempack = 1;
1751 #if __SSE2__
1752     if (opt.use_packing_layout)
1753     {
1754         out_elempack = num_output % 8 == 0 ? 8 : 1;
1755     }
1756 #endif // __SSE2__
1757 
1758     // src = inch-outch
1759     // dst = pb-inch-outch/pb
1760     {
1761         Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
1762 
1763         weight_data_int8.create(num_input, num_output / out_elempack, (size_t)out_elempack, out_elempack);
1764 
1765         for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack)
1766         {
1767             signed char* g0 = weight_data_int8.row<signed char>(q / out_elempack);
1768 
1769             for (int p = 0; p < num_input; p++)
1770             {
1771                 for (int j = 0; j < out_elempack; j++)
1772                 {
1773                     *g0++ = weight_data_r2.row<signed char>(q + j)[p];
1774                 }
1775             }
1776         }
1777     }
1778 
1779     return 0;
1780 }
1781 
forward_int8_x86(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1782 int InnerProduct_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1783 {
1784     const int num_input = weight_data_size / num_output;
1785 
1786     if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1)
1787     {
1788         // gemm
1789         Mat bottom_blob_unpacked;
1790         Option opt_unpack = opt;
1791         opt_unpack.blob_allocator = opt.workspace_allocator;
1792         convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_unpack);
1793 
1794         return forward_int8(bottom_blob_unpacked, top_blob, opt);
1795     }
1796 
1797     int elembits = bottom_blob.elembits();
1798 
1799     Mat bottom_blob_int8 = bottom_blob;
1800     if (elembits != 8)
1801     {
1802         Option opt_q = opt;
1803         opt_q.blob_allocator = opt.workspace_allocator;
1804         quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q);
1805     }
1806 
1807     Mat bottom_blob_int8_flattened = bottom_blob_int8;
1808     if (bottom_blob_int8.dims != 1)
1809     {
1810         Option opt_flatten = opt;
1811         opt_flatten.blob_allocator = opt.workspace_allocator;
1812         flatten->forward(bottom_blob_int8, bottom_blob_int8_flattened, opt_flatten);
1813     }
1814 
1815     //     int elempack = bottom_blob_int8_flattened.elempack;
1816 
1817     int out_elempack = 1;
1818 #if __SSE2__
1819     if (opt.use_packing_layout)
1820     {
1821         out_elempack = num_output % 8 == 0 ? 8 : 1;
1822     }
1823 #endif // __SSE2__
1824     //     size_t out_elemsize = elemsize / elempack * out_elempack;
1825 
1826     top_blob.create(num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.blob_allocator);
1827     if (top_blob.empty())
1828         return -100;
1829 
1830     Mat top_blob_int32;
1831     top_blob_int32.create(num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.workspace_allocator);
1832     if (top_blob_int32.empty())
1833         return -100;
1834 
1835 #if __SSE2__
1836     if (out_elempack == 8)
1837     {
1838         // num_output
1839         #pragma omp parallel for num_threads(opt.num_threads)
1840         for (int p = 0; p < num_output / out_elempack; p++)
1841         {
1842             __m128i _sum0 = _mm_setzero_si128();
1843             __m128i _sum1 = _mm_setzero_si128();
1844 
1845             const signed char* kptr = weight_data_int8.row<const signed char>(p);
1846             const signed char* sptr = bottom_blob_int8_flattened;
1847 
1848             int i = 0;
1849             for (; i < num_input; i++)
1850             {
1851                 __m128i _val = _mm_set1_epi16((short)sptr[0]);
1852 
1853                 // TODO use _mm_cvtepi8_epi16 on sse4.1
1854                 __m128i _w = _mm_loadl_epi64((const __m128i*)kptr);
1855                 _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w));
1856 
1857                 __m128i _sl = _mm_mullo_epi16(_val, _w);
1858                 __m128i _sh = _mm_mulhi_epi16(_val, _w);
1859                 __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh);
1860                 __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh);
1861 
1862                 _sum0 = _mm_add_epi32(_sum0, _s0);
1863                 _sum1 = _mm_add_epi32(_sum1, _s1);
1864 
1865                 sptr += 1;
1866                 kptr += 8;
1867             }
1868 
1869             int* outptr = (int*)top_blob_int32;
1870             _mm_storeu_si128((__m128i*)(outptr + p * 8), _sum0);
1871             _mm_storeu_si128((__m128i*)(outptr + p * 8 + 4), _sum1);
1872         }
1873     }
1874 #endif // __SSE2__
1875 
1876     if (out_elempack == 1)
1877     {
1878         // num_output
1879         #pragma omp parallel for num_threads(opt.num_threads)
1880         for (int p = 0; p < num_output / out_elempack; p++)
1881         {
1882             int sum = 0;
1883 
1884             const signed char* kptr = weight_data_int8.row<const signed char>(p);
1885             const signed char* sptr = bottom_blob_int8_flattened;
1886 
1887             int i = 0;
1888             for (; i < num_input; i++)
1889             {
1890                 signed char val = sptr[0];
1891 
1892                 signed char w = kptr[0];
1893 
1894                 sum += val * w;
1895 
1896                 sptr += 1;
1897                 kptr += 1;
1898             }
1899 
1900             int* outptr = (int*)top_blob_int32;
1901             outptr[p] = sum;
1902         }
1903     }
1904 
1905     Mat scale_data(num_output);
1906     for (int p = 0; p < num_output; p++)
1907     {
1908         // dequantize
1909         float scale_in;
1910         if (weight_data_int8_scales[p] == 0)
1911             scale_in = 0;
1912         else
1913             scale_in = 1.f / (bottom_blob_int8_scales[0] * weight_data_int8_scales[p]);
1914 
1915         scale_data[p] = scale_in;
1916     }
1917 
1918     dequantize_from_int32(top_blob_int32, top_blob, scale_data, bias_data, opt);
1919 
1920     if (activation)
1921     {
1922         activation->forward_inplace(top_blob, opt);
1923     }
1924 
1925     return 0;
1926 }
1927 #endif // NCNN_INT8
1928 
1929 } // namespace ncnn
1930