1/******************************************************************************* 2* Copyright 2019-2020 Intel Corporation 3* 4* Licensed under the Apache License, Version 2.0 (the "License"); 5* you may not use this file except in compliance with the License. 6* You may obtain a copy of the License at 7* 8* http://www.apache.org/licenses/LICENSE-2.0 9* 10* Unless required by applicable law or agreed to in writing, software 11* distributed under the License is distributed on an "AS IS" BASIS, 12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13* See the License for the specific language governing permissions and 14* limitations under the License. 15*******************************************************************************/ 16 17#include "gpu/ocl/ocl_post_ops.h" 18#include "gpu/ocl/ocl_types.h" 19 20void get_strides(int mask, long dim0, long dim1, long dim2, long *str0, 21 long *str1, long *str2) { 22 int is_3d = dim0 > 1; 23 long dims[3]; 24 dims[0] = (is_3d && mask & (1 << 0)) ? dim0 : 1; 25 dims[1] = mask & (1 << 1) ? dim1 : 1; 26 dims[2] = mask & (1 << 2) ? dim2 : 1; 27 28 *str0 = dims[0] == 1 ? 0 : dims[1] * dims[2]; 29 *str1 = dims[1] == 1 ? 0 : 1; 30 *str2 = dims[2] == 1 ? 0 : dims[1]; 31} 32 33__kernel void ref_gemm(__global A_DATA_T *a, __global B_DATA_T *b, 34 __global C_DATA_T *c, __global BIA_DATA_T *bias, long offset_a0, 35 long offset_b0, long offset_c0, long offset_bias0, int transa, 36 int transb, long MB, long M, long N, long K, long stride_a, 37 long stride_b, long stride_c, long lda, long ldb, long ldc, 38 float eltwise_alpha, float eltwise_beta, float eltwise_scale, 39 int bias_mask, __global int *a0, __global int *b0, __global int *c0, 40 int c0_mask, __global float *scales, long scale_stride, float beta) { 41 42 int n = get_global_id(1); 43 int mb = get_global_id(2); 44 45#if WITH_BIAS 46 bias += offset_bias0; 47 48 long b_strides[3]; 49 get_strides( 50 bias_mask, MB, M, N, &b_strides[0], &b_strides[1], &b_strides[2]); 51#endif 52 53 a += offset_a0; 54 b += offset_b0; 55 c += offset_c0; 56 57 long c0_strides[3]; 58 get_strides( 59 c0_mask, MB, M, N, &c0_strides[0], &c0_strides[1], &c0_strides[2]); 60 61 for (long m = 0; m < M; ++m) { 62 ACC_DATA_T acc = 0; 63 for (long k = 0; k < K; ++k) { 64 long off_a = mb * stride_a + (transa ? m * lda + k : k * lda + m); 65 long off_b = mb * stride_b + (transb ? k * ldb + n : n * ldb + k); 66 acc += TO_ACC(A_TO_REF(a[off_a]) - a0[0]) 67 * TO_ACC(B_TO_REF(b[off_b]) - b0[0]); 68 } 69 70 long off_c = mb * stride_c + n * ldc + m; 71#if WITH_BIAS || NON_DEFAULT_ATTRS 72 POST_OP_DATA_T temp = (POST_OP_DATA_T)acc; 73#if WITH_BIAS 74 long off_bias = mb * b_strides[0] + m * b_strides[1] + n * b_strides[2]; 75 temp += BIA_TO_REF(bias[off_bias]); 76#endif 77 temp *= scales[scale_stride * n]; 78#if WITH_SUM 79 temp += (POST_OP_DATA_T)(beta * C_TO_REF(c[off_c])); 80#endif 81#if WITH_ELTWISE 82 temp = fwd_eltwise(temp, eltwise_alpha, eltwise_beta, eltwise_scale); 83#endif 84 long off_c0 85 = mb * c0_strides[0] + m * c0_strides[1] + n * c0_strides[2]; 86 temp += c0[off_c0]; 87 c[off_c] = TO_C(temp); 88#else 89 c[off_c] = TO_C(acc); 90#endif 91 } 92} 93