1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/ocl/rnn/rnn_utils.hpp"
18 
19 #include "common/c_types_map.hpp"
20 #include "gpu/ocl/rnn/ref_rnn.hpp"
21 
22 namespace dnnl {
23 namespace impl {
24 namespace gpu {
25 namespace ocl {
26 #define AOC array_offset_calculator
27 
28 using namespace dnnl::impl::utils;
29 using namespace prop_kind;
30 using namespace data_type;
31 
is_ldigo(const memory_desc_wrapper & md)32 bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) {
33     if (md.format_kind() != format_kind::blocked) return false;
34 
35     auto blk = md.blocking_desc();
36     auto str = blk.strides;
37     auto dims = md.dims();
38     return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1
39             && str[3] == dims[4] && str[1] == str[2] * dims[2]
40             && str[0] == str[1] * dims[1];
41 };
42 
is_ldgoi(const memory_desc_wrapper & md)43 bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) {
44     if (md.format_kind() != format_kind::blocked) return false;
45 
46     auto blk = md.blocking_desc();
47     auto str = blk.strides;
48     auto dims = md.dims();
49     return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1
50             && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3]
51             && str[0] == str[1] * dims[1];
52 };
53 
init_rnn_conf(conf_t & rnn,const rnn_desc_t & rd,const memory_desc_wrapper & src_layer_d,const memory_desc_wrapper & src_iter_d,const memory_desc_wrapper & weights_layer_d,const memory_desc_wrapper & weights_iter_d,const memory_desc_wrapper & dst_layer_d)54 void rnn_utils::init_rnn_conf(conf_t &rnn, const rnn_desc_t &rd,
55         const memory_desc_wrapper &src_layer_d,
56         const memory_desc_wrapper &src_iter_d,
57         const memory_desc_wrapper &weights_layer_d,
58         const memory_desc_wrapper &weights_iter_d,
59         const memory_desc_wrapper &dst_layer_d) {
60 
61     rnn = utils::zero<decltype(rnn)>();
62     rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
63             prop_kind::forward_inference);
64     rnn.is_training = utils::one_of(
65             rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
66     rnn.is_lbr = rd.cell_kind == dnnl_lbr_gru;
67     rnn.is_vanilla_gru = rd.cell_kind == dnnl_vanilla_gru;
68 
69     switch (rd.direction) {
70         case dnnl_unidirectional_left2right: rnn.exec_dir = l2r; break;
71         case dnnl_unidirectional_right2left: rnn.exec_dir = r2l; break;
72         case dnnl_bidirectional_concat: rnn.exec_dir = bi_concat; break;
73         case dnnl_bidirectional_sum: rnn.exec_dir = bi_sum; break;
74         default: break;
75     }
76 
77     if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(),
78                 weights_layer_d.data_type()))
79         rnn.dt_conf = all_f32;
80     else if (everyone_is(bf16, src_layer_d.data_type(), dst_layer_d.data_type(),
81                      weights_layer_d.data_type()))
82         rnn.dt_conf = all_bf16;
83     else if (everyone_is(f16, src_layer_d.data_type(), dst_layer_d.data_type(),
84                      weights_layer_d.data_type()))
85         rnn.dt_conf = all_f16;
86     else if (dst_layer_d.data_type() == u8) {
87         if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
88             rnn.dt_conf = u8u8u8u8;
89         else
90             rnn.dt_conf = f32u8f32u8;
91     } else {
92         if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
93             rnn.dt_conf = u8u8u8f32;
94         else
95             rnn.dt_conf = f32u8f32f32;
96     }
97     rnn.is_int8 = !one_of(rnn.dt_conf, all_f32, all_f16, all_bf16);
98 
99     rnn.aux_data_type
100             = rnn.dt_conf == all_f16 ? data_type::f16 : data_type::f32;
101     rnn.diff_data_type = data_type::f32;
102 
103     rnn.n_layer = weights_layer_d.dims()[0];
104     rnn.n_iter = src_layer_d.dims()[0];
105     rnn.n_dir = weights_layer_d.dims()[1];
106     rnn.n_gates = weights_layer_d.dims()[3];
107     rnn.n_states = rd.cell_kind == dnnl_vanilla_lstm ? 2 : 1;
108     rnn.n_bias = rnn.n_gates + rnn.is_lbr;
109     rnn.mb = src_layer_d.dims()[1];
110     rnn.sic = weights_iter_d.dims()[2];
111     rnn.slc = weights_layer_d.dims()[2];
112     rnn.dhc = weights_layer_d.dims()[4];
113     rnn.dlc = dst_layer_d.dims()[2];
114 
115     rnn.gates_ld = rnn.dhc * rnn.n_gates;
116     rnn.gates_nld = rnn.mb;
117     rnn.states_nld = rnn.mb;
118 
119     // Set the correct number of weights parts
120     rnn.n_parts_weights_layer = 1;
121     rnn.parts_weights_layer[0] = rnn.n_gates;
122     rnn.parts_weights_layer[1] = 0;
123 
124     //there are two parts for VANILLA GRU weights iteration
125     rnn.n_parts_weights_iter = rnn.is_vanilla_gru ? 2 : 1;
126     rnn.parts_weights_iter[0] = rnn.is_vanilla_gru ? 2 : rnn.n_gates;
127     rnn.parts_weights_iter[1] = rnn.is_vanilla_gru ? 1 : 0;
128 
129     rnn.n_parts_bias = 1;
130     rnn.parts_bias[0] = rnn.n_bias;
131     rnn.parts_bias[1] = 0;
132 
133     bool is_gru = utils::one_of(
134             rd.cell_kind, alg_kind::vanilla_gru, alg_kind::lbr_gru);
135 
136     // Decide if to merge gemm across iterations or layers
137     auto src_layer_ld = src_layer_d.blocking_desc().strides[1];
138     auto dst_layer_ld = dst_layer_d.blocking_desc().strides[1];
139     auto src_layer_is_trivial_stride
140             = src_layer_d.blocking_desc().strides[0] == (src_layer_ld * rnn.mb);
141     auto dst_layer_is_trivial_stride
142             = dst_layer_d.blocking_desc().strides[0] == (dst_layer_ld * rnn.mb);
143 
144     rnn.merge_gemm_layer = ((rnn.is_fwd && src_layer_is_trivial_stride)
145                                    || ((rd.prop_kind == prop_kind::backward)
146                                            && dst_layer_is_trivial_stride))
147             && (((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) || rnn.is_int8);
148     rnn.merge_gemm_iter
149             = dst_layer_is_trivial_stride && !(rnn.is_fwd || is_gru);
150 
151     // Decide to copy bias
152     rnn.copy_bias = rnn.is_int8;
153 
154     rnn.use_workspace = rnn.is_training;
155 
156     switch (rnn.dt_conf) {
157         case all_f32:
158         case f32u8f32f32:
159             rnn.input_data_type = f32;
160             rnn.dst_data_type = f32;
161             rnn.output_data_type = f32;
162             break;
163         case all_f16:
164             rnn.input_data_type = f16;
165             rnn.dst_data_type = f16;
166             rnn.output_data_type = f16;
167             break;
168         case u8u8u8u8:
169             rnn.input_data_type = u8;
170             rnn.dst_data_type = u8;
171             rnn.output_data_type = u8;
172             break;
173         case u8u8u8f32:
174             rnn.input_data_type = u8;
175             rnn.dst_data_type = f32;
176             rnn.output_data_type = u8;
177             break;
178         case f32u8f32u8:
179             rnn.input_data_type = f32;
180             rnn.dst_data_type = u8;
181             rnn.output_data_type = f32;
182             break;
183         case all_bf16:
184             rnn.input_data_type = bf16;
185             rnn.dst_data_type = bf16;
186             rnn.output_data_type = bf16;
187             break;
188         default: assert(!"unimplemented");
189     }
190 }
191 
init_test_mode(conf_t & rnn,const primitive_attr_t & attr)192 void rnn_utils::init_test_mode(conf_t &rnn, const primitive_attr_t &attr) {
193     rnn.is_testmode = attr.rnn_tparams_.test_mode_;
194     rnn.tm_ngates = attr.rnn_tparams_.ngates_;
195     rnn.tm_cscale = attr.rnn_tparams_.cscale_;
196 }
197 
set_rnn_conf(conf_t & rnn,const rnn_desc_t & rd,const memory_desc_wrapper & weights_layer_d,const memory_desc_wrapper & weights_iter_d,const memory_desc_wrapper & diff_weights_layer_d,const memory_desc_wrapper & diff_weights_iter_d)198 void rnn_utils::set_rnn_conf(conf_t &rnn, const rnn_desc_t &rd,
199         const memory_desc_wrapper &weights_layer_d,
200         const memory_desc_wrapper &weights_iter_d,
201         const memory_desc_wrapper &diff_weights_layer_d,
202         const memory_desc_wrapper &diff_weights_iter_d) {
203 
204     //Set leading dimensions for input weights arrays depending on input format
205     auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
206         ld = 0;
207         nld = 0;
208         if (md.is_blocking_desc()) {
209             if (is_ldigo(md)) {
210                 ld = (int)md.blocking_desc().strides[2];
211                 nld = md.dims()[2];
212             } else if (is_ldgoi(md)) {
213                 ld = (int)md.blocking_desc().strides[4];
214                 nld = md.dims()[3] * md.dims()[4];
215             } else
216                 assert(!"unsupported weights format");
217         }
218     };
219     set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
220     set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
221     if (!rnn.is_fwd) {
222         set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
223                 rnn.diff_weights_layer_nld);
224         set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
225                 rnn.diff_weights_iter_nld);
226     }
227 
228     int sizeof_states_dt
229             = rnn.dt_conf == all_f32 ? sizeof(cl_float) : sizeof(cl_half);
230     int aux_elsz = rnn.aux_data_type == data_type::f16 ? sizeof(cl_half)
231                                                        : sizeof(float);
232     rnn.ws_states_elsz = rnn.dt_conf == all_f32
233             ? sizeof(cl_float)
234             : rnn.dt_conf == all_f16 || rnn.dt_conf == all_bf16
235                     ? sizeof(cl_half)
236                     : rnn.dt_conf == u8u8u8u8 ? sizeof(int8_t)
237                                               : sizeof(int32_t);
238 
239     // Different size required for forward and backward pass
240     rnn.scratch_gates_elsz = (!rnn.is_fwd && rnn.dt_conf == all_bf16)
241             ? sizeof(cl_half)
242             : aux_elsz;
243 
244     // Set workspace sizes to store:
245     // states to copmute a pass
246     // diff states to copmute bwd pass (training only)
247     // intermediate results from the gates
248     rnn.states_ws_ld = get_good_ld(
249             nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dhc)), sizeof_states_dt);
250     rnn.gates_ws_ld = get_good_ld(rnn.gates_ld,
251             rnn.dt_conf == all_f16 ? sizeof(cl_half) : sizeof(cl_float));
252     rnn.diff_states_ws_ld = get_good_ld(
253             nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dhc)), sizeof(cl_float));
254     rnn.scratch_gates_ld = get_good_ld(rnn.gates_ld, rnn.scratch_gates_elsz);
255 
256     bool is_lstm = rd.cell_kind == dnnl_vanilla_lstm;
257 
258     rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
259             * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * rnn.ws_states_elsz;
260     // we do not need a good ld for iter_c as it is not involved in GEMM
261     // for now reverting it back to what it was originally
262     // TODO: seprate diff_c_offsets from diff-states & seprate h- and c- off
263     rnn.ws_c_states_size = is_lstm ? (size_t)(rnn.n_layer + 1) * rnn.n_dir
264                     * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * aux_elsz
265                                    : (size_t)0;
266     rnn.ws_diff_states_size = rnn.is_training ? (size_t)(rnn.n_layer + 1)
267                     * rnn.n_dir * (rnn.n_states + 1) * (rnn.n_iter + 1) * rnn.mb
268                     * rnn.diff_states_ws_ld * aux_elsz
269                                               : (size_t)0;
270     rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb
271             * rnn.gates_ws_ld * aux_elsz;
272     rnn.n_iter_scratch_gates
273             = (rnn.merge_gemm_layer || rnn.merge_gemm_iter) ? rnn.n_iter : 1;
274     rnn.scratch_gates_size = (size_t)rnn.n_iter_scratch_gates * rnn.gates_nld
275             * rnn.scratch_gates_ld * rnn.scratch_gates_elsz;
276     rnn.ws_dhG1_size
277             = (rd.cell_kind == alg_kind::vanilla_gru && rnn.is_training)
278             ? (size_t)rnn.gates_nld * rnn.diff_states_ws_ld * sizeof(float)
279             : 0;
280     rnn.ws_bias_size
281             = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dhc * aux_elsz;
282 
283     // For intermediate step in post-gemm fwd lbr gru
284     rnn.scratch_cell_size = rnn.is_lbr
285             ? (size_t)rnn.gates_nld * rnn.scratch_gates_ld
286                     * rnn.scratch_gates_elsz
287             : (rd.cell_kind == alg_kind::vanilla_gru && rnn.is_training
288                             ? (size_t)rnn.states_nld * rnn.states_ws_ld
289                                     * rnn.ws_states_elsz
290                             : 0);
291 
292     // Used for storing the intermediate value from fwd pass in training lbr gru
293     rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dhc * aux_elsz;
294     rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
295             * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * aux_elsz;
296 }
297 
get_good_ld(int dim,int sizeof_dt)298 int rnn_utils::get_good_ld(int dim, int sizeof_dt) {
299     // we want matrices leading dimentions to be 64-byte aligned,
300     // and not divisible by 256 to avoid 4K aliasing effects
301     int ld = rnd_up(dim, 64 / sizeof_dt);
302     return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld;
303 }
304 
set_offsets(const conf_t & rnn,size_t & ws_gates_offset,size_t & ws_states_offset,size_t & ws_c_states_offset,size_t & ws_diff_states_offset,size_t & ws_grid_comp_offset,size_t & scratch_cell_offset,size_t & ws_dhG1_offset,size_t & ws_bias_offset,size_t & scratch_gates_offset,size_t & scratchpad_size,size_t & workspace_size)305 void rnn_utils::set_offsets(const conf_t &rnn, size_t &ws_gates_offset,
306         size_t &ws_states_offset, size_t &ws_c_states_offset,
307         size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
308         size_t &scratch_cell_offset, size_t &ws_dhG1_offset,
309         size_t &ws_bias_offset, size_t &scratch_gates_offset,
310         size_t &scratchpad_size, size_t &workspace_size) {
311 
312     const size_t page_size = 4096;
313     size_t current_offset;
314 
315     // Mandatory workspaces: go to workspace if use_workspace, scratchpad
316     // otherwise assumes the workspace base pointer is page aligned
317 
318     current_offset = 0;
319     ws_gates_offset = current_offset;
320     current_offset += rnn.ws_gates_size;
321 
322     current_offset = utils::rnd_up(current_offset, page_size);
323     ws_states_offset = current_offset;
324     current_offset += rnn.ws_states_size;
325 
326     current_offset = utils::rnd_up(current_offset, page_size);
327     ws_c_states_offset = current_offset;
328     current_offset += rnn.ws_c_states_size;
329 
330     current_offset = utils::rnd_up(current_offset, page_size);
331     ws_grid_comp_offset = current_offset;
332     current_offset += rnn.ws_grid_comp_size;
333 
334     current_offset = utils::rnd_up(current_offset, page_size);
335     ws_diff_states_offset = current_offset;
336     current_offset += rnn.ws_diff_states_size;
337 
338     current_offset = utils::rnd_up(current_offset, page_size);
339     ws_dhG1_offset = current_offset;
340     current_offset += rnn.ws_dhG1_size;
341 
342     workspace_size = rnn.use_workspace ? current_offset : 0;
343 
344     // Optional scratchpads
345     // Assumes the scratchpad base pointer is page aligned.
346     // If use_workspace, the following goes to scratchpad alone,
347     // otherwise, all goes to scratchpad and continue incrementing offset
348     current_offset = rnn.use_workspace ? 0 : current_offset;
349 
350     current_offset = utils::rnd_up(current_offset, page_size);
351     scratch_gates_offset = current_offset;
352     current_offset += rnn.scratch_gates_size;
353 
354     current_offset = utils::rnd_up(current_offset, page_size);
355     scratch_cell_offset = current_offset;
356     current_offset += rnn.scratch_cell_size;
357 
358     ws_bias_offset = 0;
359     if (rnn.copy_bias) {
360         current_offset = utils::rnd_up(current_offset, page_size);
361         ws_bias_offset = current_offset;
362         current_offset += rnn.ws_bias_size;
363     }
364     scratchpad_size = current_offset;
365 }
366 
get_scratchpad_and_workspace_sizes(const conf_t & rnn,size_t & scratchpad_size,size_t & workspace_size)367 void rnn_utils::get_scratchpad_and_workspace_sizes(
368         const conf_t &rnn, size_t &scratchpad_size, size_t &workspace_size) {
369     size_t ws_gates_offset, ws_states_offset, ws_c_states_offset,
370             ws_diff_states_offset, ws_grid_comp_offset, scratch_cell_offset,
371             ws_dhG1_offset, ws_bias_offset, sratch_gates_offset;
372     set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset,
373             ws_c_states_offset, ws_grid_comp_offset, scratch_cell_offset,
374             ws_dhG1_offset, ws_bias_offset, sratch_gates_offset,
375             scratchpad_size, workspace_size);
376 }
377 
set_offsets_fwd_gemm(const conf_t & rnn,int dir,int lay,data_type_t src_t,size_t * wei_layer_off_ptr,const size_t & ws_states_offset_,size_t & grid_ws_lay_offset,size_t & grid_wei_lay_offset,size_t & grid_ws_iter_offset)378 void rnn_utils::set_offsets_fwd_gemm(const conf_t &rnn, int dir, int lay,
379         data_type_t src_t, size_t *wei_layer_off_ptr,
380         const size_t &ws_states_offset_, size_t &grid_ws_lay_offset,
381         size_t &grid_wei_lay_offset, size_t &grid_ws_iter_offset) {
382     // Function overloaded. This function is called by grid execution
383     int n_layer = rnn.n_layer;
384     int n_dir = rnn.n_dir;
385 
386     AOC<size_t, 3> off_weights_lay(
387             wei_layer_off_ptr, n_layer, n_dir, rnn.n_parts_weights_layer);
388 
389     grid_wei_lay_offset = off_weights_lay(lay, dir, 0);
390     grid_ws_lay_offset = (cl_ulong)(ws_states_offset_
391             + OFF4(lay, n_layer + 1, dir, n_dir, 1, rnn.n_iter + 1, 0,
392                       rnn.mb * rnn.states_ws_ld)
393                     * types::data_type_size(src_t));
394     grid_ws_iter_offset = (cl_ulong)(ws_states_offset_
395             + OFF4(lay + 1, rnn.n_layers + 1, dir, rnn.n_dir, 0, rnn.n_iter + 1,
396                       0, rnn.mb * rnn.states_ws_ld)
397                     * types::data_type_size(src_t));
398     UNUSED(n_layer);
399 }
400 
set_offsets_fwd_gemm(const conf_t & rnn,int iter,int dir,int lay,data_type_t src_t,size_t * wei_iter_off_ptr,const size_t & ws_states_offset_,size_t & cell_ws_iter_offset,size_t & cell_ws_lay_offset,size_t & cell_scratch_offset,size_t & cell_wei_iter_offset)401 void rnn_utils::set_offsets_fwd_gemm(const conf_t &rnn, int iter, int dir,
402         int lay, data_type_t src_t, size_t *wei_iter_off_ptr,
403         const size_t &ws_states_offset_, size_t &cell_ws_iter_offset,
404         size_t &cell_ws_lay_offset, size_t &cell_scratch_offset,
405         size_t &cell_wei_iter_offset) {
406     int n_layers = rnn.n_layer;
407     int batch = rnn.mb;
408     int n_iter = rnn.n_iter;
409     int n_dir = rnn.n_dir;
410 
411     if (wei_iter_off_ptr) {
412         AOC<size_t, 3> off_weights_iter(wei_iter_off_ptr, rnn.n_layer,
413                 rnn.n_dir, rnn.n_parts_weights_iter);
414         cell_wei_iter_offset = off_weights_iter(lay, dir, 0);
415     }
416 
417     cell_scratch_offset = (rnn.merge_gemm_iter || rnn.merge_gemm_layer)
418             ? (cl_ulong)(
419                     OFF2(iter, n_iter, 0, rnn.gates_nld * rnn.scratch_gates_ld)
420                     * rnn.scratch_gates_elsz)
421             : (size_t)0;
422     cell_ws_iter_offset = (cl_ulong)(ws_states_offset_
423             + OFF4(lay + 1, n_layers + 1, dir, n_dir, iter, n_iter + 1, 0,
424                       batch * rnn.states_ws_ld)
425                     * types::data_type_size(src_t));
426     cell_ws_lay_offset = (cl_ulong)(ws_states_offset_
427             + OFF4(lay, n_layers + 1, dir, n_dir, iter + 1, n_iter + 1, 0,
428                       batch * rnn.states_ws_ld)
429                     * types::data_type_size(src_t));
430     UNUSED(n_layers);
431 }
432 
set_gru_offsets_part2(const conf_t & rnn,int iter,int dir,int lay,data_type_t src_t,size_t * wei_iter_off_ptr,const size_t & ws_states_offset_,size_t & cell_wei_iter_offset,size_t & cell_scratch_offset,size_t & cell_ws_iter_offset)433 void rnn_utils::set_gru_offsets_part2(const conf_t &rnn, int iter, int dir,
434         int lay, data_type_t src_t, size_t *wei_iter_off_ptr,
435         const size_t &ws_states_offset_, size_t &cell_wei_iter_offset,
436         size_t &cell_scratch_offset, size_t &cell_ws_iter_offset) {
437 
438     AOC<size_t, 3> off_weights_iter(
439             wei_iter_off_ptr, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
440     cell_wei_iter_offset = off_weights_iter(lay, dir, 1);
441     cell_scratch_offset += 2 * rnn.dhc * rnn.scratch_gates_elsz;
442     cell_ws_iter_offset = (cl_ulong)(ws_states_offset_
443             + OFF4(lay + 1, rnn.n_layers + 1, dir, rnn.n_dir, iter + 1,
444                       rnn.n_iter + 1, 0, rnn.mb * rnn.states_ws_ld)
445                     * types::data_type_size(src_t));
446 }
447 
set_offsets_bwd_gemm(const conf_t & rnn,int iter,int dir,int lay,const size_t & ws_diff_states_off_,size_t & cell_diff_wei_iter_off,size_t & cell_diff_wei_lay_off,size_t & cell_diff_ws_lay_off)448 void rnn_utils::set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir,
449         int lay, const size_t &ws_diff_states_off_,
450         size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
451         size_t &cell_diff_ws_lay_off) {
452     // Function overloaded. This function is called by grid execution and it
453     // then calls set_offsets_bwd_gemm which is otherwise called in cell exec
454     cl_ulong dummy_var;
455     set_offsets_bwd_gemm(rnn, iter, dir, lay, ws_diff_states_off_,
456             cell_diff_wei_iter_off, cell_diff_wei_lay_off, cell_diff_ws_lay_off,
457             dummy_var);
458 }
459 
set_offsets_bwd_gemm(const conf_t & rnn,int iter,int dir,int lay,const size_t & ws_diff_states_off_,size_t & cell_diff_wei_iter_off,size_t & cell_diff_wei_lay_off,size_t & cell_diff_ws_lay_off,size_t & cell_diff_ws_iter_off,size_t & cell_diff_wei_iter_off2)460 void rnn_utils::set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir,
461         int lay, const size_t &ws_diff_states_off_,
462         size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
463         size_t &cell_diff_ws_lay_off, size_t &cell_diff_ws_iter_off,
464         size_t &cell_diff_wei_iter_off2) {
465 
466     set_offsets_bwd_gemm(rnn, iter, dir, lay, ws_diff_states_off_,
467             cell_diff_wei_iter_off, cell_diff_wei_lay_off, cell_diff_ws_lay_off,
468             cell_diff_ws_iter_off);
469     cell_diff_wei_iter_off2
470             = cell_diff_wei_iter_off + 2 * rnn.dhc * sizeof(float);
471 }
472 
set_offsets_bwd_gemm(const conf_t & rnn,int iter,int dir,int lay,const size_t & ws_diff_states_off_,size_t & cell_diff_wei_iter_off,size_t & cell_diff_wei_lay_off,size_t & cell_diff_ws_lay_off,size_t & cell_diff_ws_iter_off)473 void rnn_utils::set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir,
474         int lay, const size_t &ws_diff_states_off_,
475         size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
476         size_t &cell_diff_ws_lay_off, size_t &cell_diff_ws_iter_off) {
477     int n_layers = rnn.n_layer;
478     int batch = rnn.mb;
479     int n_iter = rnn.n_iter;
480     int n_dir = rnn.n_dir;
481     int n_states = rnn.n_states;
482 
483     cell_diff_ws_iter_off = ws_diff_states_off_
484             + OFF5(lay, n_layers + 1, dir, n_dir, 0, n_states + 1, iter,
485                       n_iter + 1, 0, rnn.states_nld * rnn.diff_states_ws_ld)
486                     * sizeof(float);
487     cell_diff_ws_lay_off = ws_diff_states_off_
488             + OFF5(lay, n_layers + 1, dir, n_dir, n_states, n_states + 1, iter,
489                       n_iter + 1, 0, rnn.states_nld * rnn.diff_states_ws_ld)
490                     * sizeof(float);
491     cell_diff_wei_lay_off
492             = OFF3(lay, n_layers, dir, n_dir, 0,
493                       rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld)
494             * sizeof(float);
495     cell_diff_wei_iter_off
496             = OFF3(lay, n_layers, dir, n_dir, 0,
497                       rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld)
498             * sizeof(float);
499     UNUSED(n_layers);
500     UNUSED(batch);
501 }
502 
set_good_strides(memory_desc_t & weights_md,format_tag_t tag)503 status_t rnn_utils::set_good_strides(
504         memory_desc_t &weights_md, format_tag_t tag) {
505     auto &strides = weights_md.format_desc.blocking.strides;
506     auto dims = weights_md.dims;
507     using namespace format_tag;
508 
509     if (tag == ldigo) {
510         strides[2] = rnn_utils::get_good_ld((int)strides[2],
511                 (int)types::data_type_size(weights_md.data_type));
512         strides[1] = dims[2] * strides[2];
513         strides[0] = dims[1] * strides[1];
514     } else if (tag == ldgoi) {
515         strides[4] = rnn_utils::get_good_ld((int)strides[4],
516                 (int)types::data_type_size(weights_md.data_type));
517         strides[3] = dims[4] * strides[4];
518         strides[1] = dims[3] * strides[3];
519         strides[0] = dims[1] * strides[1];
520     } else
521         return status::unimplemented;
522 
523     return status::success;
524 }
525 
set_expected_desc(conf_t & rnn,memory_desc_t & weights_md,bool is_iter)526 status_t rnn_utils::set_expected_desc(
527         conf_t &rnn, memory_desc_t &weights_md, bool is_iter) {
528     using namespace format_tag;
529     CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi));
530 
531     // Adjust strides for good leading dimension in GEMM
532     CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi));
533 
534     // set we need extra memory
535     if (rnn.is_fwd && !one_of(rnn.dt_conf, all_f32, all_f16, all_bf16)) {
536         weights_md.extra.flags = memory_extra_flags::rnn_u8s8_compensation;
537         weights_md.extra.compensation_mask = 27; // ldigo 11011;
538     }
539     return status::success;
540 }
541 
542 } // namespace ocl
543 } // namespace gpu
544 } // namespace impl
545 } // namespace dnnl
546