1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef COMMON_GEMM_UTILS_HPP
18 #define COMMON_GEMM_UTILS_HPP
19 
20 #include "oneapi/dnnl/dnnl.h"
21 
22 #include "common/c_types_map.hpp"
23 #include "common/nstl.hpp"
24 #include "common/primitive_iterator.hpp"
25 #include "common/utils.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 
check_gemm_input(char transa,char transb,int m,int n,int k,int lda,int ldb,int ldc,float alpha,float beta)30 static inline status_t check_gemm_input(char transa, char transb, int m, int n,
31         int k, int lda, int ldb, int ldc, float alpha, float beta) {
32     using namespace status;
33     bool consistency = true && utils::one_of(transa, 'T', 't', 'N', 'n')
34             && utils::one_of(transb, 'T', 't', 'N', 'n') && m >= 0 && n >= 0
35             && k >= 0;
36     if (!consistency) return invalid_arguments;
37     bool isTransA = utils::one_of(transa, 'T', 't');
38     bool isTransB = utils::one_of(transb, 'T', 't');
39     int nrowA = isTransA ? k : m;
40     int nrowB = isTransB ? n : k;
41     consistency = true && lda >= nstl::max(1, nrowA)
42             && ldb >= nstl::max(1, nrowB) && ldc >= nstl::max(1, m);
43     if (!consistency) return invalid_arguments;
44 
45     return success;
46 }
47 
check_gemm_x8x8s32_input(char offsetc,char transa,char transb,int m,int n,int k,int lda,int ldb,int ldc,float alpha,float beta)48 static inline status_t check_gemm_x8x8s32_input(char offsetc, char transa,
49         char transb, int m, int n, int k, int lda, int ldb, int ldc,
50         float alpha, float beta) {
51     using namespace status;
52     if (!utils::one_of(offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
53         return invalid_arguments;
54     return check_gemm_input(
55             transa, transb, m, n, k, lda, ldb, ldc, alpha, beta);
56 }
57 
58 // This function makes a 2d tensor from an nd tensor.
59 // the 2d tensor just collapes dims[1...ndims-1] from the nd tensor
60 // The only reason we do not use reshape here is that we want to allow
61 // fusing blocked dimensions and padded dimensions.
init_2d_desc(memory_desc_t * md_2d,const memory_desc_t * md_nd,bool transpose_dims=false)62 static inline void init_2d_desc(memory_desc_t *md_2d,
63         const memory_desc_t *md_nd, bool transpose_dims = false) {
64     auto p_dims = md_nd->padded_dims;
65     auto blk = md_nd->format_desc.blocking;
66     auto strides = blk.strides;
67 
68     // we assume that the innermost dimension always has stride 1
69     assert(IMPLICATION(blk.inner_nblks == 0,
70             utils::array_min(strides, md_nd->ndims) == 1));
71 
72     // TODO: add checks to see if the memory descriptor can be 2d-fied
73     // TODO: change signature to specifiy at which dimension shall we 2d-fy (currently 1st)
74     auto p_dim1 = utils::array_product(p_dims + 1, md_nd->ndims - 1);
75     auto stride1 = blk.inner_nblks == 0
76             ? utils::array_min(strides + 1, md_nd->ndims - 1)
77             : 1;
78 
79     if (transpose_dims) {
80         dnnl_dims_t dims_2d = {p_dim1, p_dims[0]};
81         dnnl_dims_t strides_2d = {stride1, strides[0]};
82         dnnl_memory_desc_init_by_strides(
83                 md_2d, 2, dims_2d, md_nd->data_type, strides_2d);
84     } else {
85         dnnl_dims_t dims_2d = {p_dims[0], p_dim1};
86         dnnl_dims_t strides_2d = {strides[0], stride1};
87         dnnl_memory_desc_init_by_strides(
88                 md_2d, 2, dims_2d, md_nd->data_type, strides_2d);
89     }
90 }
91 
create_2d_desc(memory_desc_t * md_2d,int d0,int d1,data_type_t dt,transpose_t trans,int ld)92 static inline void create_2d_desc(memory_desc_t *md_2d, int d0, int d1,
93         data_type_t dt, transpose_t trans, int ld) {
94     dnnl_dims_t dims_2d = {d0, d1};
95     if (trans == transpose::notrans) {
96         dnnl_dims_t strides_2d = {ld, 1};
97         dnnl_memory_desc_init_by_strides(md_2d, 2, dims_2d, dt, strides_2d);
98     } else {
99         dnnl_dims_t strides_2d = {1, ld};
100         dnnl_memory_desc_init_by_strides(md_2d, 2, dims_2d, dt, strides_2d);
101     }
102 }
103 
create_gemm_pd(std::shared_ptr<primitive_desc_t> & gemm_pd_,engine_t * engine,const memory_desc_t * a_md,const memory_desc_t * b_md,const memory_desc_t * c_md,const memory_desc_t * bias_md,data_type_t acc_dt,const primitive_attr_t * attr,bool skip_ref=false)104 static inline status_t create_gemm_pd(
105         std::shared_ptr<primitive_desc_t> &gemm_pd_, engine_t *engine,
106         const memory_desc_t *a_md, const memory_desc_t *b_md,
107         const memory_desc_t *c_md, const memory_desc_t *bias_md,
108         data_type_t acc_dt, const primitive_attr_t *attr,
109         bool skip_ref = false) {
110     auto gemm_desc = gemm_desc_t();
111     gemm_desc.primitive_kind = primitive_kind::gemm;
112     gemm_desc.a_desc = *a_md;
113     gemm_desc.b_desc = *b_md;
114     gemm_desc.c_desc = *c_md;
115     gemm_desc.bias_desc = *bias_md;
116     gemm_desc.acc_type = acc_dt;
117 
118     primitive_attr_t gemm_attr = *attr;
119 
120     dnnl_primitive_desc_iterator it(
121             engine, (op_desc_t *)&gemm_desc, &gemm_attr, nullptr);
122 
123     gemm_pd_ = *(++it);
124     if (!gemm_pd_) return status::unimplemented;
125     if (skip_ref && strstr(gemm_pd_.get()->name(), "ref") != NULL)
126         return status::unimplemented;
127 
128     return status::success;
129 }
130 
131 } // namespace impl
132 } // namespace dnnl
133 
134 #endif
135