1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #pragma once
21 
22 #include <tuple>
23 #include <cmath>
24 #include <limits>
25 #include <algorithm>
26 #include <numeric>
27 
28 #include <dmlc/omp.h>
29 
30 #include "ctc_helper.h"
31 
32 namespace mxnet_warpctc {
33 
34 template<typename ProbT>
35 class CpuCTC {
36 public:
37     // Noncopyable
CpuCTC(int alphabet_size,int minibatch,void * workspace,int blank_label)38     CpuCTC(int alphabet_size, int minibatch, void* workspace,
39            int blank_label) :
40             alphabet_size_(alphabet_size), minibatch_(minibatch),
41             workspace_(workspace), blank_label_(blank_label) {
42 
43     };
44 
45     CpuCTC(const CpuCTC&) = delete;
46     CpuCTC& operator=(const CpuCTC&) = delete;
47 
48     ctcStatus_t cost_and_grad(const ProbT* const activations,
49                               ProbT *grads,
50                               ProbT* costs,
51                               const int* const flat_labels,
52                               const int* const label_lengths,
53                               const int* const input_lengths);
54 
55 
56     ctcStatus_t score_forward(const ProbT* const activations,
57                               ProbT* costs,
58                               const int* const flat_labels,
59                               const int* const label_lengths,
60                               const int* const input_lengths);
61 
62 private:
63 
64     class CpuCTC_metadata {
65 
66     private:
67         int setup_labels(const int* const labels, int blank_label, int L, int S);
68 
69     public:
70         CpuCTC_metadata(int L, int S, int T, int mb, int alphabet_size,
71                         void* workspace, size_t bytes_used, int blank_label,
72                         const int* const labels);
73 
74         ProbT* alphas;
75         ProbT* betas;
76         int* labels_w_blanks;
77         int* e_inc;
78         int* s_inc;
79         ProbT* output;
80         int repeats;
81     };
82 
83     int alphabet_size_; // Number of characters plus blank
84     int minibatch_;
85     void* workspace_;
86     int blank_label_;
87 
88     void log_softmax(const ProbT* const activations, ProbT* log_probs,
89                      const int* const input_lengths);
90 
91     std::tuple<ProbT, bool>
92             cost_and_grad_kernel(ProbT *grad, const ProbT* const log_probs,
93                                  const int* const labels, int T, int L,
94                                  int mb, size_t bytes_used);
95 
96     ProbT compute_alphas(const ProbT* log_probs, int repeats, int S, int T,
97                          const int* const e_inc,
98                          const int* const s_inc,
99                          const int* const labels,
100                          ProbT* alphas);
101 
102     ProbT compute_betas_and_grad(ProbT* grad, const ProbT* const log_probs,
103                                  ProbT log_partition, int repeats,
104                                  int S, int T, const int* const e_inc,
105                                  const int* const s_inc,
106                                  const int* const labels,
107                                  ProbT* alphas,
108                                  ProbT* betas,
109                                  ProbT* output);
110 };
111 
112 template<typename ProbT>
CpuCTC_metadata(int L,int S,int T,int mb,int alphabet_size,void * workspace,size_t bytes_used,int blank_label,const int * const labels)113 CpuCTC<ProbT>::CpuCTC_metadata::CpuCTC_metadata(int L, int S, int T, int mb,
114                                                 int alphabet_size,
115                                                 void* workspace, size_t bytes_used,
116                                                 int blank_label,
117                                                 const int* const labels) {
118 
119     alphas = reinterpret_cast<ProbT *>(static_cast<char *>(workspace) + bytes_used);
120     bytes_used += sizeof(ProbT) * S * T;
121     std::fill(alphas, alphas + S * T, ctc_helper::neg_inf<ProbT>());
122     betas = reinterpret_cast<ProbT *>(static_cast<char *>(workspace) + bytes_used);
123     bytes_used += sizeof(ProbT) * S;
124     std::fill(betas, betas + S, ctc_helper::neg_inf<ProbT>());
125     labels_w_blanks = reinterpret_cast<int *>(static_cast<char *>(workspace) + bytes_used);
126     bytes_used += sizeof(int) * S;
127     e_inc = reinterpret_cast<int *>(static_cast<char *>(workspace) + bytes_used);
128     bytes_used += sizeof(int) * S;
129     s_inc = reinterpret_cast<int *>(static_cast<char *>(workspace) + bytes_used);
130     bytes_used += sizeof(int) * S;
131     output = reinterpret_cast<ProbT *>(static_cast<char *>(workspace) + bytes_used);
132     bytes_used += sizeof(ProbT) * alphabet_size;
133 
134     repeats = setup_labels(labels, blank_label, L, S);
135 }
136 
137 template<typename ProbT>
setup_labels(const int * const labels,int blank_label,int L,int S)138 int CpuCTC<ProbT>::CpuCTC_metadata::setup_labels(const int* const labels,
139                                                  int blank_label, int L, int S) {
140     int e_counter = 0;
141     int s_counter = 0;
142 
143     s_inc[s_counter++] = 1;
144 
145     int repeats = 0;
146 
147     for (int i = 1; i < L; ++i) {
148         if (labels[i-1] == labels[i]) {
149             s_inc[s_counter++] = 1;
150             s_inc[s_counter++] = 1;
151             e_inc[e_counter++] = 1;
152             e_inc[e_counter++] = 1;
153             ++repeats;
154         }
155         else {
156             s_inc[s_counter++] = 2;
157             e_inc[e_counter++] = 2;
158         }
159     }
160     e_inc[e_counter++] = 1;
161 
162     for (int i = 0; i < L; ++i) {
163         labels_w_blanks[2 * i] = blank_label;
164         labels_w_blanks[2 * i + 1] = labels[i];
165     }
166     labels_w_blanks[S - 1] = blank_label;
167 
168     return repeats;
169 }
170 
171 template<typename ProbT>
172 void
log_softmax(const ProbT * const activations,ProbT * log_probs,const int * const input_lengths)173 CpuCTC<ProbT>::log_softmax(const ProbT* const activations, ProbT* log_probs,
174                            const int* const input_lengths) {
175 #pragma omp parallel for
176     for (int mb = 0; mb < minibatch_; ++mb) {
177         for(int c = 0; c < input_lengths[mb]; ++c) {
178             int col_offset = (mb + minibatch_ * c) * alphabet_size_;
179             ProbT max_activation = -std::numeric_limits<ProbT>::infinity();
180             for(int r = 0; r < alphabet_size_; ++r)
181                 max_activation = std::max(max_activation, activations[r + col_offset]);
182 
183             ProbT denom = ProbT(0.);
184             for(int r = 0; r < alphabet_size_; ++r) {
185                 denom += std::exp(activations[r + col_offset] - max_activation);
186             }
187 
188             for(int r = 0; r < alphabet_size_; ++r) {
189                 log_probs[r + col_offset] = activations[r + col_offset]
190                                             - max_activation - std::log(denom);
191             }
192         }
193     }
194 }
195 
196 template<typename ProbT>
197 std::tuple<ProbT, bool>
cost_and_grad_kernel(ProbT * grad,const ProbT * const log_probs,const int * const labels,int T,int L,int mb,size_t bytes_used)198 CpuCTC<ProbT>::cost_and_grad_kernel(ProbT *grad, const ProbT* const log_probs,
199                                     const int* const labels,
200                                     int T, int L, int mb, size_t bytes_used) {
201 
202     const int S = 2*L + 1; // Number of labels with blanks
203 
204     CpuCTC_metadata ctcm(L, S, T, mb, alphabet_size_, workspace_, bytes_used, blank_label_, labels);
205 
206     bool over_threshold = false;
207 
208     if (L + ctcm.repeats > T) {
209         return std::make_tuple(ProbT(0), over_threshold); // TODO, not right to return 0
210     }
211 
212     ProbT llForward = compute_alphas(log_probs, ctcm.repeats, S, T, ctcm.e_inc,
213                                      ctcm.s_inc, ctcm.labels_w_blanks,
214                                      ctcm.alphas);
215 
216     ProbT llBackward = compute_betas_and_grad(grad, log_probs, llForward, ctcm.repeats,
217                                               S, T, ctcm.e_inc, ctcm.s_inc,
218                                               ctcm.labels_w_blanks,
219                                               ctcm.alphas,
220                                               ctcm.betas,
221                                               ctcm.output);
222 
223     ProbT diff = std::abs(llForward - llBackward);
224     if (diff > ctc_helper::threshold) {
225         over_threshold = true;
226     }
227 
228     return std::make_tuple(-llForward, over_threshold);
229 }
230 
231 // Computes forward probabilities
232 template<typename ProbT>
compute_alphas(const ProbT * log_probs,int repeats,int S,int T,const int * const e_inc,const int * const s_inc,const int * const labels,ProbT * alphas)233 ProbT CpuCTC<ProbT>::compute_alphas(const ProbT* log_probs, int repeats, int S, int T,
234                                     const int* const e_inc,
235                                     const int* const s_inc,
236                                     const int* const labels,
237                                     ProbT* alphas) {
238 
239     int start =  (((S /2) + repeats - T) < 0) ? 0 : 1,
240             end = S > 1 ? 2 : 1;
241 
242     for (int i = start; i < end; ++i) {
243         alphas[i] = log_probs[labels[i]];
244     }
245 
246     for(int t = 1; t < T; ++t) {
247         int remain = (S / 2) + repeats - (T - t);
248         if(remain >= 0)
249             start += s_inc[remain];
250         if(t <= (S / 2) + repeats)
251             end += e_inc[t - 1];
252         int startloop = start;
253         int idx1 = t * S, idx2 = (t - 1) * S, idx3 = t * (alphabet_size_ * minibatch_);
254 
255         if (start == 0) {
256             alphas[idx1] = alphas[idx2] + log_probs[blank_label_ + idx3];
257             startloop += 1;
258         }
259 
260         for(int i = startloop; i < end; ++i) {
261             ProbT prev_sum = ctc_helper::log_plus<ProbT>()(alphas[i + idx2], alphas[(i-1) + idx2]);
262 
263             // Skip two if not on blank and not on repeat.
264             if (labels[i] != blank_label_ && i != 1 && labels[i] != labels[i-2])
265                 prev_sum = ctc_helper::log_plus<ProbT>()(prev_sum, alphas[(i-2) + idx2]);
266 
267             alphas[i + idx1] = prev_sum + log_probs[labels[i] + idx3];
268         }
269     }
270 
271     ProbT loglike = ctc_helper::neg_inf<ProbT>();
272     for(int i = start; i < end; ++i) {
273         loglike = ctc_helper::log_plus<ProbT>()(loglike, alphas[i + (T - 1) * S]);
274     }
275 
276     return loglike;
277 }
278 
279 // Starting from T, we sweep backward over the alpha array computing one column
280 // of betas as we go.  At each position we can update product alpha * beta and then
281 // sum into the gradient associated with each label.
282 // NOTE computes gradient w.r.t UNNORMALIZED final layer activations.
283 // Assumed passed in grads are already zeroed!
284 template<typename ProbT>
compute_betas_and_grad(ProbT * grad,const ProbT * const log_probs,ProbT log_partition,int repeats,int S,int T,const int * const e_inc,const int * const s_inc,const int * const labels,ProbT * alphas,ProbT * betas,ProbT * output)285 ProbT CpuCTC<ProbT>::compute_betas_and_grad(ProbT* grad, const ProbT* const log_probs,
286                                             ProbT log_partition, int repeats,
287                                             int S, int T, const int* const e_inc,
288                                             const int* const s_inc,
289                                             const int* const labels,
290                                             ProbT* alphas,
291                                             ProbT* betas,
292                                             ProbT* output) {
293     int start = S > 1 ? (S - 2) : 0,
294             end = (T > (S / 2) + repeats) ? S : S-1;
295 
296     std::fill(output, output + alphabet_size_, ctc_helper::neg_inf<ProbT>());
297 
298     //set the starting values in the beta column at the very right edge
299     for (int i = start; i < end; ++i) {
300         betas[i] = log_probs[labels[i] + (T - 1) * (alphabet_size_ * minibatch_)];
301 
302         //compute alpha * beta in log space at this position in (S, T) space
303         alphas[i + (T - 1) * S] += betas[i];
304 
305         //update the gradient associated with this label
306         //essentially performing a reduce-by-key in a sequential manner
307         output[labels[i]] =
308                 ctc_helper::log_plus<ProbT>()(alphas[i + (T - 1) * S], output[labels[i]]);
309     }
310 
311     //update the gradient wrt to each unique label
312     for (int i = 0; i < alphabet_size_; ++i) {
313         int idx3 = (T - 1) * alphabet_size_ * minibatch_ + i;
314 
315         if (output[i] == 0.0 || output[i] == ctc_helper::neg_inf<ProbT>() ||
316             log_probs[idx3] == ctc_helper::neg_inf<ProbT>()) {
317             grad[idx3] = std::exp(log_probs[idx3]);
318         } else {
319             grad[idx3] = std::exp(log_probs[idx3])
320                          - std::exp(output[i] - log_probs[idx3] - log_partition);
321         }
322     }
323 
324     //loop from the second to last column all the way to the left
325     for(int t = T - 2; t >= 0; --t) {
326         int remain = (S / 2) + repeats - (T - t);
327         if(remain >= -1)
328             start -= s_inc[remain + 1];
329         if(t < (S / 2) + repeats)
330             end -= e_inc[t];
331 
332         int endloop = end == S ? end - 1 : end;
333         int idx1 = t * S, idx3 = t * (alphabet_size_ * minibatch_);
334 
335         std::fill(output, output + alphabet_size_, ctc_helper::neg_inf<ProbT>());
336 
337         for(int i = start; i < endloop; ++i) {
338             ProbT next_sum = ctc_helper::log_plus<ProbT>()(betas[i], betas[(i+1)]);
339             // Skip two if not on blank and not on repeat.
340             if (labels[i] != blank_label_ && i != (S-2) && labels[i] != labels[i+2]){
341                 next_sum = ctc_helper::log_plus<ProbT>()(next_sum, betas[(i+2)]);
342             }
343             betas[i] = next_sum + log_probs[labels[i] + idx3];
344 
345             //compute alpha * beta in log space
346             alphas[i + idx1] += betas[i];
347 
348             //update the gradient associated with this label
349             output[labels[i]] =
350                     ctc_helper::log_plus<ProbT>()(alphas[i + idx1], output[labels[i]]);
351         }
352 
353         if (end == S) {
354             betas[(S-1)] = betas[(S-1)] + log_probs[blank_label_ + idx3];
355             alphas[(S-1) + idx1] += betas[(S-1)];
356 
357             output[labels[S-1]] =
358                     ctc_helper::log_plus<ProbT>()(alphas[S-1 + idx1], output[labels[S-1]]);
359         }
360 
361         //go over the unique labels and compute the final grad
362         // wrt to each one at this time step
363         for (int i = 0; i < alphabet_size_; ++i) {
364 
365             if (output[i] == 0.0 || output[i] == ctc_helper::neg_inf<ProbT>() ||
366                 log_probs[idx3] == ctc_helper::neg_inf<ProbT>()) {
367                 grad[idx3] = std::exp(log_probs[idx3]);
368             } else {
369                 grad[idx3] = std::exp(log_probs[idx3])
370                              - std::exp(output[i] - log_probs[idx3] - log_partition);
371             }
372             ++idx3;
373         }
374     }
375 
376     ProbT loglike = ctc_helper::neg_inf<ProbT>();
377     for(int i = start; i < end; ++i) {
378         loglike = ctc_helper::log_plus<ProbT>()(loglike, betas[i]);
379     }
380 
381     return loglike;
382 }
383 
384 template<typename ProbT>
385 ctcStatus_t
cost_and_grad(const ProbT * const activations,ProbT * grads,ProbT * costs,const int * const flat_labels,const int * const label_lengths,const int * const input_lengths)386 CpuCTC<ProbT>::cost_and_grad(const ProbT* const activations,
387                              ProbT *grads,
388                              ProbT *costs,
389                              const int* const flat_labels,
390                              const int* const label_lengths,
391                              const int* const input_lengths) {
392     if (activations == nullptr ||
393         grads == nullptr ||
394         costs == nullptr ||
395         flat_labels == nullptr ||
396         label_lengths == nullptr ||
397         input_lengths == nullptr
398         )
399         return CTC_STATUS_INVALID_VALUE;
400 
401     ProbT* log_probs = static_cast<ProbT *>(workspace_);
402 
403     int maxT = *std::max_element(input_lengths, input_lengths + minibatch_);
404 
405     size_t bytes_used = sizeof(ProbT) * minibatch_ * alphabet_size_ * maxT;
406 
407     //per minibatch memory
408     size_t per_minibatch_bytes = 0;
409 
410     int maxL = *std::max_element(label_lengths, label_lengths + minibatch_);;
411     int maxS = 2 * maxL + 1;
412 
413     //output
414     per_minibatch_bytes += sizeof(float) * alphabet_size_;
415 
416     //alphas
417     per_minibatch_bytes += sizeof(float) * maxS * maxT;
418 
419     //betas
420     per_minibatch_bytes += sizeof(float) * maxS;
421 
422     //labels w/blanks, e_inc, s_inc
423     per_minibatch_bytes += 3 * sizeof(int) * maxS;
424 
425     log_softmax(activations, log_probs, input_lengths);
426 
427 #pragma omp parallel for
428     for (int mb = 0; mb < minibatch_; ++mb) {
429         const int T = input_lengths[mb]; // Length of utterance (time)
430         const int L = label_lengths[mb]; // Number of labels in transcription
431 
432         bool mb_status;
433 
434         std::tie(costs[mb], mb_status) =
435                 cost_and_grad_kernel(grads + mb * alphabet_size_,
436                                      log_probs + mb * alphabet_size_,
437                                      flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0),
438                                      T, L, mb,
439                                      bytes_used + mb * per_minibatch_bytes);
440     }
441 
442     return CTC_STATUS_SUCCESS;
443 }
444 
445 template<typename ProbT>
score_forward(const ProbT * const activations,ProbT * costs,const int * const flat_labels,const int * const label_lengths,const int * const input_lengths)446 ctcStatus_t CpuCTC<ProbT>::score_forward(const ProbT* const activations,
447                                          ProbT* costs,
448                                          const int* const flat_labels,
449                                          const int* const label_lengths,
450                                          const int* const input_lengths) {
451     if (activations == nullptr ||
452         costs == nullptr ||
453         flat_labels == nullptr ||
454         label_lengths == nullptr ||
455         input_lengths == nullptr
456         )
457         return CTC_STATUS_INVALID_VALUE;
458 
459     ProbT* log_probs = static_cast<ProbT *>(workspace_);
460 
461     int maxT = *std::max_element(input_lengths, input_lengths + minibatch_);
462 
463     size_t bytes_used = sizeof(ProbT) * minibatch_ * alphabet_size_ * maxT;
464 
465     //per minibatch memory
466     size_t per_minibatch_bytes = 0;
467 
468     int maxL = *std::max_element(label_lengths, label_lengths + minibatch_);
469     int maxS = 2 * maxL + 1;
470 
471     //output
472     per_minibatch_bytes += sizeof(float) * alphabet_size_;
473 
474     //alphas
475     per_minibatch_bytes += sizeof(float) * maxS * maxT;
476 
477     //betas
478     per_minibatch_bytes += sizeof(float) * maxS;
479 
480     //labels w/blanks, e_inc, s_inc
481     per_minibatch_bytes += 3 * sizeof(int) * maxS;
482 
483     log_softmax(activations, log_probs, input_lengths);
484 
485 #pragma omp parallel for
486     for (int mb = 0; mb < minibatch_; ++mb) {
487         const int T = input_lengths[mb]; // Length of utterance (time)
488         const int L = label_lengths[mb]; // Number of labels in transcription
489         const int S = 2*L + 1; // Number of labels with blanks
490 
491         CpuCTC_metadata ctcm(L, S, T, mb, alphabet_size_, workspace_,
492                              bytes_used + mb * per_minibatch_bytes, blank_label_,
493                              flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0));
494 
495 
496         if (L + ctcm.repeats > T)
497             costs[mb] = ProbT(0);
498         else {
499             costs[mb] = -compute_alphas(log_probs + mb * alphabet_size_, ctcm.repeats, S, T,
500                                         ctcm.e_inc, ctcm.s_inc, ctcm.labels_w_blanks,
501                                         ctcm.alphas);
502         }
503 
504     }
505 
506     return CTC_STATUS_SUCCESS;
507 }
508 
509 } // mxnet_warpctc
510