1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #pragma once
21 
22 #include "../contrib/moderngpu/include/device/ctascan.cuh"
23 #include "../contrib/moderngpu/include/device/ctamerge.cuh"
24 
25 #include "ctc_helper.h"
26 
27 using namespace mgpu;
28 
29 template<int NT, int VT, typename T, typename KeyT, typename Op>
30 struct CTASegReduce {
31 
32     enum {NV = NT * VT};
33 
34     union Storage {
35         typename CTAScan<NT>::Storage scanStorage;
36         int indices[NV];
37     };
38 
39     //adapted from global kernel KernelReduceByKeyPreprocess
preprocessKeysCTASegReduce40     __device__ static void preprocessKeys(KeyT *keys, int count,
41                                           int *numUniqueLabels, int seg_start[VT],
42                                           int seg_end[VT], int *scanout) {
43         __shared__ Storage shared;
44 
45         const int tid = threadIdx.x;
46         // Compare adjacent keys within each thread and mark discontinuities
47         int endFlags = 0;
48         T key = keys[VT * tid];
49         #pragma unroll
50         for (int i = 0; i < VT; ++i) {
51             int index = VT * tid + 1 + i;
52             T next = keys[index];
53             if(index == count || (index < count && key != next)) {
54                 endFlags |= 1 << i;
55             }
56             key = next;
57         }
58 
59         __syncthreads();
60 
61         //Count the number of encountered end flags
62         int scan = CTAScan<NT>::Scan(tid, popc(endFlags), shared.scanStorage, numUniqueLabels);
63 
64         __syncthreads();
65 
66         //output the unique keys
67         //use indices as scratch space
68         int outputPos = scan;
69         #pragma unroll
70         for (int i = 0; i < VT; ++i) {
71 
72             if ( (endFlags >> i) & 1) {
73                 shared.indices[outputPos] = keys[VT * tid + i];
74                 scanout[outputPos] = VT * tid + i;
75                 outputPos++;
76             }
77         }
78 
79         __syncthreads();
80 
81         // Create start and end
82         for (int idx = tid, j = 0; idx < (*numUniqueLabels); idx += blockDim.x, ++j) {
83             seg_start[j] = (idx == 0) ? 0 : (scanout[idx-1] + 1);
84             seg_end[j] = scanout[idx];
85         }
86 
87         __syncthreads();
88 
89         //copy from the scratch space back into the keys
90         #pragma unroll
91         for (int i = 0; i < VT; ++i) {
92             keys[i * NT + tid] = shared.indices[i * NT + tid];
93         }
94 
95         __syncthreads();
96     }
97 };
98 
99 // Computes forward probabilities. This fills in a T * S matrix.
100 // The computation starts at t=1 (2nd row) and ends at t=T-1 (last row). Each row has
101 // S elements where S = 2L + 1.
102 //
103 // We only need to read in probabilities corresponding to the labels, thus a sparse
104 // set of values are read from the log probs matrix since the character set is much smaller
105 // than the labels. This is much more true for Mandarin than English.
106 template<typename ProbT, int NT, int VT>
107 __global__
compute_alpha_kernel(const ProbT * log_probs,const int * label_sizes,const int * utt_length,const int * repeats_in_labels,const int * labels_without_blanks,const int * label_offsets,int * labels_with_blanks,ProbT * alphas,ProbT * nll_forward,int stride,int out_dim,int S_memoffset,int T_memoffset,int blank_label)108 void compute_alpha_kernel (const ProbT* log_probs, const int *label_sizes,
109                            const int *utt_length, const int *repeats_in_labels,
110                            const int *labels_without_blanks, const int *label_offsets,
111                            int *labels_with_blanks, ProbT *alphas,
112                            ProbT* nll_forward, int stride, int out_dim,
113                            int S_memoffset, int T_memoffset, int blank_label) {
114 
115     ctc_helper::log_plus<ProbT> log_plus_f;
116 
117     const int tid = threadIdx.x;
118     const int L = label_sizes[blockIdx.x];
119     const int T = utt_length[blockIdx.x];
120     const int S = 2*L + 1;
121     const int prob_offset = out_dim * blockIdx.x;
122     const int repeats = repeats_in_labels[blockIdx.x];
123 
124     const int NV = NT * VT;
125     __shared__ int label[NV];
126 
127     if ((L + repeats) > T)
128         return;
129 
130     // Generate labels with blanks from labels without blanks
131     {
132         const int label_start_offset = label_offsets[blockIdx.x];
133         for (int idx = tid; idx < L; idx += blockDim.x) {
134             const int offset = (blockIdx.x * S_memoffset) + 2 * idx;
135             labels_with_blanks[offset] = blank_label;
136             labels_with_blanks[offset+1] = labels_without_blanks[label_start_offset + idx];
137         }
138         if (tid == 0) {
139             labels_with_blanks[(blockIdx.x * S_memoffset) + 2 * L] = blank_label;
140         }
141     }
142     __syncthreads();
143 
144     const int *labels = labels_with_blanks;
145     const int* label_global = &labels[blockIdx.x * S_memoffset];
146     ProbT* alpha = &alphas[blockIdx.x * (S_memoffset * T_memoffset)];
147 
148     // Set the first row of alpha neg_inf - it is much more efficient to do it
149     // here than outside
150     #pragma unroll
151     for (int idx = tid; idx < min(S, NV); idx += blockDim.x) {
152         alpha[idx] = ctc_helper::neg_inf<ProbT>();
153     }
154 
155     // Load labels into shared memory
156     #pragma unroll
157     for (int i = tid; i < S; i += NT) {
158         label[i] = label_global[i];
159     }
160 
161     __syncthreads();
162 
163     int start =  (L + repeats < T) ? 0 : 1;
164     int end = S > 1 ? 2 : 1;
165 
166     // Initialize the first row corresponding to t=0;
167     for(int i = tid; i < (end-start); i += blockDim.x)
168         alpha[i + start] = log_probs[prob_offset + label[i + start]];
169 
170     __syncthreads();
171 
172     // Fill in the rest of matrix, one row at a time (outer loop).
173     for(int t = 1; t < T; ++t) {
174 
175         // Start offsets into the current and previous row
176         const int start_cur_row = t * S;
177         const int start_prev_row = (t - 1) * S;
178 
179         // The prob is a 2D column major array, with probabilites for each t strided
180         // by (out_dim * stride), where stride is the minibatch size
181         const int start_prob_col = t * (out_dim * stride);
182 
183         // This is the first column and in this case there is nothing left of it
184         if (tid == 0) {
185             if (start == 0) {
186                 alpha[start_cur_row] = alpha[start_prev_row] +
187                                        log_probs[prob_offset + start_prob_col + blank_label];
188             }
189             else if (start == 1) {
190                 alpha[start_cur_row] = alpha[start_prev_row];
191             }
192         }
193 
194         __syncthreads();
195 
196         // Fill in the elements in each row. There is no loop dependence here since our
197         // input is the row above. We sum either two or three adjacent values from the
198         // row above depending on whether we have a blank or repeated characters. Finally
199         // we add the probability corresponding to this label at time t
200         #pragma unroll
201         for (int idx = (tid+1); idx < S; idx += blockDim.x) {
202 
203             ProbT prev_sum = log_plus_f(alpha[idx + start_prev_row], alpha[(idx-1) + start_prev_row]);
204 
205             // Skip two if not on blank and not on repeat.
206             if ((label[idx] != blank_label) &&
207                 (idx != 1) && (label[idx] != label[idx-2]))
208                 prev_sum = log_plus_f(prev_sum, alpha[(idx-2) + start_prev_row]);
209 
210             alpha[idx + start_cur_row] =
211                 prev_sum + log_probs[prob_offset + start_prob_col + label[idx]];
212         }
213 
214         __syncthreads();
215     }
216 
217     if (tid == 0) {
218         // Add and return the rightmost two/one element(s) in the last row.
219         ProbT loglike = ctc_helper::neg_inf<ProbT>();
220 
221         // This is the total increment for s_inc and e_inc through the loop
222         const int val = 2 * (L-1) + 1 - (((L + repeats) == T) ? 1 : 0);
223 
224         start = (val * (L!=0) + start);
225         end = (val * (L!=0) + end);
226 
227         for(int i = start; i < end; ++i)
228             loglike = log_plus_f(loglike, alpha[i + (T - 1) * S]);
229 
230         nll_forward[blockIdx.x] = -loglike;
231     }
232 }
233 
234 // Computes backward probabilities. This also fills in a T * S matrix
235 //
236 // See comments above compute_alphas for more context.
237 template<typename ProbT, int NT, int VT>
238 __global__
compute_betas_and_grad_kernel(const ProbT * log_probs,const int * label_sizes,const int * utt_length,const int * repeats_in_labels,const int * labels_with_blanks,ProbT * alphas,const ProbT * nll_forward,ProbT * nll_backward,ProbT * grads,int stride,int out_dim,int S_memoffset,int T_memoffset,int blank_label)239 void compute_betas_and_grad_kernel (const ProbT* log_probs, const int *label_sizes,
240                                     const int *utt_length, const int *repeats_in_labels,
241                                     const int *labels_with_blanks, ProbT *alphas,
242                                     const ProbT* nll_forward, ProbT *nll_backward,
243                                     ProbT *grads, int stride, int out_dim,
244                                     int S_memoffset, int T_memoffset, int blank_label) {
245 
246     ctc_helper::log_plus<ProbT> log_plus_f;
247     typedef CTASegReduce<NT, VT, ProbT, int, ctc_helper::log_plus<ProbT>> SegReduce;
248 
249     const int tid = threadIdx.x;
250     const int L = label_sizes[blockIdx.x];
251     const int T = utt_length[blockIdx.x];
252     const int S = 2*L + 1;
253     const int prob_offset = out_dim * blockIdx.x;
254     const int repeats = repeats_in_labels[blockIdx.x];
255     const ProbT log_partition = -nll_forward[blockIdx.x];
256 
257     const int* labels = labels_with_blanks;
258     const int* label_global = &labels[blockIdx.x * S_memoffset];
259     ProbT* alpha = &alphas[blockIdx.x * (S_memoffset * T_memoffset)];
260 
261     const int NV = NT * VT;
262 
263     union TempStorage {
264         ProbT beta[NV];
265         int result[NV];
266     };
267 
268     __shared__ TempStorage temp_buffer;
269 
270     __shared__ int label[NV];
271 
272     // Temporaries needed for segmented reduce
273     // TODO: see if we can combine the shared memory requirements
274     __shared__ int keys_shared[NV];
275     __shared__ int gather_indices[NV];
276     __shared__ ProbT output[NV];
277 
278     ProbT beta_val[VT];
279 
280     if ((L + repeats) > T)
281         return;
282 
283     int start = S > 1 ? (S - 2) : 0;
284     int end = (L + repeats < T) ? S : S-1;
285 
286     // Setup shared memory buffers
287     #pragma unroll
288     for (int idx = tid; idx < NV; idx += NT) {
289         label[idx] = (idx < S) ? label_global[idx] : INT_MAX;
290     }
291 
292     __syncthreads();
293 
294     // int flags;
295     int uniquelabels;
296     int seg_start[VT];
297     int seg_end[VT];
298 
299     // Sort labels and record indices from which to gather from
300     {
301         int key[VT];
302         int gather_val[VT];
303 
304         #pragma unroll
305         for (int i = 0; i < VT; ++i) {
306             const int idx = tid * VT + i;
307             gather_val[i] = idx;
308             key[i] = label[idx];
309         }
310 
311         __syncthreads();
312 
313         CTAMergesort<NT, VT, true, true, int, int, mgpu::less<int>>
314             (key, gather_val, keys_shared, gather_indices, S, tid, mgpu::less<int>());
315 
316         __syncthreads();
317 
318         for (int i = 0; i < VT; ++i) {
319             const int idx = tid * VT + i;
320             gather_indices[idx] = gather_val[i];
321         }
322 
323         __syncthreads();
324 
325         SegReduce::preprocessKeys(keys_shared, S, &uniquelabels, seg_start, seg_end,
326                                   temp_buffer.result);
327         __syncthreads();
328     }
329 
330     // TODO: probably not necessary
331     __syncthreads();
332 
333     // Load labels back
334     #pragma unroll
335     for (int idx = tid; idx < NV; idx += NT) {
336         temp_buffer.beta[idx] = ctc_helper::neg_inf<ProbT>();
337     }
338     __syncthreads();
339 
340     // Initialize the two rightmost values in the last row (assuming L non-zero)
341     for(int i = tid; i < (end-start); i += blockDim.x)
342         temp_buffer.beta[i + start] =
343             log_probs[prob_offset + (T - 1) * (out_dim * stride) + label[i + start]];
344 
345     __syncthreads();
346 
347     // Load output data in registers through the transpose trick - should really be a function
348     #pragma unroll
349     for (int idx = tid; idx < S; idx += NT) {
350         output[idx] = alpha[idx + (T - 1) * S] + temp_buffer.beta[idx];
351     }
352 
353     __syncthreads();
354 
355     // Start at the second to last row and backward in time
356     for(int t = T - 1; t >= 0; --t) {
357 
358         // Start offsets into the current and next row
359         const int start_cur_row = t * S;
360 
361         // Starting offset of column that we read from the log probs array
362         const int start_prob_col = t * (out_dim * stride);
363 
364         if (t < T-1) {
365 
366             // Filling up one row at at time but going back in time from the last row
367             // to the first. As in the forward pass, there is no loop dependence and we
368             // do a variable length filter of maximum filter size of 3
369             #pragma unroll
370             for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) {
371                 ProbT next_sum = log_plus_f(temp_buffer.beta[idx], temp_buffer.beta[idx+1]);
372 
373                     // Skip two if not on blank and not on repeat.
374                 if ((label[idx] != blank_label) &&
375                     (idx != (S-2)) && (label[idx] != label[idx+2]))
376                     next_sum = log_plus_f(next_sum, temp_buffer.beta[idx+2]);
377 
378                 beta_val[i] = next_sum + log_probs[prob_offset + start_prob_col + label[idx]];
379             }
380 
381             __syncthreads();
382 
383             // Initialize values for the rightmost column since there is nothing to the right
384             // Update input buffer for next iteration
385             if ((tid == 0) && (end == S))
386                 temp_buffer.beta[(S-1)] = temp_buffer.beta[(S-1)] +
387                                           log_probs[prob_offset + start_prob_col + blank_label];
388 
389             #pragma unroll
390             for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) {
391                temp_buffer.beta[idx] = beta_val[i];
392             }
393 
394             __syncthreads();
395 
396             // Beta Computation done - add to alpha and update the gradient. Reload
397             // the gradient back for segmented reduce later on
398             #pragma unroll
399             for(int idx = tid; idx < S; idx += NT) {
400                output[idx] = alpha[idx + start_cur_row] + temp_buffer.beta[idx];
401             }
402 
403             __syncthreads();
404 
405         }
406 
407         __syncthreads();
408 
409         // Compute segmented reduction of output by using label as key
410         {
411             // Somewhat faster key value reduce
412             ProbT accum[VT];
413 
414             for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) {
415 
416                 accum[j] = ctc_helper::neg_inf<ProbT>();
417                 for (int i = seg_start[j]; i <= seg_end[j]; ++i) {
418                     accum[j] = log_plus_f(accum[j], output[gather_indices[i]]);
419                 }
420             }
421             __syncthreads();
422 
423             // Write accumulated value into output since that is not used
424             for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) {
425                 output[idx] = accum[j];
426             }
427             __syncthreads();
428 
429             for (int idx = tid; idx < out_dim; idx += blockDim.x) {
430                 const int grads_offset = prob_offset + start_prob_col + idx;
431                 grads[grads_offset] = exp(log_probs[grads_offset]);
432             }
433 
434             __syncthreads();
435 
436             for (int idx = tid; idx < uniquelabels; idx += blockDim.x) {
437                 const int grads_offset = prob_offset + start_prob_col + keys_shared[idx];
438 
439                 ProbT grad = output[idx];
440 
441                 if ((grad == 0.0) || (log_probs[grads_offset] == ctc_helper::neg_inf<ProbT>()) ||
442                     (grad == ctc_helper::neg_inf<ProbT>())) {
443                 } else {
444                     grads[grads_offset] =
445                         exp(log_probs[grads_offset]) - exp(grad - log_probs[grads_offset] - log_partition);
446                 }
447             }
448 
449             __syncthreads();
450         }
451 
452         // Output backward log likelihood
453         if ((t == 0) && (tid == 0)) {
454             ProbT loglike = ctc_helper::neg_inf<ProbT>();
455 
456             const int val = 2 * (L-1) + 1 - (((L + repeats) == T) ? 1 : 0);
457 
458             start = (-val * (L != 0) + start);
459             end = (-val * (L != 0) + end);
460 
461             // Sum and return the leftmost one/two value(s) in first row
462             for(int i = start; i < end; ++i)
463                 loglike = log_plus_f(loglike, temp_buffer.beta[i]);
464 
465             nll_backward[blockIdx.x] = -loglike;
466         }
467 
468         // For some reason this is important
469         __syncthreads();
470     }
471 }
472 
473 template <typename ProbT, int VT = 1, typename Op>
compute_log_probs_kernel(Op f,ProbT * log_probs,const ProbT * const denom,int alphabet_size,int count)474 __global__ void compute_log_probs_kernel(Op f, ProbT* log_probs,
475                                      const ProbT* const denom,
476                                      int alphabet_size,
477                                      int count) {
478 
479     int idx = blockDim.x * blockIdx.x + threadIdx.x;
480     int stride = blockDim.x * gridDim.x;
481 #pragma unroll
482     for(int i = 0; i < VT; i++) {
483         if (idx < count) {
484             const int column_idx = idx / alphabet_size;
485             log_probs[idx] = log_probs[idx] - log(denom[column_idx]);
486         }
487         idx += stride;
488     }
489 }
490 
491 template <typename ProbT, int VT = 1, typename Op>
prepare_stable_LSM_kernel(Op f,ProbT * log_probs,const ProbT * const col_max,int alphabet_size,int count)492 __global__ void prepare_stable_LSM_kernel(Op f, ProbT* log_probs,
493                                          const ProbT* const col_max,
494                                          int alphabet_size,
495                                          int count) {
496 
497     int idx = blockDim.x * blockIdx.x + threadIdx.x;
498     int stride = blockDim.x * gridDim.x;
499 #pragma unroll
500     for(int i = 0; i < VT; i++) {
501         if (idx < count) {
502             const int column_idx = idx / alphabet_size;
503             log_probs[idx] = f(log_probs[idx] - col_max[column_idx]);
504         }
505         idx += stride;
506     }
507 }
508