1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef COMMON_GEMM_UTILS_HPP
18 #define COMMON_GEMM_UTILS_HPP
19
20 #include "oneapi/dnnl/dnnl.h"
21
22 #include "common/c_types_map.hpp"
23 #include "common/nstl.hpp"
24 #include "common/primitive_iterator.hpp"
25 #include "common/utils.hpp"
26
27 namespace dnnl {
28 namespace impl {
29
check_gemm_input(char transa,char transb,int m,int n,int k,int lda,int ldb,int ldc,float alpha,float beta)30 static inline status_t check_gemm_input(char transa, char transb, int m, int n,
31 int k, int lda, int ldb, int ldc, float alpha, float beta) {
32 using namespace status;
33 bool consistency = true && utils::one_of(transa, 'T', 't', 'N', 'n')
34 && utils::one_of(transb, 'T', 't', 'N', 'n') && m >= 0 && n >= 0
35 && k >= 0;
36 if (!consistency) return invalid_arguments;
37 bool isTransA = utils::one_of(transa, 'T', 't');
38 bool isTransB = utils::one_of(transb, 'T', 't');
39 int nrowA = isTransA ? k : m;
40 int nrowB = isTransB ? n : k;
41 consistency = true && lda >= nstl::max(1, nrowA)
42 && ldb >= nstl::max(1, nrowB) && ldc >= nstl::max(1, m);
43 if (!consistency) return invalid_arguments;
44
45 return success;
46 }
47
check_gemm_x8x8s32_input(char offsetc,char transa,char transb,int m,int n,int k,int lda,int ldb,int ldc,float alpha,float beta)48 static inline status_t check_gemm_x8x8s32_input(char offsetc, char transa,
49 char transb, int m, int n, int k, int lda, int ldb, int ldc,
50 float alpha, float beta) {
51 using namespace status;
52 if (!utils::one_of(offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
53 return invalid_arguments;
54 return check_gemm_input(
55 transa, transb, m, n, k, lda, ldb, ldc, alpha, beta);
56 }
57
58 // This function makes a 2d tensor from an nd tensor.
59 // the 2d tensor just collapes dims[1...ndims-1] from the nd tensor
60 // The only reason we do not use reshape here is that we want to allow
61 // fusing blocked dimensions and padded dimensions.
init_2d_desc(memory_desc_t * md_2d,const memory_desc_t * md_nd,bool transpose_dims=false)62 static inline void init_2d_desc(memory_desc_t *md_2d,
63 const memory_desc_t *md_nd, bool transpose_dims = false) {
64 auto p_dims = md_nd->padded_dims;
65 auto blk = md_nd->format_desc.blocking;
66 auto strides = blk.strides;
67
68 // we assume that the innermost dimension always has stride 1
69 assert(IMPLICATION(blk.inner_nblks == 0,
70 utils::array_min(strides, md_nd->ndims) == 1));
71
72 // TODO: add checks to see if the memory descriptor can be 2d-fied
73 // TODO: change signature to specifiy at which dimension shall we 2d-fy (currently 1st)
74 auto p_dim1 = utils::array_product(p_dims + 1, md_nd->ndims - 1);
75 auto stride1 = blk.inner_nblks == 0
76 ? utils::array_min(strides + 1, md_nd->ndims - 1)
77 : 1;
78
79 if (transpose_dims) {
80 dnnl_dims_t dims_2d = {p_dim1, p_dims[0]};
81 dnnl_dims_t strides_2d = {stride1, strides[0]};
82 dnnl_memory_desc_init_by_strides(
83 md_2d, 2, dims_2d, md_nd->data_type, strides_2d);
84 } else {
85 dnnl_dims_t dims_2d = {p_dims[0], p_dim1};
86 dnnl_dims_t strides_2d = {strides[0], stride1};
87 dnnl_memory_desc_init_by_strides(
88 md_2d, 2, dims_2d, md_nd->data_type, strides_2d);
89 }
90 }
91
create_2d_desc(memory_desc_t * md_2d,int d0,int d1,data_type_t dt,transpose_t trans,int ld)92 static inline void create_2d_desc(memory_desc_t *md_2d, int d0, int d1,
93 data_type_t dt, transpose_t trans, int ld) {
94 dnnl_dims_t dims_2d = {d0, d1};
95 if (trans == transpose::notrans) {
96 dnnl_dims_t strides_2d = {ld, 1};
97 dnnl_memory_desc_init_by_strides(md_2d, 2, dims_2d, dt, strides_2d);
98 } else {
99 dnnl_dims_t strides_2d = {1, ld};
100 dnnl_memory_desc_init_by_strides(md_2d, 2, dims_2d, dt, strides_2d);
101 }
102 }
103
create_gemm_pd(std::shared_ptr<primitive_desc_t> & gemm_pd_,engine_t * engine,const memory_desc_t * a_md,const memory_desc_t * b_md,const memory_desc_t * c_md,const memory_desc_t * bias_md,data_type_t acc_dt,const primitive_attr_t * attr,bool skip_ref=false)104 static inline status_t create_gemm_pd(
105 std::shared_ptr<primitive_desc_t> &gemm_pd_, engine_t *engine,
106 const memory_desc_t *a_md, const memory_desc_t *b_md,
107 const memory_desc_t *c_md, const memory_desc_t *bias_md,
108 data_type_t acc_dt, const primitive_attr_t *attr,
109 bool skip_ref = false) {
110 auto gemm_desc = gemm_desc_t();
111 gemm_desc.primitive_kind = primitive_kind::gemm;
112 gemm_desc.a_desc = *a_md;
113 gemm_desc.b_desc = *b_md;
114 gemm_desc.c_desc = *c_md;
115 gemm_desc.bias_desc = *bias_md;
116 gemm_desc.acc_type = acc_dt;
117
118 primitive_attr_t gemm_attr = *attr;
119
120 dnnl_primitive_desc_iterator it(
121 engine, (op_desc_t *)&gemm_desc, &gemm_attr, nullptr);
122
123 gemm_pd_ = *(++it);
124 if (!gemm_pd_) return status::unimplemented;
125 if (skip_ref && strstr(gemm_pd_.get()->name(), "ref") != NULL)
126 return status::unimplemented;
127
128 return status::success;
129 }
130
131 } // namespace impl
132 } // namespace dnnl
133
134 #endif
135