1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/ocl/gemm_matmul.hpp"
18 
19 #include "gpu/gemm/gpu_gemm_utils.hpp"
20 
21 namespace dnnl {
22 namespace impl {
23 namespace gpu {
24 namespace ocl {
25 
execute(const exec_ctx_t & ctx) const26 status_t gemm_matmul_t::execute(const exec_ctx_t &ctx) const {
27     using namespace memory_tracking::names;
28     using namespace gemm_utils;
29 
30     const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC);
31     const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS);
32     const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST);
33     const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS);
34 
35     memory_storage_t *scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES);
36     memory_storage_t *a0
37             = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
38 
39     memory_storage_t *b0
40             = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
41 
42     memory_storage_t *c0
43             = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
44 
45     gemm_exec_args_t gemm_args;
46     gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC);
47     gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
48     gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DST);
49     gemm_args.bias = &CTX_IN_STORAGE(DNNL_ARG_BIAS);
50 
51     gemm_args.a_zero_point = a0;
52     gemm_args.b_zero_point = b0;
53     gemm_args.c_zero_point = c0;
54     gemm_args.output_scales = scales;
55     gemm_args.exec_args = ctx.args();
56     auto gemm_desc = gemm_desc_t();
57     gemm_desc.primitive_kind = primitive_kind::gemm;
58     gemm_desc.a_desc = *src_d.md_;
59     gemm_desc.b_desc = *weights_d.md_;
60     gemm_desc.c_desc = *dst_d.md_;
61     gemm_desc.bias_desc = *bia_d.md_;
62     gemm_desc.acc_type = pd()->desc()->accum_data_type;
63 
64     gemm_exec_ctx_t gemm_ctx(ctx, gemm_args, &gemm_desc);
65 
66     nested_scratchpad_t ns(ctx, key_nested, gemm_);
67     gemm_ctx.set_scratchpad_grantor(ns.grantor());
68 
69     status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx);
70     if (gemm_exec_status != status::success) return gemm_exec_status;
71 
72     return status::success;
73 }
74 
75 } // namespace ocl
76 } // namespace gpu
77 } // namespace impl
78 } // namespace dnnl
79