1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cassert>
18 #include <cfloat>
19
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/math_utils.hpp"
23 #include "common/type_helpers.hpp"
24
25 #include "cpu/resampling_utils.hpp"
26
27 #include "cpu/ref_resampling.hpp"
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32
33 using namespace resampling_utils;
34
35 using byte = unsigned char;
36 using load_fn_t = std::function<float(const byte *base, const dim_t offset)>;
37 using store_fn_t
38 = std::function<void(const float val, byte *base, const dim_t offset)>;
39
40 namespace {
41 template <data_type_t type>
create_load()42 load_fn_t create_load() {
43 return [](const byte *base, dim_t offset) -> float {
44 return static_cast<float>(
45 reinterpret_cast<const typename prec_traits<type>::type *>(
46 base)[offset]);
47 };
48 }
49 template <>
create_load()50 load_fn_t create_load<data_type::f32>() {
51 return [](const byte *base, dim_t offset) -> float {
52 return reinterpret_cast<const float *>(base)[offset];
53 };
54 }
55 template <data_type_t type>
create_store()56 store_fn_t create_store() {
57 using dst_t = typename prec_traits<type>::type;
58 return [](const float val, byte *base, const dim_t offset) {
59 *reinterpret_cast<dst_t *>(base + sizeof(dst_t) * offset)
60 = cpu::saturate_and_round<dst_t>(val);
61 };
62 }
63 template <>
create_store()64 store_fn_t create_store<data_type::f32>() {
65 return [](const float val, byte *base, const dim_t offset) {
66 *reinterpret_cast<float *>(base + sizeof(float) * offset) = val;
67 };
68 }
69 } // namespace
70
create_load(const data_type_t src_dtype)71 static load_fn_t create_load(const data_type_t src_dtype) {
72 using namespace data_type;
73
74 switch (src_dtype) {
75 case f32: return create_load<f32>();
76 case s32: return create_load<s32>();
77 case bf16: return create_load<bf16>();
78 case s8: return create_load<s8>();
79 case u8: return create_load<u8>();
80 default: assert(!"Unsupported data type.");
81 }
82 return create_load<f32>();
83 }
84
create_store(const data_type_t dst_dtype)85 static store_fn_t create_store(const data_type_t dst_dtype) {
86 using namespace data_type;
87
88 switch (dst_dtype) {
89 case f32: return create_store<f32>();
90 case s32: return create_store<s32>();
91 case bf16: return create_store<bf16>();
92 case s8: return create_store<s8>();
93 case u8: return create_store<u8>();
94 default: assert(!"Unsupported data type.");
95 }
96 return create_store<f32>();
97 }
98
get_offset(const memory_desc_wrapper & data_d,int n,int c,int d,int h,int w)99 static dim_t get_offset(
100 const memory_desc_wrapper &data_d, int n, int c, int d, int h, int w) {
101 if (data_d.ndims() == 5) return data_d.off(n, c, d, h, w);
102 if (data_d.ndims() == 4) return data_d.off(n, c, h, w);
103 return data_d.off(n, c, w);
104 }
105
ref_resampling_fwd_t(const pd_t * apd)106 ref_resampling_fwd_t::ref_resampling_fwd_t(const pd_t *apd)
107 : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
108
109 ref_resampling_fwd_t::~ref_resampling_fwd_t() = default;
110
execute_forward(const exec_ctx_t & ctx) const111 void ref_resampling_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
112 if (this->pd()->has_zero_dim_memory()) return;
113
114 const auto src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC);
115 auto dst = CTX_OUT_MEM(byte *, DNNL_ARG_DST);
116
117 const memory_desc_wrapper src_d(pd()->src_md());
118 const memory_desc_wrapper dst_d(pd()->dst_md());
119
120 const data_type_t src_dt = pd()->src_md()->data_type;
121 const data_type_t dst_dt = pd()->dst_md()->data_type;
122
123 load_fn_t load_fn = create_load(src_dt);
124 store_fn_t store_fn = create_store(dst_dt);
125
126 const auto alg = pd()->desc()->alg_kind;
127
128 const int MB = pd()->MB();
129 const int C = pd()->C();
130 const int ID = pd()->ID();
131 const int IH = pd()->IH();
132 const int IW = pd()->IW();
133 const int OD = pd()->OD();
134 const int OH = pd()->OH();
135 const int OW = pd()->OW();
136
137 auto lin_interp = [&](float c0, float c1, float w) {
138 return c0 * w + c1 * (1 - w);
139 };
140 auto bilin_interp = [&](float c00, float c01, float c10, float c11,
141 float w0, float w1) {
142 return lin_interp(
143 lin_interp(c00, c10, w0), lin_interp(c01, c11, w0), w1);
144 };
145 auto trilin_interp = [&](float c000, float c001, float c010, float c011,
146 float c100, float c101, float c110, float c111,
147 float w0, float w1, float w2) {
148 return lin_interp(bilin_interp(c000, c010, c100, c110, w0, w1),
149 bilin_interp(c001, c011, c101, c111, w0, w1), w2);
150 };
151
152 parallel_nd(MB, C, OD, OH, OW,
153 [&](dim_t mb, dim_t ch, dim_t od, dim_t oh, dim_t ow) {
154 const dim_t data_p_off = get_offset(dst_d, mb, ch, od, oh, ow);
155 const dim_t data_l_off
156 = (((mb * C + ch) * OD + od) * OH + oh) * OW + ow;
157 float res = 0.f;
158
159 if (alg == alg_kind::resampling_nearest) {
160 const dim_t id = nearest_idx(od, OD, ID);
161 const dim_t ih = nearest_idx(oh, OH, IH);
162 const dim_t iw = nearest_idx(ow, OW, IW);
163 res = load_fn(src, get_offset(src_d, mb, ch, id, ih, iw));
164 } else if (alg == alg_kind::resampling_linear) {
165 // Trilinear interpolation (linear interpolation on a 3D spatial
166 // tensor) can be expressed as linear interpolation along
167 // dimension x followed by interpolation along dimension y and z
168 // C011--C11--C111
169 // - - |
170 // - - |
171 //C001--C01--C111 |
172 // - .C - C110
173 // - - -
174 // - - -
175 //C000--C00--C100
176 auto id = linear_coeffs_t(od, OD, ID);
177 auto iw = linear_coeffs_t(ow, OW, IW);
178 auto ih = linear_coeffs_t(oh, OH, IH);
179 float src_l[8] = {0};
180 for_(int i = 0; i < 2; i++)
181 for_(int j = 0; j < 2; j++)
182 for (int k = 0; k < 2; k++) {
183 src_l[4 * i + 2 * j + k] = load_fn(src,
184 get_offset(src_d, mb, ch, id.idx[i], ih.idx[j],
185 iw.idx[k]));
186 }
187 res = trilin_interp(src_l[0], src_l[1], src_l[2], src_l[3],
188 src_l[4], src_l[5], src_l[6], src_l[7], id.wei[0],
189 ih.wei[0], iw.wei[0]);
190 }
191
192 ref_post_ops_t::args_t args;
193 args.ctx = &ctx;
194 args.dst_md = pd()->dst_md();
195 args.l_offset = data_l_off;
196 args.dst_val = dst[data_p_off];
197 ref_post_ops_.execute(res, args);
198
199 store_fn(res, dst, data_p_off);
200 });
201 }
202
ref_resampling_bwd_t(const pd_t * apd)203 ref_resampling_bwd_t::ref_resampling_bwd_t(const pd_t *apd)
204 : primitive_t(apd) {}
205
206 ref_resampling_bwd_t::~ref_resampling_bwd_t() = default;
207
execute_backward(const exec_ctx_t & ctx) const208 void ref_resampling_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
209 if (this->pd()->has_zero_dim_memory()) return;
210
211 const auto diff_dst = CTX_IN_MEM(const byte *, DNNL_ARG_DIFF_DST);
212 auto diff_src = CTX_OUT_MEM(byte *, DNNL_ARG_DIFF_SRC);
213
214 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
215 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
216
217 const data_type_t diff_dst_dt = pd()->diff_dst_md()->data_type;
218 const data_type_t diff_src_dt = pd()->diff_src_md()->data_type;
219
220 load_fn_t load_fn = create_load(diff_dst_dt);
221 store_fn_t store_fn = create_store(diff_src_dt);
222
223 const auto alg = pd()->desc()->alg_kind;
224
225 const int MB = pd()->MB();
226 const int C = pd()->C();
227 const int ID = pd()->ID();
228 const int IH = pd()->IH();
229 const int IW = pd()->IW();
230 const int OD = pd()->OD();
231 const int OH = pd()->OH();
232 const int OW = pd()->OW();
233
234 if (alg == alg_kind::resampling_nearest) {
235 parallel_nd(MB, C, ID, IH, IW,
236 [&](dim_t mb, dim_t ch, dim_t id, dim_t ih, dim_t iw) {
237 const dim_t od_start
238 = ceil_idx(((float)id * OD / ID) - 0.5f);
239 const dim_t oh_start
240 = ceil_idx(((float)ih * OH / IH) - 0.5f);
241 const dim_t ow_start
242 = ceil_idx(((float)iw * OW / IW) - 0.5f);
243
244 const dim_t od_end
245 = ceil_idx(((id + 1.f) * OD / ID) - 0.5f);
246 const dim_t oh_end
247 = ceil_idx(((ih + 1.f) * OH / IH) - 0.5f);
248 const dim_t ow_end
249 = ceil_idx(((iw + 1.f) * OW / IW) - 0.5f);
250
251 float ds = 0;
252 for_(dim_t od = od_start; od < od_end; od++)
253 for_(dim_t oh = oh_start; oh < oh_end; oh++)
254 for (dim_t ow = ow_start; ow < ow_end; ow++)
255 ds += load_fn(diff_dst,
256 get_offset(diff_dst_d, mb, ch, od, oh, ow));
257 store_fn(ds, diff_src,
258 get_offset(diff_src_d, mb, ch, id, ih, iw));
259 });
260 } else {
261 parallel_nd(MB, C, ID, IH, IW,
262 [&](dim_t mb, dim_t ch, dim_t id, dim_t ih, dim_t iw) {
263 bwd_linear_coeffs_t d(id, OD, ID);
264 bwd_linear_coeffs_t h(ih, OH, IH);
265 bwd_linear_coeffs_t w(iw, OW, IW);
266
267 float ds = 0;
268 for_(int i = 0; i < 2; i++)
269 for_(int j = 0; j < 2; j++)
270 for_(int k = 0; k < 2; k++)
271 for_(dim_t od = d.start[i]; od < d.end[i]; od++)
272 for_(dim_t oh = h.start[j]; oh < h.end[j]; oh++)
273 for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) {
274 const float weight_d = linear_weight(i, od, OD, ID);
275 const float weight_h = linear_weight(j, oh, OH, IH);
276 const float weight_w = linear_weight(k, ow, OW, IW);
277
278 float dd = load_fn(diff_dst,
279 get_offset(diff_dst_d, mb, ch, od, oh, ow));
280 ds += dd * weight_d * weight_h * weight_w;
281 }
282 store_fn(ds, diff_src,
283 get_offset(diff_src_d, mb, ch, id, ih, iw));
284 });
285 }
286 }
287
288 } // namespace cpu
289 } // namespace impl
290 } // namespace dnnl
291
292 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
293