1 /******************************************************************************* 2 * Copyright 2017-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_SSE41_CONVOLUTION_HPP 18 #define CPU_X64_JIT_SSE41_CONVOLUTION_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/dnnl_thread.hpp" 22 #include "common/primitive.hpp" 23 #include "common/utils.hpp" 24 25 #include "cpu/cpu_convolution_pd.hpp" 26 27 #include "cpu/x64/jit_primitive_conf.hpp" 28 #include "cpu/x64/jit_sse41_conv_kernel_f32.hpp" 29 30 namespace dnnl { 31 namespace impl { 32 namespace cpu { 33 namespace x64 { 34 35 struct jit_sse41_convolution_fwd_t : public primitive_t { 36 struct pd_t : public cpu_convolution_fwd_pd_t { pd_tdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t::pd_t37 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, 38 const typename pd_t::base_class *hint_fwd_pd) 39 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} 40 41 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", sse41, ""), 42 jit_sse41_convolution_fwd_t); 43 initdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t::pd_t44 status_t init(engine_t *engine) { 45 using namespace data_type; 46 bool ok = is_fwd() 47 && set_default_alg_kind(alg_kind::convolution_direct) 48 && expect_data_types(f32, f32, f32, f32, f32) 49 && attr()->has_default_values( 50 primitive_attr_t::skip_mask_t::post_ops, f32) 51 && !has_zero_dim_memory() && set_default_formats() 52 && attr_.set_default_formats(dst_md(0)) == status::success; 53 if (!ok) return status::unimplemented; 54 55 CHECK(jit_sse41_conv_fwd_kernel_f32::init_conf(jcp_, *desc(), 56 *src_md(), *weights_md(), *dst_md(), *attr(), 57 dnnl_get_max_threads())); 58 59 return status::success; 60 } 61 62 jit_conv_conf_t jcp_; 63 64 protected: set_default_formatsdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t::pd_t65 bool set_default_formats() { 66 using namespace format_tag; 67 68 const bool flat = IC() == 3; 69 auto src_tag = flat 70 ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) 71 : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); 72 auto dst_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); 73 auto wei_tag = with_groups() 74 ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, 75 gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) 76 : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, 77 OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); 78 79 return set_default_formats_common(src_tag, wei_tag, dst_tag); 80 } 81 }; 82 jit_sse41_convolution_fwd_tdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t83 jit_sse41_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} 84 85 typedef typename prec_traits<data_type::f32>::type data_t; 86 initdnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t87 status_t init(engine_t *engine) override { 88 CHECK(safe_ptr_assign(kernel_, 89 new jit_sse41_conv_fwd_kernel_f32( 90 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); 91 return kernel_->create_kernel(); 92 } 93 executednnl::impl::cpu::x64::jit_sse41_convolution_fwd_t94 status_t execute(const exec_ctx_t &ctx) const override { 95 execute_forward(ctx); 96 return status::success; 97 } 98 99 private: 100 void execute_forward(const exec_ctx_t &ctx) const; pddnnl::impl::cpu::x64::jit_sse41_convolution_fwd_t101 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 102 std::unique_ptr<jit_sse41_conv_fwd_kernel_f32> kernel_; 103 }; 104 105 } // namespace x64 106 } // namespace cpu 107 } // namespace impl 108 } // namespace dnnl 109 110 #endif 111