1 /******************************************************************************* 2 * Copyright 2020-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 // Common for RNN and LSTM cell execution 18 19 #include "gpu/ocl/rnn/ref_rnn.hpp" 20 21 namespace dnnl { 22 namespace impl { 23 namespace gpu { 24 namespace ocl { 25 26 using namespace dnnl::impl::utils; 27 using namespace rnn_utils; 28 29 template <prop_kind_t aprop> 30 cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution)) { 31 const conf_t &rnn = this->pd()->rnn_conf; 32 data_type_t src_t = this->pd()->src_type; 33 34 cl_ulong cell_scratch_offset, cell_ws_iter_offset, cell_ws_lay_offset, 35 cell_wei_iter_offset; 36 37 set_offsets_fwd_gemm(rnn, iter, dir, lay, src_t, wei_iter_offset_ptr, 38 ws_states_offset_, cell_ws_iter_offset, cell_ws_lay_offset, 39 cell_scratch_offset, cell_wei_iter_offset); 40 41 if (aprop == prop_kind::forward) { 42 if (!rnn.merge_gemm_layer) { 43 gemm_primitive(engine, ctx, wei_layer, wei_layer_offset[0], 44 workspace, cell_ws_lay_offset, scratch_gates, 45 cell_scratch_offset, gemm_layer_fwd); 46 } 47 48 gemm_primitive(engine, ctx, wei_iter, cell_wei_iter_offset, workspace, 49 cell_ws_iter_offset, scratch_gates, cell_scratch_offset, 50 gemm_iter_fwd); 51 52 (this->*elemwise_common)(ctx, dir, lay, iter, rnn.dhc, rnn.mb, 53 workspace, scratch_gates, scratch_diff_states, scales, bias, 54 tm_scales); 55 56 } else { // backward 57 cl_ulong cell_diff_wei_iter_off, cell_diff_wei_lay_off, 58 cell_scr_diff_iter_off, cell_scr_diff_lay_off; 59 60 set_offsets_bwd_gemm(rnn, iter, dir, lay, cell_diff_wei_iter_off, 61 cell_diff_wei_lay_off, cell_scr_diff_lay_off, 62 cell_scr_diff_iter_off); 63 64 (this->*elemwise_common)(ctx, dir, lay, iter, rnn.dhc, rnn.mb, 65 workspace, scratch_gates, scratch_diff_states, scales, bias, 66 tm_scales); 67 68 gemm_primitive(engine, ctx, wei_iter, cell_wei_iter_offset, 69 scratch_gates, cell_scratch_offset, scratch_diff_states, 70 cell_scr_diff_iter_off, gemm_iter_bwd); 71 72 if (!rnn.merge_gemm_layer) { 73 gemm_primitive(engine, ctx, wei_layer, wei_layer_offset[0], 74 scratch_gates, cell_scratch_offset, scratch_diff_states, 75 cell_scr_diff_lay_off, gemm_layer_bwd); 76 77 gemm_primitive(engine, ctx, scratch_gates, cell_scratch_offset, 78 workspace, cell_ws_lay_offset, diff_weights_layer, 79 cell_diff_wei_lay_off, gemm_diff_wei_layer); 80 } 81 82 if (!rnn.merge_gemm_iter) { 83 gemm_primitive(engine, ctx, scratch_gates, cell_scratch_offset, 84 workspace, cell_ws_iter_offset, diff_weights_iter, 85 cell_diff_wei_iter_off, gemm_diff_wei_iter); 86 } 87 88 gates_reduction(ctx, dir, lay, iter, rnn.n_gates, rnn.dhc, rnn.mb, 89 scratch_gates, scratch_cell, diff_bias); 90 } 91 } 92 template cell_execution_sig(ref_rnn_fwd_t::cell_execution); 93 template cell_execution_sig(ref_rnn_bwd_t::cell_execution); 94 } // namespace ocl 95 } // namespace gpu 96 } // namespace impl 97 } // namespace dnnl 98