1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "binaryop_x86.h"
16 
17 #if __SSE2__
18 #include "sse_mathfun.h"
19 #if __AVX__
20 #include "avx_mathfun.h"
21 #endif // __AVX__
22 #endif // __SSE2__
23 
24 #include <math.h>
25 
26 namespace ncnn {
27 
BinaryOp_x86()28 BinaryOp_x86::BinaryOp_x86()
29 {
30 #if __SSE2__
31     support_packing = true;
32 #endif // __SSE2__
33 }
34 
35 #if __SSE2__
36 #if __AVX__
37 // broadcasting rule
38 // https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
39 
40 template<typename Op>
binary_op_pack8(const Mat & a,const Mat & b,Mat & c,const Option & opt)41 static int binary_op_pack8(const Mat& a, const Mat& b, Mat& c, const Option& opt)
42 {
43     Op op;
44 
45     int w = a.w;
46     int h = a.h;
47     int channels = a.c;
48     int size = w * h;
49     size_t elemsize = a.elemsize;
50     int elempack = a.elempack;
51 
52     int w1 = b.w;
53     int h1 = b.h;
54     int channels1 = b.c;
55     int size1 = w1 * h1;
56     size_t elemsize1 = b.elemsize;
57     int elempack1 = b.elempack;
58 
59     if (a.dims == 3)
60     {
61         if (b.dims == 3)
62         {
63             if (w1 == 1 && h1 == 1 && channels1 == channels)
64             {
65                 // special type 1
66                 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
67                 if (c.empty())
68                     return -100;
69 
70                 #pragma omp parallel for num_threads(opt.num_threads)
71                 for (int q = 0; q < channels; q++)
72                 {
73                     const float* ptr = a.channel(q);
74                     const float* b0 = b.channel(q);
75                     float* outptr = c.channel(q);
76                     __m256 _b0 = _mm256_loadu_ps(b0);
77                     for (int i = 0; i < size; i++)
78                     {
79                         __m256 _p = _mm256_loadu_ps(ptr);
80                         __m256 _outp = op(_p, _b0);
81                         _mm256_storeu_ps(outptr, _outp);
82                         ptr += 8;
83                         outptr += 8;
84                     }
85                 }
86 
87                 return 0;
88             }
89 
90             if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1)
91             {
92                 // special type 2
93                 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
94                 if (c.empty())
95                     return -100;
96 
97                 #pragma omp parallel for num_threads(opt.num_threads)
98                 for (int q = 0; q < channels; q++)
99                 {
100                     const float* ptr = a.channel(q);
101                     const float* ptr1 = b;
102                     float* outptr = c.channel(q);
103                     for (int i = 0; i < size; i++)
104                     {
105                         __m256 _p = _mm256_loadu_ps(ptr);
106                         __m256 _p1 = _mm256_broadcast_ss(ptr1);
107                         __m256 _outp = op(_p, _p1);
108                         _mm256_storeu_ps(outptr, _outp);
109                         ptr += 8;
110                         ptr1 += 1;
111                         outptr += 8;
112                     }
113                 }
114 
115                 return 0;
116             }
117 
118             if (w == 1 && h == 1 && channels1 == channels)
119             {
120                 // special type 3
121                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
122                 if (c.empty())
123                     return -100;
124 
125                 #pragma omp parallel for num_threads(opt.num_threads)
126                 for (int q = 0; q < channels1; q++)
127                 {
128                     const float* a0 = a.channel(q);
129                     const float* ptr1 = b.channel(q);
130                     float* outptr = c.channel(q);
131                     __m256 _a0 = _mm256_loadu_ps(a0);
132                     for (int i = 0; i < size1; i++)
133                     {
134                         __m256 _p1 = _mm256_loadu_ps(ptr1);
135                         __m256 _outp = op(_a0, _p1);
136                         _mm256_storeu_ps(outptr, _outp);
137                         ptr1 += 8;
138                         outptr += 8;
139                     }
140                 }
141 
142                 return 0;
143             }
144 
145             if (w1 == w && h1 == h && channels == 1 && elempack == 1)
146             {
147                 // special type 4
148                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
149                 if (c.empty())
150                     return -100;
151 
152                 #pragma omp parallel for num_threads(opt.num_threads)
153                 for (int q = 0; q < channels1; q++)
154                 {
155                     const float* ptr = a;
156                     const float* ptr1 = b.channel(q);
157                     float* outptr = c.channel(q);
158                     for (int i = 0; i < size1; i++)
159                     {
160                         __m256 _p = _mm256_broadcast_ss(ptr);
161                         __m256 _p1 = _mm256_loadu_ps(ptr1);
162                         __m256 _outp = op(_p, _p1);
163                         _mm256_storeu_ps(outptr, _outp);
164                         ptr += 1;
165                         ptr1 += 8;
166                         outptr += 8;
167                     }
168                 }
169 
170                 return 0;
171             }
172 
173             if (w != 1 && w1 == 1 && h1 == h && channels1 == channels)
174             {
175                 // special type 5
176                 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
177                 if (c.empty())
178                     return -100;
179 
180                 #pragma omp parallel for num_threads(opt.num_threads)
181                 for (int q = 0; q < channels1; q++)
182                 {
183                     const float* ptr = a.channel(q);
184                     const float* ptr1 = b.channel(q);
185                     float* outptr = c.channel(q);
186 
187                     for (int y = 0; y < h; y++)
188                     {
189                         __m256 _p1 = _mm256_loadu_ps(ptr1 + y * 8);
190                         for (int x = 0; x < w; x++)
191                         {
192                             __m256 _p = _mm256_loadu_ps(ptr);
193                             __m256 _outp = op(_p, _p1);
194                             _mm256_storeu_ps(outptr, _outp);
195 
196                             ptr += 8;
197                             outptr += 8;
198                         }
199                     }
200                 }
201 
202                 return 0;
203             }
204 
205             if (w1 == w && h != 1 && h1 == 1 && channels1 == channels)
206             {
207                 // special type 6
208                 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
209                 if (c.empty())
210                     return -100;
211 
212                 #pragma omp parallel for num_threads(opt.num_threads)
213                 for (int q = 0; q < channels1; q++)
214                 {
215                     const float* ptr = a.channel(q);
216                     const float* ptr1 = b.channel(q);
217                     float* outptr = c.channel(q);
218 
219                     for (int y = 0; y < h; y++)
220                     {
221                         for (int x = 0; x < w; x++)
222                         {
223                             __m256 _p = _mm256_loadu_ps(ptr);
224                             __m256 _p1 = _mm256_loadu_ps(ptr1 + x * 8);
225                             __m256 _outp = op(_p, _p1);
226                             _mm256_storeu_ps(outptr, _outp);
227 
228                             ptr += 8;
229                             outptr += 8;
230                         }
231                     }
232                 }
233 
234                 return 0;
235             }
236 
237             if (w1 != 1 && w == 1 && h1 == h && channels1 == channels)
238             {
239                 // special type 7
240                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
241                 if (c.empty())
242                     return -100;
243 
244                 #pragma omp parallel for num_threads(opt.num_threads)
245                 for (int q = 0; q < channels1; q++)
246                 {
247                     const float* ptr = a.channel(q);
248                     const float* ptr1 = b.channel(q);
249                     float* outptr = c.channel(q);
250 
251                     for (int y = 0; y < h1; y++)
252                     {
253                         __m256 _p = _mm256_loadu_ps(ptr + y * 8);
254                         for (int x = 0; x < w1; x++)
255                         {
256                             __m256 _p1 = _mm256_loadu_ps(ptr1);
257                             __m256 _outp = op(_p, _p1);
258                             _mm256_storeu_ps(outptr, _outp);
259 
260                             ptr1 += 8;
261                             outptr += 8;
262                         }
263                     }
264                 }
265 
266                 return 0;
267             }
268 
269             if (w1 == w && h1 != 1 && h == 1 && channels1 == channels)
270             {
271                 // special type 8
272                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
273                 if (c.empty())
274                     return -100;
275 
276                 #pragma omp parallel for num_threads(opt.num_threads)
277                 for (int q = 0; q < channels1; q++)
278                 {
279                     const float* ptr = a.channel(q);
280                     const float* ptr1 = b.channel(q);
281                     float* outptr = c.channel(q);
282 
283                     for (int y = 0; y < h1; y++)
284                     {
285                         for (int x = 0; x < w1; x++)
286                         {
287                             __m256 _p = _mm256_loadu_ps(ptr + x * 8);
288                             __m256 _p1 = _mm256_loadu_ps(ptr1);
289                             __m256 _outp = op(_p, _p1);
290                             _mm256_storeu_ps(outptr, _outp);
291 
292                             ptr1 += 8;
293                             outptr += 8;
294                         }
295                     }
296                 }
297 
298                 return 0;
299             }
300 
301             // type 19
302             c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
303             if (c.empty())
304                 return -100;
305 
306             #pragma omp parallel for num_threads(opt.num_threads)
307             for (int q = 0; q < channels; q++)
308             {
309                 const float* ptr = a.channel(q);
310                 const float* ptr1 = b.channel(q);
311                 float* outptr = c.channel(q);
312 
313                 for (int i = 0; i < size; i++)
314                 {
315                     __m256 _p = _mm256_loadu_ps(ptr);
316                     __m256 _p1 = _mm256_loadu_ps(ptr1);
317                     __m256 _outp = op(_p, _p1);
318                     _mm256_storeu_ps(outptr, _outp);
319                     ptr += 8;
320                     ptr1 += 8;
321                     outptr += 8;
322                 }
323             }
324 
325             return 0;
326         }
327 
328         c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
329         if (c.empty())
330             return -100;
331 
332         if (b.dims == 2)
333         {
334             // type 18
335             #pragma omp parallel for num_threads(opt.num_threads)
336             for (int q = 0; q < channels; q++)
337             {
338                 const float* ptr = a.channel(q);
339                 const float* ptr1 = b.row(q);
340                 float* outptr = c.channel(q);
341 
342                 for (int y = 0; y < h; y++)
343                 {
344                     __m256 _b0 = _mm256_loadu_ps(ptr1);
345                     for (int x = 0; x < w; x++)
346                     {
347                         __m256 _p = _mm256_loadu_ps(ptr);
348                         __m256 _outp = op(_p, _b0);
349                         _mm256_storeu_ps(outptr, _outp);
350                         ptr += 8;
351                         outptr += 8;
352                     }
353 
354                     ptr1 += 8;
355                 }
356             }
357 
358             return 0;
359         }
360 
361         if (b.dims == 1)
362         {
363             if (b.w == 1 && elempack1 == 1)
364             {
365                 // type 16
366                 __m256 _b0 = _mm256_set1_ps(b[0]);
367                 #pragma omp parallel for num_threads(opt.num_threads)
368                 for (int q = 0; q < channels; q++)
369                 {
370                     const float* ptr = a.channel(q);
371                     float* outptr = c.channel(q);
372 
373                     for (int i = 0; i < size; i++)
374                     {
375                         __m256 _p = _mm256_loadu_ps(ptr);
376                         __m256 _outp = op(_p, _b0);
377                         _mm256_storeu_ps(outptr, _outp);
378                         ptr += 8;
379                         outptr += 8;
380                     }
381                 }
382 
383                 return 0;
384             }
385 
386             // type 17
387             #pragma omp parallel for num_threads(opt.num_threads)
388             for (int q = 0; q < channels; q++)
389             {
390                 const float* ptr = a.channel(q);
391                 __m256 _b0 = _mm256_loadu_ps((const float*)b + q * 8);
392                 float* outptr = c.channel(q);
393 
394                 for (int i = 0; i < size; i++)
395                 {
396                     __m256 _p = _mm256_loadu_ps(ptr);
397                     __m256 _outp = op(_p, _b0);
398                     _mm256_storeu_ps(outptr, _outp);
399                     ptr += 8;
400                     outptr += 8;
401                 }
402             }
403 
404             return 0;
405         }
406     }
407     else if (a.dims == 2)
408     {
409         if (b.dims == 3)
410         {
411             // type 14
412             c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
413             if (c.empty())
414                 return -100;
415 
416             #pragma omp parallel for num_threads(opt.num_threads)
417             for (int q = 0; q < channels1; q++)
418             {
419                 const float* ptr = a.row(q);
420                 const float* ptr1 = b.channel(q);
421                 float* outptr = c.channel(q);
422 
423                 for (int y = 0; y < h1; y++)
424                 {
425                     __m256 _a0 = _mm256_loadu_ps(ptr);
426                     for (int x = 0; x < w1; x++)
427                     {
428                         __m256 _p1 = _mm256_loadu_ps(ptr1);
429                         __m256 _outp = op(_a0, _p1);
430                         _mm256_storeu_ps(outptr, _outp);
431                         ptr1 += 8;
432                         outptr += 8;
433                     }
434 
435                     ptr += 8;
436                 }
437             }
438 
439             return 0;
440         }
441 
442         c.create(w, h, elemsize, elempack, opt.blob_allocator);
443         if (c.empty())
444             return -100;
445 
446         if (b.dims == 2)
447         {
448             // type 13
449             const float* ptr = a;
450             const float* ptr1 = b;
451             float* outptr = c;
452             for (int i = 0; i < size; i++)
453             {
454                 __m256 _p = _mm256_loadu_ps(ptr);
455                 __m256 _p1 = _mm256_loadu_ps(ptr1);
456                 __m256 _outp = op(_p, _p1);
457                 _mm256_storeu_ps(outptr, _outp);
458                 ptr += 8;
459                 ptr1 += 8;
460                 outptr += 8;
461             }
462 
463             return 0;
464         }
465 
466         if (b.dims == 1)
467         {
468             c.create(w, h, elemsize, elempack, opt.blob_allocator);
469             if (c.empty())
470                 return -100;
471 
472             if (b.w == 1 && elempack1 == 1)
473             {
474                 // type 11
475                 __m256 _b0 = _mm256_set1_ps(b[0]);
476                 const float* ptr = a;
477                 float* outptr = c;
478                 for (int i = 0; i < size; i++)
479                 {
480                     __m256 _p = _mm256_loadu_ps(ptr);
481                     __m256 _outp = op(_p, _b0);
482                     _mm256_storeu_ps(outptr, _outp);
483                     ptr += 8;
484                     outptr += 8;
485                 }
486 
487                 return 0;
488             }
489 
490             // type 12
491             const float* ptr = a;
492             const float* ptr1 = b;
493             float* outptr = c;
494 
495             for (int y = 0; y < h; y++)
496             {
497                 __m256 _b0 = _mm256_loadu_ps(ptr1);
498                 for (int x = 0; x < w; x++)
499                 {
500                     __m256 _p = _mm256_loadu_ps(ptr);
501                     __m256 _outp = op(_p, _b0);
502                     _mm256_storeu_ps(outptr, _outp);
503                     ptr += 8;
504                     outptr += 8;
505                 }
506 
507                 ptr1 += 8;
508             }
509 
510             return 0;
511         }
512     }
513     else if (a.dims == 1)
514     {
515         if (a.w == 1 && elempack == 1)
516         {
517             if (b.dims == 3)
518             {
519                 // type 4
520                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
521                 if (c.empty())
522                     return -100;
523 
524                 __m256 _a0 = _mm256_set1_ps(a[0]);
525                 #pragma omp parallel for num_threads(opt.num_threads)
526                 for (int q = 0; q < channels1; q++)
527                 {
528                     const float* ptr1 = b.channel(q);
529                     float* outptr = c.channel(q);
530 
531                     for (int i = 0; i < size1; i++)
532                     {
533                         __m256 _p1 = _mm256_loadu_ps(ptr1);
534                         __m256 _outp = op(_a0, _p1);
535                         _mm256_storeu_ps(outptr, _outp);
536                         ptr1 += 8;
537                         outptr += 8;
538                     }
539                 }
540 
541                 return 0;
542             }
543 
544             if (b.dims == 2)
545             {
546                 // type 3
547                 c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator);
548                 if (c.empty())
549                     return -100;
550 
551                 __m256 _a0 = _mm256_set1_ps(a[0]);
552                 const float* ptr1 = b;
553                 float* outptr = c;
554                 for (int i = 0; i < size1; i++)
555                 {
556                     __m256 _p1 = _mm256_loadu_ps(ptr1);
557                     __m256 _outp = op(_a0, _p1);
558                     _mm256_storeu_ps(outptr, _outp);
559                     ptr1 += 8;
560                     outptr += 8;
561                 }
562 
563                 return 0;
564             }
565 
566             if (b.dims == 1)
567             {
568                 // type 2
569                 c.create(w1, elemsize1, elempack1, opt.blob_allocator);
570                 if (c.empty())
571                     return -100;
572 
573                 __m256 _a0 = _mm256_set1_ps(a[0]);
574                 const float* ptr1 = b;
575                 float* outptr = c;
576                 for (int i = 0; i < w1; i++)
577                 {
578                     __m256 _p1 = _mm256_loadu_ps(ptr1);
579                     __m256 _outp = op(_a0, _p1);
580                     _mm256_storeu_ps(outptr, _outp);
581                     ptr1 += 8;
582                     outptr += 8;
583                 }
584 
585                 return 0;
586             }
587         }
588 
589         if (b.dims == 3)
590         {
591             // type 9
592             c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
593             if (c.empty())
594                 return -100;
595 
596             #pragma omp parallel for num_threads(opt.num_threads)
597             for (int q = 0; q < channels1; q++)
598             {
599                 __m256 _a0 = _mm256_loadu_ps((const float*)a + q * 8);
600                 const float* ptr1 = b.channel(q);
601                 float* outptr = c.channel(q);
602 
603                 for (int i = 0; i < size1; i++)
604                 {
605                     __m256 _p1 = _mm256_loadu_ps(ptr1);
606                     __m256 _outp = op(_a0, _p1);
607                     _mm256_storeu_ps(outptr, _outp);
608                     ptr1 += 8;
609                     outptr += 8;
610                 }
611             }
612 
613             return 0;
614         }
615 
616         if (b.dims == 2)
617         {
618             // type 8
619             c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator);
620             if (c.empty())
621                 return -100;
622 
623             const float* ptr = a;
624             const float* ptr1 = b;
625             float* outptr = c;
626 
627             for (int y = 0; y < h1; y++)
628             {
629                 __m256 _a0 = _mm256_loadu_ps(ptr);
630                 for (int x = 0; x < w1; x++)
631                 {
632                     __m256 _p1 = _mm256_loadu_ps(ptr1);
633                     __m256 _outp = op(_a0, _p1);
634                     _mm256_storeu_ps(outptr, _outp);
635                     ptr1 += 8;
636                     outptr += 8;
637                 }
638 
639                 ptr += 8;
640             }
641 
642             return 0;
643         }
644 
645         if (b.dims == 1)
646         {
647             c.create(w, elemsize, elempack, opt.blob_allocator);
648             if (c.empty())
649                 return -100;
650 
651             if (b.w == 1 && elempack1 == 1)
652             {
653                 // type 6
654                 __m256 _b0 = _mm256_set1_ps(b[0]);
655                 const float* ptr = a;
656                 float* outptr = c;
657                 for (int i = 0; i < w; i++)
658                 {
659                     __m256 _p = _mm256_loadu_ps(ptr);
660                     __m256 _outp = op(_p, _b0);
661                     _mm256_storeu_ps(outptr, _outp);
662                     ptr += 8;
663                     outptr += 8;
664                 }
665 
666                 return 0;
667             }
668 
669             // type 7
670             const float* ptr = a;
671             const float* ptr1 = b;
672             float* outptr = c;
673             for (int i = 0; i < w; i++)
674             {
675                 __m256 _p = _mm256_loadu_ps(ptr);
676                 __m256 _p1 = _mm256_loadu_ps(ptr1);
677                 __m256 _outp = op(_p, _p1);
678                 _mm256_storeu_ps(outptr, _outp);
679                 ptr += 8;
680                 ptr1 += 8;
681                 outptr += 8;
682             }
683         }
684     }
685 
686     return 0;
687 }
688 
689 template<typename Op>
binary_op_scalar_inplace_pack8(Mat & a,float b,const Option & opt)690 static int binary_op_scalar_inplace_pack8(Mat& a, float b, const Option& opt)
691 {
692     Op op;
693 
694     int w = a.w;
695     int h = a.h;
696     int channels = a.c;
697     int size = w * h;
698 
699     __m256 _b = _mm256_set1_ps(b);
700 
701     #pragma omp parallel for num_threads(opt.num_threads)
702     for (int q = 0; q < channels; q++)
703     {
704         float* ptr = a.channel(q);
705 
706         for (int i = 0; i < size; i++)
707         {
708             __m256 _p = _mm256_loadu_ps(ptr);
709             _p = op(_p, _b);
710             _mm256_storeu_ps(ptr, _p);
711             ptr += 8;
712         }
713     }
714 
715     return 0;
716 }
717 
718 struct binary_op_add_pack8
719 {
operator ()ncnn::binary_op_add_pack8720     __m256 operator()(const __m256& x, const __m256& y) const
721     {
722         return _mm256_add_ps(x, y);
723     }
724 };
725 
726 struct binary_op_sub_pack8
727 {
operator ()ncnn::binary_op_sub_pack8728     __m256 operator()(const __m256& x, const __m256& y) const
729     {
730         return _mm256_sub_ps(x, y);
731     }
732 };
733 
734 struct binary_op_mul_pack8
735 {
operator ()ncnn::binary_op_mul_pack8736     __m256 operator()(const __m256& x, const __m256& y) const
737     {
738         return _mm256_mul_ps(x, y);
739     }
740 };
741 
742 struct binary_op_div_pack8
743 {
operator ()ncnn::binary_op_div_pack8744     __m256 operator()(const __m256& x, const __m256& y) const
745     {
746         return _mm256_div_ps(x, y);
747     }
748 };
749 
750 struct binary_op_max_pack8
751 {
operator ()ncnn::binary_op_max_pack8752     __m256 operator()(const __m256& x, const __m256& y) const
753     {
754         return _mm256_max_ps(x, y);
755     }
756 };
757 
758 struct binary_op_min_pack8
759 {
operator ()ncnn::binary_op_min_pack8760     __m256 operator()(const __m256& x, const __m256& y) const
761     {
762         return _mm256_min_ps(x, y);
763     }
764 };
765 
766 struct binary_op_pow_pack8
767 {
operator ()ncnn::binary_op_pow_pack8768     __m256 operator()(const __m256& x, const __m256& y) const
769     {
770         return exp256_ps(_mm256_mul_ps(y, log256_ps(x)));
771     }
772 };
773 
774 struct binary_op_rsub_pack8
775 {
operator ()ncnn::binary_op_rsub_pack8776     __m256 operator()(const __m256& x, const __m256& y) const
777     {
778         return _mm256_sub_ps(y, x);
779     }
780 };
781 
782 struct binary_op_rdiv_pack8
783 {
operator ()ncnn::binary_op_rdiv_pack8784     __m256 operator()(const __m256& x, const __m256& y) const
785     {
786         return _mm256_div_ps(y, x);
787     }
788 };
789 #endif // __AVX__
790 
791 template<typename Op>
binary_op_pack4(const Mat & a,const Mat & b,Mat & c,const Option & opt)792 static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt)
793 {
794     Op op;
795 
796     int w = a.w;
797     int h = a.h;
798     int channels = a.c;
799     int size = w * h;
800     size_t elemsize = a.elemsize;
801     int elempack = a.elempack;
802 
803     int w1 = b.w;
804     int h1 = b.h;
805     int channels1 = b.c;
806     int size1 = w1 * h1;
807     size_t elemsize1 = b.elemsize;
808     int elempack1 = b.elempack;
809 
810     if (a.dims == 3)
811     {
812         if (b.dims == 3)
813         {
814             if (w1 == 1 && h1 == 1 && channels1 == channels)
815             {
816                 // special type 1
817                 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
818                 if (c.empty())
819                     return -100;
820 
821                 #pragma omp parallel for num_threads(opt.num_threads)
822                 for (int q = 0; q < channels; q++)
823                 {
824                     const float* ptr = a.channel(q);
825                     float* outptr = c.channel(q);
826                     const float* b0 = b.channel(q);
827                     __m128 _b0 = _mm_loadu_ps(b0);
828                     for (int i = 0; i < size; i++)
829                     {
830                         __m128 _p = _mm_loadu_ps(ptr);
831                         __m128 _outp = op(_p, _b0);
832                         _mm_storeu_ps(outptr, _outp);
833                         ptr += 4;
834                         outptr += 4;
835                     }
836                 }
837 
838                 return 0;
839             }
840 
841             if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1)
842             {
843                 // special type 2
844                 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
845                 if (c.empty())
846                     return -100;
847 
848                 #pragma omp parallel for num_threads(opt.num_threads)
849                 for (int q = 0; q < channels; q++)
850                 {
851                     const float* ptr = a.channel(q);
852                     const float* ptr1 = b;
853                     float* outptr = c.channel(q);
854                     for (int i = 0; i < size; i++)
855                     {
856                         __m128 _p = _mm_loadu_ps(ptr);
857                         __m128 _p1 = _mm_set1_ps(*ptr1);
858                         __m128 _outp = op(_p, _p1);
859                         _mm_storeu_ps(outptr, _outp);
860                         ptr += 4;
861                         ptr1 += 1;
862                         outptr += 4;
863                     }
864                 }
865 
866                 return 0;
867             }
868 
869             if (w == 1 && h == 1 && channels1 == channels)
870             {
871                 // special type 3
872                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
873                 if (c.empty())
874                     return -100;
875 
876                 #pragma omp parallel for num_threads(opt.num_threads)
877                 for (int q = 0; q < channels1; q++)
878                 {
879                     const float* a0 = a.channel(q);
880                     float* outptr = c.channel(q);
881                     const float* ptr1 = b.channel(q);
882                     __m128 _a0 = _mm_loadu_ps(a0);
883                     for (int i = 0; i < size1; i++)
884                     {
885                         __m128 _p1 = _mm_loadu_ps(ptr1);
886                         __m128 _outp = op(_a0, _p1);
887                         _mm_storeu_ps(outptr, _outp);
888                         ptr1 += 4;
889                         outptr += 4;
890                     }
891                 }
892 
893                 return 0;
894             }
895 
896             if (w1 == w && h1 == h && channels == 1 && elempack == 1)
897             {
898                 // special type 4
899                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
900                 if (c.empty())
901                     return -100;
902 
903                 #pragma omp parallel for num_threads(opt.num_threads)
904                 for (int q = 0; q < channels1; q++)
905                 {
906                     const float* ptr = a;
907                     const float* ptr1 = b.channel(q);
908                     float* outptr = c.channel(q);
909                     for (int i = 0; i < size1; i++)
910                     {
911                         __m128 _p = _mm_set1_ps(*ptr);
912                         __m128 _p1 = _mm_loadu_ps(ptr1);
913                         __m128 _outp = op(_p, _p1);
914                         _mm_storeu_ps(outptr, _outp);
915                         ptr += 1;
916                         ptr1 += 4;
917                         outptr += 4;
918                     }
919                 }
920 
921                 return 0;
922             }
923 
924             if (w != 1 && w1 == 1 && h1 == h && channels1 == channels)
925             {
926                 // special type 5
927                 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
928                 if (c.empty())
929                     return -100;
930 
931                 #pragma omp parallel for num_threads(opt.num_threads)
932                 for (int q = 0; q < channels1; q++)
933                 {
934                     const float* ptr = a.channel(q);
935                     const float* ptr1 = b.channel(q);
936                     float* outptr = c.channel(q);
937 
938                     for (int y = 0; y < h; y++)
939                     {
940                         __m128 _p1 = _mm_loadu_ps(ptr1 + y * 4);
941                         for (int x = 0; x < w; x++)
942                         {
943                             __m128 _p = _mm_loadu_ps(ptr);
944                             __m128 _outp = op(_p, _p1);
945                             _mm_storeu_ps(outptr, _outp);
946 
947                             ptr += 4;
948                             outptr += 4;
949                         }
950                     }
951                 }
952 
953                 return 0;
954             }
955 
956             if (w1 == w && h != 1 && h1 == 1 && channels1 == channels)
957             {
958                 // special type 6
959                 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
960                 if (c.empty())
961                     return -100;
962 
963                 #pragma omp parallel for num_threads(opt.num_threads)
964                 for (int q = 0; q < channels1; q++)
965                 {
966                     const float* ptr = a.channel(q);
967                     const float* ptr1 = b.channel(q);
968                     float* outptr = c.channel(q);
969 
970                     for (int y = 0; y < h; y++)
971                     {
972                         for (int x = 0; x < w; x++)
973                         {
974                             __m128 _p = _mm_loadu_ps(ptr);
975                             __m128 _p1 = _mm_loadu_ps(ptr1 + x * 4);
976                             __m128 _outp = op(_p, _p1);
977                             _mm_storeu_ps(outptr, _outp);
978 
979                             ptr += 4;
980                             outptr += 4;
981                         }
982                     }
983                 }
984 
985                 return 0;
986             }
987 
988             if (w1 != 1 && w == 1 && h1 == h && channels1 == channels)
989             {
990                 // special type 7
991                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
992                 if (c.empty())
993                     return -100;
994 
995                 #pragma omp parallel for num_threads(opt.num_threads)
996                 for (int q = 0; q < channels1; q++)
997                 {
998                     const float* ptr = a.channel(q);
999                     const float* ptr1 = b.channel(q);
1000                     float* outptr = c.channel(q);
1001 
1002                     for (int y = 0; y < h1; y++)
1003                     {
1004                         __m128 _p = _mm_loadu_ps(ptr + y * 4);
1005                         for (int x = 0; x < w1; x++)
1006                         {
1007                             __m128 _p1 = _mm_loadu_ps(ptr1);
1008                             __m128 _outp = op(_p, _p1);
1009                             _mm_storeu_ps(outptr, _outp);
1010 
1011                             ptr1 += 4;
1012                             outptr += 4;
1013                         }
1014                     }
1015                 }
1016 
1017                 return 0;
1018             }
1019 
1020             if (w1 == w && h1 != 1 && h == 1 && channels1 == channels)
1021             {
1022                 // special type 8
1023                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
1024                 if (c.empty())
1025                     return -100;
1026 
1027                 #pragma omp parallel for num_threads(opt.num_threads)
1028                 for (int q = 0; q < channels1; q++)
1029                 {
1030                     const float* ptr = a.channel(q);
1031                     const float* ptr1 = b.channel(q);
1032                     float* outptr = c.channel(q);
1033 
1034                     for (int y = 0; y < h1; y++)
1035                     {
1036                         for (int x = 0; x < w1; x++)
1037                         {
1038                             __m128 _p = _mm_loadu_ps(ptr + x * 4);
1039                             __m128 _p1 = _mm_loadu_ps(ptr1);
1040                             __m128 _outp = op(_p, _p1);
1041                             _mm_storeu_ps(outptr, _outp);
1042 
1043                             ptr1 += 4;
1044                             outptr += 4;
1045                         }
1046                     }
1047                 }
1048 
1049                 return 0;
1050             }
1051 
1052             // type 19
1053             c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
1054             if (c.empty())
1055                 return -100;
1056 
1057             #pragma omp parallel for num_threads(opt.num_threads)
1058             for (int q = 0; q < channels; q++)
1059             {
1060                 const float* ptr = a.channel(q);
1061                 const float* ptr1 = b.channel(q);
1062                 float* outptr = c.channel(q);
1063 
1064                 for (int i = 0; i < size; i++)
1065                 {
1066                     __m128 _p = _mm_loadu_ps(ptr);
1067                     __m128 _p1 = _mm_loadu_ps(ptr1);
1068                     __m128 _outp = op(_p, _p1);
1069                     _mm_storeu_ps(outptr, _outp);
1070                     ptr += 4;
1071                     ptr1 += 4;
1072                     outptr += 4;
1073                 }
1074             }
1075 
1076             return 0;
1077         }
1078 
1079         c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
1080         if (c.empty())
1081             return -100;
1082 
1083         if (b.dims == 2)
1084         {
1085             // type 18
1086             #pragma omp parallel for num_threads(opt.num_threads)
1087             for (int q = 0; q < channels; q++)
1088             {
1089                 const float* ptr = a.channel(q);
1090                 const float* ptr1 = b.row<const float>(q);
1091                 float* outptr = c.channel(q);
1092 
1093                 for (int y = 0; y < h; y++)
1094                 {
1095                     __m128 _b0 = _mm_loadu_ps(ptr1);
1096                     for (int x = 0; x < w; x++)
1097                     {
1098                         __m128 _p = _mm_loadu_ps(ptr);
1099                         __m128 _outp = op(_p, _b0);
1100                         _mm_storeu_ps(outptr, _outp);
1101                         ptr += 4;
1102                         outptr += 4;
1103                     }
1104 
1105                     ptr1 += 4;
1106                 }
1107             }
1108 
1109             return 0;
1110         }
1111 
1112         if (b.dims == 1)
1113         {
1114             if (b.w == 1 && elempack1 == 1)
1115             {
1116                 // type 16
1117                 __m128 _b0 = _mm_set1_ps(((const float*)b)[0]);
1118                 #pragma omp parallel for num_threads(opt.num_threads)
1119                 for (int q = 0; q < channels; q++)
1120                 {
1121                     const float* ptr = a.channel(q);
1122                     float* outptr = c.channel(q);
1123 
1124                     for (int i = 0; i < size; i++)
1125                     {
1126                         __m128 _p = _mm_loadu_ps(ptr);
1127                         __m128 _outp = op(_p, _b0);
1128                         _mm_storeu_ps(outptr, _outp);
1129                         ptr += 4;
1130                         outptr += 4;
1131                     }
1132                 }
1133 
1134                 return 0;
1135             }
1136 
1137             // type 17
1138             #pragma omp parallel for num_threads(opt.num_threads)
1139             for (int q = 0; q < channels; q++)
1140             {
1141                 const float* ptr = a.channel(q);
1142                 __m128 _b0 = _mm_loadu_ps((const float*)b + q * 4);
1143                 float* outptr = c.channel(q);
1144 
1145                 for (int i = 0; i < size; i++)
1146                 {
1147                     __m128 _p = _mm_loadu_ps(ptr);
1148                     __m128 _outp = op(_p, _b0);
1149                     _mm_storeu_ps(outptr, _outp);
1150                     ptr += 4;
1151                     outptr += 4;
1152                 }
1153             }
1154 
1155             return 0;
1156         }
1157     }
1158     else if (a.dims == 2)
1159     {
1160         if (b.dims == 3)
1161         {
1162             // type 14
1163             c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
1164             if (c.empty())
1165                 return -100;
1166 
1167             #pragma omp parallel for num_threads(opt.num_threads)
1168             for (int q = 0; q < channels1; q++)
1169             {
1170                 const float* ptr = a.row<const float>(q);
1171                 const float* ptr1 = b.channel(q);
1172                 float* outptr = c.channel(q);
1173 
1174                 for (int y = 0; y < h1; y++)
1175                 {
1176                     __m128 _a0 = _mm_loadu_ps(ptr);
1177                     for (int x = 0; x < w1; x++)
1178                     {
1179                         __m128 _p1 = _mm_loadu_ps(ptr1);
1180                         __m128 _outp = op(_a0, _p1);
1181                         _mm_storeu_ps(outptr, _outp);
1182                         ptr1 += 4;
1183                         outptr += 4;
1184                     }
1185 
1186                     ptr += 4;
1187                 }
1188             }
1189 
1190             return 0;
1191         }
1192 
1193         c.create(w, h, elemsize, elempack, opt.blob_allocator);
1194         if (c.empty())
1195             return -100;
1196 
1197         if (b.dims == 2)
1198         {
1199             // type 13
1200             const float* ptr = a;
1201             const float* ptr1 = b;
1202             float* outptr = c;
1203             for (int i = 0; i < size; i++)
1204             {
1205                 __m128 _p = _mm_loadu_ps(ptr);
1206                 __m128 _p1 = _mm_loadu_ps(ptr1);
1207                 __m128 _outp = op(_p, _p1);
1208                 _mm_storeu_ps(outptr, _outp);
1209                 ptr += 4;
1210                 ptr1 += 4;
1211                 outptr += 4;
1212             }
1213 
1214             return 0;
1215         }
1216 
1217         if (b.dims == 1)
1218         {
1219             c.create(w, h, elemsize, elempack, opt.blob_allocator);
1220             if (c.empty())
1221                 return -100;
1222 
1223             if (b.w == 1 && elempack1 == 1)
1224             {
1225                 // type 11
1226                 __m128 _b0 = _mm_set1_ps(((const float*)b)[0]);
1227                 const float* ptr = a;
1228                 float* outptr = c;
1229                 for (int i = 0; i < size; i++)
1230                 {
1231                     __m128 _p = _mm_loadu_ps(ptr);
1232                     __m128 _outp = op(_p, _b0);
1233                     _mm_storeu_ps(outptr, _outp);
1234                     ptr += 4;
1235                     outptr += 4;
1236                 }
1237 
1238                 return 0;
1239             }
1240 
1241             // type 12
1242             const float* ptr = a;
1243             const float* ptr1 = b;
1244             float* outptr = c;
1245 
1246             for (int y = 0; y < h; y++)
1247             {
1248                 __m128 _b0 = _mm_loadu_ps(ptr1);
1249                 for (int x = 0; x < w; x++)
1250                 {
1251                     __m128 _p = _mm_loadu_ps(ptr);
1252                     __m128 _outp = op(_p, _b0);
1253                     _mm_storeu_ps(outptr, _outp);
1254                     ptr += 4;
1255                     outptr += 4;
1256                 }
1257 
1258                 ptr1 += 4;
1259             }
1260 
1261             return 0;
1262         }
1263     }
1264     else if (a.dims == 1)
1265     {
1266         if (a.w == 1 && elempack == 1)
1267         {
1268             if (b.dims == 3)
1269             {
1270                 // type 4
1271                 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
1272                 if (c.empty())
1273                     return -100;
1274 
1275                 __m128 _a0 = _mm_set1_ps(((const float*)a)[0]);
1276                 #pragma omp parallel for num_threads(opt.num_threads)
1277                 for (int q = 0; q < channels1; q++)
1278                 {
1279                     const float* ptr1 = b.channel(q);
1280                     float* outptr = c.channel(q);
1281 
1282                     for (int i = 0; i < size1; i++)
1283                     {
1284                         __m128 _p1 = _mm_loadu_ps(ptr1);
1285                         __m128 _outp = op(_a0, _p1);
1286                         _mm_storeu_ps(outptr, _outp);
1287                         ptr1 += 4;
1288                         outptr += 4;
1289                     }
1290                 }
1291 
1292                 return 0;
1293             }
1294 
1295             if (b.dims == 2)
1296             {
1297                 // type 3
1298                 c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator);
1299                 if (c.empty())
1300                     return -100;
1301 
1302                 __m128 _a0 = _mm_set1_ps(((const float*)a)[0]);
1303                 const float* ptr1 = b;
1304                 float* outptr = c;
1305                 for (int i = 0; i < size1; i++)
1306                 {
1307                     __m128 _p1 = _mm_loadu_ps(ptr1);
1308                     __m128 _outp = op(_a0, _p1);
1309                     _mm_storeu_ps(outptr, _outp);
1310                     ptr1 += 4;
1311                     outptr += 4;
1312                 }
1313 
1314                 return 0;
1315             }
1316 
1317             if (b.dims == 1)
1318             {
1319                 // type 2
1320                 c.create(w1, elemsize1, elempack1, opt.blob_allocator);
1321                 if (c.empty())
1322                     return -100;
1323 
1324                 __m128 _a0 = _mm_set1_ps(((const float*)a)[0]);
1325                 const float* ptr1 = b;
1326                 float* outptr = c;
1327                 for (int i = 0; i < w1; i++)
1328                 {
1329                     __m128 _p1 = _mm_loadu_ps(ptr1);
1330                     __m128 _outp = op(_a0, _p1);
1331                     _mm_storeu_ps(outptr, _outp);
1332                     ptr1 += 4;
1333                     outptr += 4;
1334                 }
1335 
1336                 return 0;
1337             }
1338         }
1339 
1340         if (b.dims == 3)
1341         {
1342             // type 9
1343             c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
1344             if (c.empty())
1345                 return -100;
1346 
1347             #pragma omp parallel for num_threads(opt.num_threads)
1348             for (int q = 0; q < channels1; q++)
1349             {
1350                 __m128 _a0 = _mm_loadu_ps((const float*)a + q * 4);
1351                 const float* ptr1 = b.channel(q);
1352                 float* outptr = c.channel(q);
1353 
1354                 for (int i = 0; i < size1; i++)
1355                 {
1356                     __m128 _p1 = _mm_loadu_ps(ptr1);
1357                     __m128 _outp = op(_a0, _p1);
1358                     _mm_storeu_ps(outptr, _outp);
1359                     ptr1 += 4;
1360                     outptr += 4;
1361                 }
1362             }
1363 
1364             return 0;
1365         }
1366 
1367         if (b.dims == 2)
1368         {
1369             // type 8
1370             c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator);
1371             if (c.empty())
1372                 return -100;
1373 
1374             const float* ptr = a;
1375             const float* ptr1 = b;
1376             float* outptr = c;
1377 
1378             for (int y = 0; y < h1; y++)
1379             {
1380                 __m128 _a0 = _mm_loadu_ps(ptr);
1381                 for (int x = 0; x < w1; x++)
1382                 {
1383                     __m128 _p1 = _mm_loadu_ps(ptr1);
1384                     __m128 _outp = op(_a0, _p1);
1385                     _mm_storeu_ps(outptr, _outp);
1386                     ptr1 += 4;
1387                     outptr += 4;
1388                 }
1389 
1390                 ptr += 4;
1391             }
1392 
1393             return 0;
1394         }
1395 
1396         if (b.dims == 1)
1397         {
1398             c.create(w, elemsize, elempack, opt.blob_allocator);
1399             if (c.empty())
1400                 return -100;
1401 
1402             if (b.w == 1 && elempack1 == 1)
1403             {
1404                 // type 6
1405                 __m128 _b0 = _mm_set1_ps(((const float*)b)[0]);
1406                 const float* ptr = a;
1407                 float* outptr = c;
1408                 for (int i = 0; i < w; i++)
1409                 {
1410                     __m128 _p = _mm_loadu_ps(ptr);
1411                     __m128 _outp = op(_p, _b0);
1412                     _mm_storeu_ps(outptr, _outp);
1413                     ptr += 4;
1414                     outptr += 4;
1415                 }
1416 
1417                 return 0;
1418             }
1419 
1420             // type 7
1421             const float* ptr = a;
1422             const float* ptr1 = b;
1423             float* outptr = c;
1424             for (int i = 0; i < w; i++)
1425             {
1426                 __m128 _p = _mm_loadu_ps(ptr);
1427                 __m128 _p1 = _mm_loadu_ps(ptr1);
1428                 __m128 _outp = op(_p, _p1);
1429                 _mm_storeu_ps(outptr, _outp);
1430                 ptr += 4;
1431                 ptr1 += 4;
1432                 outptr += 4;
1433             }
1434         }
1435     }
1436 
1437     return 0;
1438 }
1439 
1440 template<typename Op>
binary_op_scalar_inplace_pack4(Mat & a,float b,const Option & opt)1441 static int binary_op_scalar_inplace_pack4(Mat& a, float b, const Option& opt)
1442 {
1443     Op op;
1444 
1445     int w = a.w;
1446     int h = a.h;
1447     int channels = a.c;
1448     int size = w * h;
1449 
1450     __m128 _b = _mm_set1_ps((float)b);
1451 
1452     #pragma omp parallel for num_threads(opt.num_threads)
1453     for (int q = 0; q < channels; q++)
1454     {
1455         float* ptr = a.channel(q);
1456 
1457         for (int i = 0; i < size; i++)
1458         {
1459             __m128 _p = _mm_loadu_ps(ptr);
1460             _p = op(_p, _b);
1461             _mm_storeu_ps(ptr, _p);
1462             ptr += 4;
1463         }
1464     }
1465 
1466     return 0;
1467 }
1468 
1469 struct binary_op_add_pack4
1470 {
operator ()ncnn::binary_op_add_pack41471     __m128 operator()(const __m128& x, const __m128& y) const
1472     {
1473         return _mm_add_ps(x, y);
1474     }
1475 };
1476 
1477 struct binary_op_sub_pack4
1478 {
operator ()ncnn::binary_op_sub_pack41479     __m128 operator()(const __m128& x, const __m128& y) const
1480     {
1481         return _mm_sub_ps(x, y);
1482     }
1483 };
1484 
1485 struct binary_op_mul_pack4
1486 {
operator ()ncnn::binary_op_mul_pack41487     __m128 operator()(const __m128& x, const __m128& y) const
1488     {
1489         return _mm_mul_ps(x, y);
1490     }
1491 };
1492 
1493 struct binary_op_div_pack4
1494 {
operator ()ncnn::binary_op_div_pack41495     __m128 operator()(const __m128& x, const __m128& y) const
1496     {
1497         return _mm_div_ps(x, y);
1498     }
1499 };
1500 
1501 struct binary_op_max_pack4
1502 {
operator ()ncnn::binary_op_max_pack41503     __m128 operator()(const __m128& x, const __m128& y) const
1504     {
1505         return _mm_max_ps(x, y);
1506     }
1507 };
1508 
1509 struct binary_op_min_pack4
1510 {
operator ()ncnn::binary_op_min_pack41511     __m128 operator()(const __m128& x, const __m128& y) const
1512     {
1513         return _mm_min_ps(x, y);
1514     }
1515 };
1516 
1517 struct binary_op_pow_pack4
1518 {
operator ()ncnn::binary_op_pow_pack41519     __m128 operator()(const __m128& x, const __m128& y) const
1520     {
1521         return exp_ps(_mm_mul_ps(y, log_ps(x)));
1522     }
1523 };
1524 
1525 struct binary_op_rsub_pack4
1526 {
operator ()ncnn::binary_op_rsub_pack41527     __m128 operator()(const __m128& x, const __m128& y) const
1528     {
1529         return _mm_sub_ps(y, x);
1530     }
1531 };
1532 
1533 struct binary_op_rdiv_pack4
1534 {
operator ()ncnn::binary_op_rdiv_pack41535     __m128 operator()(const __m128& x, const __m128& y) const
1536     {
1537         return _mm_div_ps(y, x);
1538     }
1539 };
1540 #endif // __SSE2__
1541 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1542 int BinaryOp_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1543 {
1544 #if __SSE2__
1545     const Mat& bottom_blob = bottom_blobs[0];
1546     const Mat& bottom_blob1 = bottom_blobs[1];
1547     Mat& top_blob = top_blobs[0];
1548 
1549     int elempack = bottom_blob.elempack;
1550     int elempack1 = bottom_blob1.elempack;
1551 
1552 #if __AVX__
1553     if (elempack == 8 || elempack1 == 8)
1554     {
1555         if (op_type == Operation_ADD)
1556             return binary_op_pack8<binary_op_add_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1557 
1558         if (op_type == Operation_SUB)
1559             return binary_op_pack8<binary_op_sub_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1560 
1561         if (op_type == Operation_MUL)
1562             return binary_op_pack8<binary_op_mul_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1563 
1564         if (op_type == Operation_DIV)
1565             return binary_op_pack8<binary_op_div_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1566 
1567         if (op_type == Operation_MAX)
1568             return binary_op_pack8<binary_op_max_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1569 
1570         if (op_type == Operation_MIN)
1571             return binary_op_pack8<binary_op_min_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1572 
1573         if (op_type == Operation_POW)
1574             return binary_op_pack8<binary_op_pow_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1575 
1576         if (op_type == Operation_RSUB)
1577             return binary_op_pack8<binary_op_rsub_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1578 
1579         if (op_type == Operation_RDIV)
1580             return binary_op_pack8<binary_op_rdiv_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1581     }
1582 #endif // __AVX__
1583 
1584     if (elempack == 4 || elempack1 == 4)
1585     {
1586         if (op_type == Operation_ADD)
1587             return binary_op_pack4<binary_op_add_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1588 
1589         if (op_type == Operation_SUB)
1590             return binary_op_pack4<binary_op_sub_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1591 
1592         if (op_type == Operation_MUL)
1593             return binary_op_pack4<binary_op_mul_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1594 
1595         if (op_type == Operation_DIV)
1596             return binary_op_pack4<binary_op_div_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1597 
1598         if (op_type == Operation_MAX)
1599             return binary_op_pack4<binary_op_max_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1600 
1601         if (op_type == Operation_MIN)
1602             return binary_op_pack4<binary_op_min_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1603 
1604         if (op_type == Operation_POW)
1605             return binary_op_pack4<binary_op_pow_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1606 
1607         if (op_type == Operation_RSUB)
1608             return binary_op_pack4<binary_op_rsub_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1609 
1610         if (op_type == Operation_RDIV)
1611             return binary_op_pack4<binary_op_rdiv_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1612     }
1613 #endif // __SSE2__
1614 
1615     return BinaryOp::forward(bottom_blobs, top_blobs, opt);
1616 }
1617 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const1618 int BinaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
1619 {
1620 #if __SSE2__
1621     int elempack = bottom_top_blob.elempack;
1622 
1623 #if __AVX__
1624     if (elempack == 8)
1625     {
1626         if (op_type == Operation_ADD)
1627             return binary_op_scalar_inplace_pack8<binary_op_add_pack8>(bottom_top_blob, b, opt);
1628 
1629         if (op_type == Operation_SUB)
1630             return binary_op_scalar_inplace_pack8<binary_op_sub_pack8>(bottom_top_blob, b, opt);
1631 
1632         if (op_type == Operation_MUL)
1633             return binary_op_scalar_inplace_pack8<binary_op_mul_pack8>(bottom_top_blob, b, opt);
1634 
1635         if (op_type == Operation_DIV)
1636             return binary_op_scalar_inplace_pack8<binary_op_div_pack8>(bottom_top_blob, b, opt);
1637 
1638         if (op_type == Operation_MAX)
1639             return binary_op_scalar_inplace_pack8<binary_op_max_pack8>(bottom_top_blob, b, opt);
1640 
1641         if (op_type == Operation_MIN)
1642             return binary_op_scalar_inplace_pack8<binary_op_min_pack8>(bottom_top_blob, b, opt);
1643 
1644         if (op_type == Operation_POW)
1645             return binary_op_scalar_inplace_pack8<binary_op_pow_pack8>(bottom_top_blob, b, opt);
1646 
1647         if (op_type == Operation_RSUB)
1648             return binary_op_scalar_inplace_pack8<binary_op_rsub_pack8>(bottom_top_blob, b, opt);
1649 
1650         if (op_type == Operation_RDIV)
1651             return binary_op_scalar_inplace_pack8<binary_op_rdiv_pack8>(bottom_top_blob, b, opt);
1652     }
1653 #endif // __AVX__
1654 
1655     if (elempack == 4)
1656     {
1657         if (op_type == Operation_ADD)
1658             return binary_op_scalar_inplace_pack4<binary_op_add_pack4>(bottom_top_blob, b, opt);
1659 
1660         if (op_type == Operation_SUB)
1661             return binary_op_scalar_inplace_pack4<binary_op_sub_pack4>(bottom_top_blob, b, opt);
1662 
1663         if (op_type == Operation_MUL)
1664             return binary_op_scalar_inplace_pack4<binary_op_mul_pack4>(bottom_top_blob, b, opt);
1665 
1666         if (op_type == Operation_DIV)
1667             return binary_op_scalar_inplace_pack4<binary_op_div_pack4>(bottom_top_blob, b, opt);
1668 
1669         if (op_type == Operation_MAX)
1670             return binary_op_scalar_inplace_pack4<binary_op_max_pack4>(bottom_top_blob, b, opt);
1671 
1672         if (op_type == Operation_MIN)
1673             return binary_op_scalar_inplace_pack4<binary_op_min_pack4>(bottom_top_blob, b, opt);
1674 
1675         if (op_type == Operation_POW)
1676             return binary_op_scalar_inplace_pack4<binary_op_pow_pack4>(bottom_top_blob, b, opt);
1677 
1678         if (op_type == Operation_RSUB)
1679             return binary_op_scalar_inplace_pack4<binary_op_rsub_pack4>(bottom_top_blob, b, opt);
1680 
1681         if (op_type == Operation_RDIV)
1682             return binary_op_scalar_inplace_pack4<binary_op_rdiv_pack4>(bottom_top_blob, b, opt);
1683     }
1684 #endif // __SSE2__
1685 
1686     return BinaryOp::forward_inplace(bottom_top_blob, opt);
1687 }
1688 
1689 } // namespace ncnn
1690