1 /******************************************************************************* 2 * Copyright 2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #include "gpu/ocl/rnn/ref_rnn.hpp" 18 19 namespace dnnl { 20 namespace impl { 21 namespace gpu { 22 namespace ocl { 23 24 template <prop_kind_t aprop> 25 elemwise_sig((_ref_rnn_common_t<aprop>::rnn_elemwise)) { 26 auto nd_range = compute::nd_range_t({dhc, batch}); 27 28 const compute::kernel_t &kernel = (aprop == prop_kind::forward) 29 ? elemwise_fwd_kernel_ 30 : elemwise_bwd_kernel_; 31 32 compute::kernel_arg_list_t arg_list; 33 arg_list.set(0, dir); 34 arg_list.set(1, lay); 35 arg_list.set(2, iter); 36 arg_list.set(3, workspace); 37 arg_list.set(4, scratch_gates); 38 arg_list.set(5, bias); 39 arg_list.set(6, pd()->desc()->alpha); 40 // for test mode 41 arg_list.set(7, tm_scales ? *tm_scales : memory_storage_t::empty_storage()); 42 arg_list.set(8, pd()->rnn_conf.tm_cscale); 43 parallel_for(ctx, nd_range, kernel, arg_list); 44 } 45 template elemwise_sig(ref_rnn_fwd_t::rnn_elemwise); 46 template elemwise_sig(ref_rnn_bwd_t::rnn_elemwise); 47 48 template <prop_kind_t aprop> 49 elemwise_sig((_ref_rnn_common_t<aprop>::lstm_elemwise)) { 50 auto nd_range = compute::nd_range_t({dhc, batch}); 51 52 const compute::kernel_t &kernel = (aprop == prop_kind::forward) 53 ? elemwise_fwd_kernel_ 54 : elemwise_bwd_kernel_; 55 56 compute::kernel_arg_list_t arg_list; 57 arg_list.set(0, dir); 58 arg_list.set(1, lay); 59 arg_list.set(2, iter); 60 arg_list.set(3, workspace); 61 arg_list.set(4, scratch_gates); 62 arg_list.set(5, bias); 63 arg_list.set(6, pd()->desc()->alpha); 64 // for test mode 65 arg_list.set(7, tm_scales ? *tm_scales : memory_storage_t::empty_storage()); 66 arg_list.set(8, pd()->rnn_conf.tm_cscale); 67 parallel_for(ctx, nd_range, kernel, arg_list); 68 } 69 template elemwise_sig(ref_rnn_fwd_t::lstm_elemwise); 70 template elemwise_sig(ref_rnn_bwd_t::lstm_elemwise); 71 72 template <prop_kind_t aprop> 73 elemwise_sig((_ref_rnn_common_t<aprop>::lstm_elemwise_u8s8)) { 74 auto nd_range = compute::nd_range_t({dhc, batch}); 75 76 float data_shift = pd()->attr()->rnn_data_qparams_.shift_; 77 float data_scale = pd()->attr()->rnn_data_qparams_.scale_; 78 79 compute::kernel_arg_list_t arg_list; 80 arg_list.set(0, dir); 81 arg_list.set(1, lay); 82 arg_list.set(2, iter); 83 arg_list.set(3, workspace); 84 arg_list.set(4, scratch_gates); 85 arg_list.set(5, scales ? *scales : memory_storage_t::empty_storage()); 86 arg_list.set(6, bias); 87 arg_list.set(7, pd()->desc()->alpha); 88 arg_list.set(8, data_shift); 89 arg_list.set(9, data_scale); 90 // for test mode 91 arg_list.set( 92 10, tm_scales ? *tm_scales : memory_storage_t::empty_storage()); 93 arg_list.set(11, pd()->rnn_conf.tm_cscale); 94 parallel_for(ctx, nd_range, elemwise_fwd_kernel_, arg_list); 95 } 96 template elemwise_sig(ref_rnn_fwd_t::lstm_elemwise_u8s8); 97 template elemwise_sig(ref_rnn_bwd_t::lstm_elemwise_u8s8); 98 99 template <prop_kind_t aprop> 100 elemwise_sig((_ref_rnn_common_t<aprop>::gru_lbr_elemwise)) { 101 auto nd_range = compute::nd_range_t({dhc, batch}); 102 103 const compute::kernel_t &kernel = (aprop == prop_kind::forward) 104 ? elemwise_fwd_kernel_ 105 : elemwise_bwd_kernel_; 106 107 compute::kernel_arg_list_t arg_list; 108 arg_list.set(0, dir); 109 arg_list.set(1, lay); 110 arg_list.set(2, iter); 111 arg_list.set(3, workspace); 112 arg_list.set(4, scratch_gates); 113 arg_list.set(5, bias); 114 arg_list.set(6, pd()->desc()->alpha); 115 // for test mode 116 arg_list.set(7, tm_scales ? *tm_scales : memory_storage_t::empty_storage()); 117 arg_list.set(8, pd()->rnn_conf.tm_cscale); 118 arg_list.set(9, scratch_cell); 119 parallel_for(ctx, nd_range, kernel, arg_list); 120 } 121 template elemwise_sig(ref_rnn_fwd_t::gru_lbr_elemwise); 122 template elemwise_sig(ref_rnn_bwd_t::gru_lbr_elemwise); 123 124 template <prop_kind_t aprop> 125 elemwise_sig((_ref_rnn_common_t<aprop>::gru_elemwise)) { 126 auto nd_range = compute::nd_range_t({dhc, batch}); 127 128 const compute::kernel_t &kernel = (aprop == prop_kind::forward) 129 ? elemwise_fwd_kernel_ 130 : elemwise_bwd_kernel_; 131 132 compute::kernel_arg_list_t arg_list; 133 arg_list.set(0, dir); 134 arg_list.set(1, lay); 135 arg_list.set(2, iter); 136 arg_list.set(3, workspace); 137 arg_list.set(4, scratch_gates); 138 arg_list.set(5, bias); 139 arg_list.set(6, pd()->desc()->alpha); 140 arg_list.set(7, tm_scales ? *tm_scales : memory_storage_t::empty_storage()); 141 arg_list.set(8, pd()->rnn_conf.tm_cscale); 142 arg_list.set(9, part); 143 if (aprop != dnnl_forward) { arg_list.set(10, scratch_cell); } 144 parallel_for(ctx, nd_range, kernel, arg_list); 145 } 146 template elemwise_sig(ref_rnn_fwd_t::gru_elemwise); 147 template elemwise_sig(ref_rnn_bwd_t::gru_elemwise); 148 } // namespace ocl 149 } // namespace gpu 150 } // namespace impl 151 } // namespace dnnl 152