1 /*******************************************************************************
2 * Copyright 2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gpu/ocl/gemm/ref_gemm.hpp"
18
19 namespace dnnl {
20 namespace impl {
21 namespace gpu {
22 namespace ocl {
23
execute(const gemm_exec_ctx_t & ctx) const24 status_t ref_gemm_t::execute(const gemm_exec_ctx_t &ctx) const {
25 const auto &a = GEMM_CTX_ARG_STORAGE(b);
26 const auto &b = GEMM_CTX_ARG_STORAGE(a);
27 const auto &bias = GEMM_CTX_ARG_STORAGE(bias);
28 auto &c = GEMM_CTX_ARG_STORAGE(c);
29
30 const auto exec_d = ctx.desc() ? ctx.desc() : pd()->desc();
31
32 dim_t off_a0 = a.offset() / types::data_type_size(exec_d->a_type());
33 dim_t off_b0 = b.offset() / types::data_type_size(exec_d->b_type());
34 dim_t off_c0 = c.offset() / types::data_type_size(exec_d->c_type());
35 dim_t off_bias0 = pd()->with_bias()
36 ? bias.offset() / types::data_type_size(exec_d->bias_type())
37 : 0;
38
39 const memory_storage_t *scales = !pd()->attr()->output_scales_.defined()
40 ? &GEMM_CTX_ARG_STORAGE(output_scales)
41 : &CTX_GPU_RES_STORAGE(SCALES_);
42 const memory_storage_t *a0 = !pd()->attr()->zero_points_.defined(DNNL_ARG_A)
43 ? &GEMM_CTX_ARG_STORAGE(a_zero_point)
44 : &CTX_GPU_RES_STORAGE(A0_);
45
46 const memory_storage_t *b0 = !pd()->attr()->zero_points_.defined(DNNL_ARG_B)
47 ? &GEMM_CTX_ARG_STORAGE(b_zero_point)
48 : &CTX_GPU_RES_STORAGE(B0_);
49
50 const memory_storage_t *c0 = !pd()->attr()->zero_points_.defined(DNNL_ARG_C)
51 ? &GEMM_CTX_ARG_STORAGE(c_zero_point)
52 : &CTX_GPU_RES_STORAGE(C0_);
53
54 int c0_mask = 0;
55 pd()->attr()->zero_points_.get(DNNL_ARG_C, nullptr, &c0_mask, nullptr);
56
57 const dim_t MB = exec_d->batch();
58 const dim_t M = exec_d->m();
59 const dim_t N = exec_d->n();
60 const dim_t K = exec_d->k();
61 const dim_t stride_a = exec_d->stride_a();
62 const dim_t stride_b = exec_d->stride_b();
63 const dim_t stride_c = exec_d->stride_c();
64 const dim_t lda = exec_d->lda();
65 const dim_t ldb = exec_d->ldb();
66 const dim_t ldc = exec_d->ldc();
67
68 const dim_t scale_stride = pd()->attr()->output_scales_.mask_ == 0 ? 0 : 1;
69 const float eltwise_alpha = pd()->attr_info.eltwise_alpha;
70 const float eltwise_beta = pd()->attr_info.eltwise_beta;
71 const float eltwise_scale = pd()->attr_info.eltwise_scale;
72 const int bias_mask = exec_d->bias_mask();
73 const float beta = pd()->attr_info.sum_scale;
74
75 const int tra = exec_d->transa() == transpose::trans;
76 const int trb = exec_d->transb() == transpose::trans;
77
78 compute::kernel_arg_list_t arg_list;
79 arg_list.set(0, a);
80 arg_list.set(1, b);
81 arg_list.set(2, c);
82 arg_list.set(3, bias);
83 arg_list.set(4, off_a0);
84 arg_list.set(5, off_b0);
85 arg_list.set(6, off_c0);
86 arg_list.set(7, off_bias0);
87 arg_list.set(8, tra);
88 arg_list.set(9, trb);
89 arg_list.set(10, MB);
90 arg_list.set(11, M);
91 arg_list.set(12, N);
92 arg_list.set(13, K);
93 arg_list.set(14, stride_a);
94 arg_list.set(15, stride_b);
95 arg_list.set(16, stride_c);
96 arg_list.set(17, lda);
97 arg_list.set(18, ldb);
98 arg_list.set(19, ldc);
99 arg_list.set(20, eltwise_alpha);
100 arg_list.set(21, eltwise_beta);
101 arg_list.set(22, eltwise_scale);
102 arg_list.set(23, bias_mask);
103 arg_list.set(24, *a0);
104 arg_list.set(25, *b0);
105 arg_list.set(26, *c0);
106 arg_list.set(27, c0_mask);
107 arg_list.set(28, *scales);
108 arg_list.set(29, scale_stride);
109 arg_list.set(30, beta);
110
111 const size_t gws[3] = {1, (size_t)N, (size_t)MB};
112 const auto nd_range = compute::nd_range_t(gws);
113
114 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
115
116 return status;
117 }
118
119 } // namespace ocl
120 } // namespace gpu
121 } // namespace impl
122 } // namespace dnnl
123