1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #pragma once
21
22 #include "../contrib/moderngpu/include/device/ctascan.cuh"
23 #include "../contrib/moderngpu/include/device/ctamerge.cuh"
24
25 #include "ctc_helper.h"
26
27 using namespace mgpu;
28
29 template<int NT, int VT, typename T, typename KeyT, typename Op>
30 struct CTASegReduce {
31
32 enum {NV = NT * VT};
33
34 union Storage {
35 typename CTAScan<NT>::Storage scanStorage;
36 int indices[NV];
37 };
38
39 //adapted from global kernel KernelReduceByKeyPreprocess
preprocessKeysCTASegReduce40 __device__ static void preprocessKeys(KeyT *keys, int count,
41 int *numUniqueLabels, int seg_start[VT],
42 int seg_end[VT], int *scanout) {
43 __shared__ Storage shared;
44
45 const int tid = threadIdx.x;
46 // Compare adjacent keys within each thread and mark discontinuities
47 int endFlags = 0;
48 T key = keys[VT * tid];
49 #pragma unroll
50 for (int i = 0; i < VT; ++i) {
51 int index = VT * tid + 1 + i;
52 T next = keys[index];
53 if(index == count || (index < count && key != next)) {
54 endFlags |= 1 << i;
55 }
56 key = next;
57 }
58
59 __syncthreads();
60
61 //Count the number of encountered end flags
62 int scan = CTAScan<NT>::Scan(tid, popc(endFlags), shared.scanStorage, numUniqueLabels);
63
64 __syncthreads();
65
66 //output the unique keys
67 //use indices as scratch space
68 int outputPos = scan;
69 #pragma unroll
70 for (int i = 0; i < VT; ++i) {
71
72 if ( (endFlags >> i) & 1) {
73 shared.indices[outputPos] = keys[VT * tid + i];
74 scanout[outputPos] = VT * tid + i;
75 outputPos++;
76 }
77 }
78
79 __syncthreads();
80
81 // Create start and end
82 for (int idx = tid, j = 0; idx < (*numUniqueLabels); idx += blockDim.x, ++j) {
83 seg_start[j] = (idx == 0) ? 0 : (scanout[idx-1] + 1);
84 seg_end[j] = scanout[idx];
85 }
86
87 __syncthreads();
88
89 //copy from the scratch space back into the keys
90 #pragma unroll
91 for (int i = 0; i < VT; ++i) {
92 keys[i * NT + tid] = shared.indices[i * NT + tid];
93 }
94
95 __syncthreads();
96 }
97 };
98
99 // Computes forward probabilities. This fills in a T * S matrix.
100 // The computation starts at t=1 (2nd row) and ends at t=T-1 (last row). Each row has
101 // S elements where S = 2L + 1.
102 //
103 // We only need to read in probabilities corresponding to the labels, thus a sparse
104 // set of values are read from the log probs matrix since the character set is much smaller
105 // than the labels. This is much more true for Mandarin than English.
106 template<typename ProbT, int NT, int VT>
107 __global__
compute_alpha_kernel(const ProbT * log_probs,const int * label_sizes,const int * utt_length,const int * repeats_in_labels,const int * labels_without_blanks,const int * label_offsets,int * labels_with_blanks,ProbT * alphas,ProbT * nll_forward,int stride,int out_dim,int S_memoffset,int T_memoffset,int blank_label)108 void compute_alpha_kernel (const ProbT* log_probs, const int *label_sizes,
109 const int *utt_length, const int *repeats_in_labels,
110 const int *labels_without_blanks, const int *label_offsets,
111 int *labels_with_blanks, ProbT *alphas,
112 ProbT* nll_forward, int stride, int out_dim,
113 int S_memoffset, int T_memoffset, int blank_label) {
114
115 ctc_helper::log_plus<ProbT> log_plus_f;
116
117 const int tid = threadIdx.x;
118 const int L = label_sizes[blockIdx.x];
119 const int T = utt_length[blockIdx.x];
120 const int S = 2*L + 1;
121 const int prob_offset = out_dim * blockIdx.x;
122 const int repeats = repeats_in_labels[blockIdx.x];
123
124 const int NV = NT * VT;
125 __shared__ int label[NV];
126
127 if ((L + repeats) > T)
128 return;
129
130 // Generate labels with blanks from labels without blanks
131 {
132 const int label_start_offset = label_offsets[blockIdx.x];
133 for (int idx = tid; idx < L; idx += blockDim.x) {
134 const int offset = (blockIdx.x * S_memoffset) + 2 * idx;
135 labels_with_blanks[offset] = blank_label;
136 labels_with_blanks[offset+1] = labels_without_blanks[label_start_offset + idx];
137 }
138 if (tid == 0) {
139 labels_with_blanks[(blockIdx.x * S_memoffset) + 2 * L] = blank_label;
140 }
141 }
142 __syncthreads();
143
144 const int *labels = labels_with_blanks;
145 const int* label_global = &labels[blockIdx.x * S_memoffset];
146 ProbT* alpha = &alphas[blockIdx.x * (S_memoffset * T_memoffset)];
147
148 // Set the first row of alpha neg_inf - it is much more efficient to do it
149 // here than outside
150 #pragma unroll
151 for (int idx = tid; idx < min(S, NV); idx += blockDim.x) {
152 alpha[idx] = ctc_helper::neg_inf<ProbT>();
153 }
154
155 // Load labels into shared memory
156 #pragma unroll
157 for (int i = tid; i < S; i += NT) {
158 label[i] = label_global[i];
159 }
160
161 __syncthreads();
162
163 int start = (L + repeats < T) ? 0 : 1;
164 int end = S > 1 ? 2 : 1;
165
166 // Initialize the first row corresponding to t=0;
167 for(int i = tid; i < (end-start); i += blockDim.x)
168 alpha[i + start] = log_probs[prob_offset + label[i + start]];
169
170 __syncthreads();
171
172 // Fill in the rest of matrix, one row at a time (outer loop).
173 for(int t = 1; t < T; ++t) {
174
175 // Start offsets into the current and previous row
176 const int start_cur_row = t * S;
177 const int start_prev_row = (t - 1) * S;
178
179 // The prob is a 2D column major array, with probabilites for each t strided
180 // by (out_dim * stride), where stride is the minibatch size
181 const int start_prob_col = t * (out_dim * stride);
182
183 // This is the first column and in this case there is nothing left of it
184 if (tid == 0) {
185 if (start == 0) {
186 alpha[start_cur_row] = alpha[start_prev_row] +
187 log_probs[prob_offset + start_prob_col + blank_label];
188 }
189 else if (start == 1) {
190 alpha[start_cur_row] = alpha[start_prev_row];
191 }
192 }
193
194 __syncthreads();
195
196 // Fill in the elements in each row. There is no loop dependence here since our
197 // input is the row above. We sum either two or three adjacent values from the
198 // row above depending on whether we have a blank or repeated characters. Finally
199 // we add the probability corresponding to this label at time t
200 #pragma unroll
201 for (int idx = (tid+1); idx < S; idx += blockDim.x) {
202
203 ProbT prev_sum = log_plus_f(alpha[idx + start_prev_row], alpha[(idx-1) + start_prev_row]);
204
205 // Skip two if not on blank and not on repeat.
206 if ((label[idx] != blank_label) &&
207 (idx != 1) && (label[idx] != label[idx-2]))
208 prev_sum = log_plus_f(prev_sum, alpha[(idx-2) + start_prev_row]);
209
210 alpha[idx + start_cur_row] =
211 prev_sum + log_probs[prob_offset + start_prob_col + label[idx]];
212 }
213
214 __syncthreads();
215 }
216
217 if (tid == 0) {
218 // Add and return the rightmost two/one element(s) in the last row.
219 ProbT loglike = ctc_helper::neg_inf<ProbT>();
220
221 // This is the total increment for s_inc and e_inc through the loop
222 const int val = 2 * (L-1) + 1 - (((L + repeats) == T) ? 1 : 0);
223
224 start = (val * (L!=0) + start);
225 end = (val * (L!=0) + end);
226
227 for(int i = start; i < end; ++i)
228 loglike = log_plus_f(loglike, alpha[i + (T - 1) * S]);
229
230 nll_forward[blockIdx.x] = -loglike;
231 }
232 }
233
234 // Computes backward probabilities. This also fills in a T * S matrix
235 //
236 // See comments above compute_alphas for more context.
237 template<typename ProbT, int NT, int VT>
238 __global__
compute_betas_and_grad_kernel(const ProbT * log_probs,const int * label_sizes,const int * utt_length,const int * repeats_in_labels,const int * labels_with_blanks,ProbT * alphas,const ProbT * nll_forward,ProbT * nll_backward,ProbT * grads,int stride,int out_dim,int S_memoffset,int T_memoffset,int blank_label)239 void compute_betas_and_grad_kernel (const ProbT* log_probs, const int *label_sizes,
240 const int *utt_length, const int *repeats_in_labels,
241 const int *labels_with_blanks, ProbT *alphas,
242 const ProbT* nll_forward, ProbT *nll_backward,
243 ProbT *grads, int stride, int out_dim,
244 int S_memoffset, int T_memoffset, int blank_label) {
245
246 ctc_helper::log_plus<ProbT> log_plus_f;
247 typedef CTASegReduce<NT, VT, ProbT, int, ctc_helper::log_plus<ProbT>> SegReduce;
248
249 const int tid = threadIdx.x;
250 const int L = label_sizes[blockIdx.x];
251 const int T = utt_length[blockIdx.x];
252 const int S = 2*L + 1;
253 const int prob_offset = out_dim * blockIdx.x;
254 const int repeats = repeats_in_labels[blockIdx.x];
255 const ProbT log_partition = -nll_forward[blockIdx.x];
256
257 const int* labels = labels_with_blanks;
258 const int* label_global = &labels[blockIdx.x * S_memoffset];
259 ProbT* alpha = &alphas[blockIdx.x * (S_memoffset * T_memoffset)];
260
261 const int NV = NT * VT;
262
263 union TempStorage {
264 ProbT beta[NV];
265 int result[NV];
266 };
267
268 __shared__ TempStorage temp_buffer;
269
270 __shared__ int label[NV];
271
272 // Temporaries needed for segmented reduce
273 // TODO: see if we can combine the shared memory requirements
274 __shared__ int keys_shared[NV];
275 __shared__ int gather_indices[NV];
276 __shared__ ProbT output[NV];
277
278 ProbT beta_val[VT];
279
280 if ((L + repeats) > T)
281 return;
282
283 int start = S > 1 ? (S - 2) : 0;
284 int end = (L + repeats < T) ? S : S-1;
285
286 // Setup shared memory buffers
287 #pragma unroll
288 for (int idx = tid; idx < NV; idx += NT) {
289 label[idx] = (idx < S) ? label_global[idx] : INT_MAX;
290 }
291
292 __syncthreads();
293
294 // int flags;
295 int uniquelabels;
296 int seg_start[VT];
297 int seg_end[VT];
298
299 // Sort labels and record indices from which to gather from
300 {
301 int key[VT];
302 int gather_val[VT];
303
304 #pragma unroll
305 for (int i = 0; i < VT; ++i) {
306 const int idx = tid * VT + i;
307 gather_val[i] = idx;
308 key[i] = label[idx];
309 }
310
311 __syncthreads();
312
313 CTAMergesort<NT, VT, true, true, int, int, mgpu::less<int>>
314 (key, gather_val, keys_shared, gather_indices, S, tid, mgpu::less<int>());
315
316 __syncthreads();
317
318 for (int i = 0; i < VT; ++i) {
319 const int idx = tid * VT + i;
320 gather_indices[idx] = gather_val[i];
321 }
322
323 __syncthreads();
324
325 SegReduce::preprocessKeys(keys_shared, S, &uniquelabels, seg_start, seg_end,
326 temp_buffer.result);
327 __syncthreads();
328 }
329
330 // TODO: probably not necessary
331 __syncthreads();
332
333 // Load labels back
334 #pragma unroll
335 for (int idx = tid; idx < NV; idx += NT) {
336 temp_buffer.beta[idx] = ctc_helper::neg_inf<ProbT>();
337 }
338 __syncthreads();
339
340 // Initialize the two rightmost values in the last row (assuming L non-zero)
341 for(int i = tid; i < (end-start); i += blockDim.x)
342 temp_buffer.beta[i + start] =
343 log_probs[prob_offset + (T - 1) * (out_dim * stride) + label[i + start]];
344
345 __syncthreads();
346
347 // Load output data in registers through the transpose trick - should really be a function
348 #pragma unroll
349 for (int idx = tid; idx < S; idx += NT) {
350 output[idx] = alpha[idx + (T - 1) * S] + temp_buffer.beta[idx];
351 }
352
353 __syncthreads();
354
355 // Start at the second to last row and backward in time
356 for(int t = T - 1; t >= 0; --t) {
357
358 // Start offsets into the current and next row
359 const int start_cur_row = t * S;
360
361 // Starting offset of column that we read from the log probs array
362 const int start_prob_col = t * (out_dim * stride);
363
364 if (t < T-1) {
365
366 // Filling up one row at at time but going back in time from the last row
367 // to the first. As in the forward pass, there is no loop dependence and we
368 // do a variable length filter of maximum filter size of 3
369 #pragma unroll
370 for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) {
371 ProbT next_sum = log_plus_f(temp_buffer.beta[idx], temp_buffer.beta[idx+1]);
372
373 // Skip two if not on blank and not on repeat.
374 if ((label[idx] != blank_label) &&
375 (idx != (S-2)) && (label[idx] != label[idx+2]))
376 next_sum = log_plus_f(next_sum, temp_buffer.beta[idx+2]);
377
378 beta_val[i] = next_sum + log_probs[prob_offset + start_prob_col + label[idx]];
379 }
380
381 __syncthreads();
382
383 // Initialize values for the rightmost column since there is nothing to the right
384 // Update input buffer for next iteration
385 if ((tid == 0) && (end == S))
386 temp_buffer.beta[(S-1)] = temp_buffer.beta[(S-1)] +
387 log_probs[prob_offset + start_prob_col + blank_label];
388
389 #pragma unroll
390 for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) {
391 temp_buffer.beta[idx] = beta_val[i];
392 }
393
394 __syncthreads();
395
396 // Beta Computation done - add to alpha and update the gradient. Reload
397 // the gradient back for segmented reduce later on
398 #pragma unroll
399 for(int idx = tid; idx < S; idx += NT) {
400 output[idx] = alpha[idx + start_cur_row] + temp_buffer.beta[idx];
401 }
402
403 __syncthreads();
404
405 }
406
407 __syncthreads();
408
409 // Compute segmented reduction of output by using label as key
410 {
411 // Somewhat faster key value reduce
412 ProbT accum[VT];
413
414 for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) {
415
416 accum[j] = ctc_helper::neg_inf<ProbT>();
417 for (int i = seg_start[j]; i <= seg_end[j]; ++i) {
418 accum[j] = log_plus_f(accum[j], output[gather_indices[i]]);
419 }
420 }
421 __syncthreads();
422
423 // Write accumulated value into output since that is not used
424 for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) {
425 output[idx] = accum[j];
426 }
427 __syncthreads();
428
429 for (int idx = tid; idx < out_dim; idx += blockDim.x) {
430 const int grads_offset = prob_offset + start_prob_col + idx;
431 grads[grads_offset] = exp(log_probs[grads_offset]);
432 }
433
434 __syncthreads();
435
436 for (int idx = tid; idx < uniquelabels; idx += blockDim.x) {
437 const int grads_offset = prob_offset + start_prob_col + keys_shared[idx];
438
439 ProbT grad = output[idx];
440
441 if ((grad == 0.0) || (log_probs[grads_offset] == ctc_helper::neg_inf<ProbT>()) ||
442 (grad == ctc_helper::neg_inf<ProbT>())) {
443 } else {
444 grads[grads_offset] =
445 exp(log_probs[grads_offset]) - exp(grad - log_probs[grads_offset] - log_partition);
446 }
447 }
448
449 __syncthreads();
450 }
451
452 // Output backward log likelihood
453 if ((t == 0) && (tid == 0)) {
454 ProbT loglike = ctc_helper::neg_inf<ProbT>();
455
456 const int val = 2 * (L-1) + 1 - (((L + repeats) == T) ? 1 : 0);
457
458 start = (-val * (L != 0) + start);
459 end = (-val * (L != 0) + end);
460
461 // Sum and return the leftmost one/two value(s) in first row
462 for(int i = start; i < end; ++i)
463 loglike = log_plus_f(loglike, temp_buffer.beta[i]);
464
465 nll_backward[blockIdx.x] = -loglike;
466 }
467
468 // For some reason this is important
469 __syncthreads();
470 }
471 }
472
473 template <typename ProbT, int VT = 1, typename Op>
compute_log_probs_kernel(Op f,ProbT * log_probs,const ProbT * const denom,int alphabet_size,int count)474 __global__ void compute_log_probs_kernel(Op f, ProbT* log_probs,
475 const ProbT* const denom,
476 int alphabet_size,
477 int count) {
478
479 int idx = blockDim.x * blockIdx.x + threadIdx.x;
480 int stride = blockDim.x * gridDim.x;
481 #pragma unroll
482 for(int i = 0; i < VT; i++) {
483 if (idx < count) {
484 const int column_idx = idx / alphabet_size;
485 log_probs[idx] = log_probs[idx] - log(denom[column_idx]);
486 }
487 idx += stride;
488 }
489 }
490
491 template <typename ProbT, int VT = 1, typename Op>
prepare_stable_LSM_kernel(Op f,ProbT * log_probs,const ProbT * const col_max,int alphabet_size,int count)492 __global__ void prepare_stable_LSM_kernel(Op f, ProbT* log_probs,
493 const ProbT* const col_max,
494 int alphabet_size,
495 int count) {
496
497 int idx = blockDim.x * blockIdx.x + threadIdx.x;
498 int stride = blockDim.x * gridDim.x;
499 #pragma unroll
500 for(int i = 0; i < VT; i++) {
501 if (idx < count) {
502 const int column_idx = idx / alphabet_size;
503 log_probs[idx] = f(log_probs[idx] - col_max[column_idx]);
504 }
505 idx += stride;
506 }
507 }
508