1 /******************************************************************************* 2 * Copyright 2018-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef COMMON_DECONVOLUTION_PD_HPP 18 #define COMMON_DECONVOLUTION_PD_HPP 19 20 #include "oneapi/dnnl/dnnl.h" 21 22 #include "c_types_map.hpp" 23 #include "convolution_pd.hpp" 24 #include "primitive_desc.hpp" 25 #include "utils.hpp" 26 27 namespace dnnl { 28 namespace impl { 29 30 struct deconvolution_fwd_pd_t; 31 32 struct deconvolution_pd_t : public primitive_desc_t { 33 static constexpr auto base_pkind = primitive_kind::deconvolution; 34 descdnnl::impl::deconvolution_pd_t35 const deconvolution_desc_t *desc() const { return &desc_; } op_descdnnl::impl::deconvolution_pd_t36 const op_desc_t *op_desc() const override { 37 return reinterpret_cast<const op_desc_t *>(this->desc()); 38 } 39 querydnnl::impl::deconvolution_pd_t40 status_t query(query_t what, int idx, void *result) const override { 41 switch (what) { 42 case query::prop_kind: 43 *(prop_kind_t *)result = desc()->prop_kind; 44 break; 45 case pkind_traits<base_pkind>::query_d: 46 *(const deconvolution_desc_t **)result = desc(); 47 break; 48 default: return primitive_desc_t::query(what, idx, result); 49 } 50 return status::success; 51 } 52 53 /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */ 54 MBdnnl::impl::deconvolution_pd_t55 dim_t MB() const { return invariant_src_md()->dims[0]; } 56 ICdnnl::impl::deconvolution_pd_t57 dim_t IC() const { return invariant_src_md()->dims[1]; } OCdnnl::impl::deconvolution_pd_t58 dim_t OC() const { return invariant_dst_md()->dims[1]; } Gdnnl::impl::deconvolution_pd_t59 dim_t G() const { return with_groups() ? invariant_wei_md()->dims[0] : 1; } 60 IDdnnl::impl::deconvolution_pd_t61 dim_t ID() const { 62 return ndims() >= 5 ? invariant_src_md()->dims[ndims() - 3] : 1; 63 } IHdnnl::impl::deconvolution_pd_t64 dim_t IH() const { 65 return ndims() >= 4 ? invariant_src_md()->dims[ndims() - 2] : 1; 66 } IWdnnl::impl::deconvolution_pd_t67 dim_t IW() const { return invariant_src_md()->dims[ndims() - 1]; } 68 ODdnnl::impl::deconvolution_pd_t69 dim_t OD() const { 70 return ndims() >= 5 ? invariant_dst_md()->dims[ndims() - 3] : 1; 71 } OHdnnl::impl::deconvolution_pd_t72 dim_t OH() const { 73 return ndims() >= 4 ? invariant_dst_md()->dims[ndims() - 2] : 1; 74 } OWdnnl::impl::deconvolution_pd_t75 dim_t OW() const { return invariant_dst_md()->dims[ndims() - 1]; } 76 KDdnnl::impl::deconvolution_pd_t77 dim_t KD() const { 78 const int w_ndims = ndims() + with_groups(); 79 return ndims() >= 5 ? invariant_wei_md()->dims[w_ndims - 3] : 1; 80 } KHdnnl::impl::deconvolution_pd_t81 dim_t KH() const { 82 const int w_ndims = ndims() + with_groups(); 83 return ndims() >= 4 ? invariant_wei_md()->dims[w_ndims - 2] : 1; 84 } KWdnnl::impl::deconvolution_pd_t85 dim_t KW() const { 86 const int w_ndims = ndims() + with_groups(); 87 return invariant_wei_md()->dims[w_ndims - 1]; 88 } 89 KSDdnnl::impl::deconvolution_pd_t90 dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } KSHdnnl::impl::deconvolution_pd_t91 dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } KSWdnnl::impl::deconvolution_pd_t92 dim_t KSW() const { return desc_.strides[ndims() - 3]; } 93 KDDdnnl::impl::deconvolution_pd_t94 dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } KDHdnnl::impl::deconvolution_pd_t95 dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } KDWdnnl::impl::deconvolution_pd_t96 dim_t KDW() const { return desc_.dilates[ndims() - 3]; } 97 padFrontdnnl::impl::deconvolution_pd_t98 dim_t padFront() const { 99 return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; 100 } padBackdnnl::impl::deconvolution_pd_t101 dim_t padBack() const { 102 return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; 103 } padTdnnl::impl::deconvolution_pd_t104 dim_t padT() const { 105 return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; 106 } padBdnnl::impl::deconvolution_pd_t107 dim_t padB() const { 108 return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; 109 } padLdnnl::impl::deconvolution_pd_t110 dim_t padL() const { return desc_.padding[0][ndims() - 3]; } padRdnnl::impl::deconvolution_pd_t111 dim_t padR() const { return desc_.padding[1][ndims() - 3]; } 112 with_biasdnnl::impl::deconvolution_pd_t113 bool with_bias() const { 114 auto *bia_d = desc()->prop_kind == prop_kind::backward_weights 115 ? &desc()->diff_bias_desc 116 : &desc()->bias_desc; 117 return !memory_desc_wrapper(bia_d).is_zero(); 118 } 119 with_groupsdnnl::impl::deconvolution_pd_t120 bool with_groups() const { 121 return invariant_wei_md()->ndims == ndims() + 1; 122 } 123 ndimsdnnl::impl::deconvolution_pd_t124 int ndims() const { return invariant_src_md()->ndims; } 125 is_fwddnnl::impl::deconvolution_pd_t126 bool is_fwd() const { 127 return utils::one_of(desc_.prop_kind, prop_kind::forward_training, 128 prop_kind::forward_inference); 129 } 130 has_zero_dim_memorydnnl::impl::deconvolution_pd_t131 bool has_zero_dim_memory() const { 132 const auto s_d = memory_desc_wrapper(*invariant_src_md()); 133 const auto d_d = memory_desc_wrapper(*invariant_dst_md()); 134 return s_d.has_zero_dim() || d_d.has_zero_dim(); 135 } 136 invariant_src_mddnnl::impl::deconvolution_pd_t137 const memory_desc_t *invariant_src_md() const { 138 return desc()->prop_kind == prop_kind::backward_data ? diff_src_md() 139 : src_md(); 140 } invariant_wei_mddnnl::impl::deconvolution_pd_t141 const memory_desc_t *invariant_wei_md(int index = 0) const { 142 return desc()->prop_kind == prop_kind::backward_weights 143 ? diff_weights_md(index) 144 : weights_md(index); 145 } invariant_bia_mddnnl::impl::deconvolution_pd_t146 const memory_desc_t *invariant_bia_md() const { 147 return invariant_wei_md(1); 148 } invariant_dst_mddnnl::impl::deconvolution_pd_t149 const memory_desc_t *invariant_dst_md() const { 150 return utils::one_of(desc()->prop_kind, prop_kind::forward_inference, 151 prop_kind::forward_training) 152 ? dst_md() 153 : diff_dst_md(); 154 } 155 156 protected: 157 deconvolution_desc_t desc_; 158 const deconvolution_fwd_pd_t *hint_fwd_pd_; 159 deconvolution_pd_tdnnl::impl::deconvolution_pd_t160 deconvolution_pd_t(const deconvolution_desc_t *adesc, 161 const primitive_attr_t *attr, 162 const deconvolution_fwd_pd_t *hint_fwd_pd) 163 : primitive_desc_t(attr, base_pkind) 164 , desc_(*adesc) 165 , hint_fwd_pd_(hint_fwd_pd) {} 166 }; 167 168 struct deconvolution_fwd_pd_t : public deconvolution_pd_t { 169 typedef deconvolution_fwd_pd_t base_class; 170 typedef deconvolution_fwd_pd_t hint_class; 171 arg_usagednnl::impl::deconvolution_fwd_pd_t172 arg_usage_t arg_usage(int arg) const override { 173 if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS)) 174 return arg_usage_t::input; 175 176 if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; 177 178 if (arg == DNNL_ARG_DST) return arg_usage_t::output; 179 180 return primitive_desc_t::arg_usage(arg); 181 } 182 arg_mddnnl::impl::deconvolution_fwd_pd_t183 const memory_desc_t *arg_md(int arg) const override { 184 switch (arg) { 185 case DNNL_ARG_SRC: return src_md(0); 186 case DNNL_ARG_WEIGHTS: return weights_md(0); 187 case DNNL_ARG_BIAS: return weights_md(1); 188 case DNNL_ARG_DST: return dst_md(0); 189 default: return deconvolution_pd_t::arg_md(arg); 190 } 191 } 192 src_mddnnl::impl::deconvolution_fwd_pd_t193 const memory_desc_t *src_md(int index = 0) const override { 194 return index == 0 ? &src_md_ : &glob_zero_md; 195 } dst_mddnnl::impl::deconvolution_fwd_pd_t196 const memory_desc_t *dst_md(int index = 0) const override { 197 return index == 0 ? &dst_md_ : &glob_zero_md; 198 } weights_mddnnl::impl::deconvolution_fwd_pd_t199 const memory_desc_t *weights_md(int index = 0) const override { 200 if (index == 0) return &weights_md_; 201 if (index == 1 && with_bias()) return &bias_md_; 202 return &glob_zero_md; 203 } 204 n_inputsdnnl::impl::deconvolution_fwd_pd_t205 int n_inputs() const override { 206 return 2 + with_bias() + n_binary_po_inputs(); 207 } n_outputsdnnl::impl::deconvolution_fwd_pd_t208 int n_outputs() const override { return 1; } 209 210 protected: 211 memory_desc_t src_md_; 212 memory_desc_t weights_md_; 213 memory_desc_t bias_md_; 214 memory_desc_t dst_md_; 215 deconvolution_fwd_pd_tdnnl::impl::deconvolution_fwd_pd_t216 deconvolution_fwd_pd_t(const deconvolution_desc_t *adesc, 217 const primitive_attr_t *attr, 218 const deconvolution_fwd_pd_t *hint_fwd_pd) 219 : deconvolution_pd_t(adesc, attr, hint_fwd_pd) 220 , src_md_(desc_.src_desc) 221 , weights_md_(desc_.weights_desc) 222 , bias_md_(desc_.bias_desc) 223 , dst_md_(desc_.dst_desc) {} 224 }; 225 226 struct deconvolution_bwd_data_pd_t : public deconvolution_pd_t { 227 typedef deconvolution_bwd_data_pd_t base_class; 228 typedef deconvolution_fwd_pd_t hint_class; 229 arg_usagednnl::impl::deconvolution_bwd_data_pd_t230 arg_usage_t arg_usage(int arg) const override { 231 if (utils::one_of(arg, DNNL_ARG_WEIGHTS, DNNL_ARG_DIFF_DST)) 232 return arg_usage_t::input; 233 234 if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; 235 236 return primitive_desc_t::arg_usage(arg); 237 } 238 arg_mddnnl::impl::deconvolution_bwd_data_pd_t239 const memory_desc_t *arg_md(int arg) const override { 240 switch (arg) { 241 case DNNL_ARG_DIFF_SRC: return diff_src_md(0); 242 case DNNL_ARG_WEIGHTS: return weights_md(0); 243 case DNNL_ARG_BIAS: return weights_md(1); 244 case DNNL_ARG_DIFF_DST: return diff_dst_md(0); 245 default: return deconvolution_pd_t::arg_md(arg); 246 } 247 } 248 diff_src_mddnnl::impl::deconvolution_bwd_data_pd_t249 const memory_desc_t *diff_src_md(int index = 0) const override { 250 return index == 0 ? &diff_src_md_ : &glob_zero_md; 251 } diff_dst_mddnnl::impl::deconvolution_bwd_data_pd_t252 const memory_desc_t *diff_dst_md(int index = 0) const override { 253 return index == 0 ? &diff_dst_md_ : &glob_zero_md; 254 } weights_mddnnl::impl::deconvolution_bwd_data_pd_t255 const memory_desc_t *weights_md(int index = 0) const override { 256 return index == 0 ? &weights_md_ : &glob_zero_md; 257 } 258 n_inputsdnnl::impl::deconvolution_bwd_data_pd_t259 int n_inputs() const override { return 2; } n_outputsdnnl::impl::deconvolution_bwd_data_pd_t260 int n_outputs() const override { return 1; } 261 262 protected: 263 memory_desc_t diff_src_md_; 264 memory_desc_t weights_md_; 265 memory_desc_t diff_dst_md_; 266 deconvolution_bwd_data_pd_tdnnl::impl::deconvolution_bwd_data_pd_t267 deconvolution_bwd_data_pd_t(const deconvolution_desc_t *adesc, 268 const primitive_attr_t *attr, 269 const deconvolution_fwd_pd_t *hint_fwd_pd) 270 : deconvolution_pd_t(adesc, attr, hint_fwd_pd) 271 , diff_src_md_(desc_.diff_src_desc) 272 , weights_md_(desc_.weights_desc) 273 , diff_dst_md_(desc_.diff_dst_desc) {} 274 }; 275 276 struct deconvolution_bwd_weights_pd_t : public deconvolution_pd_t { 277 typedef deconvolution_bwd_weights_pd_t base_class; 278 typedef deconvolution_fwd_pd_t hint_class; 279 arg_usagednnl::impl::deconvolution_bwd_weights_pd_t280 arg_usage_t arg_usage(int arg) const override { 281 if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST)) 282 return arg_usage_t::input; 283 284 if (arg == DNNL_ARG_DIFF_WEIGHTS) return arg_usage_t::output; 285 286 if (arg == DNNL_ARG_DIFF_BIAS && with_bias()) 287 return arg_usage_t::output; 288 289 return primitive_desc_t::arg_usage(arg); 290 } 291 arg_mddnnl::impl::deconvolution_bwd_weights_pd_t292 const memory_desc_t *arg_md(int arg) const override { 293 switch (arg) { 294 case DNNL_ARG_SRC: return src_md(0); 295 case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0); 296 case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1); 297 case DNNL_ARG_DIFF_DST: return diff_dst_md(0); 298 default: return deconvolution_pd_t::arg_md(arg); 299 } 300 } 301 src_mddnnl::impl::deconvolution_bwd_weights_pd_t302 const memory_desc_t *src_md(int index = 0) const override { 303 return index == 0 ? &src_md_ : &glob_zero_md; 304 } diff_dst_mddnnl::impl::deconvolution_bwd_weights_pd_t305 const memory_desc_t *diff_dst_md(int index = 0) const override { 306 return index == 0 ? &diff_dst_md_ : &glob_zero_md; 307 } diff_weights_mddnnl::impl::deconvolution_bwd_weights_pd_t308 const memory_desc_t *diff_weights_md(int index = 0) const override { 309 if (index == 0) return &diff_weights_md_; 310 if (index == 1 && with_bias()) return &diff_bias_md_; 311 return &glob_zero_md; 312 } 313 n_inputsdnnl::impl::deconvolution_bwd_weights_pd_t314 int n_inputs() const override { return 2; } n_outputsdnnl::impl::deconvolution_bwd_weights_pd_t315 int n_outputs() const override { return 1 + with_bias(); } 316 317 protected: 318 memory_desc_t src_md_; 319 memory_desc_t diff_weights_md_; 320 memory_desc_t diff_bias_md_; 321 memory_desc_t diff_dst_md_; 322 deconvolution_bwd_weights_pd_tdnnl::impl::deconvolution_bwd_weights_pd_t323 deconvolution_bwd_weights_pd_t(const deconvolution_desc_t *adesc, 324 const primitive_attr_t *attr, 325 const deconvolution_fwd_pd_t *hint_fwd_pd) 326 : deconvolution_pd_t(adesc, attr, hint_fwd_pd) 327 , src_md_(desc_.src_desc) 328 , diff_weights_md_(desc_.diff_weights_desc) 329 , diff_bias_md_(desc_.diff_bias_desc) 330 , diff_dst_md_(desc_.diff_dst_desc) {} 331 }; 332 333 } // namespace impl 334 } // namespace dnnl 335 336 #endif 337 338 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s 339