1 /******************************************************************************* 2 * Copyright 2016-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef COMMON_CONVOLUTION_PD_HPP 18 #define COMMON_CONVOLUTION_PD_HPP 19 20 #include "oneapi/dnnl/dnnl.h" 21 22 #include "c_types_map.hpp" 23 #include "primitive_desc.hpp" 24 #include "utils.hpp" 25 26 namespace dnnl { 27 namespace impl { 28 29 status_t conv_desc_init(convolution_desc_t *conv_desc, prop_kind_t prop_kind, 30 alg_kind_t alg_kind, const memory_desc_t *src_desc, 31 const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, 32 const memory_desc_t *dst_desc, const dims_t strides, 33 const dims_t dilates, const dims_t padding_l, const dims_t padding_r); 34 35 memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc); 36 memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc); 37 memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc); 38 memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc); 39 const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc); 40 const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc); 41 const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc); 42 const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc); 43 44 struct convolution_fwd_pd_t; 45 46 struct convolution_pd_t : public primitive_desc_t { 47 static constexpr auto base_pkind = primitive_kind::convolution; 48 descdnnl::impl::convolution_pd_t49 const convolution_desc_t *desc() const { return &desc_; } op_descdnnl::impl::convolution_pd_t50 const op_desc_t *op_desc() const override { 51 return reinterpret_cast<const op_desc_t *>(this->desc()); 52 } 53 querydnnl::impl::convolution_pd_t54 status_t query(query_t what, int idx, void *result) const override { 55 switch (what) { 56 case query::prop_kind: 57 *(prop_kind_t *)result = desc()->prop_kind; 58 break; 59 case pkind_traits<base_pkind>::query_d: 60 *(const convolution_desc_t **)result = desc(); 61 break; 62 default: return primitive_desc_t::query(what, idx, result); 63 } 64 return status::success; 65 } 66 67 /* common conv aux functions */ 68 MBdnnl::impl::convolution_pd_t69 dim_t MB() const { return invariant_src_md()->dims[0]; } 70 ICdnnl::impl::convolution_pd_t71 dim_t IC() const { return invariant_src_md()->dims[1]; } OCdnnl::impl::convolution_pd_t72 dim_t OC() const { return invariant_dst_md()->dims[1]; } Gdnnl::impl::convolution_pd_t73 dim_t G() const { return with_groups() ? invariant_wei_md()->dims[0] : 1; } 74 IDdnnl::impl::convolution_pd_t75 dim_t ID() const { 76 return ndims() >= 5 ? invariant_src_md()->dims[ndims() - 3] : 1; 77 } IHdnnl::impl::convolution_pd_t78 dim_t IH() const { 79 return ndims() >= 4 ? invariant_src_md()->dims[ndims() - 2] : 1; 80 } IWdnnl::impl::convolution_pd_t81 dim_t IW() const { return invariant_src_md()->dims[ndims() - 1]; } 82 ODdnnl::impl::convolution_pd_t83 dim_t OD() const { 84 return ndims() >= 5 ? invariant_dst_md()->dims[ndims() - 3] : 1; 85 } OHdnnl::impl::convolution_pd_t86 dim_t OH() const { 87 return ndims() >= 4 ? invariant_dst_md()->dims[ndims() - 2] : 1; 88 } OWdnnl::impl::convolution_pd_t89 dim_t OW() const { return invariant_dst_md()->dims[ndims() - 1]; } 90 KDdnnl::impl::convolution_pd_t91 dim_t KD() const { 92 return ndims() >= 5 93 ? invariant_wei_md()->dims[ndims() + with_groups() - 3] 94 : 1; 95 } KHdnnl::impl::convolution_pd_t96 dim_t KH() const { 97 return ndims() >= 4 98 ? invariant_wei_md()->dims[ndims() + with_groups() - 2] 99 : 1; 100 } KWdnnl::impl::convolution_pd_t101 dim_t KW() const { 102 return invariant_wei_md()->dims[ndims() + with_groups() - 1]; 103 } 104 KSDdnnl::impl::convolution_pd_t105 dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } KSHdnnl::impl::convolution_pd_t106 dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } KSWdnnl::impl::convolution_pd_t107 dim_t KSW() const { return desc_.strides[ndims() - 3]; } 108 KDDdnnl::impl::convolution_pd_t109 dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } KDHdnnl::impl::convolution_pd_t110 dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } KDWdnnl::impl::convolution_pd_t111 dim_t KDW() const { return desc_.dilates[ndims() - 3]; } 112 padFrontdnnl::impl::convolution_pd_t113 dim_t padFront() const { 114 return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; 115 } padBackdnnl::impl::convolution_pd_t116 dim_t padBack() const { 117 return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; 118 } padTdnnl::impl::convolution_pd_t119 dim_t padT() const { 120 return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; 121 } padBdnnl::impl::convolution_pd_t122 dim_t padB() const { 123 return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; 124 } padLdnnl::impl::convolution_pd_t125 dim_t padL() const { return desc_.padding[0][ndims() - 3]; } padRdnnl::impl::convolution_pd_t126 dim_t padR() const { return desc_.padding[1][ndims() - 3]; } 127 ndimsdnnl::impl::convolution_pd_t128 int ndims() const { return invariant_src_md()->ndims; } 129 with_biasdnnl::impl::convolution_pd_t130 bool with_bias() const { 131 auto *bia_d = desc()->prop_kind == prop_kind::backward_weights 132 ? &desc()->diff_bias_desc 133 : &desc()->bias_desc; 134 return !memory_desc_wrapper(bia_d).is_zero(); 135 } with_groupsdnnl::impl::convolution_pd_t136 bool with_groups() const { 137 return invariant_wei_md()->ndims == ndims() + 1; 138 } 139 is_fwddnnl::impl::convolution_pd_t140 bool is_fwd() const { 141 return utils::one_of(desc_.prop_kind, prop_kind::forward_training, 142 prop_kind::forward_inference); 143 } 144 is_bwd_ddnnl::impl::convolution_pd_t145 bool is_bwd_d() const { 146 return desc_.prop_kind == prop_kind::backward_data; 147 } 148 is_bwd_wdnnl::impl::convolution_pd_t149 bool is_bwd_w() const { 150 return desc_.prop_kind == prop_kind::backward_weights; 151 } 152 has_zero_dim_memorydnnl::impl::convolution_pd_t153 bool has_zero_dim_memory() const { 154 const auto s_d = memory_desc_wrapper(*invariant_src_md()); 155 const auto d_d = memory_desc_wrapper(*invariant_dst_md()); 156 return s_d.has_zero_dim() || d_d.has_zero_dim(); 157 } 158 invariant_src_mddnnl::impl::convolution_pd_t159 const memory_desc_t *invariant_src_md() const { 160 return desc()->prop_kind == prop_kind::backward_data ? diff_src_md() 161 : src_md(); 162 } invariant_wei_mddnnl::impl::convolution_pd_t163 const memory_desc_t *invariant_wei_md(int index = 0) const { 164 return desc()->prop_kind == prop_kind::backward_weights 165 ? diff_weights_md(index) 166 : weights_md(index); 167 } invariant_bia_mddnnl::impl::convolution_pd_t168 const memory_desc_t *invariant_bia_md() const { 169 return invariant_wei_md(1); 170 } invariant_dst_mddnnl::impl::convolution_pd_t171 const memory_desc_t *invariant_dst_md() const { 172 return is_fwd() ? dst_md() : diff_dst_md(); 173 } 174 175 protected: 176 convolution_desc_t desc_; 177 const convolution_fwd_pd_t *hint_fwd_pd_; 178 convolution_pd_tdnnl::impl::convolution_pd_t179 convolution_pd_t(const convolution_desc_t *adesc, 180 const primitive_attr_t *attr, 181 const convolution_fwd_pd_t *hint_fwd_pd) 182 : primitive_desc_t(attr, base_pkind) 183 , desc_(*adesc) 184 , hint_fwd_pd_(hint_fwd_pd) {} 185 set_default_formats_common_templatednnl::impl::convolution_pd_t186 bool set_default_formats_common_template(memory_desc_t &src_md, 187 format_tag_t src_tag, memory_desc_t &wei_md, format_tag_t wei_tag, 188 memory_desc_t &dst_md, format_tag_t dst_tag, 189 memory_desc_t &bia_md) { 190 using namespace format_tag; 191 192 #define IS_OK(f) \ 193 do { \ 194 if ((f) != status::success) return false; \ 195 } while (0) 196 if (src_md.format_kind == format_kind::any 197 && !utils::one_of(src_tag, any, undef)) 198 IS_OK(memory_desc_init_by_tag(src_md, src_tag)); 199 if (dst_md.format_kind == format_kind::any 200 && !utils::one_of(dst_tag, any, undef)) 201 IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); 202 if (wei_md.format_kind == format_kind::any 203 && !utils::one_of(wei_tag, any, undef)) 204 IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); 205 if (with_bias() && bia_md.format_kind == format_kind::any) 206 IS_OK(memory_desc_init_by_tag(bia_md, x)); 207 #undef IS_OK 208 209 return true; 210 } 211 set_default_alg_kinddnnl::impl::convolution_pd_t212 bool set_default_alg_kind(alg_kind_t alg_kind) { 213 assert(utils::one_of(alg_kind, alg_kind::convolution_direct, 214 alg_kind::convolution_winograd)); 215 if (desc_.alg_kind == alg_kind::convolution_auto) 216 desc_.alg_kind = alg_kind; 217 return desc_.alg_kind == alg_kind; 218 } 219 expect_data_typesdnnl::impl::convolution_pd_t220 bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, 221 data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { 222 bool ok = true 223 && (src_dt == data_type::undef 224 || invariant_src_md()->data_type == src_dt) 225 && (wei_dt == data_type::undef 226 || invariant_wei_md()->data_type == wei_dt) 227 && (dst_dt == data_type::undef 228 || invariant_dst_md()->data_type == dst_dt) 229 && (acc_dt == data_type::undef 230 || desc_.accum_data_type == acc_dt); 231 if (with_bias() && bia_dt != data_type::undef) 232 ok = ok && invariant_bia_md()->data_type == bia_dt; 233 return ok; 234 } 235 }; 236 237 struct convolution_fwd_pd_t : public convolution_pd_t { 238 typedef convolution_fwd_pd_t base_class; 239 typedef convolution_fwd_pd_t hint_class; 240 arg_usagednnl::impl::convolution_fwd_pd_t241 arg_usage_t arg_usage(int arg) const override { 242 if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS)) 243 return arg_usage_t::input; 244 245 if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; 246 247 if (arg == DNNL_ARG_DST) return arg_usage_t::output; 248 249 return primitive_desc_t::arg_usage(arg); 250 } 251 arg_mddnnl::impl::convolution_fwd_pd_t252 const memory_desc_t *arg_md(int arg) const override { 253 switch (arg) { 254 case DNNL_ARG_SRC: return src_md(0); 255 case DNNL_ARG_WEIGHTS: return weights_md(0); 256 case DNNL_ARG_BIAS: return weights_md(1); 257 case DNNL_ARG_DST: return dst_md(0); 258 default: return convolution_pd_t::arg_md(arg); 259 } 260 } 261 src_mddnnl::impl::convolution_fwd_pd_t262 const memory_desc_t *src_md(int index = 0) const override { 263 return index == 0 ? &src_md_ : &glob_zero_md; 264 } dst_mddnnl::impl::convolution_fwd_pd_t265 const memory_desc_t *dst_md(int index = 0) const override { 266 return index == 0 ? &dst_md_ : &glob_zero_md; 267 } weights_mddnnl::impl::convolution_fwd_pd_t268 const memory_desc_t *weights_md(int index = 0) const override { 269 if (index == 0) return &weights_md_; 270 if (index == 1 && with_bias()) return &bias_md_; 271 return &glob_zero_md; 272 } 273 n_inputsdnnl::impl::convolution_fwd_pd_t274 int n_inputs() const override { 275 return 2 + with_bias() + attr_post_op_dw_inputs() 276 + n_binary_po_inputs(); 277 } 278 n_outputsdnnl::impl::convolution_fwd_pd_t279 int n_outputs() const override { return 1; } 280 281 protected: 282 memory_desc_t src_md_; 283 memory_desc_t weights_md_; 284 memory_desc_t bias_md_; 285 memory_desc_t dst_md_; 286 convolution_fwd_pd_tdnnl::impl::convolution_fwd_pd_t287 convolution_fwd_pd_t(const convolution_desc_t *adesc, 288 const primitive_attr_t *attr, 289 const convolution_fwd_pd_t *hint_fwd_pd) 290 : convolution_pd_t(adesc, attr, hint_fwd_pd) 291 , src_md_(desc_.src_desc) 292 , weights_md_(desc_.weights_desc) 293 , bias_md_(desc_.bias_desc) 294 , dst_md_(desc_.dst_desc) {} 295 set_default_formats_commondnnl::impl::convolution_fwd_pd_t296 bool set_default_formats_common( 297 format_tag_t src_tag, format_tag_t wei_tag, format_tag_t dst_tag) { 298 return set_default_formats_common_template(src_md_, src_tag, 299 weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); 300 } 301 attr_post_op_dw_inputsdnnl::impl::convolution_fwd_pd_t302 int attr_post_op_dw_inputs() const { 303 const auto &po = attr_.post_ops_; 304 int conv = po.find(primitive_kind::convolution); 305 if (conv == -1) return 0; 306 return po.entry_[conv].depthwise_conv.bias_dt == data_type::undef ? 1 307 : 2; 308 } 309 }; 310 311 struct convolution_bwd_data_pd_t : public convolution_pd_t { 312 typedef convolution_bwd_data_pd_t base_class; 313 typedef convolution_fwd_pd_t hint_class; 314 arg_usagednnl::impl::convolution_bwd_data_pd_t315 arg_usage_t arg_usage(int arg) const override { 316 if (utils::one_of(arg, DNNL_ARG_WEIGHTS, DNNL_ARG_DIFF_DST)) 317 return arg_usage_t::input; 318 319 if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; 320 321 return primitive_desc_t::arg_usage(arg); 322 } 323 arg_mddnnl::impl::convolution_bwd_data_pd_t324 const memory_desc_t *arg_md(int arg) const override { 325 switch (arg) { 326 case DNNL_ARG_DIFF_SRC: return diff_src_md(0); 327 case DNNL_ARG_WEIGHTS: return weights_md(0); 328 case DNNL_ARG_BIAS: return weights_md(1); 329 case DNNL_ARG_DIFF_DST: return diff_dst_md(0); 330 default: return convolution_pd_t::arg_md(arg); 331 } 332 } 333 diff_src_mddnnl::impl::convolution_bwd_data_pd_t334 const memory_desc_t *diff_src_md(int index = 0) const override { 335 return index == 0 ? &diff_src_md_ : &glob_zero_md; 336 } diff_dst_mddnnl::impl::convolution_bwd_data_pd_t337 const memory_desc_t *diff_dst_md(int index = 0) const override { 338 return index == 0 ? &diff_dst_md_ : &glob_zero_md; 339 } weights_mddnnl::impl::convolution_bwd_data_pd_t340 const memory_desc_t *weights_md(int index = 0) const override { 341 if (index == 0) return &weights_md_; 342 if (index == 1 && with_bias()) return &bias_md_; 343 return &glob_zero_md; 344 } 345 n_inputsdnnl::impl::convolution_bwd_data_pd_t346 int n_inputs() const override { return 2 + with_bias(); } n_outputsdnnl::impl::convolution_bwd_data_pd_t347 int n_outputs() const override { return 1; } 348 support_biasdnnl::impl::convolution_bwd_data_pd_t349 virtual bool support_bias() const { return false; } 350 351 protected: 352 memory_desc_t diff_src_md_; 353 memory_desc_t weights_md_; 354 memory_desc_t bias_md_; 355 memory_desc_t diff_dst_md_; 356 convolution_bwd_data_pd_tdnnl::impl::convolution_bwd_data_pd_t357 convolution_bwd_data_pd_t(const convolution_desc_t *adesc, 358 const primitive_attr_t *attr, 359 const convolution_fwd_pd_t *hint_fwd_pd) 360 : convolution_pd_t(adesc, attr, hint_fwd_pd) 361 , diff_src_md_(desc_.diff_src_desc) 362 , weights_md_(desc_.weights_desc) 363 , bias_md_(desc_.bias_desc) 364 , diff_dst_md_(desc_.diff_dst_desc) {} 365 set_default_formats_commondnnl::impl::convolution_bwd_data_pd_t366 bool set_default_formats_common(format_tag_t diff_src_tag, 367 format_tag_t wei_tag, format_tag_t diff_dst_tag) { 368 return set_default_formats_common_template(diff_src_md_, diff_src_tag, 369 weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); 370 } 371 }; 372 373 struct convolution_bwd_weights_pd_t : public convolution_pd_t { 374 typedef convolution_bwd_weights_pd_t base_class; 375 typedef convolution_fwd_pd_t hint_class; 376 convolution_bwd_weights_pd_tdnnl::impl::convolution_bwd_weights_pd_t377 convolution_bwd_weights_pd_t(const convolution_desc_t *adesc, 378 const primitive_attr_t *attr, 379 const convolution_fwd_pd_t *hint_fwd_pd) 380 : convolution_pd_t(adesc, attr, hint_fwd_pd) 381 , src_md_(desc_.src_desc) 382 , diff_weights_md_(desc_.diff_weights_desc) 383 , diff_bias_md_(desc_.diff_bias_desc) 384 , diff_dst_md_(desc_.diff_dst_desc) {} 385 arg_usagednnl::impl::convolution_bwd_weights_pd_t386 arg_usage_t arg_usage(int arg) const override { 387 if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST)) 388 return arg_usage_t::input; 389 390 if (arg == DNNL_ARG_DIFF_WEIGHTS) return arg_usage_t::output; 391 392 if (arg == DNNL_ARG_DIFF_BIAS && with_bias()) 393 return arg_usage_t::output; 394 395 return primitive_desc_t::arg_usage(arg); 396 } 397 arg_mddnnl::impl::convolution_bwd_weights_pd_t398 const memory_desc_t *arg_md(int arg) const override { 399 switch (arg) { 400 case DNNL_ARG_SRC: return src_md(0); 401 case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0); 402 case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1); 403 case DNNL_ARG_DIFF_DST: return diff_dst_md(0); 404 default: return convolution_pd_t::arg_md(arg); 405 } 406 } 407 src_mddnnl::impl::convolution_bwd_weights_pd_t408 const memory_desc_t *src_md(int index = 0) const override { 409 return index == 0 ? &src_md_ : &glob_zero_md; 410 } diff_dst_mddnnl::impl::convolution_bwd_weights_pd_t411 const memory_desc_t *diff_dst_md(int index = 0) const override { 412 return index == 0 ? &diff_dst_md_ : &glob_zero_md; 413 } diff_weights_mddnnl::impl::convolution_bwd_weights_pd_t414 const memory_desc_t *diff_weights_md(int index = 0) const override { 415 if (index == 0) return &diff_weights_md_; 416 if (index == 1 && with_bias()) return &diff_bias_md_; 417 return &glob_zero_md; 418 } 419 n_inputsdnnl::impl::convolution_bwd_weights_pd_t420 int n_inputs() const override { return 2; } n_outputsdnnl::impl::convolution_bwd_weights_pd_t421 int n_outputs() const override { return 1 + with_bias(); } 422 423 protected: 424 memory_desc_t src_md_; 425 memory_desc_t diff_weights_md_; 426 memory_desc_t diff_bias_md_; 427 memory_desc_t diff_dst_md_; 428 set_default_formats_commondnnl::impl::convolution_bwd_weights_pd_t429 bool set_default_formats_common(format_tag_t src_tag, 430 format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { 431 return set_default_formats_common_template(src_md_, src_tag, 432 diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, 433 diff_bias_md_); 434 } 435 }; 436 437 } // namespace impl 438 } // namespace dnnl 439 440 #endif 441 442 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s 443