1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24
25 #include "cpu/simple_q10n.hpp"
26
27 #include "cpu/ref_pooling.hpp"
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32
get_offset(const memory_desc_wrapper & mdw,dim_t n,dim_t c,dim_t d,dim_t h,dim_t w)33 static inline dim_t get_offset(const memory_desc_wrapper &mdw, dim_t n, dim_t c,
34 dim_t d, dim_t h, dim_t w) {
35 switch (mdw.ndims()) {
36 case 3: return mdw.off(n, c, w);
37 case 4: return mdw.off(n, c, h, w);
38 case 5: return mdw.off(n, c, d, h, w);
39 default: assert(!"Invalid tensor dimension in pooling");
40 }
41 return 0;
42 }
43
44 using namespace nstl;
45
46 template <data_type_t data_type, data_type_t acc_type>
execute_forward(const exec_ctx_t & ctx) const47 status_t ref_pooling_fwd_t<data_type, acc_type>::execute_forward(
48 const exec_ctx_t &ctx) const {
49
50 status_t status = status::success;
51 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
52 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
53 CHECK(status);
54 auto ws = CTX_OUT_CLEAN_MEM(unsigned char *, DNNL_ARG_WORKSPACE, status);
55 CHECK(status);
56
57 const memory_desc_wrapper src_d(pd()->src_md());
58 const memory_desc_wrapper dst_d(pd()->dst_md());
59 const memory_desc_wrapper ws_d(pd()->workspace_md());
60
61 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
62 if (ws) assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
63
64 const auto alg = pd()->desc()->alg_kind;
65 const dim_t MB = pd()->MB();
66 const dim_t OC = pd()->OC();
67 const dim_t OD = pd()->OD();
68 const dim_t OH = pd()->OH();
69 const dim_t OW = pd()->OW();
70 const dim_t ID = pd()->ID();
71 const dim_t IH = pd()->IH();
72 const dim_t IW = pd()->IW();
73 const dim_t KD = pd()->KD();
74 const dim_t KH = pd()->KH();
75 const dim_t KW = pd()->KW();
76 const dim_t SD = pd()->KSD();
77 const dim_t SH = pd()->KSH();
78 const dim_t SW = pd()->KSW();
79 const dim_t padF = pd()->padFront();
80 const dim_t padT = pd()->padT();
81 const dim_t padL = pd()->padL();
82 const dim_t DD = pd()->KDD();
83 const dim_t DH = pd()->KDH();
84 const dim_t DW = pd()->KDW();
85
86 auto set_ws = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow,
87 dim_t value) {
88 if (ws) {
89 const auto off = get_offset(ws_d, mb, oc, od, oh, ow);
90 if (ws_dt == data_type::u8) {
91 assert(0 <= value
92 && value <= numeric_limits<typename prec_traits<
93 data_type::u8>::type>::max());
94 ws[off] = value;
95 } else
96 reinterpret_cast<int *>(ws)[off] = value;
97 }
98 };
99
100 auto ker_max = [=](float &d, dim_t mb, dim_t oc, dim_t od, dim_t oh,
101 dim_t ow) {
102 set_ws(mb, oc, od, oh, ow, 0);
103 for (dim_t kd = 0; kd < KD; ++kd) {
104 const dim_t id = od * SD - padF + kd * (DD + 1);
105 if (id < 0 || id >= ID) continue;
106 for (dim_t kh = 0; kh < KH; ++kh) {
107 const dim_t ih = oh * SH - padT + kh * (DH + 1);
108 if (ih < 0 || ih >= IH) continue;
109 for (dim_t kw = 0; kw < KW; ++kw) {
110 const dim_t iw = ow * SW - padL + kw * (DW + 1);
111 if (iw < 0 || iw >= IW) continue;
112
113 const auto off = get_offset(src_d, mb, oc, id, ih, iw);
114 auto s = src[off];
115 if (s > d) {
116 d = s;
117 set_ws(mb, oc, od, oh, ow, (kd * KH + kh) * KW + kw);
118 }
119 }
120 }
121 }
122 };
123
124 auto ker_avg = [=](float &d, dim_t mb, dim_t oc, dim_t od, dim_t oh,
125 dim_t ow) {
126 for (dim_t kd = 0; kd < KD; ++kd) {
127 const dim_t id = od * SD - padF + kd * (DD + 1);
128 if (id < 0 || id >= ID) continue;
129 for (dim_t kh = 0; kh < KH; ++kh) {
130 const dim_t ih = oh * SH - padT + kh * (DH + 1);
131 if (ih < 0 || ih >= IH) continue;
132 for (dim_t kw = 0; kw < KW; ++kw) {
133 const dim_t iw = ow * SW - padL + kw * (DW + 1);
134 if (iw < 0 || iw >= IW) continue;
135
136 const auto off = get_offset(src_d, mb, oc, id, ih, iw);
137 d += src[off];
138 }
139 }
140 }
141 int num_summands;
142 if (alg == alg_kind::pooling_avg_include_padding)
143 num_summands = KW * KH * KD;
144 else {
145 auto id_start = od * SD - padF;
146 auto ih_start = oh * SH - padT;
147 auto iw_start = ow * SW - padL;
148 auto id_end = od * SD - padF + (KD - 1) * DD + KD;
149 auto ih_end = oh * SH - padT + (KH - 1) * DH + KH;
150 auto iw_end = ow * SW - padL + (KW - 1) * DW + KW;
151
152 auto id_start_excluded
153 = id_start < 0 ? (0 - id_start - 1) / (DD + 1) + 1 : 0;
154 auto ih_start_excluded
155 = ih_start < 0 ? (0 - ih_start - 1) / (DH + 1) + 1 : 0;
156 auto iw_start_excluded
157 = iw_start < 0 ? (0 - iw_start - 1) / (DW + 1) + 1 : 0;
158 auto id_end_excluded
159 = id_end > ID ? (id_end - ID - 1) / (DD + 1) + 1 : 0;
160 auto ih_end_excluded
161 = ih_end > IH ? (ih_end - IH - 1) / (DH + 1) + 1 : 0;
162 auto iw_end_excluded
163 = iw_end > IW ? (iw_end - IW - 1) / (DW + 1) + 1 : 0;
164
165 num_summands = (KD - id_start_excluded - id_end_excluded)
166 * (KH - ih_start_excluded - ih_end_excluded)
167 * (KW - iw_start_excluded - iw_end_excluded);
168 }
169 d /= num_summands;
170 };
171
172 const bool is_max_pool = alg == alg_kind::pooling_max;
173
174 float base_res
175 = is_max_pool ? (float)numeric_limits<data_t>::lowest() : 0.f;
176 using ker_t
177 = std::function<void(float &, dim_t, dim_t, dim_t, dim_t, dim_t)>;
178 ker_t kernel = is_max_pool ? (ker_t)ker_max : (ker_t)ker_avg;
179
180 parallel_nd(MB, OC, OD, OH, OW,
181 [&](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
182 auto data_p_off = get_offset(dst_d, mb, oc, od, oh, ow);
183 auto data_l_off
184 = (((mb * OC + oc) * OD + od) * OH + oh) * OW + ow;
185 float res = base_res;
186 kernel(res, mb, oc, od, oh, ow);
187
188 ref_post_ops_t::args_t args;
189 args.ctx = &ctx;
190 args.l_offset = data_l_off;
191 args.dst_md = pd()->dst_md();
192 ref_post_ops->execute(res, args);
193
194 dst[data_p_off] = cpu::saturate_and_round<data_t>(res);
195 });
196
197 return status::success;
198 }
199
200 template <data_type_t data_type>
execute_backward(const exec_ctx_t & ctx) const201 status_t ref_pooling_bwd_t<data_type>::execute_backward(
202 const exec_ctx_t &ctx) const {
203
204 status_t status = status::success;
205
206 const auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
207 const auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
208 auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
209 CHECK(status);
210
211 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
212 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
213 const memory_desc_wrapper ws_d(pd()->workspace_md());
214
215 const auto alg = pd()->desc()->alg_kind;
216 const dim_t MB = pd()->MB();
217 const dim_t OC = pd()->OC();
218 const dim_t OD = pd()->OD();
219 const dim_t OH = pd()->OH();
220 const dim_t OW = pd()->OW();
221 const dim_t ID = pd()->ID();
222 const dim_t IH = pd()->IH();
223 const dim_t IW = pd()->IW();
224 const dim_t KD = pd()->KD();
225 const dim_t KH = pd()->KH();
226 const dim_t KW = pd()->KW();
227 const dim_t SD = pd()->KSD();
228 const dim_t SH = pd()->KSH();
229 const dim_t SW = pd()->KSW();
230 const dim_t padF = pd()->padFront();
231 const dim_t padT = pd()->padT();
232 const dim_t padL = pd()->padL();
233 const dim_t DD = pd()->KDD();
234 const dim_t DH = pd()->KDH();
235 const dim_t DW = pd()->KDW();
236
237 auto ker_zero = [=](dim_t mb, dim_t oc) {
238 for_(dim_t id = 0; id < ID; ++id)
239 for_(dim_t ih = 0; ih < IH; ++ih)
240 for (dim_t iw = 0; iw < IW; ++iw) {
241 const auto off = get_offset(diff_src_d, mb, oc, id, ih, iw);
242 diff_src[off] = data_type_t(0);
243 }
244 };
245
246 auto ker_max = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
247 const auto ws_off = get_offset(ws_d, mb, oc, od, oh, ow);
248 const int index = ws_d.data_type() == data_type::u8
249 ? (int)ws[ws_off]
250 : ((int *)ws)[ws_off];
251 const dim_t kd = (index / KW) / KH;
252 const dim_t kh = (index / KW) % KH;
253 const dim_t kw = index % KW;
254 const dim_t id = od * SD - padF + kd * (DD + 1);
255 const dim_t ih = oh * SH - padT + kh * (DH + 1);
256 const dim_t iw = ow * SW - padL + kw * (DW + 1);
257
258 // If padding area could fit the kernel,
259 // then input displacement would be out of bounds.
260 // No need to back propagate there as padding is
261 // virtual in pooling_max case.
262 if (id < 0 || id >= ID) return;
263 if (ih < 0 || ih >= IH) return;
264 if (iw < 0 || iw >= IW) return;
265
266 const auto d_src_off = get_offset(diff_src_d, mb, oc, id, ih, iw);
267 const auto d_dst_off = get_offset(diff_dst_d, mb, oc, od, oh, ow);
268 diff_src[d_src_off] += diff_dst[d_dst_off];
269 };
270
271 auto ker_avg = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
272 int num_summands;
273 if (alg == alg_kind::pooling_avg_include_padding)
274 num_summands = KW * KH * KD;
275 else {
276 auto id_start = od * SD - padF;
277 auto ih_start = oh * SH - padT;
278 auto iw_start = ow * SW - padL;
279 auto id_end = od * SD - padF + (KD - 1) * DD + KD;
280 auto ih_end = oh * SH - padT + (KH - 1) * DH + KH;
281 auto iw_end = ow * SW - padL + (KW - 1) * DW + KW;
282
283 auto id_start_excluded
284 = id_start < 0 ? (0 - id_start - 1) / (DD + 1) + 1 : 0;
285 auto ih_start_excluded
286 = ih_start < 0 ? (0 - ih_start - 1) / (DH + 1) + 1 : 0;
287 auto iw_start_excluded
288 = iw_start < 0 ? (0 - iw_start - 1) / (DW + 1) + 1 : 0;
289 auto id_end_excluded
290 = id_end > ID ? (id_end - ID - 1) / (DD + 1) + 1 : 0;
291 auto ih_end_excluded
292 = ih_end > IH ? (ih_end - IH - 1) / (DH + 1) + 1 : 0;
293 auto iw_end_excluded
294 = iw_end > IW ? (iw_end - IW - 1) / (DW + 1) + 1 : 0;
295
296 num_summands = (KD - id_start_excluded - id_end_excluded)
297 * (KH - ih_start_excluded - ih_end_excluded)
298 * (KW - iw_start_excluded - iw_end_excluded);
299 }
300 for (dim_t kd = 0; kd < KD; ++kd) {
301 const dim_t id = od * SD - padF + kd * (DD + 1);
302 if (id < 0 || id >= ID) continue;
303 for (dim_t kh = 0; kh < KH; ++kh) {
304 const dim_t ih = oh * SH - padT + kh * (DH + 1);
305 if (ih < 0 || ih >= IH) continue;
306 for (dim_t kw = 0; kw < KW; ++kw) {
307 const dim_t iw = ow * SW - padL + kw * (DW + 1);
308 if (iw < 0 || iw >= IW) continue;
309
310 const auto d_src_off
311 = get_offset(diff_src_d, mb, oc, id, ih, iw);
312 const auto d_dst_off
313 = get_offset(diff_dst_d, mb, oc, od, oh, ow);
314 diff_src[d_src_off] += diff_dst[d_dst_off] / num_summands;
315 }
316 }
317 }
318 };
319
320 dim_t ow_start
321 = max(dim_t(0), utils::div_up(padL - ((KW - 1) * DW + KW) + 1, SW));
322 dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW);
323
324 dim_t oh_start
325 = max(dim_t(0), utils::div_up(padT - ((KH - 1) * DH + KH) + 1, SH));
326 dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH);
327
328 dim_t od_start
329 = max(dim_t(0), utils::div_up(padF - ((KD - 1) * DD + KD) + 1, SD));
330 dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD);
331
332 using ker_t = std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t)>;
333 ker_t kernel
334 = alg == alg_kind::pooling_max ? (ker_t)ker_max : (ker_t)ker_avg;
335
336 parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
337 ker_zero(mb, oc);
338 for_(dim_t od = od_start; od < od_end; ++od)
339 for_(dim_t oh = oh_start; oh < oh_end; ++oh)
340 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
341 kernel(mb, oc, od, oh, ow);
342 }
343 });
344
345 return status::success;
346 }
347
348 template struct ref_pooling_fwd_t<data_type::f32>;
349 template struct ref_pooling_fwd_t<data_type::s32>;
350 template struct ref_pooling_fwd_t<data_type::bf16, data_type::f32>;
351 template struct ref_pooling_fwd_t<data_type::s8, data_type::s32>;
352 template struct ref_pooling_fwd_t<data_type::u8, data_type::s32>;
353
354 template struct ref_pooling_bwd_t<data_type::f32>;
355 template struct ref_pooling_bwd_t<data_type::bf16>;
356 } // namespace cpu
357 } // namespace impl
358 } // namespace dnnl
359
360 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
361