1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #pragma once
21
22 #include <tuple>
23 #include <cmath>
24 #include <limits>
25 #include <algorithm>
26 #include <numeric>
27
28 #include <dmlc/omp.h>
29
30 #include "ctc_helper.h"
31
32 namespace mxnet_warpctc {
33
34 template<typename ProbT>
35 class CpuCTC {
36 public:
37 // Noncopyable
CpuCTC(int alphabet_size,int minibatch,void * workspace,int blank_label)38 CpuCTC(int alphabet_size, int minibatch, void* workspace,
39 int blank_label) :
40 alphabet_size_(alphabet_size), minibatch_(minibatch),
41 workspace_(workspace), blank_label_(blank_label) {
42
43 };
44
45 CpuCTC(const CpuCTC&) = delete;
46 CpuCTC& operator=(const CpuCTC&) = delete;
47
48 ctcStatus_t cost_and_grad(const ProbT* const activations,
49 ProbT *grads,
50 ProbT* costs,
51 const int* const flat_labels,
52 const int* const label_lengths,
53 const int* const input_lengths);
54
55
56 ctcStatus_t score_forward(const ProbT* const activations,
57 ProbT* costs,
58 const int* const flat_labels,
59 const int* const label_lengths,
60 const int* const input_lengths);
61
62 private:
63
64 class CpuCTC_metadata {
65
66 private:
67 int setup_labels(const int* const labels, int blank_label, int L, int S);
68
69 public:
70 CpuCTC_metadata(int L, int S, int T, int mb, int alphabet_size,
71 void* workspace, size_t bytes_used, int blank_label,
72 const int* const labels);
73
74 ProbT* alphas;
75 ProbT* betas;
76 int* labels_w_blanks;
77 int* e_inc;
78 int* s_inc;
79 ProbT* output;
80 int repeats;
81 };
82
83 int alphabet_size_; // Number of characters plus blank
84 int minibatch_;
85 void* workspace_;
86 int blank_label_;
87
88 void log_softmax(const ProbT* const activations, ProbT* log_probs,
89 const int* const input_lengths);
90
91 std::tuple<ProbT, bool>
92 cost_and_grad_kernel(ProbT *grad, const ProbT* const log_probs,
93 const int* const labels, int T, int L,
94 int mb, size_t bytes_used);
95
96 ProbT compute_alphas(const ProbT* log_probs, int repeats, int S, int T,
97 const int* const e_inc,
98 const int* const s_inc,
99 const int* const labels,
100 ProbT* alphas);
101
102 ProbT compute_betas_and_grad(ProbT* grad, const ProbT* const log_probs,
103 ProbT log_partition, int repeats,
104 int S, int T, const int* const e_inc,
105 const int* const s_inc,
106 const int* const labels,
107 ProbT* alphas,
108 ProbT* betas,
109 ProbT* output);
110 };
111
112 template<typename ProbT>
CpuCTC_metadata(int L,int S,int T,int mb,int alphabet_size,void * workspace,size_t bytes_used,int blank_label,const int * const labels)113 CpuCTC<ProbT>::CpuCTC_metadata::CpuCTC_metadata(int L, int S, int T, int mb,
114 int alphabet_size,
115 void* workspace, size_t bytes_used,
116 int blank_label,
117 const int* const labels) {
118
119 alphas = reinterpret_cast<ProbT *>(static_cast<char *>(workspace) + bytes_used);
120 bytes_used += sizeof(ProbT) * S * T;
121 std::fill(alphas, alphas + S * T, ctc_helper::neg_inf<ProbT>());
122 betas = reinterpret_cast<ProbT *>(static_cast<char *>(workspace) + bytes_used);
123 bytes_used += sizeof(ProbT) * S;
124 std::fill(betas, betas + S, ctc_helper::neg_inf<ProbT>());
125 labels_w_blanks = reinterpret_cast<int *>(static_cast<char *>(workspace) + bytes_used);
126 bytes_used += sizeof(int) * S;
127 e_inc = reinterpret_cast<int *>(static_cast<char *>(workspace) + bytes_used);
128 bytes_used += sizeof(int) * S;
129 s_inc = reinterpret_cast<int *>(static_cast<char *>(workspace) + bytes_used);
130 bytes_used += sizeof(int) * S;
131 output = reinterpret_cast<ProbT *>(static_cast<char *>(workspace) + bytes_used);
132 bytes_used += sizeof(ProbT) * alphabet_size;
133
134 repeats = setup_labels(labels, blank_label, L, S);
135 }
136
137 template<typename ProbT>
setup_labels(const int * const labels,int blank_label,int L,int S)138 int CpuCTC<ProbT>::CpuCTC_metadata::setup_labels(const int* const labels,
139 int blank_label, int L, int S) {
140 int e_counter = 0;
141 int s_counter = 0;
142
143 s_inc[s_counter++] = 1;
144
145 int repeats = 0;
146
147 for (int i = 1; i < L; ++i) {
148 if (labels[i-1] == labels[i]) {
149 s_inc[s_counter++] = 1;
150 s_inc[s_counter++] = 1;
151 e_inc[e_counter++] = 1;
152 e_inc[e_counter++] = 1;
153 ++repeats;
154 }
155 else {
156 s_inc[s_counter++] = 2;
157 e_inc[e_counter++] = 2;
158 }
159 }
160 e_inc[e_counter++] = 1;
161
162 for (int i = 0; i < L; ++i) {
163 labels_w_blanks[2 * i] = blank_label;
164 labels_w_blanks[2 * i + 1] = labels[i];
165 }
166 labels_w_blanks[S - 1] = blank_label;
167
168 return repeats;
169 }
170
171 template<typename ProbT>
172 void
log_softmax(const ProbT * const activations,ProbT * log_probs,const int * const input_lengths)173 CpuCTC<ProbT>::log_softmax(const ProbT* const activations, ProbT* log_probs,
174 const int* const input_lengths) {
175 #pragma omp parallel for
176 for (int mb = 0; mb < minibatch_; ++mb) {
177 for(int c = 0; c < input_lengths[mb]; ++c) {
178 int col_offset = (mb + minibatch_ * c) * alphabet_size_;
179 ProbT max_activation = -std::numeric_limits<ProbT>::infinity();
180 for(int r = 0; r < alphabet_size_; ++r)
181 max_activation = std::max(max_activation, activations[r + col_offset]);
182
183 ProbT denom = ProbT(0.);
184 for(int r = 0; r < alphabet_size_; ++r) {
185 denom += std::exp(activations[r + col_offset] - max_activation);
186 }
187
188 for(int r = 0; r < alphabet_size_; ++r) {
189 log_probs[r + col_offset] = activations[r + col_offset]
190 - max_activation - std::log(denom);
191 }
192 }
193 }
194 }
195
196 template<typename ProbT>
197 std::tuple<ProbT, bool>
cost_and_grad_kernel(ProbT * grad,const ProbT * const log_probs,const int * const labels,int T,int L,int mb,size_t bytes_used)198 CpuCTC<ProbT>::cost_and_grad_kernel(ProbT *grad, const ProbT* const log_probs,
199 const int* const labels,
200 int T, int L, int mb, size_t bytes_used) {
201
202 const int S = 2*L + 1; // Number of labels with blanks
203
204 CpuCTC_metadata ctcm(L, S, T, mb, alphabet_size_, workspace_, bytes_used, blank_label_, labels);
205
206 bool over_threshold = false;
207
208 if (L + ctcm.repeats > T) {
209 return std::make_tuple(ProbT(0), over_threshold); // TODO, not right to return 0
210 }
211
212 ProbT llForward = compute_alphas(log_probs, ctcm.repeats, S, T, ctcm.e_inc,
213 ctcm.s_inc, ctcm.labels_w_blanks,
214 ctcm.alphas);
215
216 ProbT llBackward = compute_betas_and_grad(grad, log_probs, llForward, ctcm.repeats,
217 S, T, ctcm.e_inc, ctcm.s_inc,
218 ctcm.labels_w_blanks,
219 ctcm.alphas,
220 ctcm.betas,
221 ctcm.output);
222
223 ProbT diff = std::abs(llForward - llBackward);
224 if (diff > ctc_helper::threshold) {
225 over_threshold = true;
226 }
227
228 return std::make_tuple(-llForward, over_threshold);
229 }
230
231 // Computes forward probabilities
232 template<typename ProbT>
compute_alphas(const ProbT * log_probs,int repeats,int S,int T,const int * const e_inc,const int * const s_inc,const int * const labels,ProbT * alphas)233 ProbT CpuCTC<ProbT>::compute_alphas(const ProbT* log_probs, int repeats, int S, int T,
234 const int* const e_inc,
235 const int* const s_inc,
236 const int* const labels,
237 ProbT* alphas) {
238
239 int start = (((S /2) + repeats - T) < 0) ? 0 : 1,
240 end = S > 1 ? 2 : 1;
241
242 for (int i = start; i < end; ++i) {
243 alphas[i] = log_probs[labels[i]];
244 }
245
246 for(int t = 1; t < T; ++t) {
247 int remain = (S / 2) + repeats - (T - t);
248 if(remain >= 0)
249 start += s_inc[remain];
250 if(t <= (S / 2) + repeats)
251 end += e_inc[t - 1];
252 int startloop = start;
253 int idx1 = t * S, idx2 = (t - 1) * S, idx3 = t * (alphabet_size_ * minibatch_);
254
255 if (start == 0) {
256 alphas[idx1] = alphas[idx2] + log_probs[blank_label_ + idx3];
257 startloop += 1;
258 }
259
260 for(int i = startloop; i < end; ++i) {
261 ProbT prev_sum = ctc_helper::log_plus<ProbT>()(alphas[i + idx2], alphas[(i-1) + idx2]);
262
263 // Skip two if not on blank and not on repeat.
264 if (labels[i] != blank_label_ && i != 1 && labels[i] != labels[i-2])
265 prev_sum = ctc_helper::log_plus<ProbT>()(prev_sum, alphas[(i-2) + idx2]);
266
267 alphas[i + idx1] = prev_sum + log_probs[labels[i] + idx3];
268 }
269 }
270
271 ProbT loglike = ctc_helper::neg_inf<ProbT>();
272 for(int i = start; i < end; ++i) {
273 loglike = ctc_helper::log_plus<ProbT>()(loglike, alphas[i + (T - 1) * S]);
274 }
275
276 return loglike;
277 }
278
279 // Starting from T, we sweep backward over the alpha array computing one column
280 // of betas as we go. At each position we can update product alpha * beta and then
281 // sum into the gradient associated with each label.
282 // NOTE computes gradient w.r.t UNNORMALIZED final layer activations.
283 // Assumed passed in grads are already zeroed!
284 template<typename ProbT>
compute_betas_and_grad(ProbT * grad,const ProbT * const log_probs,ProbT log_partition,int repeats,int S,int T,const int * const e_inc,const int * const s_inc,const int * const labels,ProbT * alphas,ProbT * betas,ProbT * output)285 ProbT CpuCTC<ProbT>::compute_betas_and_grad(ProbT* grad, const ProbT* const log_probs,
286 ProbT log_partition, int repeats,
287 int S, int T, const int* const e_inc,
288 const int* const s_inc,
289 const int* const labels,
290 ProbT* alphas,
291 ProbT* betas,
292 ProbT* output) {
293 int start = S > 1 ? (S - 2) : 0,
294 end = (T > (S / 2) + repeats) ? S : S-1;
295
296 std::fill(output, output + alphabet_size_, ctc_helper::neg_inf<ProbT>());
297
298 //set the starting values in the beta column at the very right edge
299 for (int i = start; i < end; ++i) {
300 betas[i] = log_probs[labels[i] + (T - 1) * (alphabet_size_ * minibatch_)];
301
302 //compute alpha * beta in log space at this position in (S, T) space
303 alphas[i + (T - 1) * S] += betas[i];
304
305 //update the gradient associated with this label
306 //essentially performing a reduce-by-key in a sequential manner
307 output[labels[i]] =
308 ctc_helper::log_plus<ProbT>()(alphas[i + (T - 1) * S], output[labels[i]]);
309 }
310
311 //update the gradient wrt to each unique label
312 for (int i = 0; i < alphabet_size_; ++i) {
313 int idx3 = (T - 1) * alphabet_size_ * minibatch_ + i;
314
315 if (output[i] == 0.0 || output[i] == ctc_helper::neg_inf<ProbT>() ||
316 log_probs[idx3] == ctc_helper::neg_inf<ProbT>()) {
317 grad[idx3] = std::exp(log_probs[idx3]);
318 } else {
319 grad[idx3] = std::exp(log_probs[idx3])
320 - std::exp(output[i] - log_probs[idx3] - log_partition);
321 }
322 }
323
324 //loop from the second to last column all the way to the left
325 for(int t = T - 2; t >= 0; --t) {
326 int remain = (S / 2) + repeats - (T - t);
327 if(remain >= -1)
328 start -= s_inc[remain + 1];
329 if(t < (S / 2) + repeats)
330 end -= e_inc[t];
331
332 int endloop = end == S ? end - 1 : end;
333 int idx1 = t * S, idx3 = t * (alphabet_size_ * minibatch_);
334
335 std::fill(output, output + alphabet_size_, ctc_helper::neg_inf<ProbT>());
336
337 for(int i = start; i < endloop; ++i) {
338 ProbT next_sum = ctc_helper::log_plus<ProbT>()(betas[i], betas[(i+1)]);
339 // Skip two if not on blank and not on repeat.
340 if (labels[i] != blank_label_ && i != (S-2) && labels[i] != labels[i+2]){
341 next_sum = ctc_helper::log_plus<ProbT>()(next_sum, betas[(i+2)]);
342 }
343 betas[i] = next_sum + log_probs[labels[i] + idx3];
344
345 //compute alpha * beta in log space
346 alphas[i + idx1] += betas[i];
347
348 //update the gradient associated with this label
349 output[labels[i]] =
350 ctc_helper::log_plus<ProbT>()(alphas[i + idx1], output[labels[i]]);
351 }
352
353 if (end == S) {
354 betas[(S-1)] = betas[(S-1)] + log_probs[blank_label_ + idx3];
355 alphas[(S-1) + idx1] += betas[(S-1)];
356
357 output[labels[S-1]] =
358 ctc_helper::log_plus<ProbT>()(alphas[S-1 + idx1], output[labels[S-1]]);
359 }
360
361 //go over the unique labels and compute the final grad
362 // wrt to each one at this time step
363 for (int i = 0; i < alphabet_size_; ++i) {
364
365 if (output[i] == 0.0 || output[i] == ctc_helper::neg_inf<ProbT>() ||
366 log_probs[idx3] == ctc_helper::neg_inf<ProbT>()) {
367 grad[idx3] = std::exp(log_probs[idx3]);
368 } else {
369 grad[idx3] = std::exp(log_probs[idx3])
370 - std::exp(output[i] - log_probs[idx3] - log_partition);
371 }
372 ++idx3;
373 }
374 }
375
376 ProbT loglike = ctc_helper::neg_inf<ProbT>();
377 for(int i = start; i < end; ++i) {
378 loglike = ctc_helper::log_plus<ProbT>()(loglike, betas[i]);
379 }
380
381 return loglike;
382 }
383
384 template<typename ProbT>
385 ctcStatus_t
cost_and_grad(const ProbT * const activations,ProbT * grads,ProbT * costs,const int * const flat_labels,const int * const label_lengths,const int * const input_lengths)386 CpuCTC<ProbT>::cost_and_grad(const ProbT* const activations,
387 ProbT *grads,
388 ProbT *costs,
389 const int* const flat_labels,
390 const int* const label_lengths,
391 const int* const input_lengths) {
392 if (activations == nullptr ||
393 grads == nullptr ||
394 costs == nullptr ||
395 flat_labels == nullptr ||
396 label_lengths == nullptr ||
397 input_lengths == nullptr
398 )
399 return CTC_STATUS_INVALID_VALUE;
400
401 ProbT* log_probs = static_cast<ProbT *>(workspace_);
402
403 int maxT = *std::max_element(input_lengths, input_lengths + minibatch_);
404
405 size_t bytes_used = sizeof(ProbT) * minibatch_ * alphabet_size_ * maxT;
406
407 //per minibatch memory
408 size_t per_minibatch_bytes = 0;
409
410 int maxL = *std::max_element(label_lengths, label_lengths + minibatch_);;
411 int maxS = 2 * maxL + 1;
412
413 //output
414 per_minibatch_bytes += sizeof(float) * alphabet_size_;
415
416 //alphas
417 per_minibatch_bytes += sizeof(float) * maxS * maxT;
418
419 //betas
420 per_minibatch_bytes += sizeof(float) * maxS;
421
422 //labels w/blanks, e_inc, s_inc
423 per_minibatch_bytes += 3 * sizeof(int) * maxS;
424
425 log_softmax(activations, log_probs, input_lengths);
426
427 #pragma omp parallel for
428 for (int mb = 0; mb < minibatch_; ++mb) {
429 const int T = input_lengths[mb]; // Length of utterance (time)
430 const int L = label_lengths[mb]; // Number of labels in transcription
431
432 bool mb_status;
433
434 std::tie(costs[mb], mb_status) =
435 cost_and_grad_kernel(grads + mb * alphabet_size_,
436 log_probs + mb * alphabet_size_,
437 flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0),
438 T, L, mb,
439 bytes_used + mb * per_minibatch_bytes);
440 }
441
442 return CTC_STATUS_SUCCESS;
443 }
444
445 template<typename ProbT>
score_forward(const ProbT * const activations,ProbT * costs,const int * const flat_labels,const int * const label_lengths,const int * const input_lengths)446 ctcStatus_t CpuCTC<ProbT>::score_forward(const ProbT* const activations,
447 ProbT* costs,
448 const int* const flat_labels,
449 const int* const label_lengths,
450 const int* const input_lengths) {
451 if (activations == nullptr ||
452 costs == nullptr ||
453 flat_labels == nullptr ||
454 label_lengths == nullptr ||
455 input_lengths == nullptr
456 )
457 return CTC_STATUS_INVALID_VALUE;
458
459 ProbT* log_probs = static_cast<ProbT *>(workspace_);
460
461 int maxT = *std::max_element(input_lengths, input_lengths + minibatch_);
462
463 size_t bytes_used = sizeof(ProbT) * minibatch_ * alphabet_size_ * maxT;
464
465 //per minibatch memory
466 size_t per_minibatch_bytes = 0;
467
468 int maxL = *std::max_element(label_lengths, label_lengths + minibatch_);
469 int maxS = 2 * maxL + 1;
470
471 //output
472 per_minibatch_bytes += sizeof(float) * alphabet_size_;
473
474 //alphas
475 per_minibatch_bytes += sizeof(float) * maxS * maxT;
476
477 //betas
478 per_minibatch_bytes += sizeof(float) * maxS;
479
480 //labels w/blanks, e_inc, s_inc
481 per_minibatch_bytes += 3 * sizeof(int) * maxS;
482
483 log_softmax(activations, log_probs, input_lengths);
484
485 #pragma omp parallel for
486 for (int mb = 0; mb < minibatch_; ++mb) {
487 const int T = input_lengths[mb]; // Length of utterance (time)
488 const int L = label_lengths[mb]; // Number of labels in transcription
489 const int S = 2*L + 1; // Number of labels with blanks
490
491 CpuCTC_metadata ctcm(L, S, T, mb, alphabet_size_, workspace_,
492 bytes_used + mb * per_minibatch_bytes, blank_label_,
493 flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0));
494
495
496 if (L + ctcm.repeats > T)
497 costs[mb] = ProbT(0);
498 else {
499 costs[mb] = -compute_alphas(log_probs + mb * alphabet_size_, ctcm.repeats, S, T,
500 ctcm.e_inc, ctcm.s_inc, ctcm.labels_w_blanks,
501 ctcm.alphas);
502 }
503
504 }
505
506 return CTC_STATUS_SUCCESS;
507 }
508
509 } // mxnet_warpctc
510