1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "binaryop_x86.h"
16
17 #if __SSE2__
18 #include "sse_mathfun.h"
19 #if __AVX__
20 #include "avx_mathfun.h"
21 #endif // __AVX__
22 #endif // __SSE2__
23
24 #include <math.h>
25
26 namespace ncnn {
27
BinaryOp_x86()28 BinaryOp_x86::BinaryOp_x86()
29 {
30 #if __SSE2__
31 support_packing = true;
32 #endif // __SSE2__
33 }
34
35 #if __SSE2__
36 #if __AVX__
37 // broadcasting rule
38 // https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
39
40 template<typename Op>
binary_op_pack8(const Mat & a,const Mat & b,Mat & c,const Option & opt)41 static int binary_op_pack8(const Mat& a, const Mat& b, Mat& c, const Option& opt)
42 {
43 Op op;
44
45 int w = a.w;
46 int h = a.h;
47 int channels = a.c;
48 int size = w * h;
49 size_t elemsize = a.elemsize;
50 int elempack = a.elempack;
51
52 int w1 = b.w;
53 int h1 = b.h;
54 int channels1 = b.c;
55 int size1 = w1 * h1;
56 size_t elemsize1 = b.elemsize;
57 int elempack1 = b.elempack;
58
59 if (a.dims == 3)
60 {
61 if (b.dims == 3)
62 {
63 if (w1 == 1 && h1 == 1 && channels1 == channels)
64 {
65 // special type 1
66 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
67 if (c.empty())
68 return -100;
69
70 #pragma omp parallel for num_threads(opt.num_threads)
71 for (int q = 0; q < channels; q++)
72 {
73 const float* ptr = a.channel(q);
74 const float* b0 = b.channel(q);
75 float* outptr = c.channel(q);
76 __m256 _b0 = _mm256_loadu_ps(b0);
77 for (int i = 0; i < size; i++)
78 {
79 __m256 _p = _mm256_loadu_ps(ptr);
80 __m256 _outp = op(_p, _b0);
81 _mm256_storeu_ps(outptr, _outp);
82 ptr += 8;
83 outptr += 8;
84 }
85 }
86
87 return 0;
88 }
89
90 if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1)
91 {
92 // special type 2
93 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
94 if (c.empty())
95 return -100;
96
97 #pragma omp parallel for num_threads(opt.num_threads)
98 for (int q = 0; q < channels; q++)
99 {
100 const float* ptr = a.channel(q);
101 const float* ptr1 = b;
102 float* outptr = c.channel(q);
103 for (int i = 0; i < size; i++)
104 {
105 __m256 _p = _mm256_loadu_ps(ptr);
106 __m256 _p1 = _mm256_broadcast_ss(ptr1);
107 __m256 _outp = op(_p, _p1);
108 _mm256_storeu_ps(outptr, _outp);
109 ptr += 8;
110 ptr1 += 1;
111 outptr += 8;
112 }
113 }
114
115 return 0;
116 }
117
118 if (w == 1 && h == 1 && channels1 == channels)
119 {
120 // special type 3
121 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
122 if (c.empty())
123 return -100;
124
125 #pragma omp parallel for num_threads(opt.num_threads)
126 for (int q = 0; q < channels1; q++)
127 {
128 const float* a0 = a.channel(q);
129 const float* ptr1 = b.channel(q);
130 float* outptr = c.channel(q);
131 __m256 _a0 = _mm256_loadu_ps(a0);
132 for (int i = 0; i < size1; i++)
133 {
134 __m256 _p1 = _mm256_loadu_ps(ptr1);
135 __m256 _outp = op(_a0, _p1);
136 _mm256_storeu_ps(outptr, _outp);
137 ptr1 += 8;
138 outptr += 8;
139 }
140 }
141
142 return 0;
143 }
144
145 if (w1 == w && h1 == h && channels == 1 && elempack == 1)
146 {
147 // special type 4
148 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
149 if (c.empty())
150 return -100;
151
152 #pragma omp parallel for num_threads(opt.num_threads)
153 for (int q = 0; q < channels1; q++)
154 {
155 const float* ptr = a;
156 const float* ptr1 = b.channel(q);
157 float* outptr = c.channel(q);
158 for (int i = 0; i < size1; i++)
159 {
160 __m256 _p = _mm256_broadcast_ss(ptr);
161 __m256 _p1 = _mm256_loadu_ps(ptr1);
162 __m256 _outp = op(_p, _p1);
163 _mm256_storeu_ps(outptr, _outp);
164 ptr += 1;
165 ptr1 += 8;
166 outptr += 8;
167 }
168 }
169
170 return 0;
171 }
172
173 if (w != 1 && w1 == 1 && h1 == h && channels1 == channels)
174 {
175 // special type 5
176 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
177 if (c.empty())
178 return -100;
179
180 #pragma omp parallel for num_threads(opt.num_threads)
181 for (int q = 0; q < channels1; q++)
182 {
183 const float* ptr = a.channel(q);
184 const float* ptr1 = b.channel(q);
185 float* outptr = c.channel(q);
186
187 for (int y = 0; y < h; y++)
188 {
189 __m256 _p1 = _mm256_loadu_ps(ptr1 + y * 8);
190 for (int x = 0; x < w; x++)
191 {
192 __m256 _p = _mm256_loadu_ps(ptr);
193 __m256 _outp = op(_p, _p1);
194 _mm256_storeu_ps(outptr, _outp);
195
196 ptr += 8;
197 outptr += 8;
198 }
199 }
200 }
201
202 return 0;
203 }
204
205 if (w1 == w && h != 1 && h1 == 1 && channels1 == channels)
206 {
207 // special type 6
208 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
209 if (c.empty())
210 return -100;
211
212 #pragma omp parallel for num_threads(opt.num_threads)
213 for (int q = 0; q < channels1; q++)
214 {
215 const float* ptr = a.channel(q);
216 const float* ptr1 = b.channel(q);
217 float* outptr = c.channel(q);
218
219 for (int y = 0; y < h; y++)
220 {
221 for (int x = 0; x < w; x++)
222 {
223 __m256 _p = _mm256_loadu_ps(ptr);
224 __m256 _p1 = _mm256_loadu_ps(ptr1 + x * 8);
225 __m256 _outp = op(_p, _p1);
226 _mm256_storeu_ps(outptr, _outp);
227
228 ptr += 8;
229 outptr += 8;
230 }
231 }
232 }
233
234 return 0;
235 }
236
237 if (w1 != 1 && w == 1 && h1 == h && channels1 == channels)
238 {
239 // special type 7
240 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
241 if (c.empty())
242 return -100;
243
244 #pragma omp parallel for num_threads(opt.num_threads)
245 for (int q = 0; q < channels1; q++)
246 {
247 const float* ptr = a.channel(q);
248 const float* ptr1 = b.channel(q);
249 float* outptr = c.channel(q);
250
251 for (int y = 0; y < h1; y++)
252 {
253 __m256 _p = _mm256_loadu_ps(ptr + y * 8);
254 for (int x = 0; x < w1; x++)
255 {
256 __m256 _p1 = _mm256_loadu_ps(ptr1);
257 __m256 _outp = op(_p, _p1);
258 _mm256_storeu_ps(outptr, _outp);
259
260 ptr1 += 8;
261 outptr += 8;
262 }
263 }
264 }
265
266 return 0;
267 }
268
269 if (w1 == w && h1 != 1 && h == 1 && channels1 == channels)
270 {
271 // special type 8
272 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
273 if (c.empty())
274 return -100;
275
276 #pragma omp parallel for num_threads(opt.num_threads)
277 for (int q = 0; q < channels1; q++)
278 {
279 const float* ptr = a.channel(q);
280 const float* ptr1 = b.channel(q);
281 float* outptr = c.channel(q);
282
283 for (int y = 0; y < h1; y++)
284 {
285 for (int x = 0; x < w1; x++)
286 {
287 __m256 _p = _mm256_loadu_ps(ptr + x * 8);
288 __m256 _p1 = _mm256_loadu_ps(ptr1);
289 __m256 _outp = op(_p, _p1);
290 _mm256_storeu_ps(outptr, _outp);
291
292 ptr1 += 8;
293 outptr += 8;
294 }
295 }
296 }
297
298 return 0;
299 }
300
301 // type 19
302 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
303 if (c.empty())
304 return -100;
305
306 #pragma omp parallel for num_threads(opt.num_threads)
307 for (int q = 0; q < channels; q++)
308 {
309 const float* ptr = a.channel(q);
310 const float* ptr1 = b.channel(q);
311 float* outptr = c.channel(q);
312
313 for (int i = 0; i < size; i++)
314 {
315 __m256 _p = _mm256_loadu_ps(ptr);
316 __m256 _p1 = _mm256_loadu_ps(ptr1);
317 __m256 _outp = op(_p, _p1);
318 _mm256_storeu_ps(outptr, _outp);
319 ptr += 8;
320 ptr1 += 8;
321 outptr += 8;
322 }
323 }
324
325 return 0;
326 }
327
328 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
329 if (c.empty())
330 return -100;
331
332 if (b.dims == 2)
333 {
334 // type 18
335 #pragma omp parallel for num_threads(opt.num_threads)
336 for (int q = 0; q < channels; q++)
337 {
338 const float* ptr = a.channel(q);
339 const float* ptr1 = b.row(q);
340 float* outptr = c.channel(q);
341
342 for (int y = 0; y < h; y++)
343 {
344 __m256 _b0 = _mm256_loadu_ps(ptr1);
345 for (int x = 0; x < w; x++)
346 {
347 __m256 _p = _mm256_loadu_ps(ptr);
348 __m256 _outp = op(_p, _b0);
349 _mm256_storeu_ps(outptr, _outp);
350 ptr += 8;
351 outptr += 8;
352 }
353
354 ptr1 += 8;
355 }
356 }
357
358 return 0;
359 }
360
361 if (b.dims == 1)
362 {
363 if (b.w == 1 && elempack1 == 1)
364 {
365 // type 16
366 __m256 _b0 = _mm256_set1_ps(b[0]);
367 #pragma omp parallel for num_threads(opt.num_threads)
368 for (int q = 0; q < channels; q++)
369 {
370 const float* ptr = a.channel(q);
371 float* outptr = c.channel(q);
372
373 for (int i = 0; i < size; i++)
374 {
375 __m256 _p = _mm256_loadu_ps(ptr);
376 __m256 _outp = op(_p, _b0);
377 _mm256_storeu_ps(outptr, _outp);
378 ptr += 8;
379 outptr += 8;
380 }
381 }
382
383 return 0;
384 }
385
386 // type 17
387 #pragma omp parallel for num_threads(opt.num_threads)
388 for (int q = 0; q < channels; q++)
389 {
390 const float* ptr = a.channel(q);
391 __m256 _b0 = _mm256_loadu_ps((const float*)b + q * 8);
392 float* outptr = c.channel(q);
393
394 for (int i = 0; i < size; i++)
395 {
396 __m256 _p = _mm256_loadu_ps(ptr);
397 __m256 _outp = op(_p, _b0);
398 _mm256_storeu_ps(outptr, _outp);
399 ptr += 8;
400 outptr += 8;
401 }
402 }
403
404 return 0;
405 }
406 }
407 else if (a.dims == 2)
408 {
409 if (b.dims == 3)
410 {
411 // type 14
412 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
413 if (c.empty())
414 return -100;
415
416 #pragma omp parallel for num_threads(opt.num_threads)
417 for (int q = 0; q < channels1; q++)
418 {
419 const float* ptr = a.row(q);
420 const float* ptr1 = b.channel(q);
421 float* outptr = c.channel(q);
422
423 for (int y = 0; y < h1; y++)
424 {
425 __m256 _a0 = _mm256_loadu_ps(ptr);
426 for (int x = 0; x < w1; x++)
427 {
428 __m256 _p1 = _mm256_loadu_ps(ptr1);
429 __m256 _outp = op(_a0, _p1);
430 _mm256_storeu_ps(outptr, _outp);
431 ptr1 += 8;
432 outptr += 8;
433 }
434
435 ptr += 8;
436 }
437 }
438
439 return 0;
440 }
441
442 c.create(w, h, elemsize, elempack, opt.blob_allocator);
443 if (c.empty())
444 return -100;
445
446 if (b.dims == 2)
447 {
448 // type 13
449 const float* ptr = a;
450 const float* ptr1 = b;
451 float* outptr = c;
452 for (int i = 0; i < size; i++)
453 {
454 __m256 _p = _mm256_loadu_ps(ptr);
455 __m256 _p1 = _mm256_loadu_ps(ptr1);
456 __m256 _outp = op(_p, _p1);
457 _mm256_storeu_ps(outptr, _outp);
458 ptr += 8;
459 ptr1 += 8;
460 outptr += 8;
461 }
462
463 return 0;
464 }
465
466 if (b.dims == 1)
467 {
468 c.create(w, h, elemsize, elempack, opt.blob_allocator);
469 if (c.empty())
470 return -100;
471
472 if (b.w == 1 && elempack1 == 1)
473 {
474 // type 11
475 __m256 _b0 = _mm256_set1_ps(b[0]);
476 const float* ptr = a;
477 float* outptr = c;
478 for (int i = 0; i < size; i++)
479 {
480 __m256 _p = _mm256_loadu_ps(ptr);
481 __m256 _outp = op(_p, _b0);
482 _mm256_storeu_ps(outptr, _outp);
483 ptr += 8;
484 outptr += 8;
485 }
486
487 return 0;
488 }
489
490 // type 12
491 const float* ptr = a;
492 const float* ptr1 = b;
493 float* outptr = c;
494
495 for (int y = 0; y < h; y++)
496 {
497 __m256 _b0 = _mm256_loadu_ps(ptr1);
498 for (int x = 0; x < w; x++)
499 {
500 __m256 _p = _mm256_loadu_ps(ptr);
501 __m256 _outp = op(_p, _b0);
502 _mm256_storeu_ps(outptr, _outp);
503 ptr += 8;
504 outptr += 8;
505 }
506
507 ptr1 += 8;
508 }
509
510 return 0;
511 }
512 }
513 else if (a.dims == 1)
514 {
515 if (a.w == 1 && elempack == 1)
516 {
517 if (b.dims == 3)
518 {
519 // type 4
520 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
521 if (c.empty())
522 return -100;
523
524 __m256 _a0 = _mm256_set1_ps(a[0]);
525 #pragma omp parallel for num_threads(opt.num_threads)
526 for (int q = 0; q < channels1; q++)
527 {
528 const float* ptr1 = b.channel(q);
529 float* outptr = c.channel(q);
530
531 for (int i = 0; i < size1; i++)
532 {
533 __m256 _p1 = _mm256_loadu_ps(ptr1);
534 __m256 _outp = op(_a0, _p1);
535 _mm256_storeu_ps(outptr, _outp);
536 ptr1 += 8;
537 outptr += 8;
538 }
539 }
540
541 return 0;
542 }
543
544 if (b.dims == 2)
545 {
546 // type 3
547 c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator);
548 if (c.empty())
549 return -100;
550
551 __m256 _a0 = _mm256_set1_ps(a[0]);
552 const float* ptr1 = b;
553 float* outptr = c;
554 for (int i = 0; i < size1; i++)
555 {
556 __m256 _p1 = _mm256_loadu_ps(ptr1);
557 __m256 _outp = op(_a0, _p1);
558 _mm256_storeu_ps(outptr, _outp);
559 ptr1 += 8;
560 outptr += 8;
561 }
562
563 return 0;
564 }
565
566 if (b.dims == 1)
567 {
568 // type 2
569 c.create(w1, elemsize1, elempack1, opt.blob_allocator);
570 if (c.empty())
571 return -100;
572
573 __m256 _a0 = _mm256_set1_ps(a[0]);
574 const float* ptr1 = b;
575 float* outptr = c;
576 for (int i = 0; i < w1; i++)
577 {
578 __m256 _p1 = _mm256_loadu_ps(ptr1);
579 __m256 _outp = op(_a0, _p1);
580 _mm256_storeu_ps(outptr, _outp);
581 ptr1 += 8;
582 outptr += 8;
583 }
584
585 return 0;
586 }
587 }
588
589 if (b.dims == 3)
590 {
591 // type 9
592 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
593 if (c.empty())
594 return -100;
595
596 #pragma omp parallel for num_threads(opt.num_threads)
597 for (int q = 0; q < channels1; q++)
598 {
599 __m256 _a0 = _mm256_loadu_ps((const float*)a + q * 8);
600 const float* ptr1 = b.channel(q);
601 float* outptr = c.channel(q);
602
603 for (int i = 0; i < size1; i++)
604 {
605 __m256 _p1 = _mm256_loadu_ps(ptr1);
606 __m256 _outp = op(_a0, _p1);
607 _mm256_storeu_ps(outptr, _outp);
608 ptr1 += 8;
609 outptr += 8;
610 }
611 }
612
613 return 0;
614 }
615
616 if (b.dims == 2)
617 {
618 // type 8
619 c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator);
620 if (c.empty())
621 return -100;
622
623 const float* ptr = a;
624 const float* ptr1 = b;
625 float* outptr = c;
626
627 for (int y = 0; y < h1; y++)
628 {
629 __m256 _a0 = _mm256_loadu_ps(ptr);
630 for (int x = 0; x < w1; x++)
631 {
632 __m256 _p1 = _mm256_loadu_ps(ptr1);
633 __m256 _outp = op(_a0, _p1);
634 _mm256_storeu_ps(outptr, _outp);
635 ptr1 += 8;
636 outptr += 8;
637 }
638
639 ptr += 8;
640 }
641
642 return 0;
643 }
644
645 if (b.dims == 1)
646 {
647 c.create(w, elemsize, elempack, opt.blob_allocator);
648 if (c.empty())
649 return -100;
650
651 if (b.w == 1 && elempack1 == 1)
652 {
653 // type 6
654 __m256 _b0 = _mm256_set1_ps(b[0]);
655 const float* ptr = a;
656 float* outptr = c;
657 for (int i = 0; i < w; i++)
658 {
659 __m256 _p = _mm256_loadu_ps(ptr);
660 __m256 _outp = op(_p, _b0);
661 _mm256_storeu_ps(outptr, _outp);
662 ptr += 8;
663 outptr += 8;
664 }
665
666 return 0;
667 }
668
669 // type 7
670 const float* ptr = a;
671 const float* ptr1 = b;
672 float* outptr = c;
673 for (int i = 0; i < w; i++)
674 {
675 __m256 _p = _mm256_loadu_ps(ptr);
676 __m256 _p1 = _mm256_loadu_ps(ptr1);
677 __m256 _outp = op(_p, _p1);
678 _mm256_storeu_ps(outptr, _outp);
679 ptr += 8;
680 ptr1 += 8;
681 outptr += 8;
682 }
683 }
684 }
685
686 return 0;
687 }
688
689 template<typename Op>
binary_op_scalar_inplace_pack8(Mat & a,float b,const Option & opt)690 static int binary_op_scalar_inplace_pack8(Mat& a, float b, const Option& opt)
691 {
692 Op op;
693
694 int w = a.w;
695 int h = a.h;
696 int channels = a.c;
697 int size = w * h;
698
699 __m256 _b = _mm256_set1_ps(b);
700
701 #pragma omp parallel for num_threads(opt.num_threads)
702 for (int q = 0; q < channels; q++)
703 {
704 float* ptr = a.channel(q);
705
706 for (int i = 0; i < size; i++)
707 {
708 __m256 _p = _mm256_loadu_ps(ptr);
709 _p = op(_p, _b);
710 _mm256_storeu_ps(ptr, _p);
711 ptr += 8;
712 }
713 }
714
715 return 0;
716 }
717
718 struct binary_op_add_pack8
719 {
operator ()ncnn::binary_op_add_pack8720 __m256 operator()(const __m256& x, const __m256& y) const
721 {
722 return _mm256_add_ps(x, y);
723 }
724 };
725
726 struct binary_op_sub_pack8
727 {
operator ()ncnn::binary_op_sub_pack8728 __m256 operator()(const __m256& x, const __m256& y) const
729 {
730 return _mm256_sub_ps(x, y);
731 }
732 };
733
734 struct binary_op_mul_pack8
735 {
operator ()ncnn::binary_op_mul_pack8736 __m256 operator()(const __m256& x, const __m256& y) const
737 {
738 return _mm256_mul_ps(x, y);
739 }
740 };
741
742 struct binary_op_div_pack8
743 {
operator ()ncnn::binary_op_div_pack8744 __m256 operator()(const __m256& x, const __m256& y) const
745 {
746 return _mm256_div_ps(x, y);
747 }
748 };
749
750 struct binary_op_max_pack8
751 {
operator ()ncnn::binary_op_max_pack8752 __m256 operator()(const __m256& x, const __m256& y) const
753 {
754 return _mm256_max_ps(x, y);
755 }
756 };
757
758 struct binary_op_min_pack8
759 {
operator ()ncnn::binary_op_min_pack8760 __m256 operator()(const __m256& x, const __m256& y) const
761 {
762 return _mm256_min_ps(x, y);
763 }
764 };
765
766 struct binary_op_pow_pack8
767 {
operator ()ncnn::binary_op_pow_pack8768 __m256 operator()(const __m256& x, const __m256& y) const
769 {
770 return exp256_ps(_mm256_mul_ps(y, log256_ps(x)));
771 }
772 };
773
774 struct binary_op_rsub_pack8
775 {
operator ()ncnn::binary_op_rsub_pack8776 __m256 operator()(const __m256& x, const __m256& y) const
777 {
778 return _mm256_sub_ps(y, x);
779 }
780 };
781
782 struct binary_op_rdiv_pack8
783 {
operator ()ncnn::binary_op_rdiv_pack8784 __m256 operator()(const __m256& x, const __m256& y) const
785 {
786 return _mm256_div_ps(y, x);
787 }
788 };
789 #endif // __AVX__
790
791 template<typename Op>
binary_op_pack4(const Mat & a,const Mat & b,Mat & c,const Option & opt)792 static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt)
793 {
794 Op op;
795
796 int w = a.w;
797 int h = a.h;
798 int channels = a.c;
799 int size = w * h;
800 size_t elemsize = a.elemsize;
801 int elempack = a.elempack;
802
803 int w1 = b.w;
804 int h1 = b.h;
805 int channels1 = b.c;
806 int size1 = w1 * h1;
807 size_t elemsize1 = b.elemsize;
808 int elempack1 = b.elempack;
809
810 if (a.dims == 3)
811 {
812 if (b.dims == 3)
813 {
814 if (w1 == 1 && h1 == 1 && channels1 == channels)
815 {
816 // special type 1
817 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
818 if (c.empty())
819 return -100;
820
821 #pragma omp parallel for num_threads(opt.num_threads)
822 for (int q = 0; q < channels; q++)
823 {
824 const float* ptr = a.channel(q);
825 float* outptr = c.channel(q);
826 const float* b0 = b.channel(q);
827 __m128 _b0 = _mm_loadu_ps(b0);
828 for (int i = 0; i < size; i++)
829 {
830 __m128 _p = _mm_loadu_ps(ptr);
831 __m128 _outp = op(_p, _b0);
832 _mm_storeu_ps(outptr, _outp);
833 ptr += 4;
834 outptr += 4;
835 }
836 }
837
838 return 0;
839 }
840
841 if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1)
842 {
843 // special type 2
844 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
845 if (c.empty())
846 return -100;
847
848 #pragma omp parallel for num_threads(opt.num_threads)
849 for (int q = 0; q < channels; q++)
850 {
851 const float* ptr = a.channel(q);
852 const float* ptr1 = b;
853 float* outptr = c.channel(q);
854 for (int i = 0; i < size; i++)
855 {
856 __m128 _p = _mm_loadu_ps(ptr);
857 __m128 _p1 = _mm_set1_ps(*ptr1);
858 __m128 _outp = op(_p, _p1);
859 _mm_storeu_ps(outptr, _outp);
860 ptr += 4;
861 ptr1 += 1;
862 outptr += 4;
863 }
864 }
865
866 return 0;
867 }
868
869 if (w == 1 && h == 1 && channels1 == channels)
870 {
871 // special type 3
872 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
873 if (c.empty())
874 return -100;
875
876 #pragma omp parallel for num_threads(opt.num_threads)
877 for (int q = 0; q < channels1; q++)
878 {
879 const float* a0 = a.channel(q);
880 float* outptr = c.channel(q);
881 const float* ptr1 = b.channel(q);
882 __m128 _a0 = _mm_loadu_ps(a0);
883 for (int i = 0; i < size1; i++)
884 {
885 __m128 _p1 = _mm_loadu_ps(ptr1);
886 __m128 _outp = op(_a0, _p1);
887 _mm_storeu_ps(outptr, _outp);
888 ptr1 += 4;
889 outptr += 4;
890 }
891 }
892
893 return 0;
894 }
895
896 if (w1 == w && h1 == h && channels == 1 && elempack == 1)
897 {
898 // special type 4
899 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
900 if (c.empty())
901 return -100;
902
903 #pragma omp parallel for num_threads(opt.num_threads)
904 for (int q = 0; q < channels1; q++)
905 {
906 const float* ptr = a;
907 const float* ptr1 = b.channel(q);
908 float* outptr = c.channel(q);
909 for (int i = 0; i < size1; i++)
910 {
911 __m128 _p = _mm_set1_ps(*ptr);
912 __m128 _p1 = _mm_loadu_ps(ptr1);
913 __m128 _outp = op(_p, _p1);
914 _mm_storeu_ps(outptr, _outp);
915 ptr += 1;
916 ptr1 += 4;
917 outptr += 4;
918 }
919 }
920
921 return 0;
922 }
923
924 if (w != 1 && w1 == 1 && h1 == h && channels1 == channels)
925 {
926 // special type 5
927 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
928 if (c.empty())
929 return -100;
930
931 #pragma omp parallel for num_threads(opt.num_threads)
932 for (int q = 0; q < channels1; q++)
933 {
934 const float* ptr = a.channel(q);
935 const float* ptr1 = b.channel(q);
936 float* outptr = c.channel(q);
937
938 for (int y = 0; y < h; y++)
939 {
940 __m128 _p1 = _mm_loadu_ps(ptr1 + y * 4);
941 for (int x = 0; x < w; x++)
942 {
943 __m128 _p = _mm_loadu_ps(ptr);
944 __m128 _outp = op(_p, _p1);
945 _mm_storeu_ps(outptr, _outp);
946
947 ptr += 4;
948 outptr += 4;
949 }
950 }
951 }
952
953 return 0;
954 }
955
956 if (w1 == w && h != 1 && h1 == 1 && channels1 == channels)
957 {
958 // special type 6
959 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
960 if (c.empty())
961 return -100;
962
963 #pragma omp parallel for num_threads(opt.num_threads)
964 for (int q = 0; q < channels1; q++)
965 {
966 const float* ptr = a.channel(q);
967 const float* ptr1 = b.channel(q);
968 float* outptr = c.channel(q);
969
970 for (int y = 0; y < h; y++)
971 {
972 for (int x = 0; x < w; x++)
973 {
974 __m128 _p = _mm_loadu_ps(ptr);
975 __m128 _p1 = _mm_loadu_ps(ptr1 + x * 4);
976 __m128 _outp = op(_p, _p1);
977 _mm_storeu_ps(outptr, _outp);
978
979 ptr += 4;
980 outptr += 4;
981 }
982 }
983 }
984
985 return 0;
986 }
987
988 if (w1 != 1 && w == 1 && h1 == h && channels1 == channels)
989 {
990 // special type 7
991 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
992 if (c.empty())
993 return -100;
994
995 #pragma omp parallel for num_threads(opt.num_threads)
996 for (int q = 0; q < channels1; q++)
997 {
998 const float* ptr = a.channel(q);
999 const float* ptr1 = b.channel(q);
1000 float* outptr = c.channel(q);
1001
1002 for (int y = 0; y < h1; y++)
1003 {
1004 __m128 _p = _mm_loadu_ps(ptr + y * 4);
1005 for (int x = 0; x < w1; x++)
1006 {
1007 __m128 _p1 = _mm_loadu_ps(ptr1);
1008 __m128 _outp = op(_p, _p1);
1009 _mm_storeu_ps(outptr, _outp);
1010
1011 ptr1 += 4;
1012 outptr += 4;
1013 }
1014 }
1015 }
1016
1017 return 0;
1018 }
1019
1020 if (w1 == w && h1 != 1 && h == 1 && channels1 == channels)
1021 {
1022 // special type 8
1023 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
1024 if (c.empty())
1025 return -100;
1026
1027 #pragma omp parallel for num_threads(opt.num_threads)
1028 for (int q = 0; q < channels1; q++)
1029 {
1030 const float* ptr = a.channel(q);
1031 const float* ptr1 = b.channel(q);
1032 float* outptr = c.channel(q);
1033
1034 for (int y = 0; y < h1; y++)
1035 {
1036 for (int x = 0; x < w1; x++)
1037 {
1038 __m128 _p = _mm_loadu_ps(ptr + x * 4);
1039 __m128 _p1 = _mm_loadu_ps(ptr1);
1040 __m128 _outp = op(_p, _p1);
1041 _mm_storeu_ps(outptr, _outp);
1042
1043 ptr1 += 4;
1044 outptr += 4;
1045 }
1046 }
1047 }
1048
1049 return 0;
1050 }
1051
1052 // type 19
1053 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
1054 if (c.empty())
1055 return -100;
1056
1057 #pragma omp parallel for num_threads(opt.num_threads)
1058 for (int q = 0; q < channels; q++)
1059 {
1060 const float* ptr = a.channel(q);
1061 const float* ptr1 = b.channel(q);
1062 float* outptr = c.channel(q);
1063
1064 for (int i = 0; i < size; i++)
1065 {
1066 __m128 _p = _mm_loadu_ps(ptr);
1067 __m128 _p1 = _mm_loadu_ps(ptr1);
1068 __m128 _outp = op(_p, _p1);
1069 _mm_storeu_ps(outptr, _outp);
1070 ptr += 4;
1071 ptr1 += 4;
1072 outptr += 4;
1073 }
1074 }
1075
1076 return 0;
1077 }
1078
1079 c.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
1080 if (c.empty())
1081 return -100;
1082
1083 if (b.dims == 2)
1084 {
1085 // type 18
1086 #pragma omp parallel for num_threads(opt.num_threads)
1087 for (int q = 0; q < channels; q++)
1088 {
1089 const float* ptr = a.channel(q);
1090 const float* ptr1 = b.row<const float>(q);
1091 float* outptr = c.channel(q);
1092
1093 for (int y = 0; y < h; y++)
1094 {
1095 __m128 _b0 = _mm_loadu_ps(ptr1);
1096 for (int x = 0; x < w; x++)
1097 {
1098 __m128 _p = _mm_loadu_ps(ptr);
1099 __m128 _outp = op(_p, _b0);
1100 _mm_storeu_ps(outptr, _outp);
1101 ptr += 4;
1102 outptr += 4;
1103 }
1104
1105 ptr1 += 4;
1106 }
1107 }
1108
1109 return 0;
1110 }
1111
1112 if (b.dims == 1)
1113 {
1114 if (b.w == 1 && elempack1 == 1)
1115 {
1116 // type 16
1117 __m128 _b0 = _mm_set1_ps(((const float*)b)[0]);
1118 #pragma omp parallel for num_threads(opt.num_threads)
1119 for (int q = 0; q < channels; q++)
1120 {
1121 const float* ptr = a.channel(q);
1122 float* outptr = c.channel(q);
1123
1124 for (int i = 0; i < size; i++)
1125 {
1126 __m128 _p = _mm_loadu_ps(ptr);
1127 __m128 _outp = op(_p, _b0);
1128 _mm_storeu_ps(outptr, _outp);
1129 ptr += 4;
1130 outptr += 4;
1131 }
1132 }
1133
1134 return 0;
1135 }
1136
1137 // type 17
1138 #pragma omp parallel for num_threads(opt.num_threads)
1139 for (int q = 0; q < channels; q++)
1140 {
1141 const float* ptr = a.channel(q);
1142 __m128 _b0 = _mm_loadu_ps((const float*)b + q * 4);
1143 float* outptr = c.channel(q);
1144
1145 for (int i = 0; i < size; i++)
1146 {
1147 __m128 _p = _mm_loadu_ps(ptr);
1148 __m128 _outp = op(_p, _b0);
1149 _mm_storeu_ps(outptr, _outp);
1150 ptr += 4;
1151 outptr += 4;
1152 }
1153 }
1154
1155 return 0;
1156 }
1157 }
1158 else if (a.dims == 2)
1159 {
1160 if (b.dims == 3)
1161 {
1162 // type 14
1163 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
1164 if (c.empty())
1165 return -100;
1166
1167 #pragma omp parallel for num_threads(opt.num_threads)
1168 for (int q = 0; q < channels1; q++)
1169 {
1170 const float* ptr = a.row<const float>(q);
1171 const float* ptr1 = b.channel(q);
1172 float* outptr = c.channel(q);
1173
1174 for (int y = 0; y < h1; y++)
1175 {
1176 __m128 _a0 = _mm_loadu_ps(ptr);
1177 for (int x = 0; x < w1; x++)
1178 {
1179 __m128 _p1 = _mm_loadu_ps(ptr1);
1180 __m128 _outp = op(_a0, _p1);
1181 _mm_storeu_ps(outptr, _outp);
1182 ptr1 += 4;
1183 outptr += 4;
1184 }
1185
1186 ptr += 4;
1187 }
1188 }
1189
1190 return 0;
1191 }
1192
1193 c.create(w, h, elemsize, elempack, opt.blob_allocator);
1194 if (c.empty())
1195 return -100;
1196
1197 if (b.dims == 2)
1198 {
1199 // type 13
1200 const float* ptr = a;
1201 const float* ptr1 = b;
1202 float* outptr = c;
1203 for (int i = 0; i < size; i++)
1204 {
1205 __m128 _p = _mm_loadu_ps(ptr);
1206 __m128 _p1 = _mm_loadu_ps(ptr1);
1207 __m128 _outp = op(_p, _p1);
1208 _mm_storeu_ps(outptr, _outp);
1209 ptr += 4;
1210 ptr1 += 4;
1211 outptr += 4;
1212 }
1213
1214 return 0;
1215 }
1216
1217 if (b.dims == 1)
1218 {
1219 c.create(w, h, elemsize, elempack, opt.blob_allocator);
1220 if (c.empty())
1221 return -100;
1222
1223 if (b.w == 1 && elempack1 == 1)
1224 {
1225 // type 11
1226 __m128 _b0 = _mm_set1_ps(((const float*)b)[0]);
1227 const float* ptr = a;
1228 float* outptr = c;
1229 for (int i = 0; i < size; i++)
1230 {
1231 __m128 _p = _mm_loadu_ps(ptr);
1232 __m128 _outp = op(_p, _b0);
1233 _mm_storeu_ps(outptr, _outp);
1234 ptr += 4;
1235 outptr += 4;
1236 }
1237
1238 return 0;
1239 }
1240
1241 // type 12
1242 const float* ptr = a;
1243 const float* ptr1 = b;
1244 float* outptr = c;
1245
1246 for (int y = 0; y < h; y++)
1247 {
1248 __m128 _b0 = _mm_loadu_ps(ptr1);
1249 for (int x = 0; x < w; x++)
1250 {
1251 __m128 _p = _mm_loadu_ps(ptr);
1252 __m128 _outp = op(_p, _b0);
1253 _mm_storeu_ps(outptr, _outp);
1254 ptr += 4;
1255 outptr += 4;
1256 }
1257
1258 ptr1 += 4;
1259 }
1260
1261 return 0;
1262 }
1263 }
1264 else if (a.dims == 1)
1265 {
1266 if (a.w == 1 && elempack == 1)
1267 {
1268 if (b.dims == 3)
1269 {
1270 // type 4
1271 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
1272 if (c.empty())
1273 return -100;
1274
1275 __m128 _a0 = _mm_set1_ps(((const float*)a)[0]);
1276 #pragma omp parallel for num_threads(opt.num_threads)
1277 for (int q = 0; q < channels1; q++)
1278 {
1279 const float* ptr1 = b.channel(q);
1280 float* outptr = c.channel(q);
1281
1282 for (int i = 0; i < size1; i++)
1283 {
1284 __m128 _p1 = _mm_loadu_ps(ptr1);
1285 __m128 _outp = op(_a0, _p1);
1286 _mm_storeu_ps(outptr, _outp);
1287 ptr1 += 4;
1288 outptr += 4;
1289 }
1290 }
1291
1292 return 0;
1293 }
1294
1295 if (b.dims == 2)
1296 {
1297 // type 3
1298 c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator);
1299 if (c.empty())
1300 return -100;
1301
1302 __m128 _a0 = _mm_set1_ps(((const float*)a)[0]);
1303 const float* ptr1 = b;
1304 float* outptr = c;
1305 for (int i = 0; i < size1; i++)
1306 {
1307 __m128 _p1 = _mm_loadu_ps(ptr1);
1308 __m128 _outp = op(_a0, _p1);
1309 _mm_storeu_ps(outptr, _outp);
1310 ptr1 += 4;
1311 outptr += 4;
1312 }
1313
1314 return 0;
1315 }
1316
1317 if (b.dims == 1)
1318 {
1319 // type 2
1320 c.create(w1, elemsize1, elempack1, opt.blob_allocator);
1321 if (c.empty())
1322 return -100;
1323
1324 __m128 _a0 = _mm_set1_ps(((const float*)a)[0]);
1325 const float* ptr1 = b;
1326 float* outptr = c;
1327 for (int i = 0; i < w1; i++)
1328 {
1329 __m128 _p1 = _mm_loadu_ps(ptr1);
1330 __m128 _outp = op(_a0, _p1);
1331 _mm_storeu_ps(outptr, _outp);
1332 ptr1 += 4;
1333 outptr += 4;
1334 }
1335
1336 return 0;
1337 }
1338 }
1339
1340 if (b.dims == 3)
1341 {
1342 // type 9
1343 c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator);
1344 if (c.empty())
1345 return -100;
1346
1347 #pragma omp parallel for num_threads(opt.num_threads)
1348 for (int q = 0; q < channels1; q++)
1349 {
1350 __m128 _a0 = _mm_loadu_ps((const float*)a + q * 4);
1351 const float* ptr1 = b.channel(q);
1352 float* outptr = c.channel(q);
1353
1354 for (int i = 0; i < size1; i++)
1355 {
1356 __m128 _p1 = _mm_loadu_ps(ptr1);
1357 __m128 _outp = op(_a0, _p1);
1358 _mm_storeu_ps(outptr, _outp);
1359 ptr1 += 4;
1360 outptr += 4;
1361 }
1362 }
1363
1364 return 0;
1365 }
1366
1367 if (b.dims == 2)
1368 {
1369 // type 8
1370 c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator);
1371 if (c.empty())
1372 return -100;
1373
1374 const float* ptr = a;
1375 const float* ptr1 = b;
1376 float* outptr = c;
1377
1378 for (int y = 0; y < h1; y++)
1379 {
1380 __m128 _a0 = _mm_loadu_ps(ptr);
1381 for (int x = 0; x < w1; x++)
1382 {
1383 __m128 _p1 = _mm_loadu_ps(ptr1);
1384 __m128 _outp = op(_a0, _p1);
1385 _mm_storeu_ps(outptr, _outp);
1386 ptr1 += 4;
1387 outptr += 4;
1388 }
1389
1390 ptr += 4;
1391 }
1392
1393 return 0;
1394 }
1395
1396 if (b.dims == 1)
1397 {
1398 c.create(w, elemsize, elempack, opt.blob_allocator);
1399 if (c.empty())
1400 return -100;
1401
1402 if (b.w == 1 && elempack1 == 1)
1403 {
1404 // type 6
1405 __m128 _b0 = _mm_set1_ps(((const float*)b)[0]);
1406 const float* ptr = a;
1407 float* outptr = c;
1408 for (int i = 0; i < w; i++)
1409 {
1410 __m128 _p = _mm_loadu_ps(ptr);
1411 __m128 _outp = op(_p, _b0);
1412 _mm_storeu_ps(outptr, _outp);
1413 ptr += 4;
1414 outptr += 4;
1415 }
1416
1417 return 0;
1418 }
1419
1420 // type 7
1421 const float* ptr = a;
1422 const float* ptr1 = b;
1423 float* outptr = c;
1424 for (int i = 0; i < w; i++)
1425 {
1426 __m128 _p = _mm_loadu_ps(ptr);
1427 __m128 _p1 = _mm_loadu_ps(ptr1);
1428 __m128 _outp = op(_p, _p1);
1429 _mm_storeu_ps(outptr, _outp);
1430 ptr += 4;
1431 ptr1 += 4;
1432 outptr += 4;
1433 }
1434 }
1435 }
1436
1437 return 0;
1438 }
1439
1440 template<typename Op>
binary_op_scalar_inplace_pack4(Mat & a,float b,const Option & opt)1441 static int binary_op_scalar_inplace_pack4(Mat& a, float b, const Option& opt)
1442 {
1443 Op op;
1444
1445 int w = a.w;
1446 int h = a.h;
1447 int channels = a.c;
1448 int size = w * h;
1449
1450 __m128 _b = _mm_set1_ps((float)b);
1451
1452 #pragma omp parallel for num_threads(opt.num_threads)
1453 for (int q = 0; q < channels; q++)
1454 {
1455 float* ptr = a.channel(q);
1456
1457 for (int i = 0; i < size; i++)
1458 {
1459 __m128 _p = _mm_loadu_ps(ptr);
1460 _p = op(_p, _b);
1461 _mm_storeu_ps(ptr, _p);
1462 ptr += 4;
1463 }
1464 }
1465
1466 return 0;
1467 }
1468
1469 struct binary_op_add_pack4
1470 {
operator ()ncnn::binary_op_add_pack41471 __m128 operator()(const __m128& x, const __m128& y) const
1472 {
1473 return _mm_add_ps(x, y);
1474 }
1475 };
1476
1477 struct binary_op_sub_pack4
1478 {
operator ()ncnn::binary_op_sub_pack41479 __m128 operator()(const __m128& x, const __m128& y) const
1480 {
1481 return _mm_sub_ps(x, y);
1482 }
1483 };
1484
1485 struct binary_op_mul_pack4
1486 {
operator ()ncnn::binary_op_mul_pack41487 __m128 operator()(const __m128& x, const __m128& y) const
1488 {
1489 return _mm_mul_ps(x, y);
1490 }
1491 };
1492
1493 struct binary_op_div_pack4
1494 {
operator ()ncnn::binary_op_div_pack41495 __m128 operator()(const __m128& x, const __m128& y) const
1496 {
1497 return _mm_div_ps(x, y);
1498 }
1499 };
1500
1501 struct binary_op_max_pack4
1502 {
operator ()ncnn::binary_op_max_pack41503 __m128 operator()(const __m128& x, const __m128& y) const
1504 {
1505 return _mm_max_ps(x, y);
1506 }
1507 };
1508
1509 struct binary_op_min_pack4
1510 {
operator ()ncnn::binary_op_min_pack41511 __m128 operator()(const __m128& x, const __m128& y) const
1512 {
1513 return _mm_min_ps(x, y);
1514 }
1515 };
1516
1517 struct binary_op_pow_pack4
1518 {
operator ()ncnn::binary_op_pow_pack41519 __m128 operator()(const __m128& x, const __m128& y) const
1520 {
1521 return exp_ps(_mm_mul_ps(y, log_ps(x)));
1522 }
1523 };
1524
1525 struct binary_op_rsub_pack4
1526 {
operator ()ncnn::binary_op_rsub_pack41527 __m128 operator()(const __m128& x, const __m128& y) const
1528 {
1529 return _mm_sub_ps(y, x);
1530 }
1531 };
1532
1533 struct binary_op_rdiv_pack4
1534 {
operator ()ncnn::binary_op_rdiv_pack41535 __m128 operator()(const __m128& x, const __m128& y) const
1536 {
1537 return _mm_div_ps(y, x);
1538 }
1539 };
1540 #endif // __SSE2__
1541
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1542 int BinaryOp_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1543 {
1544 #if __SSE2__
1545 const Mat& bottom_blob = bottom_blobs[0];
1546 const Mat& bottom_blob1 = bottom_blobs[1];
1547 Mat& top_blob = top_blobs[0];
1548
1549 int elempack = bottom_blob.elempack;
1550 int elempack1 = bottom_blob1.elempack;
1551
1552 #if __AVX__
1553 if (elempack == 8 || elempack1 == 8)
1554 {
1555 if (op_type == Operation_ADD)
1556 return binary_op_pack8<binary_op_add_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1557
1558 if (op_type == Operation_SUB)
1559 return binary_op_pack8<binary_op_sub_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1560
1561 if (op_type == Operation_MUL)
1562 return binary_op_pack8<binary_op_mul_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1563
1564 if (op_type == Operation_DIV)
1565 return binary_op_pack8<binary_op_div_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1566
1567 if (op_type == Operation_MAX)
1568 return binary_op_pack8<binary_op_max_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1569
1570 if (op_type == Operation_MIN)
1571 return binary_op_pack8<binary_op_min_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1572
1573 if (op_type == Operation_POW)
1574 return binary_op_pack8<binary_op_pow_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1575
1576 if (op_type == Operation_RSUB)
1577 return binary_op_pack8<binary_op_rsub_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1578
1579 if (op_type == Operation_RDIV)
1580 return binary_op_pack8<binary_op_rdiv_pack8>(bottom_blob, bottom_blob1, top_blob, opt);
1581 }
1582 #endif // __AVX__
1583
1584 if (elempack == 4 || elempack1 == 4)
1585 {
1586 if (op_type == Operation_ADD)
1587 return binary_op_pack4<binary_op_add_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1588
1589 if (op_type == Operation_SUB)
1590 return binary_op_pack4<binary_op_sub_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1591
1592 if (op_type == Operation_MUL)
1593 return binary_op_pack4<binary_op_mul_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1594
1595 if (op_type == Operation_DIV)
1596 return binary_op_pack4<binary_op_div_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1597
1598 if (op_type == Operation_MAX)
1599 return binary_op_pack4<binary_op_max_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1600
1601 if (op_type == Operation_MIN)
1602 return binary_op_pack4<binary_op_min_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1603
1604 if (op_type == Operation_POW)
1605 return binary_op_pack4<binary_op_pow_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1606
1607 if (op_type == Operation_RSUB)
1608 return binary_op_pack4<binary_op_rsub_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1609
1610 if (op_type == Operation_RDIV)
1611 return binary_op_pack4<binary_op_rdiv_pack4>(bottom_blob, bottom_blob1, top_blob, opt);
1612 }
1613 #endif // __SSE2__
1614
1615 return BinaryOp::forward(bottom_blobs, top_blobs, opt);
1616 }
1617
forward_inplace(Mat & bottom_top_blob,const Option & opt) const1618 int BinaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
1619 {
1620 #if __SSE2__
1621 int elempack = bottom_top_blob.elempack;
1622
1623 #if __AVX__
1624 if (elempack == 8)
1625 {
1626 if (op_type == Operation_ADD)
1627 return binary_op_scalar_inplace_pack8<binary_op_add_pack8>(bottom_top_blob, b, opt);
1628
1629 if (op_type == Operation_SUB)
1630 return binary_op_scalar_inplace_pack8<binary_op_sub_pack8>(bottom_top_blob, b, opt);
1631
1632 if (op_type == Operation_MUL)
1633 return binary_op_scalar_inplace_pack8<binary_op_mul_pack8>(bottom_top_blob, b, opt);
1634
1635 if (op_type == Operation_DIV)
1636 return binary_op_scalar_inplace_pack8<binary_op_div_pack8>(bottom_top_blob, b, opt);
1637
1638 if (op_type == Operation_MAX)
1639 return binary_op_scalar_inplace_pack8<binary_op_max_pack8>(bottom_top_blob, b, opt);
1640
1641 if (op_type == Operation_MIN)
1642 return binary_op_scalar_inplace_pack8<binary_op_min_pack8>(bottom_top_blob, b, opt);
1643
1644 if (op_type == Operation_POW)
1645 return binary_op_scalar_inplace_pack8<binary_op_pow_pack8>(bottom_top_blob, b, opt);
1646
1647 if (op_type == Operation_RSUB)
1648 return binary_op_scalar_inplace_pack8<binary_op_rsub_pack8>(bottom_top_blob, b, opt);
1649
1650 if (op_type == Operation_RDIV)
1651 return binary_op_scalar_inplace_pack8<binary_op_rdiv_pack8>(bottom_top_blob, b, opt);
1652 }
1653 #endif // __AVX__
1654
1655 if (elempack == 4)
1656 {
1657 if (op_type == Operation_ADD)
1658 return binary_op_scalar_inplace_pack4<binary_op_add_pack4>(bottom_top_blob, b, opt);
1659
1660 if (op_type == Operation_SUB)
1661 return binary_op_scalar_inplace_pack4<binary_op_sub_pack4>(bottom_top_blob, b, opt);
1662
1663 if (op_type == Operation_MUL)
1664 return binary_op_scalar_inplace_pack4<binary_op_mul_pack4>(bottom_top_blob, b, opt);
1665
1666 if (op_type == Operation_DIV)
1667 return binary_op_scalar_inplace_pack4<binary_op_div_pack4>(bottom_top_blob, b, opt);
1668
1669 if (op_type == Operation_MAX)
1670 return binary_op_scalar_inplace_pack4<binary_op_max_pack4>(bottom_top_blob, b, opt);
1671
1672 if (op_type == Operation_MIN)
1673 return binary_op_scalar_inplace_pack4<binary_op_min_pack4>(bottom_top_blob, b, opt);
1674
1675 if (op_type == Operation_POW)
1676 return binary_op_scalar_inplace_pack4<binary_op_pow_pack4>(bottom_top_blob, b, opt);
1677
1678 if (op_type == Operation_RSUB)
1679 return binary_op_scalar_inplace_pack4<binary_op_rsub_pack4>(bottom_top_blob, b, opt);
1680
1681 if (op_type == Operation_RDIV)
1682 return binary_op_scalar_inplace_pack4<binary_op_rdiv_pack4>(bottom_top_blob, b, opt);
1683 }
1684 #endif // __SSE2__
1685
1686 return BinaryOp::forward_inplace(bottom_top_blob, opt);
1687 }
1688
1689 } // namespace ncnn
1690