1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <math.h>
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/nstl.hpp"
22 
23 #include "cpu/simple_q10n.hpp"
24 
25 #include "cpu/ref_reduction.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 
31 template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
init_acc(acc_t & acc,alg_kind_t alg) const32 void ref_reduction_t<src_type, dst_type, acc_type>::init_acc(
33         acc_t &acc, alg_kind_t alg) const {
34     using namespace alg_kind;
35     using namespace nstl;
36 
37     switch (alg) {
38         case reduction_max:
39             acc = static_cast<acc_t>(numeric_limits<src_t>::lowest());
40             break;
41         case reduction_min:
42             acc = static_cast<acc_t>(numeric_limits<src_t>::max());
43             break;
44         case reduction_mean:
45         case reduction_sum: acc = acc_t(0); break;
46         case reduction_mul: acc = acc_t(1); break;
47         case reduction_norm_lp_max:
48         case reduction_norm_lp_sum:
49         case reduction_norm_lp_power_p_max:
50         case reduction_norm_lp_power_p_sum: acc = acc_t(0); break;
51         default: assert(!"unknown alg");
52     }
53 }
54 
55 template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
accumulate(acc_t & acc,const src_t & src,alg_kind_t alg,float p) const56 void ref_reduction_t<src_type, dst_type, acc_type>::accumulate(
57         acc_t &acc, const src_t &src, alg_kind_t alg, float p) const {
58     using namespace alg_kind;
59 
60     acc_t src_ = static_cast<acc_t>(src);
61 
62     switch (alg) {
63         case reduction_max: acc = nstl::max(acc, src_); break;
64         case reduction_min: acc = nstl::min(acc, src_); break;
65         case reduction_mean:
66         case reduction_sum: acc += src_; break;
67         case reduction_mul: acc *= src_; break;
68         case reduction_norm_lp_max:
69         case reduction_norm_lp_sum:
70         case reduction_norm_lp_power_p_max:
71         case reduction_norm_lp_power_p_sum:
72             acc += powf(nstl::abs(src_), p);
73             break;
74         default: assert(!"unknown alg");
75     }
76 }
77 
78 template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
finalize(float & acc_f32,alg_kind_t alg,float p,float eps,dim_t n) const79 void ref_reduction_t<src_type, dst_type, acc_type>::finalize(
80         float &acc_f32, alg_kind_t alg, float p, float eps, dim_t n) const {
81     using namespace alg_kind;
82 
83     switch (alg) {
84         case reduction_mean: acc_f32 /= n; break;
85         case reduction_norm_lp_max:
86             acc_f32 = nstl::max(acc_f32, eps);
87             acc_f32 = powf(acc_f32, 1.0f / p);
88             break;
89         case reduction_norm_lp_sum:
90             acc_f32 += eps;
91             acc_f32 = powf(acc_f32, 1.0f / p);
92             break;
93         case reduction_norm_lp_power_p_max:
94             acc_f32 = nstl::max(acc_f32, eps);
95             break;
96         case reduction_norm_lp_power_p_sum: acc_f32 += eps; break;
97         default: break;
98     }
99 }
100 
101 template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
execute_ref(const exec_ctx_t & ctx) const102 status_t ref_reduction_t<src_type, dst_type, acc_type>::execute_ref(
103         const exec_ctx_t &ctx) const {
104     status_t status = status::success;
105     auto src = CTX_IN_MEM(const src_t *, DNNL_ARG_SRC);
106     auto dst = CTX_OUT_CLEAN_MEM(dst_t *, DNNL_ARG_DST, status);
107     CHECK(status);
108 
109     const memory_desc_wrapper src_mdw(pd()->src_md());
110     const memory_desc_wrapper dst_mdw(pd()->dst_md());
111 
112     const int ndims = src_mdw.ndims();
113     const auto &src_dims = src_mdw.dims();
114     const auto &dst_dims = dst_mdw.dims();
115 
116     const auto alg = pd()->desc()->alg_kind;
117     const auto p = pd()->desc()->p;
118     const auto eps = pd()->desc()->eps;
119 
120     dims_t reduce_dims;
121     dim_t reduce_size {1}, idle_size = dst_mdw.nelems();
122 
123     for (int d = 0; d < ndims; ++d) {
124         reduce_dims[d] = dim_t {1};
125         const bool is_reduction_dim = src_dims[d] != dst_dims[d];
126         if (is_reduction_dim) {
127             reduce_dims[d] = src_dims[d];
128             reduce_size *= reduce_dims[d];
129         }
130     }
131 
132     parallel_nd(idle_size, [&](dim_t l_offset) {
133         dims_t idle_pos, reduce_pos;
134         utils::l_dims_by_l_offset(idle_pos, l_offset, dst_mdw.dims(), ndims);
135         const dim_t dst_off = dst_mdw.off_v(idle_pos);
136         const dim_t src_idle_off = src_mdw.off_v(idle_pos);
137         acc_t acc {0};
138         init_acc(acc, alg);
139         for (dim_t r = 0; r < reduce_size; ++r) {
140             utils::l_dims_by_l_offset(reduce_pos, r, reduce_dims, ndims);
141             const dim_t src_reduce_off = src_mdw.off_v(reduce_pos);
142             const dim_t src_off = src_idle_off + src_reduce_off;
143             accumulate(acc, src[src_off], alg, p);
144         }
145         float acc_f32 = static_cast<float>(acc);
146         finalize(acc_f32, alg, p, eps, reduce_size);
147 
148         ref_post_ops_t::args_t args;
149         args.dst_val = dst[dst_off];
150         args.ctx = &ctx;
151         args.l_offset = l_offset;
152         args.dst_md = pd()->dst_md();
153         ref_post_ops->execute(acc_f32, args);
154 
155         dst[dst_off] = saturate_and_round<dst_t>(acc_f32);
156     });
157 
158     return status::success;
159 }
160 
161 using namespace data_type;
162 template struct ref_reduction_t<f32, f32, f32>;
163 template struct ref_reduction_t<bf16, bf16, f32>;
164 template struct ref_reduction_t<bf16, f32, f32>;
165 template struct ref_reduction_t<s8, s8, s32>;
166 template struct ref_reduction_t<s8, s32, s32>;
167 template struct ref_reduction_t<s8, f32, f32>;
168 template struct ref_reduction_t<u8, u8, s32>;
169 template struct ref_reduction_t<u8, s32, s32>;
170 template struct ref_reduction_t<u8, f32, f32>;
171 
172 } // namespace cpu
173 } // namespace impl
174 } // namespace dnnl
175