1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_X64_IP_CONVOLUTION_HPP
18 #define CPU_X64_IP_CONVOLUTION_HPP
19
20 #include <string>
21
22 #include "common/c_types_map.hpp"
23 #include "common/primitive.hpp"
24 #include "common/primitive_iterator.hpp"
25 #include "common/utils.hpp"
26
27 #include "cpu/cpu_convolution_pd.hpp"
28 #include "cpu/cpu_inner_product_pd.hpp"
29
30 #include "cpu/x64/cpu_isa_traits.hpp"
31
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace x64 {
36
37 namespace {
38
reshape_dst(memory_desc_t * o_md,const memory_desc_t * i_md)39 status_t reshape_dst(memory_desc_t *o_md, const memory_desc_t *i_md) {
40 dims_t reduce {};
41 const dim_t ndims = 2; // dst is always nc for inner product
42 // conv to ip: remove spatial
43 for (int d = 0; d < ndims; ++d)
44 reduce[d] = i_md->dims[d];
45
46 return dnnl_memory_desc_reshape(o_md, i_md, ndims, reduce);
47 }
48
maybe_reshape_weights(memory_desc_t * o_md,const memory_desc_t * i_md,bool with_groups,bool to_ip=false)49 status_t maybe_reshape_weights(memory_desc_t *o_md, const memory_desc_t *i_md,
50 bool with_groups, bool to_ip = false) {
51 dims_t reduce {};
52 const dim_t ndims = i_md->ndims + (to_ip ? -1 : +1) * with_groups;
53 if (to_ip) {
54 // conv to ip: maybe remove groups
55 for (int d = 0; d < ndims; ++d)
56 reduce[d] = i_md->dims[d + with_groups];
57 } else {
58 // ip to conv: maybe restore groups
59 if (with_groups) reduce[0] = 1;
60 for (int d = 0; d < ndims; ++d)
61 reduce[d + with_groups] = i_md->dims[d];
62 }
63
64 return dnnl_memory_desc_reshape(o_md, i_md, ndims, reduce);
65 }
66
check_conv_ip(convolution_pd_t * self)67 status_t check_conv_ip(convolution_pd_t *self) {
68 // Check if convolution is equivalent to inner product
69 const bool is_ip_applicable = true
70 // no dilations
71 && utils::everyone_is(0, self->KDD(), self->KDH(), self->KDW())
72 // no "left" padding
73 && utils::everyone_is(
74 0, self->padFront(), self->padT(), self->padL())
75 // no "right" padding
76 && utils::everyone_is(
77 0, self->padBack(), self->padB(), self->padR())
78 // no non-trivial groups or output spatial
79 && utils::everyone_is(
80 1, self->G(), self->OD(), self->OH(), self->OW())
81 // only unit stride
82 && utils::everyone_is(1, self->KSD(), self->KSH(), self->KSW());
83 if (!is_ip_applicable) return status::unimplemented;
84
85 // Simple heuristic to only target arches and shapes that benefit.
86 // TODO: Extend to other arches and shapes as performance allows.
87 const dim_t ks = self->KD() * self->KH() * self->KW();
88 const dim_t ks_threshold = 27; // empirical
89 const bool is_performant
90 = 1 < self->MB() && ks > ks_threshold && mayiuse(avx512_core);
91 if (!is_performant) return status::unimplemented;
92
93 return status::success;
94 }
95
check_tag(memory_desc_t & md,const format_tag_t tag)96 status_t check_tag(memory_desc_t &md, const format_tag_t tag) {
97 const memory_desc_wrapper mdw(&md);
98 if (mdw.matches_one_of_tag(tag) == format_tag::undef)
99 return status::unimplemented;
100 return status::success;
101 }
102
set_and_or_check_formats(const convolution_desc_t & desc,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md)103 status_t set_and_or_check_formats(const convolution_desc_t &desc,
104 memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
105 memory_desc_t &bias_md) {
106 using namespace format_tag;
107 auto atag = utils::pick(src_md.ndims - 3, nwc, nhwc, ndhwc);
108 const bool is_fwd = utils::one_of(desc.prop_kind,
109 prop_kind::forward_training, prop_kind::forward_inference);
110 const bool with_bias = desc.prop_kind != prop_kind::backward_data;
111
112 // Check that nspc is the default layout for convolutions,
113 // or that expected performance gain outweights potential
114 // cost of extra reorders.
115 // Currently this means:
116 // - int8 with any forward prop_kind on any isa
117 // - fp32/bf16 with any prop_kind on avx512_core and higher
118 const bool is_set_allowed = false
119 || (utils::one_of(
120 weights_md.data_type, data_type::f32, data_type::bf16)
121 && mayiuse(avx512_core))
122 || (is_fwd && weights_md.data_type == data_type::s8);
123
124 // NOTE: Only plain layouts should be supported since the dims of
125 // dst_md_ must be reshaped from {N, C, H, W} to {N, C}. If the
126 // conv layout is blocked by channel, then the ip layout will also
127 // be blocked by channel (eg nChw16c -> nC16c). This can lead to
128 // deployment of reference ip as well as strange weights layouts.
129 if (is_set_allowed && src_md.format_kind == format_kind::any)
130 CHECK(memory_desc_init_by_tag(src_md, atag));
131 else
132 CHECK(check_tag(src_md, atag));
133 if (is_set_allowed && dst_md.format_kind == format_kind::any)
134 CHECK(memory_desc_init_by_tag(dst_md, atag));
135 else
136 CHECK(check_tag(dst_md, atag));
137 if (with_bias && bias_md.format_kind != format_kind::undef) {
138 auto btag = x;
139 if (bias_md.format_kind == format_kind::any)
140 CHECK(memory_desc_init_by_tag(bias_md, btag));
141 else
142 CHECK(check_tag(bias_md, btag));
143 }
144 return status::success;
145 }
146
147 } // namespace
148
149 struct ip_convolution_fwd_t : public primitive_t {
150 struct pd_t : public cpu_convolution_fwd_pd_t {
pd_tdnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t151 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
152 const convolution_fwd_pd_t *hint_fwd_pd)
153 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
154
pd_tdnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t155 pd_t(const pd_t &other)
156 : cpu_convolution_fwd_pd_t(other), ip_pd_(other.ip_pd_->clone()) {}
157
158 ~pd_t() = default;
159
160 DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_fwd_t);
161
init_ipdnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t162 status_t init_ip(engine_t *engine) {
163 inner_product_desc_t ipd;
164 CHECK(ip_desc_create(&ipd));
165 dnnl_primitive_desc_iterator it(
166 engine, (op_desc_t *)&ipd, attr(), nullptr);
167 if (!it.is_initialized()) return status::out_of_memory;
168
169 while (++it != it.end()) {
170 ip_pd_ = *it;
171 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
172 if (ok) return status::success;
173 }
174 return status::unimplemented;
175 }
176
initdnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t177 status_t init(engine_t *engine) {
178 using namespace format_tag;
179 using smask_t = primitive_attr_t::skip_mask_t;
180
181 const bool ok = is_fwd()
182 && set_default_alg_kind(alg_kind::convolution_direct)
183 && attr()->has_default_values(
184 smask_t::oscale | smask_t::post_ops);
185 if (!ok) return status::unimplemented;
186
187 CHECK(check_conv_ip(this));
188
189 CHECK(set_and_or_check_formats(
190 *desc(), src_md_, weights_md_, dst_md_, bias_md_));
191
192 CHECK(init_ip(engine));
193
194 if (weights_md_.format_kind == format_kind::any)
195 CHECK(maybe_reshape_weights(
196 &weights_md_, ip_pd_->weights_md(), with_groups()));
197
198 init_name();
199 init_scratchpad();
200 return status::success;
201 }
202
203 std::shared_ptr<primitive_desc_t> ip_pd_;
204
205 private:
206 std::string name_ = "ip:";
207
init_namednnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t208 void init_name() { name_.append(ip_pd_->name()); }
209
init_scratchpaddnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t210 void init_scratchpad() {
211 using namespace memory_tracking::names;
212 auto scratchpad = scratchpad_registry().registrar();
213 scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
214 }
215
ip_desc_creatednnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t216 status_t ip_desc_create(inner_product_desc_t *ipd) {
217 const bool to_ip = true;
218
219 // reinterpret dst without spatial
220 memory_desc_t ip_dst_d;
221 CHECK(reshape_dst(&ip_dst_d, &dst_md_));
222
223 // reinterpret weights without groups
224 memory_desc_t ip_weights_d;
225 CHECK(maybe_reshape_weights(
226 &ip_weights_d, &weights_md_, with_groups(), to_ip));
227
228 return ip_desc_init(ipd, desc()->prop_kind, &src_md_, &ip_weights_d,
229 &bias_md_, &ip_dst_d);
230 }
231 };
232
ip_convolution_fwd_tdnnl::impl::cpu::x64::ip_convolution_fwd_t233 ip_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
234
initdnnl::impl::cpu::x64::ip_convolution_fwd_t235 status_t init(engine_t *engine) override {
236 CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
237 return status::success;
238 }
239
240 status_t execute(const exec_ctx_t &ctx) const override;
241
242 private:
pddnnl::impl::cpu::x64::ip_convolution_fwd_t243 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
244 std::shared_ptr<primitive_t> ip_p_;
245 };
246
247 struct ip_convolution_bwd_data_t : public primitive_t {
248 struct pd_t : public cpu_convolution_bwd_data_pd_t {
pd_tdnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t249 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
250 const convolution_fwd_pd_t *hint_fwd_pd)
251 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
252
pd_tdnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t253 pd_t(const pd_t &other)
254 : cpu_convolution_bwd_data_pd_t(other)
255 , ip_pd_(other.ip_pd_->clone()) {}
256
257 ~pd_t() = default;
258
259 DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_bwd_data_t);
260
init_ipdnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t261 status_t init_ip(engine_t *engine) {
262 inner_product_desc_t ipd;
263 CHECK(ip_desc_create(&ipd));
264 dnnl_primitive_desc_iterator it(
265 engine, (op_desc_t *)&ipd, attr(), nullptr);
266 if (!it.is_initialized()) return status::out_of_memory;
267 while (++it != it.end()) {
268 ip_pd_ = *it;
269 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
270 if (ok) return status::success;
271 }
272 return status::unimplemented;
273 }
274
initdnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t275 status_t init(engine_t *engine) {
276 using namespace format_tag;
277
278 const bool ok = desc()->prop_kind == prop_kind::backward_data
279 && set_default_alg_kind(alg_kind::convolution_direct)
280 && attr()->has_default_values();
281 if (!ok) return status::unimplemented;
282
283 CHECK(check_conv_ip(this));
284
285 CHECK(set_and_or_check_formats(*desc(), diff_src_md_, weights_md_,
286 diff_dst_md_, bias_md_));
287
288 CHECK(init_ip(engine));
289
290 if (weights_md_.format_kind == format_kind::any)
291 CHECK(maybe_reshape_weights(
292 &weights_md_, ip_pd_->weights_md(), with_groups()));
293
294 init_name();
295 init_scratchpad();
296 return status::success;
297 }
298
299 std::shared_ptr<primitive_desc_t> ip_pd_;
300
301 private:
302 std::string name_ = "ip:";
303
init_namednnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t304 void init_name() { name_.append(ip_pd_->name()); }
305
init_scratchpaddnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t306 void init_scratchpad() {
307 using namespace memory_tracking::names;
308 auto scratchpad = scratchpad_registry().registrar();
309 scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
310 }
311
ip_desc_creatednnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t312 status_t ip_desc_create(inner_product_desc_t *ipd) {
313 const bool to_ip = true;
314
315 // reinterpret dst without spatial
316 memory_desc_t ip_diff_dst_d;
317 CHECK(reshape_dst(&ip_diff_dst_d, &diff_dst_md_));
318
319 // reinterpret weights without groups
320 memory_desc_t ip_weights_d;
321 CHECK(maybe_reshape_weights(
322 &ip_weights_d, &weights_md_, with_groups(), to_ip));
323
324 return ip_desc_init(ipd, desc()->prop_kind, &diff_src_md_,
325 &ip_weights_d, nullptr, &ip_diff_dst_d);
326 }
327 };
328
ip_convolution_bwd_data_tdnnl::impl::cpu::x64::ip_convolution_bwd_data_t329 ip_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
330
initdnnl::impl::cpu::x64::ip_convolution_bwd_data_t331 status_t init(engine_t *engine) override {
332 CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
333 return status::success;
334 }
335
336 status_t execute(const exec_ctx_t &ctx) const override;
337
338 private:
pddnnl::impl::cpu::x64::ip_convolution_bwd_data_t339 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
340 std::shared_ptr<primitive_t> ip_p_;
341 };
342
343 struct ip_convolution_bwd_weights_t : public primitive_t {
344 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
pd_tdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t345 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
346 const convolution_fwd_pd_t *hint_fwd_pd)
347 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
348
pd_tdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t349 pd_t(const pd_t &other)
350 : cpu_convolution_bwd_weights_pd_t(other)
351 , ip_pd_(other.ip_pd_->clone()) {}
352
353 ~pd_t() = default;
354
355 DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_bwd_weights_t);
356
init_ipdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t357 status_t init_ip(engine_t *engine) {
358 inner_product_desc_t ipd;
359 CHECK(ip_desc_create(&ipd));
360 dnnl_primitive_desc_iterator it(
361 engine, (op_desc_t *)&ipd, attr(), nullptr);
362 if (!it.is_initialized()) return status::out_of_memory;
363
364 while (++it != it.end()) {
365 ip_pd_ = *it;
366 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
367 if (ok) return status::success;
368 }
369 return status::unimplemented;
370 }
371
initdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t372 status_t init(engine_t *engine) {
373 using namespace format_tag;
374
375 const bool ok = desc()->prop_kind == prop_kind::backward_weights
376 && set_default_alg_kind(alg_kind::convolution_direct)
377 && attr()->has_default_values();
378 if (!ok) return status::unimplemented;
379
380 CHECK(check_conv_ip(this));
381
382 CHECK(set_and_or_check_formats(*desc(), src_md_, diff_weights_md_,
383 diff_dst_md_, diff_bias_md_));
384
385 CHECK(init_ip(engine));
386
387 if (diff_weights_md_.format_kind == format_kind::any)
388 CHECK(maybe_reshape_weights(&diff_weights_md_,
389 ip_pd_->diff_weights_md(), with_groups()));
390
391 init_name();
392 init_scratchpad();
393 return status::success;
394 }
395
396 std::shared_ptr<primitive_desc_t> ip_pd_;
397
398 private:
399 std::string name_ = "ip:";
400
init_namednnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t401 void init_name() { name_.append(ip_pd_->name()); }
402
init_scratchpaddnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t403 void init_scratchpad() {
404 using namespace memory_tracking::names;
405 auto scratchpad = scratchpad_registry().registrar();
406 scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
407 }
408
ip_desc_creatednnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t409 status_t ip_desc_create(inner_product_desc_t *ipd) {
410 const bool to_ip = true;
411
412 // reinterpret dst without spatial
413 memory_desc_t ip_diff_dst_d;
414 CHECK(reshape_dst(&ip_diff_dst_d, &diff_dst_md_));
415
416 // reinterpret weights without groups
417 memory_desc_t ip_diff_weights_d;
418 CHECK(maybe_reshape_weights(&ip_diff_weights_d, &diff_weights_md_,
419 with_groups(), to_ip));
420
421 return ip_desc_init(ipd, desc()->prop_kind, &src_md_,
422 &ip_diff_weights_d, &diff_bias_md_, &ip_diff_dst_d);
423 }
424 };
425
ip_convolution_bwd_weights_tdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t426 ip_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
427
initdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t428 status_t init(engine_t *engine) override {
429 CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
430 return status::success;
431 }
432
433 status_t execute(const exec_ctx_t &ctx) const override;
434
435 private:
pddnnl::impl::cpu::x64::ip_convolution_bwd_weights_t436 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
437 std::shared_ptr<primitive_t> ip_p_;
438 };
439
440 } // namespace x64
441 } // namespace cpu
442 } // namespace impl
443 } // namespace dnnl
444
445 #endif
446
447 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
448