1 /*******************************************************************************
2 * Copyright 2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/ocl/rnn/ref_rnn.hpp"
18 
19 namespace dnnl {
20 namespace impl {
21 namespace gpu {
22 namespace ocl {
23 
24 template <prop_kind_t aprop>
25 elemwise_sig((_ref_rnn_common_t<aprop>::rnn_elemwise)) {
26     auto nd_range = compute::nd_range_t({dhc, batch});
27 
28     const compute::kernel_t &kernel = (aprop == prop_kind::forward)
29             ? elemwise_fwd_kernel_
30             : elemwise_bwd_kernel_;
31 
32     compute::kernel_arg_list_t arg_list;
33     arg_list.set(0, dir);
34     arg_list.set(1, lay);
35     arg_list.set(2, iter);
36     arg_list.set(3, workspace);
37     arg_list.set(4, scratch_gates);
38     arg_list.set(5, bias);
39     arg_list.set(6, pd()->desc()->alpha);
40     // for test mode
41     arg_list.set(7, tm_scales ? *tm_scales : memory_storage_t::empty_storage());
42     arg_list.set(8, pd()->rnn_conf.tm_cscale);
43     parallel_for(ctx, nd_range, kernel, arg_list);
44 }
45 template elemwise_sig(ref_rnn_fwd_t::rnn_elemwise);
46 template elemwise_sig(ref_rnn_bwd_t::rnn_elemwise);
47 
48 template <prop_kind_t aprop>
49 elemwise_sig((_ref_rnn_common_t<aprop>::lstm_elemwise)) {
50     auto nd_range = compute::nd_range_t({dhc, batch});
51 
52     const compute::kernel_t &kernel = (aprop == prop_kind::forward)
53             ? elemwise_fwd_kernel_
54             : elemwise_bwd_kernel_;
55 
56     compute::kernel_arg_list_t arg_list;
57     arg_list.set(0, dir);
58     arg_list.set(1, lay);
59     arg_list.set(2, iter);
60     arg_list.set(3, workspace);
61     arg_list.set(4, scratch_gates);
62     arg_list.set(5, bias);
63     arg_list.set(6, pd()->desc()->alpha);
64     // for test mode
65     arg_list.set(7, tm_scales ? *tm_scales : memory_storage_t::empty_storage());
66     arg_list.set(8, pd()->rnn_conf.tm_cscale);
67     parallel_for(ctx, nd_range, kernel, arg_list);
68 }
69 template elemwise_sig(ref_rnn_fwd_t::lstm_elemwise);
70 template elemwise_sig(ref_rnn_bwd_t::lstm_elemwise);
71 
72 template <prop_kind_t aprop>
73 elemwise_sig((_ref_rnn_common_t<aprop>::lstm_elemwise_u8s8)) {
74     auto nd_range = compute::nd_range_t({dhc, batch});
75 
76     float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
77     float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
78 
79     compute::kernel_arg_list_t arg_list;
80     arg_list.set(0, dir);
81     arg_list.set(1, lay);
82     arg_list.set(2, iter);
83     arg_list.set(3, workspace);
84     arg_list.set(4, scratch_gates);
85     arg_list.set(5, scales ? *scales : memory_storage_t::empty_storage());
86     arg_list.set(6, bias);
87     arg_list.set(7, pd()->desc()->alpha);
88     arg_list.set(8, data_shift);
89     arg_list.set(9, data_scale);
90     // for test mode
91     arg_list.set(
92             10, tm_scales ? *tm_scales : memory_storage_t::empty_storage());
93     arg_list.set(11, pd()->rnn_conf.tm_cscale);
94     parallel_for(ctx, nd_range, elemwise_fwd_kernel_, arg_list);
95 }
96 template elemwise_sig(ref_rnn_fwd_t::lstm_elemwise_u8s8);
97 template elemwise_sig(ref_rnn_bwd_t::lstm_elemwise_u8s8);
98 
99 template <prop_kind_t aprop>
100 elemwise_sig((_ref_rnn_common_t<aprop>::gru_lbr_elemwise)) {
101     auto nd_range = compute::nd_range_t({dhc, batch});
102 
103     const compute::kernel_t &kernel = (aprop == prop_kind::forward)
104             ? elemwise_fwd_kernel_
105             : elemwise_bwd_kernel_;
106 
107     compute::kernel_arg_list_t arg_list;
108     arg_list.set(0, dir);
109     arg_list.set(1, lay);
110     arg_list.set(2, iter);
111     arg_list.set(3, workspace);
112     arg_list.set(4, scratch_gates);
113     arg_list.set(5, bias);
114     arg_list.set(6, pd()->desc()->alpha);
115     // for test mode
116     arg_list.set(7, tm_scales ? *tm_scales : memory_storage_t::empty_storage());
117     arg_list.set(8, pd()->rnn_conf.tm_cscale);
118     arg_list.set(9, scratch_cell);
119     parallel_for(ctx, nd_range, kernel, arg_list);
120 }
121 template elemwise_sig(ref_rnn_fwd_t::gru_lbr_elemwise);
122 template elemwise_sig(ref_rnn_bwd_t::gru_lbr_elemwise);
123 
124 template <prop_kind_t aprop>
125 elemwise_sig((_ref_rnn_common_t<aprop>::gru_elemwise)) {
126     auto nd_range = compute::nd_range_t({dhc, batch});
127 
128     const compute::kernel_t &kernel = (aprop == prop_kind::forward)
129             ? elemwise_fwd_kernel_
130             : elemwise_bwd_kernel_;
131 
132     compute::kernel_arg_list_t arg_list;
133     arg_list.set(0, dir);
134     arg_list.set(1, lay);
135     arg_list.set(2, iter);
136     arg_list.set(3, workspace);
137     arg_list.set(4, scratch_gates);
138     arg_list.set(5, bias);
139     arg_list.set(6, pd()->desc()->alpha);
140     arg_list.set(7, tm_scales ? *tm_scales : memory_storage_t::empty_storage());
141     arg_list.set(8, pd()->rnn_conf.tm_cscale);
142     arg_list.set(9, part);
143     if (aprop != dnnl_forward) { arg_list.set(10, scratch_cell); }
144     parallel_for(ctx, nd_range, kernel, arg_list);
145 }
146 template elemwise_sig(ref_rnn_fwd_t::gru_elemwise);
147 template elemwise_sig(ref_rnn_bwd_t::gru_elemwise);
148 } // namespace ocl
149 } // namespace gpu
150 } // namespace impl
151 } // namespace dnnl
152