1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_MATMUL_REF_MATMUL_HPP
18 #define CPU_MATMUL_REF_MATMUL_HPP
19 
20 #include <assert.h>
21 
22 #include "common/bfloat16.hpp"
23 #include "common/c_types_map.hpp"
24 #include "common/primitive.hpp"
25 #include "common/type_helpers.hpp"
26 #include "common/utils.hpp"
27 
28 #include "cpu/platform.hpp"
29 #include "cpu/primitive_attr_postops.hpp"
30 
31 #include "cpu/matmul/cpu_matmul_pd.hpp"
32 
33 namespace dnnl {
34 namespace impl {
35 namespace cpu {
36 namespace matmul {
37 
38 template <impl::data_type_t src_type, impl::data_type_t weights_type = src_type,
39         impl::data_type_t dst_type = src_type,
40         impl::data_type_t acc_type = dst_type>
41 struct ref_matmul_t : public primitive_t {
42     struct pd_t : public cpu_matmul_pd_t {
43         using cpu_matmul_pd_t::cpu_matmul_pd_t;
44 
45         DECLARE_COMMON_PD_T("ref:any", ref_matmul_t);
46 
initdnnl::impl::cpu::matmul::ref_matmul_t::pd_t47         status_t init(engine_t *engine) {
48             using namespace data_type;
49             using smask_t = primitive_attr_t::skip_mask_t;
50 
51             bool ok = src_md()->data_type == src_type
52                     && weights_md()->data_type == weights_type
53                     && desc()->accum_data_type == acc_type
54                     && dst_md()->data_type == dst_type
55                     && platform::has_data_type_support(src_type)
56                     && attr()->has_default_values(smask_t::oscale_runtime
57                             | smask_t::zero_points_runtime | smask_t::post_ops)
58                     && attr_oscale_ok() && attr_zero_points_ok()
59                     && set_default_formats();
60 
61             if (with_bias()) {
62                 auto bia_dt = weights_md(1)->data_type;
63                 if (acc_type == f32)
64                     ok = ok && utils::one_of(bia_dt, f32);
65                 else if (acc_type == s32)
66                     ok = ok && utils::one_of(bia_dt, f32, s32, s8, u8);
67             }
68             return ok ? status::success : status::unimplemented;
69         }
70 
71     private:
attr_oscale_okdnnl::impl::cpu::matmul::ref_matmul_t::pd_t72         bool attr_oscale_ok() const {
73             const auto &oscale = attr()->output_scales_;
74             return oscale.mask_ == 0 || oscale.mask_ == (1 << (batched() + 1));
75         }
76 
attr_zero_points_okdnnl::impl::cpu::matmul::ref_matmul_t::pd_t77         bool attr_zero_points_ok() const {
78             int mask_src = 0, mask_wei = 0, mask_dst = 0;
79             attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, &mask_src, nullptr);
80             attr()->zero_points_.get(
81                     DNNL_ARG_WEIGHTS, nullptr, &mask_wei, nullptr);
82             attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &mask_dst, nullptr);
83 
84             return IMPLICATION(acc_type != data_type::s32,
85                            attr()->zero_points_.has_default_values())
86                     && (mask_src == 0 || mask_src == 1 << 1) && (mask_wei == 0)
87                     && (mask_dst == 0 || mask_dst == 1 << 1);
88         }
89     };
90 
ref_matmul_tdnnl::impl::cpu::matmul::ref_matmul_t91     ref_matmul_t(const pd_t *apd) : primitive_t(apd) {}
92 
initdnnl::impl::cpu::matmul::ref_matmul_t93     status_t init(engine_t *engine) override {
94         ref_post_ops
95                 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
96         if (!ref_post_ops) return status::out_of_memory;
97         return status::success;
98     }
99 
100     typedef typename prec_traits<src_type>::type src_data_t;
101     typedef typename prec_traits<weights_type>::type weights_data_t;
102     typedef typename prec_traits<dst_type>::type dst_data_t;
103     typedef typename prec_traits<acc_type>::type acc_data_t;
104 
executednnl::impl::cpu::matmul::ref_matmul_t105     status_t execute(const exec_ctx_t &ctx) const override {
106         return execute_ref(ctx);
107     }
108 
109 private:
pddnnl::impl::cpu::matmul::ref_matmul_t110     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
111     status_t execute_ref(const exec_ctx_t &ctx) const;
112     std::unique_ptr<ref_post_ops_t> ref_post_ops;
113 };
114 
115 } // namespace matmul
116 } // namespace cpu
117 } // namespace impl
118 } // namespace dnnl
119 
120 #endif
121