1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef COMMON_POOLING_PD_HPP
18 #define COMMON_POOLING_PD_HPP
19 
20 #include "oneapi/dnnl/dnnl.h"
21 
22 #include "c_types_map.hpp"
23 #include "primitive_desc.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 
30 struct pooling_fwd_pd_t;
31 
32 struct pooling_pd_t : public primitive_desc_t {
33     static constexpr auto base_pkind = primitive_kind::pooling_v2;
34 
descdnnl::impl::pooling_pd_t35     const pooling_v2_desc_t *desc() const { return &desc_; }
op_descdnnl::impl::pooling_pd_t36     const op_desc_t *op_desc() const override {
37         return reinterpret_cast<const op_desc_t *>(this->desc());
38     }
39 
querydnnl::impl::pooling_pd_t40     status_t query(query_t what, int idx, void *result) const override {
41         switch (what) {
42             case query::prop_kind:
43                 *(prop_kind_t *)result = desc()->prop_kind;
44                 break;
45             case query::pooling_d:
46                 *(const pooling_desc_t **)result
47                         = reinterpret_cast<const pooling_desc_t *>(desc());
48                 break;
49             case query::pooling_v2_d:
50                 *(const pooling_v2_desc_t **)result = desc();
51                 break;
52             case query::primitive_kind:
53                 *(primitive_kind_t *)result = desc_.primitive_kind;
54                 break;
55             default: return primitive_desc_t::query(what, idx, result);
56         }
57         return status::success;
58     }
59 
60     /* common pooling aux functions */
61 
MBdnnl::impl::pooling_pd_t62     dim_t MB() const { return src_desc().dims[0]; }
ICdnnl::impl::pooling_pd_t63     dim_t IC() const { return src_desc().dims[1]; }
OCdnnl::impl::pooling_pd_t64     dim_t OC() const { return IC(); }
65 
IDdnnl::impl::pooling_pd_t66     dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; }
IHdnnl::impl::pooling_pd_t67     dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; }
IWdnnl::impl::pooling_pd_t68     dim_t IW() const { return src_desc().dims[ndims() - 1]; }
69 
ODdnnl::impl::pooling_pd_t70     dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; }
OHdnnl::impl::pooling_pd_t71     dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; }
OWdnnl::impl::pooling_pd_t72     dim_t OW() const { return dst_desc().dims[ndims() - 1]; }
73 
KDdnnl::impl::pooling_pd_t74     dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; }
KHdnnl::impl::pooling_pd_t75     dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; }
KWdnnl::impl::pooling_pd_t76     dim_t KW() const { return desc_.kernel[ndims() - 3]; }
77 
KSDdnnl::impl::pooling_pd_t78     dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
KSHdnnl::impl::pooling_pd_t79     dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
KSWdnnl::impl::pooling_pd_t80     dim_t KSW() const { return desc_.strides[ndims() - 3]; }
81 
KDDdnnl::impl::pooling_pd_t82     dim_t KDD() const {
83         return is_pooling_v2()
84                 ? (ndims() >= 5 ? desc_.dilation[ndims() - 5] : 0)
85                 : 0;
86     }
KDHdnnl::impl::pooling_pd_t87     dim_t KDH() const {
88         return is_pooling_v2()
89                 ? (ndims() >= 4 ? desc_.dilation[ndims() - 4] : 0)
90                 : 0;
91     }
KDWdnnl::impl::pooling_pd_t92     dim_t KDW() const {
93         return is_pooling_v2() ? desc_.dilation[ndims() - 3] : 0;
94     }
95 
padFrontdnnl::impl::pooling_pd_t96     dim_t padFront() const {
97         return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0;
98     }
padBackdnnl::impl::pooling_pd_t99     dim_t padBack() const {
100         return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0;
101     }
padTdnnl::impl::pooling_pd_t102     dim_t padT() const {
103         return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0;
104     }
padBdnnl::impl::pooling_pd_t105     dim_t padB() const {
106         return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0;
107     }
padLdnnl::impl::pooling_pd_t108     dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
padRdnnl::impl::pooling_pd_t109     dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
110 
ndimsdnnl::impl::pooling_pd_t111     int ndims() const { return src_desc().ndims; }
spatial_ndimsdnnl::impl::pooling_pd_t112     int spatial_ndims() const { return ndims() - 2; }
113 
is_pooling_v2dnnl::impl::pooling_pd_t114     bool is_pooling_v2() const {
115         return desc_.primitive_kind == primitive_kind::pooling_v2;
116     }
is_dilateddnnl::impl::pooling_pd_t117     bool is_dilated() const { return KDD() != 0 || KDH() != 0 || KDW() != 0; }
118 
has_zero_dim_memorydnnl::impl::pooling_pd_t119     bool has_zero_dim_memory() const {
120         return memory_desc_wrapper(src_desc()).has_zero_dim();
121     }
122 
is_fwddnnl::impl::pooling_pd_t123     bool is_fwd() const {
124         return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
125                 prop_kind::forward_inference);
126     }
127 
invariant_src_mddnnl::impl::pooling_pd_t128     const memory_desc_t *invariant_src_md() const {
129         return is_fwd() ? src_md() : diff_src_md();
130     }
131 
invariant_dst_mddnnl::impl::pooling_pd_t132     const memory_desc_t *invariant_dst_md() const {
133         return is_fwd() ? dst_md() : diff_dst_md();
134     }
135 
136 protected:
137     pooling_v2_desc_t desc_;
138     const pooling_fwd_pd_t *hint_fwd_pd_;
139 
140     memory_desc_t ws_md_;
141 
pooling_pd_tdnnl::impl::pooling_pd_t142     pooling_pd_t(const pooling_v2_desc_t *adesc, const primitive_attr_t *attr,
143             const pooling_fwd_pd_t *hint_fwd_pd)
144         : primitive_desc_t(attr, base_pkind)
145         , desc_(cast_pool_v1_to_v2(*adesc))
146         , hint_fwd_pd_(hint_fwd_pd)
147         , ws_md_() {}
148 
init_default_wsdnnl::impl::pooling_pd_t149     void init_default_ws(data_type_t dt = data_type::undef) {
150         ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md();
151         ws_md_.data_type = (dt != data_type::undef) ? dt : indices_data_type();
152     }
153 
indices_data_typednnl::impl::pooling_pd_t154     data_type_t indices_data_type() const {
155         /* the simplest way to express 256... */
156         const int u8_max = nstl::numeric_limits<
157                 typename prec_traits<data_type::u8>::type>::max();
158         return utils::array_product(desc()->kernel, spatial_ndims()) <= u8_max
159                 ? data_type::u8
160                 : data_type::s32;
161     }
162 
163 private:
src_descdnnl::impl::pooling_pd_t164     const memory_desc_t &src_desc() const {
165         return is_fwd() ? desc_.src_desc : desc_.diff_src_desc;
166     }
dst_descdnnl::impl::pooling_pd_t167     const memory_desc_t &dst_desc() const {
168         return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc;
169     }
170 
cast_pool_v1_to_v2dnnl::impl::pooling_pd_t171     pooling_v2_desc_t cast_pool_v1_to_v2(
172             const pooling_v2_desc_t &pool_desc) const {
173         if (pool_desc.primitive_kind == primitive_kind::pooling_v2)
174             return pool_desc;
175 
176         pooling_v2_desc_t pool_v2_desc;
177         pool_v2_desc.primitive_kind = primitive_kind::pooling;
178         pool_v2_desc.prop_kind = pool_desc.prop_kind;
179         pool_v2_desc.alg_kind = pool_desc.alg_kind;
180         pool_v2_desc.src_desc = pool_desc.src_desc;
181         pool_v2_desc.diff_src_desc = pool_desc.diff_src_desc;
182         pool_v2_desc.dst_desc = pool_desc.dst_desc;
183         pool_v2_desc.diff_dst_desc = pool_desc.diff_dst_desc;
184         utils::array_copy(
185                 pool_v2_desc.strides, pool_desc.strides, DNNL_MAX_NDIMS);
186         utils::array_copy(
187                 pool_v2_desc.kernel, pool_desc.kernel, DNNL_MAX_NDIMS);
188         utils::array_copy(
189                 pool_v2_desc.padding[0], pool_desc.padding[0], DNNL_MAX_NDIMS);
190         utils::array_copy(
191                 pool_v2_desc.padding[1], pool_desc.padding[1], DNNL_MAX_NDIMS);
192         utils::array_copy(
193                 pool_v2_desc.kernel, pool_desc.kernel, DNNL_MAX_NDIMS);
194         utils::array_copy(
195                 pool_v2_desc.kernel, pool_desc.kernel, DNNL_MAX_NDIMS);
196         utils::array_set(pool_v2_desc.dilation, 0, DNNL_MAX_NDIMS);
197         pool_v2_desc.accum_data_type = pool_desc.accum_data_type;
198 
199         return pool_v2_desc;
200     }
201 };
202 
203 struct pooling_fwd_pd_t : public pooling_pd_t {
204     typedef pooling_fwd_pd_t base_class;
205     typedef pooling_fwd_pd_t hint_class;
206 
arg_usagednnl::impl::pooling_fwd_pd_t207     arg_usage_t arg_usage(int arg) const override {
208         if (arg == DNNL_ARG_SRC) return arg_usage_t::input;
209 
210         if (arg == DNNL_ARG_DST) return arg_usage_t::output;
211 
212         if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md())))
213             return arg_usage_t::output;
214 
215         return primitive_desc_t::arg_usage(arg);
216     }
217 
arg_mddnnl::impl::pooling_fwd_pd_t218     const memory_desc_t *arg_md(int arg) const override {
219         switch (arg) {
220             case DNNL_ARG_SRC: return src_md(0);
221             case DNNL_ARG_DST: return dst_md(0);
222             default: return pooling_pd_t::arg_md(arg);
223         }
224     }
225 
src_mddnnl::impl::pooling_fwd_pd_t226     const memory_desc_t *src_md(int index = 0) const override {
227         return index == 0 ? &src_md_ : &glob_zero_md;
228     }
dst_mddnnl::impl::pooling_fwd_pd_t229     const memory_desc_t *dst_md(int index = 0) const override {
230         return index == 0 ? &dst_md_ : &glob_zero_md;
231     }
workspace_mddnnl::impl::pooling_fwd_pd_t232     const memory_desc_t *workspace_md(int index = 0) const override {
233         return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_
234                                                          : &glob_zero_md;
235     }
236 
n_inputsdnnl::impl::pooling_fwd_pd_t237     int n_inputs() const override { return 1 + n_binary_po_inputs(); }
n_outputsdnnl::impl::pooling_fwd_pd_t238     int n_outputs() const override {
239         return 1 + (!types::is_zero_md(workspace_md()));
240     }
241 
hint_mdsdnnl::impl::pooling_fwd_pd_t242     std::vector<memory_desc_t> hint_mds(bool is_hint) const override {
243         if (!is_hint) return {};
244         return {*dst_md(0)};
245     }
246 
247 protected:
248     memory_desc_t src_md_;
249     memory_desc_t dst_md_;
250 
pooling_fwd_pd_tdnnl::impl::pooling_fwd_pd_t251     pooling_fwd_pd_t(const pooling_v2_desc_t *adesc,
252             const primitive_attr_t *attr, const pooling_fwd_pd_t *hint_fwd_pd)
253         : pooling_pd_t(adesc, attr, hint_fwd_pd)
254         , src_md_(desc_.src_desc)
255         , dst_md_(desc_.dst_desc) {}
256 
set_default_paramsdnnl::impl::pooling_fwd_pd_t257     virtual status_t set_default_params() {
258         if (dst_md()->format_kind != format_kind::any) return status::success;
259 
260         if (src_md()->format_kind != format_kind::blocked)
261             return status::unimplemented;
262 
263         return memory_desc_init_by_blocking_desc(
264                 dst_md_, src_md_.format_desc.blocking);
265     }
266 };
267 
268 struct pooling_bwd_pd_t : public pooling_pd_t {
269     typedef pooling_bwd_pd_t base_class;
270     typedef pooling_fwd_pd_t hint_class;
271 
arg_usagednnl::impl::pooling_bwd_pd_t272     arg_usage_t arg_usage(int arg) const override {
273         if (arg == DNNL_ARG_DIFF_DST) return arg_usage_t::input;
274 
275         if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output;
276 
277         if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md())))
278             return arg_usage_t::input;
279 
280         return primitive_desc_t::arg_usage(arg);
281     }
282 
arg_mddnnl::impl::pooling_bwd_pd_t283     const memory_desc_t *arg_md(int arg) const override {
284         switch (arg) {
285             case DNNL_ARG_DIFF_SRC: return diff_src_md(0);
286             case DNNL_ARG_DIFF_DST: return diff_dst_md(0);
287             default: return pooling_pd_t::arg_md(arg);
288         }
289     }
290 
diff_src_mddnnl::impl::pooling_bwd_pd_t291     const memory_desc_t *diff_src_md(int index = 0) const override {
292         return index == 0 ? &diff_src_md_ : &glob_zero_md;
293     }
diff_dst_mddnnl::impl::pooling_bwd_pd_t294     const memory_desc_t *diff_dst_md(int index = 0) const override {
295         return index == 0 ? &diff_dst_md_ : &glob_zero_md;
296     }
workspace_mddnnl::impl::pooling_bwd_pd_t297     const memory_desc_t *workspace_md(int index = 0) const override {
298         return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_
299                                                          : &glob_zero_md;
300     }
301 
n_inputsdnnl::impl::pooling_bwd_pd_t302     int n_inputs() const override {
303         return 1 + (!types::is_zero_md(workspace_md()));
304     }
n_outputsdnnl::impl::pooling_bwd_pd_t305     int n_outputs() const override { return 1; }
306 
hint_mdsdnnl::impl::pooling_bwd_pd_t307     std::vector<memory_desc_t> hint_mds(bool is_hint) const override {
308         assert(!is_hint);
309         MAYBE_UNUSED(is_hint);
310         return hint_mds_;
311     }
312 
313 protected:
314     memory_desc_t diff_src_md_;
315     memory_desc_t diff_dst_md_;
316 
pooling_bwd_pd_tdnnl::impl::pooling_bwd_pd_t317     pooling_bwd_pd_t(const pooling_v2_desc_t *adesc,
318             const primitive_attr_t *attr, const pooling_fwd_pd_t *hint_fwd_pd)
319         : pooling_pd_t(adesc, attr, hint_fwd_pd)
320         , diff_src_md_(desc_.diff_src_desc)
321         , diff_dst_md_(desc_.diff_dst_desc) {
322         if (hint_fwd_pd_)
323             hint_mds_ = hint_fwd_pd_->hint_mds(true /* is_hint */);
324     }
325 
set_default_paramsdnnl::impl::pooling_bwd_pd_t326     virtual status_t set_default_params() {
327         if (diff_dst_md()->format_kind == format_kind::any) {
328             status_t status = status::success;
329             if (hint_fwd_pd_)
330                 status = memory_desc_init_by_md_and_dt(diff_dst_md_,
331                         hint_mds(false /* is_hint */)[0],
332                         diff_dst_md_.data_type);
333             else
334                 status = memory_desc_init_by_strides(diff_dst_md_, nullptr);
335             if (status != status::success) return status;
336         }
337 
338         if (diff_src_md()->format_kind != format_kind::any)
339             return status::success;
340 
341         if (diff_dst_md()->format_kind != format_kind::blocked)
342             return status::unimplemented;
343 
344         return memory_desc_init_by_blocking_desc(
345                 diff_src_md_, diff_dst_md_.format_desc.blocking);
346     }
347 
348 private:
349     std::vector<memory_desc_t> hint_mds_;
350 };
351 
352 } // namespace impl
353 } // namespace dnnl
354 
355 #endif
356 
357 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
358