1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_IP_CONVOLUTION_HPP
18 #define CPU_X64_IP_CONVOLUTION_HPP
19 
20 #include <string>
21 
22 #include "common/c_types_map.hpp"
23 #include "common/primitive.hpp"
24 #include "common/primitive_iterator.hpp"
25 #include "common/utils.hpp"
26 
27 #include "cpu/cpu_convolution_pd.hpp"
28 #include "cpu/cpu_inner_product_pd.hpp"
29 
30 #include "cpu/x64/cpu_isa_traits.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace x64 {
36 
37 namespace {
38 
reshape_dst(memory_desc_t * o_md,const memory_desc_t * i_md)39 status_t reshape_dst(memory_desc_t *o_md, const memory_desc_t *i_md) {
40     dims_t reduce {};
41     const dim_t ndims = 2; // dst is always nc for inner product
42     // conv to ip: remove spatial
43     for (int d = 0; d < ndims; ++d)
44         reduce[d] = i_md->dims[d];
45 
46     return dnnl_memory_desc_reshape(o_md, i_md, ndims, reduce);
47 }
48 
maybe_reshape_weights(memory_desc_t * o_md,const memory_desc_t * i_md,bool with_groups,bool to_ip=false)49 status_t maybe_reshape_weights(memory_desc_t *o_md, const memory_desc_t *i_md,
50         bool with_groups, bool to_ip = false) {
51     dims_t reduce {};
52     const dim_t ndims = i_md->ndims + (to_ip ? -1 : +1) * with_groups;
53     if (to_ip) {
54         // conv to ip: maybe remove groups
55         for (int d = 0; d < ndims; ++d)
56             reduce[d] = i_md->dims[d + with_groups];
57     } else {
58         // ip to conv: maybe restore groups
59         if (with_groups) reduce[0] = 1;
60         for (int d = 0; d < ndims; ++d)
61             reduce[d + with_groups] = i_md->dims[d];
62     }
63 
64     return dnnl_memory_desc_reshape(o_md, i_md, ndims, reduce);
65 }
66 
check_conv_ip(convolution_pd_t * self)67 status_t check_conv_ip(convolution_pd_t *self) {
68     // Check if convolution is equivalent to inner product
69     const bool is_ip_applicable = true
70             // no dilations
71             && utils::everyone_is(0, self->KDD(), self->KDH(), self->KDW())
72             // no "left" padding
73             && utils::everyone_is(
74                     0, self->padFront(), self->padT(), self->padL())
75             // no "right" padding
76             && utils::everyone_is(
77                     0, self->padBack(), self->padB(), self->padR())
78             // no non-trivial groups or output spatial
79             && utils::everyone_is(
80                     1, self->G(), self->OD(), self->OH(), self->OW())
81             // only unit stride
82             && utils::everyone_is(1, self->KSD(), self->KSH(), self->KSW());
83     if (!is_ip_applicable) return status::unimplemented;
84 
85     // Simple heuristic to only target arches and shapes that benefit.
86     // TODO: Extend to other arches and shapes as performance allows.
87     const dim_t ks = self->KD() * self->KH() * self->KW();
88     const dim_t ks_threshold = 27; // empirical
89     const bool is_performant
90             = 1 < self->MB() && ks > ks_threshold && mayiuse(avx512_core);
91     if (!is_performant) return status::unimplemented;
92 
93     return status::success;
94 }
95 
check_tag(memory_desc_t & md,const format_tag_t tag)96 status_t check_tag(memory_desc_t &md, const format_tag_t tag) {
97     const memory_desc_wrapper mdw(&md);
98     if (mdw.matches_one_of_tag(tag) == format_tag::undef)
99         return status::unimplemented;
100     return status::success;
101 }
102 
set_and_or_check_formats(const convolution_desc_t & desc,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr)103 status_t set_and_or_check_formats(const convolution_desc_t &desc,
104         memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
105         memory_desc_t &bias_md, primitive_attr_t &attr) {
106     using namespace format_tag;
107     auto atag = utils::pick(src_md.ndims - 3, nwc, nhwc, ndhwc);
108     const bool is_fwd = utils::one_of(desc.prop_kind,
109             prop_kind::forward_training, prop_kind::forward_inference);
110     const bool with_bias = desc.prop_kind != prop_kind::backward_data;
111 
112     // Check that nspc is the default layout for convolutions,
113     // or that expected performance gain outweights potential
114     // cost of extra reorders.
115     // Currently this means:
116     // - int8 with any forward prop_kind on any isa
117     // - fp32/bf16 with any prop_kind on avx512_core and higher
118     const bool is_set_allowed = false
119             || (utils::one_of(
120                         weights_md.data_type, data_type::f32, data_type::bf16)
121                     && mayiuse(avx512_core))
122             || (is_fwd && weights_md.data_type == data_type::s8);
123 
124     // NOTE: Only plain layouts should be supported since the dims of
125     // dst_md_ must be reshaped from {N, C, H, W} to {N, C}. If the
126     // conv layout is blocked by channel, then the ip layout will also
127     // be blocked by channel (eg nChw16c -> nC16c). This can lead to
128     // deployment of reference ip as well as strange weights layouts.
129     if (is_set_allowed && src_md.format_kind == format_kind::any)
130         CHECK(memory_desc_init_by_tag(src_md, atag));
131     else
132         CHECK(check_tag(src_md, atag));
133     if (is_set_allowed && dst_md.format_kind == format_kind::any)
134         CHECK(memory_desc_init_by_tag(dst_md, atag));
135     else
136         CHECK(check_tag(dst_md, atag));
137     if (with_bias && bias_md.format_kind != format_kind::undef) {
138         auto btag = x;
139         if (bias_md.format_kind == format_kind::any)
140             CHECK(memory_desc_init_by_tag(bias_md, btag));
141         else
142             CHECK(check_tag(bias_md, btag));
143     }
144     return attr.set_default_formats(&dst_md);
145 }
146 
147 } // namespace
148 
149 struct ip_convolution_fwd_t : public primitive_t {
150     struct pd_t : public cpu_convolution_fwd_pd_t {
pd_tdnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t151         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
152                 const convolution_fwd_pd_t *hint_fwd_pd)
153             : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
154 
pd_tdnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t155         pd_t(const pd_t &other)
156             : cpu_convolution_fwd_pd_t(other), ip_pd_(other.ip_pd_->clone()) {}
157 
158         ~pd_t() = default;
159 
160         DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_fwd_t);
161 
init_ipdnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t162         status_t init_ip(engine_t *engine) {
163             inner_product_desc_t ipd;
164             CHECK(ip_desc_create(&ipd));
165             dnnl_primitive_desc_iterator it(
166                     engine, (op_desc_t *)&ipd, attr(), nullptr);
167             if (!it.is_initialized()) return status::out_of_memory;
168 
169             while (++it != it.end()) {
170                 ip_pd_ = *it;
171                 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
172                 if (ok) return status::success;
173             }
174             return status::unimplemented;
175         }
176 
initdnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t177         status_t init(engine_t *engine) {
178             using namespace format_tag;
179             using smask_t = primitive_attr_t::skip_mask_t;
180 
181             const bool ok = is_fwd()
182                     && set_default_alg_kind(alg_kind::convolution_direct)
183                     && attr()->has_default_values(
184                             smask_t::oscale | smask_t::post_ops);
185             if (!ok) return status::unimplemented;
186 
187             CHECK(check_conv_ip(this));
188 
189             CHECK(set_and_or_check_formats(
190                     *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_));
191 
192             CHECK(init_ip(engine));
193 
194             if (weights_md_.format_kind == format_kind::any)
195                 CHECK(maybe_reshape_weights(
196                         &weights_md_, ip_pd_->weights_md(), with_groups()));
197 
198             init_name();
199             init_scratchpad();
200             return status::success;
201         }
202 
203         std::shared_ptr<primitive_desc_t> ip_pd_;
204 
205     private:
206         std::string name_ = "ip:";
207 
init_namednnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t208         void init_name() { name_.append(ip_pd_->name()); }
209 
init_scratchpaddnnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t210         void init_scratchpad() {
211             using namespace memory_tracking::names;
212             auto scratchpad = scratchpad_registry().registrar();
213             scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
214         }
215 
ip_desc_creatednnl::impl::cpu::x64::ip_convolution_fwd_t::pd_t216         status_t ip_desc_create(inner_product_desc_t *ipd) {
217             const bool to_ip = true;
218 
219             // reinterpret dst without spatial
220             memory_desc_t ip_dst_d;
221             CHECK(reshape_dst(&ip_dst_d, &dst_md_));
222 
223             // reinterpret weights without groups
224             memory_desc_t ip_weights_d;
225             CHECK(maybe_reshape_weights(
226                     &ip_weights_d, &weights_md_, with_groups(), to_ip));
227 
228             return ip_desc_init(ipd, desc()->prop_kind, &src_md_, &ip_weights_d,
229                     &bias_md_, &ip_dst_d);
230         }
231     };
232 
ip_convolution_fwd_tdnnl::impl::cpu::x64::ip_convolution_fwd_t233     ip_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
234 
initdnnl::impl::cpu::x64::ip_convolution_fwd_t235     status_t init(engine_t *engine) override {
236         CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
237         return status::success;
238     }
239 
240     status_t execute(const exec_ctx_t &ctx) const override;
241 
242 private:
pddnnl::impl::cpu::x64::ip_convolution_fwd_t243     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
244     std::shared_ptr<primitive_t> ip_p_;
245 };
246 
247 struct ip_convolution_bwd_data_t : public primitive_t {
248     struct pd_t : public cpu_convolution_bwd_data_pd_t {
pd_tdnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t249         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
250                 const convolution_fwd_pd_t *hint_fwd_pd)
251             : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
252 
pd_tdnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t253         pd_t(const pd_t &other)
254             : cpu_convolution_bwd_data_pd_t(other)
255             , ip_pd_(other.ip_pd_->clone()) {}
256 
257         ~pd_t() = default;
258 
259         DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_bwd_data_t);
260 
init_ipdnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t261         status_t init_ip(engine_t *engine) {
262             inner_product_desc_t ipd;
263             CHECK(ip_desc_create(&ipd));
264             dnnl_primitive_desc_iterator it(
265                     engine, (op_desc_t *)&ipd, attr(), nullptr);
266             if (!it.is_initialized()) return status::out_of_memory;
267             while (++it != it.end()) {
268                 ip_pd_ = *it;
269                 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
270                 if (ok) return status::success;
271             }
272             return status::unimplemented;
273         }
274 
initdnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t275         status_t init(engine_t *engine) {
276             using namespace format_tag;
277 
278             const bool ok = desc()->prop_kind == prop_kind::backward_data
279                     && set_default_alg_kind(alg_kind::convolution_direct)
280                     && attr()->has_default_values();
281             if (!ok) return status::unimplemented;
282 
283             CHECK(check_conv_ip(this));
284 
285             CHECK(set_and_or_check_formats(*desc(), diff_src_md_, weights_md_,
286                     diff_dst_md_, bias_md_, attr_));
287 
288             CHECK(init_ip(engine));
289 
290             if (weights_md_.format_kind == format_kind::any)
291                 CHECK(maybe_reshape_weights(
292                         &weights_md_, ip_pd_->weights_md(), with_groups()));
293 
294             init_name();
295             init_scratchpad();
296             return status::success;
297         }
298 
299         std::shared_ptr<primitive_desc_t> ip_pd_;
300 
301     private:
302         std::string name_ = "ip:";
303 
init_namednnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t304         void init_name() { name_.append(ip_pd_->name()); }
305 
init_scratchpaddnnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t306         void init_scratchpad() {
307             using namespace memory_tracking::names;
308             auto scratchpad = scratchpad_registry().registrar();
309             scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
310         }
311 
ip_desc_creatednnl::impl::cpu::x64::ip_convolution_bwd_data_t::pd_t312         status_t ip_desc_create(inner_product_desc_t *ipd) {
313             const bool to_ip = true;
314 
315             // reinterpret dst without spatial
316             memory_desc_t ip_diff_dst_d;
317             CHECK(reshape_dst(&ip_diff_dst_d, &diff_dst_md_));
318 
319             // reinterpret weights without groups
320             memory_desc_t ip_weights_d;
321             CHECK(maybe_reshape_weights(
322                     &ip_weights_d, &weights_md_, with_groups(), to_ip));
323 
324             return ip_desc_init(ipd, desc()->prop_kind, &diff_src_md_,
325                     &ip_weights_d, nullptr, &ip_diff_dst_d);
326         }
327     };
328 
ip_convolution_bwd_data_tdnnl::impl::cpu::x64::ip_convolution_bwd_data_t329     ip_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
330 
initdnnl::impl::cpu::x64::ip_convolution_bwd_data_t331     status_t init(engine_t *engine) override {
332         CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
333         return status::success;
334     }
335 
336     status_t execute(const exec_ctx_t &ctx) const override;
337 
338 private:
pddnnl::impl::cpu::x64::ip_convolution_bwd_data_t339     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
340     std::shared_ptr<primitive_t> ip_p_;
341 };
342 
343 struct ip_convolution_bwd_weights_t : public primitive_t {
344     struct pd_t : public cpu_convolution_bwd_weights_pd_t {
pd_tdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t345         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
346                 const convolution_fwd_pd_t *hint_fwd_pd)
347             : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
348 
pd_tdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t349         pd_t(const pd_t &other)
350             : cpu_convolution_bwd_weights_pd_t(other)
351             , ip_pd_(other.ip_pd_->clone()) {}
352 
353         ~pd_t() = default;
354 
355         DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_bwd_weights_t);
356 
init_ipdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t357         status_t init_ip(engine_t *engine) {
358             inner_product_desc_t ipd;
359             CHECK(ip_desc_create(&ipd));
360             dnnl_primitive_desc_iterator it(
361                     engine, (op_desc_t *)&ipd, attr(), nullptr);
362             if (!it.is_initialized()) return status::out_of_memory;
363 
364             while (++it != it.end()) {
365                 ip_pd_ = *it;
366                 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
367                 if (ok) return status::success;
368             }
369             return status::unimplemented;
370         }
371 
initdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t372         status_t init(engine_t *engine) {
373             using namespace format_tag;
374 
375             const bool ok = desc()->prop_kind == prop_kind::backward_weights
376                     && set_default_alg_kind(alg_kind::convolution_direct)
377                     && attr()->has_default_values();
378             if (!ok) return status::unimplemented;
379 
380             CHECK(check_conv_ip(this));
381 
382             CHECK(set_and_or_check_formats(*desc(), src_md_, diff_weights_md_,
383                     diff_dst_md_, diff_bias_md_, attr_));
384 
385             CHECK(init_ip(engine));
386 
387             if (diff_weights_md_.format_kind == format_kind::any)
388                 CHECK(maybe_reshape_weights(&diff_weights_md_,
389                         ip_pd_->diff_weights_md(), with_groups()));
390 
391             init_name();
392             init_scratchpad();
393             return status::success;
394         }
395 
396         std::shared_ptr<primitive_desc_t> ip_pd_;
397 
398     private:
399         std::string name_ = "ip:";
400 
init_namednnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t401         void init_name() { name_.append(ip_pd_->name()); }
402 
init_scratchpaddnnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t403         void init_scratchpad() {
404             using namespace memory_tracking::names;
405             auto scratchpad = scratchpad_registry().registrar();
406             scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
407         }
408 
ip_desc_creatednnl::impl::cpu::x64::ip_convolution_bwd_weights_t::pd_t409         status_t ip_desc_create(inner_product_desc_t *ipd) {
410             const bool to_ip = true;
411 
412             // reinterpret dst without spatial
413             memory_desc_t ip_diff_dst_d;
414             CHECK(reshape_dst(&ip_diff_dst_d, &diff_dst_md_));
415 
416             // reinterpret weights without groups
417             memory_desc_t ip_diff_weights_d;
418             CHECK(maybe_reshape_weights(&ip_diff_weights_d, &diff_weights_md_,
419                     with_groups(), to_ip));
420 
421             return ip_desc_init(ipd, desc()->prop_kind, &src_md_,
422                     &ip_diff_weights_d, &diff_bias_md_, &ip_diff_dst_d);
423         }
424     };
425 
ip_convolution_bwd_weights_tdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t426     ip_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
427 
initdnnl::impl::cpu::x64::ip_convolution_bwd_weights_t428     status_t init(engine_t *engine) override {
429         CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
430         return status::success;
431     }
432 
433     status_t execute(const exec_ctx_t &ctx) const override;
434 
435 private:
pddnnl::impl::cpu::x64::ip_convolution_bwd_weights_t436     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
437     std::shared_ptr<primitive_t> ip_p_;
438 };
439 
440 } // namespace x64
441 } // namespace cpu
442 } // namespace impl
443 } // namespace dnnl
444 
445 #endif
446 
447 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
448