1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_SSE41_CONVOLUTION_HPP
18 #define CPU_X64_JIT_SSE41_CONVOLUTION_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/primitive.hpp"
23 #include "common/utils.hpp"
24 
25 #include "cpu/cpu_convolution_pd.hpp"
26 
27 #include "cpu/x64/jit_primitive_conf.hpp"
28 #include "cpu/x64/jit_sse41_conv_kernel_f32.hpp"
29 
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34 
35 struct jit_sse41_convolution_fwd_t : public primitive_t {
36     struct pd_t : public cpu_convolution_fwd_pd_t {
pd_tdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t::pd_t37         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
38                 const typename pd_t::base_class *hint_fwd_pd)
39             : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
40 
41         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", sse41, ""),
42                 jit_sse41_convolution_fwd_t);
43 
initdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t::pd_t44         status_t init(engine_t *engine) {
45             using namespace data_type;
46             bool ok = is_fwd()
47                     && set_default_alg_kind(alg_kind::convolution_direct)
48                     && expect_data_types(f32, f32, f32, f32, f32)
49                     && attr()->has_default_values(
50                             primitive_attr_t::skip_mask_t::post_ops, f32)
51                     && !has_zero_dim_memory() && set_default_formats()
52                     && attr_.set_default_formats(dst_md(0)) == status::success;
53             if (!ok) return status::unimplemented;
54 
55             CHECK(jit_sse41_conv_fwd_kernel_f32::init_conf(jcp_, *desc(),
56                     *src_md(), *weights_md(), *dst_md(), *attr(),
57                     dnnl_get_max_threads()));
58 
59             return status::success;
60         }
61 
62         jit_conv_conf_t jcp_;
63 
64     protected:
set_default_formatsdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t::pd_t65         bool set_default_formats() {
66             using namespace format_tag;
67 
68             const bool flat = IC() == 3;
69             auto src_tag = flat
70                     ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
71                     : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
72             auto dst_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
73             auto wei_tag = with_groups()
74                     ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
75                             gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
76                     : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
77                             OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
78 
79             return set_default_formats_common(src_tag, wei_tag, dst_tag);
80         }
81     };
82 
jit_sse41_convolution_fwd_tdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t83     jit_sse41_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
84 
85     typedef typename prec_traits<data_type::f32>::type data_t;
86 
initdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t87     status_t init(engine_t *engine) override {
88         CHECK(safe_ptr_assign(kernel_,
89                 new jit_sse41_conv_fwd_kernel_f32(
90                         pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
91         return kernel_->create_kernel();
92     }
93 
executednnl::impl::cpu::x64::jit_sse41_convolution_fwd_t94     status_t execute(const exec_ctx_t &ctx) const override {
95         execute_forward(ctx);
96         return status::success;
97     }
98 
99 private:
100     void execute_forward(const exec_ctx_t &ctx) const;
pddnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t101     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
102     std::unique_ptr<jit_sse41_conv_fwd_kernel_f32> kernel_;
103 };
104 
105 } // namespace x64
106 } // namespace cpu
107 } // namespace impl
108 } // namespace dnnl
109 
110 #endif
111