1 /*******************************************************************************
2 * Copyright 2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 /*
18  * Cell execution GRU with linear before reset
19  */
20 
21 #include "gpu/ocl/rnn/ref_rnn.hpp"
22 
23 namespace dnnl {
24 namespace impl {
25 namespace gpu {
26 namespace ocl {
27 #define PART 1
28 
29 using namespace dnnl::impl::utils;
30 using namespace rnn_utils;
31 
32 template cell_execution_sig(ref_rnn_fwd_t::cell_execution_gru_lbr);
33 template cell_execution_sig(ref_rnn_bwd_t::cell_execution_gru_lbr);
34 
35 template <prop_kind_t aprop>
36 cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution_gru_lbr)) {
37     const conf_t &rnn = this->pd()->rnn_conf;
38     data_type_t src_t = this->pd()->src_type;
39 
40     cl_ulong cell_scratch_offset, cell_ws_iter_offset, cell_ws_lay_offset,
41             cell_wei_iter_offset;
42 
43     set_offsets_fwd_gemm(rnn, iter, dir, lay, src_t, wei_iter_offset_ptr,
44             ws_states_offset_, cell_ws_iter_offset, cell_ws_lay_offset,
45             cell_scratch_offset, cell_wei_iter_offset);
46 
47     if (aprop == prop_kind::forward) {
48         // call made when cell execution is enabled
49         if (!rnn.merge_gemm_layer)
50             gemm_primitive(engine, ctx, wei_layer, wei_layer_offset[0],
51                     workspace, cell_ws_lay_offset, scratch_gates,
52                     cell_scratch_offset, gemm_layer_fwd);
53 
54         gemm_primitive(engine, ctx, wei_iter, cell_wei_iter_offset, workspace,
55                 cell_ws_iter_offset, scratch_cell, 0, gemm_iter_fwd);
56 
57         (this->*elemwise_func)(ctx, dir, lay, iter, rnn.dhc, rnn.mb, workspace,
58                 scratch_gates, scratch_cell, scales, bias, tm_scales, PART);
59 
60     } else {
61         cl_ulong cell_diff_wei_iter_off, cell_diff_wei_lay_off,
62                 cell_diff_ws_iter_off, cell_diff_ws_lay_off;
63 
64         set_offsets_bwd_gemm(rnn, iter, dir, lay, ws_diff_states_offset_,
65                 cell_diff_wei_iter_off, cell_diff_wei_lay_off,
66                 cell_diff_ws_lay_off, cell_diff_ws_iter_off);
67 
68         (this->*elemwise_func)(ctx, dir, lay, iter, rnn.dhc, rnn.mb, workspace,
69                 scratch_gates, scratch_cell, scales, bias, tm_scales, PART);
70 
71         if (!rnn.merge_gemm_layer) {
72             gemm_primitive(engine, ctx, scratch_gates, cell_scratch_offset,
73                     workspace, cell_ws_lay_offset, diff_weights_layer,
74                     cell_diff_wei_lay_off, gemm_diff_wei_layer);
75 
76             gemm_primitive(engine, ctx, wei_layer, wei_layer_offset[0],
77                     scratch_gates, cell_scratch_offset, workspace,
78                     cell_diff_ws_lay_off, gemm_layer_bwd);
79         }
80 
81         gemm_primitive(engine, ctx, wei_iter, cell_wei_iter_offset,
82                 scratch_cell, 0, workspace, cell_diff_ws_iter_off,
83                 gemm_iter_bwd);
84 
85         gemm_primitive(engine, ctx, scratch_cell, 0, workspace,
86                 cell_ws_iter_offset, diff_weights_iter, cell_diff_wei_iter_off,
87                 gemm_diff_wei_iter);
88 
89         gates_reduction(ctx, dir, lay, iter, rnn.n_gates, rnn.dhc, rnn.mb,
90                 scratch_gates, scratch_cell, diff_bias);
91     }
92 }
93 
94 } // namespace ocl
95 } // namespace gpu
96 } // namespace impl
97 } // namespace dnnl
98