1 /*******************************************************************************
2 * Copyright 2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/ocl/gemm/ref_gemm.hpp"
18 
19 namespace dnnl {
20 namespace impl {
21 namespace gpu {
22 namespace ocl {
23 
execute(const gemm_exec_ctx_t & ctx) const24 status_t ref_gemm_t::execute(const gemm_exec_ctx_t &ctx) const {
25     const auto &a = GEMM_CTX_ARG_STORAGE(b);
26     const auto &b = GEMM_CTX_ARG_STORAGE(a);
27     const auto &bias = GEMM_CTX_ARG_STORAGE(bias);
28     auto &c = GEMM_CTX_ARG_STORAGE(c);
29 
30     const auto exec_d = ctx.desc() ? ctx.desc() : pd()->desc();
31 
32     dim_t off_a0 = a.offset() / types::data_type_size(exec_d->a_type());
33     dim_t off_b0 = b.offset() / types::data_type_size(exec_d->b_type());
34     dim_t off_c0 = c.offset() / types::data_type_size(exec_d->c_type());
35     dim_t off_bias0 = pd()->with_bias()
36             ? bias.offset() / types::data_type_size(exec_d->bias_type())
37             : 0;
38 
39     const memory_storage_t *scales = !pd()->attr()->output_scales_.defined()
40             ? &GEMM_CTX_ARG_STORAGE(output_scales)
41             : &CTX_GPU_RES_STORAGE(SCALES_);
42     const memory_storage_t *a0 = !pd()->attr()->zero_points_.defined(DNNL_ARG_A)
43             ? &GEMM_CTX_ARG_STORAGE(a_zero_point)
44             : &CTX_GPU_RES_STORAGE(A0_);
45 
46     const memory_storage_t *b0 = !pd()->attr()->zero_points_.defined(DNNL_ARG_B)
47             ? &GEMM_CTX_ARG_STORAGE(b_zero_point)
48             : &CTX_GPU_RES_STORAGE(B0_);
49 
50     const memory_storage_t *c0 = !pd()->attr()->zero_points_.defined(DNNL_ARG_C)
51             ? &GEMM_CTX_ARG_STORAGE(c_zero_point)
52             : &CTX_GPU_RES_STORAGE(C0_);
53 
54     int c0_mask = 0;
55     pd()->attr()->zero_points_.get(DNNL_ARG_C, nullptr, &c0_mask, nullptr);
56 
57     const dim_t MB = exec_d->batch();
58     const dim_t M = exec_d->m();
59     const dim_t N = exec_d->n();
60     const dim_t K = exec_d->k();
61     const dim_t stride_a = exec_d->stride_a();
62     const dim_t stride_b = exec_d->stride_b();
63     const dim_t stride_c = exec_d->stride_c();
64     const dim_t lda = exec_d->lda();
65     const dim_t ldb = exec_d->ldb();
66     const dim_t ldc = exec_d->ldc();
67 
68     const dim_t scale_stride = pd()->attr()->output_scales_.mask_ == 0 ? 0 : 1;
69     const float eltwise_alpha = pd()->attr_info.eltwise_alpha;
70     const float eltwise_beta = pd()->attr_info.eltwise_beta;
71     const float eltwise_scale = pd()->attr_info.eltwise_scale;
72     const int bias_mask = exec_d->bias_mask();
73     const float beta = pd()->attr_info.sum_scale;
74 
75     const int tra = exec_d->transa() == transpose::trans;
76     const int trb = exec_d->transb() == transpose::trans;
77 
78     compute::kernel_arg_list_t arg_list;
79     arg_list.set(0, a);
80     arg_list.set(1, b);
81     arg_list.set(2, c);
82     arg_list.set(3, bias);
83     arg_list.set(4, off_a0);
84     arg_list.set(5, off_b0);
85     arg_list.set(6, off_c0);
86     arg_list.set(7, off_bias0);
87     arg_list.set(8, tra);
88     arg_list.set(9, trb);
89     arg_list.set(10, MB);
90     arg_list.set(11, M);
91     arg_list.set(12, N);
92     arg_list.set(13, K);
93     arg_list.set(14, stride_a);
94     arg_list.set(15, stride_b);
95     arg_list.set(16, stride_c);
96     arg_list.set(17, lda);
97     arg_list.set(18, ldb);
98     arg_list.set(19, ldc);
99     arg_list.set(20, eltwise_alpha);
100     arg_list.set(21, eltwise_beta);
101     arg_list.set(22, eltwise_scale);
102     arg_list.set(23, bias_mask);
103     arg_list.set(24, *a0);
104     arg_list.set(25, *b0);
105     arg_list.set(26, *c0);
106     arg_list.set(27, c0_mask);
107     arg_list.set(28, *scales);
108     arg_list.set(29, scale_stride);
109     arg_list.set(30, beta);
110 
111     const size_t gws[3] = {1, (size_t)N, (size_t)MB};
112     const auto nd_range = compute::nd_range_t(gws);
113 
114     status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
115 
116     return status;
117 }
118 
119 } // namespace ocl
120 } // namespace gpu
121 } // namespace impl
122 } // namespace dnnl
123