1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 #include <math.h>
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 
25 #include "cpu/simple_q10n.hpp"
26 
27 #include "cpu/ref_pooling.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 
get_offset(const memory_desc_wrapper & mdw,dim_t n,dim_t c,dim_t d,dim_t h,dim_t w)33 static inline dim_t get_offset(const memory_desc_wrapper &mdw, dim_t n, dim_t c,
34         dim_t d, dim_t h, dim_t w) {
35     switch (mdw.ndims()) {
36         case 3: return mdw.off(n, c, w);
37         case 4: return mdw.off(n, c, h, w);
38         case 5: return mdw.off(n, c, d, h, w);
39         default: assert(!"Invalid tensor dimension in pooling");
40     }
41     return 0;
42 }
43 
44 using namespace nstl;
45 
46 template <data_type_t data_type, data_type_t acc_type>
execute_forward(const exec_ctx_t & ctx) const47 status_t ref_pooling_fwd_t<data_type, acc_type>::execute_forward(
48         const exec_ctx_t &ctx) const {
49 
50     status_t status = status::success;
51     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
52     auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
53     CHECK(status);
54     auto ws = CTX_OUT_CLEAN_MEM(unsigned char *, DNNL_ARG_WORKSPACE, status);
55     CHECK(status);
56 
57     const memory_desc_wrapper src_d(pd()->src_md());
58     const memory_desc_wrapper dst_d(pd()->dst_md());
59     const memory_desc_wrapper ws_d(pd()->workspace_md());
60 
61     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
62     if (ws) assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
63 
64     const auto alg = pd()->desc()->alg_kind;
65     const dim_t MB = pd()->MB();
66     const dim_t OC = pd()->OC();
67     const dim_t OD = pd()->OD();
68     const dim_t OH = pd()->OH();
69     const dim_t OW = pd()->OW();
70     const dim_t ID = pd()->ID();
71     const dim_t IH = pd()->IH();
72     const dim_t IW = pd()->IW();
73     const dim_t KD = pd()->KD();
74     const dim_t KH = pd()->KH();
75     const dim_t KW = pd()->KW();
76     const dim_t SD = pd()->KSD();
77     const dim_t SH = pd()->KSH();
78     const dim_t SW = pd()->KSW();
79     const dim_t padF = pd()->padFront();
80     const dim_t padT = pd()->padT();
81     const dim_t padL = pd()->padL();
82     const dim_t DD = pd()->KDD();
83     const dim_t DH = pd()->KDH();
84     const dim_t DW = pd()->KDW();
85 
86     auto set_ws = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow,
87                           dim_t value) {
88         if (ws) {
89             const auto off = get_offset(ws_d, mb, oc, od, oh, ow);
90             if (ws_dt == data_type::u8) {
91                 assert(0 <= value
92                         && value <= numeric_limits<typename prec_traits<
93                                         data_type::u8>::type>::max());
94                 ws[off] = value;
95             } else
96                 reinterpret_cast<int *>(ws)[off] = value;
97         }
98     };
99 
100     auto ker_max = [=](float &d, dim_t mb, dim_t oc, dim_t od, dim_t oh,
101                            dim_t ow) {
102         set_ws(mb, oc, od, oh, ow, 0);
103         for (dim_t kd = 0; kd < KD; ++kd) {
104             const dim_t id = od * SD - padF + kd * (DD + 1);
105             if (id < 0 || id >= ID) continue;
106             for (dim_t kh = 0; kh < KH; ++kh) {
107                 const dim_t ih = oh * SH - padT + kh * (DH + 1);
108                 if (ih < 0 || ih >= IH) continue;
109                 for (dim_t kw = 0; kw < KW; ++kw) {
110                     const dim_t iw = ow * SW - padL + kw * (DW + 1);
111                     if (iw < 0 || iw >= IW) continue;
112 
113                     const auto off = get_offset(src_d, mb, oc, id, ih, iw);
114                     auto s = src[off];
115                     if (s > d) {
116                         d = s;
117                         set_ws(mb, oc, od, oh, ow, (kd * KH + kh) * KW + kw);
118                     }
119                 }
120             }
121         }
122     };
123 
124     auto ker_avg = [=](float &d, dim_t mb, dim_t oc, dim_t od, dim_t oh,
125                            dim_t ow) {
126         for (dim_t kd = 0; kd < KD; ++kd) {
127             const dim_t id = od * SD - padF + kd * (DD + 1);
128             if (id < 0 || id >= ID) continue;
129             for (dim_t kh = 0; kh < KH; ++kh) {
130                 const dim_t ih = oh * SH - padT + kh * (DH + 1);
131                 if (ih < 0 || ih >= IH) continue;
132                 for (dim_t kw = 0; kw < KW; ++kw) {
133                     const dim_t iw = ow * SW - padL + kw * (DW + 1);
134                     if (iw < 0 || iw >= IW) continue;
135 
136                     const auto off = get_offset(src_d, mb, oc, id, ih, iw);
137                     d += src[off];
138                 }
139             }
140         }
141         int num_summands;
142         if (alg == alg_kind::pooling_avg_include_padding)
143             num_summands = KW * KH * KD;
144         else {
145             auto id_start = od * SD - padF;
146             auto ih_start = oh * SH - padT;
147             auto iw_start = ow * SW - padL;
148             auto id_end = od * SD - padF + (KD - 1) * DD + KD;
149             auto ih_end = oh * SH - padT + (KH - 1) * DH + KH;
150             auto iw_end = ow * SW - padL + (KW - 1) * DW + KW;
151 
152             auto id_start_excluded
153                     = id_start < 0 ? (0 - id_start - 1) / (DD + 1) + 1 : 0;
154             auto ih_start_excluded
155                     = ih_start < 0 ? (0 - ih_start - 1) / (DH + 1) + 1 : 0;
156             auto iw_start_excluded
157                     = iw_start < 0 ? (0 - iw_start - 1) / (DW + 1) + 1 : 0;
158             auto id_end_excluded
159                     = id_end > ID ? (id_end - ID - 1) / (DD + 1) + 1 : 0;
160             auto ih_end_excluded
161                     = ih_end > IH ? (ih_end - IH - 1) / (DH + 1) + 1 : 0;
162             auto iw_end_excluded
163                     = iw_end > IW ? (iw_end - IW - 1) / (DW + 1) + 1 : 0;
164 
165             num_summands = (KD - id_start_excluded - id_end_excluded)
166                     * (KH - ih_start_excluded - ih_end_excluded)
167                     * (KW - iw_start_excluded - iw_end_excluded);
168         }
169         d /= num_summands;
170     };
171 
172     const bool is_max_pool = alg == alg_kind::pooling_max;
173 
174     float base_res
175             = is_max_pool ? (float)numeric_limits<data_t>::lowest() : 0.f;
176     using ker_t
177             = std::function<void(float &, dim_t, dim_t, dim_t, dim_t, dim_t)>;
178     ker_t kernel = is_max_pool ? (ker_t)ker_max : (ker_t)ker_avg;
179 
180     parallel_nd(MB, OC, OD, OH, OW,
181             [&](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
182                 auto data_p_off = get_offset(dst_d, mb, oc, od, oh, ow);
183                 auto data_l_off
184                         = (((mb * OC + oc) * OD + od) * OH + oh) * OW + ow;
185                 float res = base_res;
186                 kernel(res, mb, oc, od, oh, ow);
187 
188                 ref_post_ops_t::args_t args;
189                 args.ctx = &ctx;
190                 args.l_offset = data_l_off;
191                 args.dst_md = pd()->dst_md();
192                 ref_post_ops->execute(res, args);
193 
194                 dst[data_p_off] = cpu::saturate_and_round<data_t>(res);
195             });
196 
197     return status::success;
198 }
199 
200 template <data_type_t data_type>
execute_backward(const exec_ctx_t & ctx) const201 status_t ref_pooling_bwd_t<data_type>::execute_backward(
202         const exec_ctx_t &ctx) const {
203 
204     status_t status = status::success;
205 
206     const auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
207     const auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
208     auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
209     CHECK(status);
210 
211     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
212     const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
213     const memory_desc_wrapper ws_d(pd()->workspace_md());
214 
215     const auto alg = pd()->desc()->alg_kind;
216     const dim_t MB = pd()->MB();
217     const dim_t OC = pd()->OC();
218     const dim_t OD = pd()->OD();
219     const dim_t OH = pd()->OH();
220     const dim_t OW = pd()->OW();
221     const dim_t ID = pd()->ID();
222     const dim_t IH = pd()->IH();
223     const dim_t IW = pd()->IW();
224     const dim_t KD = pd()->KD();
225     const dim_t KH = pd()->KH();
226     const dim_t KW = pd()->KW();
227     const dim_t SD = pd()->KSD();
228     const dim_t SH = pd()->KSH();
229     const dim_t SW = pd()->KSW();
230     const dim_t padF = pd()->padFront();
231     const dim_t padT = pd()->padT();
232     const dim_t padL = pd()->padL();
233     const dim_t DD = pd()->KDD();
234     const dim_t DH = pd()->KDH();
235     const dim_t DW = pd()->KDW();
236 
237     auto ker_zero = [=](dim_t mb, dim_t oc) {
238         for_(dim_t id = 0; id < ID; ++id)
239         for_(dim_t ih = 0; ih < IH; ++ih)
240         for (dim_t iw = 0; iw < IW; ++iw) {
241             const auto off = get_offset(diff_src_d, mb, oc, id, ih, iw);
242             diff_src[off] = data_type_t(0);
243         }
244     };
245 
246     auto ker_max = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
247         const auto ws_off = get_offset(ws_d, mb, oc, od, oh, ow);
248         const int index = ws_d.data_type() == data_type::u8
249                 ? (int)ws[ws_off]
250                 : ((int *)ws)[ws_off];
251         const dim_t kd = (index / KW) / KH;
252         const dim_t kh = (index / KW) % KH;
253         const dim_t kw = index % KW;
254         const dim_t id = od * SD - padF + kd * (DD + 1);
255         const dim_t ih = oh * SH - padT + kh * (DH + 1);
256         const dim_t iw = ow * SW - padL + kw * (DW + 1);
257 
258         // If padding area could fit the kernel,
259         // then input displacement would be out of bounds.
260         // No need to back propagate there as padding is
261         // virtual in pooling_max case.
262         if (id < 0 || id >= ID) return;
263         if (ih < 0 || ih >= IH) return;
264         if (iw < 0 || iw >= IW) return;
265 
266         const auto d_src_off = get_offset(diff_src_d, mb, oc, id, ih, iw);
267         const auto d_dst_off = get_offset(diff_dst_d, mb, oc, od, oh, ow);
268         diff_src[d_src_off] += diff_dst[d_dst_off];
269     };
270 
271     auto ker_avg = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
272         int num_summands;
273         if (alg == alg_kind::pooling_avg_include_padding)
274             num_summands = KW * KH * KD;
275         else {
276             auto id_start = od * SD - padF;
277             auto ih_start = oh * SH - padT;
278             auto iw_start = ow * SW - padL;
279             auto id_end = od * SD - padF + (KD - 1) * DD + KD;
280             auto ih_end = oh * SH - padT + (KH - 1) * DH + KH;
281             auto iw_end = ow * SW - padL + (KW - 1) * DW + KW;
282 
283             auto id_start_excluded
284                     = id_start < 0 ? (0 - id_start - 1) / (DD + 1) + 1 : 0;
285             auto ih_start_excluded
286                     = ih_start < 0 ? (0 - ih_start - 1) / (DH + 1) + 1 : 0;
287             auto iw_start_excluded
288                     = iw_start < 0 ? (0 - iw_start - 1) / (DW + 1) + 1 : 0;
289             auto id_end_excluded
290                     = id_end > ID ? (id_end - ID - 1) / (DD + 1) + 1 : 0;
291             auto ih_end_excluded
292                     = ih_end > IH ? (ih_end - IH - 1) / (DH + 1) + 1 : 0;
293             auto iw_end_excluded
294                     = iw_end > IW ? (iw_end - IW - 1) / (DW + 1) + 1 : 0;
295 
296             num_summands = (KD - id_start_excluded - id_end_excluded)
297                     * (KH - ih_start_excluded - ih_end_excluded)
298                     * (KW - iw_start_excluded - iw_end_excluded);
299         }
300         for (dim_t kd = 0; kd < KD; ++kd) {
301             const dim_t id = od * SD - padF + kd * (DD + 1);
302             if (id < 0 || id >= ID) continue;
303             for (dim_t kh = 0; kh < KH; ++kh) {
304                 const dim_t ih = oh * SH - padT + kh * (DH + 1);
305                 if (ih < 0 || ih >= IH) continue;
306                 for (dim_t kw = 0; kw < KW; ++kw) {
307                     const dim_t iw = ow * SW - padL + kw * (DW + 1);
308                     if (iw < 0 || iw >= IW) continue;
309 
310                     const auto d_src_off
311                             = get_offset(diff_src_d, mb, oc, id, ih, iw);
312                     const auto d_dst_off
313                             = get_offset(diff_dst_d, mb, oc, od, oh, ow);
314                     diff_src[d_src_off] += diff_dst[d_dst_off] / num_summands;
315                 }
316             }
317         }
318     };
319 
320     dim_t ow_start
321             = max(dim_t(0), utils::div_up(padL - ((KW - 1) * DW + KW) + 1, SW));
322     dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW);
323 
324     dim_t oh_start
325             = max(dim_t(0), utils::div_up(padT - ((KH - 1) * DH + KH) + 1, SH));
326     dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH);
327 
328     dim_t od_start
329             = max(dim_t(0), utils::div_up(padF - ((KD - 1) * DD + KD) + 1, SD));
330     dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD);
331 
332     using ker_t = std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t)>;
333     ker_t kernel
334             = alg == alg_kind::pooling_max ? (ker_t)ker_max : (ker_t)ker_avg;
335 
336     parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
337         ker_zero(mb, oc);
338         for_(dim_t od = od_start; od < od_end; ++od)
339         for_(dim_t oh = oh_start; oh < oh_end; ++oh)
340         for (dim_t ow = ow_start; ow < ow_end; ++ow) {
341             kernel(mb, oc, od, oh, ow);
342         }
343     });
344 
345     return status::success;
346 }
347 
348 template struct ref_pooling_fwd_t<data_type::f32>;
349 template struct ref_pooling_fwd_t<data_type::s32>;
350 template struct ref_pooling_fwd_t<data_type::bf16, data_type::f32>;
351 template struct ref_pooling_fwd_t<data_type::s8, data_type::s32>;
352 template struct ref_pooling_fwd_t<data_type::u8, data_type::s32>;
353 
354 template struct ref_pooling_bwd_t<data_type::f32>;
355 template struct ref_pooling_bwd_t<data_type::bf16>;
356 } // namespace cpu
357 } // namespace impl
358 } // namespace dnnl
359 
360 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
361