1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 // Common for RNN and LSTM cell execution
18 
19 #include "gpu/ocl/rnn/ref_rnn.hpp"
20 
21 namespace dnnl {
22 namespace impl {
23 namespace gpu {
24 namespace ocl {
25 
26 using namespace dnnl::impl::utils;
27 using namespace rnn_utils;
28 
29 template <prop_kind_t aprop>
30 cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution)) {
31     const conf_t &rnn = this->pd()->rnn_conf;
32     data_type_t src_t = this->pd()->src_type;
33 
34     cl_ulong cell_scratch_offset, cell_ws_iter_offset, cell_ws_lay_offset,
35             cell_wei_iter_offset;
36 
37     set_offsets_fwd_gemm(rnn, iter, dir, lay, src_t, wei_iter_offset_ptr,
38             ws_states_offset_, cell_ws_iter_offset, cell_ws_lay_offset,
39             cell_scratch_offset, cell_wei_iter_offset);
40 
41     if (aprop == prop_kind::forward) {
42         if (!rnn.merge_gemm_layer) {
43             gemm_primitive(engine, ctx, wei_layer, wei_layer_offset[0],
44                     workspace, cell_ws_lay_offset, scratch_gates,
45                     cell_scratch_offset, gemm_layer_fwd);
46         }
47 
48         gemm_primitive(engine, ctx, wei_iter, cell_wei_iter_offset, workspace,
49                 cell_ws_iter_offset, scratch_gates, cell_scratch_offset,
50                 gemm_iter_fwd);
51 
52         (this->*elemwise_common)(ctx, dir, lay, iter, rnn.dhc, rnn.mb,
53                 workspace, scratch_gates, scratch_diff_states, scales, bias,
54                 tm_scales);
55 
56     } else { // backward
57         cl_ulong cell_diff_wei_iter_off, cell_diff_wei_lay_off,
58                 cell_scr_diff_iter_off, cell_scr_diff_lay_off;
59 
60         set_offsets_bwd_gemm(rnn, iter, dir, lay, cell_diff_wei_iter_off,
61                 cell_diff_wei_lay_off, cell_scr_diff_lay_off,
62                 cell_scr_diff_iter_off);
63 
64         (this->*elemwise_common)(ctx, dir, lay, iter, rnn.dhc, rnn.mb,
65                 workspace, scratch_gates, scratch_diff_states, scales, bias,
66                 tm_scales);
67 
68         gemm_primitive(engine, ctx, wei_iter, cell_wei_iter_offset,
69                 scratch_gates, cell_scratch_offset, scratch_diff_states,
70                 cell_scr_diff_iter_off, gemm_iter_bwd);
71 
72         if (!rnn.merge_gemm_layer) {
73             gemm_primitive(engine, ctx, wei_layer, wei_layer_offset[0],
74                     scratch_gates, cell_scratch_offset, scratch_diff_states,
75                     cell_scr_diff_lay_off, gemm_layer_bwd);
76 
77             gemm_primitive(engine, ctx, scratch_gates, cell_scratch_offset,
78                     workspace, cell_ws_lay_offset, diff_weights_layer,
79                     cell_diff_wei_lay_off, gemm_diff_wei_layer);
80         }
81 
82         if (!rnn.merge_gemm_iter) {
83             gemm_primitive(engine, ctx, scratch_gates, cell_scratch_offset,
84                     workspace, cell_ws_iter_offset, diff_weights_iter,
85                     cell_diff_wei_iter_off, gemm_diff_wei_iter);
86         }
87 
88         gates_reduction(ctx, dir, lay, iter, rnn.n_gates, rnn.dhc, rnn.mb,
89                 scratch_gates, scratch_cell, diff_bias);
90     }
91 }
92 template cell_execution_sig(ref_rnn_fwd_t::cell_execution);
93 template cell_execution_sig(ref_rnn_bwd_t::cell_execution);
94 } // namespace ocl
95 } // namespace gpu
96 } // namespace impl
97 } // namespace dnnl
98