1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef COMMON_GEMM_TYPES_HPP
18 #define COMMON_GEMM_TYPES_HPP
19 
20 #include <assert.h>
21 
22 #include "oneapi/dnnl/dnnl_types.h"
23 
24 namespace dnnl {
25 namespace impl {
26 
27 enum transpose_t { dnnl_notrans, dnnl_trans };
28 
29 namespace transpose {
30 const transpose_t notrans = dnnl_notrans;
31 const transpose_t trans = dnnl_trans;
32 } // namespace transpose
33 
34 enum offsetc_t { dnnl_fixed, dnnl_column, dnnl_row };
35 
36 namespace offsetc {
37 const offsetc_t fixed = dnnl_fixed;
38 const offsetc_t column = dnnl_column;
39 const offsetc_t row = dnnl_row;
40 } // namespace offsetc
41 
42 /** A descriptor for a matrix multiplication (gemm) operation */
43 struct dnnl_gemm_desc_t {
44     /* To make the interface consistent, the descriptor represent the
45      * GEMM operation in row major */
46 
47     /** The kind of primitive. Used for self identifying the primitive
48      * descriptor. Must be #dnnl_gemm. */
49     dnnl_primitive_kind_t primitive_kind;
50     dnnl_memory_desc_t a_desc;
51     dnnl_memory_desc_t b_desc;
52     dnnl_memory_desc_t c_desc;
53     dnnl_memory_desc_t bias_desc;
54     /** Type for accumulating A*B. */
55     dnnl_data_type_t acc_type;
56 
57     // These accessors are to be used by the GEMM implementation
58     // Because the GEMM implementation currently assumes column major
59     // These accessors return data in column major fashion
60 
is_batcheddnnl::impl::dnnl_gemm_desc_t61     inline bool is_batched() const { return c_desc.ndims >= 3; }
62 
63     // Simplified accessors that comply to GEMM API
get_transdnnl::impl::dnnl_gemm_desc_t64     transpose_t get_trans(dnnl_memory_desc_t md) const {
65         return md.format_desc.blocking.strides[md.ndims - 1] != 1
66                 ? transpose::trans
67                 : transpose::notrans;
68     }
transadnnl::impl::dnnl_gemm_desc_t69     transpose_t transa() const { return get_trans(b_desc); };
transbdnnl::impl::dnnl_gemm_desc_t70     transpose_t transb() const { return get_trans(a_desc); };
batchdnnl::impl::dnnl_gemm_desc_t71     dnnl_dim_t batch() const {
72         // if ndims < 3, it should return 1
73         int64_t batch = 1;
74         for (int i = 0; i < c_desc.ndims - 2; ++i) {
75             if (c_desc.dims[i] == DNNL_RUNTIME_DIM_VAL)
76                 return DNNL_RUNTIME_DIM_VAL;
77             batch *= c_desc.dims[i];
78         }
79         return batch;
80     }
81 
82     /** Number of rows of C. */
mdnnl::impl::dnnl_gemm_desc_t83     dnnl_dim_t m() const { return c_desc.dims[c_desc.ndims - 1]; }
84     /** Number of columns of C. */
ndnnl::impl::dnnl_gemm_desc_t85     dnnl_dim_t n() const { return c_desc.dims[c_desc.ndims - 2]; }
86     /** Size of inner dimension shared between A and B. */
kdnnl::impl::dnnl_gemm_desc_t87     dnnl_dim_t k() const { return a_desc.dims[a_desc.ndims - 1]; }
88 
89     /** Stride between 2 matrices A in a batch. */
stride_adnnl::impl::dnnl_gemm_desc_t90     dnnl_dim_t stride_a(int dim = 0) const {
91         return (dim >= b_desc.ndims - 2 || b_desc.dims[dim] == 1)
92                 ? 0
93                 : b_desc.format_desc.blocking.strides[dim];
94     };
95     /** Stride between 2 matrices B in a batch. */
stride_bdnnl::impl::dnnl_gemm_desc_t96     dnnl_dim_t stride_b(int dim = 0) const {
97         return (dim >= a_desc.ndims - 2 || a_desc.dims[dim] == 1)
98                 ? 0
99                 : a_desc.format_desc.blocking.strides[dim];
100     };
101     /** Stride between 2 matrices C in a batch. */
stride_cdnnl::impl::dnnl_gemm_desc_t102     dnnl_dim_t stride_c(int dim = 0) const {
103         return (dim >= c_desc.ndims - 2)
104                 ? 0
105                 : c_desc.format_desc.blocking.strides[dim];
106     };
107 
108     // This assumes that one of the dimensions has strides 1
get_lddnnl::impl::dnnl_gemm_desc_t109     dnnl_dim_t get_ld(dnnl_memory_desc_t md) const {
110         auto strides = md.format_desc.blocking.strides;
111         assert(strides[md.ndims - 1] == 1 || strides[md.ndims - 2] == 1);
112         return strides[md.ndims - 1] != 1 ? strides[md.ndims - 1]
113                                           : strides[md.ndims - 2];
114     }
115     /** Leading dimension of A. */
ldadnnl::impl::dnnl_gemm_desc_t116     dnnl_dim_t lda() const { return get_ld(b_desc); }
117     /** Leading dimension of B. */
ldbdnnl::impl::dnnl_gemm_desc_t118     dnnl_dim_t ldb() const { return get_ld(a_desc); }
119     /** Leading dimension of C. */
ldcdnnl::impl::dnnl_gemm_desc_t120     dnnl_dim_t ldc() const { return get_ld(c_desc); }
121 
122     /** Type of matrix A. */
a_typednnl::impl::dnnl_gemm_desc_t123     dnnl_data_type_t a_type() const { return b_desc.data_type; }
124     /** Type of matrix B. */
b_typednnl::impl::dnnl_gemm_desc_t125     dnnl_data_type_t b_type() const { return a_desc.data_type; }
126     /** Type of matrix C. */
c_typednnl::impl::dnnl_gemm_desc_t127     dnnl_data_type_t c_type() const { return c_desc.data_type; }
128     /** Type of bias. */
bias_typednnl::impl::dnnl_gemm_desc_t129     dnnl_data_type_t bias_type() const { return bias_desc.data_type; }
130     /** Type of bias. */
bias_maskdnnl::impl::dnnl_gemm_desc_t131     int bias_mask() const {
132         assert(bias_desc.ndims <= 3);
133         int mask = 0;
134         // TODO: update the mask for batched dimension if we start
135         // supporting more batch dimensions
136         if (is_batched()) mask |= (bias_desc.dims[0] > 1) ? 1 << 0 : 0;
137 
138         // because the bias mask is in row major, we have to convert
139         // to col major here by swapping two last dimensions
140         int m_idx = is_batched();
141         mask |= (bias_desc.dims[m_idx] > 1) ? 1 << (bias_desc.ndims - m_idx)
142                                             : 0;
143         mask |= (bias_desc.dims[m_idx + 1] > 1)
144                 ? 1 << (bias_desc.ndims - (m_idx + 1))
145                 : 0;
146         return mask;
147     }
148 };
149 
150 } // namespace impl
151 } // namespace dnnl
152 
153 #endif // COMMON_GEMM_TYPES_HPP
154