1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8*     http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ocl_post_ops.h"
18#include "gpu/ocl/ocl_types.h"
19
20void get_strides(int mask, long dim0, long dim1, long dim2, long *str0,
21        long *str1, long *str2) {
22    int is_3d = dim0 > 1;
23    long dims[3];
24    dims[0] = (is_3d && mask & (1 << 0)) ? dim0 : 1;
25    dims[1] = mask & (1 << 1) ? dim1 : 1;
26    dims[2] = mask & (1 << 2) ? dim2 : 1;
27
28    *str0 = dims[0] == 1 ? 0 : dims[1] * dims[2];
29    *str1 = dims[1] == 1 ? 0 : 1;
30    *str2 = dims[2] == 1 ? 0 : dims[1];
31}
32
33__kernel void ref_gemm(__global A_DATA_T *a, __global B_DATA_T *b,
34        __global C_DATA_T *c, __global BIA_DATA_T *bias, long offset_a0,
35        long offset_b0, long offset_c0, long offset_bias0, int transa,
36        int transb, long MB, long M, long N, long K, long stride_a,
37        long stride_b, long stride_c, long lda, long ldb, long ldc,
38        float eltwise_alpha, float eltwise_beta, float eltwise_scale,
39        int bias_mask, __global int *a0, __global int *b0, __global int *c0,
40        int c0_mask, __global float *scales, long scale_stride, float beta) {
41
42    int n = get_global_id(1);
43    int mb = get_global_id(2);
44
45#if WITH_BIAS
46    bias += offset_bias0;
47
48    long b_strides[3];
49    get_strides(
50            bias_mask, MB, M, N, &b_strides[0], &b_strides[1], &b_strides[2]);
51#endif
52
53    a += offset_a0;
54    b += offset_b0;
55    c += offset_c0;
56
57    long c0_strides[3];
58    get_strides(
59            c0_mask, MB, M, N, &c0_strides[0], &c0_strides[1], &c0_strides[2]);
60
61    for (long m = 0; m < M; ++m) {
62        ACC_DATA_T acc = 0;
63        for (long k = 0; k < K; ++k) {
64            long off_a = mb * stride_a + (transa ? m * lda + k : k * lda + m);
65            long off_b = mb * stride_b + (transb ? k * ldb + n : n * ldb + k);
66            acc += TO_ACC(A_TO_REF(a[off_a]) - a0[0])
67                    * TO_ACC(B_TO_REF(b[off_b]) - b0[0]);
68        }
69
70        long off_c = mb * stride_c + n * ldc + m;
71#if WITH_BIAS || NON_DEFAULT_ATTRS
72        POST_OP_DATA_T temp = (POST_OP_DATA_T)acc;
73#if WITH_BIAS
74        long off_bias = mb * b_strides[0] + m * b_strides[1] + n * b_strides[2];
75        temp += BIA_TO_REF(bias[off_bias]);
76#endif
77        temp *= scales[scale_stride * n];
78#if WITH_SUM
79        temp += (POST_OP_DATA_T)(beta * C_TO_REF(c[off_c]));
80#endif
81#if WITH_ELTWISE
82        temp = fwd_eltwise(temp, eltwise_alpha, eltwise_beta, eltwise_scale);
83#endif
84        long off_c0
85                = mb * c0_strides[0] + m * c0_strides[1] + n * c0_strides[2];
86        temp += c0[off_c0];
87        c[off_c] = TO_C(temp);
88#else
89        c[off_c] = TO_C(acc);
90#endif
91    }
92}
93