1 /******************************************************************************* 2 * Copyright 2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 /* 18 * Cell execution GRU with linear before reset 19 */ 20 21 #include "gpu/ocl/rnn/ref_rnn.hpp" 22 23 namespace dnnl { 24 namespace impl { 25 namespace gpu { 26 namespace ocl { 27 #define PART 1 28 29 using namespace dnnl::impl::utils; 30 using namespace rnn_utils; 31 32 template cell_execution_sig(ref_rnn_fwd_t::cell_execution_gru_lbr); 33 template cell_execution_sig(ref_rnn_bwd_t::cell_execution_gru_lbr); 34 35 template <prop_kind_t aprop> 36 cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution_gru_lbr)) { 37 const conf_t &rnn = this->pd()->rnn_conf; 38 data_type_t src_t = this->pd()->src_type; 39 40 cl_ulong cell_scratch_offset, cell_ws_iter_offset, cell_ws_lay_offset, 41 cell_wei_iter_offset; 42 43 set_offsets_fwd_gemm(rnn, iter, dir, lay, src_t, wei_iter_offset_ptr, 44 ws_states_offset_, cell_ws_iter_offset, cell_ws_lay_offset, 45 cell_scratch_offset, cell_wei_iter_offset); 46 47 if (aprop == prop_kind::forward) { 48 // call made when cell execution is enabled 49 if (!rnn.merge_gemm_layer) 50 gemm_primitive(engine, ctx, wei_layer, wei_layer_offset[0], 51 workspace, cell_ws_lay_offset, scratch_gates, 52 cell_scratch_offset, gemm_layer_fwd); 53 54 gemm_primitive(engine, ctx, wei_iter, cell_wei_iter_offset, workspace, 55 cell_ws_iter_offset, scratch_cell, 0, gemm_iter_fwd); 56 57 (this->*elemwise_func)(ctx, dir, lay, iter, rnn.dhc, rnn.mb, workspace, 58 scratch_gates, scratch_cell, scales, bias, tm_scales, PART); 59 60 } else { 61 cl_ulong cell_diff_wei_iter_off, cell_diff_wei_lay_off, 62 cell_diff_ws_iter_off, cell_diff_ws_lay_off; 63 64 set_offsets_bwd_gemm(rnn, iter, dir, lay, ws_diff_states_offset_, 65 cell_diff_wei_iter_off, cell_diff_wei_lay_off, 66 cell_diff_ws_lay_off, cell_diff_ws_iter_off); 67 68 (this->*elemwise_func)(ctx, dir, lay, iter, rnn.dhc, rnn.mb, workspace, 69 scratch_gates, scratch_cell, scales, bias, tm_scales, PART); 70 71 if (!rnn.merge_gemm_layer) { 72 gemm_primitive(engine, ctx, scratch_gates, cell_scratch_offset, 73 workspace, cell_ws_lay_offset, diff_weights_layer, 74 cell_diff_wei_lay_off, gemm_diff_wei_layer); 75 76 gemm_primitive(engine, ctx, wei_layer, wei_layer_offset[0], 77 scratch_gates, cell_scratch_offset, workspace, 78 cell_diff_ws_lay_off, gemm_layer_bwd); 79 } 80 81 gemm_primitive(engine, ctx, wei_iter, cell_wei_iter_offset, 82 scratch_cell, 0, workspace, cell_diff_ws_iter_off, 83 gemm_iter_bwd); 84 85 gemm_primitive(engine, ctx, scratch_cell, 0, workspace, 86 cell_ws_iter_offset, diff_weights_iter, cell_diff_wei_iter_off, 87 gemm_diff_wei_iter); 88 89 gates_reduction(ctx, dir, lay, iter, rnn.n_gates, rnn.dhc, rnn.mb, 90 scratch_gates, scratch_cell, diff_bias); 91 } 92 } 93 94 } // namespace ocl 95 } // namespace gpu 96 } // namespace impl 97 } // namespace dnnl 98