1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_REF_DECONVOLUTION_HPP
18 #define CPU_REF_DECONVOLUTION_HPP
19
20 #include <assert.h>
21 #include <string.h>
22
23 #include "common/c_types_map.hpp"
24 #include "common/primitive.hpp"
25 #include "common/primitive_iterator.hpp"
26 #include "common/stream.hpp"
27 #include "common/type_helpers.hpp"
28 #include "common/utils.hpp"
29
30 #include "cpu/primitive_attr_postops.hpp"
31
32 #include "cpu/cpu_convolution_pd.hpp"
33 #include "cpu/cpu_deconvolution_pd.hpp"
34
35 namespace dnnl {
36 namespace impl {
37 namespace cpu {
38
weights_axes_permutation(memory_desc_t * o_md,const memory_desc_t * i_md,bool with_groups)39 static status_t weights_axes_permutation(
40 memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) {
41 int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation
42 for (int d = 0; d < DNNL_MAX_NDIMS; ++d)
43 perm[d] = d;
44 nstl::swap(perm[0 + with_groups], perm[1 + with_groups]);
45
46 return dnnl_memory_desc_permute_axes(o_md, i_md, perm);
47 }
48
conv_descr_create(const deconvolution_desc_t * dd,convolution_desc_t * cd,const memory_desc_t * bias_md=nullptr,data_type_t src_dt=data_type::undef)49 static status_t conv_descr_create(const deconvolution_desc_t *dd,
50 convolution_desc_t *cd, const memory_desc_t *bias_md = nullptr,
51 data_type_t src_dt = data_type::undef) {
52 using namespace prop_kind;
53 alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct
54 ? alg_kind::convolution_direct
55 : alg_kind::convolution_winograd;
56
57 const memory_desc_t *src_md, *dst_md, *d_weights_d;
58 memory_desc_t src_md_patched;
59 prop_kind_t prop_kind;
60
61 if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) {
62 prop_kind = backward_data;
63 assert(src_dt != data_type::undef);
64 memory_desc_init_by_md_and_dt(src_md_patched, dd->dst_desc, src_dt);
65 src_md = &src_md_patched;
66 dst_md = &dd->src_desc;
67 d_weights_d = &dd->weights_desc;
68 } else if (dd->prop_kind == backward_data) {
69 assert(src_dt == data_type::undef);
70 prop_kind = forward_training;
71 src_md = &dd->diff_dst_desc;
72 dst_md = &dd->diff_src_desc;
73 d_weights_d = &dd->weights_desc;
74 } else {
75 assert(src_dt == data_type::undef);
76 prop_kind = dd->prop_kind;
77 src_md = &dd->diff_dst_desc;
78 dst_md = &dd->src_desc;
79 d_weights_d = &dd->diff_weights_desc;
80 }
81
82 /* create weights desc for convolution */
83 memory_desc_t c_weights_d;
84 const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
85 CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups));
86
87 return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
88 bias_md, dst_md, dd->strides, dd->dilates, dd->padding[0],
89 dd->padding[1]);
90 }
91
92 struct ref_deconvolution_fwd_t : public primitive_t {
93 struct pd_t : public cpu_deconvolution_fwd_pd_t {
pd_tdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t94 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
95 const deconvolution_fwd_pd_t *hint_fwd_pd)
96 : cpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
97
pd_tdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t98 pd_t(const pd_t &other)
99 : cpu_deconvolution_fwd_pd_t(other)
100 , conv_pd_(other.conv_pd_->clone())
101 , conv_supports_bias_(other.conv_supports_bias_)
102 , dst_tag_(other.dst_tag_) {}
103
104 ~pd_t() = default;
105
106 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t);
107
init_convolutiondnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t108 status_t init_convolution(engine_t *engine) {
109 using namespace format_tag;
110 using namespace data_type;
111
112 // Create empty attributes for bwd_d conv to pick up the fastest
113 // impl available and apply post-ops and/or bias update later in
114 // this impl via simple loop.
115 primitive_attr_t conv_attr;
116
117 convolution_desc_t cd;
118 // When no attributes were requested, try to find a bwd_d conv impl
119 // which supports bias update in-place, if requested, in requested
120 // dst_dt. If appropriate conv impl was not found, enforce f32
121 // diff_src for conv for correct result. If attributes are
122 // requested, enforce conv impl to return f32 output no matter what.
123 if (attr()->has_default_values()) {
124 CHECK(conv_descr_create(
125 desc(), &cd, weights_md(1), dst_md()->data_type));
126 dnnl_primitive_desc_iterator it(
127 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
128 if (!it.is_initialized()) return status::out_of_memory;
129
130 while (++it != it.end()) {
131 conv_pd_ = *it;
132 if (with_bias()) {
133 conv_supports_bias_ = utils::downcast<
134 cpu_convolution_bwd_data_pd_t *>(conv_pd_.get())
135 ->support_bias();
136 if (!conv_supports_bias_) continue;
137 }
138 bool ok = conv_pd_->weights_md()->extra.flags == 0;
139 if (ok) return status::success;
140 }
141 }
142
143 // Intermediate f32 buffer is supported only for given condition.
144 if (!attr()->has_default_values() || with_bias()) {
145 // Enforce f32 dt for diff src and work with f32 output for bias
146 // update or post ops after conv execution.
147 CHECK(conv_descr_create(desc(), &cd, nullptr, data_type::f32));
148 dnnl_primitive_desc_iterator it(
149 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
150 if (!it.is_initialized()) return status::out_of_memory;
151
152 while (++it != it.end()) {
153 conv_pd_ = *it;
154 bool ok = conv_pd_->weights_md()->extra.flags == 0;
155 if (ok) return status::success;
156 }
157 }
158 return status::unimplemented;
159 }
160
initdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t161 status_t init(engine_t *engine) {
162 using namespace format_tag;
163 using smask_t = primitive_attr_t::skip_mask_t;
164
165 const bool ok = is_fwd()
166 && utils::one_of(desc()->alg_kind,
167 alg_kind::deconvolution_direct,
168 alg_kind::deconvolution_winograd)
169 && attr()->has_default_values(smask_t::oscale
170 | smask_t::post_ops | smask_t::zero_points_runtime)
171 && output_scales_mask_ok() && post_ops_ok()
172 && zero_points_ok();
173 if (!ok) return status::unimplemented;
174
175 CHECK(init_convolution(engine));
176
177 if (weights_md_.format_kind == format_kind::any)
178 CHECK(weights_axes_permutation(
179 &weights_md_, conv_pd_->weights_md(), with_groups()));
180 if (src_md_.format_kind == format_kind::any)
181 src_md_ = *conv_pd_->diff_dst_md();
182 if (dst_md_.format_kind == format_kind::any) {
183 // re-apply dt manually since it could be changed due to bias
184 const auto dst_dt = dst_md_.data_type;
185 memory_desc_init_by_md_and_dt(
186 dst_md_, *conv_pd_->diff_src_md(), dst_dt);
187 }
188 if (bias_md_.format_kind == format_kind::any)
189 CHECK(memory_desc_init_by_tag(bias_md_, x));
190
191 dst_tag_ = memory_desc_matches_one_of_tag(dst_md_,
192 utils::pick(ndims() - 3, ncw, nchw, ncdhw),
193 utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
194 utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
195 utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
196
197 init_scratchpad();
198 return attr_.set_default_formats(dst_md(0));
199 }
200
201 std::shared_ptr<primitive_desc_t> conv_pd_;
202 bool conv_supports_bias_ = false;
203 format_tag_t dst_tag_;
204
205 private:
init_scratchpaddnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t206 void init_scratchpad() {
207 using namespace memory_tracking::names;
208 auto scratchpad = scratchpad_registry().registrar();
209 scratchpad.book(key_nested, conv_pd_->scratchpad_registry());
210
211 // This scratchpad is required for intermediate f32 conv output
212 // since original memory can be of smaller size and will cause
213 // out of boundary access.
214 if ((with_bias() && !conv_supports_bias_)
215 || !attr()->has_default_values()) {
216 const memory_desc_wrapper diff_src_d(conv_pd_->diff_src_md());
217 assert(diff_src_d.data_type_size() == sizeof(float));
218 scratchpad.book(key_deconv_bias, diff_src_d.nelems(true),
219 diff_src_d.data_type_size());
220 }
221 // This scratchpad is required to stash original dst memory for sum
222 // post-op. It will be overwritten by conv execution and will not
223 // be available to get the correct result.
224 const memory_desc_wrapper dst_d(dst_md());
225 if (attr()->post_ops_.find(primitive_kind::sum) != -1)
226 scratchpad.book(key_deconv_sum, dst_d.nelems(true),
227 dst_d.data_type_size());
228
229 if (!attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) {
230 scratchpad.book<int32_t>(key_deconv_zp, OC() * G());
231 }
232 }
233
output_scales_mask_okdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t234 bool output_scales_mask_ok() const {
235 using namespace data_type;
236 const auto &mask = attr()->output_scales_.mask_;
237 return IMPLICATION(!utils::one_of(src_md()->data_type, s8, u8),
238 attr()->output_scales_.has_default_values())
239 && (mask == 0 || mask == 1 << 1);
240 }
241
post_ops_okdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t242 bool post_ops_ok() const {
243 return attr()->post_ops_.find(primitive_kind::convolution) == -1;
244 }
245
zero_points_okdnnl::impl::cpu::ref_deconvolution_fwd_t::pd_t246 bool zero_points_ok() const {
247 using namespace data_type;
248 int mask_src = 0, mask_dst = 0;
249 attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, &mask_src, nullptr);
250 attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &mask_dst, nullptr);
251
252 return IMPLICATION(!utils::one_of(src_md()->data_type, s8, u8),
253 attr()->zero_points_.has_default_values())
254 && attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
255 && (mask_src == 0 || mask_src == 1 << 1)
256 && (mask_dst == 0 || mask_dst == 1 << 1);
257 }
258 };
259
ref_deconvolution_fwd_tdnnl::impl::cpu::ref_deconvolution_fwd_t260 ref_deconvolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
261
initdnnl::impl::cpu::ref_deconvolution_fwd_t262 status_t init(engine_t *engine) override {
263 CHECK(pd()->conv_pd_->create_primitive(conv_p_, engine));
264
265 ref_post_ops
266 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
267 if (!ref_post_ops) return status::out_of_memory;
268 return status::success;
269 }
270
271 status_t execute(const exec_ctx_t &ctx) const override;
272
273 private:
274 void compute_fwd_bias_common(const exec_ctx_t &ctx, void *dst,
275 const float *conv_output, bool non_default_attr) const;
276
277 void compute_fwd_bias_ncdhw(const exec_ctx_t &ctx, void *dst,
278 const float *conv_output, bool non_default_attr) const;
279
280 void compute_fwd_bias_ndhwc(const exec_ctx_t &ctx, void *dst,
281 const float *conv_output, bool non_default_attr) const;
282
283 template <dim_t blk_size>
284 void compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx, void *dst,
285 const float *conv_output, bool non_default_attr) const;
286
287 void compute_fwd_bias(const exec_ctx_t &ctx, void *dst,
288 const float *conv_output, bool non_default_attr) const;
289
290 status_t compute_ref_attrs(const exec_ctx_t &ctx, const float *conv_output,
291 void *original_dst) const;
292
pddnnl::impl::cpu::ref_deconvolution_fwd_t293 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
294 std::shared_ptr<primitive_t> conv_p_;
295 std::unique_ptr<ref_post_ops_t> ref_post_ops;
296 };
297
298 struct ref_deconvolution_bwd_data_t : public primitive_t {
299 struct pd_t : public cpu_deconvolution_bwd_data_pd_t {
pd_tdnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t300 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
301 const deconvolution_fwd_pd_t *hint_fwd_pd)
302 : cpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
303
pd_tdnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t304 pd_t(const pd_t &other)
305 : cpu_deconvolution_bwd_data_pd_t(other)
306 , conv_pd_(other.conv_pd_->clone()) {}
307
308 ~pd_t() = default;
309
310 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t);
311
init_convolutiondnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t312 status_t init_convolution(engine_t *engine) {
313 using namespace types;
314
315 convolution_desc_t cd;
316 status_t status = conv_descr_create(desc(), &cd);
317 if (status != status::success) return status;
318 primitive_attr_t conv_attr(*attr());
319 if (!conv_attr.is_initialized()) return status::out_of_memory;
320
321 dnnl_primitive_desc_iterator it(
322 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
323 if (!it.is_initialized()) return status::out_of_memory;
324 while (++it != it.end()) {
325 conv_pd_ = *it;
326 if (conv_pd_->weights_md()->extra.flags == 0)
327 return status::success;
328 }
329
330 return status::unimplemented;
331 }
332
initdnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t333 status_t init(engine_t *engine) {
334 using namespace data_type;
335 auto dsrc_type = desc()->diff_src_desc.data_type;
336 auto wei_type = desc()->weights_desc.data_type;
337 auto ddst_type = desc()->diff_dst_desc.data_type;
338 bool ok = true && desc()->prop_kind == prop_kind::backward_data
339 && (utils::everyone_is(f32, dsrc_type, wei_type, ddst_type)
340 || (utils::one_of(dsrc_type, f32, bf16)
341 && utils::everyone_is(
342 bf16, wei_type, ddst_type)))
343 && utils::one_of(desc()->alg_kind,
344 alg_kind::deconvolution_direct,
345 alg_kind::deconvolution_winograd)
346 && attr()->has_default_values();
347
348 if (ok) {
349 CHECK(init_convolution(engine));
350 if (weights_md_.format_kind == format_kind::any)
351 CHECK(weights_axes_permutation(&weights_md_,
352 conv_pd_->weights_md(), with_groups()));
353 if (diff_src_md_.format_kind == format_kind::any)
354 diff_src_md_ = *conv_pd_->dst_md();
355 if (diff_dst_md_.format_kind == format_kind::any)
356 diff_dst_md_ = *conv_pd_->src_md();
357 init_scratchpad();
358 return status::success;
359 }
360
361 return status::unimplemented;
362 }
363
364 std::shared_ptr<primitive_desc_t> conv_pd_;
365
366 private:
init_scratchpaddnnl::impl::cpu::ref_deconvolution_bwd_data_t::pd_t367 void init_scratchpad() {
368 auto scratchpad = scratchpad_registry().registrar();
369 scratchpad.book(memory_tracking::names::key_nested,
370 conv_pd_->scratchpad_registry());
371 }
372 };
373
374 typedef typename prec_traits<data_type::f32>::type data_t;
375
ref_deconvolution_bwd_data_tdnnl::impl::cpu::ref_deconvolution_bwd_data_t376 ref_deconvolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
377
initdnnl::impl::cpu::ref_deconvolution_bwd_data_t378 status_t init(engine_t *engine) override {
379 return pd()->conv_pd_->create_primitive(conv_p_, engine);
380 }
381
382 status_t execute(const exec_ctx_t &ctx) const override;
383
384 private:
pddnnl::impl::cpu::ref_deconvolution_bwd_data_t385 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
386 std::shared_ptr<primitive_t> conv_p_;
387 };
388
389 struct ref_deconvolution_bwd_weights_t : public primitive_t {
390 struct pd_t : public cpu_deconvolution_bwd_weights_pd_t {
pd_tdnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t391 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
392 const deconvolution_fwd_pd_t *hint_fwd_pd)
393 : cpu_deconvolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
394
pd_tdnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t395 pd_t(const pd_t &other)
396 : cpu_deconvolution_bwd_weights_pd_t(other)
397 , conv_pd_(other.conv_pd_->clone())
398 , dst_tag_(other.dst_tag_) {}
399
400 ~pd_t() = default;
401
402 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t);
403
init_convolutiondnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t404 status_t init_convolution(engine_t *engine) {
405 using namespace types;
406 using namespace format_tag;
407
408 convolution_desc_t cd;
409 status_t status = conv_descr_create(desc(), &cd);
410 if (status != status::success) return status;
411 primitive_attr_t conv_attr(*attr());
412 if (!conv_attr.is_initialized()) return status::out_of_memory;
413
414 dnnl_primitive_desc_iterator it(
415 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
416 if (!it.is_initialized()) return status::out_of_memory;
417 while (++it != it.end()) {
418 conv_pd_ = *it;
419 bool bf16_ref_deconv_supports_bias = IMPLICATION(with_bias()
420 && desc()->src_desc.data_type
421 == data_type::bf16,
422 memory_desc_matches_one_of_tag(*conv_pd_->src_md(),
423 utils::pick(ndims() - 3, ncw, nchw, ncdhw),
424 utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
425 utils::pick(ndims() - 3, nCw16c, nChw16c,
426 nCdhw16c)));
427 if (conv_pd_->diff_weights_md()->extra.flags == 0
428 && bf16_ref_deconv_supports_bias) {
429 return status::success;
430 }
431 }
432 return status::unimplemented;
433 }
434
initdnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t435 status_t init(engine_t *engine) {
436 using namespace format_tag;
437 using namespace data_type;
438 auto src_type = desc()->src_desc.data_type;
439 auto dwei_type = desc()->diff_weights_desc.data_type;
440 auto ddst_type = desc()->diff_dst_desc.data_type;
441 bool ok = true && desc()->prop_kind == prop_kind::backward_weights
442 && (utils::everyone_is(f32, src_type, dwei_type, ddst_type)
443 || (utils::one_of(dwei_type, f32, bf16)
444 && utils::everyone_is(
445 bf16, src_type, ddst_type)))
446 && utils::one_of(desc()->alg_kind,
447 alg_kind::deconvolution_direct,
448 alg_kind::deconvolution_winograd)
449 && attr()->has_default_values();
450
451 if (ok) {
452 CHECK(init_convolution(engine));
453 if (diff_weights_md_.format_kind == format_kind::any)
454 CHECK(weights_axes_permutation(&diff_weights_md_,
455 conv_pd_->diff_weights_md(), with_groups()));
456 if (src_md_.format_kind == format_kind::any)
457 src_md_ = *conv_pd_->diff_dst_md();
458 if (diff_dst_md_.format_kind == format_kind::any)
459 diff_dst_md_ = *conv_pd_->src_md();
460 if (diff_bias_md_.format_kind == format_kind::any)
461 CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
462
463 dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_,
464 utils::pick(ndims() - 3, ncw, nchw, ncdhw),
465 utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
466 utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
467 utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
468 init_scratchpad();
469 return status::success;
470 }
471
472 return status::unimplemented;
473 }
474
475 std::shared_ptr<primitive_desc_t> conv_pd_;
476 format_tag_t dst_tag_;
477
478 private:
init_scratchpaddnnl::impl::cpu::ref_deconvolution_bwd_weights_t::pd_t479 void init_scratchpad() {
480 auto scratchpad = scratchpad_registry().registrar();
481 scratchpad.book(memory_tracking::names::key_nested,
482 conv_pd_->scratchpad_registry());
483 }
484 };
485
ref_deconvolution_bwd_weights_tdnnl::impl::cpu::ref_deconvolution_bwd_weights_t486 ref_deconvolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
487
initdnnl::impl::cpu::ref_deconvolution_bwd_weights_t488 status_t init(engine_t *engine) override {
489 return pd()->conv_pd_->create_primitive(conv_p_, engine);
490 }
491
492 status_t execute(const exec_ctx_t &ctx) const override;
493
494 private:
pddnnl::impl::cpu::ref_deconvolution_bwd_weights_t495 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
496 void compute_bwd_bias(float *diff_bias, const float *diff_dst) const;
497
498 template <data_type_t dbia_type, data_type_t ddst_type>
499 void compute_bwd_bias_ncdhw(
500 typename prec_traits<dbia_type>::type *diff_bias,
501 const typename prec_traits<ddst_type>::type *diff_dst) const;
502
503 template <data_type_t dbia_type, data_type_t ddst_type>
504 void compute_bwd_bias_ndhwc(
505 typename prec_traits<dbia_type>::type *diff_bias,
506 const typename prec_traits<ddst_type>::type *diff_dst) const;
507
508 template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize>
509 void compute_bwd_bias_nCdhwXc(
510 typename prec_traits<dbia_type>::type *diff_bias,
511 const typename prec_traits<ddst_type>::type *diff_dst) const;
512
513 template <data_type_t dbia_type, data_type_t ddst_type>
514 void compute_bias(const exec_ctx_t &ctx) const;
515 std::shared_ptr<primitive_t> conv_p_;
516 };
517
518 } // namespace cpu
519 } // namespace impl
520 } // namespace dnnl
521
522 #endif
523
524 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
525