1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_MATMUL_REF_MATMUL_HPP 18 #define CPU_MATMUL_REF_MATMUL_HPP 19 20 #include <assert.h> 21 22 #include "common/bfloat16.hpp" 23 #include "common/c_types_map.hpp" 24 #include "common/primitive.hpp" 25 #include "common/type_helpers.hpp" 26 #include "common/utils.hpp" 27 28 #include "cpu/platform.hpp" 29 #include "cpu/primitive_attr_postops.hpp" 30 31 #include "cpu/matmul/cpu_matmul_pd.hpp" 32 33 namespace dnnl { 34 namespace impl { 35 namespace cpu { 36 namespace matmul { 37 38 template <impl::data_type_t src_type, impl::data_type_t weights_type = src_type, 39 impl::data_type_t dst_type = src_type, 40 impl::data_type_t acc_type = dst_type> 41 struct ref_matmul_t : public primitive_t { 42 struct pd_t : public cpu_matmul_pd_t { 43 using cpu_matmul_pd_t::cpu_matmul_pd_t; 44 45 DECLARE_COMMON_PD_T("ref:any", ref_matmul_t); 46 initdnnl::impl::cpu::matmul::ref_matmul_t::pd_t47 status_t init(engine_t *engine) { 48 using namespace data_type; 49 using smask_t = primitive_attr_t::skip_mask_t; 50 51 bool ok = src_md()->data_type == src_type 52 && weights_md()->data_type == weights_type 53 && desc()->accum_data_type == acc_type 54 && dst_md()->data_type == dst_type 55 && platform::has_data_type_support(src_type) 56 && attr()->has_default_values(smask_t::oscale_runtime 57 | smask_t::zero_points_runtime | smask_t::post_ops) 58 && attr_oscale_ok() && attr_zero_points_ok() 59 && set_default_formats(); 60 61 if (with_bias()) { 62 auto bia_dt = weights_md(1)->data_type; 63 if (acc_type == f32) 64 ok = ok && utils::one_of(bia_dt, f32); 65 else if (acc_type == s32) 66 ok = ok && utils::one_of(bia_dt, f32, s32, s8, u8); 67 } 68 return ok ? status::success : status::unimplemented; 69 } 70 71 private: attr_oscale_okdnnl::impl::cpu::matmul::ref_matmul_t::pd_t72 bool attr_oscale_ok() const { 73 const auto &oscale = attr()->output_scales_; 74 return oscale.mask_ == 0 || oscale.mask_ == (1 << (batched() + 1)); 75 } 76 attr_zero_points_okdnnl::impl::cpu::matmul::ref_matmul_t::pd_t77 bool attr_zero_points_ok() const { 78 int mask_src = 0, mask_wei = 0, mask_dst = 0; 79 attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, &mask_src, nullptr); 80 attr()->zero_points_.get( 81 DNNL_ARG_WEIGHTS, nullptr, &mask_wei, nullptr); 82 attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &mask_dst, nullptr); 83 84 return IMPLICATION(acc_type != data_type::s32, 85 attr()->zero_points_.has_default_values()) 86 && (mask_src == 0 || mask_src == 1 << 1) && (mask_wei == 0) 87 && (mask_dst == 0 || mask_dst == 1 << 1); 88 } 89 }; 90 ref_matmul_tdnnl::impl::cpu::matmul::ref_matmul_t91 ref_matmul_t(const pd_t *apd) : primitive_t(apd) {} 92 initdnnl::impl::cpu::matmul::ref_matmul_t93 status_t init(engine_t *engine) override { 94 ref_post_ops 95 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_); 96 if (!ref_post_ops) return status::out_of_memory; 97 return status::success; 98 } 99 100 typedef typename prec_traits<src_type>::type src_data_t; 101 typedef typename prec_traits<weights_type>::type weights_data_t; 102 typedef typename prec_traits<dst_type>::type dst_data_t; 103 typedef typename prec_traits<acc_type>::type acc_data_t; 104 executednnl::impl::cpu::matmul::ref_matmul_t105 status_t execute(const exec_ctx_t &ctx) const override { 106 return execute_ref(ctx); 107 } 108 109 private: pddnnl::impl::cpu::matmul::ref_matmul_t110 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 111 status_t execute_ref(const exec_ctx_t &ctx) const; 112 std::unique_ptr<ref_post_ops_t> ref_post_ops; 113 }; 114 115 } // namespace matmul 116 } // namespace cpu 117 } // namespace impl 118 } // namespace dnnl 119 120 #endif 121