1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <cassert>
18 #include <cfloat>
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/math_utils.hpp"
23 #include "common/type_helpers.hpp"
24 
25 #include "cpu/resampling_utils.hpp"
26 
27 #include "cpu/ref_resampling.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 
33 using namespace resampling_utils;
34 
35 using byte = unsigned char;
36 using load_fn_t = std::function<float(const byte *base, const dim_t offset)>;
37 using store_fn_t
38         = std::function<void(const float val, byte *base, const dim_t offset)>;
39 
40 namespace {
41 template <data_type_t type>
create_load()42 load_fn_t create_load() {
43     return [](const byte *base, dim_t offset) -> float {
44         return static_cast<float>(
45                 reinterpret_cast<const typename prec_traits<type>::type *>(
46                         base)[offset]);
47     };
48 }
49 template <>
create_load()50 load_fn_t create_load<data_type::f32>() {
51     return [](const byte *base, dim_t offset) -> float {
52         return reinterpret_cast<const float *>(base)[offset];
53     };
54 }
55 template <data_type_t type>
create_store()56 store_fn_t create_store() {
57     using dst_t = typename prec_traits<type>::type;
58     return [](const float val, byte *base, const dim_t offset) {
59         *reinterpret_cast<dst_t *>(base + sizeof(dst_t) * offset)
60                 = cpu::saturate_and_round<dst_t>(val);
61     };
62 }
63 template <>
create_store()64 store_fn_t create_store<data_type::f32>() {
65     return [](const float val, byte *base, const dim_t offset) {
66         *reinterpret_cast<float *>(base + sizeof(float) * offset) = val;
67     };
68 }
69 } // namespace
70 
create_load(const data_type_t src_dtype)71 static load_fn_t create_load(const data_type_t src_dtype) {
72     using namespace data_type;
73 
74     switch (src_dtype) {
75         case f32: return create_load<f32>();
76         case s32: return create_load<s32>();
77         case bf16: return create_load<bf16>();
78         case s8: return create_load<s8>();
79         case u8: return create_load<u8>();
80         default: assert(!"Unsupported data type.");
81     }
82     return create_load<f32>();
83 }
84 
create_store(const data_type_t dst_dtype)85 static store_fn_t create_store(const data_type_t dst_dtype) {
86     using namespace data_type;
87 
88     switch (dst_dtype) {
89         case f32: return create_store<f32>();
90         case s32: return create_store<s32>();
91         case bf16: return create_store<bf16>();
92         case s8: return create_store<s8>();
93         case u8: return create_store<u8>();
94         default: assert(!"Unsupported data type.");
95     }
96     return create_store<f32>();
97 }
98 
get_offset(const memory_desc_wrapper & data_d,int n,int c,int d,int h,int w)99 static dim_t get_offset(
100         const memory_desc_wrapper &data_d, int n, int c, int d, int h, int w) {
101     if (data_d.ndims() == 5) return data_d.off(n, c, d, h, w);
102     if (data_d.ndims() == 4) return data_d.off(n, c, h, w);
103     return data_d.off(n, c, w);
104 }
105 
ref_resampling_fwd_t(const pd_t * apd)106 ref_resampling_fwd_t::ref_resampling_fwd_t(const pd_t *apd)
107     : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
108 
109 ref_resampling_fwd_t::~ref_resampling_fwd_t() = default;
110 
execute_forward(const exec_ctx_t & ctx) const111 void ref_resampling_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
112     if (this->pd()->has_zero_dim_memory()) return;
113 
114     const auto src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC);
115     auto dst = CTX_OUT_MEM(byte *, DNNL_ARG_DST);
116 
117     const memory_desc_wrapper src_d(pd()->src_md());
118     const memory_desc_wrapper dst_d(pd()->dst_md());
119 
120     const data_type_t src_dt = pd()->src_md()->data_type;
121     const data_type_t dst_dt = pd()->dst_md()->data_type;
122 
123     load_fn_t load_fn = create_load(src_dt);
124     store_fn_t store_fn = create_store(dst_dt);
125 
126     const auto alg = pd()->desc()->alg_kind;
127 
128     const int MB = pd()->MB();
129     const int C = pd()->C();
130     const int ID = pd()->ID();
131     const int IH = pd()->IH();
132     const int IW = pd()->IW();
133     const int OD = pd()->OD();
134     const int OH = pd()->OH();
135     const int OW = pd()->OW();
136 
137     auto lin_interp = [&](float c0, float c1, float w) {
138         return c0 * w + c1 * (1 - w);
139     };
140     auto bilin_interp = [&](float c00, float c01, float c10, float c11,
141                                 float w0, float w1) {
142         return lin_interp(
143                 lin_interp(c00, c10, w0), lin_interp(c01, c11, w0), w1);
144     };
145     auto trilin_interp = [&](float c000, float c001, float c010, float c011,
146                                  float c100, float c101, float c110, float c111,
147                                  float w0, float w1, float w2) {
148         return lin_interp(bilin_interp(c000, c010, c100, c110, w0, w1),
149                 bilin_interp(c001, c011, c101, c111, w0, w1), w2);
150     };
151 
152     parallel_nd(MB, C, OD, OH, OW,
153             [&](dim_t mb, dim_t ch, dim_t od, dim_t oh, dim_t ow) {
154                 const dim_t data_p_off = get_offset(dst_d, mb, ch, od, oh, ow);
155                 const dim_t data_l_off
156                         = (((mb * C + ch) * OD + od) * OH + oh) * OW + ow;
157                 float res = 0.f;
158 
159                 if (alg == alg_kind::resampling_nearest) {
160                     const dim_t id = nearest_idx(od, OD, ID);
161                     const dim_t ih = nearest_idx(oh, OH, IH);
162                     const dim_t iw = nearest_idx(ow, OW, IW);
163                     res = load_fn(src, get_offset(src_d, mb, ch, id, ih, iw));
164                 } else if (alg == alg_kind::resampling_linear) {
165                     // Trilinear interpolation (linear interpolation on a 3D spatial
166                     // tensor) can be expressed as linear interpolation along
167                     // dimension x followed by interpolation along dimension y and z
168                     //      C011--C11--C111
169                     //     -          - |
170                     //   -          -   |
171                     //C001--C01--C111   |
172                     // -     .C   -    C110
173                     // -          -    -
174                     // -          -  -
175                     //C000--C00--C100
176                     auto id = linear_coeffs_t(od, OD, ID);
177                     auto iw = linear_coeffs_t(ow, OW, IW);
178                     auto ih = linear_coeffs_t(oh, OH, IH);
179                     float src_l[8] = {0};
180                     for_(int i = 0; i < 2; i++)
181                     for_(int j = 0; j < 2; j++)
182                     for (int k = 0; k < 2; k++) {
183                         src_l[4 * i + 2 * j + k] = load_fn(src,
184                                 get_offset(src_d, mb, ch, id.idx[i], ih.idx[j],
185                                         iw.idx[k]));
186                     }
187                     res = trilin_interp(src_l[0], src_l[1], src_l[2], src_l[3],
188                             src_l[4], src_l[5], src_l[6], src_l[7], id.wei[0],
189                             ih.wei[0], iw.wei[0]);
190                 }
191 
192                 ref_post_ops_t::args_t args;
193                 args.ctx = &ctx;
194                 args.dst_md = pd()->dst_md();
195                 args.l_offset = data_l_off;
196                 args.dst_val = dst[data_p_off];
197                 ref_post_ops_.execute(res, args);
198 
199                 store_fn(res, dst, data_p_off);
200             });
201 }
202 
ref_resampling_bwd_t(const pd_t * apd)203 ref_resampling_bwd_t::ref_resampling_bwd_t(const pd_t *apd)
204     : primitive_t(apd) {}
205 
206 ref_resampling_bwd_t::~ref_resampling_bwd_t() = default;
207 
execute_backward(const exec_ctx_t & ctx) const208 void ref_resampling_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
209     if (this->pd()->has_zero_dim_memory()) return;
210 
211     const auto diff_dst = CTX_IN_MEM(const byte *, DNNL_ARG_DIFF_DST);
212     auto diff_src = CTX_OUT_MEM(byte *, DNNL_ARG_DIFF_SRC);
213 
214     const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
215     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
216 
217     const data_type_t diff_dst_dt = pd()->diff_dst_md()->data_type;
218     const data_type_t diff_src_dt = pd()->diff_src_md()->data_type;
219 
220     load_fn_t load_fn = create_load(diff_dst_dt);
221     store_fn_t store_fn = create_store(diff_src_dt);
222 
223     const auto alg = pd()->desc()->alg_kind;
224 
225     const int MB = pd()->MB();
226     const int C = pd()->C();
227     const int ID = pd()->ID();
228     const int IH = pd()->IH();
229     const int IW = pd()->IW();
230     const int OD = pd()->OD();
231     const int OH = pd()->OH();
232     const int OW = pd()->OW();
233 
234     if (alg == alg_kind::resampling_nearest) {
235         parallel_nd(MB, C, ID, IH, IW,
236                 [&](dim_t mb, dim_t ch, dim_t id, dim_t ih, dim_t iw) {
237                     const dim_t od_start
238                             = ceil_idx(((float)id * OD / ID) - 0.5f);
239                     const dim_t oh_start
240                             = ceil_idx(((float)ih * OH / IH) - 0.5f);
241                     const dim_t ow_start
242                             = ceil_idx(((float)iw * OW / IW) - 0.5f);
243 
244                     const dim_t od_end
245                             = ceil_idx(((id + 1.f) * OD / ID) - 0.5f);
246                     const dim_t oh_end
247                             = ceil_idx(((ih + 1.f) * OH / IH) - 0.5f);
248                     const dim_t ow_end
249                             = ceil_idx(((iw + 1.f) * OW / IW) - 0.5f);
250 
251                     float ds = 0;
252                     for_(dim_t od = od_start; od < od_end; od++)
253                     for_(dim_t oh = oh_start; oh < oh_end; oh++)
254                     for (dim_t ow = ow_start; ow < ow_end; ow++)
255                         ds += load_fn(diff_dst,
256                                 get_offset(diff_dst_d, mb, ch, od, oh, ow));
257                     store_fn(ds, diff_src,
258                             get_offset(diff_src_d, mb, ch, id, ih, iw));
259                 });
260     } else {
261         parallel_nd(MB, C, ID, IH, IW,
262                 [&](dim_t mb, dim_t ch, dim_t id, dim_t ih, dim_t iw) {
263                     bwd_linear_coeffs_t d(id, OD, ID);
264                     bwd_linear_coeffs_t h(ih, OH, IH);
265                     bwd_linear_coeffs_t w(iw, OW, IW);
266 
267                     float ds = 0;
268                     for_(int i = 0; i < 2; i++)
269                     for_(int j = 0; j < 2; j++)
270                     for_(int k = 0; k < 2; k++)
271                     for_(dim_t od = d.start[i]; od < d.end[i]; od++)
272                     for_(dim_t oh = h.start[j]; oh < h.end[j]; oh++)
273                     for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) {
274                         const float weight_d = linear_weight(i, od, OD, ID);
275                         const float weight_h = linear_weight(j, oh, OH, IH);
276                         const float weight_w = linear_weight(k, ow, OW, IW);
277 
278                         float dd = load_fn(diff_dst,
279                                 get_offset(diff_dst_d, mb, ch, od, oh, ow));
280                         ds += dd * weight_d * weight_h * weight_w;
281                     }
282                     store_fn(ds, diff_src,
283                             get_offset(diff_src_d, mb, ch, id, ih, iw));
284                 });
285     }
286 }
287 
288 } // namespace cpu
289 } // namespace impl
290 } // namespace dnnl
291 
292 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
293