1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <math.h>
18
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/nstl.hpp"
22
23 #include "cpu/simple_q10n.hpp"
24
25 #include "cpu/ref_reduction.hpp"
26
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30
31 template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
init_acc(acc_t & acc,alg_kind_t alg) const32 void ref_reduction_t<src_type, dst_type, acc_type>::init_acc(
33 acc_t &acc, alg_kind_t alg) const {
34 using namespace alg_kind;
35 using namespace nstl;
36
37 switch (alg) {
38 case reduction_max:
39 acc = static_cast<acc_t>(numeric_limits<src_t>::lowest());
40 break;
41 case reduction_min:
42 acc = static_cast<acc_t>(numeric_limits<src_t>::max());
43 break;
44 case reduction_mean:
45 case reduction_sum: acc = acc_t(0); break;
46 case reduction_mul: acc = acc_t(1); break;
47 case reduction_norm_lp_max:
48 case reduction_norm_lp_sum:
49 case reduction_norm_lp_power_p_max:
50 case reduction_norm_lp_power_p_sum: acc = acc_t(0); break;
51 default: assert(!"unknown alg");
52 }
53 }
54
55 template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
accumulate(acc_t & acc,const src_t & src,alg_kind_t alg,float p) const56 void ref_reduction_t<src_type, dst_type, acc_type>::accumulate(
57 acc_t &acc, const src_t &src, alg_kind_t alg, float p) const {
58 using namespace alg_kind;
59
60 acc_t src_ = static_cast<acc_t>(src);
61
62 switch (alg) {
63 case reduction_max: acc = nstl::max(acc, src_); break;
64 case reduction_min: acc = nstl::min(acc, src_); break;
65 case reduction_mean:
66 case reduction_sum: acc += src_; break;
67 case reduction_mul: acc *= src_; break;
68 case reduction_norm_lp_max:
69 case reduction_norm_lp_sum:
70 case reduction_norm_lp_power_p_max:
71 case reduction_norm_lp_power_p_sum:
72 acc += powf(nstl::abs(src_), p);
73 break;
74 default: assert(!"unknown alg");
75 }
76 }
77
78 template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
finalize(float & acc_f32,alg_kind_t alg,float p,float eps,dim_t n) const79 void ref_reduction_t<src_type, dst_type, acc_type>::finalize(
80 float &acc_f32, alg_kind_t alg, float p, float eps, dim_t n) const {
81 using namespace alg_kind;
82
83 switch (alg) {
84 case reduction_mean: acc_f32 /= n; break;
85 case reduction_norm_lp_max:
86 acc_f32 = nstl::max(acc_f32, eps);
87 acc_f32 = powf(acc_f32, 1.0f / p);
88 break;
89 case reduction_norm_lp_sum:
90 acc_f32 += eps;
91 acc_f32 = powf(acc_f32, 1.0f / p);
92 break;
93 case reduction_norm_lp_power_p_max:
94 acc_f32 = nstl::max(acc_f32, eps);
95 break;
96 case reduction_norm_lp_power_p_sum: acc_f32 += eps; break;
97 default: break;
98 }
99 }
100
101 template <data_type_t src_type, data_type_t dst_type, data_type_t acc_type>
execute_ref(const exec_ctx_t & ctx) const102 status_t ref_reduction_t<src_type, dst_type, acc_type>::execute_ref(
103 const exec_ctx_t &ctx) const {
104 status_t status = status::success;
105 auto src = CTX_IN_MEM(const src_t *, DNNL_ARG_SRC);
106 auto dst = CTX_OUT_CLEAN_MEM(dst_t *, DNNL_ARG_DST, status);
107 CHECK(status);
108
109 const memory_desc_wrapper src_mdw(pd()->src_md());
110 const memory_desc_wrapper dst_mdw(pd()->dst_md());
111
112 const int ndims = src_mdw.ndims();
113 const auto &src_dims = src_mdw.dims();
114 const auto &dst_dims = dst_mdw.dims();
115
116 const auto alg = pd()->desc()->alg_kind;
117 const auto p = pd()->desc()->p;
118 const auto eps = pd()->desc()->eps;
119
120 dims_t reduce_dims;
121 dim_t reduce_size {1}, idle_size = dst_mdw.nelems();
122
123 for (int d = 0; d < ndims; ++d) {
124 reduce_dims[d] = dim_t {1};
125 const bool is_reduction_dim = src_dims[d] != dst_dims[d];
126 if (is_reduction_dim) {
127 reduce_dims[d] = src_dims[d];
128 reduce_size *= reduce_dims[d];
129 }
130 }
131
132 parallel_nd(idle_size, [&](dim_t l_offset) {
133 dims_t idle_pos, reduce_pos;
134 utils::l_dims_by_l_offset(idle_pos, l_offset, dst_mdw.dims(), ndims);
135 const dim_t dst_off = dst_mdw.off_v(idle_pos);
136 const dim_t src_idle_off = src_mdw.off_v(idle_pos);
137 acc_t acc {0};
138 init_acc(acc, alg);
139 for (dim_t r = 0; r < reduce_size; ++r) {
140 utils::l_dims_by_l_offset(reduce_pos, r, reduce_dims, ndims);
141 const dim_t src_reduce_off = src_mdw.off_v(reduce_pos);
142 const dim_t src_off = src_idle_off + src_reduce_off;
143 accumulate(acc, src[src_off], alg, p);
144 }
145 float acc_f32 = static_cast<float>(acc);
146 finalize(acc_f32, alg, p, eps, reduce_size);
147
148 ref_post_ops_t::args_t args;
149 args.dst_val = dst[dst_off];
150 args.ctx = &ctx;
151 args.l_offset = l_offset;
152 args.dst_md = pd()->dst_md();
153 ref_post_ops->execute(acc_f32, args);
154
155 dst[dst_off] = saturate_and_round<dst_t>(acc_f32);
156 });
157
158 return status::success;
159 }
160
161 using namespace data_type;
162 template struct ref_reduction_t<f32, f32, f32>;
163 template struct ref_reduction_t<bf16, bf16, f32>;
164 template struct ref_reduction_t<bf16, f32, f32>;
165 template struct ref_reduction_t<s8, s8, s32>;
166 template struct ref_reduction_t<s8, s32, s32>;
167 template struct ref_reduction_t<s8, f32, f32>;
168 template struct ref_reduction_t<u8, u8, s32>;
169 template struct ref_reduction_t<u8, s32, s32>;
170 template struct ref_reduction_t<u8, f32, f32>;
171
172 } // namespace cpu
173 } // namespace impl
174 } // namespace dnnl
175