1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "reduction.h"
16 
17 #include <float.h>
18 #include <limits.h>
19 #include <math.h>
20 
21 namespace ncnn {
22 
Reduction()23 Reduction::Reduction()
24 {
25     one_blob_only = true;
26     support_inplace = false;
27 }
28 
load_param(const ParamDict & pd)29 int Reduction::load_param(const ParamDict& pd)
30 {
31     operation = pd.get(0, 0);
32     reduce_all = pd.get(1, 1);
33     coeff = pd.get(2, 1.f);
34     axes = pd.get(3, Mat());
35     keepdims = pd.get(4, 0);
36 
37     return 0;
38 }
39 
40 template<typename Op, typename Op2>
reduction_op(const Mat & a,Mat & b,float v0,bool reduce_w,bool reduce_h,bool reduce_c,const Option & opt)41 static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool reduce_h, bool reduce_c, const Option& opt)
42 {
43     Op op;
44     Op2 op2;
45 
46     size_t elemsize = a.elemsize;
47     int dims = a.dims;
48 
49     if (dims == 1)
50     {
51         int w = a.w;
52         b.create(1, elemsize, opt.blob_allocator);
53         const float* ptr = a;
54 
55         float sum = v0;
56         for (int i = 0; i < w; i++)
57         {
58             sum = op(sum, ptr[i]);
59         }
60         b[0] = sum;
61 
62         return 0;
63     }
64 
65     if (dims == 2)
66     {
67         int w = a.w;
68         int h = a.h;
69 
70         if (reduce_w && reduce_h)
71         {
72             // w h -> X X
73             b.create(1, elemsize, opt.blob_allocator);
74 
75             Mat sums(h, elemsize, opt.workspace_allocator);
76             if (sums.empty())
77                 return -100;
78 
79             #pragma omp parallel for num_threads(opt.num_threads)
80             for (int i = 0; i < h; i++)
81             {
82                 const float* ptr = a.row(i);
83 
84                 float sum = v0;
85                 for (int j = 0; j < w; j++)
86                 {
87                     sum = op(sum, ptr[j]);
88                 }
89                 sums[i] = sum;
90             }
91 
92             float sum = v0;
93             for (int i = 0; i < h; i++)
94             {
95                 sum = op2(sum, sums[i]);
96             }
97             b[0] = sum;
98 
99             return 0;
100         }
101 
102         if (reduce_w && !reduce_h)
103         {
104             // w h -> X h
105             b.create(h, elemsize, opt.blob_allocator);
106 
107             #pragma omp parallel for num_threads(opt.num_threads)
108             for (int i = 0; i < h; i++)
109             {
110                 const float* ptr = a.row(i);
111 
112                 float sum = v0;
113                 for (int j = 0; j < w; j++)
114                 {
115                     sum = op(sum, ptr[j]);
116                 }
117                 b[i] = sum;
118             }
119             return 0;
120         }
121 
122         if (!reduce_w && reduce_h)
123         {
124             // w h -> w X
125             b.create(w, elemsize, opt.blob_allocator);
126             b.fill(v0);
127 
128             for (int i = 0; i < h; i++)
129             {
130                 const float* ptr = a.row(i);
131                 for (int j = 0; j < w; j++)
132                 {
133                     b[j] = op(b[j], ptr[j]);
134                 }
135             }
136             return 0;
137         }
138     }
139 
140     if (dims == 3)
141     {
142         int w = a.w;
143         int h = a.h;
144         int channels = a.c;
145         int size = w * h;
146 
147         if (reduce_w && reduce_h && reduce_c)
148         {
149             // w h c -> X X X
150             b.create(1, elemsize, opt.blob_allocator);
151             Mat sums(channels, elemsize, opt.workspace_allocator);
152             if (sums.empty())
153                 return -100;
154 
155             #pragma omp parallel for num_threads(opt.num_threads)
156             for (int q = 0; q < channels; q++)
157             {
158                 const float* ptr = a.channel(q);
159 
160                 float sum = v0;
161                 for (int i = 0; i < size; i++)
162                 {
163                     sum = op(sum, ptr[i]);
164                 }
165                 sums[q] = sum;
166             }
167 
168             float sum = v0;
169             for (int i = 0; i < channels; i++)
170             {
171                 sum = op2(sum, sums[i]);
172             }
173             b[0] = sum;
174 
175             return 0;
176         }
177 
178         if (reduce_w && reduce_h && !reduce_c)
179         {
180             // w h c -> X X c
181             b.create(channels, elemsize, opt.blob_allocator);
182 
183             #pragma omp parallel for num_threads(opt.num_threads)
184             for (int q = 0; q < channels; q++)
185             {
186                 const float* ptr = a.channel(q);
187 
188                 float sum = v0;
189                 for (int i = 0; i < size; i++)
190                 {
191                     sum = op(sum, ptr[i]);
192                 }
193                 b[q] = sum;
194             }
195 
196             return 0;
197         }
198 
199         if (reduce_w && !reduce_h && !reduce_c)
200         {
201             // w h c -> X h c
202             b.create(h, channels, elemsize, opt.blob_allocator);
203 
204             #pragma omp parallel for num_threads(opt.num_threads)
205             for (int q = 0; q < channels; q++)
206             {
207                 const float* ptr = a.channel(q);
208                 float* outptr = b.row(q);
209 
210                 for (int i = 0; i < h; i++)
211                 {
212                     float sum = v0;
213                     for (int j = 0; j < w; j++)
214                     {
215                         sum = op(sum, ptr[j]);
216                     }
217                     outptr[i] = sum;
218                     ptr += w;
219                 }
220             }
221 
222             return 0;
223         }
224 
225         if (reduce_w && !reduce_h && reduce_c)
226         {
227             // w h c -> X h X
228             b.create(h, elemsize, opt.blob_allocator);
229             Mat mins(1, h, channels, elemsize, opt.workspace_allocator);
230             if (mins.empty())
231                 return -100;
232 
233             mins.fill(v0);
234 
235             #pragma omp parallel for num_threads(opt.num_threads)
236             for (int q = 0; q < channels; q++)
237             {
238                 const float* ptr = a.channel(q);
239                 float* mins_ptr = mins.channel(q);
240 
241                 for (int i = 0; i < h; i++)
242                 {
243                     float sum = v0;
244                     for (int j = 0; j < w; j++)
245                     {
246                         sum = op(sum, ptr[j]);
247                     }
248                     mins_ptr[i] = sum;
249                     ptr += w;
250                 }
251             }
252 
253             b.fill(v0);
254 
255             for (int q = 0; q < channels; q++)
256             {
257                 const float* mins_ptr = mins.channel(q);
258                 for (int i = 0; i < h; i++)
259                 {
260                     b[i] = op2(b[i], mins_ptr[i]);
261                 }
262             }
263 
264             return 0;
265         }
266 
267         if (!reduce_w && reduce_h && reduce_c)
268         {
269             // w h c -> w X X
270             b.create(w, elemsize, opt.blob_allocator);
271 
272             Mat mins(w, 1, channels, elemsize, opt.workspace_allocator);
273             if (mins.empty())
274                 return -100;
275 
276             mins.fill(v0);
277 
278             #pragma omp parallel for num_threads(opt.num_threads)
279             for (int q = 0; q < channels; q++)
280             {
281                 const float* ptr = a.channel(q);
282                 float* mins_ptr = mins.channel(q);
283 
284                 for (int i = 0; i < h; i++)
285                 {
286                     for (int j = 0; j < w; j++)
287                     {
288                         mins_ptr[j] = op(mins_ptr[j], ptr[j]);
289                     }
290                     ptr += w;
291                 }
292             }
293 
294             b.fill(v0);
295 
296             for (int q = 0; q < channels; q++)
297             {
298                 const float* mins_ptr = mins.channel(q);
299                 for (int j = 0; j < w; j++)
300                 {
301                     b[j] = op2(b[j], mins_ptr[j]);
302                 }
303             }
304 
305             return 0;
306         }
307 
308         if (!reduce_w && !reduce_h && reduce_c)
309         {
310             // w h c -> w h X
311             b.create(w, h, elemsize, opt.blob_allocator);
312 
313             b.fill(v0);
314 
315             for (int q = 0; q < channels; q++)
316             {
317                 const float* ptr = a.channel(q);
318 
319                 for (int i = 0; i < size; i++)
320                 {
321                     b[i] = op(b[i], ptr[i]);
322                 }
323             }
324 
325             return 0;
326         }
327 
328         if (!reduce_w && reduce_h && !reduce_c)
329         {
330             // w h c -> w X c
331             b.create(w, channels, elemsize, opt.blob_allocator);
332 
333             b.fill(v0);
334 
335             #pragma omp parallel for num_threads(opt.num_threads)
336             for (int q = 0; q < channels; q++)
337             {
338                 const float* ptr = a.channel(q);
339                 float* outptr = b.row(q);
340 
341                 for (int i = 0; i < h; i++)
342                 {
343                     for (int j = 0; j < w; j++)
344                     {
345                         outptr[j] = op(outptr[j], ptr[j]);
346                     }
347                     ptr += w;
348                 }
349             }
350             return 0;
351         }
352     }
353 
354     return 0;
355 }
356 
357 template<typename Op, typename Op2>
reduction_op_keepdims(const Mat & a,Mat & b,float v0,bool reduce_w,bool reduce_h,bool reduce_c,const Option & opt)358 static int reduction_op_keepdims(const Mat& a, Mat& b, float v0, bool reduce_w, bool reduce_h, bool reduce_c, const Option& opt)
359 {
360     Op op;
361     Op2 op2;
362 
363     size_t elemsize = a.elemsize;
364     int dims = a.dims;
365 
366     if (dims == 1)
367     {
368         int w = a.w;
369         b.create(1, elemsize, opt.blob_allocator);
370         const float* ptr = a;
371 
372         float sum = v0;
373         for (int i = 0; i < w; i++)
374         {
375             sum = op(sum, ptr[i]);
376         }
377         b[0] = sum;
378 
379         return 0;
380     }
381 
382     if (dims == 2)
383     {
384         int w = a.w;
385         int h = a.h;
386 
387         if (reduce_w && reduce_h)
388         {
389             // w h -> 1 1
390             b.create(1, 1, elemsize, opt.blob_allocator);
391 
392             Mat sums(h, elemsize, opt.workspace_allocator);
393             if (sums.empty())
394                 return -100;
395 
396             #pragma omp parallel for num_threads(opt.num_threads)
397             for (int i = 0; i < h; i++)
398             {
399                 const float* ptr = a.row(i);
400 
401                 float sum = v0;
402                 for (int j = 0; j < w; j++)
403                 {
404                     sum = op(sum, ptr[j]);
405                 }
406                 sums[i] = sum;
407             }
408 
409             float sum = v0;
410             for (int i = 0; i < h; i++)
411             {
412                 sum = op2(sum, sums[i]);
413             }
414             b[0] = sum;
415 
416             return 0;
417         }
418 
419         if (reduce_w && !reduce_h)
420         {
421             // w h -> 1 h
422             b.create(1, h, elemsize, opt.blob_allocator);
423 
424             #pragma omp parallel for num_threads(opt.num_threads)
425             for (int i = 0; i < h; i++)
426             {
427                 const float* ptr = a.row(i);
428 
429                 float sum = v0;
430                 for (int j = 0; j < w; j++)
431                 {
432                     sum = op(sum, ptr[j]);
433                 }
434                 b[i] = sum;
435             }
436 
437             return 0;
438         }
439 
440         if (!reduce_w && reduce_h)
441         {
442             // w h -> w 1
443             b.create(w, 1, elemsize, opt.blob_allocator);
444             b.fill(v0);
445 
446             for (int i = 0; i < h; i++)
447             {
448                 const float* ptr = a.row(i);
449                 for (int j = 0; j < w; j++)
450                 {
451                     b[j] = op(b[j], ptr[j]);
452                 }
453             }
454 
455             return 0;
456         }
457     }
458 
459     if (dims == 3)
460     {
461         int w = a.w;
462         int h = a.h;
463         int channels = a.c;
464         int size = w * h;
465 
466         if (reduce_w && reduce_h && reduce_c)
467         {
468             // w h c -> 1 1 1
469             b.create(1, 1, 1, elemsize, opt.blob_allocator);
470             Mat sums(channels, elemsize, opt.workspace_allocator);
471             if (sums.empty())
472                 return -100;
473 
474             #pragma omp parallel for num_threads(opt.num_threads)
475             for (int q = 0; q < channels; q++)
476             {
477                 const float* ptr = a.channel(q);
478 
479                 float sum = v0;
480                 for (int i = 0; i < size; i++)
481                 {
482                     sum = op(sum, ptr[i]);
483                 }
484                 sums[q] = sum;
485             }
486 
487             float sum = v0;
488             for (int i = 0; i < channels; i++)
489             {
490                 sum = op2(sum, sums[i]);
491             }
492             b[0] = sum;
493 
494             return 0;
495         }
496 
497         if (reduce_w && reduce_h && !reduce_c)
498         {
499             // w h c -> 1 1 c
500             b.create(1, 1, channels, elemsize, opt.blob_allocator);
501 
502             #pragma omp parallel for num_threads(opt.num_threads)
503             for (int q = 0; q < channels; q++)
504             {
505                 const float* ptr = a.channel(q);
506                 float* outptr = b.channel(q);
507 
508                 float sum = v0;
509                 for (int i = 0; i < size; i++)
510                 {
511                     sum = op(sum, ptr[i]);
512                 }
513 
514                 outptr[0] = sum;
515             }
516 
517             return 0;
518         }
519 
520         if (reduce_w && !reduce_h && !reduce_c)
521         {
522             // w h c -> 1 h c
523             b.create(1, h, channels, elemsize, opt.blob_allocator);
524 
525             #pragma omp parallel for num_threads(opt.num_threads)
526             for (int q = 0; q < channels; q++)
527             {
528                 const float* ptr = a.channel(q);
529                 float* outptr = b.channel(q);
530 
531                 for (int i = 0; i < h; i++)
532                 {
533                     float sum = v0;
534                     for (int j = 0; j < w; j++)
535                     {
536                         sum = op(sum, ptr[j]);
537                     }
538                     outptr[i] = sum;
539                     ptr += w;
540                 }
541             }
542 
543             return 0;
544         }
545 
546         if (reduce_w && !reduce_h && reduce_c)
547         {
548             // w h c -> 1 h 1
549             b.create(1, h, 1, elemsize, opt.blob_allocator);
550 
551             Mat mins(1, h, channels, elemsize, opt.workspace_allocator);
552             if (mins.empty())
553                 return -100;
554 
555             mins.fill(v0);
556 
557             #pragma omp parallel for num_threads(opt.num_threads)
558             for (int q = 0; q < channels; q++)
559             {
560                 const float* ptr = a.channel(q);
561                 float* mins_ptr = mins.channel(q);
562 
563                 for (int i = 0; i < h; i++)
564                 {
565                     float sum = v0;
566                     for (int j = 0; j < w; j++)
567                     {
568                         sum = op(sum, ptr[j]);
569                     }
570                     mins_ptr[i] = sum;
571                     ptr += w;
572                 }
573             }
574 
575             b.fill(v0);
576 
577             for (int q = 0; q < channels; q++)
578             {
579                 const float* mins_ptr = mins.channel(q);
580                 for (int i = 0; i < h; i++)
581                 {
582                     b[i] = op2(b[i], mins_ptr[i]);
583                 }
584             }
585 
586             return 0;
587         }
588 
589         if (!reduce_w && reduce_h && reduce_c)
590         {
591             // w h c -> w 1 1
592             b.create(w, 1, 1, elemsize, opt.blob_allocator);
593 
594             Mat mins(w, 1, channels, elemsize, opt.workspace_allocator);
595             if (mins.empty())
596                 return -100;
597 
598             mins.fill(v0);
599 
600             #pragma omp parallel for num_threads(opt.num_threads)
601             for (int q = 0; q < channels; q++)
602             {
603                 const float* ptr = a.channel(q);
604                 float* mins_ptr = mins.channel(q);
605 
606                 for (int i = 0; i < h; i++)
607                 {
608                     for (int j = 0; j < w; j++)
609                     {
610                         mins_ptr[j] = op(mins_ptr[j], ptr[j]);
611                     }
612                     ptr += w;
613                 }
614             }
615 
616             b.fill(v0);
617 
618             for (int q = 0; q < channels; q++)
619             {
620                 const float* mins_ptr = mins.channel(q);
621                 for (int j = 0; j < w; j++)
622                 {
623                     b[j] = op2(b[j], mins_ptr[j]);
624                 }
625             }
626 
627             return 0;
628         }
629 
630         if (!reduce_w && !reduce_h && reduce_c)
631         {
632             // w h c -> w h 1
633             b.create(w, h, 1, elemsize, opt.blob_allocator);
634 
635             b.fill(v0);
636 
637             for (int q = 0; q < channels; q++)
638             {
639                 const float* ptr = a.channel(q);
640 
641                 for (int i = 0; i < size; i++)
642                 {
643                     b[i] = op(b[i], ptr[i]);
644                 }
645             }
646 
647             return 0;
648         }
649 
650         if (!reduce_w && reduce_h && !reduce_c)
651         {
652             // w h c -> w 1 c
653             b.create(w, 1, channels, elemsize, opt.blob_allocator);
654             b.fill(v0);
655 
656             #pragma omp parallel for num_threads(opt.num_threads)
657             for (int q = 0; q < channels; q++)
658             {
659                 const float* ptr = a.channel(q);
660                 float* outptr = b.channel(q);
661 
662                 for (int i = 0; i < h; i++)
663                 {
664                     for (int j = 0; j < w; j++)
665                     {
666                         outptr[j] = op(outptr[j], ptr[j]);
667                     }
668                     ptr += w;
669                 }
670             }
671 
672             return 0;
673         }
674     }
675 
676     return 0;
677 }
678 
679 template<typename MathOp>
reduction_post_process(Mat & a,float coeff,const Option & opt)680 static int reduction_post_process(Mat& a, float coeff, const Option& opt)
681 {
682     MathOp mathop;
683 
684     int dims = a.dims;
685     if (dims == 1)
686     {
687         int w = a.w;
688 
689         #pragma omp parallel for num_threads(opt.num_threads)
690         for (int i = 0; i < w; i++)
691             a[i] = mathop(a[i]) * coeff;
692     }
693     else if (dims == 2)
694     {
695         int size = a.w * a.h;
696 
697         #pragma omp parallel for num_threads(opt.num_threads)
698         for (int i = 0; i < size; i++)
699             a[i] = mathop(a[i]) * coeff;
700     }
701     else if (dims == 3)
702     {
703         int c = a.c;
704         int size = a.w * a.h;
705         if (c == 1)
706         {
707             #pragma omp parallel for num_threads(opt.num_threads)
708             for (int i = 0; i < size; i++)
709                 a[i] = mathop(a[i]) * coeff;
710         }
711         else
712         {
713             #pragma omp parallel for num_threads(opt.num_threads)
714             for (int q = 0; q < c; q++)
715             {
716                 float* outptr = a.channel(q);
717                 for (int i = 0; i < size; i++)
718                     outptr[i] = mathop(outptr[i]) * coeff;
719             }
720         }
721     }
722 
723     return 0;
724 }
725 
726 template<typename Op, typename Op2, typename Op3>
reduction(const Mat & a,Mat & b,float v0,bool reduce_w,bool reduce_h,bool reduce_c,bool post_process,float coeff,int keepdims,const Option & opt)727 static int reduction(const Mat& a, Mat& b, float v0, bool reduce_w, bool reduce_h, bool reduce_c, bool post_process, float coeff, int keepdims, const Option& opt)
728 {
729     int ret;
730     if (keepdims)
731         ret = reduction_op_keepdims<Op, Op2>(a, b, v0, reduce_w, reduce_h, reduce_c, opt);
732     else
733         ret = reduction_op<Op, Op2>(a, b, v0, reduce_w, reduce_h, reduce_c, opt);
734     if (ret != 0)
735         return -100;
736 
737     if (post_process || fabs(coeff - 1.f) > FLT_EPSILON)
738     {
739         ret = reduction_post_process<Op3>(b, coeff, opt);
740         if (ret != 0)
741             return -100;
742     }
743     return ret;
744 }
745 
746 template<typename T>
747 struct post_process_identity
748 {
operator ()ncnn::post_process_identity749     T operator()(const T& x) const
750     {
751         return x;
752     }
753 };
754 
755 template<typename T>
756 struct post_process_sqrt
757 {
operator ()ncnn::post_process_sqrt758     T operator()(const T& x) const
759     {
760         return static_cast<T>(sqrt(x));
761     }
762 };
763 
764 template<typename T>
765 struct post_process_log
766 {
operator ()ncnn::post_process_log767     T operator()(const T& x) const
768     {
769         return static_cast<T>(log(x));
770     }
771 };
772 
773 template<typename T>
774 struct reduction_op_add
775 {
operator ()ncnn::reduction_op_add776     T operator()(const T& x, const T& y) const
777     {
778         return x + y;
779     }
780 };
781 
782 template<typename T>
783 struct reduction_op_mul
784 {
operator ()ncnn::reduction_op_mul785     T operator()(const T& x, const T& y) const
786     {
787         return x * y;
788     }
789 };
790 
791 template<typename T>
792 struct reduction_op_asum
793 {
operator ()ncnn::reduction_op_asum794     T operator()(const T& x, const T& y) const
795     {
796         return static_cast<T>(x + fabs(y));
797     }
798 };
799 
800 template<typename T>
801 struct reduction_op_sumsq
802 {
operator ()ncnn::reduction_op_sumsq803     T operator()(const T& x, const T& y) const
804     {
805         return x + y * y;
806     }
807 };
808 
809 template<typename T>
810 struct reduction_op_sumsexp
811 {
operator ()ncnn::reduction_op_sumsexp812     T operator()(const T& x, const T& y) const
813     {
814         return static_cast<T>(x + exp(y));
815     }
816 };
817 
818 template<typename T>
819 struct reduction_op_max
820 {
operator ()ncnn::reduction_op_max821     T operator()(const T& x, const T& y) const
822     {
823         return std::max(x, y);
824     }
825 };
826 
827 template<typename T>
828 struct reduction_op_min
829 {
operator ()ncnn::reduction_op_min830     T operator()(const T& x, const T& y) const
831     {
832         return std::min(x, y);
833     }
834 };
835 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const836 int Reduction::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
837 {
838     int dims = bottom_blob.dims;
839     int axes_flag[3] = {0};
840     bool reduce_w = false;
841     bool reduce_h = false;
842     bool reduce_c = false;
843 
844     if (reduce_all)
845     {
846         reduce_w = true;
847         reduce_h = true;
848         reduce_c = true;
849     }
850     else
851     {
852         const int* axes_ptr = axes;
853         int reduced_axes_num = axes.w;
854 
855         for (int i = 0; i < reduced_axes_num; i++)
856         {
857             int axis = axes_ptr[i];
858             // handle negative axis
859             if (axis < 0)
860                 axis += dims + 1;
861             axes_flag[axis - 1] = 1;
862         }
863 
864         if (dims == 1)
865         {
866             reduce_w = true;
867         }
868         else if (dims == 2)
869         {
870             if (axes_flag[0] == 1) reduce_h = true;
871             if (axes_flag[1] == 1) reduce_w = true;
872         }
873         else if (dims == 3)
874         {
875             if (axes_flag[0] == 1) reduce_c = true;
876             if (axes_flag[1] == 1) reduce_h = true;
877             if (axes_flag[2] == 1) reduce_w = true;
878         }
879     }
880 
881     if (operation == ReductionOp_SUM)
882         return reduction<reduction_op_add<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
883 
884     if (operation == ReductionOp_ASUM)
885         return reduction<reduction_op_asum<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
886 
887     if (operation == ReductionOp_SUMSQ)
888         return reduction<reduction_op_sumsq<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
889 
890     if (operation == ReductionOp_MEAN)
891     {
892         int scale = 1;
893         int dims = bottom_blob.dims;
894         if (dims == 1)
895         {
896             scale = bottom_blob.w;
897         }
898         else if (dims == 2)
899         {
900             if (reduce_w) scale *= bottom_blob.w;
901             if (reduce_h) scale *= bottom_blob.h;
902         }
903         else if (dims == 3)
904         {
905             if (reduce_w) scale *= bottom_blob.w;
906             if (reduce_h) scale *= bottom_blob.h;
907             if (reduce_c) scale *= bottom_blob.c;
908         }
909 
910         float coeff_mean = coeff / scale;
911         return reduction<reduction_op_add<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, true, coeff_mean, keepdims, opt);
912     }
913 
914     if (operation == ReductionOp_MAX)
915         return reduction<reduction_op_max<float>, reduction_op_max<float>, post_process_identity<float> >(bottom_blob, top_blob, -FLT_MAX, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
916 
917     if (operation == ReductionOp_MIN)
918         return reduction<reduction_op_min<float>, reduction_op_min<float>, post_process_identity<float> >(bottom_blob, top_blob, FLT_MAX, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
919 
920     if (operation == ReductionOp_PROD)
921         return reduction<reduction_op_mul<float>, reduction_op_mul<float>, post_process_identity<float> >(bottom_blob, top_blob, 1.f, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
922 
923     if (operation == ReductionOp_L1)
924         return reduction<reduction_op_asum<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, false, 1.f, keepdims, opt);
925 
926     if (operation == ReductionOp_L2)
927         return reduction<reduction_op_sumsq<float>, reduction_op_add<float>, post_process_sqrt<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, true, 1.f, keepdims, opt);
928 
929     if (operation == ReductionOp_LogSum)
930         return reduction<reduction_op_add<float>, reduction_op_add<float>, post_process_log<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, true, 1.f, keepdims, opt);
931 
932     if (operation == ReductionOp_LogSumExp)
933         return reduction<reduction_op_sumsexp<float>, reduction_op_add<float>, post_process_log<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, true, 1.f, keepdims, opt);
934 
935     return 0;
936 }
937 
938 } // namespace ncnn
939