1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef COMMON_DECONVOLUTION_PD_HPP
18 #define COMMON_DECONVOLUTION_PD_HPP
19 
20 #include "oneapi/dnnl/dnnl.h"
21 
22 #include "c_types_map.hpp"
23 #include "convolution_pd.hpp"
24 #include "primitive_desc.hpp"
25 #include "utils.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 
30 struct deconvolution_fwd_pd_t;
31 
32 struct deconvolution_pd_t : public primitive_desc_t {
33     static constexpr auto base_pkind = primitive_kind::deconvolution;
34 
descdnnl::impl::deconvolution_pd_t35     const deconvolution_desc_t *desc() const { return &desc_; }
op_descdnnl::impl::deconvolution_pd_t36     const op_desc_t *op_desc() const override {
37         return reinterpret_cast<const op_desc_t *>(this->desc());
38     }
39 
querydnnl::impl::deconvolution_pd_t40     status_t query(query_t what, int idx, void *result) const override {
41         switch (what) {
42             case query::prop_kind:
43                 *(prop_kind_t *)result = desc()->prop_kind;
44                 break;
45             case pkind_traits<base_pkind>::query_d:
46                 *(const deconvolution_desc_t **)result = desc();
47                 break;
48             default: return primitive_desc_t::query(what, idx, result);
49         }
50         return status::success;
51     }
52 
53     /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
54 
MBdnnl::impl::deconvolution_pd_t55     dim_t MB() const { return invariant_src_md()->dims[0]; }
56 
ICdnnl::impl::deconvolution_pd_t57     dim_t IC() const { return invariant_src_md()->dims[1]; }
OCdnnl::impl::deconvolution_pd_t58     dim_t OC() const { return invariant_dst_md()->dims[1]; }
Gdnnl::impl::deconvolution_pd_t59     dim_t G() const { return with_groups() ? invariant_wei_md()->dims[0] : 1; }
60 
IDdnnl::impl::deconvolution_pd_t61     dim_t ID() const {
62         return ndims() >= 5 ? invariant_src_md()->dims[ndims() - 3] : 1;
63     }
IHdnnl::impl::deconvolution_pd_t64     dim_t IH() const {
65         return ndims() >= 4 ? invariant_src_md()->dims[ndims() - 2] : 1;
66     }
IWdnnl::impl::deconvolution_pd_t67     dim_t IW() const { return invariant_src_md()->dims[ndims() - 1]; }
68 
ODdnnl::impl::deconvolution_pd_t69     dim_t OD() const {
70         return ndims() >= 5 ? invariant_dst_md()->dims[ndims() - 3] : 1;
71     }
OHdnnl::impl::deconvolution_pd_t72     dim_t OH() const {
73         return ndims() >= 4 ? invariant_dst_md()->dims[ndims() - 2] : 1;
74     }
OWdnnl::impl::deconvolution_pd_t75     dim_t OW() const { return invariant_dst_md()->dims[ndims() - 1]; }
76 
KDdnnl::impl::deconvolution_pd_t77     dim_t KD() const {
78         const int w_ndims = ndims() + with_groups();
79         return ndims() >= 5 ? invariant_wei_md()->dims[w_ndims - 3] : 1;
80     }
KHdnnl::impl::deconvolution_pd_t81     dim_t KH() const {
82         const int w_ndims = ndims() + with_groups();
83         return ndims() >= 4 ? invariant_wei_md()->dims[w_ndims - 2] : 1;
84     }
KWdnnl::impl::deconvolution_pd_t85     dim_t KW() const {
86         const int w_ndims = ndims() + with_groups();
87         return invariant_wei_md()->dims[w_ndims - 1];
88     }
89 
KSDdnnl::impl::deconvolution_pd_t90     dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
KSHdnnl::impl::deconvolution_pd_t91     dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
KSWdnnl::impl::deconvolution_pd_t92     dim_t KSW() const { return desc_.strides[ndims() - 3]; }
93 
KDDdnnl::impl::deconvolution_pd_t94     dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
KDHdnnl::impl::deconvolution_pd_t95     dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
KDWdnnl::impl::deconvolution_pd_t96     dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
97 
padFrontdnnl::impl::deconvolution_pd_t98     dim_t padFront() const {
99         return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0;
100     }
padBackdnnl::impl::deconvolution_pd_t101     dim_t padBack() const {
102         return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0;
103     }
padTdnnl::impl::deconvolution_pd_t104     dim_t padT() const {
105         return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0;
106     }
padBdnnl::impl::deconvolution_pd_t107     dim_t padB() const {
108         return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0;
109     }
padLdnnl::impl::deconvolution_pd_t110     dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
padRdnnl::impl::deconvolution_pd_t111     dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
112 
with_biasdnnl::impl::deconvolution_pd_t113     bool with_bias() const {
114         auto *bia_d = desc()->prop_kind == prop_kind::backward_weights
115                 ? &desc()->diff_bias_desc
116                 : &desc()->bias_desc;
117         return !memory_desc_wrapper(bia_d).is_zero();
118     }
119 
with_groupsdnnl::impl::deconvolution_pd_t120     bool with_groups() const {
121         return invariant_wei_md()->ndims == ndims() + 1;
122     }
123 
ndimsdnnl::impl::deconvolution_pd_t124     int ndims() const { return invariant_src_md()->ndims; }
125 
is_fwddnnl::impl::deconvolution_pd_t126     bool is_fwd() const {
127         return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
128                 prop_kind::forward_inference);
129     }
130 
has_zero_dim_memorydnnl::impl::deconvolution_pd_t131     bool has_zero_dim_memory() const {
132         const auto s_d = memory_desc_wrapper(*invariant_src_md());
133         const auto d_d = memory_desc_wrapper(*invariant_dst_md());
134         return s_d.has_zero_dim() || d_d.has_zero_dim();
135     }
136 
invariant_src_mddnnl::impl::deconvolution_pd_t137     const memory_desc_t *invariant_src_md() const {
138         return desc()->prop_kind == prop_kind::backward_data ? diff_src_md()
139                                                              : src_md();
140     }
invariant_wei_mddnnl::impl::deconvolution_pd_t141     const memory_desc_t *invariant_wei_md(int index = 0) const {
142         return desc()->prop_kind == prop_kind::backward_weights
143                 ? diff_weights_md(index)
144                 : weights_md(index);
145     }
invariant_bia_mddnnl::impl::deconvolution_pd_t146     const memory_desc_t *invariant_bia_md() const {
147         return invariant_wei_md(1);
148     }
invariant_dst_mddnnl::impl::deconvolution_pd_t149     const memory_desc_t *invariant_dst_md() const {
150         return utils::one_of(desc()->prop_kind, prop_kind::forward_inference,
151                        prop_kind::forward_training)
152                 ? dst_md()
153                 : diff_dst_md();
154     }
155 
156 protected:
157     deconvolution_desc_t desc_;
158     const deconvolution_fwd_pd_t *hint_fwd_pd_;
159 
deconvolution_pd_tdnnl::impl::deconvolution_pd_t160     deconvolution_pd_t(const deconvolution_desc_t *adesc,
161             const primitive_attr_t *attr,
162             const deconvolution_fwd_pd_t *hint_fwd_pd)
163         : primitive_desc_t(attr, base_pkind)
164         , desc_(*adesc)
165         , hint_fwd_pd_(hint_fwd_pd) {}
166 };
167 
168 struct deconvolution_fwd_pd_t : public deconvolution_pd_t {
169     typedef deconvolution_fwd_pd_t base_class;
170     typedef deconvolution_fwd_pd_t hint_class;
171 
arg_usagednnl::impl::deconvolution_fwd_pd_t172     arg_usage_t arg_usage(int arg) const override {
173         if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS))
174             return arg_usage_t::input;
175 
176         if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input;
177 
178         if (arg == DNNL_ARG_DST) return arg_usage_t::output;
179 
180         return primitive_desc_t::arg_usage(arg);
181     }
182 
arg_mddnnl::impl::deconvolution_fwd_pd_t183     const memory_desc_t *arg_md(int arg) const override {
184         switch (arg) {
185             case DNNL_ARG_SRC: return src_md(0);
186             case DNNL_ARG_WEIGHTS: return weights_md(0);
187             case DNNL_ARG_BIAS: return weights_md(1);
188             case DNNL_ARG_DST: return dst_md(0);
189             default: return deconvolution_pd_t::arg_md(arg);
190         }
191     }
192 
src_mddnnl::impl::deconvolution_fwd_pd_t193     const memory_desc_t *src_md(int index = 0) const override {
194         return index == 0 ? &src_md_ : &glob_zero_md;
195     }
dst_mddnnl::impl::deconvolution_fwd_pd_t196     const memory_desc_t *dst_md(int index = 0) const override {
197         return index == 0 ? &dst_md_ : &glob_zero_md;
198     }
weights_mddnnl::impl::deconvolution_fwd_pd_t199     const memory_desc_t *weights_md(int index = 0) const override {
200         if (index == 0) return &weights_md_;
201         if (index == 1 && with_bias()) return &bias_md_;
202         return &glob_zero_md;
203     }
204 
n_inputsdnnl::impl::deconvolution_fwd_pd_t205     int n_inputs() const override {
206         return 2 + with_bias() + n_binary_po_inputs();
207     }
n_outputsdnnl::impl::deconvolution_fwd_pd_t208     int n_outputs() const override { return 1; }
209 
210 protected:
211     memory_desc_t src_md_;
212     memory_desc_t weights_md_;
213     memory_desc_t bias_md_;
214     memory_desc_t dst_md_;
215 
deconvolution_fwd_pd_tdnnl::impl::deconvolution_fwd_pd_t216     deconvolution_fwd_pd_t(const deconvolution_desc_t *adesc,
217             const primitive_attr_t *attr,
218             const deconvolution_fwd_pd_t *hint_fwd_pd)
219         : deconvolution_pd_t(adesc, attr, hint_fwd_pd)
220         , src_md_(desc_.src_desc)
221         , weights_md_(desc_.weights_desc)
222         , bias_md_(desc_.bias_desc)
223         , dst_md_(desc_.dst_desc) {}
224 };
225 
226 struct deconvolution_bwd_data_pd_t : public deconvolution_pd_t {
227     typedef deconvolution_bwd_data_pd_t base_class;
228     typedef deconvolution_fwd_pd_t hint_class;
229 
arg_usagednnl::impl::deconvolution_bwd_data_pd_t230     arg_usage_t arg_usage(int arg) const override {
231         if (utils::one_of(arg, DNNL_ARG_WEIGHTS, DNNL_ARG_DIFF_DST))
232             return arg_usage_t::input;
233 
234         if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output;
235 
236         return primitive_desc_t::arg_usage(arg);
237     }
238 
arg_mddnnl::impl::deconvolution_bwd_data_pd_t239     const memory_desc_t *arg_md(int arg) const override {
240         switch (arg) {
241             case DNNL_ARG_DIFF_SRC: return diff_src_md(0);
242             case DNNL_ARG_WEIGHTS: return weights_md(0);
243             case DNNL_ARG_BIAS: return weights_md(1);
244             case DNNL_ARG_DIFF_DST: return diff_dst_md(0);
245             default: return deconvolution_pd_t::arg_md(arg);
246         }
247     }
248 
diff_src_mddnnl::impl::deconvolution_bwd_data_pd_t249     const memory_desc_t *diff_src_md(int index = 0) const override {
250         return index == 0 ? &diff_src_md_ : &glob_zero_md;
251     }
diff_dst_mddnnl::impl::deconvolution_bwd_data_pd_t252     const memory_desc_t *diff_dst_md(int index = 0) const override {
253         return index == 0 ? &diff_dst_md_ : &glob_zero_md;
254     }
weights_mddnnl::impl::deconvolution_bwd_data_pd_t255     const memory_desc_t *weights_md(int index = 0) const override {
256         return index == 0 ? &weights_md_ : &glob_zero_md;
257     }
258 
n_inputsdnnl::impl::deconvolution_bwd_data_pd_t259     int n_inputs() const override { return 2; }
n_outputsdnnl::impl::deconvolution_bwd_data_pd_t260     int n_outputs() const override { return 1; }
261 
262 protected:
263     memory_desc_t diff_src_md_;
264     memory_desc_t weights_md_;
265     memory_desc_t diff_dst_md_;
266 
deconvolution_bwd_data_pd_tdnnl::impl::deconvolution_bwd_data_pd_t267     deconvolution_bwd_data_pd_t(const deconvolution_desc_t *adesc,
268             const primitive_attr_t *attr,
269             const deconvolution_fwd_pd_t *hint_fwd_pd)
270         : deconvolution_pd_t(adesc, attr, hint_fwd_pd)
271         , diff_src_md_(desc_.diff_src_desc)
272         , weights_md_(desc_.weights_desc)
273         , diff_dst_md_(desc_.diff_dst_desc) {}
274 };
275 
276 struct deconvolution_bwd_weights_pd_t : public deconvolution_pd_t {
277     typedef deconvolution_bwd_weights_pd_t base_class;
278     typedef deconvolution_fwd_pd_t hint_class;
279 
arg_usagednnl::impl::deconvolution_bwd_weights_pd_t280     arg_usage_t arg_usage(int arg) const override {
281         if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST))
282             return arg_usage_t::input;
283 
284         if (arg == DNNL_ARG_DIFF_WEIGHTS) return arg_usage_t::output;
285 
286         if (arg == DNNL_ARG_DIFF_BIAS && with_bias())
287             return arg_usage_t::output;
288 
289         return primitive_desc_t::arg_usage(arg);
290     }
291 
arg_mddnnl::impl::deconvolution_bwd_weights_pd_t292     const memory_desc_t *arg_md(int arg) const override {
293         switch (arg) {
294             case DNNL_ARG_SRC: return src_md(0);
295             case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0);
296             case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1);
297             case DNNL_ARG_DIFF_DST: return diff_dst_md(0);
298             default: return deconvolution_pd_t::arg_md(arg);
299         }
300     }
301 
src_mddnnl::impl::deconvolution_bwd_weights_pd_t302     const memory_desc_t *src_md(int index = 0) const override {
303         return index == 0 ? &src_md_ : &glob_zero_md;
304     }
diff_dst_mddnnl::impl::deconvolution_bwd_weights_pd_t305     const memory_desc_t *diff_dst_md(int index = 0) const override {
306         return index == 0 ? &diff_dst_md_ : &glob_zero_md;
307     }
diff_weights_mddnnl::impl::deconvolution_bwd_weights_pd_t308     const memory_desc_t *diff_weights_md(int index = 0) const override {
309         if (index == 0) return &diff_weights_md_;
310         if (index == 1 && with_bias()) return &diff_bias_md_;
311         return &glob_zero_md;
312     }
313 
n_inputsdnnl::impl::deconvolution_bwd_weights_pd_t314     int n_inputs() const override { return 2; }
n_outputsdnnl::impl::deconvolution_bwd_weights_pd_t315     int n_outputs() const override { return 1 + with_bias(); }
316 
317 protected:
318     memory_desc_t src_md_;
319     memory_desc_t diff_weights_md_;
320     memory_desc_t diff_bias_md_;
321     memory_desc_t diff_dst_md_;
322 
deconvolution_bwd_weights_pd_tdnnl::impl::deconvolution_bwd_weights_pd_t323     deconvolution_bwd_weights_pd_t(const deconvolution_desc_t *adesc,
324             const primitive_attr_t *attr,
325             const deconvolution_fwd_pd_t *hint_fwd_pd)
326         : deconvolution_pd_t(adesc, attr, hint_fwd_pd)
327         , src_md_(desc_.src_desc)
328         , diff_weights_md_(desc_.diff_weights_desc)
329         , diff_bias_md_(desc_.diff_bias_desc)
330         , diff_dst_md_(desc_.diff_dst_desc) {}
331 };
332 
333 } // namespace impl
334 } // namespace dnnl
335 
336 #endif
337 
338 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
339