Home
last modified time | relevance | path

Searched refs:diff_weights_md (Results 1 – 25 of 156) sorted by relevance

1234567

/dports/math/onednn/oneDNN-2.5.1/src/gpu/ocl/
H A Dxe_hp_1st_bwd_convolution.cpp43 const memory_desc_wrapper weights_mdw(diff_weights_md()); in init_conf()
45 const memory_desc_wrapper bias_mdw(diff_weights_md(1)); in init_conf()
48 *diff_weights_md(1), *attr()); in init_conf()
159 auto temp_wei_md = *diff_weights_md(); in init_conf()
171 auto temp_bias_md = *diff_weights_md(1); in init_conf()
177 &temp_bias_md, diff_weights_md(1), &r_attr); in init_conf()
258 diff_weights_md(conf.with_bias ? 1 : 0)->data_type, "BIA"); in init_kernel_ctx()
274 auto temp_wei_md = *diff_weights_md(); in init_scratchpad()
283 auto temp_bias_md = *diff_weights_md(1); in init_scratchpad()
309 auto temp_wei_md = *pd()->diff_weights_md(); in execute_backward_weights()
[all …]
H A Dgemm_inner_product.hpp221 && utils::one_of(diff_weights_md()->data_type, f32, bf16) in init()
226 src_md(), diff_weights_md(), diff_dst_md()) in init()
228 src_md(), diff_weights_md(), diff_dst_md()); in init()
235 init_2d_desc(&c_md, diff_weights_md(), true); in init()
239 init_2d_desc(&c_md, diff_weights_md()); in init()
284 const auto &wmd = *this->diff_weights_md(); in wei_tr()
/dports/misc/mxnet/incubator-mxnet-1.9.0/3rdparty/mkldnn/src/gpu/ocl/
H A Dgemm_inner_product.hpp219 && utils::one_of(diff_weights_md()->data_type, f32, bf16) in init()
224 src_md(), diff_weights_md(), diff_dst_md()) in init()
226 src_md(), diff_weights_md(), diff_dst_md()); in init()
233 init_2d_desc(&c_md, diff_weights_md(), true); in init()
237 init_2d_desc(&c_md, diff_weights_md()); in init()
252 const auto &wmd = *this->diff_weights_md(); in wei_tr()
275 kernel_ctx, pd()->diff_weights_md(1)->data_type, "BIA"); in init()
/dports/misc/mxnet/incubator-mxnet-1.9.0/3rdparty/mkldnn/src/gpu/nvidia/
H A Dcudnn_conv_inner_product.hpp324 && memory_format_ok(diff_weights_md(0)) in init()
327 with_bias(), memory_format_ok(diff_weights_md(1))); in init()
371 + diff_weights_md(0)->format_desc.blocking.inner_nblks in no_blocking()
372 + diff_weights_md(1)->format_desc.blocking.inner_nblks in no_blocking()
379 diff_weights_md(1)->data_type == data_type::f32) in data_types_ok()
381 diff_weights_md(0)->data_type, in data_types_ok()
H A Dcudnn_conv_inner_product_impl.hpp593 pd->diff_weights_md(), dims_[io::wei], strides_[io::wei]); in init()
598 CHECK(filter_tag(*pd->diff_weights_md(0), w_tag)); in init()
610 if (!supported_filter_format(pd->diff_weights_md(0)) in init()
621 memory_desc_wrapper(pd->diff_weights_md(0)).size(), in init()
624 CHECK(get_format(pd->diff_weights_md(0), diff_weights_format)); in init()
630 CHECK(convert_data_type(pd->diff_weights_md(0), &data_types_[io::wei])); in init()
645 pd->diff_weights_md(1), &data_types_[io::bia])); in init()
H A Dcudnn_gemm_inner_product_impl.hpp352 wie_tr_ = (pd->diff_weights_md(0)->format_desc.blocking.strides[0] in init()
356 CHECK(convert_data_type(pd->diff_weights_md(0), &data_types_[io::wei])); in init()
363 pd->diff_weights_md(0), dims_[io::wei], strides_[io::wei]); in init()
370 memory_desc_wrapper(pd->diff_weights_md(0)).size(), in init()
398 pd->diff_weights_md(1), &data_types_[io::bia])); in init()
/dports/math/onednn/oneDNN-2.5.1/src/gpu/nvidia/
H A Dcudnn_conv_inner_product.hpp325 && memory_format_ok(diff_weights_md(0)) in init()
328 with_bias(), memory_format_ok(diff_weights_md(1))); in init()
372 + diff_weights_md(0)->format_desc.blocking.inner_nblks in no_blocking()
373 + diff_weights_md(1)->format_desc.blocking.inner_nblks in no_blocking()
380 diff_weights_md(1)->data_type == data_type::f32) in data_types_ok()
382 diff_weights_md(0)->data_type, in data_types_ok()
H A Dcudnn_conv_inner_product_impl.hpp593 pd->diff_weights_md(), dims_[io::wei], strides_[io::wei]); in init()
598 CHECK(filter_tag(*pd->diff_weights_md(0), w_tag)); in init()
610 if (!supported_filter_format(pd->diff_weights_md(0)) in init()
621 memory_desc_wrapper(pd->diff_weights_md(0)).size(), in init()
624 CHECK(get_format(pd->diff_weights_md(0), diff_weights_format)); in init()
630 CHECK(convert_data_type(pd->diff_weights_md(0), &data_types_[io::wei])); in init()
645 pd->diff_weights_md(1), &data_types_[io::bia])); in init()
H A Dcudnn_gemm_inner_product_impl.hpp358 wie_tr_ = (pd->diff_weights_md(0)->format_desc.blocking.strides[0] in init()
362 CHECK(convert_data_type(pd->diff_weights_md(0), &data_types_[io::wei])); in init()
369 pd->diff_weights_md(0), dims_[io::wei], strides_[io::wei]); in init()
376 memory_desc_wrapper(pd->diff_weights_md(0)).size(), in init()
404 pd->diff_weights_md(1), &data_types_[io::bia])); in init()
/dports/misc/mxnet/incubator-mxnet-1.9.0/3rdparty/mkldnn/src/common/
H A Drnn_pd.hpp347 case DNNL_ARG_DIFF_WEIGHTS_LAYER: return diff_weights_md(0); in arg_md()
348 case DNNL_ARG_DIFF_WEIGHTS_ITER: return diff_weights_md(1); in arg_md()
350 return is_lstm_peephole() ? diff_weights_md(2) : &glob_zero_md; in arg_md()
353 ? diff_weights_md(2 + is_lstm_peephole()) in arg_md()
356 return diff_weights_md( in arg_md()
374 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
H A Dbatch_normalization_pd.hpp250 case DNNL_ARG_DIFF_SHIFT: return diff_weights_md(0); in arg_md()
268 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
284 + (!types::is_zero_md(diff_weights_md())) in n_outputs()
310 diff_weights_md()->data_type)); in check_scale_shift_data_type()
H A Ddeconvolution_pd.hpp143 ? diff_weights_md(index) in invariant_wei_md()
295 case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0); in arg_md()
296 case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1); in arg_md()
308 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
H A Dinner_product_pd.hpp139 ? diff_weights_md(index) in invariant_wei_md()
346 case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0); in arg_md()
347 case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1); in arg_md()
359 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
H A Dlayer_normalization_pd.hpp245 case DNNL_ARG_DIFF_SHIFT: return diff_weights_md(0); in arg_md()
266 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
301 diff_weights_md()->data_type)); in check_scale_shift_data_type()
/dports/math/onednn/oneDNN-2.5.1/src/common/
H A Drnn_pd.hpp347 case DNNL_ARG_DIFF_WEIGHTS_LAYER: return diff_weights_md(0); in arg_md()
348 case DNNL_ARG_DIFF_WEIGHTS_ITER: return diff_weights_md(1); in arg_md()
350 return is_lstm_peephole() ? diff_weights_md(2) : &glob_zero_md; in arg_md()
353 ? diff_weights_md(2 + is_lstm_peephole()) in arg_md()
356 return diff_weights_md( in arg_md()
374 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
H A Dbatch_normalization_pd.hpp250 case DNNL_ARG_DIFF_SHIFT: return diff_weights_md(0); in arg_md()
268 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
284 + (!types::is_zero_md(diff_weights_md())) in n_outputs()
310 diff_weights_md()->data_type)); in check_scale_shift_data_type()
H A Ddeconvolution_pd.hpp143 ? diff_weights_md(index) in invariant_wei_md()
295 case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0); in arg_md()
296 case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1); in arg_md()
308 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
H A Dinner_product_pd.hpp139 ? diff_weights_md(index) in invariant_wei_md()
346 case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0); in arg_md()
347 case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1); in arg_md()
359 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
H A Dlayer_normalization_pd.hpp245 case DNNL_ARG_DIFF_SHIFT: return diff_weights_md(0); in arg_md()
266 const memory_desc_t *diff_weights_md(int index = 0) const override { in diff_weights_md() function
301 diff_weights_md()->data_type)); in check_scale_shift_data_type()
/dports/misc/mxnet/incubator-mxnet-1.9.0/3rdparty/mkldnn/src/cpu/x64/prelu/
H A Djit_prelu_reduction_kernel.cpp38 , data_type_(pd->diff_weights_md(0)->data_type) in jit_prelu_reduction_kernel_t()
40 , tail_block_size_(prelu::get_block_tail_size(pd->diff_weights_md(0))) in jit_prelu_reduction_kernel_t()
41 , c_blk_nelems_(prelu::c_blk_nelems(pd->diff_weights_md(0), false)) {} in jit_prelu_reduction_kernel_t()
206 if (isa == avx && prelu::is_s8u8({pd->diff_weights_md(0)->data_type})) in create()
/dports/math/onednn/oneDNN-2.5.1/src/cpu/x64/prelu/
H A Djit_prelu_reduction_kernel.cpp38 , data_type_(pd->diff_weights_md(0)->data_type) in jit_prelu_reduction_kernel_t()
40 , tail_block_size_(prelu::get_block_tail_size(pd->diff_weights_md(0))) in jit_prelu_reduction_kernel_t()
41 , c_blk_nelems_(prelu::c_blk_nelems(pd->diff_weights_md(0), false)) {} in jit_prelu_reduction_kernel_t()
206 if (isa == avx && prelu::is_s8u8({pd->diff_weights_md(0)->data_type})) in create()
/dports/misc/mxnet/incubator-mxnet-1.9.0/3rdparty/mkldnn/src/cpu/
H A Dgemm_inner_product.hpp144 diff_weights_md()->data_type, in init()
146 with_bias() ? diff_weights_md(1)->data_type in init()
151 src_md(), diff_weights_md(), diff_dst_md()); in init()
/dports/math/onednn/oneDNN-2.5.1/src/cpu/
H A Dgemm_inner_product.hpp145 diff_weights_md()->data_type, in init()
147 with_bias() ? diff_weights_md(1)->data_type in init()
152 src_md(), diff_weights_md(), diff_dst_md()); in init()
/dports/misc/mxnet/incubator-mxnet-1.9.0/3rdparty/mkldnn/src/cpu/x64/
H A Dgemm_bf16_inner_product.hpp208 && diff_wei_data_type == diff_weights_md()->data_type in init()
210 one_of(diff_weights_md(1)->data_type, f32, bf16)) in init()
214 src_md(), diff_weights_md(), diff_dst_md()); in init()
258 && diff_weights_md(1)->data_type == data_type::f32; in init_scratchpad()
/dports/math/onednn/oneDNN-2.5.1/src/cpu/x64/
H A Dgemm_bf16_inner_product.hpp208 && diff_wei_data_type == diff_weights_md()->data_type in init()
210 one_of(diff_weights_md(1)->data_type, f32, bf16)) in init()
214 src_md(), diff_weights_md(), diff_dst_md()); in init()
258 && diff_weights_md(1)->data_type == data_type::f32; in init_scratchpad()

1234567