1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_REF_DECONVOLUTION_HPP
18 #define CPU_REF_DECONVOLUTION_HPP
19 
20 #include <assert.h>
21 #include <string.h>
22 
23 #include "common/c_types_map.hpp"
24 #include "common/primitive.hpp"
25 #include "common/primitive_iterator.hpp"
26 #include "common/stream.hpp"
27 #include "common/type_helpers.hpp"
28 #include "common/utils.hpp"
29 
30 #include "cpu/primitive_attr_postops.hpp"
31 
32 #include "cpu/cpu_convolution_pd.hpp"
33 #include "cpu/cpu_deconvolution_pd.hpp"
34 
35 namespace dnnl {
36 namespace impl {
37 namespace cpu {
38 
weights_axes_permutation(memory_desc_t * o_md,const memory_desc_t * i_md,bool with_groups)39 static status_t weights_axes_permutation(
40         memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) {
41     int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation
42     for (int d = 0; d < DNNL_MAX_NDIMS; ++d)
43         perm[d] = d;
44     nstl::swap(perm[0 + with_groups], perm[1 + with_groups]);
45 
46     return dnnl_memory_desc_permute_axes(o_md, i_md, perm);
47 }
48 
conv_descr_create(const deconvolution_desc_t * dd,convolution_desc_t * cd,const memory_desc_t * bias_md=nullptr,data_type_t src_dt=data_type::undef)49 static status_t conv_descr_create(const deconvolution_desc_t *dd,
50         convolution_desc_t *cd, const memory_desc_t *bias_md = nullptr,
51         data_type_t src_dt = data_type::undef) {
52     using namespace prop_kind;
53     alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct
54             ? alg_kind::convolution_direct
55             : alg_kind::convolution_winograd;
56 
57     const memory_desc_t *src_md, *dst_md, *d_weights_d;
58     memory_desc_t src_md_patched;
59     prop_kind_t prop_kind;
60 
61     if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) {
62         prop_kind = backward_data;
63         assert(src_dt != data_type::undef);
64         memory_desc_init_by_md_and_dt(src_md_patched, dd->dst_desc, src_dt);
65         src_md = &src_md_patched;
66         dst_md = &dd->src_desc;
67         d_weights_d = &dd->weights_desc;
68     } else if (dd->prop_kind == backward_data) {
69         assert(src_dt == data_type::undef);
70         prop_kind = forward_training;
71         src_md = &dd->diff_dst_desc;
72         dst_md = &dd->diff_src_desc;
73         d_weights_d = &dd->weights_desc;
74     } else {
75         assert(src_dt == data_type::undef);
76         prop_kind = dd->prop_kind;
77         src_md = &dd->diff_dst_desc;
78         dst_md = &dd->src_desc;
79         d_weights_d = &dd->diff_weights_desc;
80     }
81 
82     /* create weights desc for convolution */
83     memory_desc_t c_weights_d;
84     const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
85     CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups));
86 
87     return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
88             bias_md, dst_md, dd->strides, dd->dilates, dd->padding[0],
89             dd->padding[1]);
90 }
91 
92 struct ref_deconvolution_fwd_t : public primitive_t {
93     struct pd_t : public cpu_deconvolution_fwd_pd_t {
pd_tdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t94         pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
95                 const deconvolution_fwd_pd_t *hint_fwd_pd)
96             : cpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
97 
pd_tdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t98         pd_t(const pd_t &other)
99             : cpu_deconvolution_fwd_pd_t(other)
100             , conv_pd_(other.conv_pd_->clone())
101             , conv_supports_bias_(other.conv_supports_bias_)
102             , dst_tag_(other.dst_tag_) {}
103 
104         ~pd_t() = default;
105 
106         DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t);
107 
init_convolutiondnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t108         status_t init_convolution(engine_t *engine) {
109             using namespace format_tag;
110             using namespace data_type;
111 
112             // Create empty attributes for bwd_d conv to pick up the fastest
113             // impl available and apply post-ops and/or bias update later in
114             // this impl via simple loop.
115             primitive_attr_t conv_attr;
116 
117             convolution_desc_t cd;
118             // When no attributes were requested, try to find a bwd_d conv impl
119             // which supports bias update in-place, if requested, in requested
120             // dst_dt. If appropriate conv impl was not found, enforce f32
121             // diff_src for conv for correct result. If attributes are
122             // requested, enforce conv impl to return f32 output no matter what.
123             if (attr()->has_default_values()) {
124                 CHECK(conv_descr_create(
125                         desc(), &cd, weights_md(1), dst_md()->data_type));
126                 dnnl_primitive_desc_iterator it(
127                         engine, (op_desc_t *)&cd, &conv_attr, nullptr);
128                 if (!it.is_initialized()) return status::out_of_memory;
129 
130                 while (++it != it.end()) {
131                     conv_pd_ = *it;
132                     if (with_bias()) {
133                         conv_supports_bias_ = utils::downcast<
134                                 cpu_convolution_bwd_data_pd_t *>(conv_pd_.get())
135                                                       ->support_bias();
136                         if (!conv_supports_bias_) continue;
137                     }
138                     bool ok = conv_pd_->weights_md()->extra.flags == 0;
139                     if (ok) return status::success;
140                 }
141             }
142 
143             // Intermediate f32 buffer is supported only for given condition.
144             if (!attr()->has_default_values() || with_bias()) {
145                 // Enforce f32 dt for diff src and work with f32 output for bias
146                 // update or post ops after conv execution.
147                 CHECK(conv_descr_create(desc(), &cd, nullptr, data_type::f32));
148                 dnnl_primitive_desc_iterator it(
149                         engine, (op_desc_t *)&cd, &conv_attr, nullptr);
150                 if (!it.is_initialized()) return status::out_of_memory;
151 
152                 while (++it != it.end()) {
153                     conv_pd_ = *it;
154                     bool ok = conv_pd_->weights_md()->extra.flags == 0;
155                     if (ok) return status::success;
156                 }
157             }
158             return status::unimplemented;
159         }
160 
initdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t161         status_t init(engine_t *engine) {
162             using namespace format_tag;
163             using smask_t = primitive_attr_t::skip_mask_t;
164 
165             const bool ok = is_fwd()
166                     && utils::one_of(desc()->alg_kind,
167                             alg_kind::deconvolution_direct,
168                             alg_kind::deconvolution_winograd)
169                     && attr()->has_default_values(smask_t::oscale
170                             | smask_t::post_ops | smask_t::zero_points_runtime)
171                     && output_scales_mask_ok() && post_ops_ok()
172                     && zero_points_ok();
173             if (!ok) return status::unimplemented;
174 
175             CHECK(init_convolution(engine));
176 
177             if (weights_md_.format_kind == format_kind::any)
178                 CHECK(weights_axes_permutation(
179                         &weights_md_, conv_pd_->weights_md(), with_groups()));
180             if (src_md_.format_kind == format_kind::any)
181                 src_md_ = *conv_pd_->diff_dst_md();
182             if (dst_md_.format_kind == format_kind::any) {
183                 // re-apply dt manually since it could be changed due to bias
184                 const auto dst_dt = dst_md_.data_type;
185                 memory_desc_init_by_md_and_dt(
186                         dst_md_, *conv_pd_->diff_src_md(), dst_dt);
187             }
188             if (bias_md_.format_kind == format_kind::any)
189                 CHECK(memory_desc_init_by_tag(bias_md_, x));
190 
191             dst_tag_ = memory_desc_matches_one_of_tag(dst_md_,
192                     utils::pick(ndims() - 3, ncw, nchw, ncdhw),
193                     utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
194                     utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
195                     utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
196 
197             init_scratchpad();
198             return attr_.set_default_formats(dst_md(0));
199         }
200 
201         std::shared_ptr<primitive_desc_t> conv_pd_;
202         bool conv_supports_bias_ = false;
203         format_tag_t dst_tag_;
204 
205     private:
init_scratchpaddnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t206         void init_scratchpad() {
207             using namespace memory_tracking::names;
208             auto scratchpad = scratchpad_registry().registrar();
209             scratchpad.book(key_nested, conv_pd_->scratchpad_registry());
210 
211             // This scratchpad is required for intermediate f32 conv output
212             // since original memory can be of smaller size and will cause
213             // out of boundary access.
214             if ((with_bias() && !conv_supports_bias_)
215                     || !attr()->has_default_values()) {
216                 const memory_desc_wrapper diff_src_d(conv_pd_->diff_src_md());
217                 assert(diff_src_d.data_type_size() == sizeof(float));
218                 scratchpad.book(key_deconv_bias, diff_src_d.nelems(true),
219                         diff_src_d.data_type_size());
220             }
221             // This scratchpad is required to stash original dst memory for sum
222             // post-op. It will be overwritten by conv execution and will not
223             // be available to get the correct result.
224             const memory_desc_wrapper dst_d(dst_md());
225             if (attr()->post_ops_.find(primitive_kind::sum) != -1)
226                 scratchpad.book(key_deconv_sum, dst_d.nelems(true),
227                         dst_d.data_type_size());
228 
229             if (!attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) {
230                 scratchpad.book<int32_t>(key_deconv_zp, OC() * G());
231             }
232         }
233 
output_scales_mask_okdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t234         bool output_scales_mask_ok() const {
235             using namespace data_type;
236             const auto &mask = attr()->output_scales_.mask_;
237             return IMPLICATION(!utils::one_of(src_md()->data_type, s8, u8),
238                            attr()->output_scales_.has_default_values())
239                     && (mask == 0 || mask == 1 << 1);
240         }
241 
post_ops_okdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t242         bool post_ops_ok() const {
243             return attr()->post_ops_.find(primitive_kind::convolution) == -1;
244         }
245 
zero_points_okdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t246         bool zero_points_ok() const {
247             using namespace data_type;
248             int mask_src = 0, mask_dst = 0;
249             attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, &mask_src, nullptr);
250             attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &mask_dst, nullptr);
251 
252             return IMPLICATION(!utils::one_of(src_md()->data_type, s8, u8),
253                            attr()->zero_points_.has_default_values())
254                     && attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
255                     && (mask_src == 0 || mask_src == 1 << 1)
256                     && (mask_dst == 0 || mask_dst == 1 << 1);
257         }
258     };
259 
ref_deconvolution_fwd_tdnnl::impl::cpu::ref_deconvolution_fwd_t260     ref_deconvolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
261 
initdnnl::impl::cpu::ref_deconvolution_fwd_t262     status_t init(engine_t *engine) override {
263         CHECK(pd()->conv_pd_->create_primitive(conv_p_, engine));
264 
265         ref_post_ops
266                 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
267         if (!ref_post_ops) return status::out_of_memory;
268         return status::success;
269     }
270 
271     status_t execute(const exec_ctx_t &ctx) const override;
272 
273 private:
274     void compute_fwd_bias_common(const exec_ctx_t &ctx, void *dst,
275             const float *conv_output, bool non_default_attr) const;
276 
277     void compute_fwd_bias_ncdhw(const exec_ctx_t &ctx, void *dst,
278             const float *conv_output, bool non_default_attr) const;
279 
280     void compute_fwd_bias_ndhwc(const exec_ctx_t &ctx, void *dst,
281             const float *conv_output, bool non_default_attr) const;
282 
283     template <dim_t blk_size>
284     void compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx, void *dst,
285             const float *conv_output, bool non_default_attr) const;
286 
287     void compute_fwd_bias(const exec_ctx_t &ctx, void *dst,
288             const float *conv_output, bool non_default_attr) const;
289 
290     status_t compute_ref_attrs(const exec_ctx_t &ctx, const float *conv_output,
291             void *original_dst) const;
292 
pddnnl::impl::cpu::ref_deconvolution_fwd_t293     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
294     std::shared_ptr<primitive_t> conv_p_;
295     std::unique_ptr<ref_post_ops_t> ref_post_ops;
296 };
297 
298 struct ref_deconvolution_bwd_data_t : public primitive_t {
299     struct pd_t : public cpu_deconvolution_bwd_data_pd_t {
pd_tdnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t300         pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
301                 const deconvolution_fwd_pd_t *hint_fwd_pd)
302             : cpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
303 
pd_tdnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t304         pd_t(const pd_t &other)
305             : cpu_deconvolution_bwd_data_pd_t(other)
306             , conv_pd_(other.conv_pd_->clone()) {}
307 
308         ~pd_t() = default;
309 
310         DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t);
311 
init_convolutiondnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t312         status_t init_convolution(engine_t *engine) {
313             using namespace types;
314 
315             convolution_desc_t cd;
316             status_t status = conv_descr_create(desc(), &cd);
317             if (status != status::success) return status;
318             primitive_attr_t conv_attr(*attr());
319             if (!conv_attr.is_initialized()) return status::out_of_memory;
320 
321             dnnl_primitive_desc_iterator it(
322                     engine, (op_desc_t *)&cd, &conv_attr, nullptr);
323             if (!it.is_initialized()) return status::out_of_memory;
324             while (++it != it.end()) {
325                 conv_pd_ = *it;
326                 if (conv_pd_->weights_md()->extra.flags == 0)
327                     return status::success;
328             }
329 
330             return status::unimplemented;
331         }
332 
initdnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t333         status_t init(engine_t *engine) {
334             using namespace data_type;
335             auto dsrc_type = desc()->diff_src_desc.data_type;
336             auto wei_type = desc()->weights_desc.data_type;
337             auto ddst_type = desc()->diff_dst_desc.data_type;
338             bool ok = true && desc()->prop_kind == prop_kind::backward_data
339                     && (utils::everyone_is(f32, dsrc_type, wei_type, ddst_type)
340                             || (utils::one_of(dsrc_type, f32, bf16)
341                                     && utils::everyone_is(
342                                             bf16, wei_type, ddst_type)))
343                     && utils::one_of(desc()->alg_kind,
344                             alg_kind::deconvolution_direct,
345                             alg_kind::deconvolution_winograd)
346                     && attr()->has_default_values();
347 
348             if (ok) {
349                 CHECK(init_convolution(engine));
350                 if (weights_md_.format_kind == format_kind::any)
351                     CHECK(weights_axes_permutation(&weights_md_,
352                             conv_pd_->weights_md(), with_groups()));
353                 if (diff_src_md_.format_kind == format_kind::any)
354                     diff_src_md_ = *conv_pd_->dst_md();
355                 if (diff_dst_md_.format_kind == format_kind::any)
356                     diff_dst_md_ = *conv_pd_->src_md();
357                 init_scratchpad();
358                 return status::success;
359             }
360 
361             return status::unimplemented;
362         }
363 
364         std::shared_ptr<primitive_desc_t> conv_pd_;
365 
366     private:
init_scratchpaddnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t367         void init_scratchpad() {
368             auto scratchpad = scratchpad_registry().registrar();
369             scratchpad.book(memory_tracking::names::key_nested,
370                     conv_pd_->scratchpad_registry());
371         }
372     };
373 
374     typedef typename prec_traits<data_type::f32>::type data_t;
375 
ref_deconvolution_bwd_data_tdnnl::impl::cpu::ref_deconvolution_bwd_data_t376     ref_deconvolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
377 
initdnnl::impl::cpu::ref_deconvolution_bwd_data_t378     status_t init(engine_t *engine) override {
379         return pd()->conv_pd_->create_primitive(conv_p_, engine);
380     }
381 
382     status_t execute(const exec_ctx_t &ctx) const override;
383 
384 private:
pddnnl::impl::cpu::ref_deconvolution_bwd_data_t385     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
386     std::shared_ptr<primitive_t> conv_p_;
387 };
388 
389 struct ref_deconvolution_bwd_weights_t : public primitive_t {
390     struct pd_t : public cpu_deconvolution_bwd_weights_pd_t {
pd_tdnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t391         pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
392                 const deconvolution_fwd_pd_t *hint_fwd_pd)
393             : cpu_deconvolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
394 
pd_tdnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t395         pd_t(const pd_t &other)
396             : cpu_deconvolution_bwd_weights_pd_t(other)
397             , conv_pd_(other.conv_pd_->clone())
398             , dst_tag_(other.dst_tag_) {}
399 
400         ~pd_t() = default;
401 
402         DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t);
403 
init_convolutiondnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t404         status_t init_convolution(engine_t *engine) {
405             using namespace types;
406             using namespace format_tag;
407 
408             convolution_desc_t cd;
409             status_t status = conv_descr_create(desc(), &cd);
410             if (status != status::success) return status;
411             primitive_attr_t conv_attr(*attr());
412             if (!conv_attr.is_initialized()) return status::out_of_memory;
413 
414             dnnl_primitive_desc_iterator it(
415                     engine, (op_desc_t *)&cd, &conv_attr, nullptr);
416             if (!it.is_initialized()) return status::out_of_memory;
417             while (++it != it.end()) {
418                 conv_pd_ = *it;
419                 bool bf16_ref_deconv_supports_bias = IMPLICATION(with_bias()
420                                 && desc()->src_desc.data_type
421                                         == data_type::bf16,
422                         memory_desc_matches_one_of_tag(*conv_pd_->src_md(),
423                                 utils::pick(ndims() - 3, ncw, nchw, ncdhw),
424                                 utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
425                                 utils::pick(ndims() - 3, nCw16c, nChw16c,
426                                         nCdhw16c)));
427                 if (conv_pd_->diff_weights_md()->extra.flags == 0
428                         && bf16_ref_deconv_supports_bias) {
429                     return status::success;
430                 }
431             }
432             return status::unimplemented;
433         }
434 
initdnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t435         status_t init(engine_t *engine) {
436             using namespace format_tag;
437             using namespace data_type;
438             auto src_type = desc()->src_desc.data_type;
439             auto dwei_type = desc()->diff_weights_desc.data_type;
440             auto ddst_type = desc()->diff_dst_desc.data_type;
441             bool ok = true && desc()->prop_kind == prop_kind::backward_weights
442                     && (utils::everyone_is(f32, src_type, dwei_type, ddst_type)
443                             || (utils::one_of(dwei_type, f32, bf16)
444                                     && utils::everyone_is(
445                                             bf16, src_type, ddst_type)))
446                     && utils::one_of(desc()->alg_kind,
447                             alg_kind::deconvolution_direct,
448                             alg_kind::deconvolution_winograd)
449                     && attr()->has_default_values();
450 
451             if (ok) {
452                 CHECK(init_convolution(engine));
453                 if (diff_weights_md_.format_kind == format_kind::any)
454                     CHECK(weights_axes_permutation(&diff_weights_md_,
455                             conv_pd_->diff_weights_md(), with_groups()));
456                 if (src_md_.format_kind == format_kind::any)
457                     src_md_ = *conv_pd_->diff_dst_md();
458                 if (diff_dst_md_.format_kind == format_kind::any)
459                     diff_dst_md_ = *conv_pd_->src_md();
460                 if (diff_bias_md_.format_kind == format_kind::any)
461                     CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
462 
463                 dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_,
464                         utils::pick(ndims() - 3, ncw, nchw, ncdhw),
465                         utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
466                         utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
467                         utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
468                 init_scratchpad();
469                 return status::success;
470             }
471 
472             return status::unimplemented;
473         }
474 
475         std::shared_ptr<primitive_desc_t> conv_pd_;
476         format_tag_t dst_tag_;
477 
478     private:
init_scratchpaddnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t479         void init_scratchpad() {
480             auto scratchpad = scratchpad_registry().registrar();
481             scratchpad.book(memory_tracking::names::key_nested,
482                     conv_pd_->scratchpad_registry());
483         }
484     };
485 
ref_deconvolution_bwd_weights_tdnnl::impl::cpu::ref_deconvolution_bwd_weights_t486     ref_deconvolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
487 
initdnnl::impl::cpu::ref_deconvolution_bwd_weights_t488     status_t init(engine_t *engine) override {
489         return pd()->conv_pd_->create_primitive(conv_p_, engine);
490     }
491 
492     status_t execute(const exec_ctx_t &ctx) const override;
493 
494 private:
pddnnl::impl::cpu::ref_deconvolution_bwd_weights_t495     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
496     void compute_bwd_bias(float *diff_bias, const float *diff_dst) const;
497 
498     template <data_type_t dbia_type, data_type_t ddst_type>
499     void compute_bwd_bias_ncdhw(
500             typename prec_traits<dbia_type>::type *diff_bias,
501             const typename prec_traits<ddst_type>::type *diff_dst) const;
502 
503     template <data_type_t dbia_type, data_type_t ddst_type>
504     void compute_bwd_bias_ndhwc(
505             typename prec_traits<dbia_type>::type *diff_bias,
506             const typename prec_traits<ddst_type>::type *diff_dst) const;
507 
508     template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize>
509     void compute_bwd_bias_nCdhwXc(
510             typename prec_traits<dbia_type>::type *diff_bias,
511             const typename prec_traits<ddst_type>::type *diff_dst) const;
512 
513     template <data_type_t dbia_type, data_type_t ddst_type>
514     void compute_bias(const exec_ctx_t &ctx) const;
515     std::shared_ptr<primitive_t> conv_p_;
516 };
517 
518 } // namespace cpu
519 } // namespace impl
520 } // namespace dnnl
521 
522 #endif
523 
524 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
525