1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gpu/ocl/rnn/rnn_utils.hpp"
18
19 #include "common/c_types_map.hpp"
20 #include "gpu/ocl/rnn/ref_rnn.hpp"
21
22 namespace dnnl {
23 namespace impl {
24 namespace gpu {
25 namespace ocl {
26 #define AOC array_offset_calculator
27
28 using namespace dnnl::impl::utils;
29 using namespace prop_kind;
30 using namespace data_type;
31
is_ldigo(const memory_desc_wrapper & md)32 bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) {
33 if (md.format_kind() != format_kind::blocked) return false;
34
35 auto blk = md.blocking_desc();
36 auto str = blk.strides;
37 auto dims = md.dims();
38 return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1
39 && str[3] == dims[4] && str[1] == str[2] * dims[2]
40 && str[0] == str[1] * dims[1];
41 };
42
is_ldgoi(const memory_desc_wrapper & md)43 bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) {
44 if (md.format_kind() != format_kind::blocked) return false;
45
46 auto blk = md.blocking_desc();
47 auto str = blk.strides;
48 auto dims = md.dims();
49 return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1
50 && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3]
51 && str[0] == str[1] * dims[1];
52 };
53
init_rnn_conf(conf_t & rnn,const rnn_desc_t & rd,const memory_desc_wrapper & src_layer_d,const memory_desc_wrapper & src_iter_d,const memory_desc_wrapper & weights_layer_d,const memory_desc_wrapper & weights_iter_d,const memory_desc_wrapper & dst_layer_d)54 void rnn_utils::init_rnn_conf(conf_t &rnn, const rnn_desc_t &rd,
55 const memory_desc_wrapper &src_layer_d,
56 const memory_desc_wrapper &src_iter_d,
57 const memory_desc_wrapper &weights_layer_d,
58 const memory_desc_wrapper &weights_iter_d,
59 const memory_desc_wrapper &dst_layer_d) {
60
61 rnn = utils::zero<decltype(rnn)>();
62 rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
63 prop_kind::forward_inference);
64 rnn.is_training = utils::one_of(
65 rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
66 rnn.is_lbr = rd.cell_kind == dnnl_lbr_gru;
67 rnn.is_vanilla_gru = rd.cell_kind == dnnl_vanilla_gru;
68
69 switch (rd.direction) {
70 case dnnl_unidirectional_left2right: rnn.exec_dir = l2r; break;
71 case dnnl_unidirectional_right2left: rnn.exec_dir = r2l; break;
72 case dnnl_bidirectional_concat: rnn.exec_dir = bi_concat; break;
73 case dnnl_bidirectional_sum: rnn.exec_dir = bi_sum; break;
74 default: break;
75 }
76
77 if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(),
78 weights_layer_d.data_type()))
79 rnn.dt_conf = all_f32;
80 else if (everyone_is(bf16, src_layer_d.data_type(), dst_layer_d.data_type(),
81 weights_layer_d.data_type()))
82 rnn.dt_conf = all_bf16;
83 else if (everyone_is(f16, src_layer_d.data_type(), dst_layer_d.data_type(),
84 weights_layer_d.data_type()))
85 rnn.dt_conf = all_f16;
86 else if (dst_layer_d.data_type() == u8) {
87 if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
88 rnn.dt_conf = u8u8u8u8;
89 else
90 rnn.dt_conf = f32u8f32u8;
91 } else {
92 if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
93 rnn.dt_conf = u8u8u8f32;
94 else
95 rnn.dt_conf = f32u8f32f32;
96 }
97 rnn.is_int8 = !one_of(rnn.dt_conf, all_f32, all_f16, all_bf16);
98
99 rnn.aux_data_type
100 = rnn.dt_conf == all_f16 ? data_type::f16 : data_type::f32;
101 rnn.diff_data_type = data_type::f32;
102
103 rnn.n_layer = weights_layer_d.dims()[0];
104 rnn.n_iter = src_layer_d.dims()[0];
105 rnn.n_dir = weights_layer_d.dims()[1];
106 rnn.n_gates = weights_layer_d.dims()[3];
107 rnn.n_states = rd.cell_kind == dnnl_vanilla_lstm ? 2 : 1;
108 rnn.n_bias = rnn.n_gates + rnn.is_lbr;
109 rnn.mb = src_layer_d.dims()[1];
110 rnn.sic = weights_iter_d.dims()[2];
111 rnn.slc = weights_layer_d.dims()[2];
112 rnn.dhc = weights_layer_d.dims()[4];
113 rnn.dlc = dst_layer_d.dims()[2];
114
115 rnn.gates_ld = rnn.dhc * rnn.n_gates;
116 rnn.gates_nld = rnn.mb;
117 rnn.states_nld = rnn.mb;
118
119 // Set the correct number of weights parts
120 rnn.n_parts_weights_layer = 1;
121 rnn.parts_weights_layer[0] = rnn.n_gates;
122 rnn.parts_weights_layer[1] = 0;
123
124 //there are two parts for VANILLA GRU weights iteration
125 rnn.n_parts_weights_iter = rnn.is_vanilla_gru ? 2 : 1;
126 rnn.parts_weights_iter[0] = rnn.is_vanilla_gru ? 2 : rnn.n_gates;
127 rnn.parts_weights_iter[1] = rnn.is_vanilla_gru ? 1 : 0;
128
129 rnn.n_parts_bias = 1;
130 rnn.parts_bias[0] = rnn.n_bias;
131 rnn.parts_bias[1] = 0;
132
133 bool is_gru = utils::one_of(
134 rd.cell_kind, alg_kind::vanilla_gru, alg_kind::lbr_gru);
135
136 // Decide if to merge gemm across iterations or layers
137 auto src_layer_ld = src_layer_d.blocking_desc().strides[1];
138 auto dst_layer_ld = dst_layer_d.blocking_desc().strides[1];
139 auto src_layer_is_trivial_stride
140 = src_layer_d.blocking_desc().strides[0] == (src_layer_ld * rnn.mb);
141 auto dst_layer_is_trivial_stride
142 = dst_layer_d.blocking_desc().strides[0] == (dst_layer_ld * rnn.mb);
143
144 rnn.merge_gemm_layer = ((rnn.is_fwd && src_layer_is_trivial_stride)
145 || ((rd.prop_kind == prop_kind::backward)
146 && dst_layer_is_trivial_stride))
147 && (((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) || rnn.is_int8);
148 rnn.merge_gemm_iter
149 = dst_layer_is_trivial_stride && !(rnn.is_fwd || is_gru);
150
151 // Decide to copy bias
152 rnn.copy_bias = rnn.is_int8;
153
154 rnn.use_workspace = rnn.is_training;
155
156 switch (rnn.dt_conf) {
157 case all_f32:
158 case f32u8f32f32:
159 rnn.input_data_type = f32;
160 rnn.dst_data_type = f32;
161 rnn.output_data_type = f32;
162 break;
163 case all_f16:
164 rnn.input_data_type = f16;
165 rnn.dst_data_type = f16;
166 rnn.output_data_type = f16;
167 break;
168 case u8u8u8u8:
169 rnn.input_data_type = u8;
170 rnn.dst_data_type = u8;
171 rnn.output_data_type = u8;
172 break;
173 case u8u8u8f32:
174 rnn.input_data_type = u8;
175 rnn.dst_data_type = f32;
176 rnn.output_data_type = u8;
177 break;
178 case f32u8f32u8:
179 rnn.input_data_type = f32;
180 rnn.dst_data_type = u8;
181 rnn.output_data_type = f32;
182 break;
183 case all_bf16:
184 rnn.input_data_type = bf16;
185 rnn.dst_data_type = bf16;
186 rnn.output_data_type = bf16;
187 break;
188 default: assert(!"unimplemented");
189 }
190 }
191
init_test_mode(conf_t & rnn,const primitive_attr_t & attr)192 void rnn_utils::init_test_mode(conf_t &rnn, const primitive_attr_t &attr) {
193 rnn.is_testmode = attr.rnn_tparams_.test_mode_;
194 rnn.tm_ngates = attr.rnn_tparams_.ngates_;
195 rnn.tm_cscale = attr.rnn_tparams_.cscale_;
196 }
197
set_rnn_conf(conf_t & rnn,const rnn_desc_t & rd,const memory_desc_wrapper & weights_layer_d,const memory_desc_wrapper & weights_iter_d,const memory_desc_wrapper & diff_weights_layer_d,const memory_desc_wrapper & diff_weights_iter_d)198 void rnn_utils::set_rnn_conf(conf_t &rnn, const rnn_desc_t &rd,
199 const memory_desc_wrapper &weights_layer_d,
200 const memory_desc_wrapper &weights_iter_d,
201 const memory_desc_wrapper &diff_weights_layer_d,
202 const memory_desc_wrapper &diff_weights_iter_d) {
203
204 //Set leading dimensions for input weights arrays depending on input format
205 auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
206 ld = 0;
207 nld = 0;
208 if (md.is_blocking_desc()) {
209 if (is_ldigo(md)) {
210 ld = (int)md.blocking_desc().strides[2];
211 nld = md.dims()[2];
212 } else if (is_ldgoi(md)) {
213 ld = (int)md.blocking_desc().strides[4];
214 nld = md.dims()[3] * md.dims()[4];
215 } else
216 assert(!"unsupported weights format");
217 }
218 };
219 set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
220 set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
221 if (!rnn.is_fwd) {
222 set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
223 rnn.diff_weights_layer_nld);
224 set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
225 rnn.diff_weights_iter_nld);
226 }
227
228 int sizeof_states_dt
229 = rnn.dt_conf == all_f32 ? sizeof(cl_float) : sizeof(cl_half);
230 int aux_elsz = rnn.aux_data_type == data_type::f16 ? sizeof(cl_half)
231 : sizeof(float);
232 rnn.ws_states_elsz = rnn.dt_conf == all_f32
233 ? sizeof(cl_float)
234 : rnn.dt_conf == all_f16 || rnn.dt_conf == all_bf16
235 ? sizeof(cl_half)
236 : rnn.dt_conf == u8u8u8u8 ? sizeof(int8_t)
237 : sizeof(int32_t);
238
239 // Different size required for forward and backward pass
240 rnn.scratch_gates_elsz = (!rnn.is_fwd && rnn.dt_conf == all_bf16)
241 ? sizeof(cl_half)
242 : aux_elsz;
243
244 // Set workspace sizes to store:
245 // states to copmute a pass
246 // diff states to copmute bwd pass (training only)
247 // intermediate results from the gates
248 rnn.states_ws_ld = get_good_ld(
249 nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dhc)), sizeof_states_dt);
250 rnn.gates_ws_ld = get_good_ld(rnn.gates_ld,
251 rnn.dt_conf == all_f16 ? sizeof(cl_half) : sizeof(cl_float));
252 rnn.diff_states_ws_ld = get_good_ld(
253 nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dhc)), sizeof(cl_float));
254 rnn.scratch_gates_ld = get_good_ld(rnn.gates_ld, rnn.scratch_gates_elsz);
255
256 bool is_lstm = rd.cell_kind == dnnl_vanilla_lstm;
257
258 rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
259 * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * rnn.ws_states_elsz;
260 // we do not need a good ld for iter_c as it is not involved in GEMM
261 // for now reverting it back to what it was originally
262 // TODO: seprate diff_c_offsets from diff-states & seprate h- and c- off
263 rnn.ws_c_states_size = is_lstm ? (size_t)(rnn.n_layer + 1) * rnn.n_dir
264 * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * aux_elsz
265 : (size_t)0;
266 rnn.ws_diff_states_size = rnn.is_training ? (size_t)(rnn.n_layer + 1)
267 * rnn.n_dir * (rnn.n_states + 1) * (rnn.n_iter + 1) * rnn.mb
268 * rnn.diff_states_ws_ld * aux_elsz
269 : (size_t)0;
270 rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb
271 * rnn.gates_ws_ld * aux_elsz;
272 rnn.n_iter_scratch_gates
273 = (rnn.merge_gemm_layer || rnn.merge_gemm_iter) ? rnn.n_iter : 1;
274 rnn.scratch_gates_size = (size_t)rnn.n_iter_scratch_gates * rnn.gates_nld
275 * rnn.scratch_gates_ld * rnn.scratch_gates_elsz;
276 rnn.ws_dhG1_size
277 = (rd.cell_kind == alg_kind::vanilla_gru && rnn.is_training)
278 ? (size_t)rnn.gates_nld * rnn.diff_states_ws_ld * sizeof(float)
279 : 0;
280 rnn.ws_bias_size
281 = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dhc * aux_elsz;
282
283 // For intermediate step in post-gemm fwd lbr gru
284 rnn.scratch_cell_size = rnn.is_lbr
285 ? (size_t)rnn.gates_nld * rnn.scratch_gates_ld
286 * rnn.scratch_gates_elsz
287 : (rd.cell_kind == alg_kind::vanilla_gru && rnn.is_training
288 ? (size_t)rnn.states_nld * rnn.states_ws_ld
289 * rnn.ws_states_elsz
290 : 0);
291
292 // Used for storing the intermediate value from fwd pass in training lbr gru
293 rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dhc * aux_elsz;
294 rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
295 * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * aux_elsz;
296 }
297
get_good_ld(int dim,int sizeof_dt)298 int rnn_utils::get_good_ld(int dim, int sizeof_dt) {
299 // we want matrices leading dimentions to be 64-byte aligned,
300 // and not divisible by 256 to avoid 4K aliasing effects
301 int ld = rnd_up(dim, 64 / sizeof_dt);
302 return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld;
303 }
304
set_offsets(const conf_t & rnn,size_t & ws_gates_offset,size_t & ws_states_offset,size_t & ws_c_states_offset,size_t & ws_diff_states_offset,size_t & ws_grid_comp_offset,size_t & scratch_cell_offset,size_t & ws_dhG1_offset,size_t & ws_bias_offset,size_t & scratch_gates_offset,size_t & scratchpad_size,size_t & workspace_size)305 void rnn_utils::set_offsets(const conf_t &rnn, size_t &ws_gates_offset,
306 size_t &ws_states_offset, size_t &ws_c_states_offset,
307 size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
308 size_t &scratch_cell_offset, size_t &ws_dhG1_offset,
309 size_t &ws_bias_offset, size_t &scratch_gates_offset,
310 size_t &scratchpad_size, size_t &workspace_size) {
311
312 const size_t page_size = 4096;
313 size_t current_offset;
314
315 // Mandatory workspaces: go to workspace if use_workspace, scratchpad
316 // otherwise assumes the workspace base pointer is page aligned
317
318 current_offset = 0;
319 ws_gates_offset = current_offset;
320 current_offset += rnn.ws_gates_size;
321
322 current_offset = utils::rnd_up(current_offset, page_size);
323 ws_states_offset = current_offset;
324 current_offset += rnn.ws_states_size;
325
326 current_offset = utils::rnd_up(current_offset, page_size);
327 ws_c_states_offset = current_offset;
328 current_offset += rnn.ws_c_states_size;
329
330 current_offset = utils::rnd_up(current_offset, page_size);
331 ws_grid_comp_offset = current_offset;
332 current_offset += rnn.ws_grid_comp_size;
333
334 current_offset = utils::rnd_up(current_offset, page_size);
335 ws_diff_states_offset = current_offset;
336 current_offset += rnn.ws_diff_states_size;
337
338 current_offset = utils::rnd_up(current_offset, page_size);
339 ws_dhG1_offset = current_offset;
340 current_offset += rnn.ws_dhG1_size;
341
342 workspace_size = rnn.use_workspace ? current_offset : 0;
343
344 // Optional scratchpads
345 // Assumes the scratchpad base pointer is page aligned.
346 // If use_workspace, the following goes to scratchpad alone,
347 // otherwise, all goes to scratchpad and continue incrementing offset
348 current_offset = rnn.use_workspace ? 0 : current_offset;
349
350 current_offset = utils::rnd_up(current_offset, page_size);
351 scratch_gates_offset = current_offset;
352 current_offset += rnn.scratch_gates_size;
353
354 current_offset = utils::rnd_up(current_offset, page_size);
355 scratch_cell_offset = current_offset;
356 current_offset += rnn.scratch_cell_size;
357
358 ws_bias_offset = 0;
359 if (rnn.copy_bias) {
360 current_offset = utils::rnd_up(current_offset, page_size);
361 ws_bias_offset = current_offset;
362 current_offset += rnn.ws_bias_size;
363 }
364 scratchpad_size = current_offset;
365 }
366
get_scratchpad_and_workspace_sizes(const conf_t & rnn,size_t & scratchpad_size,size_t & workspace_size)367 void rnn_utils::get_scratchpad_and_workspace_sizes(
368 const conf_t &rnn, size_t &scratchpad_size, size_t &workspace_size) {
369 size_t ws_gates_offset, ws_states_offset, ws_c_states_offset,
370 ws_diff_states_offset, ws_grid_comp_offset, scratch_cell_offset,
371 ws_dhG1_offset, ws_bias_offset, sratch_gates_offset;
372 set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset,
373 ws_c_states_offset, ws_grid_comp_offset, scratch_cell_offset,
374 ws_dhG1_offset, ws_bias_offset, sratch_gates_offset,
375 scratchpad_size, workspace_size);
376 }
377
set_offsets_fwd_gemm(const conf_t & rnn,int dir,int lay,data_type_t src_t,size_t * wei_layer_off_ptr,const size_t & ws_states_offset_,size_t & grid_ws_lay_offset,size_t & grid_wei_lay_offset,size_t & grid_ws_iter_offset)378 void rnn_utils::set_offsets_fwd_gemm(const conf_t &rnn, int dir, int lay,
379 data_type_t src_t, size_t *wei_layer_off_ptr,
380 const size_t &ws_states_offset_, size_t &grid_ws_lay_offset,
381 size_t &grid_wei_lay_offset, size_t &grid_ws_iter_offset) {
382 // Function overloaded. This function is called by grid execution
383 int n_layer = rnn.n_layer;
384 int n_dir = rnn.n_dir;
385
386 AOC<size_t, 3> off_weights_lay(
387 wei_layer_off_ptr, n_layer, n_dir, rnn.n_parts_weights_layer);
388
389 grid_wei_lay_offset = off_weights_lay(lay, dir, 0);
390 grid_ws_lay_offset = (cl_ulong)(ws_states_offset_
391 + OFF4(lay, n_layer + 1, dir, n_dir, 1, rnn.n_iter + 1, 0,
392 rnn.mb * rnn.states_ws_ld)
393 * types::data_type_size(src_t));
394 grid_ws_iter_offset = (cl_ulong)(ws_states_offset_
395 + OFF4(lay + 1, rnn.n_layers + 1, dir, rnn.n_dir, 0, rnn.n_iter + 1,
396 0, rnn.mb * rnn.states_ws_ld)
397 * types::data_type_size(src_t));
398 UNUSED(n_layer);
399 }
400
set_offsets_fwd_gemm(const conf_t & rnn,int iter,int dir,int lay,data_type_t src_t,size_t * wei_iter_off_ptr,const size_t & ws_states_offset_,size_t & cell_ws_iter_offset,size_t & cell_ws_lay_offset,size_t & cell_scratch_offset,size_t & cell_wei_iter_offset)401 void rnn_utils::set_offsets_fwd_gemm(const conf_t &rnn, int iter, int dir,
402 int lay, data_type_t src_t, size_t *wei_iter_off_ptr,
403 const size_t &ws_states_offset_, size_t &cell_ws_iter_offset,
404 size_t &cell_ws_lay_offset, size_t &cell_scratch_offset,
405 size_t &cell_wei_iter_offset) {
406 int n_layers = rnn.n_layer;
407 int batch = rnn.mb;
408 int n_iter = rnn.n_iter;
409 int n_dir = rnn.n_dir;
410
411 if (wei_iter_off_ptr) {
412 AOC<size_t, 3> off_weights_iter(wei_iter_off_ptr, rnn.n_layer,
413 rnn.n_dir, rnn.n_parts_weights_iter);
414 cell_wei_iter_offset = off_weights_iter(lay, dir, 0);
415 }
416
417 cell_scratch_offset = (rnn.merge_gemm_iter || rnn.merge_gemm_layer)
418 ? (cl_ulong)(
419 OFF2(iter, n_iter, 0, rnn.gates_nld * rnn.scratch_gates_ld)
420 * rnn.scratch_gates_elsz)
421 : (size_t)0;
422 cell_ws_iter_offset = (cl_ulong)(ws_states_offset_
423 + OFF4(lay + 1, n_layers + 1, dir, n_dir, iter, n_iter + 1, 0,
424 batch * rnn.states_ws_ld)
425 * types::data_type_size(src_t));
426 cell_ws_lay_offset = (cl_ulong)(ws_states_offset_
427 + OFF4(lay, n_layers + 1, dir, n_dir, iter + 1, n_iter + 1, 0,
428 batch * rnn.states_ws_ld)
429 * types::data_type_size(src_t));
430 UNUSED(n_layers);
431 }
432
set_gru_offsets_part2(const conf_t & rnn,int iter,int dir,int lay,data_type_t src_t,size_t * wei_iter_off_ptr,const size_t & ws_states_offset_,size_t & cell_wei_iter_offset,size_t & cell_scratch_offset,size_t & cell_ws_iter_offset)433 void rnn_utils::set_gru_offsets_part2(const conf_t &rnn, int iter, int dir,
434 int lay, data_type_t src_t, size_t *wei_iter_off_ptr,
435 const size_t &ws_states_offset_, size_t &cell_wei_iter_offset,
436 size_t &cell_scratch_offset, size_t &cell_ws_iter_offset) {
437
438 AOC<size_t, 3> off_weights_iter(
439 wei_iter_off_ptr, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
440 cell_wei_iter_offset = off_weights_iter(lay, dir, 1);
441 cell_scratch_offset += 2 * rnn.dhc * rnn.scratch_gates_elsz;
442 cell_ws_iter_offset = (cl_ulong)(ws_states_offset_
443 + OFF4(lay + 1, rnn.n_layers + 1, dir, rnn.n_dir, iter + 1,
444 rnn.n_iter + 1, 0, rnn.mb * rnn.states_ws_ld)
445 * types::data_type_size(src_t));
446 }
447
set_offsets_bwd_gemm(const conf_t & rnn,int iter,int dir,int lay,const size_t & ws_diff_states_off_,size_t & cell_diff_wei_iter_off,size_t & cell_diff_wei_lay_off,size_t & cell_diff_ws_lay_off)448 void rnn_utils::set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir,
449 int lay, const size_t &ws_diff_states_off_,
450 size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
451 size_t &cell_diff_ws_lay_off) {
452 // Function overloaded. This function is called by grid execution and it
453 // then calls set_offsets_bwd_gemm which is otherwise called in cell exec
454 cl_ulong dummy_var;
455 set_offsets_bwd_gemm(rnn, iter, dir, lay, ws_diff_states_off_,
456 cell_diff_wei_iter_off, cell_diff_wei_lay_off, cell_diff_ws_lay_off,
457 dummy_var);
458 }
459
set_offsets_bwd_gemm(const conf_t & rnn,int iter,int dir,int lay,const size_t & ws_diff_states_off_,size_t & cell_diff_wei_iter_off,size_t & cell_diff_wei_lay_off,size_t & cell_diff_ws_lay_off,size_t & cell_diff_ws_iter_off,size_t & cell_diff_wei_iter_off2)460 void rnn_utils::set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir,
461 int lay, const size_t &ws_diff_states_off_,
462 size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
463 size_t &cell_diff_ws_lay_off, size_t &cell_diff_ws_iter_off,
464 size_t &cell_diff_wei_iter_off2) {
465
466 set_offsets_bwd_gemm(rnn, iter, dir, lay, ws_diff_states_off_,
467 cell_diff_wei_iter_off, cell_diff_wei_lay_off, cell_diff_ws_lay_off,
468 cell_diff_ws_iter_off);
469 cell_diff_wei_iter_off2
470 = cell_diff_wei_iter_off + 2 * rnn.dhc * sizeof(float);
471 }
472
set_offsets_bwd_gemm(const conf_t & rnn,int iter,int dir,int lay,const size_t & ws_diff_states_off_,size_t & cell_diff_wei_iter_off,size_t & cell_diff_wei_lay_off,size_t & cell_diff_ws_lay_off,size_t & cell_diff_ws_iter_off)473 void rnn_utils::set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir,
474 int lay, const size_t &ws_diff_states_off_,
475 size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
476 size_t &cell_diff_ws_lay_off, size_t &cell_diff_ws_iter_off) {
477 int n_layers = rnn.n_layer;
478 int batch = rnn.mb;
479 int n_iter = rnn.n_iter;
480 int n_dir = rnn.n_dir;
481 int n_states = rnn.n_states;
482
483 cell_diff_ws_iter_off = ws_diff_states_off_
484 + OFF5(lay, n_layers + 1, dir, n_dir, 0, n_states + 1, iter,
485 n_iter + 1, 0, rnn.states_nld * rnn.diff_states_ws_ld)
486 * sizeof(float);
487 cell_diff_ws_lay_off = ws_diff_states_off_
488 + OFF5(lay, n_layers + 1, dir, n_dir, n_states, n_states + 1, iter,
489 n_iter + 1, 0, rnn.states_nld * rnn.diff_states_ws_ld)
490 * sizeof(float);
491 cell_diff_wei_lay_off
492 = OFF3(lay, n_layers, dir, n_dir, 0,
493 rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld)
494 * sizeof(float);
495 cell_diff_wei_iter_off
496 = OFF3(lay, n_layers, dir, n_dir, 0,
497 rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld)
498 * sizeof(float);
499 UNUSED(n_layers);
500 UNUSED(batch);
501 }
502
set_good_strides(memory_desc_t & weights_md,format_tag_t tag)503 status_t rnn_utils::set_good_strides(
504 memory_desc_t &weights_md, format_tag_t tag) {
505 auto &strides = weights_md.format_desc.blocking.strides;
506 auto dims = weights_md.dims;
507 using namespace format_tag;
508
509 if (tag == ldigo) {
510 strides[2] = rnn_utils::get_good_ld((int)strides[2],
511 (int)types::data_type_size(weights_md.data_type));
512 strides[1] = dims[2] * strides[2];
513 strides[0] = dims[1] * strides[1];
514 } else if (tag == ldgoi) {
515 strides[4] = rnn_utils::get_good_ld((int)strides[4],
516 (int)types::data_type_size(weights_md.data_type));
517 strides[3] = dims[4] * strides[4];
518 strides[1] = dims[3] * strides[3];
519 strides[0] = dims[1] * strides[1];
520 } else
521 return status::unimplemented;
522
523 return status::success;
524 }
525
set_expected_desc(conf_t & rnn,memory_desc_t & weights_md,bool is_iter)526 status_t rnn_utils::set_expected_desc(
527 conf_t &rnn, memory_desc_t &weights_md, bool is_iter) {
528 using namespace format_tag;
529 CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi));
530
531 // Adjust strides for good leading dimension in GEMM
532 CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi));
533
534 // set we need extra memory
535 if (rnn.is_fwd && !one_of(rnn.dt_conf, all_f32, all_f16, all_bf16)) {
536 weights_md.extra.flags = memory_extra_flags::rnn_u8s8_compensation;
537 weights_md.extra.compensation_mask = 27; // ldigo 11011;
538 }
539 return status::success;
540 }
541
542 } // namespace ocl
543 } // namespace gpu
544 } // namespace impl
545 } // namespace dnnl
546